diff --git a/pfrl/agents/__init__.py b/pfrl/agents/__init__.py index 1f225d9ec..1caac0279 100644 --- a/pfrl/agents/__init__.py +++ b/pfrl/agents/__init__.py @@ -17,3 +17,5 @@ from pfrl.agents.state_q_function_actor import StateQFunctionActor # NOQA from pfrl.agents.td3 import TD3 # NOQA from pfrl.agents.trpo import TRPO # NOQA +from pfrl.agents.hybrid_soft_actor_critic import HybridSoftActorCritic # NOQA +from pfrl.agents.hybrid_ppo import HybridPPO # NOQA \ No newline at end of file diff --git a/pfrl/agents/dqn_rnn.py b/pfrl/agents/dqn_rnn.py new file mode 100644 index 000000000..84773db34 --- /dev/null +++ b/pfrl/agents/dqn_rnn.py @@ -0,0 +1,884 @@ +import collections +import copy +import ctypes +import multiprocessing as mp +import multiprocessing.synchronize +import os +import time +import typing +from logging import Logger, getLogger +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple + +import numpy as np +import torch +import torch.nn.functional as F + +import pfrl +from pfrl import agent +from pfrl.action_value import ActionValue +from pfrl.explorer import Explorer +from pfrl.replay_buffer import ( + AbstractEpisodicReplayBuffer, + ReplayUpdater, + batch_experiences, + batch_recurrent_experiences, +) +from pfrl.replay_buffers import PrioritizedReplayBuffer +from pfrl.utils.batch_states import batch_states +from pfrl.utils.contexts import evaluating +from pfrl.utils.copy_param import synchronize_parameters +from pfrl.utils.recurrent import ( + get_recurrent_state_at, + mask_recurrent_state_at, + one_step_forward, + pack_and_forward, + recurrent_state_as_numpy, +) + + +def _mean_or_nan(xs: Sequence[float]) -> float: + """Return its mean a non-empty sequence, numpy.nan for a empty one.""" + return typing.cast(float, np.mean(xs)) if xs else np.nan + + +def compute_value_loss( + y: torch.Tensor, + t: torch.Tensor, + clip_delta: bool = True, + batch_accumulator: str = "mean", +) -> torch.Tensor: + """Compute a loss for value prediction problem. + + Args: + y (torch.Tensor): Predicted values. + t (torch.Tensor): Target values. + clip_delta (bool): Use the Huber loss function with delta=1 if set True. + batch_accumulator (str): 'mean' or 'sum'. 'mean' will use the mean of + the loss values in a batch. 'sum' will use the sum. + Returns: + (torch.Tensor) scalar loss + """ + assert batch_accumulator in ("mean", "sum") + y = y.reshape(-1, 1) + t = t.reshape(-1, 1) + if clip_delta: + return F.smooth_l1_loss(y, t, reduction=batch_accumulator) + else: + return F.mse_loss(y, t, reduction=batch_accumulator) / 2 + + +def compute_weighted_value_loss( + y: torch.Tensor, + t: torch.Tensor, + weights: torch.Tensor, + clip_delta: bool = True, + batch_accumulator: str = "mean", +) -> torch.Tensor: + """Compute a loss for value prediction problem. + + Args: + y (torch.Tensor): Predicted values. + t (torch.Tensor): Target values. + weights (torch.Tensor): Weights for y, t. + clip_delta (bool): Use the Huber loss function with delta=1 if set True. + batch_accumulator (str): 'mean' will divide loss by batchsize + Returns: + (torch.Tensor) scalar loss + """ + assert batch_accumulator in ("mean", "sum") + y = y.reshape(-1, 1) + t = t.reshape(-1, 1) + if clip_delta: + losses = F.smooth_l1_loss(y, t, reduction="none") + else: + losses = F.mse_loss(y, t, reduction="none") / 2 + losses = losses.reshape( + -1, + ) + weights = weights.to(losses.device) + loss_sum = torch.sum(losses * weights) + if batch_accumulator == "mean": + loss = loss_sum / y.shape[0] + elif batch_accumulator == "sum": + loss = loss_sum + return loss + + +def _batch_reset_recurrent_states_when_episodes_end( + batch_done: Sequence[bool], batch_reset: Sequence[bool], recurrent_states: Any +) -> Any: + """Reset recurrent states when episodes end. + + Args: + batch_done (array-like of bool): True iff episodes are terminal. + batch_reset (array-like of bool): True iff episodes will be reset. + recurrent_states (object): Recurrent state. + + Returns: + object: New recurrent states. + """ + indices_that_ended = [ + i + for i, (done, reset) in enumerate(zip(batch_done, batch_reset)) + if done or reset + ] + if indices_that_ended: + return mask_recurrent_state_at(recurrent_states, indices_that_ended) + else: + return recurrent_states + + +def make_target_model_as_copy(model: torch.nn.Module) -> torch.nn.Module: + target_model = copy.deepcopy(model) + + def flatten_parameters(mod): + if isinstance(mod, torch.nn.RNNBase): + mod.flatten_parameters() + + # RNNBase.flatten_parameters must be called again after deep-copy. + # See: https://discuss.pytorch.org/t/why-do-we-need-flatten-parameters-when-using-rnn-with-dataparallel/46506 # NOQA + target_model.apply(flatten_parameters) + # set target n/w to evaluate only. + target_model.eval() + return target_model + + +class DQN(agent.AttributeSavingMixin, agent.BatchAgent): + """Deep Q-Network algorithm. + + Args: + q_function (StateQFunction): Q-function + optimizer (Optimizer): Optimizer that is already setup + replay_buffer (ReplayBuffer): Replay buffer + gamma (float): Discount factor + explorer (Explorer): Explorer that specifies an exploration strategy. + gpu (int): GPU device id if not None nor negative. + replay_start_size (int): if the replay buffer's size is less than + replay_start_size, skip update + minibatch_size (int): Minibatch size + update_interval (int): Model update interval in step + target_update_interval (int): Target model update interval in step + clip_delta (bool): Clip delta if set True + phi (callable): Feature extractor applied to observations + target_update_method (str): 'hard' or 'soft'. + soft_update_tau (float): Tau of soft target update. + n_times_update (int): Number of repetition of update + batch_accumulator (str): 'mean' or 'sum' + episodic_update_len (int or None): Subsequences of this length are used + for update if set int and episodic_update=True + logger (Logger): Logger used + batch_states (callable): method which makes a batch of observations. + default is `pfrl.utils.batch_states.batch_states` + recurrent (bool): If set to True, `model` is assumed to implement + `pfrl.nn.Recurrent` and is updated in a recurrent + manner. + max_grad_norm (float or None): Maximum L2 norm of the gradient used for + gradient clipping. If set to None, the gradient is not clipped. + """ + + saved_attributes = ("model", "target_model", "optimizer") + + def __init__( + self, + q_function: torch.nn.Module, + optimizer: torch.optim.Optimizer, # type: ignore # somehow mypy complains + replay_buffer: pfrl.replay_buffer.AbstractReplayBuffer, + gamma: float, + explorer: Explorer, + gpu: Optional[int] = None, + replay_start_size: int = 50000, + minibatch_size: int = 32, + update_interval: int = 1, + target_update_interval: int = 10000, + clip_delta: bool = True, + phi: Callable[[Any], Any] = lambda x: x, + target_update_method: str = "hard", + soft_update_tau: float = 1e-2, + n_times_update: int = 1, + batch_accumulator: str = "mean", + episodic_update_len: Optional[int] = None, + logger: Logger = getLogger(__name__), + batch_states: Callable[ + [Sequence[Any], torch.device, Callable[[Any], Any]], Any + ] = batch_states, + recurrent: bool = False, + max_grad_norm: Optional[float] = None, + burnin: int = 10, # Number of initial time steps to exclude from loss computation + ): + self.model = q_function + if gpu is not None and gpu >= 0: + assert torch.cuda.is_available() + self.device = torch.device("cuda:{}".format(gpu)) + self.model.to(self.device) + else: + self.device = torch.device("cpu") + self.burnin = burnin + self.replay_buffer = replay_buffer + self.optimizer = optimizer + self.gamma = gamma + self.explorer = explorer + self.gpu = gpu + self.target_update_interval = target_update_interval + self.clip_delta = clip_delta + self.phi = phi + self.target_update_method = target_update_method + self.soft_update_tau = soft_update_tau + self.batch_accumulator = batch_accumulator + assert batch_accumulator in ("mean", "sum") + self.logger = logger + self.batch_states = batch_states + self.recurrent = recurrent + update_func: Callable[..., None] + if self.recurrent: + assert isinstance(self.replay_buffer, AbstractEpisodicReplayBuffer) + update_func = self.update_from_episodes + else: + update_func = self.update + self.replay_updater = ReplayUpdater( + replay_buffer=replay_buffer, + update_func=update_func, + batchsize=minibatch_size, + episodic_update=recurrent, + episodic_update_len=episodic_update_len, + n_times_update=n_times_update, + replay_start_size=replay_start_size, + update_interval=update_interval, + ) + self.minibatch_size = minibatch_size + self.episodic_update_len = episodic_update_len + self.replay_start_size = replay_start_size + self.update_interval = update_interval + self.max_grad_norm = max_grad_norm + self.valid_indices = [] + for i in range(self.minibatch_size): + start = i * self.episodic_update_len + self.burnin + end = (i + 1) * self.episodic_update_len + self.valid_indices.extend(range(start, end)) + assert ( + target_update_interval % update_interval == 0 + ), "target_update_interval should be a multiple of update_interval" + + self.t = 0 + self.optim_t = 0 # Compensate pytorch optim not having `t` + self._cumulative_steps = 0 + self.target_model = make_target_model_as_copy(self.model) + + # Statistics + self.q_record: collections.deque = collections.deque(maxlen=1000) + self.loss_record: collections.deque = collections.deque(maxlen=100) + + # Recurrent states of the model + self.train_recurrent_states: Any = None + self.train_prev_recurrent_states: Any = None + self.test_recurrent_states: Any = None + + # Error checking + if ( + self.replay_buffer.capacity is not None + and self.replay_buffer.capacity < self.replay_updater.replay_start_size + ): + raise ValueError("Replay start size cannot exceed replay buffer capacity.") + + @property + def cumulative_steps(self) -> int: + # cumulative_steps counts the overall steps during the training. + return self._cumulative_steps + + def _setup_actor_learner_training( + self, + n_actors: int, + actor_update_interval: int, + update_counter: Any, + ) -> Tuple[ + torch.nn.Module, + Sequence[mp.connection.Connection], + Sequence[mp.connection.Connection], + ]: + assert actor_update_interval > 0 + + self.actor_update_interval = actor_update_interval + self.update_counter = update_counter + + # Make a copy on shared memory and share among actors and the poller + shared_model = copy.deepcopy(self.model).cpu() + shared_model.share_memory() + + # Pipes are used for infrequent communication + learner_pipes, actor_pipes = list(zip(*[mp.Pipe() for _ in range(n_actors)])) + + return (shared_model, learner_pipes, actor_pipes) + + def sync_target_network(self) -> None: + """Synchronize target network with current network.""" + synchronize_parameters( + src=self.model, + dst=self.target_model, + method=self.target_update_method, + tau=self.soft_update_tau, + ) + + def update( + self, experiences: List[List[Dict[str, Any]]], errors_out: Optional[list] = None + ) -> None: + """Update the model from experiences + + Args: + experiences (list): List of lists of dicts. + For DQN, each dict must contains: + - state (object): State + - action (object): Action + - reward (float): Reward + - is_state_terminal (bool): True iff next state is terminal + - next_state (object): Next state + - weight (float, optional): Weight coefficient. It can be + used for importance sampling. + errors_out (list or None): If set to a list, then TD-errors + computed from the given experiences are appended to the list. + + Returns: + None + """ + has_weight = "weight" in experiences[0][0] + exp_batch = batch_experiences( + experiences, + device=self.device, + phi=self.phi, + gamma=self.gamma, + batch_states=self.batch_states, + ) + if has_weight: + exp_batch["weights"] = torch.tensor( + [elem[0]["weight"] for elem in experiences], + device=self.device, + dtype=torch.float32, + ) + if errors_out is None: + errors_out = [] + + loss = self._compute_loss(exp_batch, errors_out=errors_out) + if has_weight: + assert isinstance(self.replay_buffer, PrioritizedReplayBuffer) + self.replay_buffer.update_errors(errors_out) + + self.loss_record.append(float(loss.detach().cpu().numpy())) + + self.optimizer.zero_grad() + loss.backward() + if self.max_grad_norm is not None: + pfrl.utils.clip_l2_grad_norm_(self.model.parameters(), self.max_grad_norm) + self.optimizer.step() + self.optim_t += 1 + + def update_from_episodes( + self, episodes: List[List[Dict[str, Any]]], errors_out: Optional[list] = None + ) -> None: + assert errors_out is None, "Recurrent DQN does not support PrioritizedBuffer" + episodes = sorted(episodes, key=len, reverse=True) + exp_batch = batch_recurrent_experiences( + episodes, + device=self.device, + phi=self.phi, + gamma=self.gamma, + batch_states=self.batch_states, + ) + loss = self._compute_loss(exp_batch, errors_out=None) + self.loss_record.append(float(loss.detach().cpu().numpy())) + self.optimizer.zero_grad() + loss.backward() + if self.max_grad_norm is not None: + pfrl.utils.clip_l2_grad_norm_(self.model.parameters(), self.max_grad_norm) + self.optimizer.step() + self.optim_t += 1 + + + def _compute_target_values(self, exp_batch: Dict[str, Any]) -> torch.Tensor: + batch_next_state = [exp_batch["next_state"], exp_batch["action"], exp_batch["reward"]] + + if self.recurrent: + target_next_qout, _ = pack_and_forward( + self.target_model, + batch_next_state, + exp_batch["next_recurrent_state"], + ) + else: + target_next_qout = self.target_model(batch_next_state) ############################# + next_q_max = target_next_qout.max + + batch_rewards = exp_batch["reward"] + batch_terminal = exp_batch["is_state_terminal"] + discount = exp_batch["discount"] + + return batch_rewards + discount * (1.0 - batch_terminal) * next_q_max + + def _compute_y_and_t( + self, exp_batch: Dict[str, Any] + ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size = exp_batch["reward"].shape[0] + + # Compute Q-values for current states + # For the input to the model, we want batch_state, action, and reward + batch_state = [exp_batch["state"], exp_batch["action"], exp_batch["reward"]] + + if self.recurrent: + qout, _ = pack_and_forward( + self.model, batch_state, exp_batch["recurrent_state"] + ) + else: + qout = self.model(batch_state) + + batch_actions = exp_batch["action"] + batch_q = torch.reshape(qout.evaluate_actions(batch_actions), (batch_size, 1)) + + with torch.no_grad(): + batch_q_target = torch.reshape( + self._compute_target_values(exp_batch), (batch_size, 1) + ) + + return batch_q, batch_q_target + + # def _compute_loss( + # self, exp_batch: Dict[str, Any], errors_out: Optional[list] = None + # ) -> torch.Tensor: + # """Compute the Q-learning loss for a batch of experiences + + + # Args: + # exp_batch (dict): A dict of batched arrays of transitions + # Returns: + # Computed loss from the minibatch of experiences + # """ + + # y, t = self._compute_y_and_t(exp_batch) + + # self.q_record.extend(y.detach().cpu().numpy().ravel()) + + # if errors_out is not None: + # del errors_out[:] + # delta = torch.abs(y - t) + # if delta.ndim == 2: + # delta = torch.sum(delta, dim=1) + # delta = delta.detach().cpu().numpy() + # for e in delta: + # errors_out.append(e) + # if "weights" in exp_batch: + # return compute_weighted_value_loss( + # y, + # t, + # exp_batch["weights"], + # clip_delta=self.clip_delta, + # batch_accumulator=self.batch_accumulator, + # ) + # else: + # return compute_value_loss( + # y, + # t, + # clip_delta=self.clip_delta, + # batch_accumulator=self.batch_accumulator, + # ) + def _compute_loss( + self, + exp_batch: Dict[str, Any], + errors_out: Optional[list] = None, + ) -> torch.Tensor: + """Compute the Q-learning loss for a batch of experiences, with burn-in support. + + Args: + exp_batch (dict): A dict of batched arrays of transitions. + errors_out (list or None): List to output TD errors (used in prioritized replay). + burnin (int): Number of initial time steps to exclude from loss computation. + Returns: + Computed loss from the minibatch of experiences. + """ + + y, t = self._compute_y_and_t(exp_batch) + + # Slice out the burn-in portion + if self.burnin > 0: + y = y[self.valid_indices, :] + t = t[self.valid_indices, :] + + self.q_record.extend(y.detach().cpu().numpy().ravel()) + + if errors_out is not None: + del errors_out[:] + delta = torch.abs(y - t) + if delta.ndim == 2: + delta = torch.sum(delta, dim=1) + delta = delta.detach().cpu().numpy() + for e in delta: + errors_out.append(e) + + if "weights" in exp_batch: + #IT is wrong. Fix it if needed + weights = exp_batch["weights"] + if self.burnin > 0: + weights = weights[:, self.burnin:] + return compute_weighted_value_loss( + y, + t, + weights, + clip_delta=self.clip_delta, + batch_accumulator=self.batch_accumulator, + ) + else: + return compute_value_loss( + y, + t, + clip_delta=self.clip_delta, + batch_accumulator=self.batch_accumulator, + ) + + + def _evaluate_model_and_update_recurrent_states( + self, batch_obs: Sequence[Any], batch_action: Sequence[Any], batch_reward: Sequence[Any] + ) -> ActionValue: + batch_xs = [self.batch_states(batch_obs, self.device, self.phi), batch_action, batch_reward] + if self.recurrent: + if self.training: + self.train_prev_recurrent_states = self.train_recurrent_states + batch_av, self.train_recurrent_states = one_step_forward( + self.model, batch_xs, self.train_recurrent_states + ) + else: + batch_av, self.test_recurrent_states = one_step_forward( + self.model, batch_xs, self.test_recurrent_states + ) + else: + batch_av = self.model(batch_xs) + return batch_av + + def batch_act(self, batch_obs: Sequence[Any], batch_action: Sequence[Any], batch_reward: Sequence[Any]) -> Sequence[Any]: + with torch.no_grad(), evaluating(self.model): + batch_av = self._evaluate_model_and_update_recurrent_states(batch_obs, batch_action, batch_reward) + batch_argmax = batch_av.greedy_actions.detach().cpu().numpy() + if self.training: + batch_action = [ + self.explorer.select_action( + self.t, + lambda: batch_argmax[i], + action_value=batch_av[i : i + 1], + ) + for i in range(len(batch_obs)) + ] + self.batch_last_obs = list(batch_obs) + self.batch_last_action = list(batch_action) + else: + batch_action = batch_argmax + return batch_action + + def _batch_observe_train( + self, + batch_obs: Sequence[Any], + batch_reward: Sequence[float], + batch_done: Sequence[bool], + batch_reset: Sequence[bool], + ) -> None: + for i in range(len(batch_obs)): + self.t += 1 + self._cumulative_steps += 1 + # Update the target network + if self.t % self.target_update_interval == 0: + self.sync_target_network() + + # import pdb; pdb.set_trace() + if self.batch_last_obs[i] is not None: + assert self.batch_last_action[i] is not None + # Add a transition to the replay buffer + transition = { + "state": self.batch_last_obs[i], + "action": self.batch_last_action[i], + "reward": batch_reward[i], + "next_state": batch_obs[i], + "next_action": None, + "is_state_terminal": batch_done[i], + } + if self.recurrent: + transition["recurrent_state"] = recurrent_state_as_numpy( + get_recurrent_state_at( + self.train_prev_recurrent_states, i, detach=True + ) + ) + transition["next_recurrent_state"] = recurrent_state_as_numpy( + get_recurrent_state_at( + self.train_recurrent_states, i, detach=True + ) + ) + + self.replay_buffer.append(env_id=i, **transition) + if batch_reset[i] or batch_done[i]: + self.batch_last_obs[i] = None + self.batch_last_action[i] = None + self.replay_buffer.stop_current_episode(env_id=i) + self.replay_updater.update_if_necessary(self.t) + + if self.recurrent: + # Reset recurrent states when episodes end + self.train_prev_recurrent_states = None + self.train_recurrent_states = ( + _batch_reset_recurrent_states_when_episodes_end( # NOQA + batch_done=batch_done, + batch_reset=batch_reset, + recurrent_states=self.train_recurrent_states, + ) + ) + + def _batch_observe_eval( + self, + batch_obs: Sequence[Any], + batch_reward: Sequence[float], + batch_done: Sequence[bool], + batch_reset: Sequence[bool], + ) -> None: + if self.recurrent: + # Reset recurrent states when episodes end + self.test_recurrent_states = ( + _batch_reset_recurrent_states_when_episodes_end( # NOQA + batch_done=batch_done, + batch_reset=batch_reset, + recurrent_states=self.test_recurrent_states, + ) + ) + + def batch_observe( + self, + batch_obs: Sequence[Any], + batch_reward: Sequence[float], + batch_done: Sequence[bool], + batch_reset: Sequence[bool], + ) -> None: + if self.training: + return self._batch_observe_train( + batch_obs, batch_reward, batch_done, batch_reset + ) + else: + return self._batch_observe_eval( + batch_obs, batch_reward, batch_done, batch_reset + ) + + def _can_start_replay(self) -> bool: + + if len(self.replay_buffer) < self.replay_start_size: + return False + if self.recurrent: + assert isinstance(self.replay_buffer, AbstractEpisodicReplayBuffer) + if self.replay_buffer.n_episodes < self.minibatch_size: + return False + return True + + def _poll_pipe( + self, + actor_idx: int, + pipe: mp.connection.Connection, + replay_buffer_lock: mp.synchronize.Lock, + exception_event: mp.synchronize.Event, + ) -> None: + if pipe.closed: + return + try: + while pipe.poll() and not exception_event.is_set(): + cmd, data = pipe.recv() + if cmd == "get_statistics": + assert data is None + with replay_buffer_lock: + stats = self.get_statistics() + pipe.send(stats) + elif cmd == "load": + self.load(data) + pipe.send(None) + elif cmd == "save": + self.save(data) + pipe.send(None) + elif cmd == "transition": + with replay_buffer_lock: + if "env_id" not in data: + data["env_id"] = actor_idx + self.replay_buffer.append(**data) + self._cumulative_steps += 1 + elif cmd == "stop_episode": + idx = actor_idx if data is None else data + with replay_buffer_lock: + self.replay_buffer.stop_current_episode(env_id=idx) + stats = self.get_statistics() + pipe.send(stats) + + else: + raise RuntimeError("Unknown command from actor: {}".format(cmd)) + except EOFError: + pipe.close() + except Exception: + self.logger.exception("Poller loop failed. Exiting") + exception_event.set() + + def _learner_loop( + self, + shared_model: torch.nn.Module, + pipes: Sequence[mp.connection.Connection], + replay_buffer_lock: mp.synchronize.Lock, + stop_event: mp.synchronize.Event, + exception_event: mp.synchronize.Event, + n_updates: Optional[int] = None, + step_hooks: List[Callable[[None, agent.Agent, int], Any]] = [], + optimizer_step_hooks: List[Callable[[None, agent.Agent, int], Any]] = [], + ) -> None: + try: + update_counter = 0 + # To stop this loop, call stop_event.set() + while not stop_event.is_set(): + # Update model if possible + if not self._can_start_replay(): + continue + if n_updates is not None: + assert self.optim_t <= n_updates + if self.optim_t == n_updates: + stop_event.set() + break + + if self.recurrent: + assert isinstance(self.replay_buffer, AbstractEpisodicReplayBuffer) + with replay_buffer_lock: + episodes = self.replay_buffer.sample_episodes( + self.minibatch_size, self.episodic_update_len #+ burnin + ) + self.update_from_episodes(episodes) + else: + with replay_buffer_lock: + transitions = self.replay_buffer.sample(self.minibatch_size) + + self.update(transitions) + + # Update the shared model. This can be expensive if GPU is used + # since this is a DtoH copy, so it is updated only at regular + # intervals. + update_counter += 1 + if update_counter % self.actor_update_interval == 0: + with self.update_counter.get_lock(): + self.update_counter.value += 1 + shared_model.load_state_dict(self.model.state_dict()) + + # To keep the ratio of target updates to model updates, + # here we calculate back the effective current timestep + # from update_interval and number of updates so far. + effective_timestep = self.optim_t * self.update_interval + # We can safely assign self.t since in the learner + # it isn't updated by any other method + self.t = effective_timestep + + for hook in optimizer_step_hooks: + hook(None, self, self.optim_t) + + for hook in step_hooks: + hook(None, self, effective_timestep) + + if effective_timestep % self.target_update_interval == 0: + self.sync_target_network() + except Exception: + self.logger.exception("Learner loop failed. Exiting") + exception_event.set() + + def _poller_loop( + self, + shared_model: torch.nn.Module, + pipes: Sequence[mp.connection.Connection], + replay_buffer_lock: mp.synchronize.Lock, + stop_event: mp.synchronize.Event, + exception_event: mp.synchronize.Event, + ) -> None: + # To stop this loop, call stop_event.set() + while not stop_event.is_set() and not exception_event.is_set(): + time.sleep(1e-6) + # Poll actors for messages + for i, pipe in enumerate(pipes): + self._poll_pipe(i, pipe, replay_buffer_lock, exception_event) + + def setup_actor_learner_training( + self, + n_actors: int, + update_counter: Optional[Any] = None, + n_updates: Optional[int] = None, + actor_update_interval: int = 8, + step_hooks: List[Callable[[None, agent.Agent, int], Any]] = [], + optimizer_step_hooks: List[Callable[[None, agent.Agent, int], Any]] = [], + ): + if update_counter is None: + update_counter = mp.Value(ctypes.c_ulong) + + (shared_model, learner_pipes, actor_pipes) = self._setup_actor_learner_training( + n_actors, actor_update_interval, update_counter + ) + exception_event = mp.Event() + + def make_actor(i): + return pfrl.agents.StateQFunctionActor( + pipe=actor_pipes[i], + model=shared_model, + explorer=self.explorer, + phi=self.phi, + batch_states=self.batch_states, + logger=self.logger, + recurrent=self.recurrent, + ) + + replay_buffer_lock = mp.Lock() + + poller_stop_event = mp.Event() + poller = pfrl.utils.StoppableThread( + target=self._poller_loop, + kwargs=dict( + shared_model=shared_model, + pipes=learner_pipes, + replay_buffer_lock=replay_buffer_lock, + stop_event=poller_stop_event, + exception_event=exception_event, + ), + stop_event=poller_stop_event, + ) + + learner_stop_event = mp.Event() + learner = pfrl.utils.StoppableThread( + target=self._learner_loop, + kwargs=dict( + shared_model=shared_model, + pipes=learner_pipes, + replay_buffer_lock=replay_buffer_lock, + stop_event=learner_stop_event, + n_updates=n_updates, + exception_event=exception_event, + step_hooks=step_hooks, + optimizer_step_hooks=optimizer_step_hooks, + ), + stop_event=learner_stop_event, + ) + + return make_actor, learner, poller, exception_event + + def stop_episode(self) -> None: + if self.recurrent: + self.test_recurrent_states = None + + def save_snapshot(self, dirname: str) -> None: + self.save(dirname) + torch.save(self.t, os.path.join(dirname, "t.pt")) + torch.save(self.optim_t, os.path.join(dirname, "optim_t.pt")) + torch.save( + self._cumulative_steps, os.path.join(dirname, "_cumulative_steps.pt") + ) + self.replay_buffer.save(os.path.join(dirname, "replay_buffer.pkl")) + + def load_snapshot(self, dirname: str) -> None: + self.load(dirname) + self.t = torch.load(os.path.join(dirname, "t.pt")) + self.optim_t = torch.load(os.path.join(dirname, "optim_t.pt")) + self._cumulative_steps = torch.load( + os.path.join(dirname, "_cumulative_steps.pt") + ) + self.replay_buffer.load(os.path.join(dirname, "replay_buffer.pkl")) + + def get_statistics(self): + return [ + ("average_q", _mean_or_nan(self.q_record)), + ("average_loss", _mean_or_nan(self.loss_record)), + ("cumulative_steps", self.cumulative_steps), + ("n_updates", self.optim_t), + ("rlen", len(self.replay_buffer)), + ] + \ No newline at end of file diff --git a/pfrl/agents/hybrid_ddpg.py b/pfrl/agents/hybrid_ddpg.py new file mode 100644 index 000000000..a7796a220 --- /dev/null +++ b/pfrl/agents/hybrid_ddpg.py @@ -0,0 +1,331 @@ +import collections +import copy +from logging import getLogger + +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + +from pfrl.agent import AttributeSavingMixin, BatchAgent +from pfrl.replay_buffer import ReplayUpdater, hybrid_batch_experiences +from pfrl.utils.batch_states import batch_states +from pfrl.utils.contexts import evaluating +from pfrl.utils.copy_param import synchronize_parameters + + +def _mean_or_nan(xs): + """Return its mean a non-empty sequence, numpy.nan for a empty one.""" + return np.mean(xs) if xs else np.nan + + +class HybridDDPG(AttributeSavingMixin, BatchAgent): + """Deep Deterministic Policy Gradients. + + This can be used as SVG(0) by specifying a Gaussian policy instead of a + deterministic policy. + + Args: + policy (torch.nn.Module): Policy + q_func (torch.nn.Module): Q-function + actor_optimizer (Optimizer): Optimizer setup with the policy + critic_optimizer (Optimizer): Optimizer setup with the Q-function + replay_buffer (ReplayBuffer): Replay buffer + gamma (float): Discount factor + explorer (Explorer): Explorer that specifies an exploration strategy. + gpu (int): GPU device id if not None nor negative. + replay_start_size (int): if the replay buffer's size is less than + replay_start_size, skip update + minibatch_size (int): Minibatch size + update_interval (int): Model update interval in step + target_update_interval (int): Target model update interval in step + phi (callable): Feature extractor applied to observations + target_update_method (str): 'hard' or 'soft'. + soft_update_tau (float): Tau of soft target update. + n_times_update (int): Number of repetition of update + batch_accumulator (str): 'mean' or 'sum' + episodic_update (bool): Use full episodes for update if set True + episodic_update_len (int or None): Subsequences of this length are used + for update if set int and episodic_update=True + logger (Logger): Logger used + batch_states (callable): method which makes a batch of observations. + default is `pfrl.utils.batch_states.batch_states` + burnin_action_func (callable or None): If not None, this callable + object is used to select actions before the model is updated + one or more times during training. + """ + + saved_attributes = ("model", "target_model", "actor_optimizer", "critic_optimizer") + + def __init__( + self, + policy, + q_func, + actor_optimizer, + critic_optimizer, + replay_buffer, + gamma, + explorer, + gpu=None, + replay_start_size=50000, + minibatch_size=32, + update_interval=1, + target_update_interval=10000, + phi=lambda x: x, + target_update_method="hard", + soft_update_tau=1e-2, + n_times_update=1, + recurrent=False, + episodic_update_len=None, + logger=getLogger(__name__), + batch_states=batch_states, + burnin_action_func=None, + ): + self.model = nn.ModuleList([policy, q_func]) + if gpu is not None and gpu >= 0: + assert torch.cuda.is_available() + self.device = torch.device("cuda:{}".format(gpu)) + self.model.to(self.device) + else: + self.device = torch.device("cpu") + + self.replay_buffer = replay_buffer + self.gamma = gamma + self.explorer = explorer + self.gpu = gpu + self.target_update_interval = target_update_interval + self.phi = phi + self.target_update_method = target_update_method + self.soft_update_tau = soft_update_tau + self.logger = logger + self.actor_optimizer = actor_optimizer + self.critic_optimizer = critic_optimizer + self.recurrent = recurrent + assert not self.recurrent, "recurrent=True is not yet implemented" + if self.recurrent: + update_func = self.update_from_episodes + else: + update_func = self.update + self.replay_updater = ReplayUpdater( + replay_buffer=replay_buffer, + update_func=update_func, + batchsize=minibatch_size, + episodic_update=recurrent, + episodic_update_len=episodic_update_len, + n_times_update=n_times_update, + replay_start_size=replay_start_size, + update_interval=update_interval, + ) + self.batch_states = batch_states + self.burnin_action_func = burnin_action_func + + self.t = 0 + self.last_state = None + self.last_action = None + self.target_model = copy.deepcopy(self.model) + self.target_model.eval() + self.q_record = collections.deque(maxlen=1000) + self.actor_loss_record = collections.deque(maxlen=100) + self.critic_loss_record = collections.deque(maxlen=100) + self.n_updates = 0 + + # Aliases for convenience + self.policy, self.q_function = self.model + self.target_policy, self.target_q_function = self.target_model + + self.sync_target_network() + + def sync_target_network(self): + """Synchronize target network with current network.""" + synchronize_parameters( + src=self.model, + dst=self.target_model, + method=self.target_update_method, + tau=self.soft_update_tau, + ) + + # Update Q-function + def compute_critic_loss(self, batch): + """Compute loss for critic.""" + + batch_next_state = batch["next_state"] + batch_rewards = batch["reward"] + batch_terminal = batch["is_state_terminal"] + batch_state = batch["state"] + # batch_actions = batch["action"] + c_actions = batch["c_action"] + d_actions = batch["d_action"] + batchsize = len(batch_rewards) + + with torch.no_grad(): + assert not self.recurrent + # next_actions = self.target_policy(batch_next_state).sample() + c_next_action_distrib, d_next_action_distrib = self.policy(batch_next_state) + c_next_action = c_next_action_distrib.sample() + d_next_action = d_next_action_distrib.sample() + # d_next_action = F.gumbel_softmax(d_next_action_distrib.logits, tau=0.5, hard=True).argmax(dim=-1) + + next_q = self.target_q_function((batch_next_state, (c_next_action, d_next_action))) + target_q = batch_rewards + self.gamma * ( + 1.0 - batch_terminal + ) * next_q.reshape((batchsize,)) + + predict_q = self.q_function((batch_state, (c_actions, d_actions))).reshape((batchsize,)) + + loss = F.mse_loss(target_q, predict_q) + + # Update stats + self.critic_loss_record.append(loss.item()) + + return loss + + def compute_actor_loss(self, batch): + """Compute loss for actor.""" + + batch_state = batch["state"] + c_action_distrib, d_action_distrib = self.policy(batch_state) + c_actions = c_action_distrib.rsample() + d_actions = d_action_distrib.sample() + # onpolicy_actions = self.policy(batch_state).rsample() + + q = self.q_function((batch_state, (c_actions, d_actions))) + loss = -q.mean() + + # Update stats + self.q_record.extend(q.detach().cpu().numpy()) + self.actor_loss_record.append(loss.item()) + + return loss + + def update(self, experiences, errors_out=None): + """Update the model from experiences""" + + batch = hybrid_batch_experiences(experiences, self.device, self.phi, self.gamma) + + self.critic_optimizer.zero_grad() + self.compute_critic_loss(batch).backward() + self.critic_optimizer.step() + + self.actor_optimizer.zero_grad() + self.compute_actor_loss(batch).backward() + self.actor_optimizer.step() + + self.n_updates += 1 + + def update_from_episodes(self, episodes, errors_out=None): + raise NotImplementedError + + # Sort episodes desc by their lengths + sorted_episodes = list(reversed(sorted(episodes, key=len))) + max_epi_len = len(sorted_episodes[0]) + + # Precompute all the input batches + batches = [] + for i in range(max_epi_len): + transitions = [] + for ep in sorted_episodes: + if len(ep) <= i: + break + transitions.append([ep[i]]) + batch = batch_experiences( + transitions, xp=self.device, phi=self.phi, gamma=self.gamma + ) + batches.append(batch) + + with self.model.state_reset(), self.target_model.state_reset(): + # Since the target model is evaluated one-step ahead, + # its internal states need to be updated + self.target_q_function.update_state( + batches[0]["state"], batches[0]["action"] + ) + self.target_policy(batches[0]["state"]) + + # Update critic through time + critic_loss = 0 + for batch in batches: + critic_loss += self.compute_critic_loss(batch) + self.critic_optimizer.update(lambda: critic_loss / max_epi_len) + + with self.model.state_reset(): + # Update actor through time + actor_loss = 0 + for batch in batches: + actor_loss += self.compute_actor_loss(batch) + self.actor_optimizer.update(lambda: actor_loss / max_epi_len) + + def batch_act(self, batch_obs): + if self.training: + return self._batch_act_train(batch_obs) + else: + return self._batch_act_eval(batch_obs) + + def batch_observe(self, batch_obs, batch_reward, batch_done, batch_reset): + if self.training: + self._batch_observe_train(batch_obs, batch_reward, batch_done, batch_reset) + + def _batch_select_greedy_actions(self, batch_obs): + with torch.no_grad(), evaluating(self.policy): + batch_xs = self.batch_states(batch_obs, self.device, self.phi) + # batch_action = self.policy(batch_xs).sample() + c_action_distrib, d_action_distrib = self.policy(batch_xs) + c_batch_action = c_action_distrib.sample().cpu().numpy() + d_batch_action = d_action_distrib.sample().cpu().numpy() + + return [c_batch_action, d_batch_action] + + def _batch_act_eval(self, batch_obs): + assert not self.training + c_action, d_action = self._batch_select_greedy_actions(batch_obs) + return [(c_action[0], d_action[0])] + + def _batch_act_train(self, batch_obs): + assert self.training + if self.burnin_action_func is not None and self.n_updates == 0: + batch_action = [self.burnin_action_func() for _ in range(len(batch_obs))] + else: + batch_greedy_action = self._batch_select_greedy_actions(batch_obs) + batch_action = [ + ( + self.explorer.select_action(self.t, lambda: batch_greedy_action[0][i]), + batch_greedy_action[1][i], + ) + for i in range(len(batch_obs)) + ] + + + self.batch_last_obs = list(batch_obs) + self.batch_last_action = list(batch_action) + return batch_action + + def _batch_observe_train(self, batch_obs, batch_reward, batch_done, batch_reset): + assert self.training + for i in range(len(batch_obs)): + self.t += 1 + # Update the target network + if self.t % self.target_update_interval == 0: + self.sync_target_network() + if self.batch_last_obs[i] is not None: + assert self.batch_last_action[i] is not None + # Add a transition to the replay buffer + self.replay_buffer.append( + state=self.batch_last_obs[i], + action=self.batch_last_action[i], + reward=batch_reward[i], + next_state=batch_obs[i], + next_action=None, + is_state_terminal=batch_done[i], + env_id=i, + ) + if batch_reset[i] or batch_done[i]: + self.batch_last_obs[i] = None + self.batch_last_action[i] = None + self.replay_buffer.stop_current_episode(env_id=i) + self.replay_updater.update_if_necessary(self.t) + + def get_statistics(self): + return [ + ("average_q", _mean_or_nan(self.q_record)), + ("average_actor_loss", _mean_or_nan(self.actor_loss_record)), + ("average_critic_loss", _mean_or_nan(self.critic_loss_record)), + ("n_updates", self.n_updates), + ] diff --git a/pfrl/agents/hybrid_ppo.py b/pfrl/agents/hybrid_ppo.py new file mode 100644 index 000000000..f8bab39a6 --- /dev/null +++ b/pfrl/agents/hybrid_ppo.py @@ -0,0 +1,920 @@ +import collections +import itertools +import random + +import numpy as np +import torch +import torch.nn.functional as F + +import pfrl +from pfrl import agent +from pfrl.utils.batch_states import batch_states +from pfrl.utils.mode_of_distribution import mode_of_distribution +from pfrl.utils.recurrent import ( + concatenate_recurrent_states, + flatten_sequences_time_first, + get_recurrent_state_at, + mask_recurrent_state_at, + one_step_forward, + pack_and_forward, +) + + +def _mean_or_nan(xs): + """Return its mean a non-empty sequence, numpy.nan for a empty one.""" + return np.mean(xs) if xs else np.nan + + +def _elementwise_clip(x, x_min, x_max): + """Elementwise clipping + + Note: torch.clamp supports clipping to constant intervals + """ + return torch.min(torch.max(x, x_min), x_max) + + +def _add_advantage_and_value_target_to_episode(episode, gamma, lambd): + """Add advantage and value target values to an episode.""" + adv = 0.0 + for transition in reversed(episode): + td_err = ( + transition["reward"] + + (gamma * transition["nonterminal"] * transition["next_v_pred"]) + - transition["v_pred"] + ) + adv = td_err + gamma * lambd * adv + transition["adv"] = adv + transition["v_teacher"] = adv + transition["v_pred"] + + +def _add_advantage_and_value_target_to_episodes(episodes, gamma, lambd): + """Add advantage and value target values to a list of episodes.""" + for episode in episodes: + _add_advantage_and_value_target_to_episode(episode, gamma=gamma, lambd=lambd) + + +def _add_log_prob_and_value_to_episodes_recurrent( + episodes, + model, + phi, + batch_states, + obs_normalizer, + device, +): + # Sort desc by lengths so that pack_sequence does not change the order + episodes = sorted(episodes, key=len, reverse=True) + + # Prepare data for a recurrent model + seqs_states = [] + seqs_next_states = [] + for ep in episodes: + states = batch_states([transition["state"] for transition in ep], device, phi) + next_states = batch_states( + [transition["next_state"] for transition in ep], device, phi + ) + if obs_normalizer: + states = obs_normalizer(states, update=False) + next_states = obs_normalizer(next_states, update=False) + seqs_states.append(states) + seqs_next_states.append(next_states) + + flat_transitions = flatten_sequences_time_first(episodes) + + # Predict values using a recurrent model + with torch.no_grad(), pfrl.utils.evaluating(model): + rs = concatenate_recurrent_states([ep[0]["recurrent_state"] for ep in episodes]) + next_rs = concatenate_recurrent_states( + [ep[0]["next_recurrent_state"] for ep in episodes] + ) + assert (rs is None) or (next_rs is None) or (len(rs) == len(next_rs)) + + (flat_distribs, flat_vs), _ = pack_and_forward(model, seqs_states, rs) + (_, flat_next_vs), _ = pack_and_forward(model, seqs_next_states, next_rs) + + # flat_actions = torch.tensor( + # [b["action"] for b in flat_transitions], device=device + # ) + flat_actions_c = torch.tensor( + [b["action"][0] for b in flat_transitions], device=device + ) + flat_actions_d = torch.tensor( + [b["action"][1] for b in flat_transitions], device=device + ) + # flat_log_probs = flat_distribs.log_prob(flat_actions).cpu().numpy() + c_flat_log_probs = flat_distribs[0].log_prob(flat_actions_c).cpu().numpy() + d_flat_log_probs = flat_distribs[1].log_prob(flat_actions_d).cpu().numpy() + + flat_vs = flat_vs.cpu().numpy() + + flat_next_vs = flat_next_vs.cpu().numpy() + + # Add predicted values to transitions + for transition, c_log_prob, d_log_probs, v, next_v in zip( + flat_transitions, c_flat_log_probs, d_flat_log_probs, flat_vs, flat_next_vs + ): + # transition["log_prob"] = float(log_prob) + transition["c_log_prob"] = float(c_log_prob) + transition["d_log_prob"] = float(d_log_probs) + transition["v_pred"] = float(v) + transition["next_v_pred"] = float(next_v) + + +def _add_log_prob_and_value_to_episodes( + episodes, + model, + phi, + batch_states, + obs_normalizer, + device, +): + dataset = list(itertools.chain.from_iterable(episodes)) + + # Compute v_pred and next_v_pred + states = batch_states([b["state"] for b in dataset], device, phi) + next_states = batch_states([b["next_state"] for b in dataset], device, phi) + + if obs_normalizer: + states = obs_normalizer(states, update=False) + next_states = obs_normalizer(next_states, update=False) + + with torch.no_grad(), pfrl.utils.evaluating(model): + distribs, vs_pred = model(states) + _, next_vs_pred = model(next_states) + # actions = torch.tensor([b["action"] for b in dataset], device=device) + # log_probs = distribs.log_prob(actions).cpu().numpy() + #Hybrid_actions + c_actions = torch.tensor(np.array( [b["action"][0] for b in dataset] ), device=device) + d_actions = torch.tensor(np.array( [b["action"][1] for b in dataset] ), device=device) + + c_log_probs = distribs[0].log_prob(c_actions).cpu().numpy() + d_log_probs = distribs[1].log_prob(d_actions).cpu().numpy() + # log_probs = log_probs_c + log_probs_d + + vs_pred = vs_pred.cpu().numpy().ravel() + next_vs_pred = next_vs_pred.cpu().numpy().ravel() + + for transition, c_log_prob, d_log_prob, v_pred, next_v_pred in zip( + dataset, c_log_probs, d_log_probs, vs_pred, next_vs_pred + ): + transition["c_log_prob"] = c_log_prob + transition["d_log_prob"] = d_log_prob + transition["v_pred"] = v_pred + transition["next_v_pred"] = next_v_pred + + +def _limit_sequence_length(sequences, max_len): + assert max_len > 0 + new_sequences = [] + for sequence in sequences: + while len(sequence) > max_len: + new_sequences.append(sequence[:max_len]) + sequence = sequence[max_len:] + assert 0 < len(sequence) <= max_len + new_sequences.append(sequence) + return new_sequences + + +def _yield_subset_of_sequences_with_fixed_number_of_items(sequences, n_items): + assert n_items > 0 + stack = list(reversed(sequences)) + while stack: + subset = [] + count = 0 + while count < n_items and stack: + sequence = stack.pop() + subset.append(sequence) + count += len(sequence) + if count > n_items: + # Split last sequence + sequence_to_split = subset[-1] + n_exceeds = count - n_items + assert n_exceeds > 0 + subset[-1] = sequence_to_split[:-n_exceeds] + stack.append(sequence_to_split[-n_exceeds:]) + if sum(len(seq) for seq in subset) == n_items: + yield subset + else: + # This ends the while loop. + assert len(stack) == 0 + + +def _compute_explained_variance(transitions): + """Compute 1 - Var[return - v]/Var[return]. + + This function computes the fraction of variance that value predictions can + explain about returns. + """ + t = np.array([tr["v_teacher"] for tr in transitions]) + y = np.array([tr["v_pred"] for tr in transitions]) + vart = np.var(t) + if vart == 0: + return np.nan + else: + return float(1 - np.var(t - y) / vart) + + +def _make_dataset_recurrent( + episodes, + model, + phi, + batch_states, + obs_normalizer, + gamma, + lambd, + max_recurrent_sequence_len, + device, +): + """Make a list of sequences with necessary information.""" + + _add_log_prob_and_value_to_episodes_recurrent( + episodes=episodes, + model=model, + phi=phi, + batch_states=batch_states, + obs_normalizer=obs_normalizer, + device=device, + ) + + _add_advantage_and_value_target_to_episodes(episodes, gamma=gamma, lambd=lambd) + + if max_recurrent_sequence_len is not None: + dataset = _limit_sequence_length(episodes, max_recurrent_sequence_len) + else: + dataset = list(episodes) + + return dataset + + +def _make_dataset( + episodes, model, phi, batch_states, obs_normalizer, gamma, lambd, device +): + """Make a list of transitions with necessary information.""" + + _add_log_prob_and_value_to_episodes( + episodes=episodes, + model=model, + phi=phi, + batch_states=batch_states, + obs_normalizer=obs_normalizer, + device=device, + ) + + _add_advantage_and_value_target_to_episodes(episodes, gamma=gamma, lambd=lambd) + + return list(itertools.chain.from_iterable(episodes)) + + +def _yield_minibatches(dataset, minibatch_size, num_epochs): + assert dataset + buf = [] + n = 0 + while n < len(dataset) * num_epochs: + while len(buf) < minibatch_size: + buf = random.sample(dataset, k=len(dataset)) + buf + assert len(buf) >= minibatch_size + yield buf[-minibatch_size:] + n += minibatch_size + buf = buf[:-minibatch_size] + + +class HybridPPO(agent.AttributeSavingMixin, agent.BatchAgent): + """Proximal Policy Optimization + + See https://arxiv.org/abs/1707.06347 + + Args: + model (torch.nn.Module): Model to train (including recurrent models) + state s |-> (pi(s, _), v(s)) + optimizer (torch.optim.Optimizer): Optimizer used to train the model + gpu (int): GPU device id if not None nor negative + gamma (float): Discount factor [0, 1] + lambd (float): Lambda-return factor [0, 1] + phi (callable): Feature extractor function + value_func_coef (float): Weight coefficient for loss of + value function (0, inf) + entropy_coef (float): Weight coefficient for entropy bonus [0, inf) + update_interval (int): Model update interval in step + minibatch_size (int): Minibatch size + epochs (int): Training epochs in an update + clip_eps (float): Epsilon for pessimistic clipping of likelihood ratio + to update policy + clip_eps_vf (float): Epsilon for pessimistic clipping of value + to update value function. If it is ``None``, value function is not + clipped on updates. + standardize_advantages (bool): Use standardized advantages on updates + recurrent (bool): If set to True, `model` is assumed to implement + `pfrl.nn.Recurrent` and update in a recurrent + manner. + max_recurrent_sequence_len (int): Maximum length of consecutive + sequences of transitions in a minibatch for updating the model. + This value is used only when `recurrent` is True. A smaller value + will encourage a minibatch to contain more and shorter sequences. + act_deterministically (bool): If set to True, choose most probable + actions in the act method instead of sampling from distributions. + max_grad_norm (float or None): Maximum L2 norm of the gradient used for + gradient clipping. If set to None, the gradient is not clipped. + value_stats_window (int): Window size used to compute statistics + of value predictions. + entropy_stats_window (int): Window size used to compute statistics + of entropy of action distributions. + value_loss_stats_window (int): Window size used to compute statistics + of loss values regarding the value function. + policy_loss_stats_window (int): Window size used to compute statistics + of loss values regarding the policy. + + Statistics: + average_value: Average of value predictions on non-terminal states. + It's updated on (batch_)act_and_train. + average_entropy: Average of entropy of action distributions on + non-terminal states. It's updated on (batch_)act_and_train. + average_value_loss: Average of losses regarding the value function. + It's updated after the model is updated. + average_policy_loss: Average of losses regarding the policy. + It's updated after the model is updated. + n_updates: Number of model updates so far. + explained_variance: Explained variance computed from the last batch. + """ + + saved_attributes = ("model", "optimizer", "obs_normalizer") + + def __init__( + self, + model, + optimizer, + obs_normalizer=None, + gpu=None, + gamma=0.99, + lambd=0.95, + phi=lambda x: x, + value_func_coef=1.0, + entropy_coef=0.01, + update_interval=2048, + minibatch_size=64, + epochs=10, + clip_eps=0.2, + clip_eps_vf=None, + standardize_advantages=True, + batch_states=batch_states, + recurrent=False, + max_recurrent_sequence_len=None, + act_deterministically=False, + max_grad_norm=None, + value_stats_window=1000, + entropy_stats_window=1000, + value_loss_stats_window=100, + policy_loss_stats_window=100, + ): + self.model = model + self.optimizer = optimizer + self.obs_normalizer = obs_normalizer + + if gpu is not None and gpu >= 0: + assert torch.cuda.is_available() + self.device = torch.device("cuda:{}".format(gpu)) + self.model.to(self.device) + if self.obs_normalizer is not None: + self.obs_normalizer.to(self.device) + else: + self.device = torch.device("cpu") + + self.gamma = gamma + self.lambd = lambd + self.phi = phi + self.value_func_coef = value_func_coef + self.entropy_coef = entropy_coef + self.update_interval = update_interval + self.minibatch_size = minibatch_size + self.epochs = epochs + self.clip_eps = clip_eps + self.clip_eps_vf = clip_eps_vf + self.standardize_advantages = standardize_advantages + self.batch_states = batch_states + self.recurrent = recurrent + self.max_recurrent_sequence_len = max_recurrent_sequence_len + self.act_deterministically = act_deterministically + self.max_grad_norm = max_grad_norm + + # Contains episodes used for next update iteration + self.memory = [] + + # Contains transitions of the last episode not moved to self.memory yet + self.last_episode = [] + self.last_state = None + self.last_action = None + + # Batch versions of last_episode, last_state, and last_action + self.batch_last_episode = None + self.batch_last_state = None + self.batch_last_action = None + + # Recurrent states of the model + self.train_recurrent_states = None + self.train_prev_recurrent_states = None + self.test_recurrent_states = None + + self.value_record = collections.deque(maxlen=value_stats_window) + self.c_entropy_record = collections.deque(maxlen=entropy_stats_window) + self.d_entropy_record = collections.deque(maxlen=entropy_stats_window) + self.value_loss_record = collections.deque(maxlen=value_loss_stats_window) + self.policy_loss_record = collections.deque(maxlen=policy_loss_stats_window) + self.explained_variance = np.nan + self.n_updates = 0 + + def _initialize_batch_variables(self, num_envs): + self.batch_last_episode = [[] for _ in range(num_envs)] + self.batch_last_state = [None] * num_envs + self.batch_last_action = [None] * num_envs + + def _update_if_dataset_is_ready(self): + dataset_size = ( + sum(len(episode) for episode in self.memory) + + len(self.last_episode) + + ( + 0 + if self.batch_last_episode is None + else sum(len(episode) for episode in self.batch_last_episode) + ) + ) + if dataset_size >= self.update_interval: + self._flush_last_episode() + if self.recurrent: + dataset = _make_dataset_recurrent( + episodes=self.memory, + model=self.model, + phi=self.phi, + batch_states=self.batch_states, + obs_normalizer=self.obs_normalizer, + gamma=self.gamma, + lambd=self.lambd, + max_recurrent_sequence_len=self.max_recurrent_sequence_len, + device=self.device, + ) + self._update_recurrent(dataset) + else: + dataset = _make_dataset( + episodes=self.memory, + model=self.model, + phi=self.phi, + batch_states=self.batch_states, + obs_normalizer=self.obs_normalizer, + gamma=self.gamma, + lambd=self.lambd, + device=self.device, + ) + assert len(dataset) == dataset_size + self._update(dataset) + self.explained_variance = _compute_explained_variance( + list(itertools.chain.from_iterable(self.memory)) + ) + self.memory = [] + + def _flush_last_episode(self): + if self.last_episode: + self.memory.append(self.last_episode) + self.last_episode = [] + if self.batch_last_episode: + for i, episode in enumerate(self.batch_last_episode): + if episode: + self.memory.append(episode) + self.batch_last_episode[i] = [] + + def _update_obs_normalizer(self, dataset): + assert self.obs_normalizer + states = self.batch_states([b["state"] for b in dataset], self.device, self.phi) + self.obs_normalizer.experience(states) + + def _update(self, dataset): + """Update both the policy and the value function.""" + + device = self.device + + if self.obs_normalizer: + self._update_obs_normalizer(dataset) + + assert "state" in dataset[0] + assert "v_teacher" in dataset[0] + + if self.standardize_advantages: + all_advs = torch.tensor([b["adv"] for b in dataset], device=device) + std_advs, mean_advs = torch.std_mean(all_advs, unbiased=False) + + for batch in _yield_minibatches( + dataset, minibatch_size=self.minibatch_size, num_epochs=self.epochs + ): + states = self.batch_states( + [b["state"] for b in batch], self.device, self.phi + ) + if self.obs_normalizer: + states = self.obs_normalizer(states, update=False) + # actions = torch.tensor([b["action"] for b in batch], device=device) + #Hybrid_actions + c_actions = torch.tensor([b["action"][0] for b in batch], device=device) + d_actions = torch.tensor([b["action"][1] for b in batch], device=device) + + distribs, vs_pred = self.model(states) + + + #Hybrid_actions + c_log_probs = distribs[0].log_prob(c_actions) + d_log_probs = distribs[1].log_prob(d_actions) + + total_entropy = distribs[0].entropy() + distribs[1].entropy() + + advs = torch.tensor( + [b["adv"] for b in batch], dtype=torch.float32, device=device + ) + if self.standardize_advantages: + advs = (advs - mean_advs) / (std_advs + 1e-8) + + c_log_probs_old = torch.tensor( + [b["c_log_prob"] for b in batch], + dtype=torch.float, + device=device, + ) + d_log_probs_old = torch.tensor( + [b["d_log_prob"] for b in batch], + dtype=torch.float, + device=device, + ) + vs_pred_old = torch.tensor( + [b["v_pred"] for b in batch], + dtype=torch.float, + device=device, + ) + vs_teacher = torch.tensor( + [b["v_teacher"] for b in batch], + dtype=torch.float, + device=device, + ) + # Same shape as vs_pred: (batch_size, 1) + vs_pred_old = vs_pred_old[..., None] + vs_teacher = vs_teacher[..., None] + + self.model.zero_grad() + loss = self._lossfun( + # distribs.entropy(), + total_entropy, + vs_pred, + #HYBRID + c_log_probs=c_log_probs, + d_log_probs=d_log_probs, + # distribs.log_prob(actions), + vs_pred_old=vs_pred_old, + #HYBRID + c_log_probs_old=c_log_probs_old, + d_log_probs_old=d_log_probs_old, + # log_probs_old=log_probs_old, + advs=advs, + vs_teacher=vs_teacher, + ) + loss.backward() + if self.max_grad_norm is not None: + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.max_grad_norm + ) + self.optimizer.step() + self.n_updates += 1 + + def _update_once_recurrent(self, episodes, mean_advs, std_advs): + assert std_advs is None or std_advs > 0 + + device = self.device + + # Sort desc by lengths so that pack_sequence does not change the order + episodes = sorted(episodes, key=len, reverse=True) + + flat_transitions = flatten_sequences_time_first(episodes) + + # Prepare data for a recurrent model + seqs_states = [] + for ep in episodes: + states = self.batch_states( + [transition["state"] for transition in ep], + self.device, + self.phi, + ) + if self.obs_normalizer: + states = self.obs_normalizer(states, update=False) + seqs_states.append(states) + + # flat_actions = torch.tensor( + # [transition["action"] for transition in flat_transitions], + # device=device, + # ) + c_flat_actions = torch.tensor( + [transition["action"][0] for transition in flat_transitions], + device=device, + ) + d_flat_actions = torch.tensor( + [transition["action"][1] for transition in flat_transitions], + device=device, + ) + + flat_advs = torch.tensor( + [transition["adv"] for transition in flat_transitions], + dtype=torch.float, + device=device, + ) + if self.standardize_advantages: + flat_advs = (flat_advs - mean_advs) / (std_advs + 1e-8) + # flat_log_probs_old = torch.tensor( + # [transition["log_prob"] for transition in flat_transitions], + # dtype=torch.float, + # device=device, + # ) + c_flat_log_probs_old = torch.tensor( + [transition["c_log_prob"] for transition in flat_transitions], + dtype=torch.float, + device=device, + ) + d_flat_log_probs_old = torch.tensor( + [transition["d_log_prob"] for transition in flat_transitions], + dtype=torch.float, + device=device, + ) + flat_vs_pred_old = torch.tensor( + [[transition["v_pred"]] for transition in flat_transitions], + dtype=torch.float, + device=device, + ) + flat_vs_teacher = torch.tensor( + [[transition["v_teacher"]] for transition in flat_transitions], + dtype=torch.float, + device=device, + ) + + with torch.no_grad(), pfrl.utils.evaluating(self.model): + rs = concatenate_recurrent_states( + [ep[0]["recurrent_state"] for ep in episodes] + ) + + (flat_distribs, flat_vs_pred), _ = pack_and_forward(self.model, seqs_states, rs) + c_flat_log_probs = flat_distribs[0].log_prob(c_flat_actions) + d_flat_log_probs = flat_distribs[1].log_prob(d_flat_actions) + flat_entropy = flat_distribs[0].entropy() + flat_distribs[1].entropy() + + # flat_log_probs = flat_distribs.log_prob(flat_actions) + # flat_entropy = flat_distribs.entropy() + + self.model.zero_grad() + loss = self._lossfun( + entropy=flat_entropy, + vs_pred=flat_vs_pred, + # log_probs=flat_log_probs, + c_log_probs=c_flat_log_probs, + d_log_probs=d_flat_log_probs, + #HYBRID + vs_pred_old=flat_vs_pred_old, + # log_probs_old=flat_log_probs_old, + c_log_probs_old=c_flat_log_probs_old, + d_log_probs_old=d_flat_log_probs_old, + #HYBRID + advs=flat_advs, + vs_teacher=flat_vs_teacher, + ) + loss.backward() + if self.max_grad_norm is not None: + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) + self.optimizer.step() + self.n_updates += 1 + + def _update_recurrent(self, dataset): + """Update both the policy and the value function.""" + + device = self.device + + flat_dataset = list(itertools.chain.from_iterable(dataset)) + if self.obs_normalizer: + self._update_obs_normalizer(flat_dataset) + + assert "state" in flat_dataset[0] + assert "v_teacher" in flat_dataset[0] + + if self.standardize_advantages: + all_advs = torch.tensor([b["adv"] for b in flat_dataset], device=device) + std_advs, mean_advs = torch.std_mean(all_advs, unbiased=False) + else: + mean_advs = None + std_advs = None + + for _ in range(self.epochs): + random.shuffle(dataset) + for minibatch in _yield_subset_of_sequences_with_fixed_number_of_items( + dataset, self.minibatch_size + ): + self._update_once_recurrent(minibatch, mean_advs, std_advs) + + def _lossfun( + self, entropy, vs_pred, c_log_probs, d_log_probs, vs_pred_old, c_log_probs_old, d_log_probs_old, advs, vs_teacher + ): + # prob_ratio = torch.exp(log_probs - log_probs_old) + c_prob_ratio = torch.exp(c_log_probs - c_log_probs_old) + d_prob_ratio = torch.exp(d_log_probs - d_log_probs_old) + + # loss_policy = -torch.mean( + # torch.min( + # prob_ratio * advs, + # torch.clamp(prob_ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advs, + # ), + # ) + + #Hybrid + c_loss_policy = -torch.mean( + torch.min( + c_prob_ratio * advs, + torch.clamp(c_prob_ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advs, + ), + ) + d_loss_policy = -torch.mean( + torch.min( + d_prob_ratio * advs, + torch.clamp(d_prob_ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advs, + ), + ) + + loss_policy = c_loss_policy + d_loss_policy + + if self.clip_eps_vf is None: + loss_value_func = F.mse_loss(vs_pred, vs_teacher) + else: + clipped_vs_pred = _elementwise_clip( + vs_pred, + vs_pred_old - self.clip_eps_vf, + vs_pred_old + self.clip_eps_vf, + ) + loss_value_func = torch.mean( + torch.max( + F.mse_loss(vs_pred, vs_teacher, reduction="none"), + F.mse_loss(clipped_vs_pred, vs_teacher, reduction="none"), + ) + ) + loss_entropy = -torch.mean(entropy) + + self.value_loss_record.append(float(loss_value_func)) + self.policy_loss_record.append(float(loss_policy)) + + loss = ( + loss_policy + + self.value_func_coef * loss_value_func + + self.entropy_coef * loss_entropy + ) + + return loss + + def batch_act(self, batch_obs): + if self.training: + return self._batch_act_train(batch_obs) + else: + return self._batch_act_eval(batch_obs) + + def batch_observe(self, batch_obs, batch_reward, batch_done, batch_reset): + if self.training: + self._batch_observe_train(batch_obs, batch_reward, batch_done, batch_reset) + else: + self._batch_observe_eval(batch_obs, batch_reward, batch_done, batch_reset) + + def _batch_act_eval(self, batch_obs): + assert not self.training + b_state = self.batch_states(batch_obs, self.device, self.phi) + + if self.obs_normalizer: + b_state = self.obs_normalizer(b_state, update=False) + + with torch.no_grad(), pfrl.utils.evaluating(self.model): + if self.recurrent: + (action_distrib, _), self.test_recurrent_states = one_step_forward( + self.model, b_state, self.test_recurrent_states + ) + else: + action_distrib, _ = self.model(b_state) + if self.act_deterministically: + action = mode_of_distribution(action_distrib).cpu().numpy() + else: + # action = action_distrib.sample().cpu().numpy() + action = [(action_distrib[0].sample().cpu().numpy()[0], action_distrib[1].sample().cpu().numpy()[0])] + + return action + + def _batch_act_train(self, batch_obs): + assert self.training + b_state = self.batch_states(batch_obs, self.device, self.phi) + + if self.obs_normalizer: + b_state = self.obs_normalizer(b_state, update=False) + + num_envs = len(batch_obs) + if self.batch_last_episode is None: + self._initialize_batch_variables(num_envs) + assert len(self.batch_last_episode) == num_envs + assert len(self.batch_last_state) == num_envs + assert len(self.batch_last_action) == num_envs + + # action_distrib will be recomputed when computing gradients + with torch.no_grad(), pfrl.utils.evaluating(self.model): + if self.recurrent: + assert self.train_prev_recurrent_states is None + self.train_prev_recurrent_states = self.train_recurrent_states + ( + (action_distrib, batch_value), + self.train_recurrent_states, + ) = one_step_forward( + self.model, b_state, self.train_prev_recurrent_states + ) + else: + action_distrib, batch_value = self.model(b_state) + + self.value_record.extend(batch_value.cpu().numpy()) + if isinstance(action_distrib, tuple): + continous_action_distrib, discrete_action_distrib = action_distrib + batch_action = [( + continous_action_distrib.sample().cpu().numpy()[0], + discrete_action_distrib.sample().cpu().numpy()[0], + )] + + self.c_entropy_record.extend(continous_action_distrib.entropy().cpu().numpy()) + self.d_entropy_record.extend(discrete_action_distrib.entropy().cpu().numpy()) + else: + batch_action = action_distrib.sample().cpu().numpy() + self.entropy_record.extend(action_distrib.entropy().cpu().numpy()) + + self.batch_last_state = list(batch_obs) + self.batch_last_action = list(batch_action) + return batch_action + + def _batch_observe_eval(self, batch_obs, batch_reward, batch_done, batch_reset): + assert not self.training + if self.recurrent: + # Reset recurrent states when episodes end + indices_that_ended = [ + i + for i, (done, reset) in enumerate(zip(batch_done, batch_reset)) + if done or reset + ] + if indices_that_ended: + self.test_recurrent_states = mask_recurrent_state_at( + self.test_recurrent_states, indices_that_ended + ) + + def _batch_observe_train(self, batch_obs, batch_reward, batch_done, batch_reset): + assert self.training + + for i, (state, action, reward, next_state, done, reset) in enumerate( + zip( + self.batch_last_state, + self.batch_last_action, + batch_reward, + batch_obs, + batch_done, + batch_reset, + ) + ): + if state is not None: + assert action is not None + transition = { + "state": state, + "action": action, + "reward": reward, + "next_state": next_state, + "nonterminal": 0.0 if done else 1.0, + } + if self.recurrent: + transition["recurrent_state"] = get_recurrent_state_at( + self.train_prev_recurrent_states, i, detach=True + ) + transition["next_recurrent_state"] = get_recurrent_state_at( + self.train_recurrent_states, i, detach=True + ) + self.batch_last_episode[i].append(transition) + if done or reset: + assert self.batch_last_episode[i] + self.memory.append(self.batch_last_episode[i]) + self.batch_last_episode[i] = [] + self.batch_last_state[i] = None + self.batch_last_action[i] = None + + self.train_prev_recurrent_states = None + + if self.recurrent: + # Reset recurrent states when episodes end + indices_that_ended = [ + i + for i, (done, reset) in enumerate(zip(batch_done, batch_reset)) + if done or reset + ] + if indices_that_ended: + self.train_recurrent_states = mask_recurrent_state_at( + self.train_recurrent_states, indices_that_ended + ) + + self._update_if_dataset_is_ready() + + def get_statistics(self): + return [ + ("average_value", _mean_or_nan(self.value_record)), + ("c_average_entropy", _mean_or_nan(self.c_entropy_record)), + ("d_average_entropy", _mean_or_nan(self.d_entropy_record)), + ("average_value_loss", _mean_or_nan(self.value_loss_record)), + ("average_policy_loss", _mean_or_nan(self.policy_loss_record)), + ("n_updates", self.n_updates), + ("explained_variance", self.explained_variance), + ] diff --git a/pfrl/agents/hybrid_soft_actor_critic.py b/pfrl/agents/hybrid_soft_actor_critic.py new file mode 100644 index 000000000..19985e626 --- /dev/null +++ b/pfrl/agents/hybrid_soft_actor_critic.py @@ -0,0 +1,479 @@ +import collections +import copy +from logging import getLogger + +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + +import pfrl +from pfrl.agent import AttributeSavingMixin, BatchAgent +from pfrl.replay_buffer import ReplayUpdater, batch_experiences,hybrid_batch_experiences +from pfrl.utils import clip_l2_grad_norm_ +from pfrl.utils.batch_states import batch_states +from pfrl.utils.copy_param import synchronize_parameters +from pfrl.utils.mode_of_distribution import mode_of_distribution + + +def _mean_or_nan(xs): + """Return its mean a non-empty sequence, numpy.nan for a empty one.""" + return np.mean(xs) if xs else np.nan + + +class TemperatureHolder(nn.Module): + """Module that holds a temperature as a learnable value. + + Args: + initial_log_temperature (float): Initial value of log(temperature). + """ + + def __init__(self, initial_log_temperature=0): + super().__init__() + self.log_temperature = nn.Parameter( + torch.tensor(initial_log_temperature, dtype=torch.float32) + ) + + def forward(self): + """Return a temperature as a torch.Tensor.""" + return torch.exp(self.log_temperature) + + +class HybridSoftActorCritic(AttributeSavingMixin, BatchAgent): + """Soft Actor-Critic (SAC). + + See https://arxiv.org/abs/1812.05905 + + Args: + policy (Policy): Policy. + q_func1 (Module): First Q-function that takes state-action pairs as input + and outputs predicted Q-values. + q_func2 (Module): Second Q-function that takes state-action pairs as + input and outputs predicted Q-values. + policy_optimizer (Optimizer): Optimizer setup with the policy + q_func1_optimizer (Optimizer): Optimizer setup with the first + Q-function. + q_func2_optimizer (Optimizer): Optimizer setup with the second + Q-function. + replay_buffer (ReplayBuffer): Replay buffer + gamma (float): Discount factor + gpu (int): GPU device id if not None nor negative. + replay_start_size (int): if the replay buffer's size is less than + replay_start_size, skip update + minibatch_size (int): Minibatch size + update_interval (int): Model update interval in step + phi (callable): Feature extractor applied to observations + soft_update_tau (float): Tau of soft target update. + logger (Logger): Logger used + batch_states (callable): method which makes a batch of observations. + default is `pfrl.utils.batch_states.batch_states` + burnin_action_func (callable or None): If not None, this callable + object is used to select actions before the model is updated + one or more times during training. + initial_temperature (float): Initial temperature value. If + `entropy_target` is set to None, the temperature is fixed to it. + entropy_target (float or None): If set to a float, the temperature is + adjusted during training to match the policy's entropy to it. + temperature_optimizer_lr (float): Learning rate of the temperature + optimizer. If set to None, Adam with default hyperparameters + is used. + act_deterministically (bool): If set to True, choose most probable + actions in the act method instead of sampling from distributions. + """ + + saved_attributes = ( + "policy", + "q_func1", + "q_func2", + "target_q_func1", + "target_q_func2", + "policy_optimizer", + "q_func1_optimizer", + "q_func2_optimizer", + "c_temperature_holder", + "c_temperature_optimizer", + "d_temperature_holder", + "d_temperature_optimizer", + ) + + def __init__( + self, + policy, + q_func1, + q_func2, + policy_optimizer, + q_func1_optimizer, + q_func2_optimizer, + replay_buffer, + gamma, + gpu=None, + replay_start_size=10000, + minibatch_size=100, + update_interval=1, + phi=lambda x: x, + soft_update_tau=5e-3, + max_grad_norm=None, + logger=getLogger(__name__), + batch_states=batch_states, + burnin_action_func=None, + initial_temperature=1.0, + # entropy_target=None, + c_entropy_target = None, + d_entropy_target = None, + temperature_optimizer_lr=None, + act_deterministically=True, + ): + self.policy = policy + self.q_func1 = q_func1 + self.q_func2 = q_func2 + + if gpu is not None and gpu >= 0: + assert torch.cuda.is_available() + self.device = torch.device("cuda:{}".format(gpu)) + self.policy.to(self.device) + self.q_func1.to(self.device) + self.q_func2.to(self.device) + else: + self.device = torch.device("cpu") + + self.replay_buffer = replay_buffer + self.gamma = gamma + self.gpu = gpu + self.phi = phi + self.soft_update_tau = soft_update_tau + self.logger = logger + self.policy_optimizer = policy_optimizer + self.q_func1_optimizer = q_func1_optimizer + self.q_func2_optimizer = q_func2_optimizer + self.replay_updater = ReplayUpdater( + replay_buffer=replay_buffer, + update_func=self.update, + batchsize=minibatch_size, + n_times_update=1, + replay_start_size=replay_start_size, + update_interval=update_interval, + episodic_update=False, + ) + self.max_grad_norm = max_grad_norm + self.batch_states = batch_states + self.burnin_action_func = burnin_action_func + self.initial_temperature = initial_temperature + # self.entropy_target = entropy_target + self.c_entropy_target = c_entropy_target + self.d_entropy_target = d_entropy_target + if self.c_entropy_target is not None: + self.c_temperature_holder = TemperatureHolder( + initial_log_temperature=np.log(initial_temperature) + ) + if temperature_optimizer_lr is not None: + self.c_temperature_optimizer = torch.optim.Adam( + self.c_temperature_holder.parameters(), lr=temperature_optimizer_lr + ) + else: + self.c_temperature_optimizer = torch.optim.Adam( + self.c_temperature_holder.parameters() + ) + if gpu is not None and gpu >= 0: + self.c_temperature_holder.to(self.device) + else: + self.c_temperature_holder = None + self.c_temperature_optimizer = None + + if self.d_entropy_target is not None: + self.d_temperature_holder = TemperatureHolder( + initial_log_temperature=np.log(initial_temperature) + ) + if temperature_optimizer_lr is not None: + self.d_temperature_optimizer = torch.optim.Adam( + self.d_temperature_holder.parameters(), lr=temperature_optimizer_lr + ) + else: + self.d_temperature_optimizer = torch.optim.Adam( + self.d_temperature_holder.parameters() + ) + if gpu is not None and gpu >= 0: + self.d_temperature_holder.to(self.device) + else: + self.d_temperature_holder = None + self.d_temperature_optimizer = None + + self.act_deterministically = act_deterministically + + self.t = 0 + + # Target model + self.target_q_func1 = copy.deepcopy(self.q_func1).eval().requires_grad_(False) + self.target_q_func2 = copy.deepcopy(self.q_func2).eval().requires_grad_(False) + + # Statistics + self.q1_record = collections.deque(maxlen=1000) + self.q2_record = collections.deque(maxlen=1000) + # self.entropy_record = collections.deque(maxlen=1000) + self.c_entropy_record = collections.deque(maxlen=1000) + self.d_entropy_record = collections.deque(maxlen=1000) + self.q_func1_loss_record = collections.deque(maxlen=100) + self.q_func2_loss_record = collections.deque(maxlen=100) + self.n_policy_updates = 0 + + @property + def c_temperature(self): + if self.c_entropy_target is None: + return self.initial_temperature + else: + with torch.no_grad(): + return float(self.c_temperature_holder()) + + @property + def d_temperature(self): + if self.d_entropy_target is None: + return self.initial_temperature + else: + with torch.no_grad(): + return float(self.d_temperature_holder()) + + def sync_target_network(self): + """Synchronize target network with current network.""" + synchronize_parameters( + src=self.q_func1, + dst=self.target_q_func1, + method="soft", + tau=self.soft_update_tau, + ) + synchronize_parameters( + src=self.q_func2, + dst=self.target_q_func2, + method="soft", + tau=self.soft_update_tau, + ) + + def update_q_func(self, batch): + """Compute loss for a given Q-function.""" + + batch_next_state = batch["next_state"] + batch_rewards = batch["reward"] + batch_terminal = batch["is_state_terminal"] + batch_state = batch["state"] + # batch_actions = batch["action"] + batch_c_actions = batch["c_action"] + batch_d_actions = batch["d_action"] + batch_discount = batch["discount"] + + with torch.no_grad(), pfrl.utils.evaluating(self.policy), pfrl.utils.evaluating( + self.target_q_func1 + ), pfrl.utils.evaluating(self.target_q_func2): + # next_action_distrib = self.policy(batch_next_state) + c_next_action_distrib, d_next_action_distrib = self.policy(batch_next_state) + # next_actions = next_action_distrib.sample() + # next_log_prob = next_action_distrib.log_prob(next_actions) + c_next_action = c_next_action_distrib.sample() + d_next_action = d_next_action_distrib.sample() + c_next_log_prob = c_next_action_distrib.log_prob(c_next_action) + d_next_log_prob = d_next_action_distrib.log_prob(d_next_action) + d_next_prob = torch.exp(d_next_log_prob) + # next_q1 = self.target_q_func1((batch_next_state, next_actions)) + # next_q2 = self.target_q_func2((batch_next_state, next_actions)) + next_q1 = self.target_q_func1((batch_next_state, (c_next_action, d_next_action))) + next_q2 = self.target_q_func2((batch_next_state, (c_next_action, d_next_action))) + + next_q = torch.min(next_q1, next_q2) + # entropy_term = self.c_temperature * next_log_prob[..., None] + # entropy_term = self.c_temperature * c_next_log_prob + self.d_temperature * d_next_log_prob + entropy_term = self.c_temperature * c_next_log_prob * d_next_prob + self.d_temperature * d_next_log_prob + assert next_q.shape == entropy_term.shape + + # target_q = batch_rewards + batch_discount * ( + # 1.0 - batch_terminal + # ) * torch.flatten(next_q - entropy_term) + target_q = batch_rewards + batch_discount * ( + 1.0 - batch_terminal + ) * torch.flatten(next_q - entropy_term) * d_next_prob + + predict_q1 = self.q_func1((batch_state, (batch_c_actions, batch_d_actions))) + predict_q2 = self.q_func2((batch_state, (batch_c_actions, batch_d_actions))) + + + loss1 = 0.5 * F.mse_loss(target_q, predict_q1) + loss2 = 0.5 * F.mse_loss(target_q, predict_q2) + + # Update stats + self.q1_record.extend(predict_q1.detach().cpu().numpy()) + self.q2_record.extend(predict_q2.detach().cpu().numpy()) + self.q_func1_loss_record.append(loss1.item()) + self.q_func2_loss_record.append(loss2.item()) + + self.q_func1_optimizer.zero_grad() + loss1.backward() + if self.max_grad_norm is not None: + clip_l2_grad_norm_(self.q_func1.parameters(), self.max_grad_norm) + self.q_func1_optimizer.step() + + self.q_func2_optimizer.zero_grad() + loss2.backward() + if self.max_grad_norm is not None: + clip_l2_grad_norm_(self.q_func2.parameters(), self.max_grad_norm) + self.q_func2_optimizer.step() + + #Continuous update temperature + def c_update_temperature(self, log_prob): + assert not log_prob.requires_grad + # loss = -torch.mean(self.temperature_holder() * (log_prob + self.entropy_target)) + loss = -torch.mean(self.c_temperature_holder() * (log_prob + self.c_entropy_target)) + # self.temperature_optimizer.zero_grad() + self.c_temperature_optimizer.zero_grad() + loss.backward() + if self.max_grad_norm is not None: + # clip_l2_grad_norm_(self.temperature_holder.parameters(), self.max_grad_norm) + clip_l2_grad_norm_(self.c_temperature_holder.parameters(), self.max_grad_norm) + # self.temperature_optimizer.step() + self.c_temperature_optimizer.step() + + #Discrete temperature + def d_update_temperature(self, log_prob): + assert not log_prob.requires_grad + # loss = -torch.mean(self.temperature_holder() * (log_prob + self.entropy_target)) + loss = -torch.mean(self.d_temperature_holder() * (log_prob + self.d_entropy_target)) + # self.temperature_optimizer.zero_grad() + self.d_temperature_optimizer.zero_grad() + loss.backward() + if self.max_grad_norm is not None: + # clip_l2_grad_norm_(self.temperature_holder.parameters(), self.max_grad_norm) + clip_l2_grad_norm_(self.d_temperature_holder.parameters(), self.max_grad_norm) + # self.temperature_optimizer.step() + self.d_temperature_optimizer.step() + + def update_policy_and_temperature(self, batch): + """Compute loss for actor.""" + + batch_state = batch["state"] + + # action_distrib = self.policy(batch_state) + c_action_distrib, d_action_distrib = self.policy(batch_state) + # actions = action_distrib.rsample() + # log_prob = action_distrib.log_prob(actions) + c_actions = c_action_distrib.rsample() + d_actions = d_action_distrib.sample() + c_log_prob = c_action_distrib.log_prob(c_actions) + d_log_prob = d_action_distrib.log_prob(d_actions) + d_prob = torch.exp(d_log_prob) + + # q1 = self.q_func1((batch_state, actions)) + # q2 = self.q_func2((batch_state, actions)) + q1 = self.q_func1((batch_state, (c_actions, d_actions))) + q2 = self.q_func2((batch_state, (c_actions, d_actions))) + q = torch.min(q1, q2) + + # entropy_term = self.c_temperature * log_prob[..., None] + c_entropy_term = self.c_temperature * c_log_prob * d_prob + d_entropy_term = self.d_temperature * d_log_prob + # assert q.shape == entropy_term.shape + # loss = torch.mean(entropy_term - q) + loss_d = torch.mean(d_prob * (d_entropy_term - q)) + loss_c = torch.mean(d_prob * (c_entropy_term - q)) + loss = loss_c + loss_d + + self.policy_optimizer.zero_grad() + loss.backward() + if self.max_grad_norm is not None: + clip_l2_grad_norm_(self.policy.parameters(), self.max_grad_norm) + self.policy_optimizer.step() + + self.n_policy_updates += 1 + + if self.c_entropy_target is not None: + self.c_update_temperature(c_log_prob.detach()) + if self.d_entropy_target is not None: + self.d_update_temperature(d_log_prob.detach()) + + # Record entropy + with torch.no_grad(): + try: + self.c_entropy_record.extend( + c_action_distrib.entropy().detach().cpu().numpy() + ) + self.d_entropy_record.extend( + d_action_distrib.entropy().detach().cpu().numpy() + ) + except NotImplementedError: + # Record - log p(x) instead + # self.entropy_record.extend(-log_prob.detach().cpu().numpy()) + self.c_entropy_record.extend(-c_log_prob.detach().cpu().numpy()) + self.d_entropy_record.extend(-d_log_prob.detach().cpu().numpy()) + + def update(self, experiences, errors_out=None): + """Update the model from experiences""" + batch = hybrid_batch_experiences(experiences, self.device, self.phi, self.gamma) + self.update_q_func(batch) + self.update_policy_and_temperature(batch) + self.sync_target_network() + + def batch_select_greedy_action(self, batch_obs, deterministic=False): + with torch.no_grad(), pfrl.utils.evaluating(self.policy): + batch_xs = self.batch_states(batch_obs, self.device, self.phi) + policy_out = self.policy(batch_xs) + if deterministic: + batch_action = [(mode_of_distribution(policy_out[0]).cpu().numpy()[0], mode_of_distribution(policy_out[1]).cpu().numpy()[0])] + else: + batch_action = [(policy_out[0].sample().cpu().numpy()[0], policy_out[1].sample().cpu().numpy()[0])] + return batch_action + + def batch_act(self, batch_obs): + if self.training: + return self._batch_act_train(batch_obs) + else: + return self._batch_act_eval(batch_obs) + + def batch_observe(self, batch_obs, batch_reward, batch_done, batch_reset): + if self.training: + self._batch_observe_train(batch_obs, batch_reward, batch_done, batch_reset) + + def _batch_act_eval(self, batch_obs): + assert not self.training + return self.batch_select_greedy_action( + batch_obs, deterministic=self.act_deterministically + ) + + def _batch_act_train(self, batch_obs): + assert self.training + if self.burnin_action_func is not None and self.n_policy_updates == 0: + batch_action = [self.burnin_action_func() for _ in range(len(batch_obs))] + else: + batch_action = self.batch_select_greedy_action(batch_obs) + self.batch_last_obs = list(batch_obs) + self.batch_last_action = list(batch_action) + + return batch_action + + def _batch_observe_train(self, batch_obs, batch_reward, batch_done, batch_reset): + assert self.training + for i in range(len(batch_obs)): + self.t += 1 + if self.batch_last_obs[i] is not None: + assert self.batch_last_action[i] is not None + # Add a transition to the replay buffer + self.replay_buffer.append( + state=self.batch_last_obs[i], + action=self.batch_last_action[i], + reward=batch_reward[i], + next_state=batch_obs[i], + next_action=None, + is_state_terminal=batch_done[i], + env_id=i, + ) + if batch_reset[i] or batch_done[i]: + self.batch_last_obs[i] = None + self.batch_last_action[i] = None + self.replay_buffer.stop_current_episode(env_id=i) + self.replay_updater.update_if_necessary(self.t) + + def get_statistics(self): + return [ + ("average_q1", _mean_or_nan(self.q1_record)), + ("average_q2", _mean_or_nan(self.q2_record)), + ("average_q_func1_loss", _mean_or_nan(self.q_func1_loss_record)), + ("average_q_func2_loss", _mean_or_nan(self.q_func2_loss_record)), + ("n_updates", self.n_policy_updates), + ("c_average_entropy", _mean_or_nan(self.c_entropy_record)), + ("d_average_entropy", _mean_or_nan(self.d_entropy_record)), + ("c_temperature", self.c_temperature), + ("d_temperature", self.d_temperature), + ] diff --git a/pfrl/experiments/__init__.py b/pfrl/experiments/__init__.py index e4e79e49d..653d58da2 100644 --- a/pfrl/experiments/__init__.py +++ b/pfrl/experiments/__init__.py @@ -8,6 +8,7 @@ from pfrl.experiments.prepare_output_dir import prepare_output_dir # NOQA from pfrl.experiments.train_agent import train_agent # NOQA from pfrl.experiments.train_agent import train_agent_with_evaluation # NOQA +from pfrl.experiments.train_agent_RNN import train_agent_with_evaluation_RNN # NOQA from pfrl.experiments.train_agent_async import train_agent_async # NOQA from pfrl.experiments.train_agent_batch import train_agent_batch # NOQA from pfrl.experiments.train_agent_batch import train_agent_batch_with_evaluation # NOQA diff --git a/pfrl/experiments/evaluator_RNN.py b/pfrl/experiments/evaluator_RNN.py new file mode 100644 index 000000000..89e98c543 --- /dev/null +++ b/pfrl/experiments/evaluator_RNN.py @@ -0,0 +1,679 @@ +import logging +import multiprocessing as mp +import os +import statistics +import time + +import numpy as np + +import pfrl + +def _run_episodes( + env, + agent, + n_steps, + n_episodes, + max_episode_len=None, + logger=None, +): + """Run multiple episodes and return returns.""" + assert (n_steps is None) != (n_episodes is None) + + logger = logger or logging.getLogger(__name__) + scores = [] + lengths = [] + terminated = False + timestep = 0 + + reset = True + import torch + prev_act = torch.zeros(1,1) # To store the previous action taken by the agent + prev_reward = torch.zeros(1) # To store the previous reward received by the agent + while not terminated: + if reset: + obs, info = env.reset() + terminated = False + test_r = 0 + episode_len = 0 + info = {} + a = agent.act(obs, prev_act, prev_reward) + prev_act = torch.tensor(a).unsqueeze(0).unsqueeze(0) # Reshape to (1, 1) for RNN input + + obs, r, terminated, truncated, info = env.step(a) + test_r += r + prev_reward = torch.tensor(r).unsqueeze(0) + episode_len += 1 + timestep += 1 + reset = terminated or episode_len == max_episode_len or info.get("needs_reset", False) or truncated + agent.observe(obs, r, terminated, reset) + if reset: + logger.info( + "evaluation episode %s length:%s R:%s", len(scores), episode_len, test_r + ) + # As mixing float and numpy float causes errors in statistics + # functions, here every score is cast to float. + scores.append(float(test_r)) + lengths.append(float(episode_len)) + if n_steps is None: + terminated = len(scores) >= n_episodes + else: + terminated = timestep >= n_steps + # If all steps were used for a single unfinished episode + if len(scores) == 0: + scores.append(float(test_r)) + lengths.append(float(episode_len)) + logger.info( + "evaluation episode %s length:%s R:%s", len(scores), episode_len, test_r + ) + return scores, lengths + + +def run_evaluation_episodes( + env, + agent, + n_steps, + n_episodes, + max_episode_len=None, + logger=None, +): + """Run multiple evaluation episodes and return returns. + + Args: + env (Environment): Environment used for evaluation + agent (Agent): Agent to evaluate. + n_steps (int): Number of timesteps to evaluate for. + n_episodes (int): Number of evaluation runs. + max_episode_len (int or None): If specified, episodes longer than this + value will be truncated. + logger (Logger or None): If specified, the given Logger object will be + used for logging results. If not specified, the default logger of + this module will be used. + Returns: + List of returns of evaluation runs. + """ + with agent.eval_mode(): + return _run_episodes( + env=env, + agent=agent, + n_steps=n_steps, + n_episodes=n_episodes, + max_episode_len=max_episode_len, + logger=logger, + ) + + +def _batch_run_episodes( + env, + agent, + n_steps, + n_episodes, + max_episode_len=None, + logger=None, +): + """Run multiple episodes and return returns in a batch manner.""" + assert (n_steps is None) != (n_episodes is None) + + logger = logger or logging.getLogger(__name__) + num_envs = env.num_envs + episode_returns = dict() + episode_lengths = dict() + episode_indices = np.zeros(num_envs, dtype="i") + episode_idx = 0 + for i in range(num_envs): + episode_indices[i] = episode_idx + episode_idx += 1 + episode_r = np.zeros(num_envs, dtype=np.float64) + episode_len = np.zeros(num_envs, dtype="i") + + obss, infos = env.reset() + rs = np.zeros(num_envs, dtype="f") + + termination_conditions = False + timestep = 0 + while True: + # a_t + actions = agent.batch_act(obss) + timestep += 1 + # o_{t+1}, r_{t+1} + obss, rs, terminations, truncations, infos = env.step(actions) + episode_r += rs + episode_len += 1 + # Compute mask for done and reset + if max_episode_len is None: + resets = np.zeros(num_envs, dtype=bool) + else: + resets = episode_len == max_episode_len + resets = np.logical_or( + resets, [info.get("needs_reset", False) or truncated for truncated, info in zip(truncations, infos)] + ) + + # Make mask. 0 if done/reset, 1 if pass + end = np.logical_or(resets, terminations) + not_end = np.logical_not(end) + + for index in range(len(end)): + if end[index]: + episode_returns[episode_indices[index]] = episode_r[index] + episode_lengths[episode_indices[index]] = episode_len[index] + # Give the new episode an a new episode index + episode_indices[index] = episode_idx + episode_idx += 1 + + episode_r[end] = 0 + episode_len[end] = 0 + + # find first unfinished episode + first_unfinished_episode = 0 + while first_unfinished_episode in episode_returns: + first_unfinished_episode += 1 + + # Check for termination conditions + eval_episode_returns = [] + eval_episode_lens = [] + if n_steps is not None: + total_time = 0 + for index in range(first_unfinished_episode): + total_time += episode_lengths[index] + # If you will run over allocated steps, quit + if total_time > n_steps: + break + else: + eval_episode_returns.append(episode_returns[index]) + eval_episode_lens.append(episode_lengths[index]) + termination_conditions = total_time >= n_steps + if not termination_conditions: + unfinished_index = np.where( + episode_indices == first_unfinished_episode + )[0] + if total_time + episode_len[unfinished_index] >= n_steps: + termination_conditions = True + if first_unfinished_episode == 0: + eval_episode_returns.append(episode_r[unfinished_index]) + eval_episode_lens.append(episode_len[unfinished_index]) + + else: + termination_conditions = first_unfinished_episode >= n_episodes + if termination_conditions: + # Get the first n completed episodes + for index in range(n_episodes): + eval_episode_returns.append(episode_returns[index]) + eval_episode_lens.append(episode_lengths[index]) + + if termination_conditions: + # If this is the last step, make sure the agent observes reset=True + resets.fill(True) + + # Agent observes the consequences. + agent.batch_observe(obss, rs, terminations, resets) + + if termination_conditions: + break + else: + obss, infos = env.reset(not_end) + + for i, (epi_len, epi_ret) in enumerate( + zip(eval_episode_lens, eval_episode_returns) + ): + logger.info("evaluation episode %s length: %s R: %s", i, epi_len, epi_ret) + scores = [float(r) for r in eval_episode_returns] + lengths = [float(ln) for ln in eval_episode_lens] + return scores, lengths + + +def batch_run_evaluation_episodes( + env, + agent, + n_steps, + n_episodes, + max_episode_len=None, + logger=None, +): + """Run multiple evaluation episodes and return returns in a batch manner. + + Args: + env (VectorEnv): Environment used for evaluation. + agent (Agent): Agent to evaluate. + n_steps (int): Number of total timesteps to evaluate the agent. + n_episodes (int): Number of evaluation runs. + max_episode_len (int or None): If specified, episodes + longer than this value will be truncated. + logger (Logger or None): If specified, the given Logger + object will be used for logging results. If not + specified, the default logger of this module will + be used. + + Returns: + List of returns of evaluation runs. + """ + with agent.eval_mode(): + return _batch_run_episodes( + env=env, + agent=agent, + n_steps=n_steps, + n_episodes=n_episodes, + max_episode_len=max_episode_len, + logger=logger, + ) + + +def eval_performance( + env, agent, n_steps, n_episodes, max_episode_len=None, logger=None +): + """Run multiple evaluation episodes and return statistics. + + Args: + env (Environment): Environment used for evaluation + agent (Agent): Agent to evaluate. + n_steps (int): Number of timesteps to evaluate for. + n_episodes (int): Number of evaluation episodes. + max_episode_len (int or None): If specified, episodes longer than this + value will be truncated. + logger (Logger or None): If specified, the given Logger object will be + used for logging results. If not specified, the default logger of + this module will be used. + Returns: + Dict of statistics. + """ + + assert (n_steps is None) != (n_episodes is None) + + if isinstance(env, pfrl.env.VectorEnv): + scores, lengths = batch_run_evaluation_episodes( + env, + agent, + n_steps, + n_episodes, + max_episode_len=max_episode_len, + logger=logger, + ) + else: + scores, lengths = run_evaluation_episodes( + env, + agent, + n_steps, + n_episodes, + max_episode_len=max_episode_len, + logger=logger, + ) + stats = dict( + episodes=len(scores), + mean=statistics.mean(scores), + median=statistics.median(scores), + stdev=statistics.stdev(scores) if len(scores) >= 2 else 0.0, + max=np.max(scores), + min=np.min(scores), + length_mean=statistics.mean(lengths), + length_median=statistics.median(lengths), + length_stdev=statistics.stdev(lengths) if len(lengths) >= 2 else 0, + length_max=np.max(lengths), + length_min=np.min(lengths), + ) + return stats + + +def record_stats(outdir, values): + with open(os.path.join(outdir, "scores.txt"), "a+") as f: + print("\t".join(str(x) for x in values), file=f) + + +def create_tb_writer(outdir): + """Return a tensorboard summarywriter with a custom scalar.""" + # This conditional import will raise an error if tensorboard<1.14 + from torch.utils.tensorboard import SummaryWriter + + tb_writer = SummaryWriter(log_dir=outdir) + layout = { + "Aggregate Charts": { + "mean w/ min-max": [ + "Margin", + ["eval/mean", "eval/min", "eval/max"], + ], + "mean +/- std": [ + "Margin", + ["eval/mean", "extras/meanplusstdev", "extras/meanminusstdev"], + ], + } + } + tb_writer.add_custom_scalars(layout) + return tb_writer + + +def record_tb_stats(summary_writer, agent_stats, eval_stats, env_stats, t): + cur_time = time.time() + + for stat, value in agent_stats: + summary_writer.add_scalar("agent/" + stat, value, t, cur_time) + + for stat, value in env_stats: + summary_writer.add_scalar("env/" + stat, value, t, cur_time) + + for stat in ("mean", "median", "max", "min", "stdev"): + value = eval_stats[stat] + summary_writer.add_scalar("eval/" + stat, value, t, cur_time) + + summary_writer.add_scalar( + "extras/meanplusstdev", eval_stats["mean"] + eval_stats["stdev"], t, cur_time + ) + summary_writer.add_scalar( + "extras/meanminusstdev", eval_stats["mean"] - eval_stats["stdev"], t, cur_time + ) + + # manually flush to avoid loosing events on termination + summary_writer.flush() + + +def record_tb_stats_loop(outdir, queue, stop_event): + tb_writer = create_tb_writer(outdir) + + while not (stop_event.wait(1e-6) and queue.empty()): + if not queue.empty(): + agent_stats, eval_stats, env_stats, t = queue.get() + record_tb_stats(tb_writer, agent_stats, eval_stats, env_stats, t) + + +def save_agent(agent, t, outdir, logger, suffix=""): + dirname = os.path.join(outdir, "{}{}".format(t, suffix)) + agent.save(dirname) + logger.info("Saved the agent to %s", dirname) + + +def write_header(outdir, agent, env): + # Columns that describe information about an experiment. + basic_columns = ( + "steps", # number of time steps taken (= number of actions taken) + "episodes", # number of episodes finished + "elapsed", # time elapsed so far (seconds) + "mean", # mean of returns of evaluation runs + "median", # median of returns of evaluation runs + "stdev", # stdev of returns of evaluation runs + "max", # maximum value of returns of evaluation runs + "min", # minimum value of returns of evaluation runs + ) + with open(os.path.join(outdir, "scores.txt"), "w") as f: + custom_columns = tuple(t[0] for t in agent.get_statistics()) + env_get_stats = getattr(env, "get_statistics", lambda: []) + assert callable(env_get_stats) + custom_env_columns = tuple(t[0] for t in env_get_stats()) + column_names = basic_columns + custom_columns + custom_env_columns + print("\t".join(column_names), file=f) + + +class Evaluator(object): + """Object that is responsible for evaluating a given agent. + + Args: + agent (Agent): Agent to evaluate. + env (Env): Env to evaluate the agent on. + n_steps (int): Number of timesteps used in each evaluation. + n_episodes (int): Number of episodes used in each evaluation. + eval_interval (int): Interval of evaluations in steps. + outdir (str): Path to a directory to save things. + max_episode_len (int): Maximum length of episodes used in evaluations. + step_offset (int): Offset of steps used to schedule evaluations. + evaluation_hooks (Sequence): Sequence of + pfrl.experiments.evaluation_hooks.EvaluationHook objects. They are + called after each evaluation. + save_best_so_far_agent (bool): If set to True, after each evaluation, + if the score (= mean of returns in evaluation episodes) exceeds + the best-so-far score, the current agent is saved. + use_tensorboard (bool): Additionally log eval stats to tensorboard + """ + + def __init__( + self, + agent, + env, + n_steps, + n_episodes, + eval_interval, + outdir, + max_episode_len=None, + step_offset=0, + evaluation_hooks=(), + save_best_so_far_agent=True, + logger=None, + use_tensorboard=False, + ): + assert (n_steps is None) != (n_episodes is None), ( + "One of n_steps or n_episodes must be None. " + + "Either we evaluate for a specified number " + + "of episodes or for a specified number of timesteps." + ) + self.agent = agent + self.env = env + self.max_score = np.finfo(np.float32).min + self.start_time = time.time() + self.n_steps = n_steps + self.n_episodes = n_episodes + self.eval_interval = eval_interval + self.outdir = outdir + self.use_tensorboard = use_tensorboard + self.max_episode_len = max_episode_len + self.step_offset = step_offset + self.prev_eval_t = self.step_offset - self.step_offset % self.eval_interval + self.evaluation_hooks = evaluation_hooks + self.save_best_so_far_agent = save_best_so_far_agent + self.logger = logger or logging.getLogger(__name__) + self.env_get_stats = getattr(self.env, "get_statistics", lambda: []) + self.env_clear_stats = getattr(self.env, "clear_statistics", lambda: None) + assert callable(self.env_get_stats) + assert callable(self.env_clear_stats) + + # Write a header line first + write_header(self.outdir, self.agent, self.env) + + if use_tensorboard: + self.tb_writer = create_tb_writer(outdir) + + def evaluate_and_update_max_score(self, t, episodes): + self.env_clear_stats() + eval_stats = eval_performance( + self.env, + self.agent, + self.n_steps, + self.n_episodes, + max_episode_len=self.max_episode_len, + logger=self.logger, + ) + elapsed = time.time() - self.start_time + agent_stats = self.agent.get_statistics() + custom_values = tuple(tup[1] for tup in agent_stats) + env_stats = self.env_get_stats() + custom_env_values = tuple(tup[1] for tup in env_stats) + mean = eval_stats["mean"] + values = ( + ( + t, + episodes, + elapsed, + mean, + eval_stats["median"], + eval_stats["stdev"], + eval_stats["max"], + eval_stats["min"], + ) + + custom_values + + custom_env_values + ) + record_stats(self.outdir, values) + + if self.use_tensorboard: + record_tb_stats(self.tb_writer, agent_stats, eval_stats, env_stats, t) + + for hook in self.evaluation_hooks: + hook( + env=self.env, + agent=self.agent, + evaluator=self, + step=t, + eval_stats=eval_stats, + agent_stats=agent_stats, + env_stats=env_stats, + ) + + if mean > self.max_score: + self.logger.info("The best score is updated %s -> %s", self.max_score, mean) + self.max_score = mean + if self.save_best_so_far_agent: + save_agent(self.agent, "best", self.outdir, self.logger) + return mean + + def evaluate_if_necessary(self, t, episodes): + if t >= self.prev_eval_t + self.eval_interval: + score = self.evaluate_and_update_max_score(t, episodes) + self.prev_eval_t = t - t % self.eval_interval + return score + return None + + +class AsyncEvaluator(object): + """Object that is responsible for evaluating asynchronous multiple agents. + + Args: + n_steps (int): Number of timesteps used in each evaluation. + n_episodes (int): Number of episodes used in each evaluation. + eval_interval (int): Interval of evaluations in steps. + outdir (str): Path to a directory to save things. + max_episode_len (int): Maximum length of episodes used in evaluations. + step_offset (int): Offset of steps used to schedule evaluations. + evaluation_hooks (Sequence): Sequence of + pfrl.experiments.evaluation_hooks.EvaluationHook objects. They are + called after each evaluation. + save_best_so_far_agent (bool): If set to True, after each evaluation, + if the score (= mean return of evaluation episodes) exceeds + the best-so-far score, the current agent is saved. + """ + + def __init__( + self, + n_steps, + n_episodes, + eval_interval, + outdir, + max_episode_len=None, + step_offset=0, + evaluation_hooks=(), + save_best_so_far_agent=True, + logger=None, + ): + assert (n_steps is None) != (n_episodes is None), ( + "One of n_steps or n_episodes must be None. " + + "Either we evaluate for a specified number " + + "of episodes or for a specified number of timesteps." + ) + self.start_time = time.time() + self.n_steps = n_steps + self.n_episodes = n_episodes + self.eval_interval = eval_interval + self.outdir = outdir + self.max_episode_len = max_episode_len + self.step_offset = step_offset + self.evaluation_hooks = evaluation_hooks + self.save_best_so_far_agent = save_best_so_far_agent + self.logger = logger or logging.getLogger(__name__) + + # Values below are shared among processes + self.prev_eval_t = mp.Value( + "l", self.step_offset - self.step_offset % self.eval_interval + ) + self._max_score = mp.Value("f", np.finfo(np.float32).min) + self.wrote_header = mp.Value("b", False) + + # Create scores.txt + with open(os.path.join(self.outdir, "scores.txt"), "a"): + pass + + self.record_tb_stats_queue = None + self.record_tb_stats_thread = None + + @property + def max_score(self): + with self._max_score.get_lock(): + v = self._max_score.value + return v + + def evaluate_and_update_max_score(self, t, episodes, env, agent): + env_get_stats = getattr(env, "get_statistics", lambda: []) + env_clear_stats = getattr(env, "clear_statistics", lambda: None) + assert callable(env_get_stats) + assert callable(env_clear_stats) + env_clear_stats() + eval_stats = eval_performance( + env, + agent, + self.n_steps, + self.n_episodes, + max_episode_len=self.max_episode_len, + logger=self.logger, + ) + elapsed = time.time() - self.start_time + agent_stats = agent.get_statistics() + custom_values = tuple(tup[1] for tup in agent_stats) + env_stats = env_get_stats() + custom_env_values = tuple(tup[1] for tup in env_stats) + mean = eval_stats["mean"] + values = ( + ( + t, + episodes, + elapsed, + mean, + eval_stats["median"], + eval_stats["stdev"], + eval_stats["max"], + eval_stats["min"], + ) + + custom_values + + custom_env_values + ) + record_stats(self.outdir, values) + + if self.record_tb_stats_queue is not None: + self.record_tb_stats_queue.put([agent_stats, eval_stats, env_stats, t]) + + for hook in self.evaluation_hooks: + hook( + env=env, + agent=agent, + evaluator=self, + step=t, + eval_stats=eval_stats, + agent_stats=agent_stats, + env_stats=env_stats, + ) + + with self._max_score.get_lock(): + if mean > self._max_score.value: + self.logger.info( + "The best score is updated %s -> %s", self._max_score.value, mean + ) + self._max_score.value = mean + if self.save_best_so_far_agent: + save_agent(agent, "best", self.outdir, self.logger) + return mean + + def evaluate_if_necessary(self, t, episodes, env, agent): + necessary = False + with self.prev_eval_t.get_lock(): + if t >= self.prev_eval_t.value + self.eval_interval: + necessary = True + self.prev_eval_t.value += self.eval_interval + if necessary: + with self.wrote_header.get_lock(): + if not self.wrote_header.value: + write_header(self.outdir, agent, env) + self.wrote_header.value = True + return self.evaluate_and_update_max_score(t, episodes, env, agent) + return None + + def start_tensorboard_writer(self, outdir, stop_event): + self.record_tb_stats_queue = mp.Queue() + self.record_tb_stats_thread = pfrl.utils.StoppableThread( + target=record_tb_stats_loop, + args=[outdir, self.record_tb_stats_queue, stop_event], + stop_event=stop_event, + ) + self.record_tb_stats_thread.start() + + def join_tensorboard_writer(self): + self.record_tb_stats_thread.join() diff --git a/pfrl/experiments/train_agent.py b/pfrl/experiments/train_agent.py index 81321b9ac..b545522b9 100644 --- a/pfrl/experiments/train_agent.py +++ b/pfrl/experiments/train_agent.py @@ -3,6 +3,8 @@ from pfrl.experiments.evaluator import Evaluator, save_agent from pfrl.utils.ask_yes_no import ask_yes_no +import csv +import time def save_agent_replay_buffer(agent, t, outdir, suffix="", logger=None): @@ -20,8 +22,7 @@ def ask_and_save_agent_replay_buffer(agent, t, outdir, suffix=""): ): # NOQA save_agent_replay_buffer(agent, t, outdir, suffix=suffix) - -def train_agent( +def train_agent_continuing( agent, env, steps, @@ -34,12 +35,145 @@ def train_agent( step_hooks=(), eval_during_episode=False, logger=None, + wandb_logging=False, + env_checkpointable=False, + buffer_checkpointable=False, + load_env_state=False, + total_reward_so_far = 0, ): + logger = logger or logging.getLogger(__name__) episode_r = 0 episode_idx = 0 + total_reward = total_reward_so_far # To calculate average reward + + # o_0, r_0 + obs , info = env.reset() + if load_env_state: + name = os.path.join(outdir, "checkpoint_{}.json".format(step_offset)) + env.load_env_state(name) + logger.info("Loaded the environment state from %s", name) + + t = step_offset + if hasattr(agent, "t"): + agent.t = step_offset + + eval_stats_history = [] # List of evaluation episode stats dict + episode_len = 0 + try: + start = time.time() + while t < steps: + # a_t + action = agent.act(obs) + # o_{t+1}, r_{t+1} + obs, r, terminated, truncated, info = env.step(action) + + t += 1 + total_reward += info['untransformed_rewards'] # Accumulate total reward + episode_len += 1 + reset = episode_len == max_episode_len or info.get("needs_reset", False) or truncated + agent.observe(obs, r, terminated, reset) + + for hook in step_hooks: + hook(env, agent, t) + + episode_idx += 1 + + + episode_end = terminated or reset or t == steps + + if t == steps or episode_end: + break + + if t % 100 == 0: # Save values every 100 steps + logger.info( + "outdir:%s step:%s episode:%s R:%s", + outdir, + t, + episode_idx, + total_reward, + ) + stats = agent.get_statistics() + logger.info("statistics:%s", stats) + print("SPS: ", episode_len / (time.time() - start)) + start = time.time() + # Save episodic reward in a CSV file + csv_filename = os.path.join(outdir, "episodic_rewards.csv") + file_exists = os.path.isfile(csv_filename) + + with open(csv_filename, mode='a', newline='') as csv_file: + fieldnames = ['step', 'episode', 'reward', 'average_reward'] + writer = csv.DictWriter(csv_file, fieldnames=fieldnames) + if not file_exists: + writer.writeheader() + average_reward = total_reward / (episode_idx if episode_idx > 0 else 1) + writer.writerow({'step': t, 'reward': total_reward, 'average_reward': average_reward}) + if wandb_logging: + import wandb + wandb.log({'step': t, 'episode': episode_idx, 'reward': total_reward, 'average_reward': average_reward}) + + if checkpoint_freq and t % checkpoint_freq == 0: + save_agent(agent, t, outdir, logger, suffix="_checkpoint") + if env_checkpointable: + dirname = os.path.join(outdir, "{}{}".format(t, '_checkpoint')) + # Save the environment state + name = os.path.join(dirname, "checkpoint_{}.json".format(t)) + env.save_env_state(name) + if buffer_checkpointable: + save_agent_replay_buffer(agent, t, dirname, suffix="_checkpoint") + except (Exception, KeyboardInterrupt): + # Save the current model before being killed + save_agent(agent, t, outdir, logger, suffix="_except") + if env_checkpointable: + dirname = os.path.join(outdir, "{}{}".format(t, '_except')) + # Save the environment state + name = os.path.join(dirname, "except_{}.json".format(t)) + env.save_env_state(name) + + if buffer_checkpointable: + save_agent_replay_buffer(agent, t, dirname, suffix="_except") + + raise + + # Save the final model + save_agent(agent, t, outdir, logger, suffix="_finish") + if env_checkpointable: + dirname = os.path.join(outdir, "{}{}".format(t, '_finish')) + # Save the environment state + name = os.path.join(dirname, "finish_{}.json".format(t)) + env.save_env_state(name) + if buffer_checkpointable: + save_agent_replay_buffer(agent, t, dirname, suffix="_finish") + return eval_stats_history + + + + +def train_agent( + agent, + env, + steps, + outdir, + checkpoint_freq=None, + max_episode_len=None, + step_offset=0, + evaluator=None, + successful_score=None, + step_hooks=(), + eval_during_episode=False, + logger=None, + wandb_logging=False, + env_checkpointable=False, + buffer_checkpointable=False, + episode_idx = 0, +): + logger = logger or logging.getLogger(__name__) + + episode_r = 0 + episode_idx = episode_idx + other_bot_reward = 0 # o_0, r_0 obs , info = env.reset() @@ -50,13 +184,17 @@ def train_agent( eval_stats_history = [] # List of evaluation episode stats dict episode_len = 0 try: + start = time.time() while t < steps: - # a_t + # a_t action = agent.act(obs) # o_{t+1}, r_{t+1} obs, r, terminated, truncated, info = env.step(action) + t += 1 - episode_r += r + episode_r += info['untransformed_rewards'] + if "other_bot_reward" in info: + other_bot_reward += info['other_bot_reward'] episode_len += 1 reset = episode_len == max_episode_len or info.get("needs_reset", False) or truncated agent.observe(obs, r, terminated, reset) @@ -93,20 +231,66 @@ def train_agent( if episode_end: if t == steps: break + print("SPS: " , episode_len / (time.time() - start)) + start = time.time() # Start a new episode + # Save episodic reward in a CSV file + csv_filename = os.path.join(outdir, "episodic_rewards.csv") + file_exists = os.path.isfile(csv_filename) + + with open(csv_filename, mode='a', newline='') as csv_file: + if 'other_bot_reward' in info: + fieldnames = ['episode', 'steps', 'reward', 'other_bot_reward'] + else: + fieldnames = ['episode', 'steps', 'reward'] + writer = csv.DictWriter(csv_file, fieldnames=fieldnames) + if not file_exists: + writer.writeheader() + if 'other_bot_reward' in info: + writer.writerow({'episode': episode_idx,'steps': t , 'reward': episode_r, 'other_bot_reward': other_bot_reward}) + else: + writer.writerow({'episode': episode_idx,'steps': t , 'reward': episode_r}) + if wandb_logging: + import wandb + if 'other_bot_reward' in info: + writer.writerow({'episode': episode_idx,'steps': t , 'reward': episode_r, 'other_bot_reward': other_bot_reward}) + else: + writer.writerow({'episode': episode_idx,'steps': t , 'reward': episode_r}) + other_bot_reward = 0 episode_r = 0 episode_len = 0 obs, info = env.reset() if checkpoint_freq and t % checkpoint_freq == 0: save_agent(agent, t, outdir, logger, suffix="_checkpoint") + if env_checkpointable: + dirname = os.path.join(outdir, "{}{}".format(t, '_checkpoint')) + # Save the environment state + name = os.path.join(dirname, "checkpoint_{}.json".format(t)) + env.save_env_state(name) + if buffer_checkpointable: + save_agent_replay_buffer(agent, t, dirname, suffix="_checkpoint") except (Exception, KeyboardInterrupt): # Save the current model before being killed save_agent(agent, t, outdir, logger, suffix="_except") + if env_checkpointable: + dirname = os.path.join(outdir, "{}{}".format(t, '_except')) + # Save the environment state + name = os.path.join(dirname, "except_{}.json".format(t)) + env.save_env_state(name) + if buffer_checkpointable: + save_agent_replay_buffer(agent, t, dirname, suffix="_except") raise # Save the final model save_agent(agent, t, outdir, logger, suffix="_finish") + if env_checkpointable: + dirname = os.path.join(outdir, "{}{}".format(t, '_finish')) + # Save the environment state + name = os.path.join(dirname, "finish_{}.json".format(t)) + env.save_env_state(name) + if buffer_checkpointable: + save_agent_replay_buffer(agent, t, dirname, suffix="_finish") return eval_stats_history @@ -131,6 +315,13 @@ def train_agent_with_evaluation( use_tensorboard=False, eval_during_episode=False, logger=None, + wandb_logging = False, + case = "episodic", # episodic or continuing + env_checkpointable = False, + buffer_checkpointable = False, + load_env_state = False, + total_reward_so_far = 0, + episode_idx = 0, ): """Train an agent while periodically evaluating it. @@ -203,19 +394,44 @@ def train_agent_with_evaluation( logger=logger, ) - eval_stats_history = train_agent( - agent, - env, - steps, - outdir, - checkpoint_freq=checkpoint_freq, - max_episode_len=train_max_episode_len, - step_offset=step_offset, - evaluator=evaluator, - successful_score=successful_score, - step_hooks=step_hooks, - eval_during_episode=eval_during_episode, - logger=logger, - ) + if case == "continuing": + eval_stats_history = train_agent_continuing( + agent, + env, + steps, + outdir, + checkpoint_freq=checkpoint_freq, + max_episode_len=train_max_episode_len, + step_offset=step_offset, + evaluator=evaluator, + successful_score=successful_score, + step_hooks=step_hooks, + eval_during_episode=eval_during_episode, + logger=logger, + wandb_logging=wandb_logging, + env_checkpointable=env_checkpointable, + buffer_checkpointable=buffer_checkpointable, + load_env_state= load_env_state, + total_reward_so_far= total_reward_so_far, + ) + else: + eval_stats_history = train_agent( + agent, + env, + steps, + outdir, + checkpoint_freq=checkpoint_freq, + max_episode_len=train_max_episode_len, + step_offset=step_offset, + evaluator=evaluator, + successful_score=successful_score, + step_hooks=step_hooks, + eval_during_episode=eval_during_episode, + logger=logger, + wandb_logging=wandb_logging, + env_checkpointable=env_checkpointable, + buffer_checkpointable=buffer_checkpointable, + episode_idx=episode_idx, + ) return agent, eval_stats_history diff --git a/pfrl/experiments/train_agent_RNN.py b/pfrl/experiments/train_agent_RNN.py new file mode 100644 index 000000000..46e0dfd04 --- /dev/null +++ b/pfrl/experiments/train_agent_RNN.py @@ -0,0 +1,436 @@ +import logging +import os +from pfrl.experiments.evaluator import Evaluator, save_agent +from pfrl.utils.ask_yes_no import ask_yes_no +import csv +import time + + +def save_agent_replay_buffer(agent, t, outdir, suffix="", logger=None): + logger = logger or logging.getLogger(__name__) + filename = os.path.join(outdir, "{}{}.replay.pkl".format(t, suffix)) + agent.replay_buffer.save(filename) + logger.info("Saved the current replay buffer to %s", filename) + + +def ask_and_save_agent_replay_buffer(agent, t, outdir, suffix=""): + if hasattr(agent, "replay_buffer") and ask_yes_no( + "Replay buffer has {} transitions. Do you save them to a file?".format( + len(agent.replay_buffer) + ) + ): # NOQA + save_agent_replay_buffer(agent, t, outdir, suffix=suffix) + +def train_agent_continuing_RNN( + agent, + env, + steps, + outdir, + checkpoint_freq=None, + max_episode_len=None, + step_offset=0, + evaluator=None, + successful_score=None, + step_hooks=(), + eval_during_episode=False, + logger=None, + wandb_logging=False, + env_checkpointable=False, + buffer_checkpointable=False, + load_env_state=False, + total_reward_so_far = 0, +): + + logger = logger or logging.getLogger(__name__) + + episode_r = 0 + episode_idx = 0 + total_reward = total_reward_so_far # To calculate average reward + + # o_0, r_0 + obs , info = env.reset() + if load_env_state: + name = os.path.join(outdir, "checkpoint_{}.json".format(step_offset)) + env.load_env_state(name) + logger.info("Loaded the environment state from %s", name) + + t = step_offset + if hasattr(agent, "t"): + agent.t = step_offset + + eval_stats_history = [] # List of evaluation episode stats dict + episode_len = 0 + try: + start = time.time() + while t < steps: + # a_t + action = agent.act(obs) + # o_{t+1}, r_{t+1} + obs, r, terminated, truncated, info = env.step(action) + + t += 1 + total_reward += info['untransformed_rewards'] # Accumulate total reward + episode_len += 1 + reset = episode_len == max_episode_len or info.get("needs_reset", False) or truncated + agent.observe(obs, r, terminated, reset) + + for hook in step_hooks: + hook(env, agent, t) + + episode_idx += 1 + + + episode_end = terminated or reset or t == steps + + if t == steps or episode_end: + break + + if t % 100 == 0: # Save values every 100 steps + logger.info( + "outdir:%s step:%s episode:%s R:%s", + outdir, + t, + episode_idx, + total_reward, + ) + stats = agent.get_statistics() + logger.info("statistics:%s", stats) + print("SPS: ", episode_len / (time.time() - start)) + start = time.time() + # Save episodic reward in a CSV file + csv_filename = os.path.join(outdir, "episodic_rewards.csv") + file_exists = os.path.isfile(csv_filename) + + with open(csv_filename, mode='a', newline='') as csv_file: + fieldnames = ['step', 'episode', 'reward', 'average_reward'] + writer = csv.DictWriter(csv_file, fieldnames=fieldnames) + if not file_exists: + writer.writeheader() + average_reward = total_reward / (episode_idx if episode_idx > 0 else 1) + writer.writerow({'step': t, 'reward': total_reward, 'average_reward': average_reward}) + if wandb_logging: + import wandb + wandb.log({'step': t, 'episode': episode_idx, 'reward': total_reward, 'average_reward': average_reward}) + + if checkpoint_freq and t % checkpoint_freq == 0: + save_agent(agent, t, outdir, logger, suffix="_checkpoint") + if env_checkpointable: + dirname = os.path.join(outdir, "{}{}".format(t, '_checkpoint')) + # Save the environment state + name = os.path.join(dirname, "checkpoint_{}.json".format(t)) + env.save_env_state(name) + if buffer_checkpointable: + save_agent_replay_buffer(agent, t, dirname, suffix="_checkpoint") + + except (Exception, KeyboardInterrupt): + # Save the current model before being killed + save_agent(agent, t, outdir, logger, suffix="_except") + if env_checkpointable: + dirname = os.path.join(outdir, "{}{}".format(t, '_except')) + # Save the environment state + name = os.path.join(dirname, "except_{}.json".format(t)) + env.save_env_state(name) + + if buffer_checkpointable: + save_agent_replay_buffer(agent, t, dirname, suffix="_except") + + raise + + # Save the final model + save_agent(agent, t, outdir, logger, suffix="_finish") + if env_checkpointable: + dirname = os.path.join(outdir, "{}{}".format(t, '_finish')) + # Save the environment state + name = os.path.join(dirname, "finish_{}.json".format(t)) + env.save_env_state(name) + if buffer_checkpointable: + save_agent_replay_buffer(agent, t, dirname, suffix="_finish") + return eval_stats_history + + + + +def train_agent_RNN( + agent, + env, + steps, + outdir, + checkpoint_freq=None, + max_episode_len=None, + step_offset=0, + evaluator=None, + successful_score=None, + step_hooks=(), + eval_during_episode=False, + logger=None, + wandb_logging=False, + env_checkpointable=False, + buffer_checkpointable=False, + episode_idx = 0, +): + logger = logger or logging.getLogger(__name__) + + episode_r = 0 + episode_idx = episode_idx + other_bot_reward = 0 + # o_0, r_0 + obs , info = env.reset() + + t = step_offset + if hasattr(agent, "t"): + agent.t = step_offset + + eval_stats_history = [] # List of evaluation episode stats dict + episode_len = 0 + try: + start = time.time() + while t < steps: + # a_t + action = agent.act(obs) + # o_{t+1}, r_{t+1} + obs, r, terminated, truncated, info = env.step(action) + + t += 1 + episode_r += info['untransformed_rewards'] + if "other_bot_reward" in info: + other_bot_reward += info['other_bot_reward'] + episode_len += 1 + reset = episode_len == max_episode_len or info.get("needs_reset", False) or truncated + agent.observe(obs, r, terminated, reset) + + for hook in step_hooks: + hook(env, agent, t) + + episode_end = terminated or reset or t == steps + + if episode_end: + logger.info( + "outdir:%s step:%s episode:%s R:%s", + outdir, + t, + episode_idx, + episode_r, + ) + stats = agent.get_statistics() + logger.info("statistics:%s", stats) + episode_idx += 1 + + if evaluator is not None and (episode_end or eval_during_episode): + eval_score = evaluator.evaluate_if_necessary(t=t, episodes=episode_idx) + if eval_score is not None: + eval_stats = dict(agent.get_statistics()) + eval_stats["eval_score"] = eval_score + eval_stats_history.append(eval_stats) + if ( + successful_score is not None + and evaluator.max_score >= successful_score + ): + break + + if episode_end: + if t == steps: + break + print("SPS: " , episode_len / (time.time() - start)) + start = time.time() + # Start a new episode + # Save episodic reward in a CSV file + csv_filename = os.path.join(outdir, "episodic_rewards.csv") + file_exists = os.path.isfile(csv_filename) + + with open(csv_filename, mode='a', newline='') as csv_file: + if 'other_bot_reward' in info: + fieldnames = ['episode', 'steps', 'reward', 'other_bot_reward'] + else: + fieldnames = ['episode', 'steps', 'reward'] + writer = csv.DictWriter(csv_file, fieldnames=fieldnames) + if not file_exists: + writer.writeheader() + if 'other_bot_reward' in info: + writer.writerow({'episode': episode_idx,'steps': t , 'reward': episode_r, 'other_bot_reward': other_bot_reward}) + else: + writer.writerow({'episode': episode_idx,'steps': t , 'reward': episode_r}) + if wandb_logging: + import wandb + if 'other_bot_reward' in info: + writer.writerow({'episode': episode_idx,'steps': t , 'reward': episode_r, 'other_bot_reward': other_bot_reward}) + else: + writer.writerow({'episode': episode_idx,'steps': t , 'reward': episode_r}) + other_bot_reward = 0 + episode_r = 0 + episode_len = 0 + obs, info = env.reset() + if checkpoint_freq and t % checkpoint_freq == 0: + save_agent(agent, t, outdir, logger, suffix="_checkpoint") + if env_checkpointable: + dirname = os.path.join(outdir, "{}{}".format(t, '_checkpoint')) + # Save the environment state + name = os.path.join(dirname, "checkpoint_{}.json".format(t)) + env.save_env_state(name) + if buffer_checkpointable: + save_agent_replay_buffer(agent, t, dirname, suffix="_checkpoint") + + except (Exception, KeyboardInterrupt): + # Save the current model before being killed + save_agent(agent, t, outdir, logger, suffix="_except") + if env_checkpointable: + dirname = os.path.join(outdir, "{}{}".format(t, '_except')) + # Save the environment state + name = os.path.join(dirname, "except_{}.json".format(t)) + env.save_env_state(name) + if buffer_checkpointable: + save_agent_replay_buffer(agent, t, dirname, suffix="_except") + raise + + # Save the final model + save_agent(agent, t, outdir, logger, suffix="_finish") + if env_checkpointable: + dirname = os.path.join(outdir, "{}{}".format(t, '_finish')) + # Save the environment state + name = os.path.join(dirname, "finish_{}.json".format(t)) + env.save_env_state(name) + if buffer_checkpointable: + save_agent_replay_buffer(agent, t, dirname, suffix="_finish") + + return eval_stats_history + + +def train_agent_with_evaluation_RNN( + agent, + env, + steps, + eval_n_steps, + eval_n_episodes, + eval_interval, + outdir, + checkpoint_freq=None, + train_max_episode_len=None, + step_offset=0, + eval_max_episode_len=None, + eval_env=None, + successful_score=None, + step_hooks=(), + evaluation_hooks=(), + save_best_so_far_agent=True, + use_tensorboard=False, + eval_during_episode=False, + logger=None, + wandb_logging = False, + case = "episodic", # episodic or continuing + env_checkpointable = False, + buffer_checkpointable = False, + load_env_state = False, + total_reward_so_far = 0, + episode_idx = 0, +): + """Train an agent while periodically evaluating it. + + Args: + agent: A pfrl.agent.Agent + env: Environment train the agent against. + steps (int): Total number of timesteps for training. + eval_n_steps (int): Number of timesteps at each evaluation phase. + eval_n_episodes (int): Number of episodes at each evaluation phase. + eval_interval (int): Interval of evaluation. + outdir (str): Path to the directory to output data. + checkpoint_freq (int): frequency at which agents are stored. + train_max_episode_len (int): Maximum episode length during training. + step_offset (int): Time step from which training starts. + eval_max_episode_len (int or None): Maximum episode length of + evaluation runs. If None, train_max_episode_len is used instead. + eval_env: Environment used for evaluation. + successful_score (float): Finish training if the mean score is greater + than or equal to this value if not None + step_hooks (Sequence): Sequence of callable objects that accepts + (env, agent, step) as arguments. They are called every step. + See pfrl.experiments.hooks. + evaluation_hooks (Sequence): Sequence of + pfrl.experiments.evaluation_hooks.EvaluationHook objects. They are + called after each evaluation. + save_best_so_far_agent (bool): If set to True, after each evaluation + phase, if the score (= mean return of evaluation episodes) exceeds + the best-so-far score, the current agent is saved. + use_tensorboard (bool): Additionally log eval stats to tensorboard + eval_during_episode (bool): Allow running evaluation during training episodes. + This should be enabled only when `env` and `eval_env` are independent. + logger (logging.Logger): Logger used in this function. + Returns: + agent: Trained agent. + eval_stats_history: List of evaluation episode stats dict. + """ + + logger = logger or logging.getLogger(__name__) + + for hook in evaluation_hooks: + if not hook.support_train_agent: + raise ValueError( + "{} does not support train_agent_with_evaluation().".format(hook) + ) + + os.makedirs(outdir, exist_ok=True) + + if eval_env is None: + assert not eval_during_episode, ( + "To run evaluation during training episodes, you need to specify `eval_env`" + " that is independent from `env`." + ) + eval_env = env + + if eval_max_episode_len is None: + eval_max_episode_len = train_max_episode_len + + evaluator = Evaluator( + agent=agent, + n_steps=eval_n_steps, + n_episodes=eval_n_episodes, + eval_interval=eval_interval, + outdir=outdir, + max_episode_len=eval_max_episode_len, + env=eval_env, + step_offset=step_offset, + evaluation_hooks=evaluation_hooks, + save_best_so_far_agent=save_best_so_far_agent, + use_tensorboard=use_tensorboard, + logger=logger, + ) + + if case == "continuing": + eval_stats_history = train_agent_continuing_RNN( + agent, + env, + steps, + outdir, + checkpoint_freq=checkpoint_freq, + max_episode_len=train_max_episode_len, + step_offset=step_offset, + evaluator=evaluator, + successful_score=successful_score, + step_hooks=step_hooks, + eval_during_episode=eval_during_episode, + logger=logger, + wandb_logging=wandb_logging, + env_checkpointable=env_checkpointable, + buffer_checkpointable=buffer_checkpointable, + load_env_state= load_env_state, + total_reward_so_far= total_reward_so_far, + ) + else: + eval_stats_history = train_agent_RNN( + agent, + env, + steps, + outdir, + checkpoint_freq=checkpoint_freq, + max_episode_len=train_max_episode_len, + step_offset=step_offset, + evaluator=evaluator, + successful_score=successful_score, + step_hooks=step_hooks, + eval_during_episode=eval_during_episode, + logger=logger, + wandb_logging=wandb_logging, + env_checkpointable=env_checkpointable, + buffer_checkpointable=buffer_checkpointable, + episode_idx=episode_idx, + ) + + return agent, eval_stats_history diff --git a/pfrl/explorers/__init__.py b/pfrl/explorers/__init__.py index 935fba103..896249629 100644 --- a/pfrl/explorers/__init__.py +++ b/pfrl/explorers/__init__.py @@ -5,3 +5,4 @@ from pfrl.explorers.epsilon_greedy import ExponentialDecayEpsilonGreedy # NOQA from pfrl.explorers.epsilon_greedy import LinearDecayEpsilonGreedy # NOQA from pfrl.explorers.greedy import Greedy # NOQA +from pfrl.explorers.epsilon_greedy import EZGreedy # NOQA \ No newline at end of file diff --git a/pfrl/explorers/epsilon_greedy.py b/pfrl/explorers/epsilon_greedy.py index c034d6a5c..ac87a038c 100644 --- a/pfrl/explorers/epsilon_greedy.py +++ b/pfrl/explorers/epsilon_greedy.py @@ -132,3 +132,91 @@ def select_action(self, t, greedy_action_func, action_value=None): def __repr__(self): return "ExponentialDecayEpsilonGreedy(epsilon={})".format(self.epsilon) + +def power_law_duration_sampler(mu=2.0, max_duration=10000): + """ + Sample duration n ~ z(n) ∝ n^{-μ} using inverse transform sampling. + + Args: + mu (float): Power-law exponent. + max_duration (int): Maximum cap for duration. + + Returns: + int: Sampled duration (≥1). + """ + # Sample from discrete power law using inverse CDF + u = np.random.uniform() + n = int((1 / u) ** (1 / (mu - 1))) + return min(max(1, n), max_duration) + + +class EZGreedy(explorer.Explorer): + """ + εz-Greedy exploration policy with linearly decayed epsilon and temporally extended exploration. + + Args: + start_epsilon (float): Initial value of epsilon. + end_epsilon (float): Final value of epsilon. + decay_steps (int): Number of steps over which epsilon decays. + random_action_func (callable): Function that returns a random action. + greedy_action_func (callable): Function that returns the greedy action. + duration_sampler (callable): Function that samples the exploratory duration n ~ z. + logger (logging.Logger): Logger instance for debug output. + """ + + def __init__( + self, + start_epsilon, + end_epsilon, + decay_steps, + random_action_func, + # greedy_action_func, + # duration_sampler, + logger=getLogger(__name__), + ): + assert 0 <= start_epsilon <= 1 + assert 0 <= end_epsilon <= 1 + assert decay_steps >= 0 + self.start_epsilon = start_epsilon + self.end_epsilon = end_epsilon + self.decay_steps = decay_steps + self.random_action_func = random_action_func + # self.greedy_action_func = greedy_action_func + self.duration_sampler = power_law_duration_sampler + self.logger = logger + + self.epsilon = start_epsilon + self.n = 0 # Remaining duration of exploratory action + self.omega = None # Current exploratory action + + def compute_epsilon(self, t): + if t >= self.decay_steps: + return self.end_epsilon + return self.start_epsilon + (self.end_epsilon - self.start_epsilon) * (t / self.decay_steps) + + def select_action(self, t, greedy_action_func=None, action_value=None): + self.epsilon = self.compute_epsilon(t) + + if self.n == 0: + if np.random.rand() <= self.epsilon: + # Exploration phase + self.n = self.duration_sampler() + self.omega = self.random_action_func() + a = self.omega + greedy = False + else: + # Greedy phase + a = greedy_action_func() + greedy = True + else: + # Continue exploratory action + a = self.omega + self.n -= 1 + greedy = False + + # greedy_str = "greedy" if greedy else "non-greedy" + # self.logger.debug("t:%s a:%s %s n:%s", t, a, greedy_str, self.n) + return a + + def __repr__(self): + return f"EZGreedy(epsilon={self.epsilon})" diff --git a/pfrl/replay_buffer.py b/pfrl/replay_buffer.py index 03522eb73..69f77f713 100644 --- a/pfrl/replay_buffer.py +++ b/pfrl/replay_buffer.py @@ -154,7 +154,7 @@ def random_subseq(seq, subseq_len): return seq[i : i + subseq_len] -def batch_experiences(experiences, device, phi, gamma, batch_states=batch_states): +def hybrid_batch_experiences(experiences, device, phi, gamma, batch_states=batch_states): """Takes a batch of k experiences each of which contains j consecutive transitions and vectorizes them, where j is between 1 and n. @@ -177,9 +177,80 @@ def batch_experiences(experiences, device, phi, gamma, batch_states=batch_states batch_exp = { "state": batch_states([elem[0]["state"] for elem in experiences], device, phi), - "action": torch.as_tensor( - [elem[0]["action"] for elem in experiences], device=device + # "action": torch.from_numpy( + # np.asarray([elem[0]["action"] for elem in experiences])).to(device) + # , + "c_action": torch.from_numpy( + np.asarray([elem[0]["action"][0] for elem in experiences])).to(device) + , + "d_action": torch.from_numpy( + np.asarray([elem[0]["action"][1] for elem in experiences])).to(device) + , + "reward": torch.as_tensor( + [ + sum((gamma**i) * exp[i]["reward"] for i in range(len(exp))) + for exp in experiences + ], + dtype=torch.float32, + device=device, + ), + "next_state": batch_states( + [elem[-1]["next_state"] for elem in experiences], device, phi + ), + "is_state_terminal": torch.as_tensor( + [ + any(transition["is_state_terminal"] for transition in exp) + for exp in experiences + ], + dtype=torch.float32, + device=device, + ), + "discount": torch.as_tensor( + [(gamma ** len(elem)) for elem in experiences], + dtype=torch.float32, + device=device, ), + } + if all(elem[-1]["next_action"] is not None for elem in experiences): + # batch_exp["next_action"] = torch.as_tensor( + # [elem[-1]["next_action"] for elem in experiences], device=device + # ) + batch_exp["next_c_action"] = torch.as_tensor( + [elem[-1]["next_c_action"] for elem in experiences], device=device + ) + batch_exp["next_d_action"] = torch.as_tensor( + [elem[-1]["next_d_action"] for elem in experiences], device=device + ) + + return batch_exp + + +def batch_experiences(experiences, device, phi, gamma, batch_states=batch_states): + """Takes a batch of k experiences each of which contains j + + consecutive transitions and vectorizes them, where j is between 1 and n. + + Args: + experiences: list of experiences. Each experience is a list + containing between 1 and n dicts containing + - state (object): State + - action (object): Action + - reward (float): Reward + - is_state_terminal (bool): True iff next state is terminal + - next_state (object): Next state + device : GPU or CPU the tensor should be placed on + phi : Preprocessing function + gamma: discount factor + batch_states: function that converts a list to a batch + Returns: + dict of batched transitions + """ + + batch_exp = { + "state": batch_states([elem[0]["state"] for elem in experiences], device, phi), + "action": torch.from_numpy( + np.asarray([elem[0]["action"] for elem in experiences])).to(device) + , "reward": torch.as_tensor( [ sum((gamma**i) * exp[i]["reward"] for i in range(len(exp))) diff --git a/recurrent_c.py b/recurrent_c.py new file mode 100644 index 000000000..e7b9003f1 --- /dev/null +++ b/recurrent_c.py @@ -0,0 +1,362 @@ +import itertools + +import numpy as np +import torch +from torch import nn + + +def is_recurrent(layer): + """Return True iff a given layer is recurrent and supported by PFRL. + + Args: + layer (callable): Any callable object. + + Returns: + bool: True iff a given layer is recurrent and supported by PFRL. + """ + # Import here to avoid circular import + from pfrl.nn import Recurrent + + return isinstance( + layer, + ( + nn.LSTM, + nn.RNN, + nn.GRU, + Recurrent, + ), + ) + + +def mask_recurrent_state_at(recurrent_state, indices): + """Return a recurrent state masked at given indices. + + This function can be used to initialize a recurrent state only for a + certain sequence, not all the sequences. + + Args: + recurrent_state (object): Batched recurrent state. + indices (int or array-like of ints): Which recurrent state to mask. + + Returns: + object: New batched recurrent state. + """ + if recurrent_state is None: + return None + elif isinstance(recurrent_state, torch.Tensor): + mask = torch.ones_like(recurrent_state) + mask[:, indices] = 0 + return recurrent_state * mask + elif isinstance(recurrent_state, tuple): + return tuple(mask_recurrent_state_at(s, indices) for s in recurrent_state) + else: + raise ValueError("Invalid recurrent state: {}".format(recurrent_state)) + + +def get_recurrent_state_at(recurrent_state, indices, detach): + """Get a recurrent state at given indices. + + This function can be used to save a recurrent state so that you can + reuse it when you replay past sequences. + + Args: + indices (int or array-like of ints): Which recurrent state to get. + + Returns: + object: Recurrent state of given indices. + """ + if recurrent_state is None: + return None + elif isinstance(recurrent_state, torch.Tensor): + if detach: + recurrent_state = recurrent_state.detach() + return recurrent_state[:, indices] + elif isinstance(recurrent_state, tuple): + return tuple( + get_recurrent_state_at(s, indices, detach) for s in recurrent_state + ) + else: + raise ValueError("Invalid recurrent state: {}".format(recurrent_state)) + + +def concatenate_recurrent_states(split_recurrent_states): + """Concatenate recurrent states into a batch. + + This function can be used to make a batched recurrent state from separate + recurrent states obtained via the `get_recurrent_state_at` function. + + Args: + split_recurrent_states (Sequence): Recurrent states to concatenate. + + Returns: + object: Batched recurrent_state. + """ + if all(s is None for s in split_recurrent_states): + return None + else: + non_none_s = next(s for s in split_recurrent_states if s is not None) + if isinstance(non_none_s, torch.Tensor): + new_ss = [ + s if s is not None else torch.zeros_like(non_none_s) + for s in split_recurrent_states + ] + return torch.stack(new_ss, dim=1) + elif isinstance(non_none_s, np.ndarray): + new_ss = [ + s if s is not None else np.zeros_like(non_none_s) + for s in split_recurrent_states + ] + return np.stack(new_ss, axis=1) + elif isinstance(non_none_s, tuple): + return tuple( + concatenate_recurrent_states( + [s[i] if s is not None else None for s in split_recurrent_states] + ) + for i in range(len(non_none_s)) + ) + else: + raise ValueError("Invalid recurrent state: {}".format(non_none_s)) + + +def pack_one_step_batch_as_sequences(xs): + if isinstance(xs, tuple): + return tuple(pack_one_step_batch_as_sequences(x) for x in xs) + else: + assert isinstance(xs, torch.Tensor) + # xs: (B, ...)-shaped tensor + # seqs: B-sized list of (1, ...)-shaped tensors + seqs = [x for x in xs.split(1)] + assert len(xs) == len(seqs) + assert (1,) + xs.shape[1:] == seqs[0].shape + return nn.utils.rnn.pack_sequence(seqs) + + +def unpack_sequences_as_one_step_batch(pack): + if isinstance(pack, nn.utils.rnn.PackedSequence): + return pack.data + elif isinstance(pack, tuple): + return tuple(unpack_sequences_as_one_step_batch(x) for x in pack) + else: + return pack + + +def one_step_forward(rnn, batch_input, recurrent_state): + """One-step batch forward computation of a recurrent module. + + Args: + rnn (torch.nn.Module): Recurrent module. + batch_input (BatchData): One-step batched input. + recurrent_state (object): Batched recurrent state. + + Returns: + object: One-step batched output. + object: New batched recurrent state. + """ + + batch_obs, batch_action, batch_reward = batch_input + pack = pack_one_step_batch_as_sequences(batch_obs) + y, recurrent_state = rnn(pack, [recurrent_state, batch_action, batch_reward]) + return unpack_sequences_as_one_step_batch(y), recurrent_state + + +def pack_and_forward(rnn, sequences, recurrent_state): + """Pack sequences, multi-step forward, and then unwrap `PackedSequence`. + + Args: + rnn (torch.nn.Module): Recurrent module. + sequences (object): Sequences of input data. + recurrent_state (object): Batched recurrent state. + + Returns: + object: Sequence of output data, packed with time axis first. + object: New batched recurrent state. + """ + pack = pack_sequences_recursive(sequences[0]) + + batch_actions = sequences[1].reshape(-1, sequences[1].shape[-1]) + batch_rewards = sequences[2].reshape(-1, sequences[2].shape[-1]) + y, recurrent_state = rnn(pack, [recurrent_state,batch_actions, batch_rewards]) + return unwrap_packed_sequences_recursive(y), recurrent_state + + +def flatten_sequences_time_first(sequences): + """Flatten sequences with time axis first. + + The resulting order is the same as how + `torch.nn.utils.rnn.pack_sequence` will pack sequences into a tensor. + + Args: + sequences: Sequences with batch axis first. + + Returns: + list: Flattened sequences with time axis first. + """ + ret = [] + for batch in itertools.zip_longest(*sequences): + ret.extend([x for x in batch if x is not None]) + return ret + + +def wrap_packed_sequences_recursive(unwrapped, batch_sizes, sorted_indices): + """Wrap packed tensors by `PackedSequence`. + + Args: + unwrapped (object): Packed but unwrapped tensor(s). + batch_sizes (Tensor): See `PackedSequence.batch_sizes`. + sorted_indices (Tensor): See `PackedSequence.sorted_indices`. + + Returns: + object: Packed sequences. If `unwrapped` is a tensor, then the returned + value is a `PackedSequence`. If `unwrapped` is a tuple of tensors, + then the returned value is a tuple of `PackedSequence`s. + """ + if isinstance(unwrapped, torch.Tensor): + return torch.nn.utils.rnn.PackedSequence( + unwrapped, batch_sizes=batch_sizes, sorted_indices=sorted_indices + ) + if isinstance(unwrapped, tuple): + return tuple( + wrap_packed_sequences_recursive(x, batch_sizes, sorted_indices) + for x in unwrapped + ) + return unwrapped + + +def unwrap_packed_sequences_recursive(packed): + """Unwrap `PackedSequence` class of packed sequences recursively. + + This function extract `torch.Tensor` that + `torch.nn.utils.rnn.PackedSequence` holds internally. Sequences in the + internal tensor is ordered with time axis first. + + Unlike `torch.nn.pad_packed_sequence`, this function just returns the + underlying tensor as it is without padding. + + To wrap the data by `PackedSequence` again, use + `wrap_packed_sequences_recursive`. + + Args: + packed (object): Packed sequences. + + Returns: + object: Unwrapped packed sequences. If `packed` is a `PackedSequence`, + then the returned value is `PackedSequence.data`, the underlying + tensor. If `Packed` is a tuple of `PackedSequence`, then the + returned value is a tuple of the underlying tensors. + """ + if isinstance(packed, torch.nn.utils.rnn.PackedSequence): + return packed.data + if isinstance(packed, tuple): + return tuple(unwrap_packed_sequences_recursive(x) for x in packed) + return packed + + +def pack_sequences_recursive(sequences): + """Pack sequences into PackedSequence recursively. + + This function works similarly to `torch.nn.utils.rnn.pack_sequence` except + that it works recursively for tuples. + + When each given sequence is an N-tuple of `torch.Tensor`s, the function + returns an N-tuple of `torch.nn.utils.rnn.PackedSequence`, packing i-th + tensors separately for i=1,...,N. + + Args: + sequences (object): Batch of sequences to pack. + + Returns: + object: Packed sequences. If `sequences` is a list of tensors, then the + returned value is a `PackedSequence`. If `sequences` is a list of + tuples of tensors, then the returned value is a tuple of + `PackedSequence`. + """ + assert sequences + first_seq = sequences[0] + if isinstance(first_seq, torch.Tensor): + return nn.utils.rnn.pack_sequence(sequences) + if isinstance(first_seq, tuple): + return tuple( + pack_sequences_recursive([seq[i] for seq in sequences]) + for i in range(len(first_seq)) + ) + return sequences + + +def get_packed_sequence_info(packed): + """Get `batch_sizes` and `sorted_indices` of `PackedSequence`. + + Args: + packed (object): Packed sequences. If it contains multiple + `PackedSequence`s, then only one of them are sampled assuming that + all of them have same `batch_sizes` and `sorted_indices`. + + Returns: + Tensor: `PackedSequence.batch_sizes`. + Tensor: `PackedSequence.sorted_indices`. + """ + if isinstance(packed, torch.nn.utils.rnn.PackedSequence): + return packed.batch_sizes, packed.sorted_indices + if isinstance(packed, tuple): + for y in packed: + ret = get_packed_sequence_info(y) + if ret is not None: + return ret + return None + + +def recurrent_state_as_numpy(recurrent_state): + """Convert a recurrent state in torch.Tensor to numpy.ndarray. + + Args: + recurrent_state (object): Recurrent state in torch.Tensor. + + Returns: + object: Recurrent state in numpy.ndarray. + """ + if recurrent_state is None: + return None + elif isinstance(recurrent_state, torch.Tensor): + return recurrent_state.detach().cpu().numpy() + elif isinstance(recurrent_state, tuple): + return tuple(recurrent_state_as_numpy(s) for s in recurrent_state) + else: + raise ValueError("Invalid recurrent state: {}".format(recurrent_state)) + + +def recurrent_state_from_numpy(recurrent_state, device): + """Convert a recurrent state in numpy.ndarray to torch.Tensor. + + Args: + recurrent_state (object): Recurrent state in numpy.ndarray. + device (torch.Device): Device the recurrent state is moved to. + + Returns: + object: Recurrent state in torch.Tensor of a given device. + """ + if recurrent_state is None: + return None + elif isinstance(recurrent_state, np.ndarray): + return torch.from_numpy(recurrent_state).to(device) + elif isinstance(recurrent_state, tuple): + return tuple(recurrent_state_from_numpy(s, device) for s in recurrent_state) + else: + raise ValueError("Invalid recurrent state: {}".format(recurrent_state)) + + +def detach_recurrent_state(recurrent_state): + """Detach recurrent state. + + Args: + recurrent_state (object): Recurrent state in torch.Tensor. + + Returns: + object: Detached recurrent state. + """ + if recurrent_state is None: + return + elif isinstance(recurrent_state, torch.Tensor): + return recurrent_state.detach() + elif isinstance(recurrent_state, tuple): + return tuple(detach_recurrent_state(s) for s in recurrent_state) + else: + raise ValueError("Invalid recurrent state: {}".format(recurrent_state)) diff --git a/rnn_c.py b/rnn_c.py new file mode 100644 index 000000000..caadd5bc8 --- /dev/null +++ b/rnn_c.py @@ -0,0 +1,1824 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import math +import numbers +import warnings +import weakref +from typing import List, Optional, overload, Tuple +from typing_extensions import deprecated + +import torch +from torch import _VF, Tensor +from torch.nn import init +from torch.nn.parameter import Parameter +from torch.nn.utils.rnn import PackedSequence + +from .module import Module + + +__all__ = [ + "RNNBase", + "RNN", + "LSTM", + "GRU", + "RNNCellBase", + "RNNCell", + "LSTMCell", + "GRUCell", +] + +_rnn_impls = { + "RNN_TANH": _VF.rnn_tanh, + "RNN_RELU": _VF.rnn_relu, +} + + +def _apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor: + return tensor.index_select(dim, permutation) + + +@deprecated( + "`apply_permutation` is deprecated, please use `tensor.index_select(dim, permutation)` instead", + category=FutureWarning, +) +def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor: + return _apply_permutation(tensor, permutation, dim) + + +class RNNBase(Module): + r"""Base class for RNN modules (RNN, LSTM, GRU). + + Implements aspects of RNNs shared by the RNN, LSTM, and GRU classes, such as module initialization + and utility methods for parameter storage management. + + .. note:: + The forward method is not implemented by the RNNBase class. + + .. note:: + LSTM and GRU classes override some methods implemented by RNNBase. + """ + + __constants__ = [ + "mode", + "input_size", + "hidden_size", + "num_layers", + "bias", + "batch_first", + "dropout", + "bidirectional", + "proj_size", + ] + __jit_unused_properties__ = ["all_weights"] + + mode: str + input_size: int + hidden_size: int + num_layers: int + bias: bool + batch_first: bool + dropout: float + bidirectional: bool + proj_size: int + + def __init__( + self, + mode: str, + input_size: int, + hidden_size: int, + num_layers: int = 1, + bias: bool = True, + batch_first: bool = False, + dropout: float = 0.0, + bidirectional: bool = False, + proj_size: int = 0, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.mode = mode + self.input_size = input_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.bias = bias + self.batch_first = batch_first + self.dropout = float(dropout) + self.bidirectional = bidirectional + self.proj_size = proj_size + self._flat_weight_refs: List[Optional[weakref.ReferenceType[Parameter]]] = [] + num_directions = 2 if bidirectional else 1 + + if ( + not isinstance(dropout, numbers.Number) + or not 0 <= dropout <= 1 + or isinstance(dropout, bool) + ): + raise ValueError( + "dropout should be a number in range [0, 1] " + "representing the probability of an element being " + "zeroed" + ) + if dropout > 0 and num_layers == 1: + warnings.warn( + "dropout option adds dropout after all but last " + "recurrent layer, so non-zero dropout expects " + f"num_layers greater than 1, but got dropout={dropout} and " + f"num_layers={num_layers}" + ) + + if not isinstance(hidden_size, int): + raise TypeError( + f"hidden_size should be of type int, got: {type(hidden_size).__name__}" + ) + if hidden_size <= 0: + raise ValueError("hidden_size must be greater than zero") + if num_layers <= 0: + raise ValueError("num_layers must be greater than zero") + if proj_size < 0: + raise ValueError( + "proj_size should be a positive integer or zero to disable projections" + ) + if proj_size >= hidden_size: + raise ValueError("proj_size has to be smaller than hidden_size") + + if mode == "LSTM": + gate_size = 4 * hidden_size + elif mode == "GRU": + gate_size = 3 * hidden_size + elif mode == "RNN_TANH": + gate_size = hidden_size + elif mode == "RNN_RELU": + gate_size = hidden_size + else: + raise ValueError("Unrecognized RNN mode: " + mode) + + self._flat_weights_names = [] + self._all_weights = [] + for layer in range(num_layers): + for direction in range(num_directions): + real_hidden_size = proj_size if proj_size > 0 else hidden_size + layer_input_size = ( + input_size if layer == 0 else real_hidden_size * num_directions + ) + + w_ih = Parameter( + torch.empty((gate_size, layer_input_size), **factory_kwargs) + ) + w_hh = Parameter( + torch.empty((gate_size, real_hidden_size), **factory_kwargs) + ) + b_ih = Parameter(torch.empty(gate_size, **factory_kwargs)) + # Second bias vector included for CuDNN compatibility. Only one + # bias vector is needed in standard definition. + b_hh = Parameter(torch.empty(gate_size, **factory_kwargs)) + layer_params: Tuple[Tensor, ...] = () + if self.proj_size == 0: + if bias: + layer_params = (w_ih, w_hh, b_ih, b_hh) + else: + layer_params = (w_ih, w_hh) + else: + w_hr = Parameter( + torch.empty((proj_size, hidden_size), **factory_kwargs) + ) + if bias: + layer_params = (w_ih, w_hh, b_ih, b_hh, w_hr) + else: + layer_params = (w_ih, w_hh, w_hr) + + suffix = "_reverse" if direction == 1 else "" + param_names = ["weight_ih_l{}{}", "weight_hh_l{}{}"] + if bias: + param_names += ["bias_ih_l{}{}", "bias_hh_l{}{}"] + if self.proj_size > 0: + param_names += ["weight_hr_l{}{}"] + param_names = [x.format(layer, suffix) for x in param_names] + + for name, param in zip(param_names, layer_params): + setattr(self, name, param) + self._flat_weights_names.extend(param_names) + self._all_weights.append(param_names) + + self._init_flat_weights() + + self.reset_parameters() + + def _init_flat_weights(self): + self._flat_weights = [ + getattr(self, wn) if hasattr(self, wn) else None + for wn in self._flat_weights_names + ] + self._flat_weight_refs = [ + weakref.ref(w) if w is not None else None for w in self._flat_weights + ] + self.flatten_parameters() + + def __setattr__(self, attr, value): + if hasattr(self, "_flat_weights_names") and attr in self._flat_weights_names: + # keep self._flat_weights up to date if you do self.weight = ... + idx = self._flat_weights_names.index(attr) + self._flat_weights[idx] = value + super().__setattr__(attr, value) + + def flatten_parameters(self) -> None: + """Reset parameter data pointer so that they can use faster code paths. + + Right now, this works only if the module is on the GPU and cuDNN is enabled. + Otherwise, it's a no-op. + """ + # Short-circuits if _flat_weights is only partially instantiated + if len(self._flat_weights) != len(self._flat_weights_names): + return + + for w in self._flat_weights: + if not isinstance(w, Tensor): + return + # Short-circuits if any tensor in self._flat_weights is not acceptable to cuDNN + # or the tensors in _flat_weights are of different dtypes + + first_fw = self._flat_weights[0] + dtype = first_fw.dtype + for fw in self._flat_weights: + if ( + not isinstance(fw, Tensor) + or not (fw.dtype == dtype) + or not fw.is_cuda + or not torch.backends.cudnn.is_acceptable(fw) + ): + return + + # If any parameters alias, we fall back to the slower, copying code path. This is + # a sufficient check, because overlapping parameter buffers that don't completely + # alias would break the assumptions of the uniqueness check in + # Module.named_parameters(). + unique_data_ptrs = {p.data_ptr() for p in self._flat_weights} + if len(unique_data_ptrs) != len(self._flat_weights): + return + + with torch.cuda.device_of(first_fw): + import torch.backends.cudnn.rnn as rnn + + # Note: no_grad() is necessary since _cudnn_rnn_flatten_weight is + # an inplace operation on self._flat_weights + with torch.no_grad(): + if torch._use_cudnn_rnn_flatten_weight(): + num_weights = 4 if self.bias else 2 + if self.proj_size > 0: + num_weights += 1 + torch._cudnn_rnn_flatten_weight( + self._flat_weights, + num_weights, + self.input_size, + rnn.get_cudnn_mode(self.mode), + self.hidden_size, + self.proj_size, + self.num_layers, + self.batch_first, + bool(self.bidirectional), + ) + + def _apply(self, fn, recurse=True): + self._flat_weight_refs = [] + ret = super()._apply(fn, recurse) + + # Resets _flat_weights + # Note: be v. careful before removing this, as 3rd party device types + # likely rely on this behavior to properly .to() modules like LSTM. + self._init_flat_weights() + + return ret + + def reset_parameters(self) -> None: + stdv = 1.0 / math.sqrt(self.hidden_size) if self.hidden_size > 0 else 0 + for weight in self.parameters(): + init.uniform_(weight, -stdv, stdv) + + def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None: + if not torch.jit.is_scripting(): + if ( + input.dtype != self._flat_weights[0].dtype + and not torch._C._is_any_autocast_enabled() + ): + raise ValueError( + f"input must have the type {self._flat_weights[0].dtype}, got type {input.dtype}" + ) + expected_input_dim = 2 if batch_sizes is not None else 3 + if input.dim() != expected_input_dim: + raise RuntimeError( + f"input must have {expected_input_dim} dimensions, got {input.dim()}" + ) + if self.input_size != input.size(-1): + raise RuntimeError( + f"input.size(-1) must be equal to input_size. Expected {self.input_size}, got {input.size(-1)}" + ) + + def get_expected_hidden_size( + self, input: Tensor, batch_sizes: Optional[Tensor] + ) -> Tuple[int, int, int]: + if batch_sizes is not None: + mini_batch = int(batch_sizes[0]) + else: + mini_batch = input.size(0) if self.batch_first else input.size(1) + num_directions = 2 if self.bidirectional else 1 + if self.proj_size > 0: + expected_hidden_size = ( + self.num_layers * num_directions, + mini_batch, + self.proj_size, + ) + else: + expected_hidden_size = ( + self.num_layers * num_directions, + mini_batch, + self.hidden_size, + ) + return expected_hidden_size + + def check_hidden_size( + self, + hx: Tensor, + expected_hidden_size: Tuple[int, int, int], + msg: str = "Expected hidden size {}, got {}", + ) -> None: + if hx.size() != expected_hidden_size: + raise RuntimeError(msg.format(expected_hidden_size, list(hx.size()))) + + def _weights_have_changed(self): + # Returns True if the weight tensors have changed since the last forward pass. + # This is the case when used with torch.func.functional_call(), for example. + weights_changed = False + for ref, name in zip(self._flat_weight_refs, self._flat_weights_names): + weight = getattr(self, name) if hasattr(self, name) else None + if weight is not None and ref is not None and ref() is not weight: + weights_changed = True + break + return weights_changed + + def check_forward_args( + self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor] + ): + self.check_input(input, batch_sizes) + expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes) + + self.check_hidden_size(hidden, expected_hidden_size) + + def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]): + if permutation is None: + return hx + return _apply_permutation(hx, permutation) + + def extra_repr(self) -> str: + s = "{input_size}, {hidden_size}" + if self.proj_size != 0: + s += ", proj_size={proj_size}" + if self.num_layers != 1: + s += ", num_layers={num_layers}" + if self.bias is not True: + s += ", bias={bias}" + if self.batch_first is not False: + s += ", batch_first={batch_first}" + if self.dropout != 0: + s += ", dropout={dropout}" + if self.bidirectional is not False: + s += ", bidirectional={bidirectional}" + return s.format(**self.__dict__) + + def _update_flat_weights(self): + if not torch.jit.is_scripting(): + if self._weights_have_changed(): + self._init_flat_weights() + + def __getstate__(self): + # If weights have been changed, update the _flat_weights in __getstate__ here. + self._update_flat_weights() + # Don't serialize the weight references. + state = self.__dict__.copy() + del state["_flat_weight_refs"] + return state + + def __setstate__(self, d): + super().__setstate__(d) + if "all_weights" in d: + self._all_weights = d["all_weights"] + # In PyTorch 1.8 we added a proj_size member variable to LSTM. + # LSTMs that were serialized via torch.save(module) before PyTorch 1.8 + # don't have it, so to preserve compatibility we set proj_size here. + if "proj_size" not in d: + self.proj_size = 0 + + if not isinstance(self._all_weights[0][0], str): + num_layers = self.num_layers + num_directions = 2 if self.bidirectional else 1 + self._flat_weights_names = [] + self._all_weights = [] + for layer in range(num_layers): + for direction in range(num_directions): + suffix = "_reverse" if direction == 1 else "" + weights = [ + "weight_ih_l{}{}", + "weight_hh_l{}{}", + "bias_ih_l{}{}", + "bias_hh_l{}{}", + "weight_hr_l{}{}", + ] + weights = [x.format(layer, suffix) for x in weights] + if self.bias: + if self.proj_size > 0: + self._all_weights += [weights] + self._flat_weights_names.extend(weights) + else: + self._all_weights += [weights[:4]] + self._flat_weights_names.extend(weights[:4]) + else: + if self.proj_size > 0: + self._all_weights += [weights[:2]] + [weights[-1:]] + self._flat_weights_names.extend( + weights[:2] + [weights[-1:]] + ) + else: + self._all_weights += [weights[:2]] + self._flat_weights_names.extend(weights[:2]) + self._flat_weights = [ + getattr(self, wn) if hasattr(self, wn) else None + for wn in self._flat_weights_names + ] + + self._flat_weight_refs = [ + weakref.ref(w) if w is not None else None for w in self._flat_weights + ] + + @property + def all_weights(self) -> List[List[Parameter]]: + return [ + [getattr(self, weight) for weight in weights] + for weights in self._all_weights + ] + + def _replicate_for_data_parallel(self): + replica = super()._replicate_for_data_parallel() + # Need to copy these caches, otherwise the replica will share the same + # flat weights list. + replica._flat_weights = replica._flat_weights[:] + replica._flat_weights_names = replica._flat_weights_names[:] + return replica + + +class RNN(RNNBase): + r"""__init__(input_size,hidden_size,num_layers=1,nonlinearity='tanh',bias=True,batch_first=False,dropout=0.0,bidirectional=False,device=None,dtype=None) + + Apply a multi-layer Elman RNN with :math:`\tanh` or :math:`\text{ReLU}` + non-linearity to an input sequence. For each element in the input sequence, + each layer computes the following function: + + .. math:: + h_t = \tanh(x_t W_{ih}^T + b_{ih} + h_{t-1}W_{hh}^T + b_{hh}) + + where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is + the input at time `t`, and :math:`h_{(t-1)}` is the hidden state of the + previous layer at time `t-1` or the initial hidden state at time `0`. + If :attr:`nonlinearity` is ``'relu'``, then :math:`\text{ReLU}` is used instead of :math:`\tanh`. + + .. code-block:: python + + # Efficient implementation equivalent to the following with bidirectional=False + def forward(x, h_0=None): + if batch_first: + x = x.transpose(0, 1) + seq_len, batch_size, _ = x.size() + if h_0 is None: + h_0 = torch.zeros(num_layers, batch_size, hidden_size) + h_t_minus_1 = h_0 + h_t = h_0 + output = [] + for t in range(seq_len): + for layer in range(num_layers): + h_t[layer] = torch.tanh( + x[t] @ weight_ih[layer].T + + bias_ih[layer] + + h_t_minus_1[layer] @ weight_hh[layer].T + + bias_hh[layer] + ) + output.append(h_t[-1]) + h_t_minus_1 = h_t + output = torch.stack(output) + if batch_first: + output = output.transpose(0, 1) + return output, h_t + + Args: + input_size: The number of expected features in the input `x` + hidden_size: The number of features in the hidden state `h` + num_layers: Number of recurrent layers. E.g., setting ``num_layers=2`` + would mean stacking two RNNs together to form a `stacked RNN`, + with the second RNN taking in outputs of the first RNN and + computing the final results. Default: 1 + nonlinearity: The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'`` + bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. + Default: ``True`` + batch_first: If ``True``, then the input and output tensors are provided + as `(batch, seq, feature)` instead of `(seq, batch, feature)`. + Note that this does not apply to hidden or cell states. See the + Inputs/Outputs sections below for details. Default: ``False`` + dropout: If non-zero, introduces a `Dropout` layer on the outputs of each + RNN layer except the last layer, with dropout probability equal to + :attr:`dropout`. Default: 0 + bidirectional: If ``True``, becomes a bidirectional RNN. Default: ``False`` + + Inputs: input, h_0 + * **input**: tensor of shape :math:`(L, H_{in})` for unbatched input, + :math:`(L, N, H_{in})` when ``batch_first=False`` or + :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of + the input sequence. The input can also be a packed variable length sequence. + See :func:`torch.nn.utils.rnn.pack_padded_sequence` or + :func:`torch.nn.utils.rnn.pack_sequence` for details. + * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or + :math:`(D * \text{num\_layers}, N, H_{out})` containing the initial hidden + state for the input sequence batch. Defaults to zeros if not provided. + + where: + + .. math:: + \begin{aligned} + N ={} & \text{batch size} \\ + L ={} & \text{sequence length} \\ + D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\ + H_{in} ={} & \text{input\_size} \\ + H_{out} ={} & \text{hidden\_size} + \end{aligned} + + Outputs: output, h_n + * **output**: tensor of shape :math:`(L, D * H_{out})` for unbatched input, + :math:`(L, N, D * H_{out})` when ``batch_first=False`` or + :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features + `(h_t)` from the last layer of the RNN, for each `t`. If a + :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output + will also be a packed sequence. + * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or + :math:`(D * \text{num\_layers}, N, H_{out})` containing the final hidden state + for each element in the batch. + + Attributes: + weight_ih_l[k]: the learnable input-hidden weights of the k-th layer, + of shape `(hidden_size, input_size)` for `k = 0`. Otherwise, the shape is + `(hidden_size, num_directions * hidden_size)` + weight_hh_l[k]: the learnable hidden-hidden weights of the k-th layer, + of shape `(hidden_size, hidden_size)` + bias_ih_l[k]: the learnable input-hidden bias of the k-th layer, + of shape `(hidden_size)` + bias_hh_l[k]: the learnable hidden-hidden bias of the k-th layer, + of shape `(hidden_size)` + + .. note:: + All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` + where :math:`k = \frac{1}{\text{hidden\_size}}` + + .. note:: + For bidirectional RNNs, forward and backward are directions 0 and 1 respectively. + Example of splitting the output layers when ``batch_first=False``: + ``output.view(seq_len, batch, num_directions, hidden_size)``. + + .. note:: + ``batch_first`` argument is ignored for unbatched inputs. + + .. include:: ../cudnn_rnn_determinism.rst + + .. include:: ../cudnn_persistent_rnn.rst + + Examples:: + + >>> rnn = nn.RNN(10, 20, 2) + >>> input = torch.randn(5, 3, 10) + >>> h0 = torch.randn(2, 3, 20) + >>> output, hn = rnn(input, h0) + """ + + @overload + def __init__( + self, + input_size: int, + hidden_size: int, + num_layers: int = 1, + nonlinearity: str = "tanh", + bias: bool = True, + batch_first: bool = False, + dropout: float = 0.0, + bidirectional: bool = False, + device=None, + dtype=None, + ) -> None: + ... + + @overload + def __init__(self, *args, **kwargs): + ... + + def __init__(self, *args, **kwargs): + if "proj_size" in kwargs: + raise ValueError( + "proj_size argument is only supported for LSTM, not RNN or GRU" + ) + if len(args) > 3: + self.nonlinearity = args[3] + args = args[:3] + args[4:] + else: + self.nonlinearity = kwargs.pop("nonlinearity", "tanh") + if self.nonlinearity == "tanh": + mode = "RNN_TANH" + elif self.nonlinearity == "relu": + mode = "RNN_RELU" + else: + raise ValueError( + f"Unknown nonlinearity '{self.nonlinearity}'. Select from 'tanh' or 'relu'." + ) + super().__init__(mode, *args, **kwargs) + + @overload + @torch._jit_internal._overload_method # noqa: F811 + def forward( + self, input: Tensor, hx: Optional[Tensor] = None + ) -> Tuple[Tensor, Tensor]: + pass + + @overload + @torch._jit_internal._overload_method # noqa: F811 + def forward( + self, input: PackedSequence, hx: Optional[Tensor] = None + ) -> Tuple[PackedSequence, Tensor]: + pass + + def forward(self, input, hx=None): # noqa: F811 + self._update_flat_weights() + + num_directions = 2 if self.bidirectional else 1 + orig_input = input + + if isinstance(orig_input, PackedSequence): + input, batch_sizes, sorted_indices, unsorted_indices = input + max_batch_size = batch_sizes[0] + # script() is unhappy when max_batch_size is different type in cond branches, so we duplicate + if hx is None: + hx = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) + else: + # Each batch of the hidden state should match the input sequence that + # the user believes he/she is passing in. + hx = self.permute_hidden(hx, sorted_indices) + else: + batch_sizes = None + if input.dim() not in (2, 3): + raise ValueError( + f"RNN: Expected input to be 2D or 3D, got {input.dim()}D tensor instead" + ) + is_batched = input.dim() == 3 + batch_dim = 0 if self.batch_first else 1 + if not is_batched: + input = input.unsqueeze(batch_dim) + if hx is not None: + if hx.dim() != 2: + raise RuntimeError( + f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor" + ) + hx = hx.unsqueeze(1) + else: + if hx is not None and hx.dim() != 3: + raise RuntimeError( + f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor" + ) + max_batch_size = input.size(0) if self.batch_first else input.size(1) + sorted_indices = None + unsorted_indices = None + if hx is None: + hx = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) + else: + # Each batch of the hidden state should match the input sequence that + # the user believes he/she is passing in. + hx = self.permute_hidden(hx, sorted_indices) + + assert hx is not None + self.check_forward_args(input, hx, batch_sizes) + assert self.mode == "RNN_TANH" or self.mode == "RNN_RELU" + if batch_sizes is None: + if self.mode == "RNN_TANH": + result = _VF.rnn_tanh( + input, + hx, + self._flat_weights, + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + self.batch_first, + ) + else: + result = _VF.rnn_relu( + input, + hx, + self._flat_weights, + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + self.batch_first, + ) + else: + if self.mode == "RNN_TANH": + result = _VF.rnn_tanh( + input, + batch_sizes, + hx, + self._flat_weights, + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + ) + else: + result = _VF.rnn_relu( + input, + batch_sizes, + hx, + self._flat_weights, + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + ) + + output = result[0] + hidden = result[1] + + if isinstance(orig_input, PackedSequence): + output_packed = PackedSequence( + output, batch_sizes, sorted_indices, unsorted_indices + ) + return output_packed, self.permute_hidden(hidden, unsorted_indices) + + if not is_batched: # type: ignore[possibly-undefined] + output = output.squeeze(batch_dim) # type: ignore[possibly-undefined] + hidden = hidden.squeeze(1) + + return output, self.permute_hidden(hidden, unsorted_indices) + + +# XXX: LSTM and GRU implementation is different from RNNBase, this is because: +# 1. we want to support nn.LSTM and nn.GRU in TorchScript and TorchScript in +# its current state could not support the python Union Type or Any Type +# 2. TorchScript static typing does not allow a Function or Callable type in +# Dict values, so we have to separately call _VF instead of using _rnn_impls +# 3. This is temporary only and in the transition state that we want to make it +# on time for the release +# +# More discussion details in https://github.com/pytorch/pytorch/pull/23266 +# +# TODO: remove the overriding implementations for LSTM and GRU when TorchScript +# support expressing these two modules generally. + + +class LSTM(RNNBase): + r"""__init__(input_size,hidden_size,num_layers=1,bias=True,batch_first=False,dropout=0.0,bidirectional=False,proj_size=0,device=None,dtype=None) + + Apply a multi-layer long short-term memory (LSTM) RNN to an input sequence. + For each element in the input sequence, each layer computes the following + function: + + .. math:: + \begin{array}{ll} \\ + i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\ + f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\ + g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\ + o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\ + c_t = f_t \odot c_{t-1} + i_t \odot g_t \\ + h_t = o_t \odot \tanh(c_t) \\ + \end{array} + + where :math:`h_t` is the hidden state at time `t`, :math:`c_t` is the cell + state at time `t`, :math:`x_t` is the input at time `t`, :math:`h_{t-1}` + is the hidden state of the layer at time `t-1` or the initial hidden + state at time `0`, and :math:`i_t`, :math:`f_t`, :math:`g_t`, + :math:`o_t` are the input, forget, cell, and output gates, respectively. + :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product. + + In a multilayer LSTM, the input :math:`x^{(l)}_t` of the :math:`l` -th layer + (:math:`l \ge 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by + dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random + variable which is :math:`0` with probability :attr:`dropout`. + + If ``proj_size > 0`` is specified, LSTM with projections will be used. This changes + the LSTM cell in the following way. First, the dimension of :math:`h_t` will be changed from + ``hidden_size`` to ``proj_size`` (dimensions of :math:`W_{hi}` will be changed accordingly). + Second, the output hidden state of each layer will be multiplied by a learnable projection + matrix: :math:`h_t = W_{hr}h_t`. Note that as a consequence of this, the output + of LSTM network will be of different shape as well. See Inputs/Outputs sections below for exact + dimensions of all variables. You can find more details in https://arxiv.org/abs/1402.1128. + + Args: + input_size: The number of expected features in the input `x` + hidden_size: The number of features in the hidden state `h` + num_layers: Number of recurrent layers. E.g., setting ``num_layers=2`` + would mean stacking two LSTMs together to form a `stacked LSTM`, + with the second LSTM taking in outputs of the first LSTM and + computing the final results. Default: 1 + bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. + Default: ``True`` + batch_first: If ``True``, then the input and output tensors are provided + as `(batch, seq, feature)` instead of `(seq, batch, feature)`. + Note that this does not apply to hidden or cell states. See the + Inputs/Outputs sections below for details. Default: ``False`` + dropout: If non-zero, introduces a `Dropout` layer on the outputs of each + LSTM layer except the last layer, with dropout probability equal to + :attr:`dropout`. Default: 0 + bidirectional: If ``True``, becomes a bidirectional LSTM. Default: ``False`` + proj_size: If ``> 0``, will use LSTM with projections of corresponding size. Default: 0 + + Inputs: input, (h_0, c_0) + * **input**: tensor of shape :math:`(L, H_{in})` for unbatched input, + :math:`(L, N, H_{in})` when ``batch_first=False`` or + :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of + the input sequence. The input can also be a packed variable length sequence. + See :func:`torch.nn.utils.rnn.pack_padded_sequence` or + :func:`torch.nn.utils.rnn.pack_sequence` for details. + * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or + :math:`(D * \text{num\_layers}, N, H_{out})` containing the + initial hidden state for each element in the input sequence. + Defaults to zeros if (h_0, c_0) is not provided. + * **c_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{cell})` for unbatched input or + :math:`(D * \text{num\_layers}, N, H_{cell})` containing the + initial cell state for each element in the input sequence. + Defaults to zeros if (h_0, c_0) is not provided. + + where: + + .. math:: + \begin{aligned} + N ={} & \text{batch size} \\ + L ={} & \text{sequence length} \\ + D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\ + H_{in} ={} & \text{input\_size} \\ + H_{cell} ={} & \text{hidden\_size} \\ + H_{out} ={} & \text{proj\_size if } \text{proj\_size}>0 \text{ otherwise hidden\_size} \\ + \end{aligned} + + Outputs: output, (h_n, c_n) + * **output**: tensor of shape :math:`(L, D * H_{out})` for unbatched input, + :math:`(L, N, D * H_{out})` when ``batch_first=False`` or + :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features + `(h_t)` from the last layer of the LSTM, for each `t`. If a + :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output + will also be a packed sequence. When ``bidirectional=True``, `output` will contain + a concatenation of the forward and reverse hidden states at each time step in the sequence. + * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or + :math:`(D * \text{num\_layers}, N, H_{out})` containing the + final hidden state for each element in the sequence. When ``bidirectional=True``, + `h_n` will contain a concatenation of the final forward and reverse hidden states, respectively. + * **c_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{cell})` for unbatched input or + :math:`(D * \text{num\_layers}, N, H_{cell})` containing the + final cell state for each element in the sequence. When ``bidirectional=True``, + `c_n` will contain a concatenation of the final forward and reverse cell states, respectively. + + Attributes: + weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer + `(W_ii|W_if|W_ig|W_io)`, of shape `(4*hidden_size, input_size)` for `k = 0`. + Otherwise, the shape is `(4*hidden_size, num_directions * hidden_size)`. If + ``proj_size > 0`` was specified, the shape will be + `(4*hidden_size, num_directions * proj_size)` for `k > 0` + weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer + `(W_hi|W_hf|W_hg|W_ho)`, of shape `(4*hidden_size, hidden_size)`. If ``proj_size > 0`` + was specified, the shape will be `(4*hidden_size, proj_size)`. + bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer + `(b_ii|b_if|b_ig|b_io)`, of shape `(4*hidden_size)` + bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer + `(b_hi|b_hf|b_hg|b_ho)`, of shape `(4*hidden_size)` + weight_hr_l[k] : the learnable projection weights of the :math:`\text{k}^{th}` layer + of shape `(proj_size, hidden_size)`. Only present when ``proj_size > 0`` was + specified. + weight_ih_l[k]_reverse: Analogous to `weight_ih_l[k]` for the reverse direction. + Only present when ``bidirectional=True``. + weight_hh_l[k]_reverse: Analogous to `weight_hh_l[k]` for the reverse direction. + Only present when ``bidirectional=True``. + bias_ih_l[k]_reverse: Analogous to `bias_ih_l[k]` for the reverse direction. + Only present when ``bidirectional=True``. + bias_hh_l[k]_reverse: Analogous to `bias_hh_l[k]` for the reverse direction. + Only present when ``bidirectional=True``. + weight_hr_l[k]_reverse: Analogous to `weight_hr_l[k]` for the reverse direction. + Only present when ``bidirectional=True`` and ``proj_size > 0`` was specified. + + .. note:: + All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` + where :math:`k = \frac{1}{\text{hidden\_size}}` + + .. note:: + For bidirectional LSTMs, forward and backward are directions 0 and 1 respectively. + Example of splitting the output layers when ``batch_first=False``: + ``output.view(seq_len, batch, num_directions, hidden_size)``. + + .. note:: + For bidirectional LSTMs, `h_n` is not equivalent to the last element of `output`; the + former contains the final forward and reverse hidden states, while the latter contains the + final forward hidden state and the initial reverse hidden state. + + .. note:: + ``batch_first`` argument is ignored for unbatched inputs. + + .. note:: + ``proj_size`` should be smaller than ``hidden_size``. + + .. include:: ../cudnn_rnn_determinism.rst + + .. include:: ../cudnn_persistent_rnn.rst + + Examples:: + + >>> rnn = nn.LSTM(10, 20, 2) + >>> input = torch.randn(5, 3, 10) + >>> h0 = torch.randn(2, 3, 20) + >>> c0 = torch.randn(2, 3, 20) + >>> output, (hn, cn) = rnn(input, (h0, c0)) + """ + + @overload + def __init__( + self, + input_size: int, + hidden_size: int, + num_layers: int = 1, + bias: bool = True, + batch_first: bool = False, + dropout: float = 0.0, + bidirectional: bool = False, + proj_size: int = 0, + device=None, + dtype=None, + ) -> None: + ... + + @overload + def __init__(self, *args, **kwargs): + ... + + def __init__(self, *args, **kwargs): + super().__init__("LSTM", *args, **kwargs) + + def get_expected_cell_size( + self, input: Tensor, batch_sizes: Optional[Tensor] + ) -> Tuple[int, int, int]: + if batch_sizes is not None: + mini_batch = int(batch_sizes[0]) + else: + mini_batch = input.size(0) if self.batch_first else input.size(1) + num_directions = 2 if self.bidirectional else 1 + expected_hidden_size = ( + self.num_layers * num_directions, + mini_batch, + self.hidden_size, + ) + return expected_hidden_size + + # In the future, we should prevent mypy from applying contravariance rules here. + # See torch/nn/modules/module.py::_forward_unimplemented + def check_forward_args( + self, + input: Tensor, + hidden: Tuple[Tensor, Tensor], # type: ignore[override] + batch_sizes: Optional[Tensor], + ): + self.check_input(input, batch_sizes) + self.check_hidden_size( + hidden[0], + self.get_expected_hidden_size(input, batch_sizes), + "Expected hidden[0] size {}, got {}", + ) + self.check_hidden_size( + hidden[1], + self.get_expected_cell_size(input, batch_sizes), + "Expected hidden[1] size {}, got {}", + ) + + # Same as above, see torch/nn/modules/module.py::_forward_unimplemented + def permute_hidden( # type: ignore[override] + self, + hx: Tuple[Tensor, Tensor], + permutation: Optional[Tensor], + ) -> Tuple[Tensor, Tensor]: + if permutation is None: + return hx + return _apply_permutation(hx[0], permutation), _apply_permutation( + hx[1], permutation + ) + + # Same as above, see torch/nn/modules/module.py::_forward_unimplemented + @overload # type: ignore[override] + @torch._jit_internal._overload_method # noqa: F811 + def forward( + self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None + ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: # noqa: F811 + pass + + # Same as above, see torch/nn/modules/module.py::_forward_unimplemented + @overload + @torch._jit_internal._overload_method # noqa: F811 + def forward( + self, input: PackedSequence, hx: Optional[Tuple[Tensor, Tensor]] = None + ) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]]: # noqa: F811 + pass + + def forward(self, input, hx=None): # noqa: F811 + self._update_flat_weights() + + orig_input = input + # xxx: isinstance check needs to be in conditional for TorchScript to compile + batch_sizes = None + do_permute = False + num_directions = 2 if self.bidirectional else 1 + real_hidden_size = self.proj_size if self.proj_size > 0 else self.hidden_size + if isinstance(orig_input, PackedSequence): + input, batch_sizes, sorted_indices, unsorted_indices = input + max_batch_size = batch_sizes[0] + if hx is None: + h_zeros = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + real_hidden_size, + dtype=input.dtype, + device=input.device, + ) + c_zeros = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) + hx = (h_zeros, c_zeros) + else: + # Each batch of the hidden state should match the input sequence that + # the user believes he/she is passing in. + hx = self.permute_hidden(hx, sorted_indices) + else: + if input.dim() not in (2, 3): + raise ValueError( + f"LSTM: Expected input to be 2D or 3D, got {input.dim()}D instead" + ) + is_batched = input.dim() == 3 + batch_dim = 0 if self.batch_first else 1 + if not is_batched: + input = input.unsqueeze(batch_dim) + max_batch_size = input.size(0) if self.batch_first else input.size(1) + sorted_indices = None + unsorted_indices = None + if hx is None: + h_zeros = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + real_hidden_size, + dtype=input.dtype, + device=input.device, + ) + c_zeros = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) + hx = (h_zeros, c_zeros) + self.check_forward_args(input, hx, batch_sizes) + else: + if is_batched: + if hx[0].dim() != 3 or hx[1].dim() != 3: + msg = ( + "For batched 3-D input, hx and cx should " + f"also be 3-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors" + ) + raise RuntimeError(msg) + else: + if hx[0].dim() != 2 or hx[1].dim() != 2: + msg = ( + "For unbatched 2-D input, hx and cx should " + f"also be 2-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors" + ) + raise RuntimeError(msg) + hx = (hx[0].unsqueeze(1), hx[1].unsqueeze(1)) + # Each batch of the hidden state should match the input sequence that + # the user believes he/she is passing in. + self.check_forward_args(input, hx, batch_sizes) + hx = self.permute_hidden(hx, sorted_indices) + + if batch_sizes is None: + result = _VF.lstm( + input, + hx, + self._flat_weights, + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + self.batch_first, + ) + else: + result = _VF.lstm( + input, + batch_sizes, + hx, + self._flat_weights, + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + ) + output = result[0] + hidden = result[1:] + # xxx: isinstance check needs to be in conditional for TorchScript to compile + if isinstance(orig_input, PackedSequence): + output_packed = PackedSequence( + output, batch_sizes, sorted_indices, unsorted_indices + ) + return output_packed, self.permute_hidden(hidden, unsorted_indices) + else: + if not is_batched: # type: ignore[possibly-undefined] + output = output.squeeze(batch_dim) # type: ignore[possibly-undefined] + hidden = (hidden[0].squeeze(1), hidden[1].squeeze(1)) + return output, self.permute_hidden(hidden, unsorted_indices) + + +class GRU(RNNBase): + r"""__init__(input_size,hidden_size,num_layers=1,bias=True,batch_first=False,dropout=0.0,bidirectional=False,device=None,dtype=None) + + Apply a multi-layer gated recurrent unit (GRU) RNN to an input sequence. + For each element in the input sequence, each layer computes the following + function: + + .. math:: + \begin{array}{ll} + r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\ + z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\ + n_t = \tanh(W_{in} x_t + b_{in} + r_t \odot (W_{hn} h_{(t-1)}+ b_{hn})) \\ + h_t = (1 - z_t) \odot n_t + z_t \odot h_{(t-1)} + \end{array} + + where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the input + at time `t`, :math:`h_{(t-1)}` is the hidden state of the layer + at time `t-1` or the initial hidden state at time `0`, and :math:`r_t`, + :math:`z_t`, :math:`n_t` are the reset, update, and new gates, respectively. + :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product. + + In a multilayer GRU, the input :math:`x^{(l)}_t` of the :math:`l` -th layer + (:math:`l \ge 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by + dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random + variable which is :math:`0` with probability :attr:`dropout`. + + Args: + input_size: The number of expected features in the input `x` + hidden_size: The number of features in the hidden state `h` + num_layers: Number of recurrent layers. E.g., setting ``num_layers=2`` + would mean stacking two GRUs together to form a `stacked GRU`, + with the second GRU taking in outputs of the first GRU and + computing the final results. Default: 1 + bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. + Default: ``True`` + batch_first: If ``True``, then the input and output tensors are provided + as `(batch, seq, feature)` instead of `(seq, batch, feature)`. + Note that this does not apply to hidden or cell states. See the + Inputs/Outputs sections below for details. Default: ``False`` + dropout: If non-zero, introduces a `Dropout` layer on the outputs of each + GRU layer except the last layer, with dropout probability equal to + :attr:`dropout`. Default: 0 + bidirectional: If ``True``, becomes a bidirectional GRU. Default: ``False`` + + Inputs: input, h_0 + * **input**: tensor of shape :math:`(L, H_{in})` for unbatched input, + :math:`(L, N, H_{in})` when ``batch_first=False`` or + :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of + the input sequence. The input can also be a packed variable length sequence. + See :func:`torch.nn.utils.rnn.pack_padded_sequence` or + :func:`torch.nn.utils.rnn.pack_sequence` for details. + * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` or + :math:`(D * \text{num\_layers}, N, H_{out})` + containing the initial hidden state for the input sequence. Defaults to zeros if not provided. + + where: + + .. math:: + \begin{aligned} + N ={} & \text{batch size} \\ + L ={} & \text{sequence length} \\ + D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\ + H_{in} ={} & \text{input\_size} \\ + H_{out} ={} & \text{hidden\_size} + \end{aligned} + + Outputs: output, h_n + * **output**: tensor of shape :math:`(L, D * H_{out})` for unbatched input, + :math:`(L, N, D * H_{out})` when ``batch_first=False`` or + :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features + `(h_t)` from the last layer of the GRU, for each `t`. If a + :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output + will also be a packed sequence. + * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` or + :math:`(D * \text{num\_layers}, N, H_{out})` containing the final hidden state + for the input sequence. + + Attributes: + weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer + (W_ir|W_iz|W_in), of shape `(3*hidden_size, input_size)` for `k = 0`. + Otherwise, the shape is `(3*hidden_size, num_directions * hidden_size)` + weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer + (W_hr|W_hz|W_hn), of shape `(3*hidden_size, hidden_size)` + bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer + (b_ir|b_iz|b_in), of shape `(3*hidden_size)` + bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer + (b_hr|b_hz|b_hn), of shape `(3*hidden_size)` + + .. note:: + All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` + where :math:`k = \frac{1}{\text{hidden\_size}}` + + .. note:: + For bidirectional GRUs, forward and backward are directions 0 and 1 respectively. + Example of splitting the output layers when ``batch_first=False``: + ``output.view(seq_len, batch, num_directions, hidden_size)``. + + .. note:: + ``batch_first`` argument is ignored for unbatched inputs. + + .. note:: + The calculation of new gate :math:`n_t` subtly differs from the original paper and other frameworks. + In the original implementation, the Hadamard product :math:`(\odot)` between :math:`r_t` and the + previous hidden state :math:`h_{(t-1)}` is done before the multiplication with the weight matrix + `W` and addition of bias: + + .. math:: + \begin{aligned} + n_t = \tanh(W_{in} x_t + b_{in} + W_{hn} ( r_t \odot h_{(t-1)} ) + b_{hn}) + \end{aligned} + + This is in contrast to PyTorch implementation, which is done after :math:`W_{hn} h_{(t-1)}` + + .. math:: + \begin{aligned} + n_t = \tanh(W_{in} x_t + b_{in} + r_t \odot (W_{hn} h_{(t-1)}+ b_{hn})) + \end{aligned} + + This implementation differs on purpose for efficiency. + + .. include:: ../cudnn_persistent_rnn.rst + + Examples:: + + >>> rnn = nn.GRU(10, 20, 2) + >>> input = torch.randn(5, 3, 10) + >>> h0 = torch.randn(2, 3, 20) + >>> output, hn = rnn(input, h0) + """ + + @overload + def __init__( + self, + input_size: int, + hidden_size: int, + num_layers: int = 1, + bias: bool = True, + batch_first: bool = False, + dropout: float = 0.0, + bidirectional: bool = False, + device=None, + dtype=None, + ) -> None: + ... + + @overload + def __init__(self, *args, **kwargs): + ... + + def __init__(self, *args, **kwargs): + if "proj_size" in kwargs: + raise ValueError( + "proj_size argument is only supported for LSTM, not RNN or GRU" + ) + super().__init__("GRU", *args, **kwargs) + + @overload # type: ignore[override] + @torch._jit_internal._overload_method # noqa: F811 + def forward( + self, input: Tensor, hx: Optional[Tensor] = None + ) -> Tuple[Tensor, Tensor]: # noqa: F811 + pass + + @overload + @torch._jit_internal._overload_method # noqa: F811 + def forward( + self, input: PackedSequence, hx: Optional[Tensor] = None + ) -> Tuple[PackedSequence, Tensor]: # noqa: F811 + pass + + def forward(self, input, hx=None): # noqa: F811 + self._update_flat_weights() + + orig_input = input + # xxx: isinstance check needs to be in conditional for TorchScript to compile + if isinstance(orig_input, PackedSequence): + input, batch_sizes, sorted_indices, unsorted_indices = input + max_batch_size = batch_sizes[0] + if hx is None: + num_directions = 2 if self.bidirectional else 1 + hx = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) + else: + # Each batch of the hidden state should match the input sequence that + # the user believes he/she is passing in. + hx = self.permute_hidden(hx, sorted_indices) + else: + batch_sizes = None + if input.dim() not in (2, 3): + raise ValueError( + f"GRU: Expected input to be 2D or 3D, got {input.dim()}D instead" + ) + is_batched = input.dim() == 3 + batch_dim = 0 if self.batch_first else 1 + if not is_batched: + input = input.unsqueeze(batch_dim) + if hx is not None: + if hx.dim() != 2: + raise RuntimeError( + f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor" + ) + hx = hx.unsqueeze(1) + else: + if hx is not None and hx.dim() != 3: + raise RuntimeError( + f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor" + ) + max_batch_size = input.size(0) if self.batch_first else input.size(1) + sorted_indices = None + unsorted_indices = None + if hx is None: + num_directions = 2 if self.bidirectional else 1 + hx = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) + else: + # Each batch of the hidden state should match the input sequence that + # the user believes he/she is passing in. + hx = self.permute_hidden(hx, sorted_indices) + + self.check_forward_args(input, hx, batch_sizes) + if batch_sizes is None: + result = _VF.gru( + input, + hx, + self._flat_weights, + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + self.batch_first, + ) + else: + result = _VF.gru( + input, + batch_sizes, + hx, + self._flat_weights, + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + ) + output = result[0] + hidden = result[1] + + # xxx: isinstance check needs to be in conditional for TorchScript to compile + if isinstance(orig_input, PackedSequence): + output_packed = PackedSequence( + output, batch_sizes, sorted_indices, unsorted_indices + ) + return output_packed, self.permute_hidden(hidden, unsorted_indices) + else: + if not is_batched: # type: ignore[possibly-undefined] + output = output.squeeze(batch_dim) # type: ignore[possibly-undefined] + hidden = hidden.squeeze(1) + + return output, self.permute_hidden(hidden, unsorted_indices) + + +class RNNCellBase(Module): + __constants__ = ["input_size", "hidden_size", "bias"] + + input_size: int + hidden_size: int + bias: bool + weight_ih: Tensor + weight_hh: Tensor + # WARNING: bias_ih and bias_hh purposely not defined here. + # See https://github.com/pytorch/pytorch/issues/39670 + + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool, + num_chunks: int, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.bias = bias + self.weight_ih = Parameter( + torch.empty((num_chunks * hidden_size, input_size), **factory_kwargs) + ) + self.weight_hh = Parameter( + torch.empty((num_chunks * hidden_size, hidden_size), **factory_kwargs) + ) + if bias: + self.bias_ih = Parameter( + torch.empty(num_chunks * hidden_size, **factory_kwargs) + ) + self.bias_hh = Parameter( + torch.empty(num_chunks * hidden_size, **factory_kwargs) + ) + else: + self.register_parameter("bias_ih", None) + self.register_parameter("bias_hh", None) + + self.reset_parameters() + + def extra_repr(self) -> str: + s = "{input_size}, {hidden_size}" + if "bias" in self.__dict__ and self.bias is not True: + s += ", bias={bias}" + if "nonlinearity" in self.__dict__ and self.nonlinearity != "tanh": + s += ", nonlinearity={nonlinearity}" + return s.format(**self.__dict__) + + def reset_parameters(self) -> None: + stdv = 1.0 / math.sqrt(self.hidden_size) if self.hidden_size > 0 else 0 + for weight in self.parameters(): + init.uniform_(weight, -stdv, stdv) + + +class RNNCell(RNNCellBase): + r"""An Elman RNN cell with tanh or ReLU non-linearity. + + .. math:: + + h' = \tanh(W_{ih} x + b_{ih} + W_{hh} h + b_{hh}) + + If :attr:`nonlinearity` is `'relu'`, then ReLU is used in place of tanh. + + Args: + input_size: The number of expected features in the input `x` + hidden_size: The number of features in the hidden state `h` + bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. + Default: ``True`` + nonlinearity: The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'`` + + Inputs: input, hidden + - **input**: tensor containing input features + - **hidden**: tensor containing the initial hidden state + Defaults to zero if not provided. + + Outputs: h' + - **h'** of shape `(batch, hidden_size)`: tensor containing the next hidden state + for each element in the batch + + Shape: + - input: :math:`(N, H_{in})` or :math:`(H_{in})` tensor containing input features where + :math:`H_{in}` = `input_size`. + - hidden: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the initial hidden + state where :math:`H_{out}` = `hidden_size`. Defaults to zero if not provided. + - output: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the next hidden state. + + Attributes: + weight_ih: the learnable input-hidden weights, of shape + `(hidden_size, input_size)` + weight_hh: the learnable hidden-hidden weights, of shape + `(hidden_size, hidden_size)` + bias_ih: the learnable input-hidden bias, of shape `(hidden_size)` + bias_hh: the learnable hidden-hidden bias, of shape `(hidden_size)` + + .. note:: + All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` + where :math:`k = \frac{1}{\text{hidden\_size}}` + + Examples:: + + >>> rnn = nn.RNNCell(10, 20) + >>> input = torch.randn(6, 3, 10) + >>> hx = torch.randn(3, 20) + >>> output = [] + >>> for i in range(6): + ... hx = rnn(input[i], hx) + ... output.append(hx) + """ + + __constants__ = ["input_size", "hidden_size", "bias", "nonlinearity"] + nonlinearity: str + + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool = True, + nonlinearity: str = "tanh", + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(input_size, hidden_size, bias, num_chunks=1, **factory_kwargs) + self.nonlinearity = nonlinearity + + def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: + if input.dim() not in (1, 2): + raise ValueError( + f"RNNCell: Expected input to be 1D or 2D, got {input.dim()}D instead" + ) + if hx is not None and hx.dim() not in (1, 2): + raise ValueError( + f"RNNCell: Expected hidden to be 1D or 2D, got {hx.dim()}D instead" + ) + is_batched = input.dim() == 2 + if not is_batched: + input = input.unsqueeze(0) + + if hx is None: + hx = torch.zeros( + input.size(0), self.hidden_size, dtype=input.dtype, device=input.device + ) + else: + hx = hx.unsqueeze(0) if not is_batched else hx + + if self.nonlinearity == "tanh": + ret = _VF.rnn_tanh_cell( + input, + hx, + self.weight_ih, + self.weight_hh, + self.bias_ih, + self.bias_hh, + ) + elif self.nonlinearity == "relu": + ret = _VF.rnn_relu_cell( + input, + hx, + self.weight_ih, + self.weight_hh, + self.bias_ih, + self.bias_hh, + ) + else: + ret = input # TODO: remove when jit supports exception flow + raise RuntimeError(f"Unknown nonlinearity: {self.nonlinearity}") + + if not is_batched: + ret = ret.squeeze(0) + + return ret + + +class LSTMCell(RNNCellBase): + r"""A long short-term memory (LSTM) cell. + + .. math:: + + \begin{array}{ll} + i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\ + f = \sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\ + g = \tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\ + o = \sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\ + c' = f \odot c + i \odot g \\ + h' = o \odot \tanh(c') \\ + \end{array} + + where :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product. + + Args: + input_size: The number of expected features in the input `x` + hidden_size: The number of features in the hidden state `h` + bias: If ``False``, then the layer does not use bias weights `b_ih` and + `b_hh`. Default: ``True`` + + Inputs: input, (h_0, c_0) + - **input** of shape `(batch, input_size)` or `(input_size)`: tensor containing input features + - **h_0** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the initial hidden state + - **c_0** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the initial cell state + + If `(h_0, c_0)` is not provided, both **h_0** and **c_0** default to zero. + + Outputs: (h_1, c_1) + - **h_1** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the next hidden state + - **c_1** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the next cell state + + Attributes: + weight_ih: the learnable input-hidden weights, of shape + `(4*hidden_size, input_size)` + weight_hh: the learnable hidden-hidden weights, of shape + `(4*hidden_size, hidden_size)` + bias_ih: the learnable input-hidden bias, of shape `(4*hidden_size)` + bias_hh: the learnable hidden-hidden bias, of shape `(4*hidden_size)` + + .. note:: + All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` + where :math:`k = \frac{1}{\text{hidden\_size}}` + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Examples:: + + >>> rnn = nn.LSTMCell(10, 20) # (input_size, hidden_size) + >>> input = torch.randn(2, 3, 10) # (time_steps, batch, input_size) + >>> hx = torch.randn(3, 20) # (batch, hidden_size) + >>> cx = torch.randn(3, 20) + >>> output = [] + >>> for i in range(input.size()[0]): + ... hx, cx = rnn(input[i], (hx, cx)) + ... output.append(hx) + >>> output = torch.stack(output, dim=0) + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(input_size, hidden_size, bias, num_chunks=4, **factory_kwargs) + + def forward( + self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None + ) -> Tuple[Tensor, Tensor]: + if input.dim() not in (1, 2): + raise ValueError( + f"LSTMCell: Expected input to be 1D or 2D, got {input.dim()}D instead" + ) + if hx is not None: + for idx, value in enumerate(hx): + if value.dim() not in (1, 2): + raise ValueError( + f"LSTMCell: Expected hx[{idx}] to be 1D or 2D, got {value.dim()}D instead" + ) + is_batched = input.dim() == 2 + if not is_batched: + input = input.unsqueeze(0) + + if hx is None: + zeros = torch.zeros( + input.size(0), self.hidden_size, dtype=input.dtype, device=input.device + ) + hx = (zeros, zeros) + else: + hx = (hx[0].unsqueeze(0), hx[1].unsqueeze(0)) if not is_batched else hx + + ret = _VF.lstm_cell( + input, + hx, + self.weight_ih, + self.weight_hh, + self.bias_ih, + self.bias_hh, + ) + + if not is_batched: + ret = (ret[0].squeeze(0), ret[1].squeeze(0)) + return ret + + +class GRUCell(RNNCellBase): + r"""A gated recurrent unit (GRU) cell. + + .. math:: + + \begin{array}{ll} + r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\ + z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\ + n = \tanh(W_{in} x + b_{in} + r \odot (W_{hn} h + b_{hn})) \\ + h' = (1 - z) \odot n + z \odot h + \end{array} + + where :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product. + + Args: + input_size: The number of expected features in the input `x` + hidden_size: The number of features in the hidden state `h` + bias: If ``False``, then the layer does not use bias weights `b_ih` and + `b_hh`. Default: ``True`` + + Inputs: input, hidden + - **input** : tensor containing input features + - **hidden** : tensor containing the initial hidden + state for each element in the batch. + Defaults to zero if not provided. + + Outputs: h' + - **h'** : tensor containing the next hidden state + for each element in the batch + + Shape: + - input: :math:`(N, H_{in})` or :math:`(H_{in})` tensor containing input features where + :math:`H_{in}` = `input_size`. + - hidden: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the initial hidden + state where :math:`H_{out}` = `hidden_size`. Defaults to zero if not provided. + - output: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the next hidden state. + + Attributes: + weight_ih: the learnable input-hidden weights, of shape + `(3*hidden_size, input_size)` + weight_hh: the learnable hidden-hidden weights, of shape + `(3*hidden_size, hidden_size)` + bias_ih: the learnable input-hidden bias, of shape `(3*hidden_size)` + bias_hh: the learnable hidden-hidden bias, of shape `(3*hidden_size)` + + .. note:: + All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` + where :math:`k = \frac{1}{\text{hidden\_size}}` + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Examples:: + + >>> rnn = nn.GRUCell(10, 20) + >>> input = torch.randn(6, 3, 10) + >>> hx = torch.randn(3, 20) + >>> output = [] + >>> for i in range(6): + ... hx = rnn(input[i], hx) + ... output.append(hx) + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(input_size, hidden_size, bias, num_chunks=3, **factory_kwargs) + + def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: + if input.dim() not in (1, 2): + raise ValueError( + f"GRUCell: Expected input to be 1D or 2D, got {input.dim()}D instead" + ) + if hx is not None and hx.dim() not in (1, 2): + raise ValueError( + f"GRUCell: Expected hidden to be 1D or 2D, got {hx.dim()}D instead" + ) + is_batched = input.dim() == 2 + if not is_batched: + input = input.unsqueeze(0) + + if hx is None: + hx = torch.zeros( + input.size(0), self.hidden_size, dtype=input.dtype, device=input.device + ) + else: + hx = hx.unsqueeze(0) if not is_batched else hx + + ret = _VF.gru_cell( + input, + hx, + self.weight_ih, + self.weight_hh, + self.bias_ih, + self.bias_hh, + ) + + if not is_batched: + ret = ret.squeeze(0) + + return ret