diff --git a/gym/envs/__init__.py b/gym/envs/__init__.py index 57cbd8c0..0736c973 100644 --- a/gym/envs/__init__.py +++ b/gym/envs/__init__.py @@ -27,6 +27,7 @@ "MiniCheetahCfg": ".mini_cheetah.mini_cheetah_config", "MiniCheetahRefCfg": ".mini_cheetah.mini_cheetah_ref_config", "MiniCheetahOscCfg": ".mini_cheetah.mini_cheetah_osc_config", + "MiniCheetahFineTuneCfg": ".mini_cheetah.mini_cheetah_finetune_config", "MITHumanoidCfg": ".mit_humanoid.mit_humanoid_config", "A1Cfg": ".a1.a1_config", "AnymalCFlatCfg": ".anymal_c.flat.anymal_c_flat_config", @@ -39,6 +40,7 @@ "MiniCheetahRunnerCfg": ".mini_cheetah.mini_cheetah_config", "MiniCheetahRefRunnerCfg": ".mini_cheetah.mini_cheetah_ref_config", "MiniCheetahOscRunnerCfg": ".mini_cheetah.mini_cheetah_osc_config", + "MiniCheetahFineTuneRunnerCfg": ".mini_cheetah.mini_cheetah_finetune_config", "MITHumanoidRunnerCfg": ".mit_humanoid.mit_humanoid_config", "A1RunnerCfg": ".a1.a1_config", "AnymalCFlatRunnerCfg": ".anymal_c.flat.anymal_c_flat_config", @@ -59,6 +61,11 @@ "MiniCheetahOscCfg", "MiniCheetahOscRunnerCfg", ], + "mini_cheetah_finetune": [ + "MiniCheetahRef", + "MiniCheetahFineTuneCfg", + "MiniCheetahFineTuneRunnerCfg", + ], "humanoid": ["MIT_Humanoid", "MITHumanoidCfg", "MITHumanoidRunnerCfg"], "humanoid_running": [ "HumanoidRunning", diff --git a/gym/envs/base/fixed_robot.py b/gym/envs/base/fixed_robot.py index 1fe4c5d5..6d9dc971 100644 --- a/gym/envs/base/fixed_robot.py +++ b/gym/envs/base/fixed_robot.py @@ -40,64 +40,60 @@ def __init__(self, gym, sim, cfg, sim_params, sim_device, headless): self.reset() def step(self): - """Apply actions, simulate, call self.post_physics_step() - and pre_physics_step() - - Args: - actions (torch.Tensor): Tensor of shape - (num_envs, num_actions_per_env) - """ - self._reset_buffers() - self._pre_physics_step() - # * step physics and render each frame + self._pre_decimation_step() self._render() for _ in range(self.cfg.control.decimation): + self._pre_compute_torques() self.torques = self._compute_torques() - - if self.cfg.asset.disable_motors: - self.torques[:] = 0.0 - torques_to_gym_tensor = torch.zeros( - self.num_envs, self.num_dof, device=self.device - ) - - # todo encapsulate - next_torques_idx = 0 - for dof_idx in range(self.num_dof): - if self.cfg.control.actuated_joints_mask[dof_idx]: - torques_to_gym_tensor[:, dof_idx] = self.torques[ - :, next_torques_idx - ] - next_torques_idx += 1 - else: - torques_to_gym_tensor[:, dof_idx] = torch.zeros( - self.num_envs, device=self.device - ) - - self.gym.set_dof_actuation_force_tensor( - self.sim, gymtorch.unwrap_tensor(torques_to_gym_tensor) - ) - self.gym.simulate(self.sim) - if self.device == "cpu": - self.gym.fetch_results(self.sim, True) - self.gym.refresh_dof_state_tensor(self.sim) - - self._post_physics_step() + self._post_compute_torques() + self._step_physx_sim() + self._post_physx_step() + self._post_decimation_step() self._check_terminations_and_timeouts() env_ids = self.to_be_reset.nonzero(as_tuple=False).flatten() self._reset_idx(env_ids) - def _pre_physics_step(self): - pass + def _pre_decimation_step(self): + return None + + def _pre_compute_torques(self): + return None + + def _post_compute_torques(self): + if self.cfg.asset.disable_motors: + self.torques[:] = 0.0 - def _post_physics_step(self): + def _step_physx_sim(self): + next_torques_idx = 0 + torques_to_gym_tensor = torch.zeros( + self.num_envs, self.num_dof, device=self.device + ) + for dof_idx in range(self.num_dof): + if self.cfg.control.actuated_joints_mask[dof_idx]: + torques_to_gym_tensor[:, dof_idx] = self.torques[:, next_torques_idx] + next_torques_idx += 1 + else: + torques_to_gym_tensor[:, dof_idx] = torch.zeros( + self.num_envs, device=self.device + ) + self.gym.set_dof_actuation_force_tensor( + self.sim, gymtorch.unwrap_tensor(self.torques) + ) + self.gym.simulate(self.sim) + if self.device == "cpu": + self.gym.fetch_results(self.sim, True) + self.gym.refresh_dof_state_tensor(self.sim) + + def _post_physx_step(self): """ check terminations, compute observations and rewards """ self.gym.refresh_actor_root_state_tensor(self.sim) self.gym.refresh_net_contact_force_tensor(self.sim) + def _post_decimation_step(self): self.episode_length_buf += 1 self.common_step_counter += 1 @@ -212,18 +208,6 @@ def _process_rigid_body_props(self, props, env_id): return props def _compute_torques(self): - """Compute torques from actions. - Actions can be interpreted as position or velocity targets given - to a PD controller, or directly as scaled torques. - [NOTE]: torques must have the same dimension as the number of DOFs, - even if some DOFs are not actuated. - - Args: - actions (torch.Tensor): Actions - - Returns: - [torch.Tensor]: Torques sent to the simulation - """ actuated_dof_pos = torch.zeros( self.num_envs, self.num_actuators, device=self.device ) @@ -415,10 +399,7 @@ def _init_buffers(self): self.default_act_pos = self.default_act_pos.unsqueeze(0) # * store indices of actuated joints self.act_idx = to_torch(actuated_idx, dtype=torch.long, device=self.device) - # * check that init range highs and lows are consistent - # * and repopulate to match - if self.cfg.init_state.reset_mode == "reset_to_range": - self.initialize_ranges_for_initial_conditions() + self.initialize_ranges_for_initial_conditions() def initialize_ranges_for_initial_conditions(self): self.dof_pos_range = torch.zeros( diff --git a/gym/envs/base/fixed_robot_config.py b/gym/envs/base/fixed_robot_config.py index 8c7427f8..a9bdef47 100644 --- a/gym/envs/base/fixed_robot_config.py +++ b/gym/envs/base/fixed_robot_config.py @@ -123,34 +123,33 @@ class FixedRobotCfgPPO(BaseConfig): class logging: enable_local_saving = True - class policy: + class actor: init_noise_std = 1.0 hidden_dims = [512, 256, 128] - critic_hidden_dims = [512, 256, 128] # * can be elu, relu, selu, crelu, lrelu, tanh, sigmoid activation = "elu" - # only for 'ActorCriticRecurrent': - # rnn_type = 'lstm' - # rnn_hidden_size = 512 - # rnn_num_layers = 1 - obs = [ "observation_a", "observation_b", "these_need_to_be_atributes_(states)_of_the_robot_env", ] - - critic_obs = [ - "observation_x", - "observation_y", - "critic_obs_can_be_the_same_or_different_than_actor_obs", - ] + normalize_obs = True actions = ["tau_ff"] disable_actions = False class noise: - noise = 0.1 # implement as needed, also in your robot class + observation_a = 0.1 # implement as needed, also in your robot class + + class critic: + hidden_dims = [512, 256, 128] + activation = "elu" + normalize_obs = True + obs = [ + "observation_x", + "observation_y", + "critic_obs_can_be_the_same_or_different_than_actor_obs", + ] class rewards: class weights: @@ -165,20 +164,25 @@ class termination_weight: termination = 0.0 class algorithm: - # * training params - value_loss_coef = 1.0 - use_clipped_value_loss = True + # both + gamma = 0.99 + lam = 0.95 + # shared + batch_size = 2**15 + max_gradient_steps = 10 + # new + storage_size = 2**17 # new + batch_size = 2**15 # new + clip_param = 0.2 - entropy_coef = 0.01 - num_learning_epochs = 5 - # * mini batch size = num_envs*nsteps / nminibatches - num_mini_batches = 4 learning_rate = 1.0e-3 + max_grad_norm = 1.0 + # Critic + use_clipped_value_loss = True + # Actor + entropy_coef = 0.01 schedule = "adaptive" # could be adaptive, fixed - gamma = 0.99 - lam = 0.95 desired_kl = 0.01 - max_grad_norm = 1.0 class runner: policy_class_name = "ActorCritic" diff --git a/gym/envs/base/legged_robot_config.py b/gym/envs/base/legged_robot_config.py index 7e6cabc6..2e73d924 100644 --- a/gym/envs/base/legged_robot_config.py +++ b/gym/envs/base/legged_robot_config.py @@ -239,6 +239,7 @@ class actor: # can be elu, relu, selu, crelu, lrelu, tanh, sigmoid activation = "elu" normalize_obs = True + smooth_exploration = False obs = [ "observation_a", @@ -288,20 +289,27 @@ class termination_weight: termination = 0.01 class algorithm: - # * training params - value_loss_coef = 1.0 - use_clipped_value_loss = True + # both + gamma = 0.99 + lam = 0.95 + # shared + batch_size = 2**15 + max_gradient_steps = 24 + # new + storage_size = 2**17 # new + batch_size = 2**15 # new + clip_param = 0.2 - entropy_coef = 0.01 - num_learning_epochs = 5 - # * mini batch size = num_envs*nsteps / nminibatches - num_mini_batches = 4 learning_rate = 1.0e-3 + max_grad_norm = 1.0 + # Critic + use_clipped_value_loss = True + # Actor + entropy_coef = 0.01 schedule = "adaptive" # could be adaptive, fixed - gamma = 0.99 - lam = 0.95 desired_kl = 0.01 - max_grad_norm = 1.0 + lr_range = [2e-4, 1e-2] + lr_ratio = 1.3 class runner: policy_class_name = "ActorCritic" diff --git a/gym/envs/base/task_skeleton.py b/gym/envs/base/task_skeleton.py index ffbac6ce..0974970a 100644 --- a/gym/envs/base/task_skeleton.py +++ b/gym/envs/base/task_skeleton.py @@ -45,6 +45,7 @@ def reset(self): """Reset all robots""" self._reset_idx(torch.arange(self.num_envs, device=self.device)) self.step() + self.episode_length_buf[:] = 0 def _reset_buffers(self): self.to_be_reset[:] = False @@ -67,7 +68,7 @@ def _eval_reward(self, name): def _check_terminations_and_timeouts(self): """Check if environments need to be reset""" contact_forces = self.contact_forces[:, self.termination_contact_indices, :] - self.terminated = torch.any(torch.norm(contact_forces, dim=-1) > 1.0, dim=1) + self.terminated |= torch.any(torch.norm(contact_forces, dim=-1) > 1.0, dim=1) self.timed_out = self.episode_length_buf >= self.max_episode_length self.to_be_reset = self.timed_out | self.terminated diff --git a/gym/envs/cartpole/cartpole_config.py b/gym/envs/cartpole/cartpole_config.py index f18d5a3b..6e073d5c 100644 --- a/gym/envs/cartpole/cartpole_config.py +++ b/gym/envs/cartpole/cartpole_config.py @@ -67,7 +67,7 @@ class CartpoleRunnerCfg(FixedRobotCfgPPO): seed = -1 runner_class_name = "OnPolicyRunner" - class policy(FixedRobotCfgPPO.policy): + class actor(FixedRobotCfgPPO.actor): init_noise_std = 1.0 num_layers = 2 num_units = 32 diff --git a/gym/envs/mini_cheetah/mini_cheetah_config.py b/gym/envs/mini_cheetah/mini_cheetah_config.py index 28c70123..6de32218 100644 --- a/gym/envs/mini_cheetah/mini_cheetah_config.py +++ b/gym/envs/mini_cheetah/mini_cheetah_config.py @@ -130,6 +130,8 @@ class actor: hidden_dims = [256, 256, 128] # * can be elu, relu, selu, crelu, lrelu, tanh, sigmoid activation = "elu" + smooth_exploration = True + exploration_sample_freq = 16 obs = [ "base_lin_vel", @@ -190,24 +192,11 @@ class termination_weight: termination = 0.01 class algorithm(LeggedRobotRunnerCfg.algorithm): - # * training params - value_loss_coef = 1.0 - use_clipped_value_loss = True - clip_param = 0.2 - entropy_coef = 0.02 - num_learning_epochs = 4 - # * mini batch size = num_envs*nsteps / nminibatches - num_mini_batches = 8 - learning_rate = 1.0e-5 - schedule = "adaptive" # can be adaptive or fixed - discount_horizon = 1.0 # [s] - # GAE_bootstrap_horizon = 2.0 # [s] - desired_kl = 0.01 - max_grad_norm = 1.0 + pass class runner(LeggedRobotRunnerCfg.runner): run_name = "" experiment_name = "mini_cheetah" - max_iterations = 500 + max_iterations = 800 algorithm_class_name = "PPO2" num_steps_per_env = 32 diff --git a/gym/envs/mini_cheetah/mini_cheetah_finetune_config.py b/gym/envs/mini_cheetah/mini_cheetah_finetune_config.py new file mode 100644 index 00000000..31f41894 --- /dev/null +++ b/gym/envs/mini_cheetah/mini_cheetah_finetune_config.py @@ -0,0 +1,162 @@ +from gym.envs.mini_cheetah.mini_cheetah_ref_config import ( + MiniCheetahRefCfg, + MiniCheetahRefRunnerCfg, +) + +BASE_HEIGHT_REF = 0.33 + + +class MiniCheetahFineTuneCfg(MiniCheetahRefCfg): + class env(MiniCheetahRefCfg.env): + num_envs = 4096 + num_actuators = 12 + episode_length_s = 30.0 + + class terrain(MiniCheetahRefCfg.terrain): + pass + + class init_state(MiniCheetahRefCfg.init_state): + pass + + class control(MiniCheetahRefCfg.control): + # * PD Drive parameters: + stiffness = {"haa": 20.0, "hfe": 20.0, "kfe": 20.0} + damping = {"haa": 0.5, "hfe": 0.5, "kfe": 0.5} + gait_freq = 3.0 + ctrl_frequency = 100 + desired_sim_frequency = 500 + + class commands(MiniCheetahRefCfg.commands): + pass + + class push_robots(MiniCheetahRefCfg.push_robots): + pass + + class domain_rand(MiniCheetahRefCfg.domain_rand): + pass + + class asset(MiniCheetahRefCfg.asset): + pass + + class reward_settings(MiniCheetahRefCfg.reward_settings): + pass + + class scaling(MiniCheetahRefCfg.scaling): + pass + + +class MiniCheetahFineTuneRunnerCfg(MiniCheetahRefRunnerCfg): + seed = -1 + runner_class_name = "IPGRunner" + + class actor: + hidden_dims = [256, 256, 128] + # * can be elu, relu, selu, crelu, lrelu, tanh, sigmoid + activation = "elu" + smooth_exploration = True + exploration_sample_freq = 16 + + normalize_obs = True + obs = [ + "base_ang_vel", + "projected_gravity", + "commands", + "dof_pos_obs", + "dof_vel", + "phase_obs", + ] + + actions = ["dof_pos_target"] + disable_actions = False + + class noise: + scale = 1.0 + dof_pos_obs = 0.01 + base_ang_vel = 0.01 + dof_pos = 0.005 + dof_vel = 0.005 + lin_vel = 0.05 + ang_vel = [0.3, 0.15, 0.4] + gravity_vec = 0.1 + + class critic: + hidden_dims = [256, 256, 128] + # * can be elu, relu, selu, crelu, lrelu, tanh, sigmoid + activation = "elu" + + # TODO: Check normalization, SAC/IPG need gradient to pass back through actor + normalize_obs = False + obs = [ + "base_height", + "base_lin_vel", + "base_ang_vel", + "projected_gravity", + "commands", + "dof_pos_obs", + "dof_vel", + "phase_obs", + "dof_pos_target", + ] + + class reward: + class weights: + tracking_lin_vel = 4.0 + tracking_ang_vel = 2.0 + orientation = 1.0 + min_base_height = 1.5 + stand_still = 1.0 + swing_grf = 3.0 + stance_grf = 3.0 + action_rate = 0.01 + action_rate2 = 0.001 + + class termination_weight: + termination = 0.15 + + class state_estimator: + class network: + hidden_dims = [128, 128] + activation = "tanh" + dropouts = None + + obs = [ + "base_ang_vel", + "projected_gravity", + "dof_pos_obs", + "dof_vel", + "torques", + "phase_obs", + ] + targets = ["base_height", "base_lin_vel", "grf"] + normalize_obs = True + + class algorithm(MiniCheetahRefRunnerCfg.algorithm): + desired_kl = 0.02 # 0.02 for smooth-exploration, else 0.01 + + # IPG + polyak = 0.995 + use_cv = False + inter_nu = 0.9 + beta = "off_policy" + storage_size = 30_000 + # val_interpolation = 0.8 + + # Finetuning + clip_param = 0.2 + max_gradient_steps = 8 + batch_size = 30_000 + learning_rate = 5e-5 # ACTOR + schedule = "fixed" + + class runner(MiniCheetahRefRunnerCfg.runner): + run_name = "" + experiment_name = "mini_cheetah_ref" + max_iterations = 20 # number of policy updates + algorithm_class_name = "PPO_IPG" + num_steps_per_env = 32 + + # Finetuning + resume = True + load_run = "Jul24_22-48-41_nu05_B8" + checkpoint = 1000 + save_interval = 1 diff --git a/gym/envs/mini_cheetah/mini_cheetah_osc_config.py b/gym/envs/mini_cheetah/mini_cheetah_osc_config.py index a4acefa6..d5f767e8 100644 --- a/gym/envs/mini_cheetah/mini_cheetah_osc_config.py +++ b/gym/envs/mini_cheetah/mini_cheetah_osc_config.py @@ -166,6 +166,7 @@ class policy: critic_hidden_dims = [256, 256, 128] # * can be elu, relu, selu, crelu, lrelu, tanh, sigmoid activation = "elu" + smooth_exploration = False obs = [ "base_ang_vel", diff --git a/gym/envs/mini_cheetah/mini_cheetah_ref.py b/gym/envs/mini_cheetah/mini_cheetah_ref.py index 57e1f0e7..dbee12c0 100644 --- a/gym/envs/mini_cheetah/mini_cheetah_ref.py +++ b/gym/envs/mini_cheetah/mini_cheetah_ref.py @@ -23,6 +23,7 @@ def _init_buffers(self): self.phase_obs = torch.zeros( self.num_envs, 2, dtype=torch.float, device=self.device ) + self.grf = self._compute_grf() def _reset_system(self, env_ids): super()._reset_system(env_ids) @@ -41,14 +42,25 @@ def _post_decimation_step(self): self.phase_obs = torch.cat( (torch.sin(self.phase), torch.cos(self.phase)), dim=1 ) + self.grf = self._compute_grf() def _resample_commands(self, env_ids): super()._resample_commands(env_ids) - # * with 10% chance, reset to 0 commands - rand_ids = torch_rand_float( - 0, 1, (len(env_ids), 1), device=self.device - ).squeeze(1) - self.commands[env_ids, :3] *= (rand_ids < 0.9).unsqueeze(1) + # * with 20% chance, reset to 0 commands except for forward + self.commands[env_ids, 1:] *= ( + torch_rand_float(0, 1, (len(env_ids), 1), device=self.device).squeeze(1) + < 0.8 + ).unsqueeze(1) + # * with 20% chance, reset to 0 commands except for rotation + self.commands[env_ids, :2] *= ( + torch_rand_float(0, 1, (len(env_ids), 1), device=self.device).squeeze(1) + < 0.8 + ).unsqueeze(1) + # * with 10% chance, reset to 0 + self.commands[env_ids, :] *= ( + torch_rand_float(0, 1, (len(env_ids), 1), device=self.device).squeeze(1) + < 0.9 + ).unsqueeze(1) def _switch(self): c_vel = torch.linalg.norm(self.commands, dim=1) @@ -56,6 +68,13 @@ def _switch(self): -torch.square(torch.max(torch.zeros_like(c_vel), c_vel - 0.1)) / 0.1 ) + def _compute_grf(self, grf_norm=True): + grf = torch.norm(self.contact_forces[:, self.feet_indices, :], dim=-1) + if grf_norm: + return torch.clamp_max(grf / 80.0, 1.0) + else: + return grf + def _reward_swing_grf(self): """Reward non-zero grf during swing (0 to pi)""" in_contact = torch.gt( diff --git a/gym/envs/mini_cheetah/mini_cheetah_ref_config.py b/gym/envs/mini_cheetah/mini_cheetah_ref_config.py index 3c92882d..2e808719 100644 --- a/gym/envs/mini_cheetah/mini_cheetah_ref_config.py +++ b/gym/envs/mini_cheetah/mini_cheetah_ref_config.py @@ -67,12 +67,15 @@ class scaling(MiniCheetahCfg.scaling): class MiniCheetahRefRunnerCfg(MiniCheetahRunnerCfg): seed = -1 - runner_class_name = "OnPolicyRunner" + runner_class_name = "IPGRunner" class actor: hidden_dims = [256, 256, 128] # * can be elu, relu, selu, crelu, lrelu, tanh, sigmoid activation = "elu" + smooth_exploration = True + exploration_sample_freq = 16 + normalize_obs = True obs = [ "base_ang_vel", @@ -100,7 +103,9 @@ class critic: hidden_dims = [256, 256, 128] # * can be elu, relu, selu, crelu, lrelu, tanh, sigmoid activation = "elu" - normalize_obs = True + + # TODO: Check normalization, SAC/IPG need gradient to pass back through actor + normalize_obs = False obs = [ "base_height", "base_lin_vel", @@ -118,47 +123,46 @@ class weights: tracking_lin_vel = 4.0 tracking_ang_vel = 2.0 lin_vel_z = 0.0 - ang_vel_xy = 0.01 + ang_vel_xy = 0.0 orientation = 1.0 - torques = 5.0e-7 + torques = 0.0 dof_vel = 0.0 min_base_height = 1.5 collision = 0.0 action_rate = 0.01 action_rate2 = 0.001 - stand_still = 0.0 + stand_still = 1.0 dof_pos_limits = 0.0 feet_contact_forces = 0.0 dof_near_home = 0.0 - reference_traj = 1.5 - swing_grf = 1.5 - stance_grf = 1.5 + reference_traj = 0.0 + swing_grf = 3.0 + stance_grf = 3.0 class termination_weight: termination = 0.15 class algorithm(MiniCheetahRunnerCfg.algorithm): - # training params - value_loss_coef = 1.0 # deprecate for PPO2 - use_clipped_value_loss = True # deprecate for PPO2 - clip_param = 0.2 - entropy_coef = 0.01 - num_learning_epochs = 6 - # mini batch size = num_envs*nsteps/nminibatches - num_mini_batches = 4 - storage_size = 2**17 # new - mini_batch_size = 2**15 # new - learning_rate = 5.0e-5 - schedule = "adaptive" # can be adaptive, fixed - discount_horizon = 1.0 # [s] - lam = 0.95 - GAE_bootstrap_horizon = 2.0 # [s] - desired_kl = 0.01 - max_grad_norm = 1.0 + desired_kl = 0.02 # 0.02 for smooth-exploration, else 0.01 + + # IPG + polyak = 0.95 + use_cv = False + inter_nu = 0.2 + beta = "off_policy" + storage_size = 8 * 32 * 4096 # num_policies*num_steps*num_envs + val_interpolation = 0.8 # 0: use V(s'), 1: use Q(s', pi(s')) + learning_rate = 1.0e-3 + lr_range = [1e-5, 1e-2] + lr_ratio = 1.5 class runner(MiniCheetahRunnerCfg.runner): run_name = "" experiment_name = "mini_cheetah_ref" - max_iterations = 500 # number of policy updates - algorithm_class_name = "PPO2" + max_iterations = 800 # number of policy updates + algorithm_class_name = "LinkedIPG" num_steps_per_env = 32 + + # resume = False + # load_run = "Jul25_12-24-16_LinkedIPG_50Hz_nu02_v08" + # checkpoint = 1000 diff --git a/gym/envs/mini_cheetah/minimalist_cheetah.py b/gym/envs/mini_cheetah/minimalist_cheetah.py new file mode 100644 index 00000000..8052697d --- /dev/null +++ b/gym/envs/mini_cheetah/minimalist_cheetah.py @@ -0,0 +1,159 @@ +import torch + + +class MinimalistCheetah: + """ + Helper class for computing mini cheetah rewards + """ + + def __init__( + self, device="cpu", tracking_sigma=0.25, ctrl_dt=0.01, ctrl_decimation=5 + ): + self.device = device + self.tracking_sigma = tracking_sigma + + # Implemented as in legged robot action rate reward + self.dt = ctrl_dt * ctrl_decimation + + # Default joint angles from mini_cheetah_config.py + self.default_dof_pos = torch.tensor( + [0.0, -0.785398, 1.596976], device=self.device + ).repeat(4) + + # Scales + self.command_scales = torch.tensor([3.0, 1.0, 3.0]).to(self.device) + self.dof_pos_scales = torch.tensor([0.2, 0.3, 0.3]).to(self.device).repeat(4) + self.dof_vel_scales = torch.tensor([2.0, 2.0, 4.0]).to(self.device).repeat(4) + + # Previous 2 dof pos targets + self.dof_target_prev = None + self.dof_target_prev2 = None + + def set_states( + self, + base_height, + base_lin_vel, + base_ang_vel, + proj_gravity, + commands, + dof_pos_obs, + dof_vel, + phase_obs, + grf, + dof_pos_target, + ): + # Unsqueeze so first dim is batch_size + self.base_height = torch.tensor(base_height, device=self.device).unsqueeze(0) + self.base_lin_vel = torch.tensor(base_lin_vel, device=self.device).unsqueeze(0) + self.base_ang_vel = torch.tensor(base_ang_vel, device=self.device).unsqueeze(0) + self.proj_gravity = torch.tensor(proj_gravity, device=self.device).unsqueeze(0) + self.commands = ( + torch.tensor(commands, device=self.device).unsqueeze(0) + * self.command_scales + ) + self.dof_pos_obs = ( + torch.tensor(dof_pos_obs, device=self.device).unsqueeze(0) + * self.dof_pos_scales + ) + self.dof_vel = ( + torch.tensor(dof_vel, device=self.device).unsqueeze(0) * self.dof_vel_scales + ) + self.grf = torch.tensor(grf, device=self.device).unsqueeze(0) + + # Get phase sin + phase_obs = torch.tensor(phase_obs, device=self.device).unsqueeze(0) + self.phase_sin = phase_obs[:, 0].unsqueeze(0) + + # Set target history + self.dof_pos_target = ( + torch.tensor(dof_pos_target, device=self.device).unsqueeze(0) + * self.dof_pos_scales + ) + if self.dof_target_prev is None: + self.dof_target_prev = self.dof_pos_target + if self.dof_target_prev2 is None: + self.dof_target_prev2 = self.dof_target_prev + + def post_process(self): + self.dof_target_prev2 = self.dof_target_prev + self.dof_target_prev = self.dof_pos_target + + def _sqrdexp(self, x, scale=1.0): + """shorthand helper for squared exponential""" + return torch.exp(-torch.square(x / scale) / self.tracking_sigma) + + def _switch(self, scale=0.1): + # TODO: Check scale, RS commands are scaled differently than QGym + c_vel = torch.linalg.norm(self.commands, dim=1) + return torch.exp( + -torch.square(torch.max(torch.zeros_like(c_vel), c_vel - 0.1)) / scale + ) + + def _reward_min_base_height(self, target_height=0.3, scale=0.3): + """Squared exponential saturating at base_height target""" + # TODO: Check scale + error = (self.base_height - target_height) / scale + error = torch.clamp(error, max=0, min=None).flatten() + return self._sqrdexp(error) + + def _reward_tracking_lin_vel(self): + """Tracking of linear velocity commands (xy axes)""" + # just use lin_vel? + error = self.commands[:, :2] - self.base_lin_vel[:, :2] + # * scale by (1+|cmd|): if cmd=0, no scaling. + error *= 1.0 / (1.0 + torch.abs(self.commands[:, :2])) + error = torch.sum(torch.square(error), dim=1) + return torch.exp(-error / self.tracking_sigma) * (1 - self._switch()) + + def _reward_tracking_ang_vel(self): + """Tracking of angular velocity commands (yaw)""" + ang_vel_error = torch.square(self.commands[:, 2] - self.base_ang_vel[:, 2]) + return torch.exp(-ang_vel_error / self.tracking_sigma) + + def _reward_orientation(self): + """Penalize non-flat base orientation""" + error = torch.square(self.proj_gravity[:, :2]) / self.tracking_sigma + return torch.sum(torch.exp(-error), dim=1) + + def _reward_swing_grf(self, contact_thresh=50 / 80): + """Reward non-zero grf during swing (0 to pi)""" + in_contact = torch.gt(self.grf, contact_thresh) + ph_off = torch.gt(self.phase_sin, 0) # phase <= pi + rew = in_contact * torch.cat((ph_off, ~ph_off, ~ph_off, ph_off), dim=1) + return -torch.sum(rew.float(), dim=1) * (1 - self._switch()) + + def _reward_stance_grf(self, contact_thresh=50 / 80): + """Reward non-zero grf during stance (pi to 2pi)""" + in_contact = torch.gt(self.grf, contact_thresh) + ph_off = torch.lt(self.phase_sin, 0) # phase >= pi + rew = in_contact * torch.cat((ph_off, ~ph_off, ~ph_off, ph_off), dim=1) + return torch.sum(rew.float(), dim=1) * (1 - self._switch()) + + def _reward_stand_still(self): + """Penalize motion at zero commands""" + # * normalize angles so we care about being within 5 deg + rew_pos = torch.mean( + self._sqrdexp((self.dof_pos_obs) / torch.pi * 36), + dim=1, + ) + rew_vel = torch.mean(self._sqrdexp(self.dof_vel), dim=1) + rew_base_vel = torch.mean(torch.square(self.base_lin_vel), dim=1) + rew_base_vel += torch.mean(torch.square(self.base_ang_vel), dim=1) + return (rew_vel + rew_pos - rew_base_vel) * self._switch() + + def _reward_action_rate(self): + """Penalize changes in actions""" + # TODO: check this + error = torch.square(self.dof_pos_target - self.dof_target_prev) / self.dt**2 + return -torch.sum(error, dim=1) + + def _reward_action_rate2(self): + """Penalize changes in actions""" + # TODO: check this + error = ( + torch.square( + self.dof_pos_target - 2 * self.dof_target_prev + self.dof_target_prev2 + ) + / self.dt**2 + ) + return -torch.sum(error, dim=1) diff --git a/gym/envs/mit_humanoid/mit_humanoid_config.py b/gym/envs/mit_humanoid/mit_humanoid_config.py index 86a8d606..6f8eaa1d 100644 --- a/gym/envs/mit_humanoid/mit_humanoid_config.py +++ b/gym/envs/mit_humanoid/mit_humanoid_config.py @@ -187,6 +187,7 @@ class actor: critic_hidden_dims = [512, 256, 128] # * can be elu, relu, selu, crelu, lrelu, tanh, sigmoid activation = "elu" + smooth_exploration = False obs = [ "base_height", diff --git a/gym/envs/pendulum/pendulum.py b/gym/envs/pendulum/pendulum.py index 95af1fb2..d0bcfb9c 100644 --- a/gym/envs/pendulum/pendulum.py +++ b/gym/envs/pendulum/pendulum.py @@ -1,15 +1,49 @@ import torch +import numpy as np +from math import sqrt from gym.envs.base.fixed_robot import FixedRobot class Pendulum(FixedRobot): - def _post_physics_step(self): - """Update all states that are not handled in PhysX""" - super()._post_physics_step() + def _init_buffers(self): + super()._init_buffers() + self.dof_pos_obs = torch.zeros(self.num_envs, 2, device=self.device) + + def _post_decimation_step(self): + super()._post_decimation_step() + self.dof_pos_obs = torch.cat([self.dof_pos.sin(), self.dof_pos.cos()], dim=1) + + def _reset_system(self, env_ids): + super()._reset_system(env_ids) + self.dof_pos_obs[env_ids] = torch.cat( + [self.dof_pos[env_ids].sin(), self.dof_pos[env_ids].cos()], dim=1 + ) + + def _check_terminations_and_timeouts(self): + super()._check_terminations_and_timeouts() + self.terminated = self.timed_out + + def reset_to_uniform(self, env_ids): + grid_points = int(sqrt(self.num_envs)) + lin_pos = torch.linspace( + self.dof_pos_range[0, 0], + self.dof_pos_range[0, 1], + grid_points, + device=self.device, + ) + lin_vel = torch.linspace( + self.dof_vel_range[0, 0], + self.dof_vel_range[0, 1], + grid_points, + device=self.device, + ) + grid = torch.cartesian_prod(lin_pos, lin_vel) + self.dof_pos[env_ids] = grid[:, 0].unsqueeze(-1) + self.dof_vel[env_ids] = grid[:, 1].unsqueeze(-1) def _reward_theta(self): - theta_rwd = torch.cos(self.dof_pos[:, 0]) / self.scales["dof_pos"] + theta_rwd = torch.cos(self.dof_pos[:, 0]) # no scaling return self._sqrdexp(theta_rwd.squeeze(dim=-1)) def _reward_omega(self): @@ -17,27 +51,31 @@ def _reward_omega(self): return self._sqrdexp(omega_rwd.squeeze(dim=-1)) def _reward_equilibrium(self): - error = torch.abs(self.dof_state) - error[:, 0] /= self.scales["dof_pos"] - error[:, 1] /= self.scales["dof_vel"] - return self._sqrdexp(torch.mean(error, dim=1), scale=0.01) - # return torch.exp( - # -error.pow(2).sum(dim=1) / self.cfg.reward_settings.tracking_sigma - # ) - - def _reward_torques(self): - """Penalize torques""" - return self._sqrdexp(torch.mean(torch.square(self.torques), dim=1), scale=0.2) + theta_norm = self._normalize_theta() + omega = self.dof_vel[:, 0] + error = torch.stack( + [theta_norm / self.scales["dof_pos"], omega / self.scales["dof_vel"]], dim=1 + ) + return self._sqrdexp(torch.mean(error, dim=1), scale=0.2) def _reward_energy(self): - m_pendulum = 1.0 - l_pendulum = 1.0 kinetic_energy = ( - 0.5 * m_pendulum * l_pendulum**2 * torch.square(self.dof_vel[:, 0]) + 0.5 + * self.cfg.asset.mass + * self.cfg.asset.length**2 + * torch.square(self.dof_vel[:, 0]) ) potential_energy = ( - m_pendulum * 9.81 * l_pendulum * torch.cos(self.dof_pos[:, 0]) + self.cfg.asset.mass + * 9.81 + * self.cfg.asset.length + * torch.cos(self.dof_pos[:, 0]) ) - desired_energy = m_pendulum * 9.81 * l_pendulum + desired_energy = self.cfg.asset.mass * 9.81 * self.cfg.asset.length energy_error = kinetic_energy + potential_energy - desired_energy return self._sqrdexp(energy_error / desired_energy) + + def _normalize_theta(self): + # normalize to range [-pi, pi] + theta = self.dof_pos[:, 0] + return ((theta + np.pi) % (2 * np.pi)) - np.pi diff --git a/gym/envs/pendulum/pendulum_config.py b/gym/envs/pendulum/pendulum_config.py index 920356cd..51c44557 100644 --- a/gym/envs/pendulum/pendulum_config.py +++ b/gym/envs/pendulum/pendulum_config.py @@ -5,9 +5,9 @@ class PendulumCfg(FixedRobotCfg): class env(FixedRobotCfg.env): - num_envs = 2**13 - num_actuators = 1 # 1 for theta connecting base and pole - episode_length_s = 5.0 + num_envs = 4096 + num_actuators = 1 + episode_length_s = 10 class terrain(FixedRobotCfg.terrain): pass @@ -18,12 +18,12 @@ class viewer: lookat = [0.0, 0.0, 0.0] # [m] class init_state(FixedRobotCfg.init_state): - default_joint_angles = {"theta": 0.0} # -torch.pi / 2.0} + default_joint_angles = {"theta": 0} # -torch.pi / 2.0} # * default setup chooses how the initial conditions are chosen. # * "reset_to_basic" = a single position # * "reset_to_range" = uniformly random from a range defined below - reset_mode = "reset_to_range" + reset_mode = "reset_to_uniform" # * initial conditions for reset_to_range dof_pos_range = { @@ -33,8 +33,8 @@ class init_state(FixedRobotCfg.init_state): class control(FixedRobotCfg.control): actuated_joints_mask = [1] # angle - ctrl_frequency = 100 - desired_sim_frequency = 200 + ctrl_frequency = 10 + desired_sim_frequency = 100 stiffness = {"theta": 0.0} # [N*m/rad] damping = {"theta": 0.0} # [N*m*s/rad] @@ -44,6 +44,8 @@ class asset(FixedRobotCfg.asset): disable_gravity = False disable_motors = False # all torques set to 0 joint_damping = 0.1 + mass = 1.0 + length = 1.0 class reward_settings(FixedRobotCfg.reward_settings): tracking_sigma = 0.25 @@ -57,15 +59,18 @@ class scaling(FixedRobotCfg.scaling): class PendulumRunnerCfg(FixedRobotCfgPPO): seed = -1 - runner_class_name = "DataLoggingRunner" + runner_class_name = "IPGRunner" class actor: hidden_dims = [128, 64, 32] # * can be elu, relu, selu, crelu, lrelu, tanh, sigmoid activation = "tanh" + smooth_exploration = False + exploration_sample_freq = 16 + normalize_obs = False obs = [ - "dof_pos", + "dof_pos_obs", "dof_vel", ] @@ -77,21 +82,22 @@ class noise: dof_vel = 0.0 class critic: - obs = [ - "dof_pos", - "dof_vel", - ] hidden_dims = [128, 64, 32] # * can be elu, relu, selu, crelu, lrelu, tanh, sigmoid activation = "tanh" - standard_critic_nn = True + + normalize_obs = False + obs = [ + "dof_pos_obs", + "dof_vel", + ] class reward: class weights: theta = 0.0 omega = 0.0 equilibrium = 1.0 - energy = 0.0 + energy = 0.5 dof_vel = 0.0 torques = 0.025 @@ -99,27 +105,37 @@ class termination_weight: termination = 0.0 class algorithm(FixedRobotCfgPPO.algorithm): - # training params - value_loss_coef = 1.0 - use_clipped_value_loss = True + # both + gamma = 0.95 + # discount_horizon = 2.0 + lam = 0.98 + # shared + max_gradient_steps = 24 + # new + storage_size = 2**17 # new + batch_size = 2**16 # new clip_param = 0.2 + learning_rate = 1.0e-4 + max_grad_norm = 1.0 + # Critic + use_clipped_value_loss = True + # Actor entropy_coef = 0.01 - num_learning_epochs = 6 - # * mini batch size = num_envs*nsteps / nminibatches - num_mini_batches = 4 - learning_rate = 1.0e-3 schedule = "fixed" # could be adaptive, fixed - discount_horizon = 2.0 # [s] - lam = 0.98 - # GAE_bootstrap_horizon = .0 # [s] desired_kl = 0.01 - max_grad_norm = 1.0 - standard_loss = True - plus_c_penalty = 0.1 + + # IPG + polyak = 0.9 + use_cv = False + inter_nu = 0.5 + beta = "off_policy" + storage_size = 4 * 100 * 4096 # num_policies*num_steps*num_envs + val_interpolation = 0.5 # 0: use V(s'), 1: use Q(s', pi(s')) class runner(FixedRobotCfgPPO.runner): run_name = "" experiment_name = "pendulum" - max_iterations = 500 # number of policy updates - algorithm_class_name = "PPO2" - num_steps_per_env = 32 + max_iterations = 100 # number of policy updates + algorithm_class_name = "LinkedIPG" + num_steps_per_env = 100 + save_interval = 20 diff --git a/gym/utils/helpers.py b/gym/utils/helpers.py index 90af8bce..2f3a5dcf 100644 --- a/gym/utils/helpers.py +++ b/gym/utils/helpers.py @@ -181,6 +181,9 @@ def update_cfg_from_args(env_cfg, train_cfg, args): train_cfg.runner.checkpoint = args.checkpoint if args.rl_device is not None: train_cfg.runner.device = args.rl_device + # * IPG parameters + if args.inter_nu is not None: + train_cfg.algorithm.inter_nu = args.inter_nu def get_args(custom_parameters=None): @@ -301,6 +304,11 @@ def get_args(custom_parameters=None): "default": False, "help": "Use original config file for loaded policy.", }, + { + "name": "--inter_nu", + "type": float, + "help": "Interpolation parameter for IPG.", + }, ] # * parse arguments args = gymutil.parse_arguments( diff --git a/learning/algorithms/SE.py b/learning/algorithms/SE.py index 99dfb75f..9d1f2d2d 100644 --- a/learning/algorithms/SE.py +++ b/learning/algorithms/SE.py @@ -3,9 +3,56 @@ from learning.modules import StateEstimatorNN from learning.storage import SERolloutStorage +from learning.utils import create_uniform_generator class StateEstimator: + def __init__( + self, + state_estimator, + normalize_obs=True, + batch_size=2**15, + max_gradient_steps=10, + learning_rate=1e-3, + device="cpu", + **kwargs, + ): + self.device = device + + self.network = state_estimator.to(self.device) + + self.batch_size = batch_size + self.max_gradient_steps = max_gradient_steps + + self.learning_rate = learning_rate + self.mean_loss = 0.0 + self.optimizer = optim.Adam(self.network.parameters(), lr=learning_rate) + + def update(self, data): + self.mean_loss = 0 + counter = 0 + generator = create_uniform_generator( + data, self.batch_size, self.max_gradient_steps + ) + for batch in generator: + loss = nn.functional.mse_loss( + self.network.evaluate(batch["SE_obs"]), batch["SE_targets"] + ) + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + self.mean_loss += loss.item() + counter += 1 + self.mean_loss /= counter + + def estimate(self, obs): + return self.network.evaluate(obs) + + def export(self, path): + self.network.export(path) + + +class OldStateEstimator: """This class provides a learned state estimator. This is trained with supervised learning, using only on-policy data collected in a rollout storage. diff --git a/learning/algorithms/__init__.py b/learning/algorithms/__init__.py index ace1b9fe..8bee38d0 100644 --- a/learning/algorithms/__init__.py +++ b/learning/algorithms/__init__.py @@ -32,4 +32,6 @@ from .ppo import PPO from .ppo2 import PPO2 +from .ppo_ipg import PPO_IPG +from .linked_ipg import LinkedIPG from .SE import StateEstimator diff --git a/learning/algorithms/linked_ipg.py b/learning/algorithms/linked_ipg.py new file mode 100644 index 00000000..dbaef5eb --- /dev/null +++ b/learning/algorithms/linked_ipg.py @@ -0,0 +1,317 @@ +import torch +import torch.nn as nn +import torch.optim as optim + +from learning.utils import ( + create_uniform_generator, + compute_generalized_advantages, + normalize, + polyak_update, +) + + +class LinkedIPG: + def __init__( + self, + actor, + critic_v, + critic_q, + target_critic_q, + batch_size=2**15, + max_gradient_steps=10, + clip_param=0.2, + gamma=0.998, + lam=0.95, + entropy_coef=0.0, + learning_rate=1e-3, + max_grad_norm=1.0, + use_clipped_value_loss=True, + schedule="fixed", + desired_kl=0.01, + polyak=0.995, + use_cv=False, + inter_nu=0.2, + beta="off_policy", + device="cpu", + lr_range=[1e-4, 1e-2], + lr_ratio=1.3, + val_interpolation=0.5, + **kwargs, + ): + self.device = device + + self.desired_kl = desired_kl + self.schedule = schedule + self.learning_rate = learning_rate + self.lr_range = lr_range + self.lr_ratio = lr_ratio + + # * PPO components + self.actor = actor.to(self.device) + self.optimizer = optim.Adam(self.actor.parameters(), lr=learning_rate) + self.critic_v = critic_v.to(self.device) + self.critic_v_optimizer = optim.Adam( + self.critic_v.parameters(), lr=learning_rate + ) + + # * IPG components + self.critic_q = critic_q.to(self.device) + self.critic_q_optimizer = optim.Adam( + self.critic_q.parameters(), lr=learning_rate + ) + self.target_critic_q = target_critic_q.to(self.device) + self.target_critic_q.load_state_dict(self.critic_q.state_dict()) + + # * PPO parameters + self.clip_param = clip_param + self.batch_size = batch_size + self.max_gradient_steps = max_gradient_steps + self.entropy_coef = entropy_coef + self.gamma = gamma + self.lam = lam + self.max_grad_norm = max_grad_norm + self.use_clipped_value_loss = use_clipped_value_loss + + # * IPG parameters + self.polyak = polyak + self.use_cv = use_cv + self.inter_nu = inter_nu + self.beta = beta + self.val_interpolation = val_interpolation + + def switch_to_train(self): + self.actor.train() + self.critic_v.train() + self.critic_q.train() + + def act(self, obs): + return self.actor.act(obs).detach() + + def update(self, data_onpol, data_offpol): + # On-policy GAE + data_onpol["values"] = self.critic_v.evaluate(data_onpol["critic_obs"]) + data_onpol["advantages"] = compute_generalized_advantages( + data_onpol, self.gamma, self.lam, self.critic_v + ) + data_onpol["returns"] = data_onpol["advantages"] + data_onpol["values"] + data_onpol["advantages"] = normalize(data_onpol["advantages"]) + + # TODO: Possibly use off-policy GAE for V-critic update + # data_offpol["values"] = self.critic_v.evaluate(data_offpol["critic_obs"]) + # data_offpol["advantages"] = compute_generalized_advantages( + # data_offpol, self.gamma, self.lam, self.critic_v + # ) + # data_offpol["returns"] = data_offpol["advantages"] + data_offpol["values"] + + self.update_critic_v(data_onpol) + self.update_critic_q(data_offpol) + self.update_actor(data_onpol, data_offpol) + + # def update_joint_critics(self, data_onpol, data_offpol): + # self.mean_q_loss = 0 + # self.mean_value_loss = 0 + # counter = 0 + # generator_onpol = create_uniform_generator( + # data_onpol, + # self.batch_size, + # max_gradient_steps=self.max_gradient_steps, + # ) + # generator_offpol = create_uniform_generator( + # data_offpol, + # self.batch_size, + # max_gradient_steps=self.max_gradient_steps, + # ) + # for batch_onpol, batch_offpol in zip(generator_onpol, generator_offpol): + # with torch.no_grad(): + # action_next_onpol = self.actor.act_inference( + # batch_onpol["next_actor_obs"] + # ) + # q_input_next_onpol = torch.cat( + # batch_onpol["next_critic_obs"], action_next_onpol + # ) + # action_next_offpol = self.actor.act_inference( + # batch_offpol["next_actor_obs"] + # ) + # q_input_next_offpol = torch.cat( + # batch_offpol["next_critic_obs"], action_next_offpol + # ) + # q_value_offpol = self.critic_q.evaluate(q_input_next_offpol) + + # loss_V_returns = self.critic_v.loss_fn( + # batch_onpol["critic_obs"], batch_onpol["returns"] + # ) + # loss_V_Q = nn.functional.mse_loss( + # self.critic_v.evaluate(batch_onpol["critic_obs"]), + # self.critic_q.evaluate(batch_onpol["critic_obs"]), + # reduction="mean") + # + # with torch.no_grad(): + # action_next = self.actor.act_inference(batch_offpol["next_actor_obs"]) + # q_input_next = torch.cat( + # (batch_offpol["next_critic_obs"], action_next), dim=-1 + # ) + + def update_critic_q(self, data): + self.mean_q_loss = 0 + counter = 0 + + generator = create_uniform_generator( + data, + self.batch_size, + max_gradient_steps=self.max_gradient_steps, + ) + for batch in generator: + with torch.no_grad(): + action_next = self.actor.act_inference(batch["next_actor_obs"]) + q_input_next = torch.cat( + (batch["next_critic_obs"], action_next), dim=-1 + ) + q_next = self.target_critic_q.evaluate(q_input_next) + v_next = self.critic_v.evaluate(batch["next_critic_obs"]) + q_target = batch["rewards"] + batch["dones"].logical_not() * ( + self.gamma + * ( + q_next * self.val_interpolation + + v_next * (1 - self.val_interpolation) + ) + ) + q_input = torch.cat((batch["critic_obs"], batch["actions"]), dim=-1) + q_loss = self.critic_q.loss_fn(q_input, q_target) + self.critic_q_optimizer.zero_grad() + q_loss.backward() + nn.utils.clip_grad_norm_(self.critic_q.parameters(), self.max_grad_norm) + self.critic_q_optimizer.step() + self.mean_q_loss += q_loss.item() + counter += 1 + + # TODO: check where to do polyak update (IPG repo does it here) + self.target_critic_q = polyak_update( + self.critic_q, self.target_critic_q, self.polyak + ) + self.mean_q_loss /= counter + + def update_critic_v(self, data): + self.mean_value_loss = 0 + counter = 0 + + generator = create_uniform_generator( + data, + self.batch_size, + max_gradient_steps=self.max_gradient_steps, + ) + for batch in generator: + value_loss = self.critic_v.loss_fn(batch["critic_obs"], batch["returns"]) + self.critic_v_optimizer.zero_grad() + value_loss.backward() + nn.utils.clip_grad_norm_(self.critic_v.parameters(), self.max_grad_norm) + self.critic_v_optimizer.step() + self.mean_value_loss += value_loss.item() + counter += 1 + self.mean_value_loss /= counter + + def update_actor(self, data_onpol, data_offpol): + self.mean_surrogate_loss = 0 + self.mean_offpol_loss = 0 + counter = 0 + + self.actor.update_distribution(data_onpol["actor_obs"]) + data_onpol["old_sigma"] = self.actor.action_std.detach() + data_onpol["old_mu"] = self.actor.action_mean.detach() + data_onpol["old_actions_log_prob"] = self.actor.get_actions_log_prob( + data_onpol["actions"] + ).detach() + + # Generate off-policy batches and use all on-policy data + generator_offpol = create_uniform_generator( + data_offpol, + self.batch_size, + max_gradient_steps=self.max_gradient_steps, + ) + for batch_offpol in generator_offpol: + self.actor.update_distribution(data_onpol["actor_obs"]) + actions_log_prob_onpol = self.actor.get_actions_log_prob( + data_onpol["actions"] + ) + mu_onpol = self.actor.action_mean + sigma_onpol = self.actor.action_std + + # * KL + if self.desired_kl is not None and self.schedule == "adaptive": + with torch.inference_mode(): + kl = torch.sum( + torch.log(sigma_onpol / data_onpol["old_sigma"] + 1.0e-5) + + ( + torch.square(data_onpol["old_sigma"]) + + torch.square(data_onpol["old_mu"] - mu_onpol) + ) + / (2.0 * torch.square(sigma_onpol)) + - 0.5, + axis=-1, + ) + kl_mean = torch.mean(kl) + lr_min, lr_max = self.lr_range + + if kl_mean > self.desired_kl * 2.0: + self.learning_rate = max( + lr_min, self.learning_rate / self.lr_ratio + ) + elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0: + self.learning_rate = min( + lr_max, self.learning_rate * self.lr_ratio + ) + + for param_group in self.optimizer.param_groups: + # ! check this + param_group["lr"] = self.learning_rate + + # * On-policy surrogate loss + adv_onpol = data_onpol["advantages"] + if self.use_cv: + # TODO: control variate + critic_based_adv = 0 # get_control_variate(data_onpol, self.critic_v) + learning_signals = (adv_onpol - critic_based_adv) * (1 - self.inter_nu) + else: + learning_signals = adv_onpol * (1 - self.inter_nu) + + ratio = torch.exp( + actions_log_prob_onpol - data_onpol["old_actions_log_prob"] + ) + ratio_clipped = torch.clamp( + ratio, 1.0 - self.clip_param, 1.0 + self.clip_param + ) + surrogate = -learning_signals * ratio + surrogate_clipped = -learning_signals * ratio_clipped + loss_onpol = torch.max(surrogate, surrogate_clipped).mean() + + # * Off-policy loss + if self.beta == "on_policy": + loss_offpol = self.compute_loss_offpol(data_onpol) + elif self.beta == "off_policy": + loss_offpol = self.compute_loss_offpol(batch_offpol) + else: + raise ValueError(f"Invalid beta value: {self.beta}") + + if self.use_cv: + b = 1 + else: + b = self.inter_nu + + loss = loss_onpol + b * loss_offpol + + # * Gradient step + self.optimizer.zero_grad() + loss.backward() + nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm) + self.optimizer.step() + self.mean_surrogate_loss += loss_onpol.item() + self.mean_offpol_loss += b * loss_offpol.item() + counter += 1 + self.mean_surrogate_loss /= counter + self.mean_offpol_loss /= counter + + def compute_loss_offpol(self, data): + obs = data["actor_obs"] + actions = self.actor.act_inference(obs) + q_input = torch.cat((data["critic_obs"], actions), dim=-1) + q_value = self.critic_q.evaluate(q_input) + return -q_value.mean() diff --git a/learning/algorithms/ppo.py b/learning/algorithms/ppo.py index 92898ad5..ff09d788 100644 --- a/learning/algorithms/ppo.py +++ b/learning/algorithms/ppo.py @@ -34,7 +34,7 @@ import torch.nn as nn import torch.optim as optim -from learning.modules import ActorCritic +from learning.modules import ActorCritic, SmoothActor from learning.storage import RolloutStorage @@ -162,7 +162,11 @@ def update(self): old_mu_batch, old_sigma_batch, ) in generator: - self.actor_critic.act(obs_batch) + # TODO[lm]: Look into resampling noise here, gSDE paper seems to do it. + if isinstance(self.actor_critic.actor, SmoothActor): + batch_size = obs_batch.shape[0] + self.actor_critic.actor.sample_weights(batch_size) + self.actor_critic.actor.update_distribution(obs_batch) actions_log_prob_batch = self.actor_critic.get_actions_log_prob( actions_batch ) diff --git a/learning/algorithms/ppo2.py b/learning/algorithms/ppo2.py index 5866052c..6c64e334 100644 --- a/learning/algorithms/ppo2.py +++ b/learning/algorithms/ppo2.py @@ -5,6 +5,7 @@ from learning.utils import ( create_uniform_generator, compute_generalized_advantages, + normalize, ) @@ -13,8 +14,8 @@ def __init__( self, actor, critic, - num_learning_epochs=1, - num_mini_batches=1, + batch_size=2**15, + max_gradient_steps=10, clip_param=0.2, gamma=0.998, lam=0.95, @@ -24,8 +25,9 @@ def __init__( use_clipped_value_loss=True, schedule="fixed", desired_kl=0.01, - loss_fn="MSE", device="cpu", + lr_range=[1e-4, 1e-2], + lr_ratio=1.3, **kwargs, ): self.device = device @@ -33,6 +35,8 @@ def __init__( self.desired_kl = desired_kl self.schedule = schedule self.learning_rate = learning_rate + self.lr_range = lr_range + self.lr_ratio = lr_ratio # * PPO components self.actor = actor.to(self.device) @@ -42,49 +46,42 @@ def __init__( # * PPO parameters self.clip_param = clip_param - self.num_learning_epochs = num_learning_epochs - self.num_mini_batches = num_mini_batches + self.batch_size = batch_size + self.max_gradient_steps = max_gradient_steps self.entropy_coef = entropy_coef self.gamma = gamma self.lam = lam self.max_grad_norm = max_grad_norm self.use_clipped_value_loss = use_clipped_value_loss - def test_mode(self): - self.actor.test() - self.critic.test() - def switch_to_train(self): self.actor.train() self.critic.train() - def act(self, obs, critic_obs): + def act(self, obs): return self.actor.act(obs).detach() - def update(self, data, last_obs=None): - if last_obs is None: - last_values = None - else: - with torch.no_grad(): - last_values = self.critic.evaluate(last_obs).detach() - compute_generalized_advantages( - data, self.gamma, self.lam, self.critic, last_values + def update(self, data): + data["values"] = self.critic.evaluate(data["critic_obs"]) + data["advantages"] = compute_generalized_advantages( + data, self.gamma, self.lam, self.critic ) - + data["returns"] = data["advantages"] + data["values"] self.update_critic(data) + data["advantages"] = normalize(data["advantages"]) self.update_actor(data) def update_critic(self, data): self.mean_value_loss = 0 counter = 0 - n, m = data.shape - total_data = n * m - batch_size = total_data // self.num_mini_batches - generator = create_uniform_generator(data, batch_size, self.num_learning_epochs) + generator = create_uniform_generator( + data, + self.batch_size, + max_gradient_steps=self.max_gradient_steps, + ) for batch in generator: - value_batch = self.critic.evaluate(batch["critic_obs"]) - value_loss = self.critic.loss_fn(value_batch, batch["returns"]) + value_loss = self.critic.loss_fn(batch["critic_obs"], batch["returns"]) self.critic_optimizer.zero_grad() value_loss.backward() nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm) @@ -94,25 +91,23 @@ def update_critic(self, data): self.mean_value_loss /= counter def update_actor(self, data): - # already done before - # compute_generalized_advantages(data, self.gamma, self.lam, self.critic) self.mean_surrogate_loss = 0 counter = 0 - self.actor.act(data["actor_obs"]) + self.actor.update_distribution(data["actor_obs"]) data["old_sigma_batch"] = self.actor.action_std.detach() data["old_mu_batch"] = self.actor.action_mean.detach() data["old_actions_log_prob_batch"] = self.actor.get_actions_log_prob( data["actions"] ).detach() - n, m = data.shape - total_data = n * m - batch_size = total_data // self.num_mini_batches - generator = create_uniform_generator(data, batch_size, self.num_learning_epochs) + generator = create_uniform_generator( + data, + self.batch_size, + max_gradient_steps=self.max_gradient_steps, + ) for batch in generator: - # ! refactor how this is done - self.actor.act(batch["actor_obs"]) + self.actor.update_distribution(batch["actor_obs"]) actions_log_prob_batch = self.actor.get_actions_log_prob(batch["actions"]) mu_batch = self.actor.action_mean sigma_batch = self.actor.action_std @@ -132,11 +127,16 @@ def update_actor(self, data): axis=-1, ) kl_mean = torch.mean(kl) + lr_min, lr_max = self.lr_range if kl_mean > self.desired_kl * 2.0: - self.learning_rate = max(1e-5, self.learning_rate / 1.5) + self.learning_rate = max( + lr_min, self.learning_rate / self.lr_ratio + ) elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0: - self.learning_rate = min(1e-2, self.learning_rate * 1.5) + self.learning_rate = min( + lr_max, self.learning_rate * self.lr_ratio + ) for param_group in self.optimizer.param_groups: # ! check this @@ -158,10 +158,7 @@ def update_actor(self, data): # * Gradient step self.optimizer.zero_grad() loss.backward() - nn.utils.clip_grad_norm_( - list(self.actor.parameters()) + list(self.critic.parameters()), - self.max_grad_norm, - ) + nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm) self.optimizer.step() self.mean_surrogate_loss += surrogate_loss.item() counter += 1 diff --git a/learning/algorithms/ppo_ipg.py b/learning/algorithms/ppo_ipg.py new file mode 100644 index 00000000..c3f1c5be --- /dev/null +++ b/learning/algorithms/ppo_ipg.py @@ -0,0 +1,267 @@ +import torch +import torch.nn as nn +import torch.optim as optim + +from learning.utils import ( + create_uniform_generator, + compute_generalized_advantages, + normalize, + polyak_update, +) + + +class PPO_IPG: + def __init__( + self, + actor, + critic_v, + critic_q, + target_critic_q, + batch_size=2**15, + max_gradient_steps=10, + clip_param=0.2, + gamma=0.998, + lam=0.95, + entropy_coef=0.0, + learning_rate=1e-3, + max_grad_norm=1.0, + use_clipped_value_loss=True, + schedule="fixed", + desired_kl=0.01, + polyak=0.995, + use_cv=False, + inter_nu=0.2, + beta="off_policy", + device="cpu", + lr_range=[1e-4, 1e-2], + lr_ratio=1.3, + **kwargs, + ): + self.device = device + + self.desired_kl = desired_kl + self.schedule = schedule + self.learning_rate = learning_rate + self.lr_range = lr_range + self.lr_ratio = lr_ratio + + # * PPO components + self.actor = actor.to(self.device) + self.optimizer = optim.Adam(self.actor.parameters(), lr=learning_rate) + self.critic_v = critic_v.to(self.device) + self.critic_v_optimizer = optim.Adam( + self.critic_v.parameters(), lr=learning_rate + ) + + # * IPG components + self.critic_q = critic_q.to(self.device) + self.critic_q_optimizer = optim.Adam( + self.critic_q.parameters(), lr=learning_rate + ) + self.target_critic_q = target_critic_q.to(self.device) + self.target_critic_q.load_state_dict(self.critic_q.state_dict()) + + # * PPO parameters + self.clip_param = clip_param + self.batch_size = batch_size + self.max_gradient_steps = max_gradient_steps + self.entropy_coef = entropy_coef + self.gamma = gamma + self.lam = lam + self.max_grad_norm = max_grad_norm + self.use_clipped_value_loss = use_clipped_value_loss + + # * IPG parameters + self.polyak = polyak + self.use_cv = use_cv + self.inter_nu = inter_nu + self.beta = beta + + def switch_to_train(self): + self.actor.train() + self.critic_v.train() + self.critic_q.train() + + def act(self, obs): + return self.actor.act(obs).detach() + + def update(self, data_onpol, data_offpol): + # On-policy GAE + data_onpol["values"] = self.critic_v.evaluate(data_onpol["critic_obs"]) + data_onpol["advantages"] = compute_generalized_advantages( + data_onpol, self.gamma, self.lam, self.critic_v + ) + data_onpol["returns"] = data_onpol["advantages"] + data_onpol["values"] + data_onpol["advantages"] = normalize(data_onpol["advantages"]) + + # TODO: Possibly use off-policy GAE for V-critic update + # data_offpol["values"] = self.critic_v.evaluate(data_offpol["critic_obs"]) + # data_offpol["advantages"] = compute_generalized_advantages( + # data_offpol, self.gamma, self.lam, self.critic_v + # ) + # data_offpol["returns"] = data_offpol["advantages"] + data_offpol["values"] + + self.update_critic_v(data_onpol) + self.update_critic_q(data_offpol) + self.update_actor(data_onpol, data_offpol) + + def update_critic_q(self, data): + self.mean_q_loss = 0 + counter = 0 + + generator = create_uniform_generator( + data, + self.batch_size, + max_gradient_steps=self.max_gradient_steps, + ) + for batch in generator: + with torch.no_grad(): + action_next = self.actor.act_inference(batch["next_actor_obs"]) + q_input_next = torch.cat( + (batch["next_critic_obs"], action_next), dim=-1 + ) + q_next = self.target_critic_q.evaluate(q_input_next) + q_target = ( + batch["rewards"] + + self.gamma * batch["dones"].logical_not() * q_next + ) + q_input = torch.cat((batch["critic_obs"], batch["actions"]), dim=-1) + q_loss = self.critic_q.loss_fn(q_input, q_target) + self.critic_q_optimizer.zero_grad() + q_loss.backward() + nn.utils.clip_grad_norm_(self.critic_q.parameters(), self.max_grad_norm) + self.critic_q_optimizer.step() + self.mean_q_loss += q_loss.item() + counter += 1 + + # TODO: check where to do polyak update (IPG repo does it here) + self.target_critic_q = polyak_update( + self.critic_q, self.target_critic_q, self.polyak + ) + self.mean_q_loss /= counter + + def update_critic_v(self, data): + self.mean_value_loss = 0 + counter = 0 + + generator = create_uniform_generator( + data, + self.batch_size, + max_gradient_steps=self.max_gradient_steps, + ) + for batch in generator: + value_loss = self.critic_v.loss_fn(batch["critic_obs"], batch["returns"]) + self.critic_v_optimizer.zero_grad() + value_loss.backward() + nn.utils.clip_grad_norm_(self.critic_v.parameters(), self.max_grad_norm) + self.critic_v_optimizer.step() + self.mean_value_loss += value_loss.item() + counter += 1 + self.mean_value_loss /= counter + + def update_actor(self, data_onpol, data_offpol): + self.mean_surrogate_loss = 0 + self.mean_offpol_loss = 0 + counter = 0 + + self.actor.update_distribution(data_onpol["actor_obs"]) + data_onpol["old_sigma"] = self.actor.action_std.detach() + data_onpol["old_mu"] = self.actor.action_mean.detach() + data_onpol["old_actions_log_prob"] = self.actor.get_actions_log_prob( + data_onpol["actions"] + ).detach() + + # Generate off-policy batches and use all on-policy data + generator_offpol = create_uniform_generator( + data_offpol, + self.batch_size, + max_gradient_steps=self.max_gradient_steps, + ) + for batch_offpol in generator_offpol: + self.actor.update_distribution(data_onpol["actor_obs"]) + actions_log_prob_onpol = self.actor.get_actions_log_prob( + data_onpol["actions"] + ) + mu_onpol = self.actor.action_mean + sigma_onpol = self.actor.action_std + + # * KL + if self.desired_kl is not None and self.schedule == "adaptive": + with torch.inference_mode(): + kl = torch.sum( + torch.log(sigma_onpol / data_onpol["old_sigma"] + 1.0e-5) + + ( + torch.square(data_onpol["old_sigma"]) + + torch.square(data_onpol["old_mu"] - mu_onpol) + ) + / (2.0 * torch.square(sigma_onpol)) + - 0.5, + axis=-1, + ) + kl_mean = torch.mean(kl) + lr_min, lr_max = self.lr_range + + if kl_mean > self.desired_kl * 2.0: + self.learning_rate = max( + lr_min, self.learning_rate / self.lr_ratio + ) + elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0: + self.learning_rate = min( + lr_max, self.learning_rate * self.lr_ratio + ) + + for param_group in self.optimizer.param_groups: + # ! check this + param_group["lr"] = self.learning_rate + + # * On-policy surrogate loss + adv_onpol = data_onpol["advantages"] + if self.use_cv: + # TODO: control variate + critic_based_adv = 0 # get_control_variate(data_onpol, self.critic_v) + learning_signals = (adv_onpol - critic_based_adv) * (1 - self.inter_nu) + else: + learning_signals = adv_onpol * (1 - self.inter_nu) + + ratio = torch.exp( + actions_log_prob_onpol - data_onpol["old_actions_log_prob"] + ) + ratio_clipped = torch.clamp( + ratio, 1.0 - self.clip_param, 1.0 + self.clip_param + ) + surrogate = -learning_signals * ratio + surrogate_clipped = -learning_signals * ratio_clipped + loss_onpol = torch.max(surrogate, surrogate_clipped).mean() + + # * Off-policy loss + if self.beta == "on_policy": + loss_offpol = self.compute_loss_offpol(data_onpol) + elif self.beta == "off_policy": + loss_offpol = self.compute_loss_offpol(batch_offpol) + else: + raise ValueError(f"Invalid beta value: {self.beta}") + + if self.use_cv: + b = 1 + else: + b = self.inter_nu + + loss = loss_onpol + b * loss_offpol.requires_grad_() + + # * Gradient step + self.optimizer.zero_grad() + loss.backward() + nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm) + self.optimizer.step() + self.mean_surrogate_loss += loss_onpol.item() + self.mean_offpol_loss += b * loss_offpol.item() + counter += 1 + self.mean_surrogate_loss /= counter + self.mean_offpol_loss /= counter + + def compute_loss_offpol(self, data): + obs = data["actor_obs"] + actions = self.actor.act_inference(obs) + q_input = torch.cat((data["critic_obs"], actions), dim=-1) + q_value = self.critic_q.evaluate(q_input) + return -q_value.mean() diff --git a/learning/modules/__init__.py b/learning/modules/__init__.py index 3caf5fec..410e2f25 100644 --- a/learning/modules/__init__.py +++ b/learning/modules/__init__.py @@ -34,3 +34,4 @@ from .state_estimator import StateEstimatorNN from .actor import Actor from .critic import Critic +from .smooth_actor import SmoothActor diff --git a/learning/modules/actor.py b/learning/modules/actor.py index fbaa6868..9fc37c3b 100644 --- a/learning/modules/actor.py +++ b/learning/modules/actor.py @@ -5,6 +5,8 @@ from .utils import export_network from .utils import RunningMeanStd +from gym import LEGGED_GYM_ROOT_DIR + class Actor(nn.Module): def __init__( @@ -25,7 +27,11 @@ def __init__( self.num_obs = num_obs self.num_actions = num_actions - self.NN = create_MLP(num_obs, num_actions, hidden_dims, activation) + self.hidden_dims = hidden_dims + self.activation = activation + self.NN = create_MLP( + num_obs, num_actions, hidden_dims, activation, latent=False + ) # Action noise self.std = nn.Parameter(init_noise_std * torch.ones(num_actions)) @@ -33,6 +39,9 @@ def __init__( # disable args validation for speedup Normal.set_default_validate_args = False + # Debug mode for plotting + self.debug = False + @property def action_mean(self): return self.distribution.mean @@ -51,7 +60,12 @@ def update_distribution(self, observations): def act(self, observations): self.update_distribution(observations) - return self.distribution.sample() + sample = self.distribution.sample() + if self.debug: + mean = self.distribution.mean + path = f"{LEGGED_GYM_ROOT_DIR}/plots/distribution_baseline.csv" + self.log_actions(mean[0][2], sample[0][2], path) + return sample def get_actions_log_prob(self, actions): return self.distribution.log_prob(actions).sum(dim=-1) @@ -67,3 +81,7 @@ def forward(self, observations): def export(self, path): export_network(self, "policy", path, self.num_obs) + + def log_actions(self, mean, sample, path): + with open(path, "a") as f: + f.write(str(mean.item()) + ", " + str(sample.item()) + "\n") diff --git a/learning/modules/critic.py b/learning/modules/critic.py index 732f8eda..f2e8678f 100644 --- a/learning/modules/critic.py +++ b/learning/modules/critic.py @@ -20,11 +20,14 @@ def __init__( if self._normalize_obs: self.obs_rms = RunningMeanStd(num_obs) - def evaluate(self, critic_observations): + def forward(self, x): if self._normalize_obs: with torch.no_grad(): - critic_observations = self.obs_rms(critic_observations) - return self.NN(critic_observations).squeeze() + x = self.obs_rms(x) + return self.NN(x).squeeze(-1) + + def evaluate(self, critic_observations): + return self.forward(critic_observations) - def loss_fn(self, input, target): - return nn.functional.mse_loss(input, target, reduction="mean") + def loss_fn(self, obs, target): + return nn.functional.mse_loss(self.forward(obs), target, reduction="mean") diff --git a/learning/modules/lqrc/plotting.py b/learning/modules/lqrc/plotting.py new file mode 100644 index 00000000..dcc913c9 --- /dev/null +++ b/learning/modules/lqrc/plotting.py @@ -0,0 +1,142 @@ +import matplotlib as mpl +import matplotlib.pyplot as plt +import matplotlib.colors as mcolors +from matplotlib.colors import ListedColormap + +import numpy as np + + +def create_custom_bwr_colormap(): + # Define the colors for each segment + dark_blue = [0, 0, 0.5, 1] + light_blue = [0.5, 0.5, 1, 1] + white = [1, 1, 1, 1] + light_red = [1, 0.5, 0.5, 1] + dark_red = [0.5, 0, 0, 1] + + # Number of bins for each segment + n_bins = 128 + mid_band = 5 + + # Create the colormap segments + blue_segment = np.linspace(dark_blue, light_blue, n_bins // 2) + white_segment = np.tile(white, (mid_band, 1)) + red_segment = np.linspace(light_red, dark_red, n_bins // 2) + + # Stack segments to create the full colormap + colors = np.vstack((blue_segment, white_segment, red_segment)) + custom_bwr = ListedColormap(colors, name="custom_bwr") + + return custom_bwr + + +def plot_pendulum_critics_with_data( + x, + predictions, + targets, + actions, + title, + fn, + data, + colorbar_label="f(x)", + grid_size=64, +): + num_critics = len(x.keys()) + fig, axes = plt.subplots(nrows=2, ncols=num_critics, figsize=(4 * num_critics, 6)) + + # Determine global min and max error for consistent scaling + global_min_error = float("inf") + global_max_error = float("-inf") + global_min_prediction = float("inf") + global_max_prediction = float("-inf") + prediction_cmap = mpl.cm.get_cmap("viridis") + action_cmap = mpl.cm.get_cmap("bwr") + + for critic_name in x: + np_predictions = predictions[critic_name].detach().cpu().numpy().reshape(-1) + np_targets = ( + targets["Ground Truth MC Returns"].detach().cpu().numpy().reshape(-1) + ) + np_error = np_predictions - np_targets + global_min_error = min(global_min_error, np.min(np_error)) + global_max_error = max(global_max_error, np.max(np_error)) + global_min_prediction = min( + global_min_prediction, np.min(np_predictions), np.min(np_targets) + ) + global_max_prediction = max( + global_max_prediction, np.max(np_predictions), np.max(np_targets) + ) + prediction_norm = mcolors.CenteredNorm( + vcenter=(global_max_prediction + global_min_prediction) / 2, + halfrange=(global_max_prediction - global_min_prediction) / 2, + ) + action_norm = mcolors.CenteredNorm() + + xcord = np.linspace(-2 * np.pi, 2 * np.pi, grid_size) + ycord = np.linspace(-5, 5, grid_size) + + for idx, critic_name in enumerate(x): + np_predictions = predictions[critic_name].detach().cpu().numpy().reshape(-1) + np_targets = ( + targets["Ground Truth MC Returns"].detach().cpu().numpy().reshape(-1) + ) + np_error = np_predictions - np_targets + + # Plot Predictions + axes[0, idx].imshow( + np_predictions.reshape(grid_size, grid_size).T, + origin="lower", + extent=(xcord.min(), xcord.max(), ycord.min(), ycord.max()), + cmap=prediction_cmap, + norm=prediction_norm, + ) + axes[0, idx].set_title(f"{critic_name} Prediction") + + # Plot MC Trajectories + data = data.detach().cpu().numpy() + theta = data[:, :, 0] + omega = data[:, :, 1] + axes[1, 0].plot(theta, omega, lw=1) + axes[1, 0].set_xlabel("theta") + axes[1, 0].set_ylabel("theta_dot") + fig.suptitle(title, fontsize=16) + + # Plot Actions + np_actions = actions.detach().cpu().numpy().reshape(-1) + axes[1, 1].imshow( + np_actions.reshape(grid_size, grid_size).T, + origin="lower", + extent=(xcord.min(), xcord.max(), ycord.min(), ycord.max()), + cmap=action_cmap, + norm=action_norm, + ) + axes[1, 1].set_title("Actions") + ax1_mappable = mpl.cm.ScalarMappable(norm=action_norm, cmap=action_cmap) + + # Last axis empty + axes[1, 2].axis("off") + + # Ensure the axes are the same for all plots + for ax in axes.flat: + ax.set_xlim([xcord.min(), xcord.max()]) + ax.set_ylim([ycord.min(), ycord.max()]) + + plt.subplots_adjust( + top=0.9, bottom=0.1, left=0.1, right=0.9, hspace=0.4, wspace=0.3 + ) + + fig.colorbar( + mpl.cm.ScalarMappable(norm=prediction_norm, cmap=prediction_cmap), + ax=axes[0, :].ravel().tolist(), + shrink=0.95, + label=colorbar_label, + ) + fig.colorbar( + ax1_mappable, + ax=axes[1, :].ravel().tolist(), + shrink=0.95, + label=colorbar_label, + ) + + plt.savefig(f"{fn}.png") + print(f"Saved to {fn}.png") diff --git a/learning/modules/smooth_actor.py b/learning/modules/smooth_actor.py new file mode 100644 index 00000000..f824d263 --- /dev/null +++ b/learning/modules/smooth_actor.py @@ -0,0 +1,133 @@ +import torch +import torch.nn as nn +from torch.distributions import Normal + +from .actor import Actor +from .utils import create_MLP + +from gym import LEGGED_GYM_ROOT_DIR + + +# The following implementation is based on the gSDE paper. See code: +# https://github.com/DLR-RM/stable-baselines3/blob/56f20e40a2206bbb16501a0f600e29ce1b112ef1/stable_baselines3/common/distributions.py#L421C7-L421C38 +class SmoothActor(Actor): + weights_dist: Normal + latent_sde: torch.Tensor + exploration_matrices: torch.Tensor + exploration_scale: float + + def __init__( + self, + *args, + full_std: bool = True, + use_exp_ln: bool = True, + learn_features: bool = True, + epsilon: float = 1e-6, + log_std_init: float = 0.0, + exploration_scale: float = 1.0, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.full_std = full_std + self.use_exp_ln = use_exp_ln + self.learn_features = learn_features + self.epsilon = epsilon + self.log_std_init = log_std_init + self.exploration_scale = exploration_scale # for finetuning + + # Create latent NN and last layer + self.latent_net = create_MLP( + self.num_obs, + self.num_actions, + self.hidden_dims, + self.activation, + latent=True, + ) + self.latent_dim = self.hidden_dims[-1] + self.mean_actions_net = nn.Linear(self.latent_dim, self.num_actions) + # Reduce the number of parameters if needed + if self.full_std: + log_std = torch.ones(self.latent_dim, self.num_actions) + else: + log_std = torch.ones(self.latent_dim, 1) + self.log_std = nn.Parameter(log_std * self.log_std_init, requires_grad=True) + # Sample an exploration matrix + self.sample_weights() + self.distribution = None + + # Debug mode for plotting + self.debug = False + + def sample_weights(self, batch_size=1): + # Sample weights for the noise exploration matrix + std = self.get_std + self.weights_dist = Normal(torch.zeros_like(std), std) + # Pre-compute matrices in case of parallel exploration + self.exploration_matrices = self.weights_dist.rsample((batch_size,)) + + @property + def get_std(self): + # TODO[lm]: Check if this is ok, and can use action_std in ActorCritic normally + if self.use_exp_ln: + # From gSDE paper, it allows to keep variance + # above zero and prevent it from growing too fast + below_threshold = torch.exp(self.log_std) * (self.log_std <= 0) + # Avoid NaN: zeros values that are below zero + safe_log_std = self.log_std * (self.log_std > 0) + self.epsilon + above_threshold = (torch.log1p(safe_log_std) + 1.0) * (self.log_std > 0) + std = below_threshold + above_threshold + else: + # Use normal exponential + std = torch.exp(self.log_std) + + if self.full_std: + return std + assert self.latent_dim is not None + # Reduce the number of parameters: + return torch.ones(self.latent_dim, 1).to(self.log_std.device) * std + + def update_distribution(self, observations): + if self._normalize_obs: + with torch.no_grad(): + observations = self.obs_rms(observations) + # Get latent features and compute distribution + self.latent_sde = self.latent_net(observations) + std_scaled = self.get_std * self.exploration_scale + if not self.learn_features: + self.latent_sde = self.latent_sde.detach() + if self.latent_sde.dim() == 2: + variance = torch.mm(self.latent_sde**2, std_scaled**2) + elif self.latent_sde.dim() == 3: + variance = torch.einsum("abc,cd->abd", self.latent_sde**2, std_scaled**2) + else: + raise ValueError("Invalid latent_sde dimension") + mean_actions = self.mean_actions_net(self.latent_sde) + self.distribution = Normal(mean_actions, torch.sqrt(variance + self.epsilon)) + + def act(self, observations): + self.update_distribution(observations) + mean = self.distribution.mean + sample = mean + self.get_noise() * self.exploration_scale + if self.debug: + path = f"{LEGGED_GYM_ROOT_DIR}/plots/distribution_smooth.csv" + self.log_actions(mean[0][2], sample[0][2], path) + return sample + + def act_inference(self, observations): + if self._normalize_obs: + with torch.no_grad(): + observations = self.obs_rms(observations) + latent_sde = self.latent_net(observations) + mean_actions = self.mean_actions_net(latent_sde) + return mean_actions + + def get_noise(self): + latent_sde = self.latent_sde + if not self.learn_features: + latent_sde = latent_sde.detach() + # Use batch matrix multiplication for efficient computation + # (batch_size, n_features) -> (batch_size, 1, n_features) + latent_sde = latent_sde.unsqueeze(dim=1) + # (batch_size, 1, n_actions) + noise = torch.bmm(latent_sde, self.exploration_matrices.to(latent_sde.device)) + return noise.squeeze(dim=1) diff --git a/learning/modules/state_estimator.py b/learning/modules/state_estimator.py index f8cfe221..26a606f7 100644 --- a/learning/modules/state_estimator.py +++ b/learning/modules/state_estimator.py @@ -1,6 +1,5 @@ import torch.nn as nn -from .utils import create_MLP -from .utils import export_network +from .utils import create_MLP, export_network, RunningMeanStd class StateEstimatorNN(nn.Module): @@ -21,6 +20,7 @@ def __init__( hidden_dims=[256, 128], activation="elu", dropouts=None, + normalize_obs=True, **kwargs, ): if kwargs: @@ -30,13 +30,19 @@ def __init__( ) super().__init__() + self._normalize_obs = normalize_obs + if self._normalize_obs: + self.obs_rms = RunningMeanStd(num_inputs) + self.num_inputs = num_inputs self.num_outputs = num_outputs self.NN = create_MLP(num_inputs, num_outputs, hidden_dims, activation, dropouts) print(f"State Estimator MLP: {self.NN}") - def evaluate(self, observations): - return self.NN(observations) + def evaluate(self, obs): + if self._normalize_obs: + obs = self.obs_rms(obs) + return self.NN(obs) def export(self, path): export_network(self.NN, "state_estimator", path, self.num_inputs) diff --git a/learning/modules/utils/neural_net.py b/learning/modules/utils/neural_net.py index 5fd6ab3b..d452da57 100644 --- a/learning/modules/utils/neural_net.py +++ b/learning/modules/utils/neural_net.py @@ -1,9 +1,13 @@ import torch import os -import copy +# import copy +import numpy as np -def create_MLP(num_inputs, num_outputs, hidden_dims, activation, dropouts=None): + +def create_MLP( + num_inputs, num_outputs, hidden_dims, activation, dropouts=None, latent=False +): activation = get_activation(activation) if dropouts is None: @@ -15,8 +19,12 @@ def create_MLP(num_inputs, num_outputs, hidden_dims, activation, dropouts=None): else: add_layer(layers, num_inputs, hidden_dims[0], activation, dropouts[0]) for i in range(len(hidden_dims)): + # TODO[lm]: Could also create a separate function that gives the latent + # reprentation used for smooth exploration (but if it doesn't mess up + # anything, this is simpler) if i == len(hidden_dims) - 1: - add_layer(layers, hidden_dims[i], num_outputs) + if not latent: + add_layer(layers, hidden_dims[i], num_outputs) else: add_layer( layers, @@ -56,7 +64,7 @@ def add_layer(layer_list, num_inputs, num_outputs, activation=None, dropout=0): layer_list.append(activation) -def export_network(network, network_name, path, num_inputs): +def export_network(network, network_name, path, num_inputs, latent=True): """ Thsi function traces and exports the given network module in .pt and .onnx file formats. These can be used for evaluation on other systems @@ -71,10 +79,26 @@ def export_network(network, network_name, path, num_inputs): os.makedirs(path, exist_ok=True) path_TS = os.path.join(path, network_name + ".pt") # TorchScript path path_onnx = os.path.join(path, network_name + ".onnx") # ONNX path - model = copy.deepcopy(network).to("cpu") + # model = copy.deepcopy(network).to("cpu") + model = network.to("cpu") # no deepcopy # To trace model, must be evaluated once with arbitrary input model.eval() - dummy_input = torch.rand((1, num_inputs)) + dummy_input = torch.rand((num_inputs)) model_traced = torch.jit.trace(model, dummy_input) torch.jit.save(model_traced, path_TS) torch.onnx.export(model_traced, dummy_input, path_onnx) + + if latent: + # Export latent model + path_latent = os.path.join(path, network_name + "_latent.onnx") + model_latent = torch.nn.Sequential(model.obs_rms, model.latent_net) + model_latent.eval() + dummy_input = torch.rand((num_inputs)) + model_traced = torch.jit.trace(model_latent, dummy_input) + torch.onnx.export(model_traced, dummy_input, path_latent) + + # Save actor std of shape (num_actions, latent_dim) + # It is important that the shape is the same as the exploration matrix + path_std = os.path.join(path, network_name + "_std.txt") + std_transposed = model.get_std.detach().numpy().T + np.savetxt(path_std, std_transposed) diff --git a/learning/modules/utils/normalize.py b/learning/modules/utils/normalize.py index 246bafa4..0a60dcbf 100644 --- a/learning/modules/utils/normalize.py +++ b/learning/modules/utils/normalize.py @@ -41,7 +41,7 @@ def _update_mean_var_from_moments( def forward(self, input): if self.training: mean = input.mean(tuple(range(input.dim() - 1))) - var = input.var(tuple(range(input.dim() - 1))) + var = torch.nan_to_num(input.var(tuple(range(input.dim() - 1)))) ( self.running_mean, self.running_var, diff --git a/learning/runners/BaseRunner.py b/learning/runners/BaseRunner.py index 9e5ab2e3..39ce336d 100644 --- a/learning/runners/BaseRunner.py +++ b/learning/runners/BaseRunner.py @@ -1,6 +1,6 @@ import torch from learning.algorithms import * # noqa: F403 -from learning.modules import Actor, Critic +from learning.modules import Actor, Critic, SmoothActor from learning.utils import remove_zero_weighted_rewards @@ -22,7 +22,10 @@ def _set_up_alg(self): num_actor_obs = self.get_obs_size(self.actor_cfg["obs"]) num_actions = self.get_action_size(self.actor_cfg["actions"]) num_critic_obs = self.get_obs_size(self.critic_cfg["obs"]) - actor = Actor(num_actor_obs, num_actions, **self.actor_cfg) + if self.actor_cfg["smooth_exploration"]: + actor = SmoothActor(num_actor_obs, num_actions, **self.actor_cfg) + else: + actor = Actor(num_actor_obs, num_actions, **self.actor_cfg) critic = Critic(num_critic_obs, **self.critic_cfg) alg_class = eval(self.cfg["algorithm_class_name"]) self.alg = alg_class(actor, critic, device=self.device, **self.alg_cfg) diff --git a/learning/runners/__init__.py b/learning/runners/__init__.py index b0f49217..a2cc7576 100644 --- a/learning/runners/__init__.py +++ b/learning/runners/__init__.py @@ -32,4 +32,5 @@ from .on_policy_runner import OnPolicyRunner from .my_runner import MyRunner -from .old_policy_runner import OldPolicyRunner \ No newline at end of file +from .old_policy_runner import OldPolicyRunner +from .ipg_runner import IPGRunner \ No newline at end of file diff --git a/learning/runners/finetune_runner.py b/learning/runners/finetune_runner.py new file mode 100644 index 00000000..c4a33f64 --- /dev/null +++ b/learning/runners/finetune_runner.py @@ -0,0 +1,454 @@ +from learning.algorithms import * # noqa: F403 +from learning.algorithms import StateEstimator +from learning.modules import Actor, SmoothActor, Critic, StateEstimatorNN +from learning.storage import DictStorage +from gym.envs.mini_cheetah.minimalist_cheetah import MinimalistCheetah +from .BaseRunner import BaseRunner + +import torch +import os +import scipy.io +from tensordict import TensorDict + +sim_storage = DictStorage() + + +class FineTuneRunner(BaseRunner): + def __init__( + self, + env, + train_cfg, + log_dir, + data_list, + data_length=3000, + data_name="SMOOTH_RL_CONTROLLER", + se_path=None, + use_simulator=True, + exploration_scale=1.0, + device="cpu", + ): + # Instead of super init, only set necessary attributes + self.env = env + self.parse_train_cfg(train_cfg) + self.num_steps_per_env = self.cfg["num_steps_per_env"] + + self.log_dir = log_dir + self.data_list = data_list # Describes structure of Robot-Software logs + self.data_length = data_length # Logs must contain at least this many steps + self.data_name = data_name + self.se_path = se_path + self.use_simulator = use_simulator + self.exploration_scale = exploration_scale + self.device = device + self._set_up_alg() + + def _set_up_alg(self): + num_actor_obs = self.get_obs_size(self.actor_cfg["obs"]) + num_actions = self.get_action_size(self.actor_cfg["actions"]) + num_critic_obs = self.get_obs_size(self.critic_cfg["obs"]) + if self.actor_cfg["smooth_exploration"]: + actor = SmoothActor( + num_obs=num_actor_obs, + num_actions=num_actions, + exploration_scale=self.exploration_scale, + **self.actor_cfg, + ) + else: + actor = Actor(num_actor_obs, num_actions, **self.actor_cfg) + + alg_name = self.cfg["algorithm_class_name"] + alg_class = eval(alg_name) + self.ipg = alg_name in ["PPO_IPG", "LinkedIPG"] + + if self.ipg: + critic_v = Critic(num_critic_obs, **self.critic_cfg) + critic_q = Critic(num_critic_obs + num_actions, **self.critic_cfg) + target_critic_q = Critic(num_critic_obs + num_actions, **self.critic_cfg) + self.alg = alg_class( + actor, + critic_v, + critic_q, + target_critic_q, + device=self.device, + **self.alg_cfg, + ) + else: + critic = Critic(num_critic_obs, **self.critic_cfg) + self.alg = alg_class(actor, critic, device=self.device, **self.alg_cfg) + + if "state_estimator" in self.train_cfg.keys() and self.se_path is not None: + self.se_cfg = self.train_cfg["state_estimator"] + state_estimator_network = StateEstimatorNN( + self.get_obs_size(self.se_cfg["obs"]), + self.get_obs_size(self.se_cfg["targets"]), + **self.se_cfg["network"], + ) + self.SE = StateEstimator( + state_estimator_network, device=self.device, **self.se_cfg + ) + self.load_se(self.se_path) + else: + self.SE = None + + def parse_train_cfg(self, train_cfg): + self.cfg = train_cfg["runner"] + self.alg_cfg = train_cfg["algorithm"] + self.actor_cfg = train_cfg["actor"] + self.critic_cfg = train_cfg["critic"] + self.train_cfg = train_cfg + + def get_data_dict(self, offpol=False, load_path=None, save_path=None): + # Concatenate data with loaded dict + loaded_data_dict = torch.load(load_path) if load_path else None + + checkpoint = self.cfg["checkpoint"] + if offpol: + # All files up until checkpoint + log_files = [ + file + for file in os.listdir(self.log_dir) + if file.endswith(".mat") and int(file.split(".")[0]) <= checkpoint + ] + log_files = sorted(log_files) + else: + # Single log file for checkpoint + log_files = [str(checkpoint) + ".mat"] + + # Initialize data dict + data = scipy.io.loadmat(os.path.join(self.log_dir, log_files[0])) + batch_size = (self.data_length - 1, len(log_files)) # -1 for next_obs + data_dict = TensorDict({}, device=self.device, batch_size=batch_size) + + # Collect all data + actor_obs_all = torch.empty(0).to(self.device) + critic_obs_all = torch.empty(0).to(self.device) + actions_all = torch.empty(0).to(self.device) + rewards_all = torch.empty(0).to(self.device) + for log in log_files: + data = scipy.io.loadmat(os.path.join(self.log_dir, log)) + self.data_struct = data[self.data_name][0][0] + if self.SE: + self.update_state_estimates() + + actor_obs = self.get_data_obs(self.actor_cfg["obs"], self.data_struct) + critic_obs = self.get_data_obs(self.critic_cfg["obs"], self.data_struct) + actor_obs_all = torch.cat((actor_obs_all, actor_obs), dim=1) + critic_obs_all = torch.cat((critic_obs_all, critic_obs), dim=1) + + actions_idx = self.data_list.index("dof_pos_target") + actions = ( + torch.tensor(self.data_struct[actions_idx]).to(self.device).float() + ) + actions = actions[: self.data_length] + actions = actions.reshape( + (self.data_length, 1, -1) + ) # shape (data_length, 1, n) + actions_all = torch.cat((actions_all, actions), dim=1) + + reward_weights = self.critic_cfg["reward"]["weights"] + rewards, _ = self.get_data_rewards(self.data_struct, reward_weights) + rewards = rewards[: self.data_length] + rewards = rewards.reshape((self.data_length, 1)) # shape (data_length, 1) + rewards_all = torch.cat((rewards_all, rewards), dim=1) + + data_dict["actor_obs"] = actor_obs_all[:-1] + data_dict["next_actor_obs"] = actor_obs_all[1:] + data_dict["critic_obs"] = critic_obs_all[:-1] + data_dict["next_critic_obs"] = critic_obs_all[1:] + data_dict["actions"] = actions_all[:-1] + data_dict["rewards"] = rewards_all[:-1] + + # No time outs and dones + data_dict["timed_out"] = torch.zeros(batch_size, device=self.device, dtype=bool) + data_dict["dones"] = torch.zeros(batch_size, device=self.device, dtype=bool) + + # Concatenate with loaded dict + if loaded_data_dict is not None: + loaded_batch_size = loaded_data_dict.batch_size + assert loaded_batch_size[0] == batch_size[0] + new_batch_size = ( + loaded_batch_size[0], + loaded_batch_size[1] + batch_size[1], + ) + data_dict = TensorDict( + { + key: torch.cat((loaded_data_dict[key], data_dict[key]), dim=1) + for key in data_dict.keys() + }, + device=self.device, + batch_size=new_batch_size, + ) + + if save_path: + torch.save(data_dict, save_path) + + return data_dict + + def get_data_obs(self, obs_list, data_struct): + obs_all = torch.empty(0).to(self.device) + for obs_name in obs_list: + data_idx = self.data_list.index(obs_name) + obs = torch.tensor(data_struct[data_idx]).to(self.device) + obs = obs.squeeze()[: self.data_length] + obs = obs.reshape((self.data_length, 1, -1)) # shape (data_length, 1, n) + obs_all = torch.cat((obs_all, obs), dim=-1) + + return obs_all.float() + + def get_data_rewards(self, data_struct, reward_weights): + ctrl_dt = 1.0 / self.env.cfg.control.ctrl_frequency + minimalist_cheetah = MinimalistCheetah(ctrl_dt=ctrl_dt, device=self.device) + rewards_dict = {name: [] for name in reward_weights.keys()} # for plotting + rewards_all = torch.empty(0).to(self.device) + + for i in range(self.data_length): + minimalist_cheetah.set_states( + base_height=data_struct[1][i], + base_lin_vel=data_struct[2][i], + base_ang_vel=data_struct[3][i], + proj_gravity=data_struct[4][i], + commands=data_struct[5][i], + dof_pos_obs=data_struct[6][i], + dof_vel=data_struct[7][i], + phase_obs=data_struct[8][i], + grf=data_struct[9][i], + dof_pos_target=data_struct[10][i], + ) + total_rewards = 0 + for name, weight in reward_weights.items(): + reward = weight * eval(f"minimalist_cheetah._reward_{name}()") + rewards_dict[name].append(reward.item()) + total_rewards += reward + rewards_all = torch.cat((rewards_all, total_rewards), dim=0) + # Post process mini cheetah + minimalist_cheetah.post_process() + + rewards_dict["total"] = rewards_all.tolist() + rewards_all *= ctrl_dt # scaled for alg update + + return rewards_all.float(), rewards_dict + + def update_state_estimates(self): + se_obs = torch.empty(0).to(self.device) + for obs in self.se_cfg["obs"]: + data_idx = self.data_list.index(obs) + data = torch.tensor(self.data_struct[data_idx]).to(self.device) + data = data.squeeze()[: self.data_length] + data = data.reshape((self.data_length, -1)) + se_obs = torch.cat((se_obs, data), dim=-1) + + se_targets = self.SE.estimate(se_obs.float()) + + # Overwrite data struct with state estimates + idx = 0 + for target in self.se_cfg["targets"]: + data_idx = self.data_list.index(target) + dim = self.data_struct[data_idx].shape[1] + self.data_struct[data_idx] = ( + se_targets[:, idx : idx + dim].cpu().detach().numpy() + ) + idx += dim + + def load_data(self, load_path=None, save_path=None): + # Load on- and off-policy data + if self.use_simulator: + # Simulate on-policy data + self.data_onpol = TensorDict( + self.get_sim_data(), + batch_size=(self.num_steps_per_env, self.env.num_envs), + device=self.device, + ) + else: + self.data_onpol = self.get_data_dict() + + if self.ipg: + self.data_offpol = self.get_data_dict( + offpol=True, load_path=load_path, save_path=save_path + ) + else: + self.data_offpol = None + + def learn(self): + self.alg.switch_to_train() + + # Set fixed actor LR from config + for param_group in self.alg.optimizer.param_groups: + param_group["lr"] = self.alg_cfg["learning_rate"] + + # Single alg update on data + if self.data_offpol is None: + self.alg.update(self.data_onpol) + else: + self.alg.update(self.data_onpol, self.data_offpol) + + def get_sim_data(self): + rewards_dict = {} + actor_obs = self.get_obs(self.actor_cfg["obs"]) + critic_obs = self.get_obs(self.critic_cfg["obs"]) + + # * Initialize smooth exploration matrices + if self.actor_cfg["smooth_exploration"]: + self.alg.actor.sample_weights(batch_size=self.env.num_envs) + + # * Start up storage + transition = TensorDict({}, batch_size=self.env.num_envs, device=self.device) + transition.update( + { + "actor_obs": actor_obs, + "next_actor_obs": actor_obs, + "actions": self.alg.act(actor_obs), + "critic_obs": critic_obs, + "next_critic_obs": critic_obs, + "rewards": self.get_rewards({"termination": 0.0})["termination"], + "dones": self.get_timed_out(), + } + ) + sim_storage.initialize( + transition, + self.env.num_envs, + self.env.num_envs * self.num_steps_per_env, + device=self.device, + ) + + # * Rollout + with torch.inference_mode(): + for i in range(self.num_steps_per_env): + # * Re-sample noise matrix for smooth exploration + sample_freq = self.actor_cfg["exploration_sample_freq"] + if self.actor_cfg["smooth_exploration"] and i % sample_freq == 0: + self.alg.actor.sample_weights(batch_size=self.env.num_envs) + + actions = self.alg.act(actor_obs) + self.set_actions( + self.actor_cfg["actions"], + actions, + self.actor_cfg["disable_actions"], + ) + + transition.update( + { + "actor_obs": actor_obs, + "actions": actions, + "critic_obs": critic_obs, + } + ) + + self.env.step() + + actor_obs = self.get_noisy_obs( + self.actor_cfg["obs"], self.actor_cfg["noise"] + ) + critic_obs = self.get_obs(self.critic_cfg["obs"]) + + # * get time_outs + timed_out = self.get_timed_out() + terminated = self.get_terminated() + dones = timed_out | terminated + + self.update_rewards(rewards_dict, terminated) + total_rewards = torch.stack(tuple(rewards_dict.values())).sum(dim=0) + + transition.update( + { + "next_actor_obs": actor_obs, + "next_critic_obs": critic_obs, + "rewards": total_rewards, + "timed_out": timed_out, + "dones": dones, + } + ) + sim_storage.add_transitions(transition) + + return sim_storage.data + + def update_rewards(self, rewards_dict, terminated): + rewards_dict.update( + self.get_rewards( + self.critic_cfg["reward"]["termination_weight"], mask=terminated + ) + ) + rewards_dict.update( + self.get_rewards( + self.critic_cfg["reward"]["weights"], + modifier=self.env.dt, + mask=~terminated, + ) + ) + + def save(self, path): + if self.ipg: + torch.save( + { + "actor_state_dict": self.alg.actor.state_dict(), + "critic_v_state_dict": self.alg.critic_v.state_dict(), + "critic_q_state_dict": self.alg.critic_q.state_dict(), + "target_critic_q_state_dict": self.alg.target_critic_q.state_dict(), + "optimizer_state_dict": self.alg.optimizer.state_dict(), + "critic_v_opt_state_dict": self.alg.critic_v_optimizer.state_dict(), + "critic_q_opt_state_dict": self.alg.critic_q_optimizer.state_dict(), + }, + path, + ) + return + torch.save( + { + "actor_state_dict": self.alg.actor.state_dict(), + "critic_state_dict": self.alg.critic.state_dict(), + "optimizer_state_dict": self.alg.optimizer.state_dict(), + "critic_optimizer_state_dict": self.alg.critic_optimizer.state_dict(), + }, + path, + ) + + def load(self, path, load_optimizer=True): + loaded_dict = torch.load(path) + self.alg.actor.load_state_dict(loaded_dict["actor_state_dict"]) + + if self.ipg: + self.alg.critic_v.load_state_dict(loaded_dict["critic_v_state_dict"]) + self.alg.critic_q.load_state_dict(loaded_dict["critic_q_state_dict"]) + # TODO: Possibly always load target with critic_q_state_dict + try: + self.alg.target_critic_q.load_state_dict( + loaded_dict["target_critic_q_state_dict"] + ) + except: + self.alg.target_critic_q.load_state_dict( + loaded_dict["critic_q_state_dict"] + ) + if load_optimizer: + self.alg.optimizer.load_state_dict(loaded_dict["optimizer_state_dict"]) + self.alg.critic_v_optimizer.load_state_dict( + loaded_dict["critic_v_opt_state_dict"] + ) + self.alg.critic_q_optimizer.load_state_dict( + loaded_dict["critic_q_opt_state_dict"] + ) + else: + self.alg.critic.load_state_dict(loaded_dict["critic_state_dict"]) + if load_optimizer: + self.alg.optimizer.load_state_dict(loaded_dict["optimizer_state_dict"]) + self.alg.critic_optimizer.load_state_dict( + loaded_dict["critic_optimizer_state_dict"] + ) + + def load_se(self, se_path): + se_dict = torch.load(se_path) + self.SE.network.load_state_dict(se_dict["SE_state_dict"]) + + def export(self, path): + # Need to make a copy of actor + if self.actor_cfg["smooth_exploration"]: + actor_copy = SmoothActor( + self.alg.actor.num_obs, self.alg.actor.num_actions, **self.actor_cfg + ) + else: + actor_copy = Actor( + self.alg.actor.num_obs, self.alg.actor.num_actions, **self.actor_cfg + ) + state_dict = { + name: param.detach().clone() + for name, param in self.alg.actor.state_dict().items() + } + actor_copy.load_state_dict(state_dict) + actor_copy.export(path) diff --git a/learning/runners/ipg_runner.py b/learning/runners/ipg_runner.py new file mode 100644 index 00000000..252f6098 --- /dev/null +++ b/learning/runners/ipg_runner.py @@ -0,0 +1,293 @@ +import os +import torch +import torch.nn as nn +from tensordict import TensorDict + +from learning.utils import Logger + +from .BaseRunner import BaseRunner +from learning.algorithms import * # noqa: F403 +from learning.modules import Actor, SmoothActor, Critic +from learning.storage import DictStorage, ReplayBuffer + +logger = Logger() +storage_onpol = DictStorage() +storage_offpol = ReplayBuffer() + + +class IPGRunner(BaseRunner): + def __init__(self, env, train_cfg, device="cpu"): + super().__init__(env, train_cfg, device) + logger.initialize( + self.env.num_envs, + self.env.dt, + self.cfg["max_iterations"], + self.device, + ) + + def _set_up_alg(self): + alg_class_name = self.cfg["algorithm_class_name"] + if alg_class_name not in ["PPO_IPG", "LinkedIPG"]: + raise ValueError("IPGRunner only supports PPO_IPG or Linked_IPG") + + alg_class = eval(alg_class_name) + num_actor_obs = self.get_obs_size(self.actor_cfg["obs"]) + num_actions = self.get_action_size(self.actor_cfg["actions"]) + num_critic_obs = self.get_obs_size(self.critic_cfg["obs"]) + if self.actor_cfg["smooth_exploration"]: + actor = SmoothActor(num_actor_obs, num_actions, **self.actor_cfg) + else: + actor = Actor(num_actor_obs, num_actions, **self.actor_cfg) + critic_v = Critic(num_critic_obs, **self.critic_cfg) + critic_q = Critic(num_critic_obs + num_actions, **self.critic_cfg) + target_critic_q = Critic(num_critic_obs + num_actions, **self.critic_cfg) + self.alg = alg_class( + actor, + critic_v, + critic_q, + target_critic_q, + device=self.device, + **self.alg_cfg, + ) + + def learn(self): + self.set_up_logger() + + rewards_dict = {} + + self.alg.switch_to_train() + actor_obs = self.get_obs(self.actor_cfg["obs"]) + critic_obs = self.get_obs(self.critic_cfg["obs"]) + tot_iter = self.it + self.num_learning_iterations + # self.save() + + # * Initialize smooth exploration matrices + if self.actor_cfg["smooth_exploration"]: + self.alg.actor.sample_weights(batch_size=self.env.num_envs) + + # * start up both on- and off-policy storage + transition = TensorDict({}, batch_size=self.env.num_envs, device=self.device) + transition.update( + { + "actor_obs": actor_obs, + "next_actor_obs": actor_obs, + "actions": self.alg.act(actor_obs), + "critic_obs": critic_obs, + "next_critic_obs": critic_obs, + "rewards": self.get_rewards({"termination": 0.0})["termination"], + "dones": self.get_timed_out(), + "dof_pos": self.env.dof_pos, + "dof_vel": self.env.dof_vel, + } + ) + storage_onpol.initialize( + transition, + self.env.num_envs, + self.env.num_envs * self.num_steps_per_env, + device=self.device, + ) + storage_offpol.initialize( + transition, + self.env.num_envs, + self.alg_cfg["storage_size"], + device=self.device, + ) + + # burn in observation normalization. + if self.actor_cfg["normalize_obs"] or self.critic_cfg["normalize_obs"]: + self.burn_in_normalization() + + logger.tic("runtime") + for self.it in range(self.it + 1, tot_iter + 1): + logger.tic("iteration") + logger.tic("collection") + # * Rollout + with torch.inference_mode(): + for i in range(self.num_steps_per_env): + # * Re-sample noise matrix for smooth exploration + sample_freq = self.actor_cfg["exploration_sample_freq"] + if self.actor_cfg["smooth_exploration"] and i % sample_freq == 0: + self.alg.actor.sample_weights(batch_size=self.env.num_envs) + + actions = self.alg.act(actor_obs) + self.set_actions( + self.actor_cfg["actions"], + actions, + self.actor_cfg["disable_actions"], + ) + + transition.update( + { + "actor_obs": actor_obs, + "actions": actions, + "critic_obs": critic_obs, + } + ) + + self.env.step() + + actor_obs = self.get_noisy_obs( + self.actor_cfg["obs"], self.actor_cfg["noise"] + ) + critic_obs = self.get_obs(self.critic_cfg["obs"]) + + # * get time_outs + timed_out = self.get_timed_out() + terminated = self.get_terminated() + dones = timed_out | terminated + + self.update_rewards(rewards_dict, terminated) + total_rewards = torch.stack(tuple(rewards_dict.values())).sum(dim=0) + + transition.update( + { + "next_actor_obs": actor_obs, + "next_critic_obs": critic_obs, + "rewards": total_rewards, + "timed_out": timed_out, + "dones": dones, + "dof_pos": self.env.dof_pos, + "dof_vel": self.env.dof_vel, + } + ) + # add transition to both storages + storage_onpol.add_transitions(transition) + storage_offpol.add_transitions(transition) + + logger.log_rewards(rewards_dict) + logger.log_rewards({"total_rewards": total_rewards}) + logger.finish_step(dones) + logger.toc("collection") + + logger.tic("learning") + self.alg.update(storage_onpol.data, storage_offpol.get_data()) + logger.toc("learning") + + if self.it % self.save_interval == 0: + self.save(save_storage=False) + storage_onpol.clear() # only clear on-policy storage + + logger.log_all_categories() + logger.finish_iteration() + logger.toc("iteration") + logger.toc("runtime") + logger.print_to_terminal() + + # self.save() + + @torch.no_grad + def burn_in_normalization(self, n_iterations=100): + actor_obs = self.get_obs(self.actor_cfg["obs"]) + critic_obs = self.get_obs(self.critic_cfg["obs"]) + for _ in range(n_iterations): + actions = self.alg.act(actor_obs) + self.set_actions(self.actor_cfg["actions"], actions) + self.env.step() + actor_obs = self.get_noisy_obs( + self.actor_cfg["obs"], self.actor_cfg["noise"] + ) + critic_obs = self.get_obs(self.critic_cfg["obs"]) + # TODO: Check this, seems to perform better without critic eval + # self.alg.critic_v.evaluate(critic_obs) + q_input = torch.cat((critic_obs, actions), dim=-1) + self.alg.critic_q.evaluate(q_input) + self.alg.target_critic_q.evaluate(q_input) + self.env.reset() + + def update_rewards(self, rewards_dict, terminated): + rewards_dict.update( + self.get_rewards( + self.critic_cfg["reward"]["termination_weight"], mask=terminated + ) + ) + rewards_dict.update( + self.get_rewards( + self.critic_cfg["reward"]["weights"], + modifier=self.env.dt, + mask=~terminated, + ) + ) + + def set_up_logger(self): + logger.register_rewards(list(self.critic_cfg["reward"]["weights"].keys())) + logger.register_rewards( + list(self.critic_cfg["reward"]["termination_weight"].keys()) + ) + logger.register_rewards(["total_rewards"]) + logger.register_category( + "algorithm", + self.alg, + [ + "mean_value_loss", + "mean_surrogate_loss", + "learning_rate", + # IPG specific + "mean_q_loss", + "mean_offpol_loss", + ], + ) + logger.register_category("actor", self.alg.actor, ["action_std", "entropy"]) + + logger.attach_torch_obj_to_wandb( + (self.alg.actor, self.alg.critic_v, self.alg.critic_q) + ) + + def save(self, save_storage=False): + os.makedirs(self.log_dir, exist_ok=True) + path = os.path.join(self.log_dir, "model_{}.pt".format(self.it)) + torch.save( + { + "actor_state_dict": self.alg.actor.state_dict(), + "critic_v_state_dict": self.alg.critic_v.state_dict(), + "critic_q_state_dict": self.alg.critic_q.state_dict(), + "target_critic_q_state_dict": self.alg.target_critic_q.state_dict(), + "optimizer_state_dict": self.alg.optimizer.state_dict(), + "critic_v_opt_state_dict": self.alg.critic_v_optimizer.state_dict(), + "critic_q_opt_state_dict": self.alg.critic_q_optimizer.state_dict(), + "iter": self.it, + }, + path, + ) + if save_storage: + path_onpol = os.path.join(self.log_dir, "data_onpol_{}".format(self.it)) + path_offpol = os.path.join(self.log_dir, "data_offpol_{}".format(self.it)) + torch.save(storage_onpol.data.cpu(), path_onpol + ".pt") + torch.save(storage_offpol.get_data().cpu(), path_offpol + ".pt") + + def load(self, path, load_optimizer=True, load_actor_std=True): + loaded_dict = torch.load(path) + if load_actor_std: + self.alg.actor.load_state_dict(loaded_dict["actor_state_dict"]) + else: + std_init = self.alg.actor.std.detach().clone() + self.alg.actor.load_state_dict(loaded_dict["actor_state_dict"]) + self.alg.actor.std = nn.Parameter(std_init) + self.alg.critic_v.load_state_dict(loaded_dict["critic_v_state_dict"]) + self.alg.critic_q.load_state_dict(loaded_dict["critic_q_state_dict"]) + try: + self.alg.target_critic_q.load_state_dict( + loaded_dict["target_critic_q_state_dict"] + ) + except: + self.alg.target_critic_q.load_state_dict(loaded_dict["critic_q_state_dict"]) + if load_optimizer: + self.alg.optimizer.load_state_dict(loaded_dict["optimizer_state_dict"]) + self.alg.critic_v_optimizer.load_state_dict( + loaded_dict["critic_v_opt_state_dict"] + ) + self.alg.critic_q_optimizer.load_state_dict( + loaded_dict["critic_q_opt_state_dict"] + ) + self.it = loaded_dict["iter"] + + def switch_to_eval(self): + self.alg.actor.eval() + self.alg.critic_v.eval() + self.alg.critic_q.eval() + + def get_inference_actions(self): + obs = self.get_noisy_obs(self.actor_cfg["obs"], self.actor_cfg["noise"]) + return self.alg.actor.act_inference(obs) + + def export(self, path): + self.alg.actor.export(path) diff --git a/learning/runners/old_policy_runner.py b/learning/runners/old_policy_runner.py index 42f019a3..8b58dbe4 100644 --- a/learning/runners/old_policy_runner.py +++ b/learning/runners/old_policy_runner.py @@ -4,7 +4,7 @@ from learning.utils import Logger from .BaseRunner import BaseRunner from learning.algorithms import PPO # noqa: F401 -from learning.modules import ActorCritic, Actor, Critic +from learning.modules import ActorCritic, Actor, Critic, SmoothActor logger = Logger() @@ -24,7 +24,10 @@ def _set_up_alg(self): num_actor_obs = self.get_obs_size(self.actor_cfg["obs"]) num_actions = self.get_action_size(self.actor_cfg["actions"]) num_critic_obs = self.get_obs_size(self.critic_cfg["obs"]) - actor = Actor(num_actor_obs, num_actions, **self.actor_cfg) + if self.actor_cfg["smooth_exploration"]: + actor = SmoothActor(num_actor_obs, num_actions, **self.actor_cfg) + else: + actor = Actor(num_actor_obs, num_actions, **self.actor_cfg) critic = Critic(num_critic_obs, **self.critic_cfg) actor_critic = ActorCritic(actor, critic) alg_class = eval(self.cfg["algorithm_class_name"]) @@ -42,6 +45,10 @@ def learn(self): self.save() + # * Initialize smooth exploration matrices + if self.actor_cfg["smooth_exploration"]: + self.alg.actor_critic.actor.sample_weights(batch_size=self.env.num_envs) + logger.tic("runtime") for self.it in range(self.it + 1, tot_iter + 1): logger.tic("iteration") @@ -49,6 +56,13 @@ def learn(self): # * Rollout with torch.inference_mode(): for i in range(self.num_steps_per_env): + # * Re-sample noise matrix for smooth exploration + sample_freq = self.actor_cfg["exploration_sample_freq"] + if self.actor_cfg["smooth_exploration"] and i % sample_freq == 0: + self.alg.actor_critic.actor.sample_weights( + batch_size=self.env.num_envs + ) + actions = self.alg.act(actor_obs, critic_obs) self.set_actions( self.actor_cfg["actions"], diff --git a/learning/runners/on_policy_runner.py b/learning/runners/on_policy_runner.py index 2cd6a9d8..ffad7a99 100644 --- a/learning/runners/on_policy_runner.py +++ b/learning/runners/on_policy_runner.py @@ -21,7 +21,7 @@ def __init__(self, env, train_cfg, device="cpu"): self.device, ) - def learn(self): + def learn(self, states_to_log_dict=None): self.set_up_logger() rewards_dict = {} @@ -30,17 +30,25 @@ def learn(self): actor_obs = self.get_obs(self.actor_cfg["obs"]) critic_obs = self.get_obs(self.critic_cfg["obs"]) tot_iter = self.it + self.num_learning_iterations - self.save() + # self.save() + + # * Initialize smooth exploration matrices + if self.actor_cfg["smooth_exploration"]: + self.alg.actor.sample_weights(batch_size=self.env.num_envs) # * start up storage transition = TensorDict({}, batch_size=self.env.num_envs, device=self.device) transition.update( { "actor_obs": actor_obs, - "actions": self.alg.act(actor_obs, critic_obs), + "next_actor_obs": actor_obs, + "actions": self.alg.act(actor_obs), "critic_obs": critic_obs, + "next_critic_obs": critic_obs, "rewards": self.get_rewards({"termination": 0.0})["termination"], "dones": self.get_timed_out(), + "dof_pos": self.env.dof_pos, + "dof_vel": self.env.dof_vel, } ) storage.initialize( @@ -50,14 +58,30 @@ def learn(self): device=self.device, ) + # burn in observation normalization. + if self.actor_cfg["normalize_obs"] or self.critic_cfg["normalize_obs"]: + self.burn_in_normalization() + logger.tic("runtime") for self.it in range(self.it + 1, tot_iter + 1): logger.tic("iteration") logger.tic("collection") + + # * Simulate environment and log states + if states_to_log_dict is not None: + it_idx = self.it - 1 + if it_idx % 10 == 0: + self.sim_and_log_states(states_to_log_dict, it_idx) + # * Rollout with torch.inference_mode(): for i in range(self.num_steps_per_env): - actions = self.alg.act(actor_obs, critic_obs) + # * Re-sample noise matrix for smooth exploration + sample_freq = self.actor_cfg["exploration_sample_freq"] + if self.actor_cfg["smooth_exploration"] and i % sample_freq == 0: + self.alg.actor.sample_weights(batch_size=self.env.num_envs) + + actions = self.alg.act(actor_obs) self.set_actions( self.actor_cfg["actions"], actions, @@ -89,9 +113,13 @@ def learn(self): transition.update( { + "next_actor_obs": actor_obs, + "next_critic_obs": critic_obs, "rewards": total_rewards, "timed_out": timed_out, "dones": dones, + "dof_pos": self.env.dof_pos, + "dof_vel": self.env.dof_vel, } ) storage.add_transitions(transition) @@ -103,19 +131,36 @@ def learn(self): logger.tic("learning") self.alg.update(storage.data) - storage.clear() logger.toc("learning") - logger.log_all_categories() + if self.it % self.save_interval == 0: + self.save() + storage.clear() + + logger.log_all_categories() logger.finish_iteration() logger.toc("iteration") logger.toc("runtime") logger.print_to_terminal() - if self.it % self.save_interval == 0: - self.save() self.save() + @torch.no_grad + def burn_in_normalization(self, n_iterations=100): + actor_obs = self.get_obs(self.actor_cfg["obs"]) + # critic_obs = self.get_obs(self.critic_cfg["obs"]) + for _ in range(n_iterations): + actions = self.alg.act(actor_obs) + self.set_actions(self.actor_cfg["actions"], actions) + self.env.step() + actor_obs = self.get_noisy_obs( + self.actor_cfg["obs"], self.actor_cfg["noise"] + ) + # TODO: Check this, seems to perform better without critic eval + # critic_obs = self.get_obs(self.critic_cfg["obs"]) + # self.alg.critic.evaluate(critic_obs) + self.env.reset() + def update_rewards(self, rewards_dict, terminated): rewards_dict.update( self.get_rewards( @@ -137,13 +182,15 @@ def set_up_logger(self): ) logger.register_rewards(["total_rewards"]) logger.register_category( - "algorithm", self.alg, ["mean_value_loss", "mean_surrogate_loss"] + "algorithm", + self.alg, + ["mean_value_loss", "mean_surrogate_loss", "learning_rate"], ) logger.register_category("actor", self.alg.actor, ["action_std", "entropy"]) logger.attach_torch_obj_to_wandb((self.alg.actor, self.alg.critic)) - def save(self): + def save(self, save_storage=False): os.makedirs(self.log_dir, exist_ok=True) path = os.path.join(self.log_dir, "model_{}.pt".format(self.it)) torch.save( @@ -156,6 +203,9 @@ def save(self): }, path, ) + if save_storage: + path_data = os.path.join(self.log_dir, "data_{}".format(self.it)) + torch.save(storage.data.cpu(), path_data + ".pt") def load(self, path, load_optimizer=True): loaded_dict = torch.load(path) @@ -178,3 +228,42 @@ def get_inference_actions(self): def export(self, path): self.alg.actor.export(path) + + def sim_and_log_states(self, states_to_log_dict, it_idx): + # Simulate environment for as many steps as expected in the dict. + # Log states to the dict, as well as whether the env terminated. + steps = states_to_log_dict["terminated"].shape[2] + actor_obs = self.get_obs(self.policy_cfg["actor_obs"]) + critic_obs = self.get_obs(self.policy_cfg["critic_obs"]) + + with torch.inference_mode(): + for i in range(steps): + sample_freq = self.policy_cfg["exploration_sample_freq"] + if self.policy_cfg["smooth_exploration"] and i % sample_freq == 0: + self.alg.actor_critic.actor.sample_weights( + batch_size=self.env.num_envs + ) + + actions = self.alg.act(actor_obs, critic_obs) + self.set_actions( + self.policy_cfg["actions"], + actions, + self.policy_cfg["disable_actions"], + ) + + self.env.step() + + actor_obs = self.get_noisy_obs( + self.policy_cfg["actor_obs"], self.policy_cfg["noise"] + ) + critic_obs = self.get_obs(self.policy_cfg["critic_obs"]) + + # Log states (just for the first env) + terminated = self.get_terminated()[0] + for state in states_to_log_dict: + if state == "terminated": + states_to_log_dict[state][0, it_idx, i, :] = terminated + else: + states_to_log_dict[state][0, it_idx, i, :] = getattr( + self.env, state + )[0, :] diff --git a/learning/storage/__init__.py b/learning/storage/__init__.py index f10df85f..47f2593b 100644 --- a/learning/storage/__init__.py +++ b/learning/storage/__init__.py @@ -3,4 +3,5 @@ from .rollout_storage import RolloutStorage from .SE_storage import SERolloutStorage -from .dict_storage import DictStorage \ No newline at end of file +from .dict_storage import DictStorage +from .replay_buffer import ReplayBuffer \ No newline at end of file diff --git a/learning/storage/replay_buffer.py b/learning/storage/replay_buffer.py new file mode 100644 index 00000000..350a2c25 --- /dev/null +++ b/learning/storage/replay_buffer.py @@ -0,0 +1,54 @@ +import torch +from tensordict import TensorDict + + +class ReplayBuffer: + def __init__(self): + self.initialized = False + + def initialize( + self, + dummy_dict, + num_envs=2**12, + max_storage=2**17, + device="cpu", + ): + self.device = device + self.num_envs = num_envs + max_length = max_storage // num_envs + self.max_length = max_length + self.data = TensorDict({}, batch_size=(max_length, num_envs), device=device) + self.fill_count = 0 + self.add_index = 0 + + for key in dummy_dict.keys(): + if dummy_dict[key].dim() == 1: # if scalar + self.data[key] = torch.zeros( + (max_length, num_envs), + dtype=dummy_dict[key].dtype, + device=self.device, + ) + else: + self.data[key] = torch.zeros( + (max_length, num_envs, dummy_dict[key].shape[1]), + dtype=dummy_dict[key].dtype, + device=self.device, + ) + + @torch.inference_mode + def add_transitions(self, transition: TensorDict): + if self.fill_count >= self.max_length and self.add_index >= self.max_length: + self.add_index = 0 + self.data[self.add_index] = transition + self.fill_count += 1 + self.add_index += 1 + + def get_data(self): + return self.data[: min(self.fill_count, self.max_length), :] + + def clear(self): + self.fill_count = 0 + self.add_index = 0 + with torch.inference_mode(): + for tensor in self.data: + tensor.zero_() diff --git a/learning/utils/__init__.py b/learning/utils/__init__.py index 15c674d6..ec798653 100644 --- a/learning/utils/__init__.py +++ b/learning/utils/__init__.py @@ -1,8 +1,5 @@ -from .utils import ( - remove_zero_weighted_rewards, - set_discount_from_horizon, -) +from .utils import * from .dict_utils import * from .logger import Logger from .PBRS.PotentialBasedRewardShaping import PotentialBasedRewardShaping \ No newline at end of file diff --git a/learning/utils/dict_utils.py b/learning/utils/dict_utils.py index 2d19e072..99534c22 100644 --- a/learning/utils/dict_utils.py +++ b/learning/utils/dict_utils.py @@ -9,30 +9,29 @@ def compute_MC_returns(data: TensorDict, gamma, critic=None): else: last_values = critic.evaluate(data["critic_obs"][-1]) - data.update({"returns": torch.zeros_like(data["rewards"])}) - data["returns"][-1] = last_values * ~data["dones"][-1] + returns = torch.zeros_like(data["rewards"]) + returns[-1] = last_values * ~data["dones"][-1] for k in reversed(range(data["rewards"].shape[0] - 1)): not_done = ~data["dones"][k] - data["returns"][k] = ( - data["rewards"][k] + gamma * data["returns"][k + 1] * not_done - ) - data["returns"] = (data["returns"] - data["returns"].mean()) / ( - data["returns"].std() + 1e-8 - ) - return + returns[k] = data["rewards"][k] + gamma * returns[k + 1] * not_done + + return normalize(returns) @torch.no_grad -def compute_generalized_advantages(data, gamma, lam, critic, last_values=None): - data.update({"values": critic.evaluate(data["critic_obs"])}) +def normalize(input, eps=1e-8): + return (input - input.mean()) / (input.std() + eps) - data.update({"advantages": torch.zeros_like(data["values"])}) +@torch.no_grad +def compute_generalized_advantages(data, gamma, lam, critic): + last_values = critic.evaluate(data["next_critic_obs"][-1]) + advantages = torch.zeros_like(data["values"]) if last_values is not None: # todo check this # since we don't have observations for the last step, need last value plugged in not_done = ~data["dones"][-1] - data["advantages"][-1] = ( + advantages[-1] = ( data["rewards"][-1] + gamma * data["values"][-1] * data["timed_out"][-1] + gamma * last_values * not_done @@ -47,23 +46,30 @@ def compute_generalized_advantages(data, gamma, lam, critic, last_values=None): + gamma * data["values"][k + 1] * not_done - data["values"][k] ) - data["advantages"][k] = ( - td_error + gamma * lam * not_done * data["advantages"][k + 1] - ) + advantages[k] = td_error + gamma * lam * not_done * advantages[k + 1] - data["returns"] = data["advantages"] + data["values"] - - data["advantages"] = (data["advantages"] - data["advantages"].mean()) / ( - data["advantages"].std() + 1e-8 - ) + return advantages # todo change num_epochs to num_batches @torch.no_grad -def create_uniform_generator(data, batch_size, num_epochs, keys=None): +def create_uniform_generator( + data, batch_size, num_epochs=1, max_gradient_steps=None, keys=None +): n, m = data.shape total_data = n * m + + if batch_size > total_data: + Warning("Batch size is larger than total data, using available data only.") + batch_size = total_data + num_batches_per_epoch = total_data // batch_size + if max_gradient_steps: + if max_gradient_steps < num_batches_per_epoch: + num_batches_per_epoch = max_gradient_steps + num_epochs = max_gradient_steps // num_batches_per_epoch + num_epochs = max(num_epochs, 1) + for epoch in range(num_epochs): indices = torch.randperm(total_data, device=data.device) for i in range(num_batches_per_epoch): diff --git a/learning/utils/utils.py b/learning/utils/utils.py index 7a54c65a..d4eac980 100644 --- a/learning/utils/utils.py +++ b/learning/utils/utils.py @@ -17,3 +17,9 @@ def set_discount_from_horizon(dt, horizon): discount_factor = 1 - 1 / discrete_time_horizon return discount_factor + + +def polyak_update(online, target, polyak_factor): + for op, tp in zip(online.parameters(), target.parameters()): + tp.data.copy_((1.0 - polyak_factor) * op.data + polyak_factor * tp.data) + return target diff --git a/scripts/eval_rewards.py b/scripts/eval_rewards.py new file mode 100644 index 00000000..1c527e5f --- /dev/null +++ b/scripts/eval_rewards.py @@ -0,0 +1,191 @@ +from gym.envs import __init__ # noqa: F401 +from gym.utils import get_args, task_registry +from gym.utils.helpers import class_to_dict + +from learning.runners.finetune_runner import FineTuneRunner + +from gym import LEGGED_GYM_ROOT_DIR + +import os +import scipy.io +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt + +ROOT_DIR = f"{LEGGED_GYM_ROOT_DIR}/logs/mini_cheetah_ref/" +# SE_PATH = f"{LEGGED_GYM_ROOT_DIR}/logs/SE/model_1000.pt" # if None: no SE +SE_PATH = None +LOAD_RUN = "Jul24_22-48-41_nu05_B8" + +REWARDS_FILE = ( + "rewards_nu09_nosim2.csv" # generate this file from logs, if None: just plot +) + +PLOT_REWARDS = { + "Nu=0.5 no sim": "rewards_nu05_nosim.csv", + "Nu=0.9 no sim": "rewards_nu09_nosim.csv", + "Nu=0.9 no sim 2": "rewards_nu09_nosim2.csv", + "Nu=0.95 no sim": "rewards_nu095_nosim.csv", +} + +# Data struct fields from Robot-Software logs +DATA_LIST = [ + "header", + "base_height", # 1 + "base_lin_vel", # 2 + "base_ang_vel", # 3 + "projected_gravity", # 4 + "commands", # 5 + "dof_pos_obs", # 6 + "dof_vel", # 7 + "phase_obs", # 8 + "grf", # 9 + "dof_pos_target", # 10 + "torques", # 11 + "exploration_noise", # 12 + "footer", +] + +REWARD_WEIGHTS = { + "tracking_lin_vel": 4.0, + "tracking_ang_vel": 2.0, + "min_base_height": 1.5, + "orientation": 1.0, + "stand_still": 1.0, + "swing_grf": 3.0, + "stance_grf": 3.0, + "action_rate": 0.01, + "action_rate2": 0.001, +} + +DEVICE = "cuda" + + +def update_rewards_df(runner): + data_dict = {} + log_files = [ + file + for file in os.listdir(os.path.join(ROOT_DIR, LOAD_RUN)) + if file.endswith(".mat") + ] + log_files = sorted(log_files) + for file in log_files: + iteration = int(file.split(".")[0]) + path = os.path.join(ROOT_DIR, LOAD_RUN, file) + data = scipy.io.loadmat(path) + runner.data_struct = data[runner.data_name][0][0] + if runner.SE: + runner.update_state_estimates() + data_dict[iteration] = runner.data_struct + + rewards_path = os.path.join(ROOT_DIR, LOAD_RUN, REWARDS_FILE) + if os.path.exists(rewards_path): + os.remove(rewards_path) + + # Get rewards from runner + for it, data_struct in data_dict.items(): + _, rewards_dict = runner.get_data_rewards(data_struct, REWARD_WEIGHTS) + + # Save rewards in dataframe + if not os.path.exists(rewards_path): + rewards_df = pd.DataFrame(columns=["iteration", "type", "mean", "std"]) + else: + rewards_df = pd.read_csv(rewards_path) + for name, rewards in rewards_dict.items(): + rewards = np.array(rewards) + mean = rewards.mean() + std = rewards.std() + rewards_df = rewards_df._append( + { + "iteration": it, + "type": name, + "mean": mean, + "std": std, + }, + ignore_index=True, + ) + rewards_df.to_csv(rewards_path, index=False) + + +def plot_rewards(rewards_df, axs, name): + for i, key in enumerate(REWARD_WEIGHTS.keys()): + rewards_mean = rewards_df[rewards_df["type"] == key]["mean"].reset_index( + drop=True + ) + rewards_std = rewards_df[rewards_df["type"] == key]["std"].reset_index( + drop=True + ) + axs[i].plot(rewards_mean, label=name) + axs[i].fill_between( + range(len(rewards_mean)), + rewards_mean - rewards_std, + rewards_mean + rewards_std, + alpha=0.2, + ) + axs[i].set_title(key) + axs[i].legend() + + i = num_plots - 1 + total_mean = rewards_df[rewards_df["type"] == "total"]["mean"].reset_index( + drop=True + ) + total_std = rewards_df[rewards_df["type"] == "total"]["std"].reset_index(drop=True) + axs[i].plot(total_mean, label=name) + axs[i].fill_between( + range(len(total_mean)), + total_mean - total_std, + total_mean + total_std, + alpha=0.2, + ) + axs[i].set_title("Total Rewards") + axs[i].legend() + + +def setup(): + args = get_args() + + env_cfg, train_cfg = task_registry.create_cfgs(args) + task_registry.make_gym_and_sim() + env = task_registry.make_env(name=args.task, env_cfg=env_cfg) + + train_cfg = class_to_dict(train_cfg) + log_dir = os.path.join(ROOT_DIR, train_cfg["runner"]["load_run"]) + + runner = FineTuneRunner( + env, + train_cfg, + log_dir, + data_list=DATA_LIST, + device=DEVICE, + se_path=SE_PATH, + ) + + return runner + + +if __name__ == "__main__": + if REWARDS_FILE is not None: + runner = setup() + update_rewards_df(runner) + + # Plot rewards stats + num_plots = len(REWARD_WEIGHTS) + 1 # +1 for total rewards + cols = 5 + rows = np.ceil(num_plots / cols).astype(int) + fig, axs = plt.subplots(rows, cols, figsize=(20, 8)) + fig.suptitle("IPG Finetuning Rewards") + axs = axs.flatten() + + for name, file in PLOT_REWARDS.items(): + path = os.path.join(ROOT_DIR, LOAD_RUN, file) + rewards_df = pd.read_csv(path) + plot_rewards(rewards_df, axs, name) + + for i in range(num_plots): + axs[i].set_xlabel("Iter") + for i in range(num_plots, len(axs)): + axs[i].axis("off") + + plt.tight_layout() + plt.savefig(f"{ROOT_DIR}/{LOAD_RUN}/rewards_stats.png") + plt.show() diff --git a/scripts/export_policy.py b/scripts/export_policy.py index a5897359..d591ea14 100644 --- a/scripts/export_policy.py +++ b/scripts/export_policy.py @@ -22,6 +22,7 @@ def setup_and_export(args): LEGGED_GYM_ROOT_DIR, "logs", train_cfg.runner.experiment_name, + train_cfg.runner.load_run, "exported", ) runner.export(path) diff --git a/scripts/finetune.py b/scripts/finetune.py new file mode 100644 index 00000000..4db64925 --- /dev/null +++ b/scripts/finetune.py @@ -0,0 +1,168 @@ +from gym.envs import __init__ # noqa: F401 +from gym.utils import get_args, task_registry +from gym.utils.helpers import class_to_dict + +from learning.runners.finetune_runner import FineTuneRunner + +from gym import LEGGED_GYM_ROOT_DIR + +import os +import torch +import numpy as np +import pandas as pd + +ROOT_DIR = f"{LEGGED_GYM_ROOT_DIR}/logs/mini_cheetah_ref/" +# SE_PATH = f"{LEGGED_GYM_ROOT_DIR}/logs/SE/model_1000.pt" # if None: no SE +SE_PATH = None +OUTPUT_FILE = "output.txt" +LOSSES_FILE = "losses.csv" + +USE_SIMULATOR = False # For on-policy data +DATA_LENGTH = 12_000 # Robot-Software logs must contain at least this many steps + +# Load/save off-policy storage, this can contain many runs +OFFPOL_LOAD_FILE = None # "offpol_data.pt" +OFFPOL_SAVE_FILE = None # "offpol_data.pt" + +# Scales +EXPLORATION_SCALE = 0.8 # used during data collection +ACTION_SCALES = np.tile(np.array([0.2, 0.3, 0.3]), 4) + +# Data struct fields from Robot-Software logs +DATA_LIST = [ + "header", + "base_height", # 1 + "base_lin_vel", # 2 + "base_ang_vel", # 3 + "projected_gravity", # 4 + "commands", # 5 + "dof_pos_obs", # 6 + "dof_vel", # 7 + "phase_obs", # 8 + "grf", # 9 + "dof_pos_target", # 10 + "torques", # 11 + "exploration_noise", # 12 + "footer", +] + +DEVICE = "cuda" + + +def setup(): + args = get_args() + + env_cfg, train_cfg = task_registry.create_cfgs(args) + task_registry.make_gym_and_sim() + env = task_registry.make_env(name=args.task, env_cfg=env_cfg) + + train_cfg = class_to_dict(train_cfg) + log_dir = os.path.join(ROOT_DIR, train_cfg["runner"]["load_run"]) + + runner = FineTuneRunner( + env, + train_cfg, + log_dir, + data_list=DATA_LIST, + data_length=DATA_LENGTH, + se_path=SE_PATH, + use_simulator=USE_SIMULATOR, + exploration_scale=EXPLORATION_SCALE, + device=DEVICE, + ) + + return runner + + +def finetune(runner): + # Load model + load_run = runner.cfg["load_run"] + checkpoint = runner.cfg["checkpoint"] + model_path = os.path.join(ROOT_DIR, load_run, "model_" + str(checkpoint) + ".pt") + runner.load(model_path) + + # Load data + load_path = ( + os.path.join(ROOT_DIR, load_run, OFFPOL_LOAD_FILE) if OFFPOL_LOAD_FILE else None + ) + save_path = ( + os.path.join(ROOT_DIR, load_run, OFFPOL_SAVE_FILE) if OFFPOL_SAVE_FILE else None + ) + runner.load_data(load_path=load_path, save_path=save_path) + + # Get old inference actions + action_scales = torch.tensor(ACTION_SCALES).to(DEVICE) + actions_old = action_scales * runner.alg.actor.act_inference( + runner.data_onpol["actor_obs"] + ) + + # Perform a single update + runner.learn() + + # Compare old to new actions + actions_new = action_scales * runner.alg.actor.act_inference( + runner.data_onpol["actor_obs"] + ) + diff = actions_new - actions_old + + # Save and export + save_path = os.path.join(ROOT_DIR, load_run, "model_" + str(checkpoint + 1) + ".pt") + export_path = os.path.join(ROOT_DIR, load_run, "exported_" + str(checkpoint + 1)) + runner.save(save_path) + runner.export(export_path) + + # Print to output file + with open(os.path.join(ROOT_DIR, load_run, OUTPUT_FILE), "a") as f: + f.write(f"############ Checkpoint: {checkpoint} #######################\n") + f.write(f"############## Nu={runner.alg.inter_nu} ###################\n") + f.write("############### DATA ###############\n") + f.write(f"Data on-policy shape: {runner.data_onpol.shape}\n") + if runner.data_offpol is not None: + f.write(f"Data off-policy shape: {runner.data_offpol.shape}\n") + f.write("############## LOSSES ##############\n") + f.write(f"Mean Value Loss: {runner.alg.mean_value_loss}\n") + f.write(f"Mean Surrogate Loss: {runner.alg.mean_surrogate_loss}\n") + if runner.data_offpol is not None: + f.write(f"Mean Q Loss: {runner.alg.mean_q_loss}\n") + f.write(f"Mean Offpol Loss: {runner.alg.mean_offpol_loss}\n") + f.write("############## ACTIONS #############\n") + f.write(f"Mean action diff per actuator: {diff.mean(dim=(0, 1))}\n") + f.write(f"Std action diff per actuator: {diff.std(dim=(0, 1))}\n") + f.write(f"Overall mean action diff: {diff.mean()}\n") + + # Log losses to csv + losses_path = os.path.join(ROOT_DIR, load_run, LOSSES_FILE) + if not os.path.exists(losses_path): + if runner.ipg: + losses_df = pd.DataFrame( + columns=[ + "checkpoint", + "value_loss", + "q_loss", + "surrogate_loss", + "offpol_loss", + ] + ) + else: + losses_df = pd.DataFrame( + columns=["checkpoint", "value_loss", "surrogate_loss"] + ) + else: + losses_df = pd.read_csv(losses_path) + + append_data = { + "checkpoint": checkpoint, + "value_loss": runner.alg.mean_value_loss, + "surrogate_loss": runner.alg.mean_surrogate_loss, + } + if runner.data_offpol is not None: + append_data["q_loss"] = runner.alg.mean_q_loss + append_data["offpol_loss"] = runner.alg.mean_offpol_loss + + losses_df = losses_df._append(append_data, ignore_index=True) + losses_df.to_csv(losses_path, index=False) + + +if __name__ == "__main__": + runner = setup() + finetune(runner) diff --git a/scripts/finetune_multiple.sh b/scripts/finetune_multiple.sh new file mode 100644 index 00000000..e3dd96c1 --- /dev/null +++ b/scripts/finetune_multiple.sh @@ -0,0 +1,53 @@ +#!/bin/bash +source /home/lmolnar/miniconda3/etc/profile.d/conda.sh + +# Args +LOAD_RUN="Jul24_22-48-41_nu05_B8" +CHECKPOINT=1000 +N_RUNS=5 +INTER_NU=0.9 # can be fixed or adaptive +EVAL=false # no finetuning, just evaluate without exploration (set in RS) + +# Set directories +QGYM_DIR="/home/lmolnar/workspace/QGym" +RS_DIR="/home/lmolnar/workspace/Robot-Software" + +QGYM_LOG_DIR="${QGYM_DIR}/logs/mini_cheetah_ref/${LOAD_RUN}" + +# Change default folder name and copy to RS +mv ${QGYM_LOG_DIR}/exported ${QGYM_LOG_DIR}/exported_${CHECKPOINT} +cp ${QGYM_LOG_DIR}/exported_${CHECKPOINT}/* ${RS_DIR}/config/systems/quadruped/controllers/policies + +for i in $(seq 1 $N_RUNS) +do + # Store logs in files labeled by checkpoint + LCM_FILE=${RS_DIR}/logging/lcm_logs/${CHECKPOINT} + MAT_FILE=${RS_DIR}/logging/matlab_logs/${CHECKPOINT}.mat + rm ${LCM_FILE} + + # Run logging script in background + ${RS_DIR}/logging/scripts/run_lcm_logger.sh ${CHECKPOINT} & + PID1=$! + + # Run quadruped script and cancel logging script when done + ${RS_DIR}/build/bin/run_quadruped m s + kill $PID1 + + # Convert logs to .mat and copy to QGym + conda deactivate + conda activate robot-sw + ${RS_DIR}/logging/scripts/sim_data_convert.sh ${CHECKPOINT} + cp ${RS_DIR}/logging/matlab_logs/${CHECKPOINT}.mat ${QGYM_LOG_DIR} + + # Finetune in QGym + # INTER_NU=$(echo "0.05 * $i" | bc) # adaptive + if [ "$EVAL" = false ] ; then + conda deactivate + conda activate qgym + python ${QGYM_DIR}/scripts/finetune.py --task=mini_cheetah_finetune --headless --load_run=${LOAD_RUN} --checkpoint=${CHECKPOINT} --inter_nu=${INTER_NU} + fi + + # Copy policy to RS + CHECKPOINT=$((CHECKPOINT + 1)) + cp ${QGYM_LOG_DIR}/exported_${CHECKPOINT}/* ${RS_DIR}/config/systems/quadruped/controllers/policies +done \ No newline at end of file diff --git a/scripts/generate_commands.py b/scripts/generate_commands.py new file mode 100644 index 00000000..9fe21d79 --- /dev/null +++ b/scripts/generate_commands.py @@ -0,0 +1,45 @@ +import numpy as np +import random + +# Command ranges during training: +# x_range = [-2.0, 3.0] # [m/s] +# y_range = [-1.0, 1.0] # [m/s] +# yaw_range = [-3.0, 3.0] # [rad/s] + +# Command ranges for finetuning: +x_range = [-0.67, 1.0] # [m/s] +y_range = [-0.33, 0.33] # [m/s] +yaw_range = [-2.0, 2.0] # [rad/s] + +# Generate structured command sequence (fixed lin/ang vel, yaw, some random) +N = 100 +cmds_zero = np.array([[0, 0, 0]]).repeat(N, axis=0) +cmds = np.zeros((N, 3)) +for _ in range(4): + cmds = np.append(cmds, np.array([[x_range[1], 0, 0]]).repeat(3 * N, axis=0), axis=0) + cmds = np.append(cmds, cmds_zero, axis=0) + cmds = np.append(cmds, np.array([[x_range[0], 0, 0]]).repeat(3 * N, axis=0), axis=0) + cmds = np.append(cmds, cmds_zero, axis=0) + cmds = np.append(cmds, np.array([[0, y_range[1], 0]]).repeat(2 * N, axis=0), axis=0) + cmds = np.append(cmds, cmds_zero, axis=0) + cmds = np.append(cmds, np.array([[0, y_range[0], 0]]).repeat(2 * N, axis=0), axis=0) + cmds = np.append(cmds, cmds_zero, axis=0) + cmds = np.append( + cmds, np.array([[0, 0, yaw_range[1]]]).repeat(2 * N, axis=0), axis=0 + ) + cmds = np.append(cmds, cmds_zero, axis=0) + cmds = np.append( + cmds, np.array([[0, 0, yaw_range[0]]]).repeat(2 * N, axis=0), axis=0 + ) + cmds = np.append(cmds, cmds_zero, axis=0) + + for i in range(5): + x = random.uniform(x_range[0], x_range[1]) + y = random.uniform(y_range[0], y_range[1]) + yaw = random.uniform(yaw_range[0], yaw_range[1]) + cmds = np.append(cmds, np.array([[x, y, yaw]]).repeat(2 * N, axis=0), axis=0) + +print(cmds.shape) + +# Export to txt +np.savetxt("commands_long.txt", cmds, fmt="%.3f") diff --git a/scripts/plot_losses.py b/scripts/plot_losses.py new file mode 100644 index 00000000..e165b860 --- /dev/null +++ b/scripts/plot_losses.py @@ -0,0 +1,30 @@ +from gym import LEGGED_GYM_ROOT_DIR + +import os +import matplotlib.pyplot as plt +import pandas as pd + +ROOT_DIR = f"{LEGGED_GYM_ROOT_DIR}/logs/mini_cheetah_ref/" +LOAD_RUN = "Jul26_11-58-37_LinkedIPG_100Hz_nu02_v08" + +PLOT_LOSSES = { + "Nu=0.9 sim": "losses_nu09_sim.csv", + "Nu=0.9 no sim": "losses_nu09_nosim.csv", +} + +LABELS = ["value_loss", "surrogate_loss", "q_loss", "offpol_loss"] + +fig, axs = plt.subplots(2, 2, figsize=(10, 6)) +plt.suptitle(" IPG Finetuning Losses") + +for i, label in enumerate(LABELS): + for name, file in PLOT_LOSSES.items(): + data = pd.read_csv(os.path.join(ROOT_DIR, LOAD_RUN, file)) + axs[i // 2, i % 2].plot(data[label], label=name) + axs[i // 2, i % 2].set_title(label) + axs[i // 2, i % 2].set_xlabel("Checkpoint") + axs[i // 2, i % 2].set_ylabel("Loss") + axs[i // 2, i % 2].legend() + +plt.tight_layout() +plt.show() diff --git a/scripts/train.py b/scripts/train.py index 283717e8..40966682 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -1,5 +1,5 @@ from gym.envs import __init__ # noqa: F401 -from gym.utils import get_args, task_registry, randomize_episode_counters +from gym.utils import get_args, task_registry from gym.utils.logging_and_saving import wandb_singleton from gym.utils.logging_and_saving import local_code_save_helper @@ -12,9 +12,8 @@ def setup(): task_registry.make_gym_and_sim() wandb_helper.setup_wandb(env_cfg=env_cfg, train_cfg=train_cfg, args=args) env = task_registry.make_env(name=args.task, env_cfg=env_cfg) - randomize_episode_counters(env) + # randomize_episode_counters(env) - randomize_episode_counters(env) policy_runner = task_registry.make_alg_runner(env, train_cfg) local_code_save_helper.save_local_files_to_logs(train_cfg.log_dir) diff --git a/scripts/visualize_ipg.py b/scripts/visualize_ipg.py new file mode 100644 index 00000000..95b03cad --- /dev/null +++ b/scripts/visualize_ipg.py @@ -0,0 +1,124 @@ +import os +import shutil +import torch +import numpy as np +import matplotlib.pyplot as plt + +from learning.modules.actor import Actor +from learning.modules.critic import Critic + +from learning.utils import ( + compute_generalized_advantages, + compute_MC_returns, +) +from learning.modules.lqrc.plotting import plot_pendulum_critics_with_data +from gym import LEGGED_GYM_ROOT_DIR + +DEVICE = "cpu" + +# * Setup +LOAD_RUN = "Jul25_11-28-41_LinkedIPG_nu05_v05" +TITLE = "LinkedIPG nu=0.5 blend=0.5 polyak=0.9" +IT_RANGE = range(20, 101, 20) + +RUN_DIR = os.path.join(LEGGED_GYM_ROOT_DIR, "logs", "pendulum", LOAD_RUN) +PLOT_DIR = os.path.join(RUN_DIR, "visualize_critic") +os.makedirs(PLOT_DIR, exist_ok=True) + +# * V-Critic, Q-Critic and Actor +critic_q = Critic( + num_obs=4, hidden_dims=[128, 64, 32], activation="tanh", normalize_obs=False +).to(DEVICE) +critic_v = Critic( + num_obs=3, hidden_dims=[128, 64, 32], activation="tanh", normalize_obs=False +).to(DEVICE) +actor = Actor( + num_obs=3, + num_actions=1, + hidden_dims=[128, 64, 32], + activation="tanh", + normalize_obs=False, +).to(DEVICE) + +# * Params +n_envs = 4096 # that were trained with +gamma = 0.95 +lam = 1.0 +episode_steps = 100 +visualize_steps = 10 # just to show rollouts +n_trajs = 64 +rand_perm = torch.randperm(n_envs) +traj_idx = rand_perm[0:n_trajs] +test_idx = rand_perm[n_trajs : n_trajs + 1000] + +for it in IT_RANGE: + # load data + base_data = torch.load(os.path.join(RUN_DIR, "data_onpol_{}.pt".format(it))).to( + DEVICE + ) + data = base_data.detach().clone() + + dof_pos = base_data["dof_pos"].detach().clone() + dof_vel = base_data["dof_vel"].detach().clone() + graphing_obs = torch.cat((dof_pos, dof_vel), dim=2) + + # compute ground-truth + graphing_data = { + data_name: {} for data_name in ["obs", "values", "returns", "actions"] + } + + episode_rollouts = compute_MC_returns(base_data, gamma) + print(f"Initializing value offset to: {episode_rollouts.mean().item()}") + graphing_data["obs"]["Ground Truth MC Returns"] = graphing_obs[0, :] + graphing_data["values"]["Ground Truth MC Returns"] = episode_rollouts[0, :] + graphing_data["returns"]["Ground Truth MC Returns"] = episode_rollouts[0, :] + + # load models + model = torch.load(os.path.join(RUN_DIR, "model_{}.pt".format(it))) + critic_v.load_state_dict(model["critic_v_state_dict"]) + critic_q.load_state_dict(model["critic_q_state_dict"]) + actor.load_state_dict(model["actor_state_dict"]) + + # V-critic values and returns + data["values"] = critic_v.evaluate(data["critic_obs"]) + data["advantages"] = compute_generalized_advantages(data, gamma, lam, critic_v) + data["returns"] = data["advantages"] + data["values"] + + with torch.no_grad(): + graphing_data["obs"]["V-Critic"] = graphing_obs[0, :] + graphing_data["values"]["V-Critic"] = critic_v.evaluate( + data[0, :]["critic_obs"] + ) + graphing_data["returns"]["V-Critic"] = data[0, :]["returns"] + graphing_data["actions"] = actor(data[0, :]["actor_obs"]) + + # Q-critic values and returns + actions = actor(data["actor_obs"]) + critic_q_obs = torch.cat((data["critic_obs"], actions), dim=2) + data["values"] = critic_q.evaluate(critic_q_obs) + data["next_critic_obs"] = critic_q_obs # needed for GAE + data["advantages"] = compute_generalized_advantages(data, gamma, lam, critic_q) + data["returns"] = data["advantages"] + data["values"] + + with torch.no_grad(): + graphing_data["obs"]["Q-Critic"] = graphing_obs[0, :] + graphing_data["values"]["Q-Critic"] = critic_q.evaluate(critic_q_obs[0, :]) + graphing_data["returns"]["Q-Critic"] = data[0, :]["returns"] + + # generate plots + grid_size = int(np.sqrt(n_envs)) + plot_pendulum_critics_with_data( + x=graphing_data["obs"], + predictions=graphing_data["values"], + targets=graphing_data["returns"], + actions=graphing_data["actions"], + title=f"{TITLE} Iteration {it}", + fn=PLOT_DIR + f"/IPG_it{it}", + data=graphing_obs[:visualize_steps, traj_idx], + grid_size=grid_size, + ) + + plt.close() + +this_file = os.path.join(LEGGED_GYM_ROOT_DIR, "scripts", "visualize_ipg.py") +shutil.copy(this_file, os.path.join(PLOT_DIR, os.path.basename(this_file))) diff --git a/scripts/visualize_ppo.py b/scripts/visualize_ppo.py new file mode 100644 index 00000000..61cf123a --- /dev/null +++ b/scripts/visualize_ppo.py @@ -0,0 +1,102 @@ +import os +import shutil +import torch +import numpy as np +import matplotlib.pyplot as plt + +from learning.modules.actor import Actor +from learning.modules.critic import Critic + +from learning.utils import ( + compute_generalized_advantages, + compute_MC_returns, +) +from learning.modules.lqrc.plotting import plot_pendulum_critics_with_data +from gym import LEGGED_GYM_ROOT_DIR + +DEVICE = "cpu" + +# * Setup +LOAD_RUN = "Jul17_15-39-53_PPO" +TITLE = "PPO" +IT_RANGE = range(20, 101, 20) + +RUN_DIR = os.path.join(LEGGED_GYM_ROOT_DIR, "logs", "pendulum", LOAD_RUN) +PLOT_DIR = os.path.join(RUN_DIR, "visualize_critic") +os.makedirs(PLOT_DIR, exist_ok=True) + +# * Critic and Actor +critic = Critic( + num_obs=3, hidden_dims=[128, 64, 32], activation="tanh", normalize_obs=False +).to(DEVICE) +actor = Actor( + num_obs=3, + num_actions=1, + hidden_dims=[128, 64, 32], + activation="tanh", + normalize_obs=False, +).to(DEVICE) + +# * Params +n_envs = 4096 # that were trained with +gamma = 0.95 +lam = 1.0 +visualize_steps = 10 # just to show rollouts +n_trajs = 64 +rand_perm = torch.randperm(n_envs) +traj_idx = rand_perm[0:n_trajs] +test_idx = rand_perm[n_trajs : n_trajs + 1000] + +for it in IT_RANGE: + # load data + base_data = torch.load(os.path.join(RUN_DIR, "data_{}.pt".format(it))).to(DEVICE) + data = base_data.detach().clone() + + dof_pos = base_data["dof_pos"].detach().clone() + dof_vel = base_data["dof_vel"].detach().clone() + graphing_obs = torch.cat((dof_pos, dof_vel), dim=2) + + # compute ground-truth + graphing_data = { + data_name: {} for data_name in ["obs", "values", "returns", "actions"] + } + + episode_rollouts = compute_MC_returns(base_data, gamma) + print(f"Initializing value offset to: {episode_rollouts.mean().item()}") + graphing_data["obs"]["Ground Truth MC Returns"] = graphing_obs[0, :] + graphing_data["values"]["Ground Truth MC Returns"] = episode_rollouts[0, :] + graphing_data["returns"]["Ground Truth MC Returns"] = episode_rollouts[0, :] + + # load models + model = torch.load(os.path.join(RUN_DIR, "model_{}.pt".format(it))) + critic.load_state_dict(model["critic_state_dict"]) + actor.load_state_dict(model["actor_state_dict"]) + + # compute values and returns + data["values"] = critic.evaluate(data["critic_obs"]) + data["advantages"] = compute_generalized_advantages(data, gamma, lam, critic) + data["returns"] = data["advantages"] + data["values"] + + with torch.no_grad(): + graphing_data["obs"]["V-Critic"] = graphing_obs[0, :] + graphing_data["values"]["V-Critic"] = critic.evaluate(data[0, :]["critic_obs"]) + graphing_data["returns"]["V-Critic"] = data[0, :]["returns"] + graphing_data["actions"]["V-Critic"] = actor(data[0, :]["actor_obs"]) + + # generate plots + grid_size = int(np.sqrt(n_envs)) + plot_pendulum_critics_with_data( + x=graphing_data["obs"], + predictions=graphing_data["values"], + targets=graphing_data["returns"], + actions=graphing_data["actions"], + title=f"{TITLE} Iteration {it}", + fn=PLOT_DIR + f"/PPO_it{it}", + data=graphing_obs[:visualize_steps, traj_idx], + grid_size=grid_size, + ) + + plt.close() + +this_file = os.path.join(LEGGED_GYM_ROOT_DIR, "scripts", "visualize_ppo.py") +shutil.copy(this_file, os.path.join(PLOT_DIR, os.path.basename(this_file)))