From 617742bb8511a8f6f39445b151e848ce7a2de96a Mon Sep 17 00:00:00 2001 From: Lukas Molnar Date: Wed, 19 Jun 2024 13:48:25 -0400 Subject: [PATCH 01/10] Integrated GAE with Vtrace into PPO2 --- gym/envs/pendulum/pendulum.py | 6 +--- gym/envs/pendulum/pendulum_config.py | 5 +++ learning/algorithms/ppo2.py | 38 +++++++++++++++++++--- learning/modules/actor.py | 21 +++++++++++++ learning/utils/dict_utils.py | 47 ++++++++++++++++++++++++++++ 5 files changed, 107 insertions(+), 10 deletions(-) diff --git a/gym/envs/pendulum/pendulum.py b/gym/envs/pendulum/pendulum.py index 39c8b331..d768bbf6 100644 --- a/gym/envs/pendulum/pendulum.py +++ b/gym/envs/pendulum/pendulum.py @@ -33,11 +33,7 @@ def _reward_equilibrium(self): error = torch.stack( [theta_norm / self.scales["dof_pos"], omega / self.scales["dof_vel"]], dim=1 ) - return self._sqrdexp(torch.mean(error, dim=1), scale=0.01) - - def _reward_torques(self): - """Penalize torques""" - return self._sqrdexp(torch.mean(torch.square(self.torques), dim=1), scale=0.2) + return self._sqrdexp(torch.mean(error, dim=1), scale=0.2) def _reward_energy(self): kinetic_energy = ( diff --git a/gym/envs/pendulum/pendulum_config.py b/gym/envs/pendulum/pendulum_config.py index 0fb705ae..129c5d9e 100644 --- a/gym/envs/pendulum/pendulum_config.py +++ b/gym/envs/pendulum/pendulum_config.py @@ -73,6 +73,8 @@ class actor: "dof_vel", ] + store_pik = True + actions = ["tau_ff"] disable_actions = False @@ -124,6 +126,9 @@ class algorithm(FixedRobotCfgPPO.algorithm): schedule = "fixed" # could be adaptive, fixed desired_kl = 0.01 + # GePPO + geppo = True + class runner(FixedRobotCfgPPO.runner): run_name = "" experiment_name = "pendulum" diff --git a/learning/algorithms/ppo2.py b/learning/algorithms/ppo2.py index 09c60ab9..e9433160 100644 --- a/learning/algorithms/ppo2.py +++ b/learning/algorithms/ppo2.py @@ -5,6 +5,7 @@ from learning.utils import ( create_uniform_generator, compute_generalized_advantages, + compute_gae_vtrace, normalize, ) @@ -25,6 +26,8 @@ def __init__( use_clipped_value_loss=True, schedule="fixed", desired_kl=0.01, + geppo=False, + is_trunc=1.0, device="cpu", **kwargs, ): @@ -50,6 +53,10 @@ def __init__( self.max_grad_norm = max_grad_norm self.use_clipped_value_loss = use_clipped_value_loss + # * GePPO parameters + self.geppo = geppo + self.is_trunc = is_trunc + def switch_to_train(self): self.actor.train() self.critic.train() @@ -58,15 +65,36 @@ def act(self, obs): return self.actor.act(obs).detach() def update(self, data): - data["values"] = self.critic.evaluate(data["critic_obs"]) - data["advantages"] = compute_generalized_advantages( - data, self.gamma, self.lam, self.critic - ) - data["returns"] = data["advantages"] + data["values"] + values = self.critic.evaluate(data["critic_obs"]) + # handle single env case + if values.dim() == 1: + values = values.unsqueeze(-1) + data["values"] = values + + if self.geppo: + advantages, returns = compute_gae_vtrace( + data, self.gamma, self.lam, self.is_trunc, self.actor, self.critic + ) + # handle single env case + if advantages.dim() == 1: + advantages = advantages.unsqueeze(-1) + if returns.dim() == 1: + returns = returns.unsqueeze(-1) + data["advantages"] = advantages + data["returns"] = returns + else: + data["advantages"] = compute_generalized_advantages( + data, self.gamma, self.lam, self.critic + ) + data["returns"] = data["advantages"] + data["values"] + self.update_critic(data) data["advantages"] = normalize(data["advantages"]) self.update_actor(data) + if self.actor.store_pik: + self.actor.update_pik_weights() + def update_critic(self, data): self.mean_value_loss = 0 counter = 0 diff --git a/learning/modules/actor.py b/learning/modules/actor.py index fbaa6868..d71d2ac2 100644 --- a/learning/modules/actor.py +++ b/learning/modules/actor.py @@ -15,6 +15,7 @@ def __init__( activation="elu", init_noise_std=1.0, normalize_obs=True, + store_pik=False, **kwargs, ): super().__init__() @@ -33,6 +34,12 @@ def __init__( # disable args validation for speedup Normal.set_default_validate_args = False + self.store_pik = store_pik + if self.store_pik: + self.NN_pik = create_MLP(num_obs, num_actions, hidden_dims, activation) + self.std_pik = self.std.detach().clone() + self.update_pik_weights() + @property def action_mean(self): return self.distribution.mean @@ -67,3 +74,17 @@ def forward(self, observations): def export(self, path): export_network(self, "policy", path, self.num_obs) + + def update_pik_weights(self): + nn_state_dict = self.NN.state_dict() + self.NN_pik.load_state_dict(nn_state_dict) + self.std_pik = self.std.detach().clone() + + def get_pik_log_prob(self, observations, actions): + if self._normalize_obs: + with torch.no_grad(): + observations = self.obs_rms(observations) + mean_pik = self.NN_pik(observations) + std_pik = self.std_pik.to(mean_pik.device) + distribution = Normal(mean_pik, mean_pik * 0.0 + std_pik) + return distribution.log_prob(actions).sum(dim=-1) diff --git a/learning/utils/dict_utils.py b/learning/utils/dict_utils.py index 99534c22..6e29b58f 100644 --- a/learning/utils/dict_utils.py +++ b/learning/utils/dict_utils.py @@ -51,6 +51,53 @@ def compute_generalized_advantages(data, gamma, lam, critic): return advantages +# Implementation based on GePPO repo: https://github.com/jqueeney/geppo +@torch.no_grad +def compute_gae_vtrace(data, gamma, lam, is_trunc, actor, critic): + if actor.store_pik is False: + raise NotImplementedError("Need to store pik for V-trace") + + log_prob = actor.get_actions_log_prob(data["actions"]) + log_prob_pik = actor.get_pik_log_prob(data["actor_obs"], data["actions"]) + + # n: rollout length, e: num envs + # TODO: Double check GePPO code and paper (they diverge imo) + ratio = torch.exp(log_prob - log_prob_pik) # shape [n, e] + + n, e = ratio.shape + ones_U = torch.triu(torch.ones((n, n)), 0).to(data.device) + + ratio_trunc = torch.clamp_max(ratio, is_trunc) # [n, e] + ratio_trunc_T = ratio_trunc.transpose(0, 1) # [e, n] + ratio_trunc_repeat = ratio_trunc_T.unsqueeze(-1).repeat(1, 1, n) # [e, n, n] + ratio_trunc_L = torch.tril(ratio_trunc_repeat, -1) + # cumprod along axis 1, keep shape [e, n, n] + ratio_trunc_prods = torch.tril(torch.cumprod(ratio_trunc_L + ones_U, axis=1), 0) + + # everything in data dict is [n, e] + values = critic.evaluate(data["critic_obs"]) + values_next = critic.evaluate(data["next_critic_obs"]) + not_done = ~data["dones"] + + delta = data["rewards"] + gamma * values_next * not_done - values # [n, e] + delta_T = delta.transpose(0, 1) # [e, n] + delta_repeat = delta_T.unsqueeze(-1).repeat(1, 1, n) # [e, n, n] + + rate_L = torch.tril(torch.ones((n, n)) * gamma * lam, -1).to(data.device) # [n, n] + rates = torch.tril(torch.cumprod(rate_L + ones_U, axis=0), 0) + rates_repeat = rates.unsqueeze(0).repeat(e, 1, 1) # [e, n, n] + batch_prod = torch.bmm(rates_repeat, ratio_trunc_prods) # [e, n, n] + + # element-wise multiplication: + intermediate = batch_prod * delta_repeat # [e, n, n] + advantages = torch.sum(intermediate, axis=1) # [e, n] + + advantages = advantages.transpose(0, 1) # [n, e] + returns = advantages * ratio_trunc + values # [n, e] + + return advantages, returns + + # todo change num_epochs to num_batches @torch.no_grad def create_uniform_generator( From 0ad0e48fc7f576a9ce43a5963b9c09a035e95a2c Mon Sep 17 00:00:00 2001 From: Lukas Molnar Date: Wed, 19 Jun 2024 14:30:13 -0400 Subject: [PATCH 02/10] new GePPO class and log advantages, returns --- gym/envs/pendulum/pendulum_config.py | 5 +-- learning/algorithms/__init__.py | 1 + learning/algorithms/geppo.py | 50 ++++++++++++++++++++++++++++ learning/algorithms/ppo2.py | 33 +++--------------- learning/runners/on_policy_runner.py | 8 +++++ 5 files changed, 67 insertions(+), 30 deletions(-) create mode 100644 learning/algorithms/geppo.py diff --git a/gym/envs/pendulum/pendulum_config.py b/gym/envs/pendulum/pendulum_config.py index 129c5d9e..b9ad2ec2 100644 --- a/gym/envs/pendulum/pendulum_config.py +++ b/gym/envs/pendulum/pendulum_config.py @@ -73,6 +73,7 @@ class actor: "dof_vel", ] + # GePPO store_pik = True actions = ["tau_ff"] @@ -127,11 +128,11 @@ class algorithm(FixedRobotCfgPPO.algorithm): desired_kl = 0.01 # GePPO - geppo = True + is_trunc = 1.0 class runner(FixedRobotCfgPPO.runner): run_name = "" experiment_name = "pendulum" max_iterations = 200 # number of policy updates - algorithm_class_name = "PPO2" + algorithm_class_name = "GePPO" num_steps_per_env = 32 diff --git a/learning/algorithms/__init__.py b/learning/algorithms/__init__.py index 78231181..6703faac 100644 --- a/learning/algorithms/__init__.py +++ b/learning/algorithms/__init__.py @@ -32,5 +32,6 @@ from .ppo import PPO from .ppo2 import PPO2 +from .geppo import GePPO from .SE import StateEstimator from .sac import SAC \ No newline at end of file diff --git a/learning/algorithms/geppo.py b/learning/algorithms/geppo.py new file mode 100644 index 00000000..42b655ec --- /dev/null +++ b/learning/algorithms/geppo.py @@ -0,0 +1,50 @@ +from .ppo2 import PPO2 +from learning.utils import ( + compute_generalized_advantages, + compute_gae_vtrace, + normalize, +) + + +# Implementation based on GePPO repo: https://github.com/jqueeney/geppo +class GePPO(PPO2): + def __init__(self, actor, critic, is_trunc=1.0, **kwargs): + super().__init__(actor, critic, **kwargs) + + # Importance sampling truncation + self.is_trunc = is_trunc + + def update(self, data): + values = self.critic.evaluate(data["critic_obs"]) + # Handle single env case + if values.dim() == 1: + values = values.unsqueeze(-1) + data["values"] = values + + # Compute V-trace GAE + adv_vtrace, ret_vtrace = compute_gae_vtrace( + data, self.gamma, self.lam, self.is_trunc, self.actor, self.critic + ) + # Handle single env case + if adv_vtrace.dim() == 1: + adv_vtrace = adv_vtrace.unsqueeze(-1) + if ret_vtrace.dim() == 1: + ret_vtrace = ret_vtrace.unsqueeze(-1) + data["advantages"] = adv_vtrace + data["returns"] = ret_vtrace + + self.update_critic(data) + data["advantages"] = normalize(data["advantages"]) + self.update_actor(data) + + if self.actor.store_pik: + self.actor.update_pik_weights() + + # Logging: Store mean GAE with and without V-trace + adv = compute_generalized_advantages(data, self.gamma, self.lam, self.critic) + ret = adv + values + + self.adv_mean = adv.mean().item() + self.ret_mean = ret.mean().item() + self.adv_vtrace_mean = adv_vtrace.mean().item() + self.ret_vtrace_mean = ret_vtrace.mean().item() diff --git a/learning/algorithms/ppo2.py b/learning/algorithms/ppo2.py index e9433160..a489ce5e 100644 --- a/learning/algorithms/ppo2.py +++ b/learning/algorithms/ppo2.py @@ -5,7 +5,6 @@ from learning.utils import ( create_uniform_generator, compute_generalized_advantages, - compute_gae_vtrace, normalize, ) @@ -26,8 +25,6 @@ def __init__( use_clipped_value_loss=True, schedule="fixed", desired_kl=0.01, - geppo=False, - is_trunc=1.0, device="cpu", **kwargs, ): @@ -53,10 +50,6 @@ def __init__( self.max_grad_norm = max_grad_norm self.use_clipped_value_loss = use_clipped_value_loss - # * GePPO parameters - self.geppo = geppo - self.is_trunc = is_trunc - def switch_to_train(self): self.actor.train() self.critic.train() @@ -66,35 +59,19 @@ def act(self, obs): def update(self, data): values = self.critic.evaluate(data["critic_obs"]) - # handle single env case + # Handle single env case if values.dim() == 1: values = values.unsqueeze(-1) data["values"] = values - - if self.geppo: - advantages, returns = compute_gae_vtrace( - data, self.gamma, self.lam, self.is_trunc, self.actor, self.critic - ) - # handle single env case - if advantages.dim() == 1: - advantages = advantages.unsqueeze(-1) - if returns.dim() == 1: - returns = returns.unsqueeze(-1) - data["advantages"] = advantages - data["returns"] = returns - else: - data["advantages"] = compute_generalized_advantages( - data, self.gamma, self.lam, self.critic - ) - data["returns"] = data["advantages"] + data["values"] + data["advantages"] = compute_generalized_advantages( + data, self.gamma, self.lam, self.critic + ) + data["returns"] = data["advantages"] + data["values"] self.update_critic(data) data["advantages"] = normalize(data["advantages"]) self.update_actor(data) - if self.actor.store_pik: - self.actor.update_pik_weights() - def update_critic(self, data): self.mean_value_loss = 0 counter = 0 diff --git a/learning/runners/on_policy_runner.py b/learning/runners/on_policy_runner.py index c1758000..f92dd6e8 100644 --- a/learning/runners/on_policy_runner.py +++ b/learning/runners/on_policy_runner.py @@ -6,6 +6,7 @@ from .BaseRunner import BaseRunner from learning.storage import DictStorage +from learning.algorithms import GePPO logger = Logger() storage = DictStorage() @@ -164,6 +165,13 @@ def set_up_logger(self): ) logger.register_category("actor", self.alg.actor, ["action_std", "entropy"]) + if isinstance(self.alg, GePPO): + logger.register_category( + "GePPO", + self.alg, + ["adv_mean", "ret_mean", "adv_vtrace_mean", "ret_vtrace_mean"], + ) + logger.attach_torch_obj_to_wandb((self.alg.actor, self.alg.critic)) def save(self): From ded68ef124a9afc520fd82bf97fc08349b529816 Mon Sep 17 00:00:00 2001 From: Lukas Molnar Date: Fri, 21 Jun 2024 10:02:24 -0400 Subject: [PATCH 03/10] added HybridPolicyRunner and GePPO actor update --- gym/envs/pendulum/pendulum_config.py | 10 +- learning/algorithms/geppo.py | 145 ++++++++++++- learning/runners/__init__.py | 3 +- learning/runners/hybrid_policy_runner.py | 250 +++++++++++++++++++++++ learning/utils/dict_utils.py | 2 +- 5 files changed, 392 insertions(+), 18 deletions(-) create mode 100644 learning/runners/hybrid_policy_runner.py diff --git a/gym/envs/pendulum/pendulum_config.py b/gym/envs/pendulum/pendulum_config.py index b9ad2ec2..905021f7 100644 --- a/gym/envs/pendulum/pendulum_config.py +++ b/gym/envs/pendulum/pendulum_config.py @@ -5,7 +5,7 @@ class PendulumCfg(FixedRobotCfg): class env(FixedRobotCfg.env): - num_envs = 4096 + num_envs = 2048 num_actuators = 1 episode_length_s = 10 @@ -59,7 +59,7 @@ class scaling(FixedRobotCfg.scaling): class PendulumRunnerCfg(FixedRobotCfgPPO): seed = -1 - runner_class_name = "OnPolicyRunner" + runner_class_name = "HybridPolicyRunner" class actor: hidden_dims = [128, 64, 32] @@ -73,9 +73,6 @@ class actor: "dof_vel", ] - # GePPO - store_pik = True - actions = ["tau_ff"] disable_actions = False @@ -135,4 +132,5 @@ class runner(FixedRobotCfgPPO.runner): experiment_name = "pendulum" max_iterations = 200 # number of policy updates algorithm_class_name = "GePPO" - num_steps_per_env = 32 + num_steps_per_env = 64 + num_old_policies = 4 diff --git a/learning/algorithms/geppo.py b/learning/algorithms/geppo.py index 42b655ec..eb381325 100644 --- a/learning/algorithms/geppo.py +++ b/learning/algorithms/geppo.py @@ -1,5 +1,9 @@ +import torch +import torch.nn as nn + from .ppo2 import PPO2 from learning.utils import ( + create_uniform_generator, compute_generalized_advantages, compute_gae_vtrace, normalize, @@ -8,20 +12,37 @@ # Implementation based on GePPO repo: https://github.com/jqueeney/geppo class GePPO(PPO2): - def __init__(self, actor, critic, is_trunc=1.0, **kwargs): + def __init__( + self, + actor, + critic, + is_trunc=1.0, + eps_ppo=0.2, + eps_vary=True, + **kwargs, + ): super().__init__(actor, critic, **kwargs) # Importance sampling truncation self.is_trunc = is_trunc - def update(self, data): + # Clipping parameter + self.eps_ppo = eps_ppo + self.eps_vary = eps_vary + self.eps = self.eps_ppo # TODO: This should be computed + + self.updated = False + + def update(self, data, policy_weights): values = self.critic.evaluate(data["critic_obs"]) # Handle single env case if values.dim() == 1: values = values.unsqueeze(-1) data["values"] = values - # Compute V-trace GAE + # Compute GAE with and without V-trace + adv = compute_generalized_advantages(data, self.gamma, self.lam, self.critic) + ret = adv + values adv_vtrace, ret_vtrace = compute_gae_vtrace( data, self.gamma, self.lam, self.is_trunc, self.actor, self.critic ) @@ -30,21 +51,125 @@ def update(self, data): adv_vtrace = adv_vtrace.unsqueeze(-1) if ret_vtrace.dim() == 1: ret_vtrace = ret_vtrace.unsqueeze(-1) - data["advantages"] = adv_vtrace - data["returns"] = ret_vtrace + # Only use V-trace if we have updated once already + if self.updated: + data["advantages"] = adv_vtrace + data["returns"] = ret_vtrace + else: + data["advantages"] = adv + data["returns"] = ret + self.updated = True + + # Update critic and actor self.update_critic(data) data["advantages"] = normalize(data["advantages"]) - self.update_actor(data) + self.update_actor(data, policy_weights) + # Update pik weights if self.actor.store_pik: self.actor.update_pik_weights() - # Logging: Store mean GAE with and without V-trace - adv = compute_generalized_advantages(data, self.gamma, self.lam, self.critic) - ret = adv + values - + # Logging: Store mean advantages and returns self.adv_mean = adv.mean().item() self.ret_mean = ret.mean().item() self.adv_vtrace_mean = adv_vtrace.mean().item() self.ret_vtrace_mean = ret_vtrace.mean().item() + + def update_actor(self, data, policy_weights): + if self.eps_vary: + log_prob_pik = self.actor.get_pik_log_prob( + data["actor_obs"], data["actions"] + ) + offpol_ratio = torch.exp(data["log_prob"] - log_prob_pik) + # TODO: I am taking the mean over 2 dims, check if this is correct + eps_old = torch.mean(policy_weights * torch.abs(offpol_ratio - 1.0)) + self.eps = max(self.eps_ppo - eps_old.item(), 0.0) + + self.mean_surrogate_loss = 0 + counter = 0 + + self.actor.act(data["actor_obs"]) + data["old_sigma_batch"] = self.actor.action_std.detach() + data["old_mu_batch"] = self.actor.action_mean.detach() + data["old_actions_log_prob_batch"] = self.actor.get_actions_log_prob( + data["actions"] + ).detach() + + # Add policy weights to data + data["policy_weights"] = policy_weights + + generator = create_uniform_generator( + data, + self.batch_size, + max_gradient_steps=self.max_gradient_steps, + ) + for batch in generator: + self.actor.act(batch["actor_obs"]) + actions_log_prob_batch = self.actor.get_actions_log_prob(batch["actions"]) + mu_batch = self.actor.action_mean + sigma_batch = self.actor.action_std + entropy_batch = self.actor.entropy + + # * KL + # TODO: Implement GePPO adaptive LR + if self.desired_kl is not None and self.schedule == "adaptive": + with torch.inference_mode(): + kl = torch.sum( + torch.log(sigma_batch / batch["old_sigma_batch"] + 1.0e-5) + + ( + torch.square(batch["old_sigma_batch"]) + + torch.square(batch["old_mu_batch"] - mu_batch) + ) + / (2.0 * torch.square(sigma_batch)) + - 0.5, + axis=-1, + ) + kl_mean = torch.mean(kl) + + if kl_mean > self.desired_kl * 2.0: + self.learning_rate = max(1e-5, self.learning_rate / 1.5) + elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0: + self.learning_rate = min(1e-2, self.learning_rate * 1.5) + + for param_group in self.optimizer.param_groups: + # ! check this + param_group["lr"] = self.learning_rate + + # * GePPO Surrogate loss + log_prob_pik = self.actor.get_pik_log_prob( + batch["actor_obs"], batch["actions"] + ) + offpol_ratio = torch.exp(batch["log_prob"] - log_prob_pik) + advantages = batch["advantages"] + + # TODO: Center/clip advantages (optional) + # adv_mean = torch.mean( + # offpol_ratio * batch["policy_weights"] * advantages, dim=2 + # ) / torch.mean(offpol_ratio, batch["policy_weights"], dim=2) + # adv_std = torch.std( + # offpol_ratio * batch["policy_weights"] * advantages, dim=2 + # ) + + ratio = torch.exp( + actions_log_prob_batch + - torch.squeeze(batch["old_actions_log_prob_batch"]) + ) + surrogate = -torch.squeeze(advantages) * ratio + surrogate_clipped = -torch.squeeze(advantages) * torch.clamp( + ratio, offpol_ratio - self.eps, offpol_ratio + self.eps + ) + surrogate_loss = ( + torch.max(surrogate, surrogate_clipped) * batch["policy_weights"] + ).mean() + + loss = surrogate_loss - self.entropy_coef * entropy_batch.mean() + + # * Gradient step + self.optimizer.zero_grad() + loss.backward() + nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm) + self.optimizer.step() + self.mean_surrogate_loss += surrogate_loss.item() + counter += 1 + self.mean_surrogate_loss /= counter diff --git a/learning/runners/__init__.py b/learning/runners/__init__.py index a0840b06..8831a0ac 100644 --- a/learning/runners/__init__.py +++ b/learning/runners/__init__.py @@ -33,4 +33,5 @@ from .on_policy_runner import OnPolicyRunner from .my_runner import MyRunner from .old_policy_runner import OldPolicyRunner -from .off_policy_runner import OffPolicyRunner \ No newline at end of file +from .off_policy_runner import OffPolicyRunner +from .hybrid_policy_runner import HybridPolicyRunner \ No newline at end of file diff --git a/learning/runners/hybrid_policy_runner.py b/learning/runners/hybrid_policy_runner.py new file mode 100644 index 00000000..c3110e1a --- /dev/null +++ b/learning/runners/hybrid_policy_runner.py @@ -0,0 +1,250 @@ +import os +import torch +from tensordict import TensorDict + +from learning.utils import Logger + +from .BaseRunner import BaseRunner +from learning.modules import Actor, Critic +from learning.storage import ReplayBuffer +from learning.algorithms import GePPO + +logger = Logger() +storage = ReplayBuffer() + + +class HybridPolicyRunner(BaseRunner): + def __init__(self, env, train_cfg, device="cpu"): + super().__init__(env, train_cfg, device) + logger.initialize( + self.env.num_envs, + self.env.dt, + self.cfg["max_iterations"], + self.device, + ) + + # TODO: Weights hardcoded for 4 policies + self.num_old_policies = self.cfg["num_old_policies"] + self.weights = torch.tensor([0.4, 0.3, 0.2, 0.1]).to(self.device) + + def _set_up_alg(self): + alg_class = eval(self.cfg["algorithm_class_name"]) + if alg_class != GePPO: + raise ValueError("HybridPolicyRunner only supports GePPO") + + num_actor_obs = self.get_obs_size(self.actor_cfg["obs"]) + num_actions = self.get_action_size(self.actor_cfg["actions"]) + num_critic_obs = self.get_obs_size(self.critic_cfg["obs"]) + # Store pik for the actor + actor = Actor(num_actor_obs, num_actions, store_pik=True, **self.actor_cfg) + critic = Critic(num_critic_obs, **self.critic_cfg) + self.alg = alg_class(actor, critic, device=self.device, **self.alg_cfg) + + def learn(self): + self.set_up_logger() + + rewards_dict = {} + + self.alg.switch_to_train() + actor_obs = self.get_obs(self.actor_cfg["obs"]) + critic_obs = self.get_obs(self.critic_cfg["obs"]) + tot_iter = self.it + self.num_learning_iterations + self.save() + + # * start up storage + transition = TensorDict({}, batch_size=self.env.num_envs, device=self.device) + transition.update( + { + "actor_obs": actor_obs, + "next_actor_obs": actor_obs, + "actions": self.alg.act(actor_obs), + "critic_obs": critic_obs, + "next_critic_obs": critic_obs, + "rewards": self.get_rewards({"termination": 0.0})["termination"], + "dones": self.get_timed_out(), + } + ) + max_storage = self.env.num_envs * self.num_steps_per_env * self.num_old_policies + storage.initialize( + dummy_dict=transition, + num_envs=self.env.num_envs, + max_storage=max_storage, + device=self.device, + ) + + # burn in observation normalization. + if self.actor_cfg["normalize_obs"] or self.critic_cfg["normalize_obs"]: + self.burn_in_normalization() + + logger.tic("runtime") + for self.it in range(self.it + 1, tot_iter + 1): + logger.tic("iteration") + logger.tic("collection") + # * Rollout + with torch.inference_mode(): + for i in range(self.num_steps_per_env): + actions = self.alg.act(actor_obs) + self.set_actions( + self.actor_cfg["actions"], + actions, + self.actor_cfg["disable_actions"], + ) + # Store additional data for GePPO + log_prob = self.alg.actor.get_actions_log_prob(actions) + action_mean = self.alg.actor.action_mean + action_std = self.alg.actor.action_std + + transition.update( + { + "actor_obs": actor_obs, + "actions": actions, + "critic_obs": critic_obs, + "log_prob": log_prob, + "action_mean": action_mean, + "action_std": action_std, + } + ) + + self.env.step() + + actor_obs = self.get_noisy_obs( + self.actor_cfg["obs"], self.actor_cfg["noise"] + ) + critic_obs = self.get_obs(self.critic_cfg["obs"]) + + # * get time_outs + timed_out = self.get_timed_out() + terminated = self.get_terminated() + dones = timed_out | terminated + + self.update_rewards(rewards_dict, terminated) + total_rewards = torch.stack(tuple(rewards_dict.values())).sum(dim=0) + + transition.update( + { + "next_actor_obs": actor_obs, + "next_critic_obs": critic_obs, + "rewards": total_rewards, + "timed_out": timed_out, + "dones": dones, + } + ) + storage.add_transitions(transition) + + logger.log_rewards(rewards_dict) + logger.log_rewards({"total_rewards": total_rewards}) + logger.finish_step(dones) + logger.toc("collection") + + # Compute GePPO policy weights + n_policies = min(self.it, self.num_old_policies) + weights_active = self.weights[:n_policies] + weights_active = weights_active * n_policies / weights_active.sum() + idx_newest = (self.it - 1) % self.num_old_policies + indices_all = [ + i % self.num_old_policies + for i in range(idx_newest, idx_newest - n_policies, -1) + ] + weights_all = weights_active[indices_all] + weights_all = weights_all.repeat_interleave(self.num_steps_per_env) + weights_all = weights_all.unsqueeze(-1).repeat(1, self.env.num_envs) + + # Update GePPO with policy weights + logger.tic("learning") + self.alg.update(storage.get_data(), policy_weights=weights_all) + logger.toc("learning") + logger.log_all_categories() + + logger.finish_iteration() + logger.toc("iteration") + logger.toc("runtime") + logger.print_to_terminal() + + if self.it % self.save_interval == 0: + self.save() + self.save() + + @torch.no_grad + def burn_in_normalization(self, n_iterations=100): + actor_obs = self.get_obs(self.actor_cfg["obs"]) + critic_obs = self.get_obs(self.critic_cfg["obs"]) + for _ in range(n_iterations): + actions = self.alg.act(actor_obs) + self.set_actions(self.actor_cfg["actions"], actions) + self.env.step() + actor_obs = self.get_noisy_obs( + self.actor_cfg["obs"], self.actor_cfg["noise"] + ) + critic_obs = self.get_obs(self.critic_cfg["obs"]) + self.alg.critic.evaluate(critic_obs) + self.env.reset() + + def update_rewards(self, rewards_dict, terminated): + rewards_dict.update( + self.get_rewards( + self.critic_cfg["reward"]["termination_weight"], mask=terminated + ) + ) + rewards_dict.update( + self.get_rewards( + self.critic_cfg["reward"]["weights"], + modifier=self.env.dt, + mask=~terminated, + ) + ) + + def set_up_logger(self): + logger.register_rewards(list(self.critic_cfg["reward"]["weights"].keys())) + logger.register_rewards( + list(self.critic_cfg["reward"]["termination_weight"].keys()) + ) + logger.register_rewards(["total_rewards"]) + logger.register_category( + "algorithm", self.alg, ["mean_value_loss", "mean_surrogate_loss"] + ) + logger.register_category("actor", self.alg.actor, ["action_std", "entropy"]) + + # GePPO specific logging + logger.register_category( + "GePPO", + self.alg, + ["eps", "adv_mean", "ret_mean", "adv_vtrace_mean", "ret_vtrace_mean"], + ) + + logger.attach_torch_obj_to_wandb((self.alg.actor, self.alg.critic)) + + def save(self): + os.makedirs(self.log_dir, exist_ok=True) + path = os.path.join(self.log_dir, "model_{}.pt".format(self.it)) + torch.save( + { + "actor_state_dict": self.alg.actor.state_dict(), + "critic_state_dict": self.alg.critic.state_dict(), + "optimizer_state_dict": self.alg.optimizer.state_dict(), + "critic_optimizer_state_dict": self.alg.critic_optimizer.state_dict(), + "iter": self.it, + }, + path, + ) + + def load(self, path, load_optimizer=True): + loaded_dict = torch.load(path) + self.alg.actor.load_state_dict(loaded_dict["actor_state_dict"]) + self.alg.critic.load_state_dict(loaded_dict["critic_state_dict"]) + if load_optimizer: + self.alg.optimizer.load_state_dict(loaded_dict["optimizer_state_dict"]) + self.alg.critic_optimizer.load_state_dict( + loaded_dict["critic_optimizer_state_dict"] + ) + self.it = loaded_dict["iter"] + + def switch_to_eval(self): + self.alg.actor.eval() + self.alg.critic.eval() + + def get_inference_actions(self): + obs = self.get_noisy_obs(self.actor_cfg["obs"], self.actor_cfg["noise"]) + return self.alg.actor.act_inference(obs) + + def export(self, path): + self.alg.actor.export(path) diff --git a/learning/utils/dict_utils.py b/learning/utils/dict_utils.py index 6e29b58f..a02b3dc1 100644 --- a/learning/utils/dict_utils.py +++ b/learning/utils/dict_utils.py @@ -55,7 +55,7 @@ def compute_generalized_advantages(data, gamma, lam, critic): @torch.no_grad def compute_gae_vtrace(data, gamma, lam, is_trunc, actor, critic): if actor.store_pik is False: - raise NotImplementedError("Need to store pik for V-trace") + raise ValueError("Need to store pik for V-trace") log_prob = actor.get_actions_log_prob(data["actions"]) log_prob_pik = actor.get_pik_log_prob(data["actor_obs"], data["actions"]) From 6133f3b392e94b1647fcc504cc5ec367706c7da3 Mon Sep 17 00:00:00 2001 From: Lukas Molnar Date: Fri, 21 Jun 2024 15:03:51 -0400 Subject: [PATCH 04/10] bugfixes in V-trace GAE, pendulum learns now --- learning/algorithms/geppo.py | 57 +++++++++++++++--------- learning/modules/critic.py | 8 +++- learning/runners/hybrid_policy_runner.py | 10 +++-- learning/runners/on_policy_runner.py | 8 ---- learning/utils/dict_utils.py | 7 +-- 5 files changed, 50 insertions(+), 40 deletions(-) diff --git a/learning/algorithms/geppo.py b/learning/algorithms/geppo.py index eb381325..cf732cab 100644 --- a/learning/algorithms/geppo.py +++ b/learning/algorithms/geppo.py @@ -33,7 +33,7 @@ def __init__( self.updated = False - def update(self, data, policy_weights): + def update(self, data, weights): values = self.critic.evaluate(data["critic_obs"]) # Handle single env case if values.dim() == 1: @@ -62,9 +62,10 @@ def update(self, data, policy_weights): self.updated = True # Update critic and actor + data["weights"] = weights self.update_critic(data) data["advantages"] = normalize(data["advantages"]) - self.update_actor(data, policy_weights) + self.update_actor(data) # Update pik weights if self.actor.store_pik: @@ -76,14 +77,37 @@ def update(self, data, policy_weights): self.adv_vtrace_mean = adv_vtrace.mean().item() self.ret_vtrace_mean = ret_vtrace.mean().item() - def update_actor(self, data, policy_weights): + def update_critic(self, data): + self.mean_value_loss = 0 + counter = 0 + + generator = create_uniform_generator( + data, + self.batch_size, + max_gradient_steps=self.max_gradient_steps, + ) + for batch in generator: + # GePPO critic loss uses weights + value_loss = self.critic.loss_fn( + batch["critic_obs"], batch["returns"], batch["weights"] + ) + self.critic_optimizer.zero_grad() + value_loss.backward() + nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm) + self.critic_optimizer.step() + self.mean_value_loss += value_loss.item() + counter += 1 + self.mean_value_loss /= counter + + def update_actor(self, data): + # Update clipping eps if self.eps_vary: log_prob_pik = self.actor.get_pik_log_prob( data["actor_obs"], data["actions"] ) - offpol_ratio = torch.exp(data["log_prob"] - log_prob_pik) + offpol_ratio = torch.exp(log_prob_pik - data["log_prob"]) # TODO: I am taking the mean over 2 dims, check if this is correct - eps_old = torch.mean(policy_weights * torch.abs(offpol_ratio - 1.0)) + eps_old = torch.mean(data["weights"] * torch.abs(offpol_ratio - 1.0)) self.eps = max(self.eps_ppo - eps_old.item(), 0.0) self.mean_surrogate_loss = 0 @@ -92,12 +116,6 @@ def update_actor(self, data, policy_weights): self.actor.act(data["actor_obs"]) data["old_sigma_batch"] = self.actor.action_std.detach() data["old_mu_batch"] = self.actor.action_mean.detach() - data["old_actions_log_prob_batch"] = self.actor.get_actions_log_prob( - data["actions"] - ).detach() - - # Add policy weights to data - data["policy_weights"] = policy_weights generator = create_uniform_generator( data, @@ -106,7 +124,6 @@ def update_actor(self, data, policy_weights): ) for batch in generator: self.actor.act(batch["actor_obs"]) - actions_log_prob_batch = self.actor.get_actions_log_prob(batch["actions"]) mu_batch = self.actor.action_mean sigma_batch = self.actor.action_std entropy_batch = self.actor.entropy @@ -140,27 +157,25 @@ def update_actor(self, data, policy_weights): log_prob_pik = self.actor.get_pik_log_prob( batch["actor_obs"], batch["actions"] ) - offpol_ratio = torch.exp(batch["log_prob"] - log_prob_pik) + offpol_ratio = torch.exp(log_prob_pik - batch["log_prob"]) advantages = batch["advantages"] # TODO: Center/clip advantages (optional) # adv_mean = torch.mean( - # offpol_ratio * batch["policy_weights"] * advantages, dim=2 - # ) / torch.mean(offpol_ratio, batch["policy_weights"], dim=2) + # offpol_ratio * batch["weights"] * advantages, dim=2 + # ) / torch.mean(offpol_ratio, batch["weights"], dim=2) # adv_std = torch.std( - # offpol_ratio * batch["policy_weights"] * advantages, dim=2 + # offpol_ratio * batch["weights"] * advantages, dim=2 # ) - ratio = torch.exp( - actions_log_prob_batch - - torch.squeeze(batch["old_actions_log_prob_batch"]) - ) + log_prob = self.actor.get_actions_log_prob(batch["actions"]) + ratio = torch.exp(log_prob - batch["log_prob"]) surrogate = -torch.squeeze(advantages) * ratio surrogate_clipped = -torch.squeeze(advantages) * torch.clamp( ratio, offpol_ratio - self.eps, offpol_ratio + self.eps ) surrogate_loss = ( - torch.max(surrogate, surrogate_clipped) * batch["policy_weights"] + torch.max(surrogate, surrogate_clipped) * batch["weights"] ).mean() loss = surrogate_loss - self.entropy_coef * entropy_batch.mean() diff --git a/learning/modules/critic.py b/learning/modules/critic.py index 96c53438..0ea7825d 100644 --- a/learning/modules/critic.py +++ b/learning/modules/critic.py @@ -29,5 +29,9 @@ def forward(self, x): def evaluate(self, critic_observations): return self.forward(critic_observations) - def loss_fn(self, obs, target): - return nn.functional.mse_loss(self.forward(obs), target, reduction="mean") + def loss_fn(self, obs, target, weights=None): + if weights is None: + return nn.functional.mse_loss(self.forward(obs), target, reduction="mean") + + # Compute MSE loss with weights + return (weights * (self.forward(obs) - target).pow(2)).mean() diff --git a/learning/runners/hybrid_policy_runner.py b/learning/runners/hybrid_policy_runner.py index c3110e1a..df3b3781 100644 --- a/learning/runners/hybrid_policy_runner.py +++ b/learning/runners/hybrid_policy_runner.py @@ -136,7 +136,7 @@ def learn(self): logger.finish_step(dones) logger.toc("collection") - # Compute GePPO policy weights + # Compute GePPO weights n_policies = min(self.it, self.num_old_policies) weights_active = self.weights[:n_policies] weights_active = weights_active * n_policies / weights_active.sum() @@ -149,9 +149,9 @@ def learn(self): weights_all = weights_all.repeat_interleave(self.num_steps_per_env) weights_all = weights_all.unsqueeze(-1).repeat(1, self.env.num_envs) - # Update GePPO with policy weights + # Update GePPO with weights logger.tic("learning") - self.alg.update(storage.get_data(), policy_weights=weights_all) + self.alg.update(storage.get_data(), weights=weights_all) logger.toc("learning") logger.log_all_categories() @@ -200,7 +200,9 @@ def set_up_logger(self): ) logger.register_rewards(["total_rewards"]) logger.register_category( - "algorithm", self.alg, ["mean_value_loss", "mean_surrogate_loss"] + "algorithm", + self.alg, + ["learning_rate", "mean_value_loss", "mean_surrogate_loss"], ) logger.register_category("actor", self.alg.actor, ["action_std", "entropy"]) diff --git a/learning/runners/on_policy_runner.py b/learning/runners/on_policy_runner.py index f92dd6e8..c1758000 100644 --- a/learning/runners/on_policy_runner.py +++ b/learning/runners/on_policy_runner.py @@ -6,7 +6,6 @@ from .BaseRunner import BaseRunner from learning.storage import DictStorage -from learning.algorithms import GePPO logger = Logger() storage = DictStorage() @@ -165,13 +164,6 @@ def set_up_logger(self): ) logger.register_category("actor", self.alg.actor, ["action_std", "entropy"]) - if isinstance(self.alg, GePPO): - logger.register_category( - "GePPO", - self.alg, - ["adv_mean", "ret_mean", "adv_vtrace_mean", "ret_vtrace_mean"], - ) - logger.attach_torch_obj_to_wandb((self.alg.actor, self.alg.critic)) def save(self): diff --git a/learning/utils/dict_utils.py b/learning/utils/dict_utils.py index a02b3dc1..1a82db2f 100644 --- a/learning/utils/dict_utils.py +++ b/learning/utils/dict_utils.py @@ -57,12 +57,11 @@ def compute_gae_vtrace(data, gamma, lam, is_trunc, actor, critic): if actor.store_pik is False: raise ValueError("Need to store pik for V-trace") - log_prob = actor.get_actions_log_prob(data["actions"]) log_prob_pik = actor.get_pik_log_prob(data["actor_obs"], data["actions"]) # n: rollout length, e: num envs # TODO: Double check GePPO code and paper (they diverge imo) - ratio = torch.exp(log_prob - log_prob_pik) # shape [n, e] + ratio = torch.exp(log_prob_pik - data["log_prob"]) # shape [n, e] n, e = ratio.shape ones_U = torch.triu(torch.ones((n, n)), 0).to(data.device) @@ -86,15 +85,13 @@ def compute_gae_vtrace(data, gamma, lam, is_trunc, actor, critic): rate_L = torch.tril(torch.ones((n, n)) * gamma * lam, -1).to(data.device) # [n, n] rates = torch.tril(torch.cumprod(rate_L + ones_U, axis=0), 0) rates_repeat = rates.unsqueeze(0).repeat(e, 1, 1) # [e, n, n] - batch_prod = torch.bmm(rates_repeat, ratio_trunc_prods) # [e, n, n] # element-wise multiplication: - intermediate = batch_prod * delta_repeat # [e, n, n] + intermediate = rates_repeat * ratio_trunc_prods * delta_repeat # [e, n, n] advantages = torch.sum(intermediate, axis=1) # [e, n] advantages = advantages.transpose(0, 1) # [n, e] returns = advantages * ratio_trunc + values # [n, e] - return advantages, returns From 08b2b771240370efac0ccb2808ce9719b1bf8bb9 Mon Sep 17 00:00:00 2001 From: Lukas Molnar Date: Fri, 21 Jun 2024 17:10:00 -0400 Subject: [PATCH 05/10] adapt LR based on GePPO paper --- gym/envs/pendulum/pendulum_config.py | 4 +- learning/algorithms/geppo.py | 59 +++++++++++------------- learning/runners/hybrid_policy_runner.py | 2 +- 3 files changed, 31 insertions(+), 34 deletions(-) diff --git a/gym/envs/pendulum/pendulum_config.py b/gym/envs/pendulum/pendulum_config.py index 905021f7..33112a0f 100644 --- a/gym/envs/pendulum/pendulum_config.py +++ b/gym/envs/pendulum/pendulum_config.py @@ -5,7 +5,7 @@ class PendulumCfg(FixedRobotCfg): class env(FixedRobotCfg.env): - num_envs = 2048 + num_envs = 4096 num_actuators = 1 episode_length_s = 10 @@ -132,5 +132,5 @@ class runner(FixedRobotCfgPPO.runner): experiment_name = "pendulum" max_iterations = 200 # number of policy updates algorithm_class_name = "GePPO" - num_steps_per_env = 64 + num_steps_per_env = 32 num_old_policies = 4 diff --git a/learning/algorithms/geppo.py b/learning/algorithms/geppo.py index cf732cab..03242be4 100644 --- a/learning/algorithms/geppo.py +++ b/learning/algorithms/geppo.py @@ -19,6 +19,10 @@ def __init__( is_trunc=1.0, eps_ppo=0.2, eps_vary=True, + adapt_lr=True, + adapt_factor=0.03, + adapt_minthresh=0.0, + adapt_maxthresh=1.0, **kwargs, ): super().__init__(actor, critic, **kwargs) @@ -31,6 +35,12 @@ def __init__( self.eps_vary = eps_vary self.eps = self.eps_ppo # TODO: This should be computed + # Learning rate + self.adapt_lr = adapt_lr + self.adapt_factor = adapt_factor + self.adapt_minthresh = adapt_minthresh + self.adapt_maxthresh = adapt_maxthresh + self.updated = False def update(self, data, weights): @@ -113,10 +123,6 @@ def update_actor(self, data): self.mean_surrogate_loss = 0 counter = 0 - self.actor.act(data["actor_obs"]) - data["old_sigma_batch"] = self.actor.action_std.detach() - data["old_mu_batch"] = self.actor.action_mean.detach() - generator = create_uniform_generator( data, self.batch_size, @@ -124,35 +130,8 @@ def update_actor(self, data): ) for batch in generator: self.actor.act(batch["actor_obs"]) - mu_batch = self.actor.action_mean - sigma_batch = self.actor.action_std entropy_batch = self.actor.entropy - # * KL - # TODO: Implement GePPO adaptive LR - if self.desired_kl is not None and self.schedule == "adaptive": - with torch.inference_mode(): - kl = torch.sum( - torch.log(sigma_batch / batch["old_sigma_batch"] + 1.0e-5) - + ( - torch.square(batch["old_sigma_batch"]) - + torch.square(batch["old_mu_batch"] - mu_batch) - ) - / (2.0 * torch.square(sigma_batch)) - - 0.5, - axis=-1, - ) - kl_mean = torch.mean(kl) - - if kl_mean > self.desired_kl * 2.0: - self.learning_rate = max(1e-5, self.learning_rate / 1.5) - elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0: - self.learning_rate = min(1e-2, self.learning_rate * 1.5) - - for param_group in self.optimizer.param_groups: - # ! check this - param_group["lr"] = self.learning_rate - # * GePPO Surrogate loss log_prob_pik = self.actor.get_pik_log_prob( batch["actor_obs"], batch["actions"] @@ -188,3 +167,21 @@ def update_actor(self, data): self.mean_surrogate_loss += surrogate_loss.item() counter += 1 self.mean_surrogate_loss /= counter + + # Compute TV, add to self for logging + self.actor.act(data["actor_obs"]) + log_prob = self.actor.get_actions_log_prob(data["actions"]) + log_prob_pik = self.actor.get_pik_log_prob(data["actor_obs"], data["actions"]) + ratio = torch.exp(log_prob - data["log_prob"]) + clip_center = torch.exp(log_prob_pik - data["log_prob"]) + ratio_diff = torch.abs(ratio - clip_center) + self.tv = 0.5 * torch.mean(data["weights"] * ratio_diff) + + # Adapt learning rate + if self.adapt_lr: + if self.tv > (self.adapt_maxthresh * (0.5 * self.eps)): + self.learning_rate /= 1 + self.adapt_factor + elif self.tv < (self.adapt_minthresh * (0.5 * self.eps)): + self.learning_rate *= 1 + self.adapt_factor + for param_group in self.optimizer.param_groups: + param_group["lr"] = self.learning_rate diff --git a/learning/runners/hybrid_policy_runner.py b/learning/runners/hybrid_policy_runner.py index df3b3781..e63a4dfe 100644 --- a/learning/runners/hybrid_policy_runner.py +++ b/learning/runners/hybrid_policy_runner.py @@ -210,7 +210,7 @@ def set_up_logger(self): logger.register_category( "GePPO", self.alg, - ["eps", "adv_mean", "ret_mean", "adv_vtrace_mean", "ret_vtrace_mean"], + ["eps", "tv", "adv_mean", "ret_mean", "adv_vtrace_mean", "ret_vtrace_mean"], ) logger.attach_torch_obj_to_wandb((self.alg.actor, self.alg.critic)) From 6845f0fbccf7ee0207b2e8f6efb0775bca0c4312 Mon Sep 17 00:00:00 2001 From: Lukas Molnar Date: Sat, 22 Jun 2024 15:26:16 -0400 Subject: [PATCH 06/10] GePPO for mini cheetah ref, and constant eps_geppo param --- .../mini_cheetah/mini_cheetah_ref_config.py | 30 ++++++++++++++++--- learning/algorithms/geppo.py | 13 ++++---- learning/runners/hybrid_policy_runner.py | 9 +++++- learning/runners/on_policy_runner.py | 4 ++- 4 files changed, 44 insertions(+), 12 deletions(-) diff --git a/gym/envs/mini_cheetah/mini_cheetah_ref_config.py b/gym/envs/mini_cheetah/mini_cheetah_ref_config.py index bd8140b2..9a729b82 100644 --- a/gym/envs/mini_cheetah/mini_cheetah_ref_config.py +++ b/gym/envs/mini_cheetah/mini_cheetah_ref_config.py @@ -8,7 +8,7 @@ class MiniCheetahRefCfg(MiniCheetahCfg): class env(MiniCheetahCfg.env): - num_envs = 4096 + num_envs = 1024 num_actuators = 12 episode_length_s = 5.0 @@ -67,7 +67,7 @@ class scaling(MiniCheetahCfg.scaling): class MiniCheetahRefRunnerCfg(MiniCheetahRunnerCfg): seed = -1 - runner_class_name = "OnPolicyRunner" + runner_class_name = "HybridPolicyRunner" class actor: hidden_dims = [256, 256, 128] @@ -138,11 +138,33 @@ class termination_weight: termination = 0.15 class algorithm(MiniCheetahRunnerCfg.algorithm): - pass + # both + gamma = 0.99 + lam = 0.95 + # shared + batch_size = 2**15 + max_gradient_steps = 10 + # new + storage_size = 2**17 # new + batch_size = 2**15 # new + + clip_param = 0.2 + learning_rate = 1.0e-3 + max_grad_norm = 1.0 + # Critic + use_clipped_value_loss = True + # Actor + entropy_coef = 0.01 + schedule = "adaptive" # could be adaptive, fixed + desired_kl = 0.01 + + # GePPO + is_trunc = 1.0 class runner(MiniCheetahRunnerCfg.runner): run_name = "" experiment_name = "mini_cheetah_ref" max_iterations = 500 # number of policy updates - algorithm_class_name = "PPO2" + algorithm_class_name = "GePPO" num_steps_per_env = 32 # deprecate + num_old_policies = 4 diff --git a/learning/algorithms/geppo.py b/learning/algorithms/geppo.py index 03242be4..18959585 100644 --- a/learning/algorithms/geppo.py +++ b/learning/algorithms/geppo.py @@ -18,7 +18,8 @@ def __init__( critic, is_trunc=1.0, eps_ppo=0.2, - eps_vary=True, + eps_geppo=0.1, + eps_vary=False, adapt_lr=True, adapt_factor=0.03, adapt_minthresh=0.0, @@ -32,8 +33,8 @@ def __init__( # Clipping parameter self.eps_ppo = eps_ppo + self.eps_geppo = eps_geppo self.eps_vary = eps_vary - self.eps = self.eps_ppo # TODO: This should be computed # Learning rate self.adapt_lr = adapt_lr @@ -118,7 +119,7 @@ def update_actor(self, data): offpol_ratio = torch.exp(log_prob_pik - data["log_prob"]) # TODO: I am taking the mean over 2 dims, check if this is correct eps_old = torch.mean(data["weights"] * torch.abs(offpol_ratio - 1.0)) - self.eps = max(self.eps_ppo - eps_old.item(), 0.0) + self.eps_geppo = max(self.eps_ppo - eps_old.item(), 0.0) self.mean_surrogate_loss = 0 counter = 0 @@ -151,7 +152,7 @@ def update_actor(self, data): ratio = torch.exp(log_prob - batch["log_prob"]) surrogate = -torch.squeeze(advantages) * ratio surrogate_clipped = -torch.squeeze(advantages) * torch.clamp( - ratio, offpol_ratio - self.eps, offpol_ratio + self.eps + ratio, offpol_ratio - self.eps_geppo, offpol_ratio + self.eps_geppo ) surrogate_loss = ( torch.max(surrogate, surrogate_clipped) * batch["weights"] @@ -179,9 +180,9 @@ def update_actor(self, data): # Adapt learning rate if self.adapt_lr: - if self.tv > (self.adapt_maxthresh * (0.5 * self.eps)): + if self.tv > (self.adapt_maxthresh * (0.5 * self.eps_geppo)): self.learning_rate /= 1 + self.adapt_factor - elif self.tv < (self.adapt_minthresh * (0.5 * self.eps)): + elif self.tv < (self.adapt_minthresh * (0.5 * self.eps_geppo)): self.learning_rate *= 1 + self.adapt_factor for param_group in self.optimizer.param_groups: param_group["lr"] = self.learning_rate diff --git a/learning/runners/hybrid_policy_runner.py b/learning/runners/hybrid_policy_runner.py index e63a4dfe..08d90fd8 100644 --- a/learning/runners/hybrid_policy_runner.py +++ b/learning/runners/hybrid_policy_runner.py @@ -210,7 +210,14 @@ def set_up_logger(self): logger.register_category( "GePPO", self.alg, - ["eps", "tv", "adv_mean", "ret_mean", "adv_vtrace_mean", "ret_vtrace_mean"], + [ + "eps_geppo", + "tv", + "adv_mean", + "ret_mean", + "adv_vtrace_mean", + "ret_vtrace_mean", + ], ) logger.attach_torch_obj_to_wandb((self.alg.actor, self.alg.critic)) diff --git a/learning/runners/on_policy_runner.py b/learning/runners/on_policy_runner.py index c1758000..a82f24f6 100644 --- a/learning/runners/on_policy_runner.py +++ b/learning/runners/on_policy_runner.py @@ -160,7 +160,9 @@ def set_up_logger(self): ) logger.register_rewards(["total_rewards"]) logger.register_category( - "algorithm", self.alg, ["mean_value_loss", "mean_surrogate_loss"] + "algorithm", + self.alg, + ["learning_rate", "mean_value_loss", "mean_surrogate_loss"], ) logger.register_category("actor", self.alg.actor, ["action_std", "entropy"]) From ed751e33aefeee80d75a1a5f80e7d08c168ee7d1 Mon Sep 17 00:00:00 2001 From: Lukas Molnar Date: Tue, 25 Jun 2024 11:45:34 -0400 Subject: [PATCH 07/10] added recursive GAE vtrace, and split GAE by policy --- .../mini_cheetah/mini_cheetah_ref_config.py | 4 +- gym/envs/pendulum/pendulum_config.py | 5 +- learning/algorithms/geppo.py | 67 +++++++++---- learning/modules/actor.py | 1 + learning/runners/hybrid_policy_runner.py | 8 +- learning/utils/dict_utils.py | 96 ++++++++++++------- 6 files changed, 130 insertions(+), 51 deletions(-) diff --git a/gym/envs/mini_cheetah/mini_cheetah_ref_config.py b/gym/envs/mini_cheetah/mini_cheetah_ref_config.py index 9a729b82..142e8ed7 100644 --- a/gym/envs/mini_cheetah/mini_cheetah_ref_config.py +++ b/gym/envs/mini_cheetah/mini_cheetah_ref_config.py @@ -146,7 +146,6 @@ class algorithm(MiniCheetahRunnerCfg.algorithm): max_gradient_steps = 10 # new storage_size = 2**17 # new - batch_size = 2**15 # new clip_param = 0.2 learning_rate = 1.0e-3 @@ -159,6 +158,9 @@ class algorithm(MiniCheetahRunnerCfg.algorithm): desired_kl = 0.01 # GePPO + vtrace = True + normalize_advantages = False # weighted normalization in GePPO loss + recursive_advantages = True # applies to vtrace is_trunc = 1.0 class runner(MiniCheetahRunnerCfg.runner): diff --git a/gym/envs/pendulum/pendulum_config.py b/gym/envs/pendulum/pendulum_config.py index 33112a0f..11493419 100644 --- a/gym/envs/pendulum/pendulum_config.py +++ b/gym/envs/pendulum/pendulum_config.py @@ -115,7 +115,7 @@ class algorithm(FixedRobotCfgPPO.algorithm): storage_size = 2**17 # new batch_size = 2**16 # new clip_param = 0.2 - learning_rate = 1.0e-4 + learning_rate = 3e-4 max_grad_norm = 1.0 # Critic use_clipped_value_loss = True @@ -125,6 +125,9 @@ class algorithm(FixedRobotCfgPPO.algorithm): desired_kl = 0.01 # GePPO + vtrace = True + normalize_advantages = False # weighted normalization in GePPO loss + recursive_advantages = True # applies to vtrace is_trunc = 1.0 class runner(FixedRobotCfgPPO.runner): diff --git a/learning/algorithms/geppo.py b/learning/algorithms/geppo.py index 18959585..10198260 100644 --- a/learning/algorithms/geppo.py +++ b/learning/algorithms/geppo.py @@ -16,6 +16,10 @@ def __init__( self, actor, critic, + num_steps_per_env=32, + vtrace=True, + normalize_advantages=False, + recursive_advantages=True, is_trunc=1.0, eps_ppo=0.2, eps_geppo=0.1, @@ -27,6 +31,12 @@ def __init__( **kwargs, ): super().__init__(actor, critic, **kwargs) + self.num_steps_per_env = num_steps_per_env + + # GAE parameters + self.vtrace = vtrace + self.normalize_advantages = normalize_advantages + self.recursive_advantages = recursive_advantages # Importance sampling truncation self.is_trunc = is_trunc @@ -52,11 +62,8 @@ def update(self, data, weights): data["values"] = values # Compute GAE with and without V-trace - adv = compute_generalized_advantages(data, self.gamma, self.lam, self.critic) - ret = adv + values - adv_vtrace, ret_vtrace = compute_gae_vtrace( - data, self.gamma, self.lam, self.is_trunc, self.actor, self.critic - ) + adv, ret = self.compute_gae_all(data, vtrace=False) + adv_vtrace, ret_vtrace = self.compute_gae_all(data, vtrace=True) # Handle single env case if adv_vtrace.dim() == 1: adv_vtrace = adv_vtrace.unsqueeze(-1) @@ -64,7 +71,7 @@ def update(self, data, weights): ret_vtrace = ret_vtrace.unsqueeze(-1) # Only use V-trace if we have updated once already - if self.updated: + if self.vtrace and self.updated: data["advantages"] = adv_vtrace data["returns"] = ret_vtrace else: @@ -88,6 +95,35 @@ def update(self, data, weights): self.adv_vtrace_mean = adv_vtrace.mean().item() self.ret_vtrace_mean = ret_vtrace.mean().item() + def compute_gae_all(self, data, vtrace): + # Compute GAE for each policy and concatenate + adv = torch.zeros_like(data["values"]).to(self.device) + ret = torch.zeros_like(data["values"]).to(self.device) + steps = self.num_steps_per_env + loaded_policies = data["values"].shape[0] // steps + + for i in range(loaded_policies): + data_i = data[i * steps : (i + 1) * steps] + if vtrace: + adv_i, ret_i = compute_gae_vtrace( + data_i, + self.gamma, + self.lam, + self.is_trunc, + self.actor, + self.critic, + rec=self.recursive_advantages, + ) + else: + adv_i = compute_generalized_advantages( + data_i, self.gamma, self.lam, self.critic + ) + ret_i = adv_i + data_i["values"] + adv[i * steps : (i + 1) * steps] = adv_i + ret[i * steps : (i + 1) * steps] = ret_i + + return adv, ret + def update_critic(self, data): self.mean_value_loss = 0 counter = 0 @@ -134,22 +170,21 @@ def update_actor(self, data): entropy_batch = self.actor.entropy # * GePPO Surrogate loss + log_prob = self.actor.get_actions_log_prob(batch["actions"]) log_prob_pik = self.actor.get_pik_log_prob( batch["actor_obs"], batch["actions"] ) + ratio = torch.exp(log_prob - batch["log_prob"]) offpol_ratio = torch.exp(log_prob_pik - batch["log_prob"]) - advantages = batch["advantages"] - # TODO: Center/clip advantages (optional) - # adv_mean = torch.mean( - # offpol_ratio * batch["weights"] * advantages, dim=2 - # ) / torch.mean(offpol_ratio, batch["weights"], dim=2) - # adv_std = torch.std( - # offpol_ratio * batch["weights"] * advantages, dim=2 - # ) + advantages = batch["advantages"] + if self.normalize_advantages: + adv_mean = torch.mean( + offpol_ratio * batch["weights"] * advantages + ) / torch.mean(offpol_ratio * batch["weights"]) + adv_std = torch.std(offpol_ratio * batch["weights"] * advantages) + advantages = (advantages - adv_mean) / (adv_std + 1e-8) - log_prob = self.actor.get_actions_log_prob(batch["actions"]) - ratio = torch.exp(log_prob - batch["log_prob"]) surrogate = -torch.squeeze(advantages) * ratio surrogate_clipped = -torch.squeeze(advantages) * torch.clamp( ratio, offpol_ratio - self.eps_geppo, offpol_ratio + self.eps_geppo diff --git a/learning/modules/actor.py b/learning/modules/actor.py index d71d2ac2..a4573ef3 100644 --- a/learning/modules/actor.py +++ b/learning/modules/actor.py @@ -82,6 +82,7 @@ def update_pik_weights(self): def get_pik_log_prob(self, observations, actions): if self._normalize_obs: + # TODO: Check if this updates the normalization mean/std with torch.no_grad(): observations = self.obs_rms(observations) mean_pik = self.NN_pik(observations) diff --git a/learning/runners/hybrid_policy_runner.py b/learning/runners/hybrid_policy_runner.py index 08d90fd8..fb2c9890 100644 --- a/learning/runners/hybrid_policy_runner.py +++ b/learning/runners/hybrid_policy_runner.py @@ -38,7 +38,13 @@ def _set_up_alg(self): # Store pik for the actor actor = Actor(num_actor_obs, num_actions, store_pik=True, **self.actor_cfg) critic = Critic(num_critic_obs, **self.critic_cfg) - self.alg = alg_class(actor, critic, device=self.device, **self.alg_cfg) + self.alg = alg_class( + actor, + critic, + device=self.device, + num_steps_per_env=self.num_steps_per_env, # GePPO needs this + **self.alg_cfg, + ) def learn(self): self.set_up_logger() diff --git a/learning/utils/dict_utils.py b/learning/utils/dict_utils.py index 1a82db2f..3bde0380 100644 --- a/learning/utils/dict_utils.py +++ b/learning/utils/dict_utils.py @@ -53,45 +53,77 @@ def compute_generalized_advantages(data, gamma, lam, critic): # Implementation based on GePPO repo: https://github.com/jqueeney/geppo @torch.no_grad -def compute_gae_vtrace(data, gamma, lam, is_trunc, actor, critic): +def compute_gae_vtrace(data, gamma, lam, is_trunc, actor, critic, rec=False): if actor.store_pik is False: raise ValueError("Need to store pik for V-trace") - log_prob_pik = actor.get_pik_log_prob(data["actor_obs"], data["actions"]) - # n: rollout length, e: num envs - # TODO: Double check GePPO code and paper (they diverge imo) + n, e = data.shape + log_prob_pik = actor.get_pik_log_prob(data["actor_obs"], data["actions"]) ratio = torch.exp(log_prob_pik - data["log_prob"]) # shape [n, e] + ratio_trunc = torch.clamp_max(ratio, is_trunc) # [n, e] - n, e = ratio.shape - ones_U = torch.triu(torch.ones((n, n)), 0).to(data.device) + if rec: + # Recursive version + last_values = critic.evaluate(data["next_critic_obs"][-1]) + advantages = torch.zeros_like(data["values"]) + if last_values is not None: + # TODO: check this (copied from regular GAE) + # since we don't have obs for the last step, need last value plugged in + not_done = ~data["dones"][-1] + advantages[-1] = ( + data["rewards"][-1] + + gamma * data["values"][-1] * data["timed_out"][-1] + + gamma * last_values * not_done + - data["values"][-1] + ) + + for k in reversed(range(data["values"].shape[0] - 1)): + not_done = ~data["dones"][k] + td_error = ( + data["rewards"][k] + + gamma * data["values"][k] * data["timed_out"][k] + + gamma * data["values"][k + 1] * not_done + - data["values"][k] + ) + advantages[k] = ( + td_error + gamma * lam * not_done * ratio_trunc[k] * advantages[k + 1] + ) + + returns = advantages * ratio_trunc + data["values"] + + else: + # GePPO paper version + ratio_trunc_T = ratio_trunc.transpose(0, 1) # [e, n] + ratio_trunc_repeat = ratio_trunc_T.unsqueeze(-1).repeat(1, 1, n) # [e, n, n] + ratio_trunc_L = torch.tril(ratio_trunc_repeat, -1) + ones_U = torch.triu(torch.ones((n, n)), 0).to(data.device) + + # cumprod along axis 1, keep shape [e, n, n] + ratio_trunc_prods = torch.tril(torch.cumprod(ratio_trunc_L + ones_U, axis=1), 0) + + # everything in data dict is [n, e] + values = critic.evaluate(data["critic_obs"]) + values_next = critic.evaluate(data["next_critic_obs"]) + not_done = ~data["dones"] + + delta = data["rewards"] + gamma * values_next * not_done - values # [n, e] + delta_T = delta.transpose(0, 1) # [e, n] + delta_repeat = delta_T.unsqueeze(-1).repeat(1, 1, n) # [e, n, n] + + rate_L = torch.tril(torch.ones((n, n)) * gamma * lam, -1).to( + data.device + ) # [n, n] + rates = torch.tril(torch.cumprod(rate_L + ones_U, axis=0), 0) + rates_repeat = rates.unsqueeze(0).repeat(e, 1, 1) # [e, n, n] + + # element-wise multiplication: + intermediate = rates_repeat * ratio_trunc_prods * delta_repeat # [e, n, n] + advantages = torch.sum(intermediate, axis=1) # [e, n] + + advantages = advantages.transpose(0, 1) # [n, e] + returns = advantages * ratio_trunc + values # [n, e] - ratio_trunc = torch.clamp_max(ratio, is_trunc) # [n, e] - ratio_trunc_T = ratio_trunc.transpose(0, 1) # [e, n] - ratio_trunc_repeat = ratio_trunc_T.unsqueeze(-1).repeat(1, 1, n) # [e, n, n] - ratio_trunc_L = torch.tril(ratio_trunc_repeat, -1) - # cumprod along axis 1, keep shape [e, n, n] - ratio_trunc_prods = torch.tril(torch.cumprod(ratio_trunc_L + ones_U, axis=1), 0) - - # everything in data dict is [n, e] - values = critic.evaluate(data["critic_obs"]) - values_next = critic.evaluate(data["next_critic_obs"]) - not_done = ~data["dones"] - - delta = data["rewards"] + gamma * values_next * not_done - values # [n, e] - delta_T = delta.transpose(0, 1) # [e, n] - delta_repeat = delta_T.unsqueeze(-1).repeat(1, 1, n) # [e, n, n] - - rate_L = torch.tril(torch.ones((n, n)) * gamma * lam, -1).to(data.device) # [n, n] - rates = torch.tril(torch.cumprod(rate_L + ones_U, axis=0), 0) - rates_repeat = rates.unsqueeze(0).repeat(e, 1, 1) # [e, n, n] - - # element-wise multiplication: - intermediate = rates_repeat * ratio_trunc_prods * delta_repeat # [e, n, n] - advantages = torch.sum(intermediate, axis=1) # [e, n] - - advantages = advantages.transpose(0, 1) # [n, e] - returns = advantages * ratio_trunc + values # [n, e] return advantages, returns From 0329ae5fce430e75397db32946ce91504b5a86ae Mon Sep 17 00:00:00 2001 From: Lukas Molnar Date: Wed, 26 Jun 2024 11:43:13 -0400 Subject: [PATCH 08/10] fine tune by loadining PPO run and training with noise multiplied --- .../mini_cheetah/mini_cheetah_ref_config.py | 14 +++++-- learning/algorithms/ppo2.py | 4 ++ learning/modules/actor.py | 28 +++++++++----- learning/modules/utils/__init__.py | 2 +- learning/modules/utils/neural_net.py | 14 +++++++ learning/modules/utils/normalize.py | 9 ++++- learning/runners/BaseRunner.py | 2 + learning/runners/hybrid_policy_runner.py | 38 ++++++++++++------- learning/runners/on_policy_runner.py | 20 +++++++++- learning/utils/dict_utils.py | 4 +- 10 files changed, 102 insertions(+), 33 deletions(-) diff --git a/gym/envs/mini_cheetah/mini_cheetah_ref_config.py b/gym/envs/mini_cheetah/mini_cheetah_ref_config.py index 142e8ed7..258806f4 100644 --- a/gym/envs/mini_cheetah/mini_cheetah_ref_config.py +++ b/gym/envs/mini_cheetah/mini_cheetah_ref_config.py @@ -8,7 +8,7 @@ class MiniCheetahRefCfg(MiniCheetahCfg): class env(MiniCheetahCfg.env): - num_envs = 1024 + num_envs = 1 # Fine tuning num_actuators = 12 episode_length_s = 5.0 @@ -87,6 +87,8 @@ class actor: disable_actions = False class noise: + noise_multiplier = 10.0 # Fine tuning: multiplies all noise + scale = 1.0 dof_pos_obs = 0.01 base_ang_vel = 0.01 @@ -142,7 +144,7 @@ class algorithm(MiniCheetahRunnerCfg.algorithm): gamma = 0.99 lam = 0.95 # shared - batch_size = 2**15 + batch_size = 500 # 2**15 max_gradient_steps = 10 # new storage_size = 2**17 # new @@ -166,7 +168,11 @@ class algorithm(MiniCheetahRunnerCfg.algorithm): class runner(MiniCheetahRunnerCfg.runner): run_name = "" experiment_name = "mini_cheetah_ref" - max_iterations = 500 # number of policy updates + max_iterations = 50 # number of policy updates algorithm_class_name = "GePPO" - num_steps_per_env = 32 # deprecate + num_steps_per_env = 500 # deprecate num_old_policies = 4 + + # Fine tuning + resume = True + load_run = "rollout_32" # pretrained PPO run diff --git a/learning/algorithms/ppo2.py b/learning/algorithms/ppo2.py index a489ce5e..8a29f346 100644 --- a/learning/algorithms/ppo2.py +++ b/learning/algorithms/ppo2.py @@ -68,6 +68,10 @@ def update(self, data): ) data["returns"] = data["advantages"] + data["values"] + # Logging: Store mean advantages and returns + self.adv_mean = data["advantages"].mean().item() + self.ret_mean = data["returns"].mean().item() + self.update_critic(data) data["advantages"] = normalize(data["advantages"]) self.update_actor(data) diff --git a/learning/modules/actor.py b/learning/modules/actor.py index a4573ef3..ab6e15e9 100644 --- a/learning/modules/actor.py +++ b/learning/modules/actor.py @@ -1,8 +1,7 @@ import torch import torch.nn as nn from torch.distributions import Normal -from .utils import create_MLP -from .utils import export_network +from .utils import StaticNN, create_MLP, export_network from .utils import RunningMeanStd @@ -36,8 +35,10 @@ def __init__( self.store_pik = store_pik if self.store_pik: - self.NN_pik = create_MLP(num_obs, num_actions, hidden_dims, activation) - self.std_pik = self.std.detach().clone() + self._NN_pik = StaticNN( + create_MLP(num_obs, num_actions, hidden_dims, activation) + ) + self._std_pik = self.std.detach().clone() self.update_pik_weights() @property @@ -52,6 +53,14 @@ def action_std(self): def entropy(self): return self.distribution.entropy().sum(dim=-1) + @property + def obs_running_mean(self): + return self.obs_rms.running_mean + + @property + def obs_running_std(self): + return self.obs_rms.running_var.sqrt() + def update_distribution(self, observations): mean = self.act_inference(observations) self.distribution = Normal(mean, mean * 0.0 + self.std) @@ -76,16 +85,17 @@ def export(self, path): export_network(self, "policy", path, self.num_obs) def update_pik_weights(self): - nn_state_dict = self.NN.state_dict() - self.NN_pik.load_state_dict(nn_state_dict) - self.std_pik = self.std.detach().clone() + with torch.no_grad(): + nn_state_dict = self.NN.state_dict() + self._NN_pik.model.load_state_dict(nn_state_dict) + self._std_pik = self.std.detach().clone() def get_pik_log_prob(self, observations, actions): if self._normalize_obs: # TODO: Check if this updates the normalization mean/std with torch.no_grad(): observations = self.obs_rms(observations) - mean_pik = self.NN_pik(observations) - std_pik = self.std_pik.to(mean_pik.device) + mean_pik = self._NN_pik(observations) + std_pik = self._std_pik.to(mean_pik.device) distribution = Normal(mean_pik, mean_pik * 0.0 + std_pik) return distribution.log_prob(actions).sum(dim=-1) diff --git a/learning/modules/utils/__init__.py b/learning/modules/utils/__init__.py index 1422b82d..cf34428b 100644 --- a/learning/modules/utils/__init__.py +++ b/learning/modules/utils/__init__.py @@ -1,2 +1,2 @@ -from .neural_net import create_MLP, export_network +from .neural_net import StaticNN, create_MLP, export_network from .normalize import RunningMeanStd diff --git a/learning/modules/utils/neural_net.py b/learning/modules/utils/neural_net.py index 5fd6ab3b..002b4ed8 100644 --- a/learning/modules/utils/neural_net.py +++ b/learning/modules/utils/neural_net.py @@ -3,6 +3,20 @@ import copy +class StaticNN: + """ + A static neural network wrapper that is used just for inference. + """ + + def __init__(self, model): + self.model = model + + def __call__(self, x): + model = self.model.to(x.device) + with torch.no_grad(): + return model(x) + + def create_MLP(num_inputs, num_outputs, hidden_dims, activation, dropouts=None): activation = get_activation(activation) diff --git a/learning/modules/utils/normalize.py b/learning/modules/utils/normalize.py index 246bafa4..4da03653 100644 --- a/learning/modules/utils/normalize.py +++ b/learning/modules/utils/normalize.py @@ -40,8 +40,13 @@ def _update_mean_var_from_moments( def forward(self, input): if self.training: - mean = input.mean(tuple(range(input.dim() - 1))) - var = input.var(tuple(range(input.dim() - 1))) + # TODO: check this, it got rid of NaN values in first iteration + dim = tuple(range(input.dim() - 1)) + mean = input.mean(dim) + if input.dim() <= 2: + var = torch.zeros_like(mean) + else: + var = input.var(dim) ( self.running_mean, self.running_var, diff --git a/learning/runners/BaseRunner.py b/learning/runners/BaseRunner.py index 9e5ab2e3..fc5d745d 100644 --- a/learning/runners/BaseRunner.py +++ b/learning/runners/BaseRunner.py @@ -49,6 +49,8 @@ def get_noise(self, obs_list, noise_dict): noise_tensor = torch.ones(obs_size).to(self.device) * torch.tensor( noise_dict[obs] ).to(self.device) + if "noise_multiplier" in noise_dict.keys(): + noise_tensor *= noise_dict["noise_multiplier"] if obs in self.env.scales.keys(): noise_tensor /= self.env.scales[obs] noise_vec[obs_index : obs_index + obs_size] = noise_tensor diff --git a/learning/runners/hybrid_policy_runner.py b/learning/runners/hybrid_policy_runner.py index fb2c9890..3e70f332 100644 --- a/learning/runners/hybrid_policy_runner.py +++ b/learning/runners/hybrid_policy_runner.py @@ -143,13 +143,14 @@ def learn(self): logger.toc("collection") # Compute GePPO weights - n_policies = min(self.it, self.num_old_policies) - weights_active = self.weights[:n_policies] - weights_active = weights_active * n_policies / weights_active.sum() + num_policies = storage.fill_count // self.num_steps_per_env + num_policies = min(num_policies, self.num_old_policies) + weights_active = self.weights[:num_policies] + weights_active = weights_active * num_policies / weights_active.sum() idx_newest = (self.it - 1) % self.num_old_policies indices_all = [ i % self.num_old_policies - for i in range(idx_newest, idx_newest - n_policies, -1) + for i in range(idx_newest, idx_newest - num_policies, -1) ] weights_all = weights_active[indices_all] weights_all = weights_all.repeat_interleave(self.num_steps_per_env) @@ -208,21 +209,28 @@ def set_up_logger(self): logger.register_category( "algorithm", self.alg, - ["learning_rate", "mean_value_loss", "mean_surrogate_loss"], - ) - logger.register_category("actor", self.alg.actor, ["action_std", "entropy"]) - - # GePPO specific logging - logger.register_category( - "GePPO", - self.alg, [ - "eps_geppo", - "tv", + "learning_rate", + "mean_value_loss", + "mean_surrogate_loss", "adv_mean", "ret_mean", + # GePPO specific: "adv_vtrace_mean", "ret_vtrace_mean", + "eps_geppo", + "tv", + ], + ) + logger.register_category( + "actor", + self.alg.actor, + [ + "action_mean", + "action_std", + "entropy", + "obs_running_mean", + "obs_running_std", ], ) @@ -245,6 +253,8 @@ def save(self): def load(self, path, load_optimizer=True): loaded_dict = torch.load(path) self.alg.actor.load_state_dict(loaded_dict["actor_state_dict"]) + # Update pik NN weights + self.alg.actor.update_pik_weights() self.alg.critic.load_state_dict(loaded_dict["critic_state_dict"]) if load_optimizer: self.alg.optimizer.load_state_dict(loaded_dict["optimizer_state_dict"]) diff --git a/learning/runners/on_policy_runner.py b/learning/runners/on_policy_runner.py index a82f24f6..6cd7e2d6 100644 --- a/learning/runners/on_policy_runner.py +++ b/learning/runners/on_policy_runner.py @@ -162,9 +162,25 @@ def set_up_logger(self): logger.register_category( "algorithm", self.alg, - ["learning_rate", "mean_value_loss", "mean_surrogate_loss"], + [ + "learning_rate", + "mean_value_loss", + "mean_surrogate_loss", + "adv_mean", + "ret_mean", + ], + ) + logger.register_category( + "actor", + self.alg.actor, + [ + "action_mean", + "action_std", + "entropy", + "obs_running_mean", + "obs_running_std", + ], ) - logger.register_category("actor", self.alg.actor, ["action_std", "entropy"]) logger.attach_torch_obj_to_wandb((self.alg.actor, self.alg.critic)) diff --git a/learning/utils/dict_utils.py b/learning/utils/dict_utils.py index 3bde0380..d8138343 100644 --- a/learning/utils/dict_utils.py +++ b/learning/utils/dict_utils.py @@ -103,8 +103,10 @@ def compute_gae_vtrace(data, gamma, lam, is_trunc, actor, critic, rec=False): ratio_trunc_prods = torch.tril(torch.cumprod(ratio_trunc_L + ones_U, axis=1), 0) # everything in data dict is [n, e] - values = critic.evaluate(data["critic_obs"]) + values = data["values"] values_next = critic.evaluate(data["next_critic_obs"]) + if values_next.dim() == 1: + values_next = values_next.reshape((n, e)) not_done = ~data["dones"] delta = data["rewards"] + gamma * values_next * not_done - values # [n, e] From b4e94f66a3974c426adefec763f5e5daf518a81a Mon Sep 17 00:00:00 2001 From: Steve Heim Date: Tue, 2 Jul 2024 16:12:56 -0400 Subject: [PATCH 09/10] handle value size at source --- learning/algorithms/ppo2.py | 6 +----- learning/modules/critic.py | 2 +- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/learning/algorithms/ppo2.py b/learning/algorithms/ppo2.py index 8a29f346..c1c4c126 100644 --- a/learning/algorithms/ppo2.py +++ b/learning/algorithms/ppo2.py @@ -58,11 +58,7 @@ def act(self, obs): return self.actor.act(obs).detach() def update(self, data): - values = self.critic.evaluate(data["critic_obs"]) - # Handle single env case - if values.dim() == 1: - values = values.unsqueeze(-1) - data["values"] = values + data["values"] = self.critic.evaluate(data["critic_obs"]) data["advantages"] = compute_generalized_advantages( data, self.gamma, self.lam, self.critic ) diff --git a/learning/modules/critic.py b/learning/modules/critic.py index 0ea7825d..cec533a7 100644 --- a/learning/modules/critic.py +++ b/learning/modules/critic.py @@ -24,7 +24,7 @@ def forward(self, x): if self._normalize_obs: with torch.no_grad(): x = self.obs_rms(x) - return self.NN(x).squeeze() + return self.NN(x).squeeze(-1) def evaluate(self, critic_observations): return self.forward(critic_observations) From 8f0680be4df5edc72d66459e07ad9eff6584916e Mon Sep 17 00:00:00 2001 From: Steve Heim Date: Tue, 2 Jul 2024 17:38:28 -0400 Subject: [PATCH 10/10] fixed "handle one env" at the source in criitic --- learning/algorithms/geppo.py | 11 +---------- learning/modules/utils/normalize.py | 8 ++------ 2 files changed, 3 insertions(+), 16 deletions(-) diff --git a/learning/algorithms/geppo.py b/learning/algorithms/geppo.py index 10198260..a56d432f 100644 --- a/learning/algorithms/geppo.py +++ b/learning/algorithms/geppo.py @@ -55,20 +55,11 @@ def __init__( self.updated = False def update(self, data, weights): - values = self.critic.evaluate(data["critic_obs"]) - # Handle single env case - if values.dim() == 1: - values = values.unsqueeze(-1) - data["values"] = values + data["values"] = self.critic.evaluate(data["critic_obs"]) # Compute GAE with and without V-trace adv, ret = self.compute_gae_all(data, vtrace=False) adv_vtrace, ret_vtrace = self.compute_gae_all(data, vtrace=True) - # Handle single env case - if adv_vtrace.dim() == 1: - adv_vtrace = adv_vtrace.unsqueeze(-1) - if ret_vtrace.dim() == 1: - ret_vtrace = ret_vtrace.unsqueeze(-1) # Only use V-trace if we have updated once already if self.vtrace and self.updated: diff --git a/learning/modules/utils/normalize.py b/learning/modules/utils/normalize.py index 4da03653..00471325 100644 --- a/learning/modules/utils/normalize.py +++ b/learning/modules/utils/normalize.py @@ -41,12 +41,8 @@ def _update_mean_var_from_moments( def forward(self, input): if self.training: # TODO: check this, it got rid of NaN values in first iteration - dim = tuple(range(input.dim() - 1)) - mean = input.mean(dim) - if input.dim() <= 2: - var = torch.zeros_like(mean) - else: - var = input.var(dim) + mean = input.mean(tuple(range(input.dim() - 1))) + var = torch.nan_to_num(input.var(tuple(range(input.dim() - 1)))) ( self.running_mean, self.running_var,