diff --git a/gym/envs/__init__.py b/gym/envs/__init__.py index 57cbd8c0..f9d8b788 100644 --- a/gym/envs/__init__.py +++ b/gym/envs/__init__.py @@ -19,7 +19,7 @@ "Anymal": ".anymal_c.anymal", "A1": ".a1.a1", "HumanoidRunning": ".mit_humanoid.humanoid_running", - "Pendulum": ".pendulum.pendulum", + "Pendulum": ".pendulum.pendulum" } config_dict = { @@ -27,11 +27,13 @@ "MiniCheetahCfg": ".mini_cheetah.mini_cheetah_config", "MiniCheetahRefCfg": ".mini_cheetah.mini_cheetah_ref_config", "MiniCheetahOscCfg": ".mini_cheetah.mini_cheetah_osc_config", + "MiniCheetahSACCfg": ".mini_cheetah.mini_cheetah_SAC_config", "MITHumanoidCfg": ".mit_humanoid.mit_humanoid_config", "A1Cfg": ".a1.a1_config", "AnymalCFlatCfg": ".anymal_c.flat.anymal_c_flat_config", "HumanoidRunningCfg": ".mit_humanoid.humanoid_running_config", "PendulumCfg": ".pendulum.pendulum_config", + "PendulumSACCfg": ".pendulum.pendulum_SAC_config", } runner_config_dict = { @@ -39,11 +41,13 @@ "MiniCheetahRunnerCfg": ".mini_cheetah.mini_cheetah_config", "MiniCheetahRefRunnerCfg": ".mini_cheetah.mini_cheetah_ref_config", "MiniCheetahOscRunnerCfg": ".mini_cheetah.mini_cheetah_osc_config", + "MiniCheetahSACRunnerCfg": ".mini_cheetah.mini_cheetah_SAC_config", "MITHumanoidRunnerCfg": ".mit_humanoid.mit_humanoid_config", "A1RunnerCfg": ".a1.a1_config", "AnymalCFlatRunnerCfg": ".anymal_c.flat.anymal_c_flat_config", "HumanoidRunningRunnerCfg": ".mit_humanoid.humanoid_running_config", "PendulumRunnerCfg": ".pendulum.pendulum_config", + "PendulumSACRunnerCfg": ".pendulum.pendulum_SAC_config", } task_dict = { @@ -59,6 +63,11 @@ "MiniCheetahOscCfg", "MiniCheetahOscRunnerCfg", ], + "sac_mini_cheetah": [ + "MiniCheetahRef", + "MiniCheetahSACCfg", + "MiniCheetahSACRunnerCfg" + ], "humanoid": ["MIT_Humanoid", "MITHumanoidCfg", "MITHumanoidRunnerCfg"], "humanoid_running": [ "HumanoidRunning", @@ -66,7 +75,8 @@ "HumanoidRunningRunnerCfg", ], "flat_anymal_c": ["Anymal", "AnymalCFlatCfg", "AnymalCFlatRunnerCfg"], - "pendulum": ["Pendulum", "PendulumCfg", "PendulumRunnerCfg"] + "pendulum": ["Pendulum", "PendulumCfg", "PendulumRunnerCfg"], + "sac_pendulum": ["Pendulum", "PendulumSACCfg", "PendulumSACRunnerCfg"], } for class_name, class_location in class_dict.items(): diff --git a/gym/envs/base/fixed_robot.py b/gym/envs/base/fixed_robot.py index 1fe4c5d5..6d9dc971 100644 --- a/gym/envs/base/fixed_robot.py +++ b/gym/envs/base/fixed_robot.py @@ -40,64 +40,60 @@ def __init__(self, gym, sim, cfg, sim_params, sim_device, headless): self.reset() def step(self): - """Apply actions, simulate, call self.post_physics_step() - and pre_physics_step() - - Args: - actions (torch.Tensor): Tensor of shape - (num_envs, num_actions_per_env) - """ - self._reset_buffers() - self._pre_physics_step() - # * step physics and render each frame + self._pre_decimation_step() self._render() for _ in range(self.cfg.control.decimation): + self._pre_compute_torques() self.torques = self._compute_torques() - - if self.cfg.asset.disable_motors: - self.torques[:] = 0.0 - torques_to_gym_tensor = torch.zeros( - self.num_envs, self.num_dof, device=self.device - ) - - # todo encapsulate - next_torques_idx = 0 - for dof_idx in range(self.num_dof): - if self.cfg.control.actuated_joints_mask[dof_idx]: - torques_to_gym_tensor[:, dof_idx] = self.torques[ - :, next_torques_idx - ] - next_torques_idx += 1 - else: - torques_to_gym_tensor[:, dof_idx] = torch.zeros( - self.num_envs, device=self.device - ) - - self.gym.set_dof_actuation_force_tensor( - self.sim, gymtorch.unwrap_tensor(torques_to_gym_tensor) - ) - self.gym.simulate(self.sim) - if self.device == "cpu": - self.gym.fetch_results(self.sim, True) - self.gym.refresh_dof_state_tensor(self.sim) - - self._post_physics_step() + self._post_compute_torques() + self._step_physx_sim() + self._post_physx_step() + self._post_decimation_step() self._check_terminations_and_timeouts() env_ids = self.to_be_reset.nonzero(as_tuple=False).flatten() self._reset_idx(env_ids) - def _pre_physics_step(self): - pass + def _pre_decimation_step(self): + return None + + def _pre_compute_torques(self): + return None + + def _post_compute_torques(self): + if self.cfg.asset.disable_motors: + self.torques[:] = 0.0 - def _post_physics_step(self): + def _step_physx_sim(self): + next_torques_idx = 0 + torques_to_gym_tensor = torch.zeros( + self.num_envs, self.num_dof, device=self.device + ) + for dof_idx in range(self.num_dof): + if self.cfg.control.actuated_joints_mask[dof_idx]: + torques_to_gym_tensor[:, dof_idx] = self.torques[:, next_torques_idx] + next_torques_idx += 1 + else: + torques_to_gym_tensor[:, dof_idx] = torch.zeros( + self.num_envs, device=self.device + ) + self.gym.set_dof_actuation_force_tensor( + self.sim, gymtorch.unwrap_tensor(self.torques) + ) + self.gym.simulate(self.sim) + if self.device == "cpu": + self.gym.fetch_results(self.sim, True) + self.gym.refresh_dof_state_tensor(self.sim) + + def _post_physx_step(self): """ check terminations, compute observations and rewards """ self.gym.refresh_actor_root_state_tensor(self.sim) self.gym.refresh_net_contact_force_tensor(self.sim) + def _post_decimation_step(self): self.episode_length_buf += 1 self.common_step_counter += 1 @@ -212,18 +208,6 @@ def _process_rigid_body_props(self, props, env_id): return props def _compute_torques(self): - """Compute torques from actions. - Actions can be interpreted as position or velocity targets given - to a PD controller, or directly as scaled torques. - [NOTE]: torques must have the same dimension as the number of DOFs, - even if some DOFs are not actuated. - - Args: - actions (torch.Tensor): Actions - - Returns: - [torch.Tensor]: Torques sent to the simulation - """ actuated_dof_pos = torch.zeros( self.num_envs, self.num_actuators, device=self.device ) @@ -415,10 +399,7 @@ def _init_buffers(self): self.default_act_pos = self.default_act_pos.unsqueeze(0) # * store indices of actuated joints self.act_idx = to_torch(actuated_idx, dtype=torch.long, device=self.device) - # * check that init range highs and lows are consistent - # * and repopulate to match - if self.cfg.init_state.reset_mode == "reset_to_range": - self.initialize_ranges_for_initial_conditions() + self.initialize_ranges_for_initial_conditions() def initialize_ranges_for_initial_conditions(self): self.dof_pos_range = torch.zeros( diff --git a/gym/envs/base/fixed_robot_config.py b/gym/envs/base/fixed_robot_config.py index 8c7427f8..ecf12dae 100644 --- a/gym/envs/base/fixed_robot_config.py +++ b/gym/envs/base/fixed_robot_config.py @@ -123,34 +123,33 @@ class FixedRobotCfgPPO(BaseConfig): class logging: enable_local_saving = True - class policy: + class actor: init_noise_std = 1.0 hidden_dims = [512, 256, 128] - critic_hidden_dims = [512, 256, 128] # * can be elu, relu, selu, crelu, lrelu, tanh, sigmoid activation = "elu" - # only for 'ActorCriticRecurrent': - # rnn_type = 'lstm' - # rnn_hidden_size = 512 - # rnn_num_layers = 1 - obs = [ "observation_a", "observation_b", "these_need_to_be_atributes_(states)_of_the_robot_env", ] - - critic_obs = [ - "observation_x", - "observation_y", - "critic_obs_can_be_the_same_or_different_than_actor_obs", - ] + normalize_obs = True actions = ["tau_ff"] disable_actions = False class noise: - noise = 0.1 # implement as needed, also in your robot class + observation_a = 0.1 # implement as needed, also in your robot class + + class critic: + hidden_dims = [512, 256, 128] + activation = "elu" + normalize_obs = True + obs = [ + "observation_x", + "observation_y", + "critic_obs_can_be_the_same_or_different_than_actor_obs", + ] class rewards: class weights: @@ -165,20 +164,25 @@ class termination_weight: termination = 0.0 class algorithm: - # * training params - value_loss_coef = 1.0 - use_clipped_value_loss = True + # both + gamma = 0.99 + lam = 0.95 + # shared + batch_size = 2**15 + max_gradient_steps = 10 + # new + storage_size = 2**17 # new + batch_size = 2**15 # new + clip_param = 0.2 - entropy_coef = 0.01 - num_learning_epochs = 5 - # * mini batch size = num_envs*nsteps / nminibatches - num_mini_batches = 4 learning_rate = 1.0e-3 + max_grad_norm = 1.0 + # Critic + use_clipped_value_loss = True + # Actor + entropy_coef = 0.01 schedule = "adaptive" # could be adaptive, fixed - gamma = 0.99 - lam = 0.95 desired_kl = 0.01 - max_grad_norm = 1.0 class runner: policy_class_name = "ActorCritic" @@ -189,6 +193,7 @@ class runner: # * logging # * check for potential saves every this many iterations save_interval = 50 + log_storage = False run_name = "" experiment_name = "fixed_robot" diff --git a/gym/envs/base/legged_robot_config.py b/gym/envs/base/legged_robot_config.py index 7e6cabc6..ce95ed0a 100644 --- a/gym/envs/base/legged_robot_config.py +++ b/gym/envs/base/legged_robot_config.py @@ -238,13 +238,12 @@ class actor: hidden_dims = [512, 256, 128] # can be elu, relu, selu, crelu, lrelu, tanh, sigmoid activation = "elu" - normalize_obs = True - obs = [ "observation_a", "observation_b", "these_need_to_be_atributes_(states)_of_the_robot_env", ] + normalize_obs = True actions = ["q_des"] disable_actions = False @@ -288,25 +287,30 @@ class termination_weight: termination = 0.01 class algorithm: - # * training params - value_loss_coef = 1.0 - use_clipped_value_loss = True + # both + gamma = 0.99 + lam = 0.95 + # shared + batch_size = 2**15 + max_gradient_steps = 10 + # new + storage_size = 2**17 # new + batch_size = 2**15 # new + clip_param = 0.2 - entropy_coef = 0.01 - num_learning_epochs = 5 - # * mini batch size = num_envs*nsteps / nminibatches - num_mini_batches = 4 learning_rate = 1.0e-3 + max_grad_norm = 1.0 + # Critic + use_clipped_value_loss = True + # Actor + entropy_coef = 0.01 schedule = "adaptive" # could be adaptive, fixed - gamma = 0.99 - lam = 0.95 desired_kl = 0.01 - max_grad_norm = 1.0 class runner: policy_class_name = "ActorCritic" algorithm_class_name = "PPO2" - num_steps_per_env = 24 + num_steps_per_env = 24 # deprecate max_iterations = 1500 save_interval = 50 run_name = "" diff --git a/gym/envs/base/task_skeleton.py b/gym/envs/base/task_skeleton.py index ffbac6ce..0974970a 100644 --- a/gym/envs/base/task_skeleton.py +++ b/gym/envs/base/task_skeleton.py @@ -45,6 +45,7 @@ def reset(self): """Reset all robots""" self._reset_idx(torch.arange(self.num_envs, device=self.device)) self.step() + self.episode_length_buf[:] = 0 def _reset_buffers(self): self.to_be_reset[:] = False @@ -67,7 +68,7 @@ def _eval_reward(self, name): def _check_terminations_and_timeouts(self): """Check if environments need to be reset""" contact_forces = self.contact_forces[:, self.termination_contact_indices, :] - self.terminated = torch.any(torch.norm(contact_forces, dim=-1) > 1.0, dim=1) + self.terminated |= torch.any(torch.norm(contact_forces, dim=-1) > 1.0, dim=1) self.timed_out = self.episode_length_buf >= self.max_episode_length self.to_be_reset = self.timed_out | self.terminated diff --git a/gym/envs/cartpole/cartpole_config.py b/gym/envs/cartpole/cartpole_config.py index f18d5a3b..6e073d5c 100644 --- a/gym/envs/cartpole/cartpole_config.py +++ b/gym/envs/cartpole/cartpole_config.py @@ -67,7 +67,7 @@ class CartpoleRunnerCfg(FixedRobotCfgPPO): seed = -1 runner_class_name = "OnPolicyRunner" - class policy(FixedRobotCfgPPO.policy): + class actor(FixedRobotCfgPPO.actor): init_noise_std = 1.0 num_layers = 2 num_units = 32 diff --git a/gym/envs/mini_cheetah/mini_cheetah_SAC_config.py b/gym/envs/mini_cheetah/mini_cheetah_SAC_config.py new file mode 100644 index 00000000..69da4cf8 --- /dev/null +++ b/gym/envs/mini_cheetah/mini_cheetah_SAC_config.py @@ -0,0 +1,183 @@ +from gym.envs.base.legged_robot_config import LeggedRobotRunnerCfg +from gym.envs.mini_cheetah.mini_cheetah_ref_config import MiniCheetahRefCfg + +BASE_HEIGHT_REF = 0.3 + + +class MiniCheetahSACCfg(MiniCheetahRefCfg): + class env(MiniCheetahRefCfg.env): + num_envs = 1 + episode_length_s = 4 # TODO + + class terrain(MiniCheetahRefCfg.terrain): + pass + + class init_state(MiniCheetahRefCfg.init_state): + pass + + class control(MiniCheetahRefCfg.control): + # * PD Drive parameters: + stiffness = {"haa": 20.0, "hfe": 20.0, "kfe": 20.0} + damping = {"haa": 0.5, "hfe": 0.5, "kfe": 0.5} + gait_freq = 3.0 + ctrl_frequency = 20 # TODO + desired_sim_frequency = 100 + + class commands(MiniCheetahRefCfg.commands): + pass + + class push_robots(MiniCheetahRefCfg.push_robots): + pass + + class domain_rand(MiniCheetahRefCfg.domain_rand): + pass + + class asset(MiniCheetahRefCfg.asset): + file = ( + "{LEGGED_GYM_ROOT_DIR}/resources/robots/" + + "mini_cheetah/urdf/mini_cheetah_simple.urdf" + ) + foot_name = "foot" + penalize_contacts_on = ["shank"] + terminate_after_contacts_on = ["base"] + end_effector_names = ["foot"] + collapse_fixed_joints = False + self_collisions = 1 + flip_visual_attachments = False + disable_gravity = False + disable_motors = False + joint_damping = 0.1 + rotor_inertia = [0.002268, 0.002268, 0.005484] * 4 + + class reward_settings(MiniCheetahRefCfg.reward_settings): + soft_dof_pos_limit = 0.9 + soft_dof_vel_limit = 0.9 + soft_torque_limit = 0.9 + max_contact_force = 600.0 + base_height_target = BASE_HEIGHT_REF + tracking_sigma = 0.25 + + class scaling(MiniCheetahRefCfg.scaling): + base_ang_vel = 0.3 + base_lin_vel = BASE_HEIGHT_REF + dof_vel = 4 * [2.0, 2.0, 4.0] + base_height = 0.3 + dof_pos = 4 * [0.2, 0.3, 0.3] + dof_pos_obs = dof_pos + dof_pos_target = 4 * [0.2, 0.3, 0.3] + tau_ff = 4 * [18, 18, 28] + commands = [3, 1, 3] + + +class MiniCheetahSACRunnerCfg(LeggedRobotRunnerCfg): + seed = -1 + runner_class_name = "OffPolicyRunner" + + class actor: + hidden_dims = { + "latent": [128, 128], + "mean": [64], + "std": [64], + } + # * can be elu, relu, selu, crelu, lrelu, tanh, sigmoid + activation = { + "latent": "elu", + "mean": "elu", + "std": "elu", + } + + # TODO[lm]: Handle normalization + normalize_obs = False + obs = [ + "base_ang_vel", + "projected_gravity", + "commands", + "dof_pos_obs", + "dof_vel", + "phase_obs", + ] + + actions = ["dof_pos_target"] + add_noise = False # TODO + disable_actions = False + + class noise: + scale = 1.0 + dof_pos_obs = 0.01 + base_ang_vel = 0.01 + dof_pos = 0.005 + dof_vel = 0.005 + lin_vel = 0.05 + ang_vel = [0.3, 0.15, 0.4] + gravity_vec = 0.1 + + class critic: + hidden_dims = [128, 128, 64] + # * can be elu, relu, selu, crelu, lrelu, tanh, sigmoid + activation = "elu" + + # TODO[lm]: Handle normalization + normalize_obs = False + obs = [ + "base_height", + "base_lin_vel", + "base_ang_vel", + "projected_gravity", + "commands", + "dof_pos_obs", + "dof_vel", + "phase_obs", + "dof_pos_target", + ] + + class reward: + class weights: + tracking_lin_vel = 4.0 + tracking_ang_vel = 2.0 + lin_vel_z = 0.0 + ang_vel_xy = 0.01 + orientation = 1.0 + torques = 5.0e-7 + dof_vel = 0.0 + min_base_height = 1.5 + collision = 0.0 + action_rate = 0.01 + action_rate2 = 0.001 + stand_still = 0.0 + dof_pos_limits = 0.0 + feet_contact_forces = 0.0 + dof_near_home = 0.0 + reference_traj = 1.5 + swing_grf = 1.5 + stance_grf = 1.5 + + class termination_weight: + termination = 0.15 + + class algorithm(LeggedRobotRunnerCfg.algorithm): + # Taken from SAC pendulum + initial_fill = 500 + storage_size = 10**6 + batch_size = 256 + max_gradient_steps = 1 # 10 + action_max = 1.0 # TODO + action_min = -1.0 # TODO + actor_noise_std = 0.5 # TODO + log_std_max = 4.0 + log_std_min = -20.0 + alpha = 0.2 + target_entropy = -12.0 # -action_dim + max_grad_norm = 1.0 + polyak = 0.995 # flipped compared to SB3 (polyak == 1-tau) + gamma = 0.99 + alpha_lr = 1e-4 + actor_lr = 1e-4 + critic_lr = 1e-4 + + class runner(LeggedRobotRunnerCfg.runner): + run_name = "" + experiment_name = "sac_mini_cheetah" + max_iterations = 50_000 + algorithm_class_name = "SAC" + save_interval = 10_000 + num_steps_per_env = 1 diff --git a/gym/envs/mini_cheetah/mini_cheetah_config.py b/gym/envs/mini_cheetah/mini_cheetah_config.py index 28c70123..ff82119f 100644 --- a/gym/envs/mini_cheetah/mini_cheetah_config.py +++ b/gym/envs/mini_cheetah/mini_cheetah_config.py @@ -130,7 +130,6 @@ class actor: hidden_dims = [256, 256, 128] # * can be elu, relu, selu, crelu, lrelu, tanh, sigmoid activation = "elu" - obs = [ "base_lin_vel", "base_ang_vel", @@ -140,6 +139,7 @@ class actor: "dof_vel", "dof_pos_target", ] + normalize_obs = True actions = ["dof_pos_target"] add_noise = True disable_actions = False @@ -168,6 +168,7 @@ class critic: "dof_vel", "dof_pos_target", ] + normalize_obs = True class reward: class weights: @@ -190,20 +191,7 @@ class termination_weight: termination = 0.01 class algorithm(LeggedRobotRunnerCfg.algorithm): - # * training params - value_loss_coef = 1.0 - use_clipped_value_loss = True - clip_param = 0.2 - entropy_coef = 0.02 - num_learning_epochs = 4 - # * mini batch size = num_envs*nsteps / nminibatches - num_mini_batches = 8 - learning_rate = 1.0e-5 - schedule = "adaptive" # can be adaptive or fixed - discount_horizon = 1.0 # [s] - # GAE_bootstrap_horizon = 2.0 # [s] - desired_kl = 0.01 - max_grad_norm = 1.0 + pass class runner(LeggedRobotRunnerCfg.runner): run_name = "" diff --git a/gym/envs/mini_cheetah/mini_cheetah_osc_config.py b/gym/envs/mini_cheetah/mini_cheetah_osc_config.py index a4acefa6..5aaafde3 100644 --- a/gym/envs/mini_cheetah/mini_cheetah_osc_config.py +++ b/gym/envs/mini_cheetah/mini_cheetah_osc_config.py @@ -160,6 +160,7 @@ class scaling(MiniCheetahCfg.scaling): class MiniCheetahOscRunnerCfg(MiniCheetahRunnerCfg): seed = -1 + runner_class_name = "OnPolicyRunner" class policy: hidden_dims = [256, 256, 128] diff --git a/gym/envs/mini_cheetah/mini_cheetah_ref_config.py b/gym/envs/mini_cheetah/mini_cheetah_ref_config.py index 3c92882d..bd8140b2 100644 --- a/gym/envs/mini_cheetah/mini_cheetah_ref_config.py +++ b/gym/envs/mini_cheetah/mini_cheetah_ref_config.py @@ -73,7 +73,6 @@ class actor: hidden_dims = [256, 256, 128] # * can be elu, relu, selu, crelu, lrelu, tanh, sigmoid activation = "elu" - normalize_obs = True obs = [ "base_ang_vel", "projected_gravity", @@ -82,6 +81,7 @@ class actor: "dof_vel", "phase_obs", ] + normalize_obs = True actions = ["dof_pos_target"] disable_actions = False @@ -100,7 +100,6 @@ class critic: hidden_dims = [256, 256, 128] # * can be elu, relu, selu, crelu, lrelu, tanh, sigmoid activation = "elu" - normalize_obs = True obs = [ "base_height", "base_lin_vel", @@ -112,6 +111,7 @@ class critic: "phase_obs", "dof_pos_target", ] + normalize_obs = True class reward: class weights: @@ -138,27 +138,11 @@ class termination_weight: termination = 0.15 class algorithm(MiniCheetahRunnerCfg.algorithm): - # training params - value_loss_coef = 1.0 # deprecate for PPO2 - use_clipped_value_loss = True # deprecate for PPO2 - clip_param = 0.2 - entropy_coef = 0.01 - num_learning_epochs = 6 - # mini batch size = num_envs*nsteps/nminibatches - num_mini_batches = 4 - storage_size = 2**17 # new - mini_batch_size = 2**15 # new - learning_rate = 5.0e-5 - schedule = "adaptive" # can be adaptive, fixed - discount_horizon = 1.0 # [s] - lam = 0.95 - GAE_bootstrap_horizon = 2.0 # [s] - desired_kl = 0.01 - max_grad_norm = 1.0 + pass class runner(MiniCheetahRunnerCfg.runner): run_name = "" experiment_name = "mini_cheetah_ref" max_iterations = 500 # number of policy updates algorithm_class_name = "PPO2" - num_steps_per_env = 32 + num_steps_per_env = 32 # deprecate diff --git a/gym/envs/mit_humanoid/mit_humanoid_config.py b/gym/envs/mit_humanoid/mit_humanoid_config.py index 86a8d606..c7e061a9 100644 --- a/gym/envs/mit_humanoid/mit_humanoid_config.py +++ b/gym/envs/mit_humanoid/mit_humanoid_config.py @@ -198,6 +198,7 @@ class actor: "dof_vel", "dof_pos_history", ] + normalize_obs = True actions = ["dof_pos_target"] disable_actions = False @@ -226,6 +227,7 @@ class critic: "dof_vel", "dof_pos_history", ] + normalize_obs = True class reward: class weights: diff --git a/gym/envs/pendulum/pendulum.py b/gym/envs/pendulum/pendulum.py index 95af1fb2..d0bcfb9c 100644 --- a/gym/envs/pendulum/pendulum.py +++ b/gym/envs/pendulum/pendulum.py @@ -1,15 +1,49 @@ import torch +import numpy as np +from math import sqrt from gym.envs.base.fixed_robot import FixedRobot class Pendulum(FixedRobot): - def _post_physics_step(self): - """Update all states that are not handled in PhysX""" - super()._post_physics_step() + def _init_buffers(self): + super()._init_buffers() + self.dof_pos_obs = torch.zeros(self.num_envs, 2, device=self.device) + + def _post_decimation_step(self): + super()._post_decimation_step() + self.dof_pos_obs = torch.cat([self.dof_pos.sin(), self.dof_pos.cos()], dim=1) + + def _reset_system(self, env_ids): + super()._reset_system(env_ids) + self.dof_pos_obs[env_ids] = torch.cat( + [self.dof_pos[env_ids].sin(), self.dof_pos[env_ids].cos()], dim=1 + ) + + def _check_terminations_and_timeouts(self): + super()._check_terminations_and_timeouts() + self.terminated = self.timed_out + + def reset_to_uniform(self, env_ids): + grid_points = int(sqrt(self.num_envs)) + lin_pos = torch.linspace( + self.dof_pos_range[0, 0], + self.dof_pos_range[0, 1], + grid_points, + device=self.device, + ) + lin_vel = torch.linspace( + self.dof_vel_range[0, 0], + self.dof_vel_range[0, 1], + grid_points, + device=self.device, + ) + grid = torch.cartesian_prod(lin_pos, lin_vel) + self.dof_pos[env_ids] = grid[:, 0].unsqueeze(-1) + self.dof_vel[env_ids] = grid[:, 1].unsqueeze(-1) def _reward_theta(self): - theta_rwd = torch.cos(self.dof_pos[:, 0]) / self.scales["dof_pos"] + theta_rwd = torch.cos(self.dof_pos[:, 0]) # no scaling return self._sqrdexp(theta_rwd.squeeze(dim=-1)) def _reward_omega(self): @@ -17,27 +51,31 @@ def _reward_omega(self): return self._sqrdexp(omega_rwd.squeeze(dim=-1)) def _reward_equilibrium(self): - error = torch.abs(self.dof_state) - error[:, 0] /= self.scales["dof_pos"] - error[:, 1] /= self.scales["dof_vel"] - return self._sqrdexp(torch.mean(error, dim=1), scale=0.01) - # return torch.exp( - # -error.pow(2).sum(dim=1) / self.cfg.reward_settings.tracking_sigma - # ) - - def _reward_torques(self): - """Penalize torques""" - return self._sqrdexp(torch.mean(torch.square(self.torques), dim=1), scale=0.2) + theta_norm = self._normalize_theta() + omega = self.dof_vel[:, 0] + error = torch.stack( + [theta_norm / self.scales["dof_pos"], omega / self.scales["dof_vel"]], dim=1 + ) + return self._sqrdexp(torch.mean(error, dim=1), scale=0.2) def _reward_energy(self): - m_pendulum = 1.0 - l_pendulum = 1.0 kinetic_energy = ( - 0.5 * m_pendulum * l_pendulum**2 * torch.square(self.dof_vel[:, 0]) + 0.5 + * self.cfg.asset.mass + * self.cfg.asset.length**2 + * torch.square(self.dof_vel[:, 0]) ) potential_energy = ( - m_pendulum * 9.81 * l_pendulum * torch.cos(self.dof_pos[:, 0]) + self.cfg.asset.mass + * 9.81 + * self.cfg.asset.length + * torch.cos(self.dof_pos[:, 0]) ) - desired_energy = m_pendulum * 9.81 * l_pendulum + desired_energy = self.cfg.asset.mass * 9.81 * self.cfg.asset.length energy_error = kinetic_energy + potential_energy - desired_energy return self._sqrdexp(energy_error / desired_energy) + + def _normalize_theta(self): + # normalize to range [-pi, pi] + theta = self.dof_pos[:, 0] + return ((theta + np.pi) % (2 * np.pi)) - np.pi diff --git a/gym/envs/pendulum/pendulum_SAC_config.py b/gym/envs/pendulum/pendulum_SAC_config.py new file mode 100644 index 00000000..e3b3047b --- /dev/null +++ b/gym/envs/pendulum/pendulum_SAC_config.py @@ -0,0 +1,113 @@ +import torch +from gym.envs.base.fixed_robot_config import FixedRobotCfgPPO +from gym.envs.pendulum.pendulum_config import PendulumCfg + + +class PendulumSACCfg(PendulumCfg): + class env(PendulumCfg.env): + num_envs = 1024 + episode_length_s = 10 + + class init_state(PendulumCfg.init_state): + reset_mode = "reset_to_uniform" + default_joint_angles = {"theta": 0.0} + dof_pos_range = { + "theta": [-torch.pi, torch.pi], + } + dof_vel_range = {"theta": [-5, 5]} + + class control(PendulumCfg.control): + ctrl_frequency = 10 + desired_sim_frequency = 100 + + class asset(PendulumCfg.asset): + joint_damping = 0.1 + + class reward_settings(PendulumCfg.reward_settings): + tracking_sigma = 0.25 + + class scaling(PendulumCfg.scaling): + dof_vel = 5.0 + dof_pos = 2.0 * torch.pi + tau_ff = 1.0 + + +class PendulumSACRunnerCfg(FixedRobotCfgPPO): + seed = -1 + runner_class_name = "OffPolicyRunner" + + class actor: + hidden_dims = { + "latent": [128, 64], + "mean": [32], + "std": [32], + } + # * can be elu, relu, selu, crelu, lrelu, tanh, sigmoid + activation = { + "latent": "elu", + "mean": "elu", + "std": "elu", + } + + normalize_obs = False + obs = [ + "dof_pos_obs", + "dof_vel", + ] + actions = ["tau_ff"] + disable_actions = False + + class noise: + dof_pos = 0.0 + dof_vel = 0.0 + + class critic: + obs = [ + "dof_pos_obs", + "dof_vel", + ] + hidden_dims = [128, 64, 32] + # * can be elu, relu, selu, crelu, lrelu, tanh, sigmoid + activation = "elu" + # TODO[lm]: Current normalization uses torch.no_grad, this should be changed + normalize_obs = False + + class reward: + class weights: + theta = 0.0 + omega = 0.0 + equilibrium = 1.0 + energy = 0.5 + dof_vel = 0.0 + torques = 0.025 + + class termination_weight: + termination = 0.0 + + class algorithm(FixedRobotCfgPPO.algorithm): + initial_fill = 0 + storage_size = 100 * 1024 # steps_per_episode * num_envs + batch_size = 256 # 4096 + max_gradient_steps = 1 # 10 # SB3: 1 + action_max = 2.0 + action_min = -2.0 + actor_noise_std = 1.0 + log_std_max = 4.0 + log_std_min = -20.0 + alpha = 0.2 + target_entropy = -1.0 + max_grad_norm = 1.0 + polyak = 0.995 # flipped compared to stable-baselines3 (polyak == 1-tau) + gamma = 0.99 + alpha_lr = 1e-4 + actor_lr = 1e-3 + critic_lr = 1e-3 + + class runner(FixedRobotCfgPPO.runner): + run_name = "" + experiment_name = "sac_pendulum" + max_iterations = 30_000 # number of policy updates + algorithm_class_name = "SAC" + num_steps_per_env = 1 + save_interval = 2000 + log_storage = True diff --git a/gym/envs/pendulum/pendulum_config.py b/gym/envs/pendulum/pendulum_config.py index 920356cd..d8086666 100644 --- a/gym/envs/pendulum/pendulum_config.py +++ b/gym/envs/pendulum/pendulum_config.py @@ -5,9 +5,9 @@ class PendulumCfg(FixedRobotCfg): class env(FixedRobotCfg.env): - num_envs = 2**13 - num_actuators = 1 # 1 for theta connecting base and pole - episode_length_s = 5.0 + num_envs = 4096 + num_actuators = 1 + episode_length_s = 10 class terrain(FixedRobotCfg.terrain): pass @@ -18,12 +18,12 @@ class viewer: lookat = [0.0, 0.0, 0.0] # [m] class init_state(FixedRobotCfg.init_state): - default_joint_angles = {"theta": 0.0} # -torch.pi / 2.0} + default_joint_angles = {"theta": 0} # -torch.pi / 2.0} # * default setup chooses how the initial conditions are chosen. # * "reset_to_basic" = a single position # * "reset_to_range" = uniformly random from a range defined below - reset_mode = "reset_to_range" + reset_mode = "reset_to_uniform" # * initial conditions for reset_to_range dof_pos_range = { @@ -33,8 +33,8 @@ class init_state(FixedRobotCfg.init_state): class control(FixedRobotCfg.control): actuated_joints_mask = [1] # angle - ctrl_frequency = 100 - desired_sim_frequency = 200 + ctrl_frequency = 10 + desired_sim_frequency = 100 stiffness = {"theta": 0.0} # [N*m/rad] damping = {"theta": 0.0} # [N*m*s/rad] @@ -44,6 +44,8 @@ class asset(FixedRobotCfg.asset): disable_gravity = False disable_motors = False # all torques set to 0 joint_damping = 0.1 + mass = 1.0 + length = 1.0 class reward_settings(FixedRobotCfg.reward_settings): tracking_sigma = 0.25 @@ -57,15 +59,17 @@ class scaling(FixedRobotCfg.scaling): class PendulumRunnerCfg(FixedRobotCfgPPO): seed = -1 - runner_class_name = "DataLoggingRunner" + runner_class_name = "OnPolicyRunner" class actor: hidden_dims = [128, 64, 32] # * can be elu, relu, selu, crelu, lrelu, tanh, sigmoid activation = "tanh" + # TODO[lm]: Handle normalization in SAC, then also use it here again + normalize_obs = False obs = [ - "dof_pos", + "dof_pos_obs", "dof_vel", ] @@ -77,21 +81,23 @@ class noise: dof_vel = 0.0 class critic: - obs = [ - "dof_pos", - "dof_vel", - ] hidden_dims = [128, 64, 32] # * can be elu, relu, selu, crelu, lrelu, tanh, sigmoid activation = "tanh" - standard_critic_nn = True + + # TODO[lm]: Handle normalization in SAC, then also use it here again + normalize_obs = False + obs = [ + "dof_pos_obs", + "dof_vel", + ] class reward: class weights: theta = 0.0 omega = 0.0 equilibrium = 1.0 - energy = 0.0 + energy = 0.5 dof_vel = 0.0 torques = 0.025 @@ -99,27 +105,30 @@ class termination_weight: termination = 0.0 class algorithm(FixedRobotCfgPPO.algorithm): - # training params - value_loss_coef = 1.0 - use_clipped_value_loss = True + # both + gamma = 0.95 + # discount_horizon = 2.0 + lam = 0.98 + # shared + max_gradient_steps = 24 + # new + storage_size = 2**17 # new + batch_size = 2**16 # new clip_param = 0.2 + learning_rate = 1.0e-4 + max_grad_norm = 1.0 + # Critic + use_clipped_value_loss = True + # Actor entropy_coef = 0.01 - num_learning_epochs = 6 - # * mini batch size = num_envs*nsteps / nminibatches - num_mini_batches = 4 - learning_rate = 1.0e-3 schedule = "fixed" # could be adaptive, fixed - discount_horizon = 2.0 # [s] - lam = 0.98 - # GAE_bootstrap_horizon = .0 # [s] desired_kl = 0.01 - max_grad_norm = 1.0 - standard_loss = True - plus_c_penalty = 0.1 class runner(FixedRobotCfgPPO.runner): run_name = "" experiment_name = "pendulum" - max_iterations = 500 # number of policy updates + max_iterations = 100 # number of policy updates algorithm_class_name = "PPO2" - num_steps_per_env = 32 + num_steps_per_env = 100 + save_interval = 20 + log_storage = True diff --git a/learning/algorithms/__init__.py b/learning/algorithms/__init__.py index ace1b9fe..78231181 100644 --- a/learning/algorithms/__init__.py +++ b/learning/algorithms/__init__.py @@ -33,3 +33,4 @@ from .ppo import PPO from .ppo2 import PPO2 from .SE import StateEstimator +from .sac import SAC \ No newline at end of file diff --git a/learning/algorithms/ppo.py b/learning/algorithms/ppo.py index 92898ad5..befeea79 100644 --- a/learning/algorithms/ppo.py +++ b/learning/algorithms/ppo.py @@ -105,9 +105,6 @@ def init_storage( self.device, ) - def test_mode(self): - self.actor_critic.test() - def train_mode(self): self.actor_critic.train() diff --git a/learning/algorithms/ppo2.py b/learning/algorithms/ppo2.py index 5866052c..09c60ab9 100644 --- a/learning/algorithms/ppo2.py +++ b/learning/algorithms/ppo2.py @@ -5,6 +5,7 @@ from learning.utils import ( create_uniform_generator, compute_generalized_advantages, + normalize, ) @@ -13,8 +14,8 @@ def __init__( self, actor, critic, - num_learning_epochs=1, - num_mini_batches=1, + batch_size=2**15, + max_gradient_steps=10, clip_param=0.2, gamma=0.998, lam=0.95, @@ -24,7 +25,6 @@ def __init__( use_clipped_value_loss=True, schedule="fixed", desired_kl=0.01, - loss_fn="MSE", device="cpu", **kwargs, ): @@ -42,49 +42,42 @@ def __init__( # * PPO parameters self.clip_param = clip_param - self.num_learning_epochs = num_learning_epochs - self.num_mini_batches = num_mini_batches + self.batch_size = batch_size + self.max_gradient_steps = max_gradient_steps self.entropy_coef = entropy_coef self.gamma = gamma self.lam = lam self.max_grad_norm = max_grad_norm self.use_clipped_value_loss = use_clipped_value_loss - def test_mode(self): - self.actor.test() - self.critic.test() - def switch_to_train(self): self.actor.train() self.critic.train() - def act(self, obs, critic_obs): + def act(self, obs): return self.actor.act(obs).detach() - def update(self, data, last_obs=None): - if last_obs is None: - last_values = None - else: - with torch.no_grad(): - last_values = self.critic.evaluate(last_obs).detach() - compute_generalized_advantages( - data, self.gamma, self.lam, self.critic, last_values + def update(self, data): + data["values"] = self.critic.evaluate(data["critic_obs"]) + data["advantages"] = compute_generalized_advantages( + data, self.gamma, self.lam, self.critic ) - + data["returns"] = data["advantages"] + data["values"] self.update_critic(data) + data["advantages"] = normalize(data["advantages"]) self.update_actor(data) def update_critic(self, data): self.mean_value_loss = 0 counter = 0 - n, m = data.shape - total_data = n * m - batch_size = total_data // self.num_mini_batches - generator = create_uniform_generator(data, batch_size, self.num_learning_epochs) + generator = create_uniform_generator( + data, + self.batch_size, + max_gradient_steps=self.max_gradient_steps, + ) for batch in generator: - value_batch = self.critic.evaluate(batch["critic_obs"]) - value_loss = self.critic.loss_fn(value_batch, batch["returns"]) + value_loss = self.critic.loss_fn(batch["critic_obs"], batch["returns"]) self.critic_optimizer.zero_grad() value_loss.backward() nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm) @@ -94,8 +87,6 @@ def update_critic(self, data): self.mean_value_loss /= counter def update_actor(self, data): - # already done before - # compute_generalized_advantages(data, self.gamma, self.lam, self.critic) self.mean_surrogate_loss = 0 counter = 0 @@ -106,12 +97,12 @@ def update_actor(self, data): data["actions"] ).detach() - n, m = data.shape - total_data = n * m - batch_size = total_data // self.num_mini_batches - generator = create_uniform_generator(data, batch_size, self.num_learning_epochs) + generator = create_uniform_generator( + data, + self.batch_size, + max_gradient_steps=self.max_gradient_steps, + ) for batch in generator: - # ! refactor how this is done self.actor.act(batch["actor_obs"]) actions_log_prob_batch = self.actor.get_actions_log_prob(batch["actions"]) mu_batch = self.actor.action_mean @@ -158,10 +149,7 @@ def update_actor(self, data): # * Gradient step self.optimizer.zero_grad() loss.backward() - nn.utils.clip_grad_norm_( - list(self.actor.parameters()) + list(self.critic.parameters()), - self.max_grad_norm, - ) + nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm) self.optimizer.step() self.mean_surrogate_loss += surrogate_loss.item() counter += 1 diff --git a/learning/algorithms/sac.py b/learning/algorithms/sac.py new file mode 100644 index 00000000..e0651d2d --- /dev/null +++ b/learning/algorithms/sac.py @@ -0,0 +1,252 @@ +import torch + +# import torch.nn as nn +import torch.optim as optim + +from learning.utils import create_uniform_generator, polyak_update + + +class SAC: + def __init__( + self, + actor, + critic_1, + critic_2, + target_critic_1, + target_critic_2, + batch_size=2**15, + max_gradient_steps=10, + action_max=1.0, + action_min=-1.0, + actor_noise_std=1.0, + log_std_max=4.0, + log_std_min=-20.0, + alpha=0.2, + alpha_lr=1e-4, + target_entropy=None, + max_grad_norm=1.0, + polyak=0.995, + gamma=0.99, + actor_lr=1e-4, + critic_lr=1e-4, + device="cpu", + **kwargs, + ): + self.device = device + + # * SAC components + self.actor = actor.to(self.device) + self.critic_1 = critic_1.to(self.device) + self.critic_2 = critic_2.to(self.device) + self.target_critic_1 = target_critic_1.to(self.device) + self.target_critic_2 = target_critic_2.to(self.device) + self.target_critic_1.load_state_dict(self.critic_1.state_dict()) + self.target_critic_2.load_state_dict(self.critic_2.state_dict()) + + self.log_alpha = torch.log(torch.tensor(alpha)).requires_grad_() + + self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_lr) + self.log_alpha_optimizer = optim.Adam([self.log_alpha], lr=alpha_lr) + self.critic_1_optimizer = optim.Adam(self.critic_1.parameters(), lr=critic_lr) + self.critic_2_optimizer = optim.Adam(self.critic_2.parameters(), lr=critic_lr) + + # this is probably something to put into the neural network + self.action_max = action_max + self.action_min = action_min + + self.action_delta = (self.action_max - self.action_min) / 2.0 + self.action_offset = (self.action_max + self.action_min) / 2.0 + + self.max_grad_norm = max_grad_norm + self.target_entropy = ( + target_entropy if target_entropy else -self.actor.num_actions + ) + + # * SAC parameters + self.max_gradient_steps = max_gradient_steps + self.batch_size = batch_size + self.polyak = polyak + self.gamma = gamma + # self.ent_coef = "fixed" + # self.target_entropy = "fixed" + + @property + def alpha(self): + return self.log_alpha.exp() + + def switch_to_train(self): + self.actor.train() + self.critic_1.train() + self.critic_2.train() + self.target_critic_1.train() + self.target_critic_2.train() + + def switch_to_eval(self): + self.actor.eval() + self.critic_1.eval() + self.critic_2.eval() + self.target_critic_1.eval() + self.target_critic_2.eval() + + def act(self, obs): + mean, std = self.actor.forward(obs, deterministic=False) + distribution = torch.distributions.Normal(mean, std) + actions = distribution.rsample() + actions_normalized = torch.tanh(actions) + # RSL also does a resahpe(-1, self.action_size), not sure why + actions_scaled = ( + actions_normalized * self.action_delta + self.action_offset + ).clamp(self.action_min, self.action_max) + return actions_scaled + + def act_inference(self, obs): + mean = self.actor.forward(obs, deterministic=True) + actions_normalized = torch.tanh(mean) + actions_scaled = ( + actions_normalized * self.action_delta + self.action_offset + ).clamp(self.action_min, self.action_max) + return actions_scaled + + def update(self, data): + generator = create_uniform_generator( + data, + self.batch_size, + max_gradient_steps=self.max_gradient_steps, + ) + + count = 0 + self.mean_actor_loss = 0 + self.mean_alpha_loss = 0 + self.mean_critic_1_loss = 0 + self.mean_critic_2_loss = 0 + + for batch in generator: + self.update_critic(batch) + self.update_actor_and_alpha(batch) + + count += 1 + # Update Target Networks + self.target_critic_1 = polyak_update( + self.critic_1, self.target_critic_1, self.polyak + ) + self.target_critic_2 = polyak_update( + self.critic_2, self.target_critic_2, self.polyak + ) + self.mean_actor_loss /= count + self.mean_alpha_loss /= count + self.mean_critic_1_loss /= count + self.mean_critic_2_loss /= count + + return None + + def update_critic(self, batch): + critic_obs = batch["critic_obs"] + actions = batch["actions"] + rewards = batch["rewards"] + next_actor_obs = batch["next_actor_obs"] + next_critic_obs = batch["next_critic_obs"] + dones = batch["dones"] + + with torch.no_grad(): + # * self._sample_action(actor_next_obs) + mean, std = self.actor.forward(next_actor_obs, deterministic=False) + distribution = torch.distributions.Normal(mean, std) + next_actions = distribution.rsample() + + ## * self._scale_actions(actions, intermediate=True) + actions_normalized = torch.tanh(next_actions) + # RSL also does a resahpe(-1, self.action_size), not sure why + actions_scaled = ( + actions_normalized * self.action_delta + self.action_offset + ).clamp(self.action_min, self.action_max) + ## * + action_logp = ( + distribution.log_prob(next_actions) + - torch.log(1.0 - actions_normalized.pow(2) + 1e-6) + ).sum(-1) + + # * returns target_action = actions_scaled, target_action_logp = action_logp + target_action = actions_scaled + target_action_logp = action_logp + + # * self._critic_input + # ! def should put the action computation into the actor + target_critic_in = torch.cat((next_critic_obs, target_action), dim=-1) + target_critic_prediction_1 = self.target_critic_1.forward(target_critic_in) + target_critic_prediction_2 = self.target_critic_2.forward(target_critic_in) + + target_next = ( + torch.min(target_critic_prediction_1, target_critic_prediction_2) + - self.alpha.detach() * target_action_logp + ) + # the detach inside torch.no_grad() should be redundant + target = rewards + self.gamma * dones.logical_not() * target_next + + critic_in = torch.cat((critic_obs, actions), dim=-1) + + # critic_prediction_1 = self.critic_1.forward(critic_in) + critic_loss_1 = self.critic_1.loss_fn(critic_in, target) + self.critic_1_optimizer.zero_grad() + critic_loss_1.backward() + # nn.utils.clip_grad_norm_(self.critic_1.parameters(), self.max_grad_norm) + self.critic_1_optimizer.step() + + # critic_prediction_2 = self.critic_2.forward(critic_in) + critic_loss_2 = self.critic_2.loss_fn(critic_in, target) + self.critic_2_optimizer.zero_grad() + critic_loss_2.backward() + # nn.utils.clip_grad_norm_(self.critic_2.parameters(), self.max_grad_norm) + self.critic_2_optimizer.step() + + self.mean_critic_1_loss += critic_loss_1.item() + self.mean_critic_2_loss += critic_loss_2.item() + + return + + def update_actor_and_alpha(self, batch): + actor_obs = batch["actor_obs"] + critic_obs = batch["critic_obs"] + + mean, std = self.actor.forward(actor_obs, deterministic=False) + distribution = torch.distributions.Normal(mean, std) + actions = distribution.rsample() + + ## * self._scale_actions(actions, intermediate=True) + actions_normalized = torch.tanh(actions) + # RSL also does a resahpe(-1, self.action_size), not sure why + actions_scaled = ( + actions_normalized * self.action_delta + self.action_offset + ).clamp(self.action_min, self.action_max) + ## * + action_logp = ( + distribution.log_prob(actions) + - torch.log(1.0 - actions_normalized.pow(2) + 1e-6) + ).sum(-1) + + # * returns target_action = actions_scaled, target_action_logp = action_logp + actor_prediction = actions_scaled + actor_prediction_logp = action_logp + + # entropy loss + alpha_loss = -( + self.log_alpha * (action_logp + self.target_entropy).detach() + ).mean() + + self.log_alpha_optimizer.zero_grad() + alpha_loss.backward() + self.log_alpha_optimizer.step() + + critic_in = torch.cat((critic_obs, actor_prediction), dim=-1) + q_value_1 = self.critic_1.forward(critic_in) + q_value_2 = self.critic_2.forward(critic_in) + actor_loss = ( + self.alpha.detach() * actor_prediction_logp + - torch.min(q_value_1, q_value_2) + ).mean() + self.actor_optimizer.zero_grad() + actor_loss.backward() + # nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm) + self.actor_optimizer.step() + + self.mean_alpha_loss += alpha_loss.item() + self.mean_actor_loss += actor_loss.item() diff --git a/learning/modules/__init__.py b/learning/modules/__init__.py index 3caf5fec..8d4a3ab1 100644 --- a/learning/modules/__init__.py +++ b/learning/modules/__init__.py @@ -34,3 +34,4 @@ from .state_estimator import StateEstimatorNN from .actor import Actor from .critic import Critic +from .chimera_actor import ChimeraActor \ No newline at end of file diff --git a/learning/modules/chimera_actor.py b/learning/modules/chimera_actor.py new file mode 100644 index 00000000..1addc83d --- /dev/null +++ b/learning/modules/chimera_actor.py @@ -0,0 +1,83 @@ +import torch +import torch.nn as nn +from torch.distributions import Normal +from .utils import create_MLP +from .utils import export_network +from .utils import RunningMeanStd + + +class ChimeraActor(nn.Module): + def __init__( + self, + num_obs, + num_actions, + hidden_dims, + activation, + std_init=1.0, + log_std_max=4.0, + log_std_min=-20.0, + normalize_obs=True, + **kwargs, + ): + super().__init__() + + self._normalize_obs = normalize_obs + if self._normalize_obs: + self.obs_rms = RunningMeanStd(num_obs) + + self.log_std_max = log_std_max + self.log_std_min = log_std_min + self.log_std_init = torch.tensor([std_init]).log() # refactor + + self.num_obs = num_obs + self.num_actions = num_actions + + self.latent_NN = create_MLP( + num_inputs=num_obs, + num_outputs=hidden_dims["latent"][-1], + hidden_dims=hidden_dims["latent"][:-1], + activation=activation["latent"], + ) + self.mean_NN = create_MLP( + num_inputs=hidden_dims["latent"][-1], + num_outputs=num_actions, + hidden_dims=hidden_dims["mean"], + activation=activation["mean"], + ) + self.std_NN = create_MLP( + num_inputs=hidden_dims["latent"][-1], + num_outputs=num_actions, + hidden_dims=hidden_dims["std"], + activation=activation["std"], + ) + + # maybe zap + self.distribution = Normal(torch.zeros(num_actions), torch.ones(num_actions)) + Normal.set_default_validate_args = False + + def forward(self, x, deterministic=True): + if self._normalize_obs: + with torch.no_grad(): + x = self.obs_rms(x) + latent = self.latent_NN(x) + mean = self.mean_NN(latent) + if deterministic: + return mean + log_std = self.log_std_init + self.std_NN(latent) + return mean, log_std.clamp(self.log_std_min, self.log_std_max).exp() + + def act(self, x): + mean, std = self.forward(x, deterministic=False) + self.distribution = Normal(mean, std) + return self.distribution.sample() + + def inference_policy(self, x): + return self.forward(x, deterministic=True) + + def export(self, path): + export_network(self.inference_policy, "policy", path, self.num_obs) + + def to(self, device): + super().to(device) + self.log_std_init = self.log_std_init.to(device) + return self diff --git a/learning/modules/critic.py b/learning/modules/critic.py index 732f8eda..96c53438 100644 --- a/learning/modules/critic.py +++ b/learning/modules/critic.py @@ -20,11 +20,14 @@ def __init__( if self._normalize_obs: self.obs_rms = RunningMeanStd(num_obs) - def evaluate(self, critic_observations): + def forward(self, x): if self._normalize_obs: with torch.no_grad(): - critic_observations = self.obs_rms(critic_observations) - return self.NN(critic_observations).squeeze() + x = self.obs_rms(x) + return self.NN(x).squeeze() + + def evaluate(self, critic_observations): + return self.forward(critic_observations) - def loss_fn(self, input, target): - return nn.functional.mse_loss(input, target, reduction="mean") + def loss_fn(self, obs, target): + return nn.functional.mse_loss(self.forward(obs), target, reduction="mean") diff --git a/learning/modules/lqrc/plotting.py b/learning/modules/lqrc/plotting.py new file mode 100644 index 00000000..1c9b3581 --- /dev/null +++ b/learning/modules/lqrc/plotting.py @@ -0,0 +1,159 @@ +import matplotlib as mpl +import matplotlib.pyplot as plt +import matplotlib.colors as mcolors +from matplotlib.colors import ListedColormap + +import numpy as np + + +def create_custom_bwr_colormap(): + # Define the colors for each segment + dark_blue = [0, 0, 0.5, 1] + light_blue = [0.5, 0.5, 1, 1] + white = [1, 1, 1, 1] + light_red = [1, 0.5, 0.5, 1] + dark_red = [0.5, 0, 0, 1] + + # Number of bins for each segment + n_bins = 128 + mid_band = 5 + + # Create the colormap segments + blue_segment = np.linspace(dark_blue, light_blue, n_bins // 2) + white_segment = np.tile(white, (mid_band, 1)) + red_segment = np.linspace(light_red, dark_red, n_bins // 2) + + # Stack segments to create the full colormap + colors = np.vstack((blue_segment, white_segment, red_segment)) + custom_bwr = ListedColormap(colors, name="custom_bwr") + + return custom_bwr + + +def plot_pendulum_multiple_critics_w_data( + x, + predictions, + targets, + title, + fn, + data, + colorbar_label="f(x)", + grid_size=64, + actions=None, +): + num_critics = len(x.keys()) + fig, axes = plt.subplots(nrows=2, ncols=num_critics, figsize=(4 * num_critics, 6)) + + # Determine global min and max error for consistent scaling + global_min_error = float("inf") + global_max_error = float("-inf") + global_min_prediction = float("inf") + global_max_prediction = float("-inf") + prediction_cmap = mpl.cm.get_cmap("viridis") + error_cmap = create_custom_bwr_colormap() + action_cmap = mpl.cm.get_cmap("viridis") + + for critic_name in x: + np_predictions = predictions[critic_name].detach().cpu().numpy().reshape(-1) + np_targets = ( + targets["Ground Truth MC Returns"].detach().cpu().numpy().reshape(-1) + ) + np_error = np_predictions - np_targets + global_min_error = min(global_min_error, np.min(np_error)) + global_max_error = max(global_max_error, np.max(np_error)) + global_min_prediction = min( + global_min_prediction, np.min(np_predictions), np.min(np_targets) + ) + global_max_prediction = max( + global_max_prediction, np.max(np_predictions), np.max(np_targets) + ) + error_norm = mcolors.TwoSlopeNorm( + vmin=global_min_error - 1e-5, vcenter=0, vmax=global_max_error + 1e-5 + ) + prediction_norm = mcolors.CenteredNorm( + vcenter=(global_max_prediction + global_min_prediction) / 2, + halfrange=(global_max_prediction - global_min_prediction) / 2, + ) + action_norm = mcolors.CenteredNorm() + + xcord = np.linspace(-2 * np.pi, 2 * np.pi, grid_size) + ycord = np.linspace(-5, 5, grid_size) + + for ix, critic_name in enumerate(x): + np_predictions = predictions[critic_name].detach().cpu().numpy().reshape(-1) + np_targets = ( + targets["Ground Truth MC Returns"].detach().cpu().numpy().reshape(-1) + ) + np_error = np_predictions - np_targets + + # Plot Predictions + axes[0, ix].imshow( + np_predictions.reshape(grid_size, grid_size).T, + origin="lower", + extent=(xcord.min(), xcord.max(), ycord.min(), ycord.max()), + cmap=prediction_cmap, + norm=prediction_norm, + ) + axes[0, ix].set_title(f"{critic_name} Prediction") + + if ix == 0: + continue + + if actions is None: + # Plot Errors + axes[1, ix].imshow( + np_error.reshape(grid_size, grid_size).T, + origin="lower", + extent=(xcord.min(), xcord.max(), ycord.min(), ycord.max()), + cmap=error_cmap, + norm=error_norm, + ) + axes[1, ix].set_title(f"{critic_name} Error") + ax1_mappable = mpl.cm.ScalarMappable(norm=error_norm, cmap=error_cmap) + + else: + # Plot Actions + np_actions = actions[critic_name].detach().cpu().numpy().reshape(-1) + axes[1, ix].imshow( + np_actions.reshape(grid_size, grid_size).T, + origin="lower", + extent=(xcord.min(), xcord.max(), ycord.min(), ycord.max()), + cmap=action_cmap, + norm=action_norm, + ) + axes[1, ix].set_title(f"{critic_name} Action") + ax1_mappable = mpl.cm.ScalarMappable(norm=action_norm, cmap=action_cmap) + + # Plot MC Trajectories + data = data.detach().cpu().numpy() + theta = data[:, :, 0] + omega = data[:, :, 1] + axes[1, 0].plot(theta, omega, lw=1) + axes[1, 0].set_xlabel("theta") + axes[1, 0].set_ylabel("theta_dot") + fig.suptitle(title, fontsize=16) + + # Ensure the axes are the same for all plots + for ax in axes.flat: + ax.set_xlim([xcord.min(), xcord.max()]) + ax.set_ylim([ycord.min(), ycord.max()]) + + plt.subplots_adjust( + top=0.9, bottom=0.1, left=0.1, right=0.9, hspace=0.4, wspace=0.3 + ) + + fig.colorbar( + mpl.cm.ScalarMappable(norm=prediction_norm, cmap=prediction_cmap), + ax=axes[0, :].ravel().tolist(), + shrink=0.95, + label=colorbar_label, + ) + fig.colorbar( + ax1_mappable, + ax=axes[1, :].ravel().tolist(), + shrink=0.95, + label=colorbar_label, + ) + + plt.savefig(f"{fn}.png") + print(f"Saved to {fn}.png") diff --git a/learning/runners/BaseRunner.py b/learning/runners/BaseRunner.py index 9e5ab2e3..698fdf1d 100644 --- a/learning/runners/BaseRunner.py +++ b/learning/runners/BaseRunner.py @@ -10,6 +10,7 @@ def __init__(self, env, train_cfg, device="cpu"): self.env = env self.parse_train_cfg(train_cfg) + self.log_storage = self.cfg["log_storage"] self.num_steps_per_env = self.cfg["num_steps_per_env"] self.save_interval = self.cfg["save_interval"] self.num_learning_iterations = self.cfg["max_iterations"] diff --git a/learning/runners/__init__.py b/learning/runners/__init__.py index b0f49217..a0840b06 100644 --- a/learning/runners/__init__.py +++ b/learning/runners/__init__.py @@ -32,4 +32,5 @@ from .on_policy_runner import OnPolicyRunner from .my_runner import MyRunner -from .old_policy_runner import OldPolicyRunner \ No newline at end of file +from .old_policy_runner import OldPolicyRunner +from .off_policy_runner import OffPolicyRunner \ No newline at end of file diff --git a/learning/runners/off_policy_runner.py b/learning/runners/off_policy_runner.py new file mode 100644 index 00000000..960112a4 --- /dev/null +++ b/learning/runners/off_policy_runner.py @@ -0,0 +1,299 @@ +import os +import torch +from tensordict import TensorDict + +from learning.utils import Logger + +from .BaseRunner import BaseRunner +from learning.modules import Critic, ChimeraActor +from learning.storage import ReplayBuffer +from learning.algorithms import SAC +from learning.utils import export_to_numpy + +logger = Logger() +storage = ReplayBuffer() + + +class OffPolicyRunner(BaseRunner): + def __init__(self, env, train_cfg, device="cpu"): + super().__init__(env, train_cfg, device) + logger.initialize( + self.env.num_envs, + self.env.dt, + self.cfg["max_iterations"], + self.device, + ) + + def _set_up_alg(self): + num_actor_obs = self.get_obs_size(self.actor_cfg["obs"]) + num_actions = self.get_action_size(self.actor_cfg["actions"]) + num_critic_obs = self.get_obs_size(self.critic_cfg["obs"]) + actor = ChimeraActor(num_actor_obs, num_actions, **self.actor_cfg) + critic_1 = Critic(num_critic_obs + num_actions, **self.critic_cfg) + critic_2 = Critic(num_critic_obs + num_actions, **self.critic_cfg) + target_critic_1 = Critic(num_critic_obs + num_actions, **self.critic_cfg) + target_critic_2 = Critic(num_critic_obs + num_actions, **self.critic_cfg) + + print(actor) + + self.alg = SAC( + actor, + critic_1, + critic_2, + target_critic_1, + target_critic_2, + device=self.device, + **self.alg_cfg, + ) + + def learn(self): + self.set_up_logger() + + rewards_dict = {} + + self.alg.switch_to_train() + actor_obs = self.get_obs(self.actor_cfg["obs"]) + critic_obs = self.get_obs(self.critic_cfg["obs"]) + actions = self.alg.act(actor_obs) + tot_iter = self.it + self.num_learning_iterations + + # * start up storage + transition = TensorDict({}, batch_size=self.env.num_envs, device=self.device) + transition.update( + { + "actor_obs": actor_obs, + "next_actor_obs": actor_obs, + "actions": self.alg.act(actor_obs), + "critic_obs": critic_obs, + "next_critic_obs": critic_obs, + "rewards": self.get_rewards({"termination": 0.0})["termination"], + "timed_out": self.get_timed_out(), + "terminated": self.get_terminated(), + "dones": self.get_timed_out() | self.get_terminated(), + "dof_pos": self.env.dof_pos, + "dof_vel": self.env.dof_vel, + } + ) + storage.initialize( + transition, + self.env.num_envs, + self.alg_cfg["storage_size"], + device=self.device, + ) + + # fill buffer + for _ in range(self.alg_cfg["initial_fill"]): + with torch.inference_mode(): + actions = torch.rand_like(actions) * 2 - 1 + self.set_actions( + self.actor_cfg["actions"], + actions, + self.actor_cfg["disable_actions"], + ) + transition.update( + { + "actor_obs": actor_obs, + "actions": actions, + "critic_obs": critic_obs, + "dof_pos": self.env.dof_pos, + "dof_vel": self.env.dof_vel, + } + ) + + self.env.step() + + actor_obs = self.get_noisy_obs( + self.actor_cfg["obs"], self.actor_cfg["noise"] + ) + critic_obs = self.get_obs(self.critic_cfg["obs"]) + + # * get time_outs + timed_out = self.get_timed_out() + terminated = self.get_terminated() + dones = timed_out | terminated + + self.update_rewards(rewards_dict, terminated) + total_rewards = torch.stack(tuple(rewards_dict.values())).sum(dim=0) + + transition.update( + { + "next_actor_obs": actor_obs, + "next_critic_obs": critic_obs, + "rewards": total_rewards, + "timed_out": timed_out, + "terminated": terminated, + "dones": dones, + } + ) + + storage.add_transitions(transition) + # print every 10% of initial fill + if (self.alg_cfg["initial_fill"] > 10) and ( + _ % (self.alg_cfg["initial_fill"] // 10) == 0 + ): + print(f"Filled {100 * _ / self.alg_cfg['initial_fill']}%") + + logger.tic("runtime") + for self.it in range(self.it + 1, tot_iter + 1): + logger.tic("iteration") + logger.tic("collection") + # * Rollout + with torch.inference_mode(): + for i in range(self.num_steps_per_env): + actions = self.alg.act(actor_obs) + self.set_actions( + self.actor_cfg["actions"], + actions, + self.actor_cfg["disable_actions"], + ) + + transition.update( + { + "actor_obs": actor_obs, + "actions": actions, + "critic_obs": critic_obs, + "dof_pos": self.env.dof_pos, + "dof_vel": self.env.dof_vel, + } + ) + + self.env.step() + + actor_obs = self.get_noisy_obs( + self.actor_cfg["obs"], self.actor_cfg["noise"] + ) + critic_obs = self.get_obs(self.critic_cfg["obs"]) + + # * get time_outs + timed_out = self.get_timed_out() + terminated = self.get_terminated() + dones = timed_out | terminated + + self.update_rewards(rewards_dict, terminated) + total_rewards = torch.stack(tuple(rewards_dict.values())).sum(dim=0) + + transition.update( + { + "next_actor_obs": actor_obs, + "next_critic_obs": critic_obs, + "rewards": total_rewards, + "timed_out": timed_out, + "dones": dones, + } + ) + storage.add_transitions(transition) + + logger.log_rewards(rewards_dict) + logger.log_rewards({"total_rewards": total_rewards}) + logger.finish_step(dones) + logger.toc("collection") + + logger.tic("learning") + self.alg.update(storage.get_data()) + logger.toc("learning") + logger.log_all_categories() + + logger.finish_iteration() + logger.toc("iteration") + logger.toc("runtime") + logger.print_to_terminal() + + if self.it % self.save_interval == 0: + self.save() + + # self.save() + + def update_rewards(self, rewards_dict, terminated): + rewards_dict.update( + self.get_rewards( + self.critic_cfg["reward"]["termination_weight"], mask=terminated + ) + ) + rewards_dict.update( + self.get_rewards( + self.critic_cfg["reward"]["weights"], + modifier=self.env.dt, + mask=~terminated, + ) + ) + + def set_up_logger(self): + logger.register_rewards(list(self.critic_cfg["reward"]["weights"].keys())) + logger.register_rewards( + list(self.critic_cfg["reward"]["termination_weight"].keys()) + ) + logger.register_rewards(["total_rewards"]) + logger.register_category( + "algorithm", + self.alg, + [ + "mean_critic_1_loss", + "mean_critic_2_loss", + "mean_actor_loss", + "mean_alpha_loss", + "alpha", + ], + ) + # logger.register_category("actor", self.alg.actor, ["action_std", "entropy"]) + + logger.attach_torch_obj_to_wandb( + (self.alg.actor, self.alg.critic_1, self.alg.critic_2) + ) + + def save(self): + os.makedirs(self.log_dir, exist_ok=True) + path = os.path.join(self.log_dir, "model_{}.pt".format(self.it)) + save_dict = { + "actor_state_dict": self.alg.actor.state_dict(), + "critic_1_state_dict": self.alg.critic_1.state_dict(), + "critic_2_state_dict": self.alg.critic_2.state_dict(), + "log_alpha": self.alg.log_alpha, + "actor_optimizer_state_dict": self.alg.actor_optimizer.state_dict(), + "critic_1_optimizer_state_dict": self.alg.critic_1_optimizer.state_dict(), + "critic_2_optimizer_state_dict": self.alg.critic_2_optimizer.state_dict(), + "log_alpha_optimizer_state_dict": self.alg.log_alpha_optimizer.state_dict(), + "iter": self.it, + } + torch.save(save_dict, path) + if self.log_storage: + path_data = os.path.join(self.log_dir, "data_{}".format(self.it)) + torch.save(storage.data.cpu(), path_data + ".pt") + export_to_numpy(storage.data, path_data + ".npz") + + def load(self, path, load_optimizer=True): + loaded_dict = torch.load(path) + self.alg.actor.load_state_dict(loaded_dict["actor_state_dict"]) + self.alg.critic_1.load_state_dict(loaded_dict["critic_1_state_dict"]) + self.alg.critic_2.load_state_dict(loaded_dict["critic_2_state_dict"]) + self.log_alpha = loaded_dict["log_alpha"] + if load_optimizer: + self.alg.actor_optimizer.load_state_dict( + loaded_dict["actor_optimizer_state_dict"] + ) + self.alg.critic_1_optimizer.load_state_dict( + loaded_dict["critic_1_optimizer_state_dict"] + ) + self.alg.critic_2_optimizer.load_state_dict( + loaded_dict["critic_2_optimizer_state_dict"] + ) + self.alg.log_alpha_optimizer.load_state_dict( + loaded_dict["log_alpha_optimizer_state_dict"] + ) + self.it = loaded_dict["iter"] + + def switch_to_eval(self): + self.alg.actor.eval() + self.alg.critic_1.eval() + self.alg.critic_2.eval() + + def get_inference_actions(self): + obs = self.get_noisy_obs(self.actor_cfg["obs"], self.actor_cfg["noise"]) + mean = self.alg.actor.forward(obs) + actions = torch.tanh(mean) + actions = (actions * self.alg.action_delta + self.alg.action_offset).clamp( + self.alg.action_min, self.alg.action_max + ) + return actions + + def export(self, path): + self.alg.actor.export(path) diff --git a/learning/runners/on_policy_runner.py b/learning/runners/on_policy_runner.py index 2cd6a9d8..b7fc3722 100644 --- a/learning/runners/on_policy_runner.py +++ b/learning/runners/on_policy_runner.py @@ -6,6 +6,7 @@ from .BaseRunner import BaseRunner from learning.storage import DictStorage +from learning.utils import export_to_numpy logger = Logger() storage = DictStorage() @@ -30,19 +31,30 @@ def learn(self): actor_obs = self.get_obs(self.actor_cfg["obs"]) critic_obs = self.get_obs(self.critic_cfg["obs"]) tot_iter = self.it + self.num_learning_iterations - self.save() # * start up storage transition = TensorDict({}, batch_size=self.env.num_envs, device=self.device) transition.update( { "actor_obs": actor_obs, - "actions": self.alg.act(actor_obs, critic_obs), + "next_actor_obs": actor_obs, + "actions": self.alg.act(actor_obs), "critic_obs": critic_obs, + "next_critic_obs": critic_obs, "rewards": self.get_rewards({"termination": 0.0})["termination"], - "dones": self.get_timed_out(), + "timed_out": self.get_timed_out(), + "terminated": self.get_terminated(), + "dones": self.get_timed_out() | self.get_terminated(), } ) + if self.log_storage: + transition.update( + { + "dof_pos": self.env.dof_pos, + "dof_vel": self.env.dof_vel, + } + ) + storage.initialize( transition, self.env.num_envs, @@ -50,6 +62,12 @@ def learn(self): device=self.device, ) + self.save() + # burn in observation normalization. + if self.actor_cfg["normalize_obs"] or self.critic_cfg["normalize_obs"]: + self.burn_in_normalization() + self.env.reset() + logger.tic("runtime") for self.it in range(self.it + 1, tot_iter + 1): logger.tic("iteration") @@ -57,7 +75,7 @@ def learn(self): # * Rollout with torch.inference_mode(): for i in range(self.num_steps_per_env): - actions = self.alg.act(actor_obs, critic_obs) + actions = self.alg.act(actor_obs) self.set_actions( self.actor_cfg["actions"], actions, @@ -89,11 +107,22 @@ def learn(self): transition.update( { + "next_actor_obs": actor_obs, + "next_critic_obs": critic_obs, "rewards": total_rewards, "timed_out": timed_out, + "terminated": terminated, "dones": dones, } ) + if self.log_storage: + transition.update( + { + "dof_pos": self.env.dof_pos, + "dof_vel": self.env.dof_vel, + } + ) + storage.add_transitions(transition) logger.log_rewards(rewards_dict) @@ -101,6 +130,9 @@ def learn(self): logger.finish_step(dones) logger.toc("collection") + if self.it % self.save_interval == 0: + self.save() + logger.tic("learning") self.alg.update(storage.data) storage.clear() @@ -112,9 +144,22 @@ def learn(self): logger.toc("runtime") logger.print_to_terminal() - if self.it % self.save_interval == 0: - self.save() - self.save() + # self.save() + + @torch.no_grad + def burn_in_normalization(self, n_iterations=100): + actor_obs = self.get_obs(self.actor_cfg["obs"]) + critic_obs = self.get_obs(self.critic_cfg["obs"]) + for _ in range(n_iterations): + actions = self.alg.act(actor_obs) + self.set_actions(self.actor_cfg["actions"], actions) + self.env.step() + actor_obs = self.get_noisy_obs( + self.actor_cfg["obs"], self.actor_cfg["noise"] + ) + critic_obs = self.get_obs(self.critic_cfg["obs"]) + self.alg.critic.evaluate(critic_obs) + self.env.reset() def update_rewards(self, rewards_dict, terminated): rewards_dict.update( @@ -156,6 +201,10 @@ def save(self): }, path, ) + if self.log_storage: + path_data = os.path.join(self.log_dir, "data_{}".format(self.it)) + torch.save(storage.data.cpu(), path_data + ".pt") + export_to_numpy(storage.data, path_data + ".npz") def load(self, path, load_optimizer=True): loaded_dict = torch.load(path) diff --git a/learning/storage/SE_storage.py b/learning/storage/SE_storage.py index a63f5f9f..0f4eaf41 100644 --- a/learning/storage/SE_storage.py +++ b/learning/storage/SE_storage.py @@ -53,9 +53,9 @@ def clear(self): def mini_batch_generator(self, num_mini_batches, num_epochs=8): """Generate mini batch for learning""" batch_size = self.num_envs * self.num_transitions_per_env - mini_batch_size = batch_size // num_mini_batches + batch_size = batch_size // num_mini_batches indices = torch.randperm( - num_mini_batches * mini_batch_size, + num_mini_batches * batch_size, requires_grad=False, device=self.device, ) @@ -69,8 +69,8 @@ def mini_batch_generator(self, num_mini_batches, num_epochs=8): for epoch in range(num_epochs): for i in range(num_mini_batches): - start = i * mini_batch_size - end = (i + 1) * mini_batch_size + start = i * batch_size + end = (i + 1) * batch_size batch_idx = indices[start:end] obs_batch = observations[batch_idx] diff --git a/learning/storage/__init__.py b/learning/storage/__init__.py index f10df85f..47f2593b 100644 --- a/learning/storage/__init__.py +++ b/learning/storage/__init__.py @@ -3,4 +3,5 @@ from .rollout_storage import RolloutStorage from .SE_storage import SERolloutStorage -from .dict_storage import DictStorage \ No newline at end of file +from .dict_storage import DictStorage +from .replay_buffer import ReplayBuffer \ No newline at end of file diff --git a/learning/storage/replay_buffer.py b/learning/storage/replay_buffer.py new file mode 100644 index 00000000..350a2c25 --- /dev/null +++ b/learning/storage/replay_buffer.py @@ -0,0 +1,54 @@ +import torch +from tensordict import TensorDict + + +class ReplayBuffer: + def __init__(self): + self.initialized = False + + def initialize( + self, + dummy_dict, + num_envs=2**12, + max_storage=2**17, + device="cpu", + ): + self.device = device + self.num_envs = num_envs + max_length = max_storage // num_envs + self.max_length = max_length + self.data = TensorDict({}, batch_size=(max_length, num_envs), device=device) + self.fill_count = 0 + self.add_index = 0 + + for key in dummy_dict.keys(): + if dummy_dict[key].dim() == 1: # if scalar + self.data[key] = torch.zeros( + (max_length, num_envs), + dtype=dummy_dict[key].dtype, + device=self.device, + ) + else: + self.data[key] = torch.zeros( + (max_length, num_envs, dummy_dict[key].shape[1]), + dtype=dummy_dict[key].dtype, + device=self.device, + ) + + @torch.inference_mode + def add_transitions(self, transition: TensorDict): + if self.fill_count >= self.max_length and self.add_index >= self.max_length: + self.add_index = 0 + self.data[self.add_index] = transition + self.fill_count += 1 + self.add_index += 1 + + def get_data(self): + return self.data[: min(self.fill_count, self.max_length), :] + + def clear(self): + self.fill_count = 0 + self.add_index = 0 + with torch.inference_mode(): + for tensor in self.data: + tensor.zero_() diff --git a/learning/storage/rollout_storage.py b/learning/storage/rollout_storage.py index 38a315d7..0f2885a3 100644 --- a/learning/storage/rollout_storage.py +++ b/learning/storage/rollout_storage.py @@ -184,9 +184,9 @@ def get_statistics(self): def mini_batch_generator(self, num_mini_batches=1, num_epochs=8): batch_size = self.num_envs * self.num_transitions_per_env - mini_batch_size = batch_size // num_mini_batches + batch_size = batch_size // num_mini_batches indices = torch.randperm( - num_mini_batches * mini_batch_size, + num_mini_batches * batch_size, requires_grad=False, device=self.device, ) @@ -207,8 +207,8 @@ def mini_batch_generator(self, num_mini_batches=1, num_epochs=8): for epoch in range(num_epochs): for i in range(num_mini_batches): - start = i * mini_batch_size - end = (i + 1) * mini_batch_size + start = i * batch_size + end = (i + 1) * batch_size batch_idx = indices[start:end] obs_batch = observations[batch_idx] diff --git a/learning/utils/__init__.py b/learning/utils/__init__.py index 15c674d6..b39dc16d 100644 --- a/learning/utils/__init__.py +++ b/learning/utils/__init__.py @@ -1,7 +1,7 @@ - from .utils import ( remove_zero_weighted_rewards, set_discount_from_horizon, + polyak_update ) from .dict_utils import * from .logger import Logger diff --git a/learning/utils/dict_utils.py b/learning/utils/dict_utils.py index 2d19e072..873dfdda 100644 --- a/learning/utils/dict_utils.py +++ b/learning/utils/dict_utils.py @@ -1,38 +1,45 @@ +import numpy as np import torch from tensordict import TensorDict @torch.no_grad def compute_MC_returns(data: TensorDict, gamma, critic=None): + # todo not as accurate as taking if critic is None: last_values = torch.zeros_like(data["rewards"][0]) else: last_values = critic.evaluate(data["critic_obs"][-1]) - data.update({"returns": torch.zeros_like(data["rewards"])}) - data["returns"][-1] = last_values * ~data["dones"][-1] + returns = torch.zeros_like(data["rewards"]) + returns[-1] = data["rewards"][-1] + gamma * last_values * ~data["terminated"][-1] for k in reversed(range(data["rewards"].shape[0] - 1)): not_done = ~data["dones"][k] - data["returns"][k] = ( - data["rewards"][k] + gamma * data["returns"][k + 1] * not_done - ) - data["returns"] = (data["returns"] - data["returns"].mean()) / ( - data["returns"].std() + 1e-8 - ) - return + returns[k] = data["rewards"][k] + gamma * returns[k + 1] * not_done + if critic is not None: + returns[k] += ( + gamma + * critic.evaluate(data["critic_obs"][k]) + * data["timed_out"][k] + * ~data["terminated"][k] + ) + return returns @torch.no_grad -def compute_generalized_advantages(data, gamma, lam, critic, last_values=None): - data.update({"values": critic.evaluate(data["critic_obs"])}) +def normalize(input, eps=1e-8): + return (input - input.mean()) / (input.std() + eps) - data.update({"advantages": torch.zeros_like(data["values"])}) +@torch.no_grad +def compute_generalized_advantages(data, gamma, lam, critic): + last_values = critic.evaluate(data["next_critic_obs"][-1]) + advantages = torch.zeros_like(data["values"]) if last_values is not None: # todo check this # since we don't have observations for the last step, need last value plugged in not_done = ~data["dones"][-1] - data["advantages"][-1] = ( + advantages[-1] = ( data["rewards"][-1] + gamma * data["values"][-1] * data["timed_out"][-1] + gamma * last_values * not_done @@ -47,23 +54,30 @@ def compute_generalized_advantages(data, gamma, lam, critic, last_values=None): + gamma * data["values"][k + 1] * not_done - data["values"][k] ) - data["advantages"][k] = ( - td_error + gamma * lam * not_done * data["advantages"][k + 1] - ) - - data["returns"] = data["advantages"] + data["values"] + advantages[k] = td_error + gamma * lam * not_done * advantages[k + 1] - data["advantages"] = (data["advantages"] - data["advantages"].mean()) / ( - data["advantages"].std() + 1e-8 - ) + return advantages # todo change num_epochs to num_batches @torch.no_grad -def create_uniform_generator(data, batch_size, num_epochs, keys=None): +def create_uniform_generator( + data, batch_size, num_epochs=1, max_gradient_steps=None, keys=None +): n, m = data.shape total_data = n * m + + if batch_size > total_data: + Warning("Batch size is larger than total data, using available data only.") + batch_size = total_data + num_batches_per_epoch = total_data // batch_size + if max_gradient_steps: + if max_gradient_steps < num_batches_per_epoch: + num_batches_per_epoch = max_gradient_steps + num_epochs = max_gradient_steps // num_batches_per_epoch + num_epochs = max(num_epochs, 1) + for epoch in range(num_epochs): indices = torch.randperm(total_data, device=data.device) for i in range(num_batches_per_epoch): @@ -71,3 +85,12 @@ def create_uniform_generator(data, batch_size, num_epochs, keys=None): indices[i * batch_size : (i + 1) * batch_size] ] yield batched_data + + +@torch.no_grad +def export_to_numpy(data, path): + # check if path ends iwth ".npz", and if not append it. + if not path.endswith(".npz"): + path += ".npz" + np.savez_compressed(path, **{key: val.cpu().numpy() for key, val in data.items()}) + return diff --git a/learning/utils/utils.py b/learning/utils/utils.py index 7a54c65a..d4eac980 100644 --- a/learning/utils/utils.py +++ b/learning/utils/utils.py @@ -17,3 +17,9 @@ def set_discount_from_horizon(dt, horizon): discount_factor = 1 - 1 / discrete_time_horizon return discount_factor + + +def polyak_update(online, target, polyak_factor): + for op, tp in zip(online.parameters(), target.parameters()): + tp.data.copy_((1.0 - polyak_factor) * op.data + polyak_factor * tp.data) + return target diff --git a/scripts/play.py b/scripts/play.py index 4486daae..37882f7e 100644 --- a/scripts/play.py +++ b/scripts/play.py @@ -9,16 +9,16 @@ def setup(args): env_cfg, train_cfg = task_registry.create_cfgs(args) - env_cfg.env.num_envs = min(env_cfg.env.num_envs, 16) + env_cfg.env.num_envs = 32 if hasattr(env_cfg, "push_robots"): env_cfg.push_robots.toggle = False if hasattr(env_cfg, "commands"): env_cfg.commands.resampling_time = 9999 - env_cfg.env.episode_length_s = 9999 + env_cfg.env.episode_length_s = 50 env_cfg.env.num_projectiles = 20 task_registry.make_gym_and_sim() + env_cfg.init_state.reset_mode = "reset_to_range" env = task_registry.make_env(args.task, env_cfg) - env.cfg.init_state.reset_mode = "reset_to_basic" train_cfg.runner.resume = True train_cfg.logging.enable_local_saving = False runner = task_registry.make_alg_runner(env, train_cfg) diff --git a/scripts/train.py b/scripts/train.py index 283717e8..5b55f188 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -1,5 +1,5 @@ from gym.envs import __init__ # noqa: F401 -from gym.utils import get_args, task_registry, randomize_episode_counters +from gym.utils import get_args, task_registry # , randomize_episode_counters from gym.utils.logging_and_saving import wandb_singleton from gym.utils.logging_and_saving import local_code_save_helper @@ -12,9 +12,7 @@ def setup(): task_registry.make_gym_and_sim() wandb_helper.setup_wandb(env_cfg=env_cfg, train_cfg=train_cfg, args=args) env = task_registry.make_env(name=args.task, env_cfg=env_cfg) - randomize_episode_counters(env) - - randomize_episode_counters(env) + # randomize_episode_counters(env) policy_runner = task_registry.make_alg_runner(env, train_cfg) local_code_save_helper.save_local_files_to_logs(train_cfg.log_dir) diff --git a/scripts/visualize_ppo.py b/scripts/visualize_ppo.py new file mode 100644 index 00000000..fd4e8aeb --- /dev/null +++ b/scripts/visualize_ppo.py @@ -0,0 +1,92 @@ +import matplotlib.pyplot as plt +from learning.modules.critic import Critic + +from learning.utils import ( + compute_generalized_advantages, + compute_MC_returns, +) +from learning.modules.lqrc.plotting import plot_pendulum_multiple_critics_w_data +from gym import LEGGED_GYM_ROOT_DIR +import os +import shutil + +import torch + +DEVICE = "cpu" + +# * Setup +experiment_name = "pendulum" +run_name = "obs_no_norm" + +log_dir = os.path.join(LEGGED_GYM_ROOT_DIR, "logs", experiment_name, run_name) +plot_dir = os.path.join(LEGGED_GYM_ROOT_DIR, "logs", "V_plots", run_name) +os.makedirs(plot_dir, exist_ok=True) + +# * Critic Params +name = "PPO" +n_obs = 3 # [dof_pos_obs, dof_vel] +hidden_dims = [128, 64, 32] +activation = "tanh" +normalize_obs = False +n_envs = 4096 + +# * Params +gamma = 0.95 +lam = 1.0 +visualize_steps = 10 # just to show rollouts +n_trajs = 64 +rand_perm = torch.randperm(n_envs) +traj_idx = rand_perm[0:n_trajs] +test_idx = rand_perm[n_trajs : n_trajs + 1000] + +it_delta = 20 +it_total = 100 +it_range = range(it_delta, it_total + 1, it_delta) + +for it in it_range: + # load data + base_data = torch.load(os.path.join(log_dir, "data_{}.pt".format(it))).to(DEVICE) + + dof_pos = base_data["dof_pos"].detach().clone() + dof_vel = base_data["dof_vel"].detach().clone() + graphing_obs = torch.cat((dof_pos, dof_vel), dim=2) + + # compute ground-truth + graphing_data = {data_name: {} for data_name in ["obs", "values", "returns"]} + + episode_rollouts = compute_MC_returns(base_data, gamma) + print(f"Initializing value offset to: {episode_rollouts.mean().item()}") + graphing_data["obs"]["Ground Truth MC Returns"] = graphing_obs[0, :] + graphing_data["values"]["Ground Truth MC Returns"] = episode_rollouts[0, :] + graphing_data["returns"]["Ground Truth MC Returns"] = episode_rollouts[0, :] + + # load critic + model = torch.load(os.path.join(log_dir, "model_{}.pt".format(it))) + critic = Critic(n_obs, hidden_dims, activation, normalize_obs).to(DEVICE) + critic.load_state_dict(model["critic_state_dict"]) + + # compute values and returns + data = base_data.detach().clone() + data["values"] = critic.evaluate(data["critic_obs"]) + data["advantages"] = compute_generalized_advantages(data, gamma, lam, critic) + data["returns"] = data["advantages"] + data["values"] + + with torch.no_grad(): + graphing_data["obs"][name] = graphing_obs[0, :] + graphing_data["values"][name] = critic.evaluate(data[0, :]["critic_obs"]) + graphing_data["returns"][name] = data[0, :]["returns"] + + # generate plots + plot_pendulum_multiple_critics_w_data( + graphing_data["obs"], + graphing_data["values"], + graphing_data["returns"], + title=f"iteration{it}", + fn=plot_dir + f"/PPO_CRITIC_it{it}", + data=graphing_obs[:visualize_steps, traj_idx], + ) + + plt.close() + +this_file = os.path.join(LEGGED_GYM_ROOT_DIR, "scripts", "visualize_ppo.py") +shutil.copy(this_file, os.path.join(plot_dir, os.path.basename(this_file))) diff --git a/scripts/visualize_sac.py b/scripts/visualize_sac.py new file mode 100644 index 00000000..d24b0d3a --- /dev/null +++ b/scripts/visualize_sac.py @@ -0,0 +1,136 @@ +import matplotlib.pyplot as plt +from learning.modules.critic import Critic + +from learning.utils import ( + compute_generalized_advantages, + compute_MC_returns, +) +from learning.modules.lqrc.plotting import plot_pendulum_multiple_critics_w_data +from gym import LEGGED_GYM_ROOT_DIR +import os +import shutil +import numpy as np +import torch + +DEVICE = "cpu" + +# * Setup +experiment_name = "sac_pendulum" +run_name = "1024envs" + +log_dir = os.path.join(LEGGED_GYM_ROOT_DIR, "logs", experiment_name, run_name) +plot_dir = os.path.join(LEGGED_GYM_ROOT_DIR, "logs", "V_plots_sac", run_name) +os.makedirs(plot_dir, exist_ok=True) + +# * Critic Params +name = "SAC" +n_obs = 4 # [dof_pos_obs, dof_vel, action] +hidden_dims = [128, 64, 32] +activation = "elu" +normalize_obs = False +n_envs = 1024 + +# * Params +gamma = 0.95 +lam = 1.0 +episode_steps = 100 +visualize_steps = 10 # just to show rollouts +n_trajs = 64 +rand_perm = torch.randperm(n_envs) +traj_idx = rand_perm[0:n_trajs] +test_idx = rand_perm[n_trajs : n_trajs + 1000] + +it_delta = 2000 +it_total = 30_000 +it_range = range(it_delta, it_total + 1, it_delta) + +for it in it_range: + # load data + base_data = torch.load(os.path.join(log_dir, "data_{}.pt".format(it))).to(DEVICE) + # TODO: handle buffer differently? + base_data = base_data[-episode_steps:, :] + + dof_pos = base_data["dof_pos"].detach().clone() + dof_vel = base_data["dof_vel"].detach().clone() + graphing_obs = torch.cat((dof_pos, dof_vel), dim=2) + + # compute ground-truth + graphing_data = { + data_name: {} for data_name in ["obs", "values", "returns", "actions"] + } + + episode_rollouts = compute_MC_returns(base_data, gamma) + print(f"Initializing value offset to: {episode_rollouts.mean().item()}") + graphing_data["obs"]["Ground Truth MC Returns"] = graphing_obs[0, :] + graphing_data["values"]["Ground Truth MC Returns"] = episode_rollouts[0, :] + graphing_data["returns"]["Ground Truth MC Returns"] = episode_rollouts[0, :] + + # load model which includes both critics + model = torch.load(os.path.join(log_dir, "model_{}.pt".format(it))) + + # line search to find best action + data = base_data.detach().clone() + data_shape = data["critic_obs"].shape # [a, b, 3] + + # create a tensor of actions to evaluate + N = 41 + actions_space = torch.linspace(-2, 2, N).to(DEVICE) + actions = ( + actions_space.unsqueeze(0).unsqueeze(1).unsqueeze(-1) + ) # Shape: [1, 1, N, 1] + + # repeat the actions for each entry in data + actions = actions.repeat(data_shape[0], data_shape[1], 1, 1) # Shape: [a, b, N, 1] + + # repeat the data for each action + critic_obs = ( + data["critic_obs"].unsqueeze(2).repeat(1, 1, N, 1) + ) # Shape: [a, b, N, 3] + + # concatenate the actions to the data + critic_obs = torch.cat((critic_obs, actions), dim=3) # Shape: [a, b, N, 4] + + # evaluate the critic for all actions and entries + for critic_str in ["critic_1", "critic_2"]: + critic_name = name + " " + critic_str + critic = Critic(n_obs, hidden_dims, activation, normalize_obs).to(DEVICE) + critic.load_state_dict(model[critic_str + "_state_dict"]) + q_values = critic.evaluate(critic_obs) # Shape: [a, b, N] + + # find the best action for each entry + best_actions_idx = torch.argmax(q_values, dim=2) # Shape: [a, b] + best_actions = actions_space[best_actions_idx] # Shape: [a, b] + best_actions = best_actions.unsqueeze(-1) # Shape: [a, b, 1] + + # compute values and returns + best_obs = torch.cat( + (data["critic_obs"], best_actions), dim=2 + ) # Shape: [a, b, 4] + data["values"] = critic.evaluate(best_obs) # Shape: [a, b] + data["next_critic_obs"] = best_obs # needed for GAE + data["advantages"] = compute_generalized_advantages(data, gamma, lam, critic) + data["returns"] = data["advantages"] + data["values"] + + with torch.no_grad(): + graphing_data["obs"][critic_name] = graphing_obs[0, :] + graphing_data["values"][critic_name] = critic.evaluate(best_obs[0, :]) + graphing_data["returns"][critic_name] = data[0, :]["returns"] + graphing_data["actions"][critic_name] = best_actions[0, :] + + # generate plots + grid_size = int(np.sqrt(n_envs)) + plot_pendulum_multiple_critics_w_data( + graphing_data["obs"], + graphing_data["values"], + graphing_data["returns"], + title=f"iteration{it}", + fn=plot_dir + f"/{name}_CRITIC_it{it}", + data=graphing_obs[:visualize_steps, traj_idx], + grid_size=grid_size, + actions=graphing_data["actions"], + ) + + plt.close() + +this_file = os.path.join(LEGGED_GYM_ROOT_DIR, "scripts", "visualize_sac.py") +shutil.copy(this_file, os.path.join(plot_dir, os.path.basename(this_file)))