diff --git a/gym/envs/mini_cheetah/mini_cheetah_ref_config.py b/gym/envs/mini_cheetah/mini_cheetah_ref_config.py index bd8140b2..258806f4 100644 --- a/gym/envs/mini_cheetah/mini_cheetah_ref_config.py +++ b/gym/envs/mini_cheetah/mini_cheetah_ref_config.py @@ -8,7 +8,7 @@ class MiniCheetahRefCfg(MiniCheetahCfg): class env(MiniCheetahCfg.env): - num_envs = 4096 + num_envs = 1 # Fine tuning num_actuators = 12 episode_length_s = 5.0 @@ -67,7 +67,7 @@ class scaling(MiniCheetahCfg.scaling): class MiniCheetahRefRunnerCfg(MiniCheetahRunnerCfg): seed = -1 - runner_class_name = "OnPolicyRunner" + runner_class_name = "HybridPolicyRunner" class actor: hidden_dims = [256, 256, 128] @@ -87,6 +87,8 @@ class actor: disable_actions = False class noise: + noise_multiplier = 10.0 # Fine tuning: multiplies all noise + scale = 1.0 dof_pos_obs = 0.01 base_ang_vel = 0.01 @@ -138,11 +140,39 @@ class termination_weight: termination = 0.15 class algorithm(MiniCheetahRunnerCfg.algorithm): - pass + # both + gamma = 0.99 + lam = 0.95 + # shared + batch_size = 500 # 2**15 + max_gradient_steps = 10 + # new + storage_size = 2**17 # new + + clip_param = 0.2 + learning_rate = 1.0e-3 + max_grad_norm = 1.0 + # Critic + use_clipped_value_loss = True + # Actor + entropy_coef = 0.01 + schedule = "adaptive" # could be adaptive, fixed + desired_kl = 0.01 + + # GePPO + vtrace = True + normalize_advantages = False # weighted normalization in GePPO loss + recursive_advantages = True # applies to vtrace + is_trunc = 1.0 class runner(MiniCheetahRunnerCfg.runner): run_name = "" experiment_name = "mini_cheetah_ref" - max_iterations = 500 # number of policy updates - algorithm_class_name = "PPO2" - num_steps_per_env = 32 # deprecate + max_iterations = 50 # number of policy updates + algorithm_class_name = "GePPO" + num_steps_per_env = 500 # deprecate + num_old_policies = 4 + + # Fine tuning + resume = True + load_run = "rollout_32" # pretrained PPO run diff --git a/gym/envs/pendulum/pendulum.py b/gym/envs/pendulum/pendulum.py index 39c8b331..d768bbf6 100644 --- a/gym/envs/pendulum/pendulum.py +++ b/gym/envs/pendulum/pendulum.py @@ -33,11 +33,7 @@ def _reward_equilibrium(self): error = torch.stack( [theta_norm / self.scales["dof_pos"], omega / self.scales["dof_vel"]], dim=1 ) - return self._sqrdexp(torch.mean(error, dim=1), scale=0.01) - - def _reward_torques(self): - """Penalize torques""" - return self._sqrdexp(torch.mean(torch.square(self.torques), dim=1), scale=0.2) + return self._sqrdexp(torch.mean(error, dim=1), scale=0.2) def _reward_energy(self): kinetic_energy = ( diff --git a/gym/envs/pendulum/pendulum_config.py b/gym/envs/pendulum/pendulum_config.py index 0fb705ae..11493419 100644 --- a/gym/envs/pendulum/pendulum_config.py +++ b/gym/envs/pendulum/pendulum_config.py @@ -59,7 +59,7 @@ class scaling(FixedRobotCfg.scaling): class PendulumRunnerCfg(FixedRobotCfgPPO): seed = -1 - runner_class_name = "OnPolicyRunner" + runner_class_name = "HybridPolicyRunner" class actor: hidden_dims = [128, 64, 32] @@ -115,7 +115,7 @@ class algorithm(FixedRobotCfgPPO.algorithm): storage_size = 2**17 # new batch_size = 2**16 # new clip_param = 0.2 - learning_rate = 1.0e-4 + learning_rate = 3e-4 max_grad_norm = 1.0 # Critic use_clipped_value_loss = True @@ -124,9 +124,16 @@ class algorithm(FixedRobotCfgPPO.algorithm): schedule = "fixed" # could be adaptive, fixed desired_kl = 0.01 + # GePPO + vtrace = True + normalize_advantages = False # weighted normalization in GePPO loss + recursive_advantages = True # applies to vtrace + is_trunc = 1.0 + class runner(FixedRobotCfgPPO.runner): run_name = "" experiment_name = "pendulum" max_iterations = 200 # number of policy updates - algorithm_class_name = "PPO2" + algorithm_class_name = "GePPO" num_steps_per_env = 32 + num_old_policies = 4 diff --git a/learning/algorithms/__init__.py b/learning/algorithms/__init__.py index 78231181..6703faac 100644 --- a/learning/algorithms/__init__.py +++ b/learning/algorithms/__init__.py @@ -32,5 +32,6 @@ from .ppo import PPO from .ppo2 import PPO2 +from .geppo import GePPO from .SE import StateEstimator from .sac import SAC \ No newline at end of file diff --git a/learning/algorithms/geppo.py b/learning/algorithms/geppo.py new file mode 100644 index 00000000..a56d432f --- /dev/null +++ b/learning/algorithms/geppo.py @@ -0,0 +1,214 @@ +import torch +import torch.nn as nn + +from .ppo2 import PPO2 +from learning.utils import ( + create_uniform_generator, + compute_generalized_advantages, + compute_gae_vtrace, + normalize, +) + + +# Implementation based on GePPO repo: https://github.com/jqueeney/geppo +class GePPO(PPO2): + def __init__( + self, + actor, + critic, + num_steps_per_env=32, + vtrace=True, + normalize_advantages=False, + recursive_advantages=True, + is_trunc=1.0, + eps_ppo=0.2, + eps_geppo=0.1, + eps_vary=False, + adapt_lr=True, + adapt_factor=0.03, + adapt_minthresh=0.0, + adapt_maxthresh=1.0, + **kwargs, + ): + super().__init__(actor, critic, **kwargs) + self.num_steps_per_env = num_steps_per_env + + # GAE parameters + self.vtrace = vtrace + self.normalize_advantages = normalize_advantages + self.recursive_advantages = recursive_advantages + + # Importance sampling truncation + self.is_trunc = is_trunc + + # Clipping parameter + self.eps_ppo = eps_ppo + self.eps_geppo = eps_geppo + self.eps_vary = eps_vary + + # Learning rate + self.adapt_lr = adapt_lr + self.adapt_factor = adapt_factor + self.adapt_minthresh = adapt_minthresh + self.adapt_maxthresh = adapt_maxthresh + + self.updated = False + + def update(self, data, weights): + data["values"] = self.critic.evaluate(data["critic_obs"]) + + # Compute GAE with and without V-trace + adv, ret = self.compute_gae_all(data, vtrace=False) + adv_vtrace, ret_vtrace = self.compute_gae_all(data, vtrace=True) + + # Only use V-trace if we have updated once already + if self.vtrace and self.updated: + data["advantages"] = adv_vtrace + data["returns"] = ret_vtrace + else: + data["advantages"] = adv + data["returns"] = ret + self.updated = True + + # Update critic and actor + data["weights"] = weights + self.update_critic(data) + data["advantages"] = normalize(data["advantages"]) + self.update_actor(data) + + # Update pik weights + if self.actor.store_pik: + self.actor.update_pik_weights() + + # Logging: Store mean advantages and returns + self.adv_mean = adv.mean().item() + self.ret_mean = ret.mean().item() + self.adv_vtrace_mean = adv_vtrace.mean().item() + self.ret_vtrace_mean = ret_vtrace.mean().item() + + def compute_gae_all(self, data, vtrace): + # Compute GAE for each policy and concatenate + adv = torch.zeros_like(data["values"]).to(self.device) + ret = torch.zeros_like(data["values"]).to(self.device) + steps = self.num_steps_per_env + loaded_policies = data["values"].shape[0] // steps + + for i in range(loaded_policies): + data_i = data[i * steps : (i + 1) * steps] + if vtrace: + adv_i, ret_i = compute_gae_vtrace( + data_i, + self.gamma, + self.lam, + self.is_trunc, + self.actor, + self.critic, + rec=self.recursive_advantages, + ) + else: + adv_i = compute_generalized_advantages( + data_i, self.gamma, self.lam, self.critic + ) + ret_i = adv_i + data_i["values"] + adv[i * steps : (i + 1) * steps] = adv_i + ret[i * steps : (i + 1) * steps] = ret_i + + return adv, ret + + def update_critic(self, data): + self.mean_value_loss = 0 + counter = 0 + + generator = create_uniform_generator( + data, + self.batch_size, + max_gradient_steps=self.max_gradient_steps, + ) + for batch in generator: + # GePPO critic loss uses weights + value_loss = self.critic.loss_fn( + batch["critic_obs"], batch["returns"], batch["weights"] + ) + self.critic_optimizer.zero_grad() + value_loss.backward() + nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm) + self.critic_optimizer.step() + self.mean_value_loss += value_loss.item() + counter += 1 + self.mean_value_loss /= counter + + def update_actor(self, data): + # Update clipping eps + if self.eps_vary: + log_prob_pik = self.actor.get_pik_log_prob( + data["actor_obs"], data["actions"] + ) + offpol_ratio = torch.exp(log_prob_pik - data["log_prob"]) + # TODO: I am taking the mean over 2 dims, check if this is correct + eps_old = torch.mean(data["weights"] * torch.abs(offpol_ratio - 1.0)) + self.eps_geppo = max(self.eps_ppo - eps_old.item(), 0.0) + + self.mean_surrogate_loss = 0 + counter = 0 + + generator = create_uniform_generator( + data, + self.batch_size, + max_gradient_steps=self.max_gradient_steps, + ) + for batch in generator: + self.actor.act(batch["actor_obs"]) + entropy_batch = self.actor.entropy + + # * GePPO Surrogate loss + log_prob = self.actor.get_actions_log_prob(batch["actions"]) + log_prob_pik = self.actor.get_pik_log_prob( + batch["actor_obs"], batch["actions"] + ) + ratio = torch.exp(log_prob - batch["log_prob"]) + offpol_ratio = torch.exp(log_prob_pik - batch["log_prob"]) + + advantages = batch["advantages"] + if self.normalize_advantages: + adv_mean = torch.mean( + offpol_ratio * batch["weights"] * advantages + ) / torch.mean(offpol_ratio * batch["weights"]) + adv_std = torch.std(offpol_ratio * batch["weights"] * advantages) + advantages = (advantages - adv_mean) / (adv_std + 1e-8) + + surrogate = -torch.squeeze(advantages) * ratio + surrogate_clipped = -torch.squeeze(advantages) * torch.clamp( + ratio, offpol_ratio - self.eps_geppo, offpol_ratio + self.eps_geppo + ) + surrogate_loss = ( + torch.max(surrogate, surrogate_clipped) * batch["weights"] + ).mean() + + loss = surrogate_loss - self.entropy_coef * entropy_batch.mean() + + # * Gradient step + self.optimizer.zero_grad() + loss.backward() + nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm) + self.optimizer.step() + self.mean_surrogate_loss += surrogate_loss.item() + counter += 1 + self.mean_surrogate_loss /= counter + + # Compute TV, add to self for logging + self.actor.act(data["actor_obs"]) + log_prob = self.actor.get_actions_log_prob(data["actions"]) + log_prob_pik = self.actor.get_pik_log_prob(data["actor_obs"], data["actions"]) + ratio = torch.exp(log_prob - data["log_prob"]) + clip_center = torch.exp(log_prob_pik - data["log_prob"]) + ratio_diff = torch.abs(ratio - clip_center) + self.tv = 0.5 * torch.mean(data["weights"] * ratio_diff) + + # Adapt learning rate + if self.adapt_lr: + if self.tv > (self.adapt_maxthresh * (0.5 * self.eps_geppo)): + self.learning_rate /= 1 + self.adapt_factor + elif self.tv < (self.adapt_minthresh * (0.5 * self.eps_geppo)): + self.learning_rate *= 1 + self.adapt_factor + for param_group in self.optimizer.param_groups: + param_group["lr"] = self.learning_rate diff --git a/learning/algorithms/ppo2.py b/learning/algorithms/ppo2.py index 09c60ab9..c1c4c126 100644 --- a/learning/algorithms/ppo2.py +++ b/learning/algorithms/ppo2.py @@ -63,6 +63,11 @@ def update(self, data): data, self.gamma, self.lam, self.critic ) data["returns"] = data["advantages"] + data["values"] + + # Logging: Store mean advantages and returns + self.adv_mean = data["advantages"].mean().item() + self.ret_mean = data["returns"].mean().item() + self.update_critic(data) data["advantages"] = normalize(data["advantages"]) self.update_actor(data) diff --git a/learning/modules/actor.py b/learning/modules/actor.py index fbaa6868..ab6e15e9 100644 --- a/learning/modules/actor.py +++ b/learning/modules/actor.py @@ -1,8 +1,7 @@ import torch import torch.nn as nn from torch.distributions import Normal -from .utils import create_MLP -from .utils import export_network +from .utils import StaticNN, create_MLP, export_network from .utils import RunningMeanStd @@ -15,6 +14,7 @@ def __init__( activation="elu", init_noise_std=1.0, normalize_obs=True, + store_pik=False, **kwargs, ): super().__init__() @@ -33,6 +33,14 @@ def __init__( # disable args validation for speedup Normal.set_default_validate_args = False + self.store_pik = store_pik + if self.store_pik: + self._NN_pik = StaticNN( + create_MLP(num_obs, num_actions, hidden_dims, activation) + ) + self._std_pik = self.std.detach().clone() + self.update_pik_weights() + @property def action_mean(self): return self.distribution.mean @@ -45,6 +53,14 @@ def action_std(self): def entropy(self): return self.distribution.entropy().sum(dim=-1) + @property + def obs_running_mean(self): + return self.obs_rms.running_mean + + @property + def obs_running_std(self): + return self.obs_rms.running_var.sqrt() + def update_distribution(self, observations): mean = self.act_inference(observations) self.distribution = Normal(mean, mean * 0.0 + self.std) @@ -67,3 +83,19 @@ def forward(self, observations): def export(self, path): export_network(self, "policy", path, self.num_obs) + + def update_pik_weights(self): + with torch.no_grad(): + nn_state_dict = self.NN.state_dict() + self._NN_pik.model.load_state_dict(nn_state_dict) + self._std_pik = self.std.detach().clone() + + def get_pik_log_prob(self, observations, actions): + if self._normalize_obs: + # TODO: Check if this updates the normalization mean/std + with torch.no_grad(): + observations = self.obs_rms(observations) + mean_pik = self._NN_pik(observations) + std_pik = self._std_pik.to(mean_pik.device) + distribution = Normal(mean_pik, mean_pik * 0.0 + std_pik) + return distribution.log_prob(actions).sum(dim=-1) diff --git a/learning/modules/critic.py b/learning/modules/critic.py index 96c53438..cec533a7 100644 --- a/learning/modules/critic.py +++ b/learning/modules/critic.py @@ -24,10 +24,14 @@ def forward(self, x): if self._normalize_obs: with torch.no_grad(): x = self.obs_rms(x) - return self.NN(x).squeeze() + return self.NN(x).squeeze(-1) def evaluate(self, critic_observations): return self.forward(critic_observations) - def loss_fn(self, obs, target): - return nn.functional.mse_loss(self.forward(obs), target, reduction="mean") + def loss_fn(self, obs, target, weights=None): + if weights is None: + return nn.functional.mse_loss(self.forward(obs), target, reduction="mean") + + # Compute MSE loss with weights + return (weights * (self.forward(obs) - target).pow(2)).mean() diff --git a/learning/modules/utils/__init__.py b/learning/modules/utils/__init__.py index 1422b82d..cf34428b 100644 --- a/learning/modules/utils/__init__.py +++ b/learning/modules/utils/__init__.py @@ -1,2 +1,2 @@ -from .neural_net import create_MLP, export_network +from .neural_net import StaticNN, create_MLP, export_network from .normalize import RunningMeanStd diff --git a/learning/modules/utils/neural_net.py b/learning/modules/utils/neural_net.py index 5fd6ab3b..002b4ed8 100644 --- a/learning/modules/utils/neural_net.py +++ b/learning/modules/utils/neural_net.py @@ -3,6 +3,20 @@ import copy +class StaticNN: + """ + A static neural network wrapper that is used just for inference. + """ + + def __init__(self, model): + self.model = model + + def __call__(self, x): + model = self.model.to(x.device) + with torch.no_grad(): + return model(x) + + def create_MLP(num_inputs, num_outputs, hidden_dims, activation, dropouts=None): activation = get_activation(activation) diff --git a/learning/modules/utils/normalize.py b/learning/modules/utils/normalize.py index 246bafa4..00471325 100644 --- a/learning/modules/utils/normalize.py +++ b/learning/modules/utils/normalize.py @@ -40,8 +40,9 @@ def _update_mean_var_from_moments( def forward(self, input): if self.training: + # TODO: check this, it got rid of NaN values in first iteration mean = input.mean(tuple(range(input.dim() - 1))) - var = input.var(tuple(range(input.dim() - 1))) + var = torch.nan_to_num(input.var(tuple(range(input.dim() - 1)))) ( self.running_mean, self.running_var, diff --git a/learning/runners/BaseRunner.py b/learning/runners/BaseRunner.py index 9e5ab2e3..fc5d745d 100644 --- a/learning/runners/BaseRunner.py +++ b/learning/runners/BaseRunner.py @@ -49,6 +49,8 @@ def get_noise(self, obs_list, noise_dict): noise_tensor = torch.ones(obs_size).to(self.device) * torch.tensor( noise_dict[obs] ).to(self.device) + if "noise_multiplier" in noise_dict.keys(): + noise_tensor *= noise_dict["noise_multiplier"] if obs in self.env.scales.keys(): noise_tensor /= self.env.scales[obs] noise_vec[obs_index : obs_index + obs_size] = noise_tensor diff --git a/learning/runners/__init__.py b/learning/runners/__init__.py index a0840b06..8831a0ac 100644 --- a/learning/runners/__init__.py +++ b/learning/runners/__init__.py @@ -33,4 +33,5 @@ from .on_policy_runner import OnPolicyRunner from .my_runner import MyRunner from .old_policy_runner import OldPolicyRunner -from .off_policy_runner import OffPolicyRunner \ No newline at end of file +from .off_policy_runner import OffPolicyRunner +from .hybrid_policy_runner import HybridPolicyRunner \ No newline at end of file diff --git a/learning/runners/hybrid_policy_runner.py b/learning/runners/hybrid_policy_runner.py new file mode 100644 index 00000000..3e70f332 --- /dev/null +++ b/learning/runners/hybrid_policy_runner.py @@ -0,0 +1,275 @@ +import os +import torch +from tensordict import TensorDict + +from learning.utils import Logger + +from .BaseRunner import BaseRunner +from learning.modules import Actor, Critic +from learning.storage import ReplayBuffer +from learning.algorithms import GePPO + +logger = Logger() +storage = ReplayBuffer() + + +class HybridPolicyRunner(BaseRunner): + def __init__(self, env, train_cfg, device="cpu"): + super().__init__(env, train_cfg, device) + logger.initialize( + self.env.num_envs, + self.env.dt, + self.cfg["max_iterations"], + self.device, + ) + + # TODO: Weights hardcoded for 4 policies + self.num_old_policies = self.cfg["num_old_policies"] + self.weights = torch.tensor([0.4, 0.3, 0.2, 0.1]).to(self.device) + + def _set_up_alg(self): + alg_class = eval(self.cfg["algorithm_class_name"]) + if alg_class != GePPO: + raise ValueError("HybridPolicyRunner only supports GePPO") + + num_actor_obs = self.get_obs_size(self.actor_cfg["obs"]) + num_actions = self.get_action_size(self.actor_cfg["actions"]) + num_critic_obs = self.get_obs_size(self.critic_cfg["obs"]) + # Store pik for the actor + actor = Actor(num_actor_obs, num_actions, store_pik=True, **self.actor_cfg) + critic = Critic(num_critic_obs, **self.critic_cfg) + self.alg = alg_class( + actor, + critic, + device=self.device, + num_steps_per_env=self.num_steps_per_env, # GePPO needs this + **self.alg_cfg, + ) + + def learn(self): + self.set_up_logger() + + rewards_dict = {} + + self.alg.switch_to_train() + actor_obs = self.get_obs(self.actor_cfg["obs"]) + critic_obs = self.get_obs(self.critic_cfg["obs"]) + tot_iter = self.it + self.num_learning_iterations + self.save() + + # * start up storage + transition = TensorDict({}, batch_size=self.env.num_envs, device=self.device) + transition.update( + { + "actor_obs": actor_obs, + "next_actor_obs": actor_obs, + "actions": self.alg.act(actor_obs), + "critic_obs": critic_obs, + "next_critic_obs": critic_obs, + "rewards": self.get_rewards({"termination": 0.0})["termination"], + "dones": self.get_timed_out(), + } + ) + max_storage = self.env.num_envs * self.num_steps_per_env * self.num_old_policies + storage.initialize( + dummy_dict=transition, + num_envs=self.env.num_envs, + max_storage=max_storage, + device=self.device, + ) + + # burn in observation normalization. + if self.actor_cfg["normalize_obs"] or self.critic_cfg["normalize_obs"]: + self.burn_in_normalization() + + logger.tic("runtime") + for self.it in range(self.it + 1, tot_iter + 1): + logger.tic("iteration") + logger.tic("collection") + # * Rollout + with torch.inference_mode(): + for i in range(self.num_steps_per_env): + actions = self.alg.act(actor_obs) + self.set_actions( + self.actor_cfg["actions"], + actions, + self.actor_cfg["disable_actions"], + ) + # Store additional data for GePPO + log_prob = self.alg.actor.get_actions_log_prob(actions) + action_mean = self.alg.actor.action_mean + action_std = self.alg.actor.action_std + + transition.update( + { + "actor_obs": actor_obs, + "actions": actions, + "critic_obs": critic_obs, + "log_prob": log_prob, + "action_mean": action_mean, + "action_std": action_std, + } + ) + + self.env.step() + + actor_obs = self.get_noisy_obs( + self.actor_cfg["obs"], self.actor_cfg["noise"] + ) + critic_obs = self.get_obs(self.critic_cfg["obs"]) + + # * get time_outs + timed_out = self.get_timed_out() + terminated = self.get_terminated() + dones = timed_out | terminated + + self.update_rewards(rewards_dict, terminated) + total_rewards = torch.stack(tuple(rewards_dict.values())).sum(dim=0) + + transition.update( + { + "next_actor_obs": actor_obs, + "next_critic_obs": critic_obs, + "rewards": total_rewards, + "timed_out": timed_out, + "dones": dones, + } + ) + storage.add_transitions(transition) + + logger.log_rewards(rewards_dict) + logger.log_rewards({"total_rewards": total_rewards}) + logger.finish_step(dones) + logger.toc("collection") + + # Compute GePPO weights + num_policies = storage.fill_count // self.num_steps_per_env + num_policies = min(num_policies, self.num_old_policies) + weights_active = self.weights[:num_policies] + weights_active = weights_active * num_policies / weights_active.sum() + idx_newest = (self.it - 1) % self.num_old_policies + indices_all = [ + i % self.num_old_policies + for i in range(idx_newest, idx_newest - num_policies, -1) + ] + weights_all = weights_active[indices_all] + weights_all = weights_all.repeat_interleave(self.num_steps_per_env) + weights_all = weights_all.unsqueeze(-1).repeat(1, self.env.num_envs) + + # Update GePPO with weights + logger.tic("learning") + self.alg.update(storage.get_data(), weights=weights_all) + logger.toc("learning") + logger.log_all_categories() + + logger.finish_iteration() + logger.toc("iteration") + logger.toc("runtime") + logger.print_to_terminal() + + if self.it % self.save_interval == 0: + self.save() + self.save() + + @torch.no_grad + def burn_in_normalization(self, n_iterations=100): + actor_obs = self.get_obs(self.actor_cfg["obs"]) + critic_obs = self.get_obs(self.critic_cfg["obs"]) + for _ in range(n_iterations): + actions = self.alg.act(actor_obs) + self.set_actions(self.actor_cfg["actions"], actions) + self.env.step() + actor_obs = self.get_noisy_obs( + self.actor_cfg["obs"], self.actor_cfg["noise"] + ) + critic_obs = self.get_obs(self.critic_cfg["obs"]) + self.alg.critic.evaluate(critic_obs) + self.env.reset() + + def update_rewards(self, rewards_dict, terminated): + rewards_dict.update( + self.get_rewards( + self.critic_cfg["reward"]["termination_weight"], mask=terminated + ) + ) + rewards_dict.update( + self.get_rewards( + self.critic_cfg["reward"]["weights"], + modifier=self.env.dt, + mask=~terminated, + ) + ) + + def set_up_logger(self): + logger.register_rewards(list(self.critic_cfg["reward"]["weights"].keys())) + logger.register_rewards( + list(self.critic_cfg["reward"]["termination_weight"].keys()) + ) + logger.register_rewards(["total_rewards"]) + logger.register_category( + "algorithm", + self.alg, + [ + "learning_rate", + "mean_value_loss", + "mean_surrogate_loss", + "adv_mean", + "ret_mean", + # GePPO specific: + "adv_vtrace_mean", + "ret_vtrace_mean", + "eps_geppo", + "tv", + ], + ) + logger.register_category( + "actor", + self.alg.actor, + [ + "action_mean", + "action_std", + "entropy", + "obs_running_mean", + "obs_running_std", + ], + ) + + logger.attach_torch_obj_to_wandb((self.alg.actor, self.alg.critic)) + + def save(self): + os.makedirs(self.log_dir, exist_ok=True) + path = os.path.join(self.log_dir, "model_{}.pt".format(self.it)) + torch.save( + { + "actor_state_dict": self.alg.actor.state_dict(), + "critic_state_dict": self.alg.critic.state_dict(), + "optimizer_state_dict": self.alg.optimizer.state_dict(), + "critic_optimizer_state_dict": self.alg.critic_optimizer.state_dict(), + "iter": self.it, + }, + path, + ) + + def load(self, path, load_optimizer=True): + loaded_dict = torch.load(path) + self.alg.actor.load_state_dict(loaded_dict["actor_state_dict"]) + # Update pik NN weights + self.alg.actor.update_pik_weights() + self.alg.critic.load_state_dict(loaded_dict["critic_state_dict"]) + if load_optimizer: + self.alg.optimizer.load_state_dict(loaded_dict["optimizer_state_dict"]) + self.alg.critic_optimizer.load_state_dict( + loaded_dict["critic_optimizer_state_dict"] + ) + self.it = loaded_dict["iter"] + + def switch_to_eval(self): + self.alg.actor.eval() + self.alg.critic.eval() + + def get_inference_actions(self): + obs = self.get_noisy_obs(self.actor_cfg["obs"], self.actor_cfg["noise"]) + return self.alg.actor.act_inference(obs) + + def export(self, path): + self.alg.actor.export(path) diff --git a/learning/runners/on_policy_runner.py b/learning/runners/on_policy_runner.py index c1758000..6cd7e2d6 100644 --- a/learning/runners/on_policy_runner.py +++ b/learning/runners/on_policy_runner.py @@ -160,9 +160,27 @@ def set_up_logger(self): ) logger.register_rewards(["total_rewards"]) logger.register_category( - "algorithm", self.alg, ["mean_value_loss", "mean_surrogate_loss"] + "algorithm", + self.alg, + [ + "learning_rate", + "mean_value_loss", + "mean_surrogate_loss", + "adv_mean", + "ret_mean", + ], + ) + logger.register_category( + "actor", + self.alg.actor, + [ + "action_mean", + "action_std", + "entropy", + "obs_running_mean", + "obs_running_std", + ], ) - logger.register_category("actor", self.alg.actor, ["action_std", "entropy"]) logger.attach_torch_obj_to_wandb((self.alg.actor, self.alg.critic)) diff --git a/learning/utils/dict_utils.py b/learning/utils/dict_utils.py index 99534c22..d8138343 100644 --- a/learning/utils/dict_utils.py +++ b/learning/utils/dict_utils.py @@ -51,6 +51,84 @@ def compute_generalized_advantages(data, gamma, lam, critic): return advantages +# Implementation based on GePPO repo: https://github.com/jqueeney/geppo +@torch.no_grad +def compute_gae_vtrace(data, gamma, lam, is_trunc, actor, critic, rec=False): + if actor.store_pik is False: + raise ValueError("Need to store pik for V-trace") + + # n: rollout length, e: num envs + n, e = data.shape + log_prob_pik = actor.get_pik_log_prob(data["actor_obs"], data["actions"]) + ratio = torch.exp(log_prob_pik - data["log_prob"]) # shape [n, e] + ratio_trunc = torch.clamp_max(ratio, is_trunc) # [n, e] + + if rec: + # Recursive version + last_values = critic.evaluate(data["next_critic_obs"][-1]) + advantages = torch.zeros_like(data["values"]) + if last_values is not None: + # TODO: check this (copied from regular GAE) + # since we don't have obs for the last step, need last value plugged in + not_done = ~data["dones"][-1] + advantages[-1] = ( + data["rewards"][-1] + + gamma * data["values"][-1] * data["timed_out"][-1] + + gamma * last_values * not_done + - data["values"][-1] + ) + + for k in reversed(range(data["values"].shape[0] - 1)): + not_done = ~data["dones"][k] + td_error = ( + data["rewards"][k] + + gamma * data["values"][k] * data["timed_out"][k] + + gamma * data["values"][k + 1] * not_done + - data["values"][k] + ) + advantages[k] = ( + td_error + gamma * lam * not_done * ratio_trunc[k] * advantages[k + 1] + ) + + returns = advantages * ratio_trunc + data["values"] + + else: + # GePPO paper version + ratio_trunc_T = ratio_trunc.transpose(0, 1) # [e, n] + ratio_trunc_repeat = ratio_trunc_T.unsqueeze(-1).repeat(1, 1, n) # [e, n, n] + ratio_trunc_L = torch.tril(ratio_trunc_repeat, -1) + ones_U = torch.triu(torch.ones((n, n)), 0).to(data.device) + + # cumprod along axis 1, keep shape [e, n, n] + ratio_trunc_prods = torch.tril(torch.cumprod(ratio_trunc_L + ones_U, axis=1), 0) + + # everything in data dict is [n, e] + values = data["values"] + values_next = critic.evaluate(data["next_critic_obs"]) + if values_next.dim() == 1: + values_next = values_next.reshape((n, e)) + not_done = ~data["dones"] + + delta = data["rewards"] + gamma * values_next * not_done - values # [n, e] + delta_T = delta.transpose(0, 1) # [e, n] + delta_repeat = delta_T.unsqueeze(-1).repeat(1, 1, n) # [e, n, n] + + rate_L = torch.tril(torch.ones((n, n)) * gamma * lam, -1).to( + data.device + ) # [n, n] + rates = torch.tril(torch.cumprod(rate_L + ones_U, axis=0), 0) + rates_repeat = rates.unsqueeze(0).repeat(e, 1, 1) # [e, n, n] + + # element-wise multiplication: + intermediate = rates_repeat * ratio_trunc_prods * delta_repeat # [e, n, n] + advantages = torch.sum(intermediate, axis=1) # [e, n] + + advantages = advantages.transpose(0, 1) # [n, e] + returns = advantages * ratio_trunc + values # [n, e] + + return advantages, returns + + # todo change num_epochs to num_batches @torch.no_grad def create_uniform_generator(