From 64bd6eae6dd3a4f0179f0c2ea9bd13868e1c98bb Mon Sep 17 00:00:00 2001 From: Honglu Fan Date: Sun, 18 Dec 2022 17:02:42 -0500 Subject: [PATCH 01/17] Fix sampling behavior. Reduce duplicated codes. --- algorithm_distillation/task.py | 83 ++++++++++++++++++---------------- tests/test_env.py | 9 ++-- 2 files changed, 49 insertions(+), 43 deletions(-) diff --git a/algorithm_distillation/task.py b/algorithm_distillation/task.py index 58e407e..0e43a8e 100644 --- a/algorithm_distillation/task.py +++ b/algorithm_distillation/task.py @@ -177,37 +177,12 @@ def _get_most_recent_history( """ pos = buffer.pos total_length = length * (skip + 1) - if pos >= total_length: - return ( - self._obs_post_process( - buffer.observations[pos - total_length : pos : skip + 1] - ), - self._act_post_process( - buffer.actions[pos - total_length : pos : skip + 1] - ), - self._rew_post_process( - buffer.rewards[pos - total_length : pos : skip + 1] - ), - ) - else: - latest = ( - self._obs_post_process(buffer.observations[: pos : skip + 1]), - self._act_post_process(buffer.actions[: pos : skip + 1]), - self._rew_post_process(buffer.rewards[: pos : skip + 1]), - ) - if buffer.full: - extra = ( - self._obs_post_process( - buffer.observations[pos - length :: skip + 1] - ), - self._act_post_process(buffer.actions[pos - length :: skip + 1]), - self._rew_post_process(buffer.rewards[pos - length :: skip + 1]), - ) - latest = tuple( - [np.concatenate([extra[i], latest[i]], axis=0) for i in range(3)] - ) - - return latest + assert buffer.buffer_size > total_length, "Replay buffer size must be larger than the sequence length." + + start = (pos - total_length + buffer.buffer_size) % buffer.buffer_size + end = pos + + return self._get_obs_act_rew(buffer, start, end, skip) def _randomly_sample_buffer( self, buffer: ReplayBuffer, length: int, skip: int @@ -220,21 +195,51 @@ def _randomly_sample_buffer( :param skip: the amount to skip between states. :return: an (observations, actions, rewards) tuple. """ - upper_bound = buffer.size if buffer.full else buffer.pos total_length = length * (skip + 1) + assert buffer.buffer_size > total_length, "Replay buffer size must be larger than the sequence length." + + if not buffer.full: + start = random.randint(0, buffer.pos - total_length) + end = start + total_length + else: + start = random.randint(0, buffer.buffer_size) + end = (start + total_length) % buffer.buffer_size + + return self._get_obs_act_rew(buffer, start, end, skip) - if total_length > upper_bound: - raise IndexError("Buffer contains fewer samples than necessary.") + @staticmethod + def _get_range(array: np.ndarray, start: int, end: int, interval: int) -> np.ndarray: + """ + A helper function to either slice array[start:end:interval] or combine array[start::interval] and + array[:end:interval] depending on whether start < end. + :param array: the sliced array. + :param start: the starting index. + :param end: the ending index (exclusive). + :param interval: the interval. + :return: the sliced sub-array. + """ + if start < end: + return array[start:end:interval] + else: + return np.concatenate([array[start::interval], array[:end:interval]], axis=0) - start = random.randint(0, upper_bound - total_length) + def _get_obs_act_rew(self, buffer: ReplayBuffer, start: int, end: int, skip: int): + """ + Return a tuple (obs, act, rew) sampled according to the buffer and the parameters. + :param buffer: the replay buffer. + :param start: the starting index. + :param end: the ending index. + :param skip: the amount of states to skip. + :return: the tuple (obs, act, rew) + """ return ( self._obs_post_process( - buffer.observations[start : start + total_length : skip + 1] + self._get_range(buffer.observations, start, end, skip + 1) ), self._act_post_process( - buffer.actions[start : start + total_length : skip + 1] + self._get_range(buffer.actions, start, end, skip + 1) ), self._rew_post_process( - buffer.rewards[start : start + total_length : skip + 1] + self._get_range(buffer.rewards, start, end, skip + 1) ), - ) + ) \ No newline at end of file diff --git a/tests/test_env.py b/tests/test_env.py index 34fc082..78e7777 100644 --- a/tests/test_env.py +++ b/tests/test_env.py @@ -26,10 +26,11 @@ def test_gym_task(): assert sample[2].shape == (10, 1) # Rewards. # More than buffer - sample = task.sample_history(1000, skip=0, most_recent=True) - assert sample[0].shape == (100, 16) # Observations are discrete classes - assert sample[1].shape == (100, 1) # Actions are discrete classes - assert sample[2].shape == (100, 1) # Rewards. + try: + sample = task.sample_history(1000, skip=0, most_recent=True) + assert False, "Error should have been raised." + except AssertionError: + assert True """# Please install gym[atari, accept-rom-license] manually if you want to run Atari. env = gym.make('Alien-v4') From 03167680223eabcea39fe831d6cc44754b734ae2 Mon Sep 17 00:00:00 2001 From: Honglu Fan Date: Sun, 18 Dec 2022 17:09:33 -0500 Subject: [PATCH 02/17] Minor fix on requirements.txt. --- requirements.txt | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/requirements.txt b/requirements.txt index 5887b2c..029afe5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ stable-baselines3 -transformers~=4.24.0 -torch~=1.12.1 -pytest~=7.2.0 -gym~=0.21.0 -numpy~=1.23.5 \ No newline at end of file +transformers>=4.24.0 +torch>=1.12.1 +pytest>=7.2.0 +gym==0.21.0 +numpy>=1.23.5 \ No newline at end of file From c98720e2d09eadceba671c57a16bb76e6901b77e Mon Sep 17 00:00:00 2001 From: Honglu Fan Date: Sun, 18 Dec 2022 18:31:49 -0500 Subject: [PATCH 03/17] Added customized logger and included its usage in the test. --- algorithm_distillation/sb3_util/__init__.py | 4 ++ algorithm_distillation/sb3_util/logger.py | 61 +++++++++++++++++++++ algorithm_distillation/task.py | 2 +- tests/test_ad.py | 12 ++++ 4 files changed, 78 insertions(+), 1 deletion(-) create mode 100644 algorithm_distillation/sb3_util/__init__.py create mode 100644 algorithm_distillation/sb3_util/logger.py diff --git a/algorithm_distillation/sb3_util/__init__.py b/algorithm_distillation/sb3_util/__init__.py new file mode 100644 index 0000000..09fe6a3 --- /dev/null +++ b/algorithm_distillation/sb3_util/__init__.py @@ -0,0 +1,4 @@ +from .logger import CustomLogger, configure + + +__all__ = ['CustomLogger', 'configure'] diff --git a/algorithm_distillation/sb3_util/logger.py b/algorithm_distillation/sb3_util/logger.py new file mode 100644 index 0000000..63a54fa --- /dev/null +++ b/algorithm_distillation/sb3_util/logger.py @@ -0,0 +1,61 @@ +import datetime +import os +import tempfile +from collections import defaultdict +from typing import Any, Optional, Union, Tuple, List +from stable_baselines3.common.logger import Logger, make_output_format + + +class CustomLogger(Logger): + """ + A logger object can be plugged into an SB3 agent to record the metric. Here we customize it by overriding + in order to add a history. One can further customize to connect with wandb. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.history_value = defaultdict(list) + self.history_mean_value = defaultdict(list) + + def record(self, key: str, value: Any, exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None: + super().record(key, value, exclude) + self.history_value[key].append(value) + + def record_mean(self, key: str, value: Any, exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None: + super().record_mean(key, value, exclude) + self.history_mean_value[key].append(self.name_to_value[key]) + + +def configure(folder: Optional[str] = None, + format_strings: Optional[List[str]] = None, + logger_class=Logger) -> Logger: + """ + (This is an adaptation of SB3's logger configuration helper function. The change was very minor, only one line + towards the end to allow for a customized logger class.) + Configure the current logger. + :param folder: the save location + (if None, $SB3_LOGDIR, if still None, tempdir/SB3-[date & time]) + :param format_strings: the output logging format + (if None, $SB3_LOG_FORMAT, if still None, ['stdout', 'log', 'csv']) + :param logger_class: (Optional) the custom logger class. + :return: The logger object. + """ + if folder is None: + folder = os.getenv("SB3_LOGDIR") + if folder is None: + folder = os.path.join(tempfile.gettempdir(), datetime.datetime.now().strftime("SB3-%Y-%m-%d-%H-%M-%S-%f")) + assert isinstance(folder, str) + os.makedirs(folder, exist_ok=True) + + log_suffix = "" + if format_strings is None: + format_strings = os.getenv("SB3_LOG_FORMAT", "stdout,log,csv").split(",") + + format_strings = list(filter(None, format_strings)) + output_formats = [make_output_format(f, folder, log_suffix) for f in format_strings] + + logger = logger_class(folder=folder, output_formats=output_formats) + # Only print when some files will be saved + if len(format_strings) > 0 and format_strings != ["stdout"]: + logger.log(f"Logging to {folder}") + return logger diff --git a/algorithm_distillation/task.py b/algorithm_distillation/task.py index 0e43a8e..c9c8d89 100644 --- a/algorithm_distillation/task.py +++ b/algorithm_distillation/task.py @@ -96,7 +96,7 @@ def sample_history( (note: length of training steps instead of length of sequence! Every step includes obs, act, rew. Thus, the final sequence is 3x as long.) :param skip: (Optional) skip certain amount of steps between two states - :param most_recent: (Optional) sample from the most recent histories. False to sample randomly. + :param most_recent: (Optional) get the most recent histories. False to sample randomly. :return: a tuple of (observations, actions, rewards). Each is a tensor. """ if self.algorithm in self._on_policy: diff --git a/tests/test_ad.py b/tests/test_ad.py index 746ec02..1a48d8c 100644 --- a/tests/test_ad.py +++ b/tests/test_ad.py @@ -2,6 +2,7 @@ from algorithm_distillation import GymTask, TaskManager, GymAD from algorithm_distillation.models import GPT2AD +from algorithm_distillation.sb3_util import CustomLogger, configure def test_ad(): @@ -13,8 +14,19 @@ def test_ad(): model = GPT2AD(env.observation_space.n, env.action_space.n, 12, max_ep_len=16) task = GymTask(env, 'DQN', config) + # Inject a customized logger + logger = configure(None, None, CustomLogger) + task.agent.set_logger(logger) + task_manager = TaskManager([task]) task_manager.train(100) + assert 'history_value' in task.agent.logger.__dir__() + # 100 total time-steps, but training only happens upon the finish of an episode. We don't know how many gradient + # steps are trained, but we are sure it is nonzero. + assert len(task.agent.logger.history_value['train/loss']) != 0 + # But we are sure that rollout happens 100 times. + assert len(task.agent.logger.history_value['rollout/exploration_rate']) == 100 + ad = GymAD(model) ad.train(task_manager, 100, 10, skip=0, batch_size=8) From 23eb20bf6e831c6fde9af281f0c8fdc92e95c737 Mon Sep 17 00:00:00 2001 From: Honglu Fan Date: Sun, 18 Dec 2022 22:42:04 -0500 Subject: [PATCH 04/17] Added rollout method for AD. --- algorithm_distillation/ad.py | 75 +++++++++++++++++++++++--- algorithm_distillation/models/gpt2.py | 66 +++++++++++++---------- algorithm_distillation/models/util.py | 4 ++ algorithm_distillation/task.py | 39 ++++++++++---- algorithm_distillation/task_manager.py | 5 +- tests/test_ad.py | 3 ++ tests/test_env.py | 4 +- tests/test_util.py | 4 ++ 8 files changed, 153 insertions(+), 47 deletions(-) diff --git a/algorithm_distillation/ad.py b/algorithm_distillation/ad.py index f42a70f..05828fd 100644 --- a/algorithm_distillation/ad.py +++ b/algorithm_distillation/ad.py @@ -1,8 +1,10 @@ import abc +import numpy as np import torch from algorithm_distillation.models.ad_transformer import ADTransformer +from .task import Task, GymTask from .task_manager import TaskManager @@ -27,6 +29,13 @@ def train( ): pass + @abc.abstractmethod + def rollout(self, + task: Task, + steps: int, + skip: int) -> tuple: + pass + class GymAD(AlgorithmDistillation): def train( @@ -75,9 +84,63 @@ def train( losses.append(loss.item()) optimizer.step() - self.model.eval() # By default, set to eval mode outside of training. + self.model.eval() # By default, set to eval mode outside training. + + def rollout(self, task: GymTask, steps: int, skip: int) -> tuple: + """ + Roll out for `steps` amount of steps (ignore the policy embedded in `task` and only uses its _env). + + :param task: the task to perform rollout on. + :param steps: the amount of steps to roll out. + :param skip: the amount of steps to skip (normally should be the same as `skip` during training). + :return: the full sequences (observations, actions, rewards), each of length `steps`. + """ + device = ( + torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + ) + self.model.to(device) - print(losses) + for st in ['obs', 'act']: + if getattr(task, f'{st}_dim') != getattr(self.model, f'{st}_dim'): + raise ValueError(f"The task must have observation dimension {self.model.obs_dim}") + env = task.env + observations = torch.zeros((steps, task.obs_dim), device=device, dtype=torch.float) + + # Predicted action logits + action_logits = torch.zeros((steps, task.act_dim), device=device, dtype=torch.float) + # The actual actions taken (argmax of action_logits) + actions = torch.zeros((steps,), device=device, dtype=torch.long) + # The actual one-hot encoded actions (nn.one_hot of actions) + actions_one_hot = torch.zeros((steps, task.act_dim), device=device, dtype=torch.float) + + rewards = torch.zeros((steps, 1), device=device, dtype=torch.float) + + obs, done = None, True + for step in range(steps): + if done: + obs, done = torch.tensor(task.obs_post_process(np.array([env.reset()])), + device=device, + dtype=torch.float), False + + # TODO: can be optimized using cache + with torch.inference_mode(): + action_logits[step] = self.model( + None if step < skip + 1 else observations[None, : step: skip + 1], + None if step < skip + 1 else actions_one_hot[None, : step: skip + 1], + None if step < skip + 1 else rewards[None, : step : skip + 1], + current_obs=obs[None, 0], + action_only=True)[0, step] + actions[step] = torch.argmax(action_logits[step]).type(torch.long) + actions_one_hot[step] = torch.nn.functional.one_hot(actions[step], + num_classes=task.act_dim).type(torch.float) + + observations[step] = obs[None, 0] + obs, rew, done, _ = env.step(actions[step].item()) + obs = torch.tensor(task.obs_post_process(np.array([obs])), device=device, dtype=torch.float) + rew = torch.tensor(task.rew_post_process(np.array([rew])), device=device, dtype=torch.float) + rewards[step] = rew[0] + + return observations, actions, rewards @staticmethod def _get_data_iter( @@ -90,13 +153,13 @@ def _get_data_iter( yield ( torch.tensor( - [sample[0] for sample in samples], dtype=torch.float, device=device + np.array([sample[0] for sample in samples]), dtype=torch.float, device=device ), # observations torch.tensor( - [sample[1] for sample in samples], dtype=torch.long, device=device + np.array([sample[1] for sample in samples]), dtype=torch.long, device=device ), # actions torch.tensor( - [sample[2] for sample in samples], dtype=torch.float, device=device + np.array([sample[2] for sample in samples]), dtype=torch.float, device=device ), # rewards ) @@ -109,5 +172,5 @@ def _compute_loss(x, y) -> torch.Tensor: """ assert y.dtype == torch.long assert x.shape[:-1] + (1,) == y.shape - x = torch.nn.functional.log_softmax(x) # (b, length, action_num) + x = torch.nn.functional.log_softmax(x, dim=-1) # (b, length, action_num) return -torch.take_along_dim(x, y, dim=len(y.shape) - 1).sum(-1).mean() diff --git a/algorithm_distillation/models/gpt2.py b/algorithm_distillation/models/gpt2.py index 281bc60..9db6093 100644 --- a/algorithm_distillation/models/gpt2.py +++ b/algorithm_distillation/models/gpt2.py @@ -110,9 +110,9 @@ def forward( But other cases are allowed, e.g., `..., latest_obs, latest_act, None` -> `next_reward` - :param obs: (b, t, obs_dim) - :param actions: (b, t, act_dim) - :param rewards: (b, t, 1) + :param obs: (b, t, obs_dim) or None if b==0 + :param actions: (b, t, act_dim) or None if b==0 + :param rewards: (b, t, 1) or None if b==0 :param current_obs: (Optional) (b, obs_dim) :param current_action: (Optional) shape (b, act_dim) :param current_reward: (Optional) shape (b, 1) @@ -122,9 +122,14 @@ def forward( :param action_only: (Optional) return predicted actions only. :return: predicted action logits (if action_only) or predicted action logits, rewards, and obs. """ - device = obs.device + if obs is None: + assert current_obs is not None, "Empty input." + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + batch_size, timestep = current_obs.shape[0], 0 + else: + device = obs.device + batch_size, timestep, _ = obs.shape - batch_size, timestep, _ = obs.shape if current_step_id is None: if step_ids is not None: logger.warning( @@ -132,22 +137,21 @@ def forward( ) current_step_id = timestep - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, timestep), dtype=torch.float, device=device - ) - - if step_ids is None: + if step_ids is None and timestep > 0: step_ids = torch.arange(0, timestep, dtype=torch.long, device=device).view( 1, timestep + ).repeat((batch_size, 1)) + + if timestep == 0: + embedded_obs, embedded_act, embedded_rew = None, None, None + else: + embedded_steps = self.step_embedding(step_ids).view( + batch_size, timestep, self.hidden_size ) - embedded_steps = self.step_embedding(step_ids).view( - 1, timestep, self.hidden_size - ) + embedded_obs = self.obs_embedding(obs) + embedded_steps + embedded_act = self.act_embedding(actions) + embedded_steps + embedded_rew = self.rew_embedding(rewards) + embedded_steps - embedded_obs = self.obs_embedding(obs) + embedded_steps - embedded_act = self.act_embedding(actions) + embedded_steps - embedded_rew = self.rew_embedding(rewards) + embedded_steps embedded_latest_step = self.step_embedding( torch.tensor([current_step_id], dtype=torch.long, device=device) ) @@ -168,16 +172,24 @@ def forward( # Note: only affects axis 1. Axis 0 (batch) and axis 2 (embedding) are preserved. input_seq = stack_seq(embedded_obs, embedded_act, embedded_rew, extra) input_seq = self.layer_norm_in_embedding(input_seq) - attention_mask = ( - attention_mask.unsqueeze(-1).repeat((1, 1, 3)).view(batch_size, -1) - ) - attention_mask = torch.concat( - [ - attention_mask, - torch.ones((batch_size, num_extra), dtype=torch.float, device=device), - ], - dim=1, - ) + + if timestep == 0: + attention_mask = torch.ones((batch_size, num_extra), dtype=torch.float, device=device) + else: + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, timestep), dtype=torch.float, device=device + ) + attention_mask = ( + attention_mask.unsqueeze(-1).repeat((1, 1, 3)).view(batch_size, -1) + ) + attention_mask = torch.concat( + [ + attention_mask, + torch.ones((batch_size, num_extra), dtype=torch.float, device=device), + ], + dim=1, + ) # Do inference using the underlying transformer. output = self.transformers( diff --git a/algorithm_distillation/models/util.py b/algorithm_distillation/models/util.py index a631ced..541d970 100644 --- a/algorithm_distillation/models/util.py +++ b/algorithm_distillation/models/util.py @@ -11,6 +11,10 @@ def stack_seq(obs, act, rew, extra=None) -> torch.Tensor: :param extra: (Optional) shape (b, i, hidden_size) where i can be 1, 2 or 3 :return: shape (b, 3*t+i, hidden_size) """ + if obs is None: + # batch size is 0. We return extra. + return extra + batch_size, timestep, _ = obs.shape stacked = torch.stack((obs, act, rew), dim=2).view(batch_size, 3 * timestep, -1) if extra is None: diff --git a/algorithm_distillation/task.py b/algorithm_distillation/task.py index c9c8d89..c2d9d7a 100644 --- a/algorithm_distillation/task.py +++ b/algorithm_distillation/task.py @@ -6,6 +6,7 @@ import gym import numpy as np import stable_baselines3 +import torch.nn.functional from stable_baselines3.common.buffers import ReplayBuffer @@ -54,7 +55,10 @@ class GymTask(Task): # If the environment has a discrete obs space of n classes, return an (n,)-shape array for each obs. # If the environment has a continuous obs space of certain shape, return a flattened array for each obs. obs_dim: int + act_dim: int + obs_cls: str + act_cls: str def __init__(self, env, algorithm: str, config: Optional[dict] = None): """ @@ -64,9 +68,9 @@ def __init__(self, env, algorithm: str, config: Optional[dict] = None): :param algorithm: the stable-baselines3 algorithm :param config: (Optional) the stable-baselines3 algorithm config (except for the gym environment) """ - self.env = env + self._env = env if not isinstance(env.action_space, gym.spaces.Discrete): - raise ValueError("Only support discrete action spaces.") + raise NotImplementedError("Only supports discrete action spaces for now.") self.algorithm = algorithm if algorithm not in self._algorithms: @@ -82,6 +86,7 @@ def __init__(self, env, algorithm: str, config: Optional[dict] = None): self.agent = self._algorithms[algorithm](**self.config) self.obs_dim, self.obs_cls = self._get_obs_specs() + self.act_dim, self.act_cls = self._get_act_specs() def train(self, steps: int): self.agent.learn(total_timesteps=steps) @@ -113,8 +118,6 @@ def sample_history( return self._randomly_sample_buffer(buffer, length, skip) def _get_obs_specs(self) -> tuple: - if not hasattr(self, "env"): - raise ValueError('Need to assign "env" attribute first.') obs_space = self.env.observation_space if isinstance(obs_space, gym.spaces.Discrete): @@ -126,7 +129,23 @@ def _get_obs_specs(self) -> tuple: f"The observation space does not support {type(obs_space)}." ) - def _obs_post_process(self, obs: np.ndarray) -> np.ndarray: + def _get_act_specs(self) -> tuple: + act_space = self.env.action_space + + if isinstance(act_space, gym.spaces.Discrete): + return act_space.n, "discrete" + elif isinstance(act_space, gym.spaces.Box): + raise NotImplementedError( + f"The observation space does not support {type(act_space)}." + ) + + @property + def env(self) -> gym.Env: + if not hasattr(self, "_env"): + raise ValueError('Need to assign "_env" attribute first.') + return self._env + + def obs_post_process(self, obs: np.ndarray) -> np.ndarray: """ Post-process the observations according to its type and shape. @@ -145,7 +164,7 @@ def _obs_post_process(self, obs: np.ndarray) -> np.ndarray: raise RuntimeError("Impossible code path.") @staticmethod - def _act_post_process(act: np.ndarray) -> np.ndarray: + def act_post_process(act: np.ndarray) -> np.ndarray: """ Post-process the actions. Assume actions are discrete with shape (length, 1) or (length). @@ -155,7 +174,7 @@ def _act_post_process(act: np.ndarray) -> np.ndarray: return act.reshape((-1, 1)).astype(int) @staticmethod - def _rew_post_process(rew: np.ndarray) -> np.ndarray: + def rew_post_process(rew: np.ndarray) -> np.ndarray: """ Post-process the rewards. Rewards are scalars with shape (length,). @@ -233,13 +252,13 @@ def _get_obs_act_rew(self, buffer: ReplayBuffer, start: int, end: int, skip: int :return: the tuple (obs, act, rew) """ return ( - self._obs_post_process( + self.obs_post_process( self._get_range(buffer.observations, start, end, skip + 1) ), - self._act_post_process( + self.act_post_process( self._get_range(buffer.actions, start, end, skip + 1) ), - self._rew_post_process( + self.rew_post_process( self._get_range(buffer.rewards, start, end, skip + 1) ), ) \ No newline at end of file diff --git a/algorithm_distillation/task_manager.py b/algorithm_distillation/task_manager.py index d268fbd..c893e80 100644 --- a/algorithm_distillation/task_manager.py +++ b/algorithm_distillation/task_manager.py @@ -16,8 +16,9 @@ def __init__(self, tasks: List[Task]): if not tasks: raise ValueError("The task list cannot be empty.") for task in tasks[1:]: - if task.obs_dim != tasks[0].obs_dim: - raise ValueError("All tasks must have the same obs_dim.") + for st in ['obs', 'act']: + if getattr(task, f'{st}_dim') != getattr(tasks[0], f'{st}_dim'): + raise ValueError(f"All tasks must have the same {st}_dim.") self.tasks = tasks diff --git a/tests/test_ad.py b/tests/test_ad.py index 1a48d8c..68266b1 100644 --- a/tests/test_ad.py +++ b/tests/test_ad.py @@ -30,3 +30,6 @@ def test_ad(): ad = GymAD(model) ad.train(task_manager, 100, 10, skip=0, batch_size=8) + + obs, act, rew = ad.rollout(task, 16, 0) + print(obs, act, rew) diff --git a/tests/test_env.py b/tests/test_env.py index 78e7777..ff049f2 100644 --- a/tests/test_env.py +++ b/tests/test_env.py @@ -33,12 +33,12 @@ def test_gym_task(): assert True """# Please install gym[atari, accept-rom-license] manually if you want to run Atari. - env = gym.make('Alien-v4') + _env = gym.make('Alien-v4') config = {'learning_starts': 10, 'buffer_size': 1000, 'policy': 'MlpPolicy' } - task = GymTask(env, 'DQN', config) + task = GymTask(_env, 'DQN', config) task.train(100) """ diff --git a/tests/test_util.py b/tests/test_util.py index e888e59..28b825d 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -30,3 +30,7 @@ def test_stack_seq(): result = stack_seq(obs, act, rew, extra) assert torch.all(result.isclose(r)) + + # when batch size is 0, all of obs, act, rew should be None. It should just return the extra. + result = stack_seq(None, None, None, extra) + assert torch.all(result.isclose(extra)) From 88faa77f66d9b65ae42250530fccba4526c64b86 Mon Sep 17 00:00:00 2001 From: Honglu Fan Date: Sun, 18 Dec 2022 22:47:55 -0500 Subject: [PATCH 05/17] Fix type. --- algorithm_distillation/ad.py | 76 ++++++++++++++------- algorithm_distillation/models/gpt2.py | 22 ++++-- algorithm_distillation/sb3_util/__init__.py | 3 +- algorithm_distillation/sb3_util/logger.py | 30 ++++++-- algorithm_distillation/task.py | 23 +++++-- algorithm_distillation/task_manager.py | 4 +- 6 files changed, 108 insertions(+), 50 deletions(-) diff --git a/algorithm_distillation/ad.py b/algorithm_distillation/ad.py index 05828fd..46529c6 100644 --- a/algorithm_distillation/ad.py +++ b/algorithm_distillation/ad.py @@ -4,8 +4,8 @@ import torch from algorithm_distillation.models.ad_transformer import ADTransformer -from .task import Task, GymTask +from .task import GymTask from .task_manager import TaskManager @@ -25,15 +25,12 @@ def train( length: int, skip: int, batch_size: int, - **config + **config, ): pass @abc.abstractmethod - def rollout(self, - task: Task, - steps: int, - skip: int) -> tuple: + def rollout(self, task, steps: int, skip: int) -> tuple: pass @@ -45,7 +42,7 @@ def train( length: int, skip: int, batch_size: int, - **config + **config, ): """ Collect samples and train `steps` amount of gradient steps. @@ -100,44 +97,65 @@ def rollout(self, task: GymTask, steps: int, skip: int) -> tuple: ) self.model.to(device) - for st in ['obs', 'act']: - if getattr(task, f'{st}_dim') != getattr(self.model, f'{st}_dim'): - raise ValueError(f"The task must have observation dimension {self.model.obs_dim}") + for st in ["obs", "act"]: + if getattr(task, f"{st}_dim") != getattr(self.model, f"{st}_dim"): + raise ValueError( + f"The task must have observation dimension {self.model.obs_dim}" + ) env = task.env - observations = torch.zeros((steps, task.obs_dim), device=device, dtype=torch.float) + observations = torch.zeros( + (steps, task.obs_dim), device=device, dtype=torch.float + ) # Predicted action logits - action_logits = torch.zeros((steps, task.act_dim), device=device, dtype=torch.float) + action_logits = torch.zeros( + (steps, task.act_dim), device=device, dtype=torch.float + ) # The actual actions taken (argmax of action_logits) actions = torch.zeros((steps,), device=device, dtype=torch.long) # The actual one-hot encoded actions (nn.one_hot of actions) - actions_one_hot = torch.zeros((steps, task.act_dim), device=device, dtype=torch.float) + actions_one_hot = torch.zeros( + (steps, task.act_dim), device=device, dtype=torch.float + ) rewards = torch.zeros((steps, 1), device=device, dtype=torch.float) obs, done = None, True for step in range(steps): if done: - obs, done = torch.tensor(task.obs_post_process(np.array([env.reset()])), - device=device, - dtype=torch.float), False + obs, done = ( + torch.tensor( + task.obs_post_process(np.array([env.reset()])), + device=device, + dtype=torch.float, + ), + False, + ) # TODO: can be optimized using cache with torch.inference_mode(): action_logits[step] = self.model( - None if step < skip + 1 else observations[None, : step: skip + 1], - None if step < skip + 1 else actions_one_hot[None, : step: skip + 1], + None if step < skip + 1 else observations[None, : step : skip + 1], + None + if step < skip + 1 + else actions_one_hot[None, : step : skip + 1], None if step < skip + 1 else rewards[None, : step : skip + 1], current_obs=obs[None, 0], - action_only=True)[0, step] + action_only=True, + )[0, step] actions[step] = torch.argmax(action_logits[step]).type(torch.long) - actions_one_hot[step] = torch.nn.functional.one_hot(actions[step], - num_classes=task.act_dim).type(torch.float) + actions_one_hot[step] = torch.nn.functional.one_hot( + actions[step], num_classes=task.act_dim + ).type(torch.float) observations[step] = obs[None, 0] obs, rew, done, _ = env.step(actions[step].item()) - obs = torch.tensor(task.obs_post_process(np.array([obs])), device=device, dtype=torch.float) - rew = torch.tensor(task.rew_post_process(np.array([rew])), device=device, dtype=torch.float) + obs = torch.tensor( + task.obs_post_process(np.array([obs])), device=device, dtype=torch.float + ) + rew = torch.tensor( + task.rew_post_process(np.array([rew])), device=device, dtype=torch.float + ) rewards[step] = rew[0] return observations, actions, rewards @@ -153,13 +171,19 @@ def _get_data_iter( yield ( torch.tensor( - np.array([sample[0] for sample in samples]), dtype=torch.float, device=device + np.array([sample[0] for sample in samples]), + dtype=torch.float, + device=device, ), # observations torch.tensor( - np.array([sample[1] for sample in samples]), dtype=torch.long, device=device + np.array([sample[1] for sample in samples]), + dtype=torch.long, + device=device, ), # actions torch.tensor( - np.array([sample[2] for sample in samples]), dtype=torch.float, device=device + np.array([sample[2] for sample in samples]), + dtype=torch.float, + device=device, ), # rewards ) diff --git a/algorithm_distillation/models/gpt2.py b/algorithm_distillation/models/gpt2.py index 9db6093..a9a1525 100644 --- a/algorithm_distillation/models/gpt2.py +++ b/algorithm_distillation/models/gpt2.py @@ -124,7 +124,11 @@ def forward( """ if obs is None: assert current_obs is not None, "Empty input." - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + device = ( + torch.device("cuda") + if torch.cuda.is_available() + else torch.device("cpu") + ) batch_size, timestep = current_obs.shape[0], 0 else: device = obs.device @@ -138,9 +142,11 @@ def forward( current_step_id = timestep if step_ids is None and timestep > 0: - step_ids = torch.arange(0, timestep, dtype=torch.long, device=device).view( - 1, timestep - ).repeat((batch_size, 1)) + step_ids = ( + torch.arange(0, timestep, dtype=torch.long, device=device) + .view(1, timestep) + .repeat((batch_size, 1)) + ) if timestep == 0: embedded_obs, embedded_act, embedded_rew = None, None, None @@ -174,7 +180,9 @@ def forward( input_seq = self.layer_norm_in_embedding(input_seq) if timestep == 0: - attention_mask = torch.ones((batch_size, num_extra), dtype=torch.float, device=device) + attention_mask = torch.ones( + (batch_size, num_extra), dtype=torch.float, device=device + ) else: if attention_mask is None: attention_mask = torch.ones( @@ -186,7 +194,9 @@ def forward( attention_mask = torch.concat( [ attention_mask, - torch.ones((batch_size, num_extra), dtype=torch.float, device=device), + torch.ones( + (batch_size, num_extra), dtype=torch.float, device=device + ), ], dim=1, ) diff --git a/algorithm_distillation/sb3_util/__init__.py b/algorithm_distillation/sb3_util/__init__.py index 09fe6a3..03cd976 100644 --- a/algorithm_distillation/sb3_util/__init__.py +++ b/algorithm_distillation/sb3_util/__init__.py @@ -1,4 +1,3 @@ from .logger import CustomLogger, configure - -__all__ = ['CustomLogger', 'configure'] +__all__ = ["CustomLogger", "configure"] diff --git a/algorithm_distillation/sb3_util/logger.py b/algorithm_distillation/sb3_util/logger.py index 63a54fa..0565d61 100644 --- a/algorithm_distillation/sb3_util/logger.py +++ b/algorithm_distillation/sb3_util/logger.py @@ -2,7 +2,8 @@ import os import tempfile from collections import defaultdict -from typing import Any, Optional, Union, Tuple, List +from typing import Any, List, Optional, Tuple, Union + from stable_baselines3.common.logger import Logger, make_output_format @@ -17,18 +18,30 @@ def __init__(self, *args, **kwargs): self.history_value = defaultdict(list) self.history_mean_value = defaultdict(list) - def record(self, key: str, value: Any, exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None: + def record( + self, + key: str, + value: Any, + exclude: Optional[Union[str, Tuple[str, ...]]] = None, + ) -> None: super().record(key, value, exclude) self.history_value[key].append(value) - def record_mean(self, key: str, value: Any, exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None: + def record_mean( + self, + key: str, + value: Any, + exclude: Optional[Union[str, Tuple[str, ...]]] = None, + ) -> None: super().record_mean(key, value, exclude) self.history_mean_value[key].append(self.name_to_value[key]) -def configure(folder: Optional[str] = None, - format_strings: Optional[List[str]] = None, - logger_class=Logger) -> Logger: +def configure( + folder: Optional[str] = None, + format_strings: Optional[List[str]] = None, + logger_class=Logger, +) -> Logger: """ (This is an adaptation of SB3's logger configuration helper function. The change was very minor, only one line towards the end to allow for a customized logger class.) @@ -43,7 +56,10 @@ def configure(folder: Optional[str] = None, if folder is None: folder = os.getenv("SB3_LOGDIR") if folder is None: - folder = os.path.join(tempfile.gettempdir(), datetime.datetime.now().strftime("SB3-%Y-%m-%d-%H-%M-%S-%f")) + folder = os.path.join( + tempfile.gettempdir(), + datetime.datetime.now().strftime("SB3-%Y-%m-%d-%H-%M-%S-%f"), + ) assert isinstance(folder, str) os.makedirs(folder, exist_ok=True) diff --git a/algorithm_distillation/task.py b/algorithm_distillation/task.py index c2d9d7a..1792677 100644 --- a/algorithm_distillation/task.py +++ b/algorithm_distillation/task.py @@ -6,7 +6,6 @@ import gym import numpy as np import stable_baselines3 -import torch.nn.functional from stable_baselines3.common.buffers import ReplayBuffer @@ -18,6 +17,8 @@ class Task(abc.ABC): # `obs_dim` marks the total dimension of the observations obs_dim: int + act_dim: int + env: object @abc.abstractmethod def __init__(self, **kwargs): @@ -134,7 +135,7 @@ def _get_act_specs(self) -> tuple: if isinstance(act_space, gym.spaces.Discrete): return act_space.n, "discrete" - elif isinstance(act_space, gym.spaces.Box): + else: raise NotImplementedError( f"The observation space does not support {type(act_space)}." ) @@ -196,7 +197,9 @@ def _get_most_recent_history( """ pos = buffer.pos total_length = length * (skip + 1) - assert buffer.buffer_size > total_length, "Replay buffer size must be larger than the sequence length." + assert ( + buffer.buffer_size > total_length + ), "Replay buffer size must be larger than the sequence length." start = (pos - total_length + buffer.buffer_size) % buffer.buffer_size end = pos @@ -215,7 +218,9 @@ def _randomly_sample_buffer( :return: an (observations, actions, rewards) tuple. """ total_length = length * (skip + 1) - assert buffer.buffer_size > total_length, "Replay buffer size must be larger than the sequence length." + assert ( + buffer.buffer_size > total_length + ), "Replay buffer size must be larger than the sequence length." if not buffer.full: start = random.randint(0, buffer.pos - total_length) @@ -227,7 +232,9 @@ def _randomly_sample_buffer( return self._get_obs_act_rew(buffer, start, end, skip) @staticmethod - def _get_range(array: np.ndarray, start: int, end: int, interval: int) -> np.ndarray: + def _get_range( + array: np.ndarray, start: int, end: int, interval: int + ) -> np.ndarray: """ A helper function to either slice array[start:end:interval] or combine array[start::interval] and array[:end:interval] depending on whether start < end. @@ -240,7 +247,9 @@ def _get_range(array: np.ndarray, start: int, end: int, interval: int) -> np.nda if start < end: return array[start:end:interval] else: - return np.concatenate([array[start::interval], array[:end:interval]], axis=0) + return np.concatenate( + [array[start::interval], array[:end:interval]], axis=0 + ) def _get_obs_act_rew(self, buffer: ReplayBuffer, start: int, end: int, skip: int): """ @@ -261,4 +270,4 @@ def _get_obs_act_rew(self, buffer: ReplayBuffer, start: int, end: int, skip: int self.rew_post_process( self._get_range(buffer.rewards, start, end, skip + 1) ), - ) \ No newline at end of file + ) diff --git a/algorithm_distillation/task_manager.py b/algorithm_distillation/task_manager.py index c893e80..02f9c45 100644 --- a/algorithm_distillation/task_manager.py +++ b/algorithm_distillation/task_manager.py @@ -16,8 +16,8 @@ def __init__(self, tasks: List[Task]): if not tasks: raise ValueError("The task list cannot be empty.") for task in tasks[1:]: - for st in ['obs', 'act']: - if getattr(task, f'{st}_dim') != getattr(tasks[0], f'{st}_dim'): + for st in ["obs", "act"]: + if getattr(task, f"{st}_dim") != getattr(tasks[0], f"{st}_dim"): raise ValueError(f"All tasks must have the same {st}_dim.") self.tasks = tasks From dcc13505536401e82492f0c5825d93f2a4aec86a Mon Sep 17 00:00:00 2001 From: Honglu Fan Date: Mon, 19 Dec 2022 18:58:02 +0000 Subject: [PATCH 06/17] Added verbose and return more logging info. Added a couple paragraphs for README.md. --- README.md | 11 ++++++- algorithm_distillation/ad.py | 59 +++++++++++++++++++++++++++++------- requirements.txt | 3 +- tests/test_ad.py | 6 ++-- 4 files changed, 63 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index fea9271..03f91e1 100644 --- a/README.md +++ b/README.md @@ -1 +1,10 @@ -# Algorithm-Distillation-RLHF \ No newline at end of file +# Algorithm-Distillation-RLHF + +The current state of this repo is a preliminary version of a replication of the Algorithmic Distillation algorithm +described in the paper [In-context Reinforcement Learning with Algorithm Distillation](https://arxiv.org/abs/2210.14215). +We also aim to try ideas beyond. + +# Quick start +A test script is not provided yet, but the unit test `tests/test_ad.py` provides a complete routine of applying the +AD transformer to the histories of a single toy task "FrozenLake-v1". Please [take a look](tests/test_ad.py) and feel +free to plug in your own gym env. diff --git a/algorithm_distillation/ad.py b/algorithm_distillation/ad.py index 46529c6..3b79e33 100644 --- a/algorithm_distillation/ad.py +++ b/algorithm_distillation/ad.py @@ -1,7 +1,9 @@ import abc +import logging import numpy as np import torch +from tqdm import tqdm from algorithm_distillation.models.ad_transformer import ADTransformer @@ -16,6 +18,7 @@ class AlgorithmDistillation(abc.ABC): def __init__(self, model: ADTransformer): self.model = model + self.logger = logging.getLogger(__name__) @abc.abstractmethod def train( @@ -26,7 +29,7 @@ def train( skip: int, batch_size: int, **config, - ): + ) -> list: pass @abc.abstractmethod @@ -41,9 +44,11 @@ def train( steps: int, length: int, skip: int, - batch_size: int, + batch_size: int = 32, + lr: float = 1e-4, + verbose: int = 0, **config, - ): + ) -> list: """ Collect samples and train `steps` amount of gradient steps. @@ -51,24 +56,36 @@ def train( :param steps: the amount of gradient steps to train. :param length: the step-length of sampled sequences (not the sequence length which is 3x). :param skip: the amount of states to skip between two consecutive ones. - :param batch_size: the batch size. + :param batch_size: (Optional) the batch size. + :param lr: (Optional) the learning rate. + :param verbose: (Optional) verbose level. Nonzero => showing progress bar and certain logs. :param config: the extra config that goes into transformer training. - :return: None + :return: a list of losses """ + # Combine the config and the direct args. + # Note: direct args `batch_size` and `lr` override the config dict! + cfg = {**config, "batch_size": batch_size, "lr": lr} + # We implement a PyTorch training loop. # Use GPU if exists. device = ( torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") ) + if verbose: + self.logger.info(f"Device: {device.type}") + data_iter = self._get_data_iter( - steps, batch_size, task_manager, length, skip, device=device + steps, cfg["batch_size"], task_manager, length, skip, device=device ) self.model.to(device) - optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-2) + optimizer = torch.optim.Adam(self.model.parameters(), lr=cfg["lr"]) self.model.train() # Set to train mode so that dropout and batch norm would update. losses = [] - for step, sample in enumerate(data_iter): + + _tqdm_iter = tqdm(enumerate(data_iter), total=steps) + + for step, sample in _tqdm_iter if verbose else enumerate(data_iter): optimizer.zero_grad() obs, actions, rewards = sample one_hot_actions = torch.nn.functional.one_hot( @@ -81,15 +98,26 @@ def train( losses.append(loss.item()) optimizer.step() + if verbose: # Update loss if verbose is on + _tqdm_iter.set_postfix(ordered_dict={"loss": losses[-1]}) + self.model.eval() # By default, set to eval mode outside training. + return losses - def rollout(self, task: GymTask, steps: int, skip: int) -> tuple: + def rollout( + self, + task: GymTask, + steps: int, + skip: int, + verbose: int = 0, + ) -> tuple: """ Roll out for `steps` amount of steps (ignore the policy embedded in `task` and only uses its _env). :param task: the task to perform rollout on. :param steps: the amount of steps to roll out. :param skip: the amount of steps to skip (normally should be the same as `skip` during training). + :param verbose: (Optional) verbose level. Nonzero => showing progress bar and certain logs. :return: the full sequences (observations, actions, rewards), each of length `steps`. """ device = ( @@ -119,9 +147,13 @@ def rollout(self, task: GymTask, steps: int, skip: int) -> tuple: ) rewards = torch.zeros((steps, 1), device=device, dtype=torch.float) + terminals = torch.zeros((steps,), device=device, dtype=torch.bool) obs, done = None, True - for step in range(steps): + cum_reward = 0.0 + + _tqdm_iter = tqdm(range(steps)) + for step in _tqdm_iter if verbose else range(steps): if done: obs, done = ( torch.tensor( @@ -157,8 +189,13 @@ def rollout(self, task: GymTask, steps: int, skip: int) -> tuple: task.rew_post_process(np.array([rew])), device=device, dtype=torch.float ) rewards[step] = rew[0] + terminals[step] = done + cum_reward += rew[0] + + if verbose: # Update loss if verbose is on + _tqdm_iter.set_postfix(ordered_dict={"cum_reward": cum_reward}) - return observations, actions, rewards + return observations, actions, rewards, terminals @staticmethod def _get_data_iter( diff --git a/requirements.txt b/requirements.txt index 029afe5..8303bfe 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,5 @@ transformers>=4.24.0 torch>=1.12.1 pytest>=7.2.0 gym==0.21.0 -numpy>=1.23.5 \ No newline at end of file +numpy +tqdm \ No newline at end of file diff --git a/tests/test_ad.py b/tests/test_ad.py index 68266b1..c055bb3 100644 --- a/tests/test_ad.py +++ b/tests/test_ad.py @@ -29,7 +29,7 @@ def test_ad(): assert len(task.agent.logger.history_value['rollout/exploration_rate']) == 100 ad = GymAD(model) - ad.train(task_manager, 100, 10, skip=0, batch_size=8) + ad.train(task_manager, 100, 10, skip=0, batch_size=8, verbose=1) - obs, act, rew = ad.rollout(task, 16, 0) - print(obs, act, rew) + obs, act, rew, term = ad.rollout(task, 16, 0, verbose=1) + print(obs, act, rew, term) From 3caa347ef678f6b214816fc1d5f8f63881878f3a Mon Sep 17 00:00:00 2001 From: Honglu Fan Date: Mon, 19 Dec 2022 19:20:35 +0000 Subject: [PATCH 07/17] Fix tqdm. --- algorithm_distillation/ad.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/algorithm_distillation/ad.py b/algorithm_distillation/ad.py index 3b79e33..f8befe5 100644 --- a/algorithm_distillation/ad.py +++ b/algorithm_distillation/ad.py @@ -83,9 +83,8 @@ def train( self.model.train() # Set to train mode so that dropout and batch norm would update. losses = [] - _tqdm_iter = tqdm(enumerate(data_iter), total=steps) - - for step, sample in _tqdm_iter if verbose else enumerate(data_iter): + _tqdm_iter = tqdm(enumerate(data_iter), total=steps, disable=(verbose == 0)) + for step, sample in _tqdm_iter: optimizer.zero_grad() obs, actions, rewards = sample one_hot_actions = torch.nn.functional.one_hot( @@ -152,8 +151,8 @@ def rollout( obs, done = None, True cum_reward = 0.0 - _tqdm_iter = tqdm(range(steps)) - for step in _tqdm_iter if verbose else range(steps): + _tqdm_iter = tqdm(range(steps), disable=(verbose == 0)) + for step in _tqdm_iter: if done: obs, done = ( torch.tensor( From 1bb95fc19128c919f2c55e3b60703c61ef850ee7 Mon Sep 17 00:00:00 2001 From: Honglu Fan Date: Wed, 21 Dec 2022 19:27:35 -0500 Subject: [PATCH 08/17] Fix sampling behavior. --- algorithm_distillation/task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithm_distillation/task.py b/algorithm_distillation/task.py index 1792677..86fdfc4 100644 --- a/algorithm_distillation/task.py +++ b/algorithm_distillation/task.py @@ -226,7 +226,7 @@ def _randomly_sample_buffer( start = random.randint(0, buffer.pos - total_length) end = start + total_length else: - start = random.randint(0, buffer.buffer_size) + start = random.randint(0, buffer.buffer_size - total_length) + buffer.pos end = (start + total_length) % buffer.buffer_size return self._get_obs_act_rew(buffer, start, end, skip) From de320ee99cf95b9dd87146ac71d683823b7da4f8 Mon Sep 17 00:00:00 2001 From: Honglu Fan Date: Sat, 24 Dec 2022 17:47:58 +0000 Subject: [PATCH 09/17] Fix rollout behavior when skip is more than 0. --- algorithm_distillation/ad.py | 37 +++++++++------- .../models/ad_transformer.py | 3 ++ algorithm_distillation/models/gpt2.py | 39 +++++++++-------- algorithm_distillation/models/util.py | 43 +++++++++++++++++++ tests/test_ad.py | 8 +++- tests/test_module.py | 2 +- tests/test_util.py | 24 ++++++++++- 7 files changed, 120 insertions(+), 36 deletions(-) diff --git a/algorithm_distillation/ad.py b/algorithm_distillation/ad.py index f8befe5..f199194 100644 --- a/algorithm_distillation/ad.py +++ b/algorithm_distillation/ad.py @@ -6,6 +6,7 @@ from tqdm import tqdm from algorithm_distillation.models.ad_transformer import ADTransformer +from .models.util import get_sequence from .task import GymTask from .task_manager import TaskManager @@ -130,10 +131,15 @@ def rollout( f"The task must have observation dimension {self.model.obs_dim}" ) env = task.env + + # The max_len of history should be 1 less than max_step_len (leave room for current_obs) + # (recall that max sequence length for the transformer is `model.max_step_len * 3`) + max_len = self.model.max_step_len - 1 + + # Prepare sequential inputs/outputs for the transformer observations = torch.zeros( (steps, task.obs_dim), device=device, dtype=torch.float ) - # Predicted action logits action_logits = torch.zeros( (steps, task.act_dim), device=device, dtype=torch.float @@ -144,7 +150,6 @@ def rollout( actions_one_hot = torch.zeros( (steps, task.act_dim), device=device, dtype=torch.float ) - rewards = torch.zeros((steps, 1), device=device, dtype=torch.float) terminals = torch.zeros((steps,), device=device, dtype=torch.bool) @@ -153,7 +158,7 @@ def rollout( _tqdm_iter = tqdm(range(steps), disable=(verbose == 0)) for step in _tqdm_iter: - if done: + if done: # Last step was terminal. Reset. obs, done = ( torch.tensor( task.obs_post_process(np.array([env.reset()])), @@ -165,31 +170,33 @@ def rollout( # TODO: can be optimized using cache with torch.inference_mode(): + # The input of the model is collected from the index `step` (exclusive) going backwards + # with interval `skip + 1`. It goes as far back as possible until either the beginning or + # `max_len` steps (the maximal number of steps that can fit into the transformer) are + # collected. We then take the argmax of the prediction of the next action and perform the + # rollout. action_logits[step] = self.model( - None if step < skip + 1 else observations[None, : step : skip + 1], - None - if step < skip + 1 - else actions_one_hot[None, : step : skip + 1], - None if step < skip + 1 else rewards[None, : step : skip + 1], + get_sequence(observations, max_len, step, skip + 1)[None, :], + get_sequence(actions_one_hot, max_len, step, skip + 1)[None, :], + get_sequence(rewards, max_len, step, skip + 1)[None, :], current_obs=obs[None, 0], action_only=True, - )[0, step] + )[0, min(step // (skip + 1), max_len - 1)] + actions[step] = torch.argmax(action_logits[step]).type(torch.long) actions_one_hot[step] = torch.nn.functional.one_hot( actions[step], num_classes=task.act_dim ).type(torch.float) observations[step] = obs[None, 0] - obs, rew, done, _ = env.step(actions[step].item()) + obs, rew, done, _ = env.step(actions[step].item()) # are still np.ndarray obs = torch.tensor( task.obs_post_process(np.array([obs])), device=device, dtype=torch.float ) - rew = torch.tensor( - task.rew_post_process(np.array([rew])), device=device, dtype=torch.float - ) - rewards[step] = rew[0] + + rewards[step] = rew terminals[step] = done - cum_reward += rew[0] + cum_reward += rew if verbose: # Update loss if verbose is on _tqdm_iter.set_postfix(ordered_dict={"cum_reward": cum_reward}) diff --git a/algorithm_distillation/models/ad_transformer.py b/algorithm_distillation/models/ad_transformer.py index 334342b..bdb398b 100644 --- a/algorithm_distillation/models/ad_transformer.py +++ b/algorithm_distillation/models/ad_transformer.py @@ -6,6 +6,9 @@ class ADTransformer(abc.ABC, torch.nn.Module): obs_dim: int act_dim: int + # The maximal amount of steps in the input. + # Note: the corresponding max sequence length is `max_step_len * 3`. + max_step_len: int @abc.abstractmethod def __init__(self, **kwargs): diff --git a/algorithm_distillation/models/gpt2.py b/algorithm_distillation/models/gpt2.py index a9a1525..5a16ddf 100644 --- a/algorithm_distillation/models/gpt2.py +++ b/algorithm_distillation/models/gpt2.py @@ -35,7 +35,7 @@ def __init__( obs_dim, act_dim, hidden_size, - max_ep_len=4096, + max_step_len=1024, action_tanh=True, obs_emb_cls=torch.nn.Linear, act_emb_cls=torch.nn.Linear, @@ -48,7 +48,7 @@ def __init__( :param obs_dim: observation dimension (as a flattened tensor) :param act_dim: action dimension (as a flattened tensor) :param hidden_size: the dimension of the embedding space - :param max_ep_len: (Optional) maximal episode length + :param max_step_len: (Optional) maximal episode length :param action_tanh: (Optional) apply tanh activation function on the action :param obs_emb_cls: (Optional) the nn.Module class for observation embedding :param act_emb_cls: (Optional) the nn.Module class for action embedding @@ -61,13 +61,17 @@ def __init__( self.act_dim = act_dim self.hidden_size = hidden_size self.action_tanh = action_tanh + self.max_step_len = max_step_len # Generate the most basic GPT2 config - config = transformers.GPT2Config(vocab_size=1, n_embd=hidden_size, **kwargs) + config = transformers.GPT2Config(vocab_size=1, + n_embd=hidden_size, + n_positions=max_step_len * 3, + **kwargs) self.transformers = transformers.GPT2Model(config) # Remove the position embedding by replacing it with a dummy. self.transformers.wpe = ZeroDummy((hidden_size,)) # This is our position embedding based on steps. - self.step_embedding = torch.nn.Embedding(max_ep_len, self.hidden_size) + self.step_embedding = torch.nn.Embedding(max_step_len, self.hidden_size) # The embedding layers self.obs_embedding = obs_emb_cls(self.obs_dim, self.hidden_size) @@ -122,37 +126,38 @@ def forward( :param action_only: (Optional) return predicted actions only. :return: predicted action logits (if action_only) or predicted action logits, rewards, and obs. """ - if obs is None: + if obs is None or obs.numel() == 0: # None or empty tensor are regarded as no history assert current_obs is not None, "Empty input." device = ( torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") ) - batch_size, timestep = current_obs.shape[0], 0 + batch_size, history_length = current_obs.shape[0], 0 else: device = obs.device - batch_size, timestep, _ = obs.shape + batch_size, history_length, _ = obs.shape if current_step_id is None: if step_ids is not None: logger.warning( - "'current_step_id' defaults to the number of steps. But it may conflict with the given 'step_ids'." + "'current_step_id' will default to the number of history steps. " + "But it may conflict with the given 'step_ids'." ) - current_step_id = timestep + current_step_id = history_length - if step_ids is None and timestep > 0: + if step_ids is None and history_length > 0: step_ids = ( - torch.arange(0, timestep, dtype=torch.long, device=device) - .view(1, timestep) + torch.arange(0, history_length, dtype=torch.long, device=device) + .view(1, history_length) .repeat((batch_size, 1)) ) - if timestep == 0: + if history_length == 0: embedded_obs, embedded_act, embedded_rew = None, None, None else: embedded_steps = self.step_embedding(step_ids).view( - batch_size, timestep, self.hidden_size + batch_size, history_length, self.hidden_size ) embedded_obs = self.obs_embedding(obs) + embedded_steps embedded_act = self.act_embedding(actions) + embedded_steps @@ -179,14 +184,14 @@ def forward( input_seq = stack_seq(embedded_obs, embedded_act, embedded_rew, extra) input_seq = self.layer_norm_in_embedding(input_seq) - if timestep == 0: + if history_length == 0: # Empty history -> no need to concatenate with history. attention_mask = torch.ones( (batch_size, num_extra), dtype=torch.float, device=device ) - else: + else: # Nonempty history -> concatenate history sequence with current state info. if attention_mask is None: attention_mask = torch.ones( - (batch_size, timestep), dtype=torch.float, device=device + (batch_size, history_length), dtype=torch.float, device=device ) attention_mask = ( attention_mask.unsqueeze(-1).repeat((1, 1, 3)).view(batch_size, -1) diff --git a/algorithm_distillation/models/util.py b/algorithm_distillation/models/util.py index 541d970..d4eb49a 100644 --- a/algorithm_distillation/models/util.py +++ b/algorithm_distillation/models/util.py @@ -5,6 +5,7 @@ def stack_seq(obs, act, rew, extra=None) -> torch.Tensor: """ Stack up into a sequence (obs, act, rew, obs, act, rew, ...) in axis 1, and append extra in the end. + :param obs: shape (b, t, hidden_size) :param act: shape (b, t, hidden_size) :param rew: shape (b, t, hidden_size) @@ -21,3 +22,45 @@ def stack_seq(obs, act, rew, extra=None) -> torch.Tensor: return stacked else: return torch.concat([stacked, extra], dim=1) + + +def get_sequence(arr: torch.Tensor, + num_items: int, + end_idx: int, + interval: int): + """ + Get the subsequence of indices 'end_idx - num_items * interval, ..., end_idx - interval' from + the PyTorch tensor `arr`. + Note: + - While `end_idx` is excluded, the intervals start backwards from it. + - Negative indices are ignored (meaning the return could be shorter). + - If the length of `arr` is less than `end_idx`, treat `arr` as a circular buffer. + Example: + arr = [7, 8, 9], length is 3 + num_items = 3, end_idx = 4, interval = 2 + -> [7, 9] + num_items = 1, end_idx = 3, interval = 2 + -> [8] + num_items = 1, end_idx = 2, interval = 2 + -> [7] + + :param arr: a tensor whose first dimension is the index we are concerned with. + :param num_items: the max number of items. + :param end_idx: the end index. + :param interval: the interval length. + :return: a subsequence of `arr` according to the description. + """ + length = arr.size(0) + # The size of the circular buffer determines max how many items it can return + num_items = min(num_items, (length + interval - 1) // interval) + # Get the actual start index (inclusive) + start_idx = max(end_idx - num_items * interval, end_idx % interval) + + if end_idx >= length: + # The subseq cuts in the middle. Update `end_idx` to the actual end index on the second half. + end_idx = max(0, end_idx % length - interval + 1) + return torch.concat( + [arr[start_idx % length :: interval], arr[:end_idx:interval]], dim=0 + ) + else: + return arr[start_idx:end_idx:interval] diff --git a/tests/test_ad.py b/tests/test_ad.py index c055bb3..f39e527 100644 --- a/tests/test_ad.py +++ b/tests/test_ad.py @@ -11,7 +11,7 @@ def test_ad(): 'buffer_size': 1000, 'policy': 'MlpPolicy' } - model = GPT2AD(env.observation_space.n, env.action_space.n, 12, max_ep_len=16) + model = GPT2AD(env.observation_space.n, env.action_space.n, 12, max_step_len=16) task = GymTask(env, 'DQN', config) # Inject a customized logger @@ -31,5 +31,9 @@ def test_ad(): ad = GymAD(model) ad.train(task_manager, 100, 10, skip=0, batch_size=8, verbose=1) - obs, act, rew, term = ad.rollout(task, 16, 0, verbose=1) + obs, act, rew, term = ad.rollout(task, 100, 0, verbose=1) + assert obs.size(0) == 100 + print(obs, act, rew, term) + obs, act, rew, term = ad.rollout(task, 100, 2, verbose=1) + assert obs.size(0) == 100 print(obs, act, rew, term) diff --git a/tests/test_module.py b/tests/test_module.py index 2671c7e..abd6ef2 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -5,7 +5,7 @@ @pytest.mark.parametrize("n", [1, 2, 3]) def test_GPT2AD(n): - model = GPT2AD(2, n, 12, max_ep_len=16) + model = GPT2AD(2, n, 12, max_step_len=16) model.eval() # Disable dropout sample_obs = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=torch.float) # two samples, two steps in each trajectory: diff --git a/tests/test_util.py b/tests/test_util.py index 28b825d..e70fbc3 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,5 +1,5 @@ import torch -from algorithm_distillation.models.util import stack_seq +from algorithm_distillation.models.util import stack_seq, get_sequence def test_stack_seq(): @@ -34,3 +34,25 @@ def test_stack_seq(): # when batch size is 0, all of obs, act, rew should be None. It should just return the extra. result = stack_seq(None, None, None, extra) assert torch.all(result.isclose(extra)) + + +def test_get_sequence(): + arr = torch.tensor([7, 8, 9]) + n = 3 + end_idx = 4 + interval = 2 + assert all(get_sequence(arr, n, end_idx, interval) == torch.tensor([7, 9])) + n = 1 + end_idx = 3 + interval = 2 + assert all(get_sequence(arr, n, end_idx, interval) == torch.tensor([8])) + n = 1 + end_idx = 2 + interval = 2 + assert all(get_sequence(arr, n, end_idx, interval) == torch.tensor([7])) + + n = 3 + end_idx = 5 + interval = 2 + # can only fetch a max elem that the circular buffer allows (i.e., 2 in this case) + assert all(get_sequence(arr, n, end_idx, interval) == torch.tensor([8, 7])) From 305de88266197c2802c2db330ce93ba4c44368b6 Mon Sep 17 00:00:00 2001 From: Honglu Fan Date: Mon, 26 Dec 2022 02:18:17 +0000 Subject: [PATCH 10/17] Enforce types for rewards/terminals in ad rollout (some sloppy envs might not return the right types). --- algorithm_distillation/ad.py | 14 +++++++------- algorithm_distillation/models/gpt2.py | 11 ++++++----- algorithm_distillation/models/util.py | 9 +++------ 3 files changed, 16 insertions(+), 18 deletions(-) diff --git a/algorithm_distillation/ad.py b/algorithm_distillation/ad.py index f199194..0fd3642 100644 --- a/algorithm_distillation/ad.py +++ b/algorithm_distillation/ad.py @@ -6,8 +6,8 @@ from tqdm import tqdm from algorithm_distillation.models.ad_transformer import ADTransformer -from .models.util import get_sequence +from .models.util import get_sequence from .task import GymTask from .task_manager import TaskManager @@ -168,12 +168,12 @@ def rollout( False, ) - # TODO: can be optimized using cache + # TODO: can probably be optimized using kv cache with torch.inference_mode(): # The input of the model is collected from the index `step` (exclusive) going backwards # with interval `skip + 1`. It goes as far back as possible until either the beginning or - # `max_len` steps (the maximal number of steps that can fit into the transformer) are - # collected. We then take the argmax of the prediction of the next action and perform the + # `max_len` number of steps (the maximal number of steps that can fit into the transformer) + # We then take the argmax of the prediction of the next action and perform the # rollout. action_logits[step] = self.model( get_sequence(observations, max_len, step, skip + 1)[None, :], @@ -194,9 +194,9 @@ def rollout( task.obs_post_process(np.array([obs])), device=device, dtype=torch.float ) - rewards[step] = rew - terminals[step] = done - cum_reward += rew + rewards[step] = float(rew) + terminals[step] = bool(done) + cum_reward += float(rew) if verbose: # Update loss if verbose is on _tqdm_iter.set_postfix(ordered_dict={"cum_reward": cum_reward}) diff --git a/algorithm_distillation/models/gpt2.py b/algorithm_distillation/models/gpt2.py index 5a16ddf..256d11e 100644 --- a/algorithm_distillation/models/gpt2.py +++ b/algorithm_distillation/models/gpt2.py @@ -63,10 +63,9 @@ def __init__( self.action_tanh = action_tanh self.max_step_len = max_step_len # Generate the most basic GPT2 config - config = transformers.GPT2Config(vocab_size=1, - n_embd=hidden_size, - n_positions=max_step_len * 3, - **kwargs) + config = transformers.GPT2Config( + vocab_size=1, n_embd=hidden_size, n_positions=max_step_len * 3, **kwargs + ) self.transformers = transformers.GPT2Model(config) # Remove the position embedding by replacing it with a dummy. self.transformers.wpe = ZeroDummy((hidden_size,)) @@ -126,7 +125,9 @@ def forward( :param action_only: (Optional) return predicted actions only. :return: predicted action logits (if action_only) or predicted action logits, rewards, and obs. """ - if obs is None or obs.numel() == 0: # None or empty tensor are regarded as no history + if ( + obs is None or obs.numel() == 0 + ): # None or empty tensor are regarded as no history assert current_obs is not None, "Empty input." device = ( torch.device("cuda") diff --git a/algorithm_distillation/models/util.py b/algorithm_distillation/models/util.py index d4eb49a..ee5c8d1 100644 --- a/algorithm_distillation/models/util.py +++ b/algorithm_distillation/models/util.py @@ -24,10 +24,7 @@ def stack_seq(obs, act, rew, extra=None) -> torch.Tensor: return torch.concat([stacked, extra], dim=1) -def get_sequence(arr: torch.Tensor, - num_items: int, - end_idx: int, - interval: int): +def get_sequence(arr: torch.Tensor, num_items: int, end_idx: int, interval: int): """ Get the subsequence of indices 'end_idx - num_items * interval, ..., end_idx - interval' from the PyTorch tensor `arr`. @@ -60,7 +57,7 @@ def get_sequence(arr: torch.Tensor, # The subseq cuts in the middle. Update `end_idx` to the actual end index on the second half. end_idx = max(0, end_idx % length - interval + 1) return torch.concat( - [arr[start_idx % length :: interval], arr[:end_idx:interval]], dim=0 - ) + [arr[start_idx % length :: interval], arr[:end_idx:interval]], dim=0 + ) else: return arr[start_idx:end_idx:interval] From 2ca2eb8840d00dc818230296364fcc683113d16b Mon Sep 17 00:00:00 2001 From: Honglu Fan Date: Mon, 26 Dec 2022 20:17:59 +0000 Subject: [PATCH 11/17] Supports PPO and A2C. Expands environment to continuous observation like CartPole. --- algorithm_distillation/sb3_util/callback.py | 63 ++++++++++++++++++ algorithm_distillation/task.py | 73 ++++++++++++++++----- algorithm_distillation/task_manager.py | 5 +- tests/test_ad.py | 41 ++++++++---- tests/test_env.py | 2 +- tests/test_sb3_util.py | 35 ++++++++++ 6 files changed, 189 insertions(+), 30 deletions(-) create mode 100644 algorithm_distillation/sb3_util/callback.py create mode 100644 tests/test_sb3_util.py diff --git a/algorithm_distillation/sb3_util/callback.py b/algorithm_distillation/sb3_util/callback.py new file mode 100644 index 0000000..c0772d1 --- /dev/null +++ b/algorithm_distillation/sb3_util/callback.py @@ -0,0 +1,63 @@ +from stable_baselines3.common.buffers import ReplayBuffer +from stable_baselines3.common.callbacks import BaseCallback +from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm + + +class RolloutCallback(BaseCallback): + """ + This is a custom callback that collects rollouts from an on-policy algorithm. + :param buffer: The external replay buffer to save rollouts. + :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages. + """ + + def __init__( + self, agent: OnPolicyAlgorithm, buffer: ReplayBuffer, verbose: int = 0 + ): + super().__init__(verbose) + self.agent = agent + self.buffer = buffer + assert ( + self.buffer.buffer_size >= self.agent.rollout_buffer.buffer_size + ), "External replay buffer must be larger than the agent rollout buffer." + self._enforce_type() + + def _on_step(self) -> bool: + return True + + def _on_rollout_end(self) -> None: + start = self.buffer.pos + end = start + self.agent.rollout_buffer.buffer_size + + buffer = self.buffer + tgt_buffer = self.agent.rollout_buffer + + ranges = [(start, min(buffer.buffer_size, end))] + tgt_start = [0] + if buffer.buffer_size < end: # If overflown, wrap around to the beginning + ranges.append((0, end - buffer.buffer_size)) + tgt_start.append(buffer.buffer_size) + buffer.full = True + + for (st, ed), tst in zip(ranges, tgt_start): + buffer.observations[st:ed] = tgt_buffer.observations[ + tst : tst + ed - st + ].copy() + buffer.actions[st:ed] = tgt_buffer.actions[tst : tst + ed - st].copy() + buffer.rewards[st:ed] = tgt_buffer.rewards[tst : tst + ed - st].copy() + self.buffer.pos = ed # Update the pointer to the last ending range + + def _enforce_type(self) -> None: + # Rollout buffer's observation and action are float and it can cause inconsistency. + # So we force the types of buffer to be the same (should we emit a warning?). + buffer = self.buffer + tgt_buffer = self.agent.rollout_buffer + if tgt_buffer.observations is None: + raise RuntimeError("The rollout buffer is not initialized.") + + obs_type = tgt_buffer.observations.dtype + action_type = tgt_buffer.actions.dtype + reward_type = tgt_buffer.rewards.dtype + + buffer.observations = buffer.observations.astype(obs_type) + buffer.actions = buffer.actions.astype(action_type) + buffer.rewards = buffer.rewards.astype(reward_type) diff --git a/algorithm_distillation/task.py b/algorithm_distillation/task.py index 86fdfc4..97ba93e 100644 --- a/algorithm_distillation/task.py +++ b/algorithm_distillation/task.py @@ -8,6 +8,8 @@ import stable_baselines3 from stable_baselines3.common.buffers import ReplayBuffer +from algorithm_distillation.sb3_util.callback import RolloutCallback + class Task(abc.ABC): """ @@ -25,11 +27,12 @@ def __init__(self, **kwargs): pass @abc.abstractmethod - def train(self, steps: int): + def train(self, steps: int, **kwargs): """ Train a number of steps. - :param steps: the number of steps to train + :param steps: the number of steps to train. + :param kwargs: (Optional) extra parameters for SB3 `learn` method. :return: None """ pass @@ -49,9 +52,12 @@ def sample_history(self, length: int, skip: int = 0) -> tuple: class GymTask(Task): - # TODO: Still need some work to set up the on-policy algorithms due to problems in their rollout buffers. - _algorithms = {"DQN": stable_baselines3.DQN} - _on_policy = ("PPO",) + # TODO: support TD3 for continuous action space? + _algorithms = {"DQN": stable_baselines3.DQN, + "PPO": stable_baselines3.PPO, + "A2C": stable_baselines3.A2C, + } + _on_policy = ("PPO", "A2C") # If the environment has a discrete obs space of n classes, return an (n,)-shape array for each obs. # If the environment has a continuous obs space of certain shape, return a flattened array for each obs. @@ -61,12 +67,20 @@ class GymTask(Task): obs_cls: str act_cls: str - def __init__(self, env, algorithm: str, config: Optional[dict] = None): + def __init__( + self, + env, + algorithm: str, + buffer_size: int = 10_000, + config: Optional[dict] = None, + ): """ GymTask takes in a gym environment and set up a stable-baselines3 algorithm to train a policy network. :param env: the gym environment :param algorithm: the stable-baselines3 algorithm + :param buffer_size: (Optional) the buffer size. This refers to the stored history length where the + transformer samples from, and applies to both on-policy and off-policy :param config: (Optional) the stable-baselines3 algorithm config (except for the gym environment) """ self._env = env @@ -78,19 +92,49 @@ def __init__(self, env, algorithm: str, config: Optional[dict] = None): raise ValueError( f"Input must be one of {self._algorithms.keys()}. Got {algorithm} instead." ) + self.buffer_size = buffer_size + # Set up the agent self.config = {} if config is None else config.copy() self.config["env"] = self.env - # Default to MultiInputPolicy which is applicable most of the time. - if "policy" not in self.config: - self.config["policy"] = "MultiInputPolicy" - + if "buffer_size" in self.config: + self.config[ + "buffer_size" + ] = self.buffer_size # Overwrite policy buffer size if exists + + # Set up the defaults + defaults = { + "policy": "MultiInputPolicy", # a common policy applicable most of the time + } + # Set up the defaults specifically for either on-policy or off-policy algos + if self.algorithm in self._on_policy: + policy_level_defaults = { + "n_steps": 2048, + } + else: + policy_level_defaults = {} + self.config = {**defaults, **policy_level_defaults, **self.config} + # Set up the agent self.agent = self._algorithms[algorithm](**self.config) + + # If on-policy, set up the the callback to collect rollout + if self.algorithm in self._on_policy: + buffer_config = { + "buffer_size": self.buffer_size, + "observation_space": self._env.observation_space, + "action_space": self._env.action_space, + } + self.buffer = ReplayBuffer(**buffer_config) + self.callback = RolloutCallback(self.agent, self.buffer) + else: + self.buffer = self.agent.replay_buffer + self.callback = None + self.obs_dim, self.obs_cls = self._get_obs_specs() self.act_dim, self.act_cls = self._get_act_specs() - def train(self, steps: int): - self.agent.learn(total_timesteps=steps) + def train(self, steps: int, **kwargs): + self.agent.learn(total_timesteps=steps, callback=self.callback, **kwargs) def sample_history( self, length: int, skip: int = 0, most_recent: bool = False @@ -105,10 +149,7 @@ def sample_history( :param most_recent: (Optional) get the most recent histories. False to sample randomly. :return: a tuple of (observations, actions, rewards). Each is a tensor. """ - if self.algorithm in self._on_policy: - raise NotImplementedError("Not supporting on-policy algorithms yet.") - - buffer = self.agent.replay_buffer + buffer = self.buffer if buffer.n_envs != 1: raise NotImplementedError("Not supporting parallel environments yet.") diff --git a/algorithm_distillation/task_manager.py b/algorithm_distillation/task_manager.py index 02f9c45..523b9e3 100644 --- a/algorithm_distillation/task_manager.py +++ b/algorithm_distillation/task_manager.py @@ -22,15 +22,16 @@ def __init__(self, tasks: List[Task]): self.tasks = tasks - def train(self, steps: int): + def train(self, steps: int, **kwargs): """ Train `steps` amount of steps for all the tasks. :param steps: the amount of gradient steps to train + :param kwargs: (Optional) extra parameters for SB3 `learn` method. :return: None """ for task in self.tasks: - task.train(steps) + task.train(steps, **kwargs) def sample_history(self, length: int, skip: int = 0) -> tuple: """ diff --git a/tests/test_ad.py b/tests/test_ad.py index f39e527..d3f455c 100644 --- a/tests/test_ad.py +++ b/tests/test_ad.py @@ -1,32 +1,51 @@ import gym +import pytest from algorithm_distillation import GymTask, TaskManager, GymAD from algorithm_distillation.models import GPT2AD from algorithm_distillation.sb3_util import CustomLogger, configure -def test_ad(): - env = gym.make('FrozenLake-v1') - config = {'learning_starts': 10, - 'buffer_size': 1000, - 'policy': 'MlpPolicy' - } - model = GPT2AD(env.observation_space.n, env.action_space.n, 12, max_step_len=16) +@pytest.mark.parametrize("policy", ['DQN', 'PPO', 'A2C']) +@pytest.mark.parametrize("env_name", ['FrozenLake-v1', 'CartPole-v1']) +def test_ad(policy: str, env_name: str): + env = gym.make(env_name) + + if policy in ['DQN', 'TD3']: + config = {'learning_starts': 10, + 'policy': 'MlpPolicy' + } + else: + config = {'n_steps': 100, + 'batch_size': 10, + 'policy': 'MlpPolicy' + } + if policy == 'A2C': + config.pop('batch_size') + + task = GymTask(env, policy, buffer_size=1000, config=config) + model = GPT2AD(task.obs_dim, task.act_dim, 12, max_step_len=16) - task = GymTask(env, 'DQN', config) # Inject a customized logger logger = configure(None, None, CustomLogger) task.agent.set_logger(logger) task_manager = TaskManager([task]) - task_manager.train(100) + task_manager.train(100, log_interval=1) assert 'history_value' in task.agent.logger.__dir__() # 100 total time-steps, but training only happens upon the finish of an episode. We don't know how many gradient # steps are trained, but we are sure it is nonzero. - assert len(task.agent.logger.history_value['train/loss']) != 0 + loss_key = 'train/policy_loss' if policy == 'A2C' else 'train/loss' + assert len(task.agent.logger.history_value[loss_key]) != 0 # But we are sure that rollout happens 100 times. - assert len(task.agent.logger.history_value['rollout/exploration_rate']) == 100 + if policy in ['DQN']: + assert len(task.agent.logger.history_value['rollout/exploration_rate']) == 100 + else: + # Only logged once, because training step (100) equals the rollout length. + # SB3 off-policy algorithms first collect rollouts accumulating `n_steps` until rollout buffer is full, + # and then check if training continues. + assert len(task.agent.logger.history_value['rollout/ep_rew_mean']) == 1 ad = GymAD(model) ad.train(task_manager, 100, 10, skip=0, batch_size=8, verbose=1) diff --git a/tests/test_env.py b/tests/test_env.py index ff049f2..375c21d 100644 --- a/tests/test_env.py +++ b/tests/test_env.py @@ -8,7 +8,7 @@ def test_gym_task(): 'buffer_size': 1000, 'policy': 'MlpPolicy' } - task = GymTask(env, 'DQN', config) + task = GymTask(env, 'DQN', config=config) assert task.obs_cls == 'discrete' assert task.obs_dim == 16 # The default observation space of FrozenLake is discrete 16 labels diff --git a/tests/test_sb3_util.py b/tests/test_sb3_util.py new file mode 100644 index 0000000..d5a4559 --- /dev/null +++ b/tests/test_sb3_util.py @@ -0,0 +1,35 @@ +import gym +import numpy as np +import stable_baselines3 +from stable_baselines3.common.buffers import ReplayBuffer + +from algorithm_distillation import GymTask +from algorithm_distillation.sb3_util.callback import RolloutCallback + + +def test_callback(): + env = gym.make('FrozenLake-v1') + config = {'batch_size': 10, + 'n_steps': 100, # determines the rollout size + 'policy': 'MlpPolicy' + } + buffer_config = {'buffer_size': 1000, + 'observation_space': env.observation_space, + 'action_space': env.action_space} + buffer = ReplayBuffer(**buffer_config) + agent = stable_baselines3.PPO(env=env, **config) + cb = RolloutCallback(agent, buffer) + + assert buffer.pos == 0 + agent.learn(200, callback=cb) + assert buffer.pos == 200 + # Note: There is a subtlety in comparing shapes of two buffers after the learning is finished. + # RolloutBuffer.get changes the shapes of everything. It will call swap_and_flatten which flattens + # the first two dimensions: (buffer_size, n_env, ...) -> (buffer_size * n_env, ...) + assert np.all(np.isclose(buffer.observations[100:200].flatten(), + agent.rollout_buffer.observations.flatten())) + assert np.all(np.isclose(buffer.actions[100:200].flatten(), + agent.rollout_buffer.actions.flatten())) + assert np.all(np.isclose(buffer.rewards[100:200].flatten(), + agent.rollout_buffer.rewards.flatten())) + From fd58ec7ace8587c60a064ffd604bbb7a09ef56a9 Mon Sep 17 00:00:00 2001 From: Honglu Fan Date: Mon, 26 Dec 2022 20:18:45 +0000 Subject: [PATCH 12/17] Fix format. --- algorithm_distillation/task.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/algorithm_distillation/task.py b/algorithm_distillation/task.py index 97ba93e..5734ae8 100644 --- a/algorithm_distillation/task.py +++ b/algorithm_distillation/task.py @@ -53,10 +53,11 @@ def sample_history(self, length: int, skip: int = 0) -> tuple: class GymTask(Task): # TODO: support TD3 for continuous action space? - _algorithms = {"DQN": stable_baselines3.DQN, - "PPO": stable_baselines3.PPO, - "A2C": stable_baselines3.A2C, - } + _algorithms = { + "DQN": stable_baselines3.DQN, + "PPO": stable_baselines3.PPO, + "A2C": stable_baselines3.A2C, + } _on_policy = ("PPO", "A2C") # If the environment has a discrete obs space of n classes, return an (n,)-shape array for each obs. From e2ae0597034724f34dab2c83cc6fed2357df38ed Mon Sep 17 00:00:00 2001 From: Honglu Fan Date: Mon, 26 Dec 2022 20:36:28 +0000 Subject: [PATCH 13/17] README and docstrings. --- README.md | 8 ++++---- algorithm_distillation/sb3_util/callback.py | 1 + algorithm_distillation/sb3_util/logger.py | 9 +++++---- algorithm_distillation/task.py | 2 ++ 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 03f91e1..6255ade 100644 --- a/README.md +++ b/README.md @@ -2,9 +2,9 @@ The current state of this repo is a preliminary version of a replication of the Algorithmic Distillation algorithm described in the paper [In-context Reinforcement Learning with Algorithm Distillation](https://arxiv.org/abs/2210.14215). -We also aim to try ideas beyond. +We also aim to make the codes general enough to try ideas beyond the paper. # Quick start -A test script is not provided yet, but the unit test `tests/test_ad.py` provides a complete routine of applying the -AD transformer to the histories of a single toy task "FrozenLake-v1". Please [take a look](tests/test_ad.py) and feel -free to plug in your own gym env. +A demo script/notebook is not provided yet, but the unit test `tests/test_ad.py` provides a complete routine of applying +the transformer to the histories of toy tasks "FrozenLake-v1", "CartPole-v1". Please [take a look](tests/test_ad.py) +and feel free to plug in your own gym env. diff --git a/algorithm_distillation/sb3_util/callback.py b/algorithm_distillation/sb3_util/callback.py index c0772d1..c72b5e5 100644 --- a/algorithm_distillation/sb3_util/callback.py +++ b/algorithm_distillation/sb3_util/callback.py @@ -6,6 +6,7 @@ class RolloutCallback(BaseCallback): """ This is a custom callback that collects rollouts from an on-policy algorithm. + :param buffer: The external replay buffer to save rollouts. :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages. """ diff --git a/algorithm_distillation/sb3_util/logger.py b/algorithm_distillation/sb3_util/logger.py index 0565d61..fb1150c 100644 --- a/algorithm_distillation/sb3_util/logger.py +++ b/algorithm_distillation/sb3_util/logger.py @@ -9,8 +9,8 @@ class CustomLogger(Logger): """ - A logger object can be plugged into an SB3 agent to record the metric. Here we customize it by overriding - in order to add a history. One can further customize to connect with wandb. + A logger object can be plugged into an SB3 agent to record the metrics. Here we customize it to save metric + histories. One can further customize it and implement, for example, the connection with wandb. """ def __init__(self, *args, **kwargs): @@ -43,9 +43,10 @@ def configure( logger_class=Logger, ) -> Logger: """ - (This is an adaptation of SB3's logger configuration helper function. The change was very minor, only one line - towards the end to allow for a customized logger class.) Configure the current logger. + (This is almost the same as SB3's logger configuration helper function except one line in the parameter and + another line towards the end to allow for customized logger classes.) + :param folder: the save location (if None, $SB3_LOGDIR, if still None, tempdir/SB3-[date & time]) :param format_strings: the output logging format diff --git a/algorithm_distillation/task.py b/algorithm_distillation/task.py index 5734ae8..8ec3ae9 100644 --- a/algorithm_distillation/task.py +++ b/algorithm_distillation/task.py @@ -280,6 +280,7 @@ def _get_range( """ A helper function to either slice array[start:end:interval] or combine array[start::interval] and array[:end:interval] depending on whether start < end. + :param array: the sliced array. :param start: the starting index. :param end: the ending index (exclusive). @@ -296,6 +297,7 @@ def _get_range( def _get_obs_act_rew(self, buffer: ReplayBuffer, start: int, end: int, skip: int): """ Return a tuple (obs, act, rew) sampled according to the buffer and the parameters. + :param buffer: the replay buffer. :param start: the starting index. :param end: the ending index. From 9e07368068bec2e940d12af37b7fd0268f0a0d1d Mon Sep 17 00:00:00 2001 From: Honglu Fan Date: Thu, 29 Dec 2022 00:47:20 +0000 Subject: [PATCH 14/17] Fix various bugs (length size issue, callback collection length issue, ...). --- algorithm_distillation/sb3_util/callback.py | 2 +- algorithm_distillation/task.py | 24 ++++++------ tests/test_ad.py | 8 ++-- tests/test_env.py | 43 +++++++++++++-------- 4 files changed, 43 insertions(+), 34 deletions(-) diff --git a/algorithm_distillation/sb3_util/callback.py b/algorithm_distillation/sb3_util/callback.py index c72b5e5..f1de03f 100644 --- a/algorithm_distillation/sb3_util/callback.py +++ b/algorithm_distillation/sb3_util/callback.py @@ -36,7 +36,7 @@ def _on_rollout_end(self) -> None: tgt_start = [0] if buffer.buffer_size < end: # If overflown, wrap around to the beginning ranges.append((0, end - buffer.buffer_size)) - tgt_start.append(buffer.buffer_size) + tgt_start.append(buffer.buffer_size - start) buffer.full = True for (st, ed), tst in zip(ranges, tgt_start): diff --git a/algorithm_distillation/task.py b/algorithm_distillation/task.py index 8ec3ae9..ee740fd 100644 --- a/algorithm_distillation/task.py +++ b/algorithm_distillation/task.py @@ -98,10 +98,9 @@ def __init__( # Set up the agent self.config = {} if config is None else config.copy() self.config["env"] = self.env - if "buffer_size" in self.config: - self.config[ - "buffer_size" - ] = self.buffer_size # Overwrite policy buffer size if exists + if "buffer_size" in self.config or algorithm not in self._on_policy: + # Overwrite policy buffer size if exists/doing off-policy + self.config["buffer_size"] = self.buffer_size # Set up the defaults defaults = { @@ -150,15 +149,13 @@ def sample_history( :param most_recent: (Optional) get the most recent histories. False to sample randomly. :return: a tuple of (observations, actions, rewards). Each is a tensor. """ - buffer = self.buffer - - if buffer.n_envs != 1: + if self.buffer.n_envs != 1: raise NotImplementedError("Not supporting parallel environments yet.") if most_recent: - return self._get_most_recent_history(buffer, length, skip) + return self._get_most_recent_history(self.buffer, length, skip) else: - return self._randomly_sample_buffer(buffer, length, skip) + return self._randomly_sample_buffer(self.buffer, length, skip) def _get_obs_specs(self) -> tuple: obs_space = self.env.observation_space @@ -238,7 +235,8 @@ def _get_most_recent_history( :return: an (observations, actions, rewards) tuple. """ pos = buffer.pos - total_length = length * (skip + 1) + total_length = (length - 1) * (skip + 1) + 1 + assert ( buffer.buffer_size > total_length ), "Replay buffer size must be larger than the sequence length." @@ -259,7 +257,7 @@ def _randomly_sample_buffer( :param skip: the amount to skip between states. :return: an (observations, actions, rewards) tuple. """ - total_length = length * (skip + 1) + total_length = (length - 1) * (skip + 1) + 1 assert ( buffer.buffer_size > total_length ), "Replay buffer size must be larger than the sequence length." @@ -268,7 +266,9 @@ def _randomly_sample_buffer( start = random.randint(0, buffer.pos - total_length) end = start + total_length else: - start = random.randint(0, buffer.buffer_size - total_length) + buffer.pos + start = ( + random.randint(0, buffer.buffer_size - total_length) + buffer.pos + ) % buffer.buffer_size end = (start + total_length) % buffer.buffer_size return self._get_obs_act_rew(buffer, start, end, skip) diff --git a/tests/test_ad.py b/tests/test_ad.py index d3f455c..d9dbbf3 100644 --- a/tests/test_ad.py +++ b/tests/test_ad.py @@ -16,14 +16,14 @@ def test_ad(policy: str, env_name: str): 'policy': 'MlpPolicy' } else: - config = {'n_steps': 100, + config = {'n_steps': 30, 'batch_size': 10, 'policy': 'MlpPolicy' } if policy == 'A2C': config.pop('batch_size') - task = GymTask(env, policy, buffer_size=1000, config=config) + task = GymTask(env, policy, buffer_size=100, config=config) model = GPT2AD(task.obs_dim, task.act_dim, 12, max_step_len=16) # Inject a customized logger @@ -45,14 +45,12 @@ def test_ad(policy: str, env_name: str): # Only logged once, because training step (100) equals the rollout length. # SB3 off-policy algorithms first collect rollouts accumulating `n_steps` until rollout buffer is full, # and then check if training continues. - assert len(task.agent.logger.history_value['rollout/ep_rew_mean']) == 1 + assert len(task.agent.logger.history_value['rollout/ep_rew_mean']) == 4 ad = GymAD(model) ad.train(task_manager, 100, 10, skip=0, batch_size=8, verbose=1) obs, act, rew, term = ad.rollout(task, 100, 0, verbose=1) assert obs.size(0) == 100 - print(obs, act, rew, term) obs, act, rew, term = ad.rollout(task, 100, 2, verbose=1) assert obs.size(0) == 100 - print(obs, act, rew, term) diff --git a/tests/test_env.py b/tests/test_env.py index 375c21d..67c118e 100644 --- a/tests/test_env.py +++ b/tests/test_env.py @@ -1,29 +1,40 @@ import gym +import pytest + from algorithm_distillation import GymTask -def test_gym_task(): +@pytest.mark.parametrize("policy", ['DQN', 'PPO', 'A2C']) +def test_gym_task(policy: str): env = gym.make('FrozenLake-v1') - config = {'learning_starts': 10, - 'buffer_size': 1000, - 'policy': 'MlpPolicy' - } - task = GymTask(env, 'DQN', config=config) + if policy in ['DQN', 'TD3']: + config = {'learning_starts': 10, + 'policy': 'MlpPolicy' + } + else: + config = {'n_steps': 3, + 'batch_size': 10, + 'policy': 'MlpPolicy' + } + if policy == 'A2C': + config.pop('batch_size') + task = GymTask(env, policy, buffer_size=100, config=config) assert task.obs_cls == 'discrete' assert task.obs_dim == 16 # The default observation space of FrozenLake is discrete 16 labels - task.train(100) + task.train(150) for most_recent in [True, False]: - sample = task.sample_history(10, most_recent=most_recent) - assert sample[0].shape == (10, 16) # Observations are discrete classes - assert sample[1].shape == (10, 1) # Actions are discrete classes - assert sample[2].shape == (10, 1) # Rewards. - - sample = task.sample_history(10, skip=2, most_recent=most_recent) - assert sample[0].shape == (10, 16) # Observations are discrete classes - assert sample[1].shape == (10, 1) # Actions are discrete classes - assert sample[2].shape == (10, 1) # Rewards. + for _ in range(100): # There is a bit of randomness. Try many times. + sample = task.sample_history(10, most_recent=most_recent) + assert sample[0].shape == (10, 16) # Observations are discrete classes + assert sample[1].shape == (10, 1) # Actions are discrete classes + assert sample[2].shape == (10, 1) # Rewards. + + sample = task.sample_history(10, skip=3, most_recent=most_recent) + assert sample[0].shape == (10, 16) # Observations are discrete classes + assert sample[1].shape == (10, 1) # Actions are discrete classes + assert sample[2].shape == (10, 1) # Rewards. # More than buffer try: From 89d91c10dd3a253ac9e14529762c3ad7e8bd5862 Mon Sep 17 00:00:00 2001 From: Honglu Fan Date: Thu, 29 Dec 2022 05:50:03 -0500 Subject: [PATCH 15/17] A2C steps seem undeterministic... leave it alone for now. --- tests/test_ad.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_ad.py b/tests/test_ad.py index d9dbbf3..331cd14 100644 --- a/tests/test_ad.py +++ b/tests/test_ad.py @@ -45,7 +45,7 @@ def test_ad(policy: str, env_name: str): # Only logged once, because training step (100) equals the rollout length. # SB3 off-policy algorithms first collect rollouts accumulating `n_steps` until rollout buffer is full, # and then check if training continues. - assert len(task.agent.logger.history_value['rollout/ep_rew_mean']) == 4 + assert len(task.agent.logger.history_value['rollout/ep_rew_mean']) > 0 # A2C steps not deterministic?? ad = GymAD(model) ad.train(task_manager, 100, 10, skip=0, batch_size=8, verbose=1) From 2e0d95a244343a1d5b828bc5a081df3faa355b6c Mon Sep 17 00:00:00 2001 From: Honglu Fan Date: Thu, 12 Jan 2023 18:33:27 +0000 Subject: [PATCH 16/17] Refactored folder structures to pave the way for a merge. --- algorithm_distillation/__init__.py | 6 ++---- algorithm_distillation/models/__init__.py | 2 +- algorithm_distillation/{ => models}/sb3_util/__init__.py | 0 algorithm_distillation/{ => models}/sb3_util/callback.py | 0 algorithm_distillation/{ => models}/sb3_util/logger.py | 0 algorithm_distillation/rl_tasks/__init__.py | 5 +++++ algorithm_distillation/{ => rl_tasks}/ad.py | 2 +- algorithm_distillation/{ => rl_tasks}/task.py | 2 +- algorithm_distillation/{ => rl_tasks}/task_manager.py | 0 tests/test_ad.py | 4 ++-- tests/test_env.py | 2 +- tests/test_sb3_util.py | 3 +-- 12 files changed, 14 insertions(+), 12 deletions(-) rename algorithm_distillation/{ => models}/sb3_util/__init__.py (100%) rename algorithm_distillation/{ => models}/sb3_util/callback.py (100%) rename algorithm_distillation/{ => models}/sb3_util/logger.py (100%) create mode 100644 algorithm_distillation/rl_tasks/__init__.py rename algorithm_distillation/{ => rl_tasks}/ad.py (99%) rename algorithm_distillation/{ => rl_tasks}/task.py (99%) rename algorithm_distillation/{ => rl_tasks}/task_manager.py (100%) diff --git a/algorithm_distillation/__init__.py b/algorithm_distillation/__init__.py index 0a31263..6145a3d 100644 --- a/algorithm_distillation/__init__.py +++ b/algorithm_distillation/__init__.py @@ -1,5 +1,3 @@ -from .ad import AlgorithmDistillation, GymAD -from .task import GymTask, Task -from .task_manager import TaskManager +from algorithm_distillation import rl_tasks -__all__ = ["AlgorithmDistillation", "GymAD", "Task", "GymTask", "TaskManager"] +__all__ = ["rl_tasks"] diff --git a/algorithm_distillation/models/__init__.py b/algorithm_distillation/models/__init__.py index 6a82b3e..c6300e9 100644 --- a/algorithm_distillation/models/__init__.py +++ b/algorithm_distillation/models/__init__.py @@ -1,3 +1,3 @@ from .gpt2 import GPT2AD -__all__ = ["util", "GPT2AD"] +__all__ = ["sb3_util", "util", "GPT2AD"] diff --git a/algorithm_distillation/sb3_util/__init__.py b/algorithm_distillation/models/sb3_util/__init__.py similarity index 100% rename from algorithm_distillation/sb3_util/__init__.py rename to algorithm_distillation/models/sb3_util/__init__.py diff --git a/algorithm_distillation/sb3_util/callback.py b/algorithm_distillation/models/sb3_util/callback.py similarity index 100% rename from algorithm_distillation/sb3_util/callback.py rename to algorithm_distillation/models/sb3_util/callback.py diff --git a/algorithm_distillation/sb3_util/logger.py b/algorithm_distillation/models/sb3_util/logger.py similarity index 100% rename from algorithm_distillation/sb3_util/logger.py rename to algorithm_distillation/models/sb3_util/logger.py diff --git a/algorithm_distillation/rl_tasks/__init__.py b/algorithm_distillation/rl_tasks/__init__.py new file mode 100644 index 0000000..e2a0157 --- /dev/null +++ b/algorithm_distillation/rl_tasks/__init__.py @@ -0,0 +1,5 @@ +from algorithm_distillation.rl_tasks.ad import AlgorithmDistillation, GymAD +from algorithm_distillation.rl_tasks.task import GymTask, Task +from algorithm_distillation.rl_tasks.task_manager import TaskManager + +__all__ = ["AlgorithmDistillation", "GymAD", "Task", "GymTask", "TaskManager"] diff --git a/algorithm_distillation/ad.py b/algorithm_distillation/rl_tasks/ad.py similarity index 99% rename from algorithm_distillation/ad.py rename to algorithm_distillation/rl_tasks/ad.py index 0fd3642..8aa5078 100644 --- a/algorithm_distillation/ad.py +++ b/algorithm_distillation/rl_tasks/ad.py @@ -6,8 +6,8 @@ from tqdm import tqdm from algorithm_distillation.models.ad_transformer import ADTransformer +from algorithm_distillation.models.util import get_sequence -from .models.util import get_sequence from .task import GymTask from .task_manager import TaskManager diff --git a/algorithm_distillation/task.py b/algorithm_distillation/rl_tasks/task.py similarity index 99% rename from algorithm_distillation/task.py rename to algorithm_distillation/rl_tasks/task.py index ee740fd..801df0c 100644 --- a/algorithm_distillation/task.py +++ b/algorithm_distillation/rl_tasks/task.py @@ -8,7 +8,7 @@ import stable_baselines3 from stable_baselines3.common.buffers import ReplayBuffer -from algorithm_distillation.sb3_util.callback import RolloutCallback +from algorithm_distillation.models.sb3_util.callback import RolloutCallback class Task(abc.ABC): diff --git a/algorithm_distillation/task_manager.py b/algorithm_distillation/rl_tasks/task_manager.py similarity index 100% rename from algorithm_distillation/task_manager.py rename to algorithm_distillation/rl_tasks/task_manager.py diff --git a/tests/test_ad.py b/tests/test_ad.py index d9dbbf3..0058ed2 100644 --- a/tests/test_ad.py +++ b/tests/test_ad.py @@ -1,9 +1,9 @@ import gym import pytest -from algorithm_distillation import GymTask, TaskManager, GymAD +from algorithm_distillation.rl_tasks import GymTask, TaskManager, GymAD from algorithm_distillation.models import GPT2AD -from algorithm_distillation.sb3_util import CustomLogger, configure +from algorithm_distillation.models.sb3_util import CustomLogger, configure @pytest.mark.parametrize("policy", ['DQN', 'PPO', 'A2C']) diff --git a/tests/test_env.py b/tests/test_env.py index 67c118e..ea49351 100644 --- a/tests/test_env.py +++ b/tests/test_env.py @@ -1,7 +1,7 @@ import gym import pytest -from algorithm_distillation import GymTask +from algorithm_distillation.rl_tasks import GymTask @pytest.mark.parametrize("policy", ['DQN', 'PPO', 'A2C']) diff --git a/tests/test_sb3_util.py b/tests/test_sb3_util.py index d5a4559..dfa0604 100644 --- a/tests/test_sb3_util.py +++ b/tests/test_sb3_util.py @@ -3,8 +3,7 @@ import stable_baselines3 from stable_baselines3.common.buffers import ReplayBuffer -from algorithm_distillation import GymTask -from algorithm_distillation.sb3_util.callback import RolloutCallback +from algorithm_distillation.models.sb3_util.callback import RolloutCallback def test_callback(): From 75029b28f6d234a827b8ede6ea5e90a71ce3adf2 Mon Sep 17 00:00:00 2001 From: Honglu Fan Date: Thu, 12 Jan 2023 18:47:05 +0000 Subject: [PATCH 17/17] Refactored folder structures to pave the way for a merge. --- algorithm_distillation/__init__.py | 5 +++-- algorithm_distillation/{rl_tasks => }/ad.py | 5 ++--- algorithm_distillation/rl_tasks/__init__.py | 5 ----- algorithm_distillation/tasks/__init__.py | 0 algorithm_distillation/tasks/rl/__init__.py | 4 ++++ algorithm_distillation/{rl_tasks => tasks/rl}/task.py | 0 .../{rl_tasks => tasks/rl}/task_manager.py | 0 tests/test_ad.py | 3 ++- tests/test_env.py | 2 +- 9 files changed, 12 insertions(+), 12 deletions(-) rename algorithm_distillation/{rl_tasks => }/ad.py (98%) delete mode 100644 algorithm_distillation/rl_tasks/__init__.py create mode 100644 algorithm_distillation/tasks/__init__.py create mode 100644 algorithm_distillation/tasks/rl/__init__.py rename algorithm_distillation/{rl_tasks => tasks/rl}/task.py (100%) rename algorithm_distillation/{rl_tasks => tasks/rl}/task_manager.py (100%) diff --git a/algorithm_distillation/__init__.py b/algorithm_distillation/__init__.py index 6145a3d..cfde3d4 100644 --- a/algorithm_distillation/__init__.py +++ b/algorithm_distillation/__init__.py @@ -1,3 +1,4 @@ -from algorithm_distillation import rl_tasks +from algorithm_distillation import models, tasks +from algorithm_distillation.ad import AlgorithmDistillation, GymAD -__all__ = ["rl_tasks"] +__all__ = ["models", "tasks", "AlgorithmDistillation", "GymAD"] diff --git a/algorithm_distillation/rl_tasks/ad.py b/algorithm_distillation/ad.py similarity index 98% rename from algorithm_distillation/rl_tasks/ad.py rename to algorithm_distillation/ad.py index 8aa5078..fa88090 100644 --- a/algorithm_distillation/rl_tasks/ad.py +++ b/algorithm_distillation/ad.py @@ -7,9 +7,8 @@ from algorithm_distillation.models.ad_transformer import ADTransformer from algorithm_distillation.models.util import get_sequence - -from .task import GymTask -from .task_manager import TaskManager +from algorithm_distillation.tasks.rl.task import GymTask +from algorithm_distillation.tasks.rl.task_manager import TaskManager class AlgorithmDistillation(abc.ABC): diff --git a/algorithm_distillation/rl_tasks/__init__.py b/algorithm_distillation/rl_tasks/__init__.py deleted file mode 100644 index e2a0157..0000000 --- a/algorithm_distillation/rl_tasks/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from algorithm_distillation.rl_tasks.ad import AlgorithmDistillation, GymAD -from algorithm_distillation.rl_tasks.task import GymTask, Task -from algorithm_distillation.rl_tasks.task_manager import TaskManager - -__all__ = ["AlgorithmDistillation", "GymAD", "Task", "GymTask", "TaskManager"] diff --git a/algorithm_distillation/tasks/__init__.py b/algorithm_distillation/tasks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/algorithm_distillation/tasks/rl/__init__.py b/algorithm_distillation/tasks/rl/__init__.py new file mode 100644 index 0000000..7baca5f --- /dev/null +++ b/algorithm_distillation/tasks/rl/__init__.py @@ -0,0 +1,4 @@ +from algorithm_distillation.tasks.rl.task import GymTask, Task +from algorithm_distillation.tasks.rl.task_manager import TaskManager + +__all__ = ["Task", "GymTask", "TaskManager"] diff --git a/algorithm_distillation/rl_tasks/task.py b/algorithm_distillation/tasks/rl/task.py similarity index 100% rename from algorithm_distillation/rl_tasks/task.py rename to algorithm_distillation/tasks/rl/task.py diff --git a/algorithm_distillation/rl_tasks/task_manager.py b/algorithm_distillation/tasks/rl/task_manager.py similarity index 100% rename from algorithm_distillation/rl_tasks/task_manager.py rename to algorithm_distillation/tasks/rl/task_manager.py diff --git a/tests/test_ad.py b/tests/test_ad.py index ec55d56..f7b9c60 100644 --- a/tests/test_ad.py +++ b/tests/test_ad.py @@ -1,7 +1,8 @@ import gym import pytest -from algorithm_distillation.rl_tasks import GymTask, TaskManager, GymAD +from algorithm_distillation import GymAD +from algorithm_distillation.tasks.rl import GymTask, TaskManager from algorithm_distillation.models import GPT2AD from algorithm_distillation.models.sb3_util import CustomLogger, configure diff --git a/tests/test_env.py b/tests/test_env.py index ea49351..3a160e1 100644 --- a/tests/test_env.py +++ b/tests/test_env.py @@ -1,7 +1,7 @@ import gym import pytest -from algorithm_distillation.rl_tasks import GymTask +from algorithm_distillation.tasks.rl import GymTask @pytest.mark.parametrize("policy", ['DQN', 'PPO', 'A2C'])