From 1f5f73d770c718733a57c48d4eadd397d0ab93f4 Mon Sep 17 00:00:00 2001 From: Theresa Eimer Date: Thu, 31 Jul 2025 17:15:22 +0200 Subject: [PATCH 01/14] configs are now running --- .../dacbench/function_approximation_benchmark.yaml | 7 ------- .../environment/pufferlib_ocean/memory.yaml | 7 ------- .../environment/pufferlib_ocean/password.yaml | 2 +- .../environment/pufferlib_ocean/squared.yaml | 4 ++-- .../environment/pufferlib_ocean/stochastic.yaml | 2 +- mighty/configs/exploration/ez_greedy.yaml | 3 ++- mighty/mighty_agents/base_agent.py | 5 ++++- mighty/mighty_models/dqn.py | 1 + mighty/mighty_utils/wrappers.py | 14 ++++++++------ 9 files changed, 19 insertions(+), 26 deletions(-) delete mode 100644 mighty/configs/environment/dacbench/function_approximation_benchmark.yaml delete mode 100644 mighty/configs/environment/pufferlib_ocean/memory.yaml diff --git a/mighty/configs/environment/dacbench/function_approximation_benchmark.yaml b/mighty/configs/environment/dacbench/function_approximation_benchmark.yaml deleted file mode 100644 index 36087074..00000000 --- a/mighty/configs/environment/dacbench/function_approximation_benchmark.yaml +++ /dev/null @@ -1,7 +0,0 @@ -# @package _global_ - -num_steps: 1e5 -env: FunctionApproximationBenchmark -env_kwargs: {benchmark: true, dimension: 1} -env_wrappers: [] -num_envs: 16 \ No newline at end of file diff --git a/mighty/configs/environment/pufferlib_ocean/memory.yaml b/mighty/configs/environment/pufferlib_ocean/memory.yaml deleted file mode 100644 index 3c6ff1fd..00000000 --- a/mighty/configs/environment/pufferlib_ocean/memory.yaml +++ /dev/null @@ -1,7 +0,0 @@ -# @package _global_ - -num_steps: 50_000 -env: pufferlib.ocean.memory -env_kwargs: {} -env_wrappers: [] -num_envs: 1 \ No newline at end of file diff --git a/mighty/configs/environment/pufferlib_ocean/password.yaml b/mighty/configs/environment/pufferlib_ocean/password.yaml index 7c36a6c6..2dafd95e 100644 --- a/mighty/configs/environment/pufferlib_ocean/password.yaml +++ b/mighty/configs/environment/pufferlib_ocean/password.yaml @@ -4,4 +4,4 @@ num_steps: 50_000 env: pufferlib.ocean.password env_kwargs: {} env_wrappers: [] -num_envs: 1 \ No newline at end of file +num_envs: 64 \ No newline at end of file diff --git a/mighty/configs/environment/pufferlib_ocean/squared.yaml b/mighty/configs/environment/pufferlib_ocean/squared.yaml index 10abb6cb..7da47bad 100644 --- a/mighty/configs/environment/pufferlib_ocean/squared.yaml +++ b/mighty/configs/environment/pufferlib_ocean/squared.yaml @@ -3,5 +3,5 @@ num_steps: 50_000 env: pufferlib.ocean.squared env_kwargs: {} -env_wrappers: [mighty.utils.wrappers.FlattenVecObs] -num_envs: 1 \ No newline at end of file +env_wrappers: [mighty.mighty_utils.wrappers.FlattenVecObs] +num_envs: 64 \ No newline at end of file diff --git a/mighty/configs/environment/pufferlib_ocean/stochastic.yaml b/mighty/configs/environment/pufferlib_ocean/stochastic.yaml index d032ccce..4bb8008d 100644 --- a/mighty/configs/environment/pufferlib_ocean/stochastic.yaml +++ b/mighty/configs/environment/pufferlib_ocean/stochastic.yaml @@ -4,4 +4,4 @@ num_steps: 50_000 env: pufferlib.ocean.stochastic env_kwargs: {} env_wrappers: [] -num_envs: 1 \ No newline at end of file +num_envs: 64 \ No newline at end of file diff --git a/mighty/configs/exploration/ez_greedy.yaml b/mighty/configs/exploration/ez_greedy.yaml index 2e61df6b..45df0c10 100644 --- a/mighty/configs/exploration/ez_greedy.yaml +++ b/mighty/configs/exploration/ez_greedy.yaml @@ -1,3 +1,4 @@ # @package _global_ algorithm_kwargs: - policy_class: mighty.mighty_exploration.EZGreedy \ No newline at end of file + policy_class: mighty.mighty_exploration.EZGreedy + policy_kwargs: null \ No newline at end of file diff --git a/mighty/mighty_agents/base_agent.py b/mighty/mighty_agents/base_agent.py index 790a7c74..eacee61c 100644 --- a/mighty/mighty_agents/base_agent.py +++ b/mighty/mighty_agents/base_agent.py @@ -13,7 +13,7 @@ import pandas as pd import torch import wandb -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf from rich import print from rich.layout import Layout from rich.live import Live @@ -323,6 +323,9 @@ def initialize_agent(self) -> None: if isinstance(self.buffer_class, type) and issubclass( self.buffer_class, PrioritizedReplay ): + self.buffer_kwargs = OmegaConf.to_container( + self.buffer_kwargs, resolve=True + ) # 1) Get observation-space shape try: obs_space = self.env.single_observation_space diff --git a/mighty/mighty_models/dqn.py b/mighty/mighty_models/dqn.py index 96f78863..9f7eaa97 100644 --- a/mighty/mighty_models/dqn.py +++ b/mighty/mighty_models/dqn.py @@ -25,6 +25,7 @@ def __init__(self, num_actions, obs_size, dueling=False, **kwargs): feature_extractor_kwargs.update(kwargs["feature_extractor_kwargs"]) # Make feature extractor + print(obs_size) self.feature_extractor, self.output_size = make_feature_extractor( **feature_extractor_kwargs ) diff --git a/mighty/mighty_utils/wrappers.py b/mighty/mighty_utils/wrappers.py index f8bc0747..70b93ed3 100644 --- a/mighty/mighty_utils/wrappers.py +++ b/mighty/mighty_utils/wrappers.py @@ -106,19 +106,21 @@ def __init__(self, env): """ super().__init__(env) - self.n_actions = len(self.env.single_action_space.nvec) - self.single_action_space = gym.spaces.Discrete( - np.prod(self.env.single_action_space.nvec) - ) + self.n_actions = len(self.env.action_space.nvec) + self.action_mapper = {} for idx, prod_idx in zip( - range(np.prod(self.env.single_action_space.nvec)), + range(np.prod(self.env.action_space.nvec)), itertools.product( - *[np.arange(val) for val in self.env.single_action_space.nvec] + *[np.arange(val) for val in self.env.action_space.nvec] ), ): self.action_mapper[idx] = prod_idx + self.action_space = gym.spaces.Discrete( + int(np.prod(self.env.action_space.nvec)) + ) + def step(self, action): """Maps discrete action value to array.""" action = [self.action_mapper[a] for a in action] From a3cce9f0e91047e6794fe0c34d0f7646d26fbcce Mon Sep 17 00:00:00 2001 From: Theresa Eimer Date: Thu, 31 Jul 2025 17:15:31 +0200 Subject: [PATCH 02/14] delete superflous configs --- mighty/configs/ppo_smac.yaml | 51 ------------------- mighty/configs/sac_smac.yaml | 51 ------------------- mighty/configs/search_space/dqn_rs.yaml | 15 ------ mighty/configs/search_space/dqn_template.yaml | 11 ---- mighty/configs/search_space/ppo_rs.yaml | 41 --------------- mighty/configs/search_space/sac_rs.yaml | 9 ---- mighty/configs/sweep_ppo_pbt.yaml | 44 ---------------- mighty/configs/sweep_rs.yaml | 38 -------------- 8 files changed, 260 deletions(-) delete mode 100644 mighty/configs/ppo_smac.yaml delete mode 100644 mighty/configs/sac_smac.yaml delete mode 100644 mighty/configs/search_space/dqn_rs.yaml delete mode 100644 mighty/configs/search_space/dqn_template.yaml delete mode 100644 mighty/configs/search_space/ppo_rs.yaml delete mode 100644 mighty/configs/search_space/sac_rs.yaml delete mode 100644 mighty/configs/sweep_ppo_pbt.yaml delete mode 100644 mighty/configs/sweep_rs.yaml diff --git a/mighty/configs/ppo_smac.yaml b/mighty/configs/ppo_smac.yaml deleted file mode 100644 index 40da7c69..00000000 --- a/mighty/configs/ppo_smac.yaml +++ /dev/null @@ -1,51 +0,0 @@ -defaults: - - _self_ - - /cluster: local - - algorithm: ppo_mujoco - - environment: gymnasium/pendulum - - search_space: ppo_rs - - override hydra/job_logging: colorlog - - override hydra/hydra_logging: colorlog - - override hydra/help: mighty_help - - override hydra/sweeper: HyperSMAC # use Hypersweeper’s RandomSearch - -runner: standard -debug: false -seed: 0 -output_dir: sweep_smac -wandb_project: null -tensorboard_file: null -experiment_name: ppo_smac - -budget: 200000 # Budget for the hyperparameter search - -algorithm_kwargs: {} - -# Training -eval_every_n_steps: 1e4 # After how many steps to evaluate. -n_episodes_eval: 10 -checkpoint: null # Path to load model checkpoint -save_model_every_n_steps: 5e5 - -hydra: - sweeper: - n_trials: 10 - budget_variable: budget - sweeper_kwargs: - seeds: [0] - optimizer_kwargs: - smac_facade: - _target_: smac.facade.blackbox_facade.BlackBoxFacade - _partial_: true - logging_level: 20 # 10 DEBUG, 20 INFO - scenario: - seed: 42 - n_trials: ${hydra.sweeper.n_trials} - deterministic: true - n_workers: 4 - output_directory: ${hydra.sweep.dir} - search_space: ${search_space} - run: - dir: ${output_dir}/${experiment_name}_${seed} - sweep: - dir: ${output_dir}/${experiment_name}_${seed} diff --git a/mighty/configs/sac_smac.yaml b/mighty/configs/sac_smac.yaml deleted file mode 100644 index 613efd26..00000000 --- a/mighty/configs/sac_smac.yaml +++ /dev/null @@ -1,51 +0,0 @@ -defaults: - - _self_ - - /cluster: local - - algorithm: sac_mujoco - - environment: gymnasium/pendulum - - search_space: sac_rs - - override hydra/job_logging: colorlog - - override hydra/hydra_logging: colorlog - - override hydra/help: mighty_help - - override hydra/sweeper: HyperSMAC # use Hypersweeper’s RandomSearch - -runner: standard -debug: false -seed: 0 -output_dir: sweep_smac -wandb_project: null -tensorboard_file: null -experiment_name: ppo_smac - -budget: 200000 # Budget for the hyperparameter search - -algorithm_kwargs: {} - -# Training -eval_every_n_steps: 1e4 # After how many steps to evaluate. -n_episodes_eval: 10 -checkpoint: null # Path to load model checkpoint -save_model_every_n_steps: 5e5 - -hydra: - sweeper: - n_trials: 10 - budget_variable: budget - sweeper_kwargs: - seeds: [0] - optimizer_kwargs: - smac_facade: - _target_: smac.facade.blackbox_facade.BlackBoxFacade - _partial_: true - logging_level: 20 # 10 DEBUG, 20 INFO - scenario: - seed: 42 - n_trials: ${hydra.sweeper.n_trials} - deterministic: true - n_workers: 4 - output_directory: ${hydra.sweep.dir} - search_space: ${search_space} - run: - dir: ${output_dir}/${experiment_name}_${seed} - sweep: - dir: ${output_dir}/${experiment_name}_${seed} diff --git a/mighty/configs/search_space/dqn_rs.yaml b/mighty/configs/search_space/dqn_rs.yaml deleted file mode 100644 index 2a910e72..00000000 --- a/mighty/configs/search_space/dqn_rs.yaml +++ /dev/null @@ -1,15 +0,0 @@ -hyperparameters: - algorithm_kwargs.learning_rate: - type: uniform_float - upper: 0.1 - lower: 1.0e-06 - default: 0.0003 - log: true - algorithm_kwargs.gamma: - type: uniform_float - lower: 0.9 - upper: 0.9999 - log: false - algorithm_kwargs.batch_size: - type: categorical - choices: [32, 64, 128, 256] \ No newline at end of file diff --git a/mighty/configs/search_space/dqn_template.yaml b/mighty/configs/search_space/dqn_template.yaml deleted file mode 100644 index 51d23767..00000000 --- a/mighty/configs/search_space/dqn_template.yaml +++ /dev/null @@ -1,11 +0,0 @@ -# @package hydra.sweeper.search_space -hyperparameters: - algorithm_kwargs.n_units: - type: ordinal - sequence: [4,8,16,32,64,128,256,512] - algorithm_kwargs.soft_update_weight: - type: uniform_float - lower: 0 - upper: 1 - default_value: 1 - diff --git a/mighty/configs/search_space/ppo_rs.yaml b/mighty/configs/search_space/ppo_rs.yaml deleted file mode 100644 index 9ae950a9..00000000 --- a/mighty/configs/search_space/ppo_rs.yaml +++ /dev/null @@ -1,41 +0,0 @@ -# configs/search_space/ppo_rs.yaml -hyperparameters: - # match the keys under algorithm_kwargs in your PPO config - algorithm_kwargs.learning_rate: - type: uniform_float - lower: 1e-5 - upper: 1e-3 - log: true - algorithm_kwargs.batch_size: - type: categorical - choices: [8192, 16384, 32768] - algorithm_kwargs.n_gradient_steps: - type: uniform_int - lower: 1 - upper: 20 - log: false - algorithm_kwargs.gamma: - type: uniform_float - lower: 0.9 - upper: 0.9999 - log: false - algorithm_kwargs.ppo_clip: - type: uniform_float - lower: 0.1 - upper: 0.3 - log: false - algorithm_kwargs.value_loss_coef: - type: uniform_float - lower: 0.1 - upper: 1.0 - log: false - algorithm_kwargs.entropy_coef: - type: uniform_float - lower: 0.0 - upper: 0.1 - log: false - algorithm_kwargs.max_grad_norm: - type: uniform_float - lower: 0.1 - upper: 1.0 - log: false diff --git a/mighty/configs/search_space/sac_rs.yaml b/mighty/configs/search_space/sac_rs.yaml deleted file mode 100644 index fdaa3d87..00000000 --- a/mighty/configs/search_space/sac_rs.yaml +++ /dev/null @@ -1,9 +0,0 @@ -hyperparameters: - algorithm_kwargs.learning_rate: - type: uniform_float - lower: 0.000001 - upper: 0.01 - log: true - algorithm_kwargs.batch_size: - type: categorical - choices: [32, 64, 128, 256] \ No newline at end of file diff --git a/mighty/configs/sweep_ppo_pbt.yaml b/mighty/configs/sweep_ppo_pbt.yaml deleted file mode 100644 index 3aba687f..00000000 --- a/mighty/configs/sweep_ppo_pbt.yaml +++ /dev/null @@ -1,44 +0,0 @@ -defaults: - - _self_ - - /cluster: local - - algorithm: ppo - - environment: gymnasium/pendulum - - search_space: ppo_rs - - override hydra/job_logging: colorlog - - override hydra/hydra_logging: colorlog - - override hydra/help: mighty_help - - override hydra/sweeper: HyperPBT # use Hypersweeper’s RandomSearch - -runner: standard -debug: false -seed: 0 -output_dir: sweep_pbt -wandb_project: null -tensorboard_file: null -experiment_name: mighty_experiment - -algorithm_kwargs: {} - -# Training -eval_every_n_steps: 1e4 # After how many steps to evaluate. -n_episodes_eval: 10 -checkpoint: null # Path to load model checkpoint -save_model_every_n_steps: 5e5 - -hydra: - sweeper: - budget: 100000 - budget_variable: 100000 - loading_variable: load - saving_variable: save - sweeper_kwargs: - optimizer_kwargs: - population_size: 10 - config_interval: 1e4 - checkpoint_tf: true - load_tf: true - search_space: ${search_space} - run: - dir: ${output_dir}/${experiment_name}_${seed} - sweep: - dir: ${output_dir}/${experiment_name}_${seed} \ No newline at end of file diff --git a/mighty/configs/sweep_rs.yaml b/mighty/configs/sweep_rs.yaml deleted file mode 100644 index 650c3545..00000000 --- a/mighty/configs/sweep_rs.yaml +++ /dev/null @@ -1,38 +0,0 @@ -defaults: - - _self_ - - /cluster: local - - algorithm: ppo - - environment: gymnasium/pendulum - - search_space: ppo_rs - - override hydra/job_logging: colorlog - - override hydra/hydra_logging: colorlog - - override hydra/help: mighty_help - - override hydra/sweeper: HyperRS # use Hypersweeper’s RandomSearch - -runner: standard -debug: false -seed: 0 -output_dir: sweep_rs -wandb_project: null -tensorboard_file: null -experiment_name: dqn_sweep - -algorithm_kwargs: {} - -# Training -eval_every_n_steps: 1e4 # After how many steps to evaluate. -n_episodes_eval: 10 -checkpoint: null # Path to load model checkpoint -save_model_every_n_steps: 5e5 - -hydra: - sweeper: - n_trials: 10 - sweeper_kwargs: - max_parallelization: 0.8 - max_budget: 100000 - search_space: ${search_space} - run: - dir: ${output_dir}/${experiment_name}_${seed} - sweep: - dir: ${output_dir}/${experiment_name}_${seed} \ No newline at end of file From d0a425b6cd4b1dd2b7a9e7834804b424537bad38 Mon Sep 17 00:00:00 2001 From: Theresa Eimer Date: Thu, 31 Jul 2025 17:21:56 +0200 Subject: [PATCH 03/14] check for buffer kwargs type --- mighty/mighty_agents/base_agent.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mighty/mighty_agents/base_agent.py b/mighty/mighty_agents/base_agent.py index eacee61c..70cc1c89 100644 --- a/mighty/mighty_agents/base_agent.py +++ b/mighty/mighty_agents/base_agent.py @@ -323,9 +323,10 @@ def initialize_agent(self) -> None: if isinstance(self.buffer_class, type) and issubclass( self.buffer_class, PrioritizedReplay ): - self.buffer_kwargs = OmegaConf.to_container( - self.buffer_kwargs, resolve=True - ) + if isinstance(self.buffer_kwargs, DictConfig): + self.buffer_kwargs = OmegaConf.to_container( + self.buffer_kwargs, resolve=True + ) # 1) Get observation-space shape try: obs_space = self.env.single_observation_space From 488b439417743bb7a8ceb800c4dfcc05caf18a13 Mon Sep 17 00:00:00 2001 From: Aditya Mohan Date: Wed, 6 Aug 2025 23:06:33 +0200 Subject: [PATCH 04/14] updated sac to test if it works --- mighty/configs/algorithm/sac.yaml | 36 ++++++----- mighty/mighty_agents/sac.py | 2 +- .../mighty_exploration/stochastic_policy.py | 9 +-- mighty/mighty_models/sac.py | 14 ++++- mighty/mighty_update/sac_update.py | 60 +++++++++---------- 5 files changed, 67 insertions(+), 54 deletions(-) diff --git a/mighty/configs/algorithm/sac.yaml b/mighty/configs/algorithm/sac.yaml index f1804932..07c80294 100644 --- a/mighty/configs/algorithm/sac.yaml +++ b/mighty/configs/algorithm/sac.yaml @@ -7,10 +7,11 @@ algorithm_kwargs: # Normalization normalize_obs: False normalize_reward: False + rescale_action: True # CRITICAL: Add this! Must be True for MuJoCo # Network sizes - n_policy_units: 256 - soft_update_weight: 0.005 + n_policy_units: 256 + soft_update_weight: 0.005 # tau in SAC terms # Replay buffer replay_buffer_class: @@ -19,31 +20,36 @@ algorithm_kwargs: capacity: 1e6 # Scheduling & batch-updates - batch_size: 256 - learning_starts: 5000 - update_every: 1 - n_gradient_steps: 1 + batch_size: 256 + learning_starts: 5000 # Good, matches CleanRL + update_every: 1 # Good, update every step + n_gradient_steps: 1 # Good - # Learning rates + # Learning rates - CRITICAL CHANGE policy_lr: 3e-4 - q_lr: 1e-3 - alpha_lr: 1e-3 + q_lr: 1e-3 # This is correct now (was 3e-4) + alpha_lr: 3e-4 # 3e-4 is better than 1e-3 for alpha # SAC hyperparameters gamma: 0.99 alpha: 0.2 auto_alpha: True - target_entropy: -6.0 # -action_dim for HalfCheetah (6 actions) + target_entropy: null # Let it auto-compute as -action_dim + + # Network architecture + hidden_sizes: [256, 256] # Explicitly specify + activation: relu + log_std_min: -5 + log_std_max: 2 # Policy configuration policy_class: mighty.mighty_exploration.StochasticPolicy policy_kwargs: - entropy_coefficient: 0.0 discrete: False - + # Remove entropy_coefficient - SAC handles alpha internally # SAC specific frequencies - policy_frequency: 2 # Delayed policy updates + policy_frequency: 2 # Can also try 1 for even better performance target_network_frequency: 1 # Update targets every step # Environment and training configuration @@ -55,5 +61,5 @@ max_episode_steps: 1000 # HalfCheetah episode length eval_frequency: 10000 # More frequent eval for single env save_frequency: 50000 # Save every 50k steps - -# python mighty/run_mighty.py algorithm=sac env=HalfCheetah-v4 num_steps=1e6 num_envs=1 \ No newline at end of file +# Command to run: +# python mighty/run_mighty.py algorithm=sac env=HalfCheetah-v4 num_steps=1e6 num_envs=1 \ No newline at end of file diff --git a/mighty/mighty_agents/sac.py b/mighty/mighty_agents/sac.py index 2125c794..316eed79 100644 --- a/mighty/mighty_agents/sac.py +++ b/mighty/mighty_agents/sac.py @@ -38,7 +38,7 @@ def __init__( # --- Network architecture (optional override) --- hidden_sizes: Optional[List[int]] = None, activation: str = "relu", - log_std_min: float = -20, + log_std_min: float = -5, log_std_max: float = 2, # --- Logging & buffer --- render_progress: bool = True, diff --git a/mighty/mighty_exploration/stochastic_policy.py b/mighty/mighty_exploration/stochastic_policy.py index 63e995d0..b845075c 100644 --- a/mighty/mighty_exploration/stochastic_policy.py +++ b/mighty/mighty_exploration/stochastic_policy.py @@ -177,12 +177,9 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso # Special handling for SACModel elif isinstance(self.model, SACModel): action, z, mean, log_std = self.model(state, deterministic=False) - std = torch.exp(log_std) - dist = Normal(mean, std) - - log_pz = dist.log_prob(z).sum(dim=-1, keepdim=True) - weighted_log_prob = log_pz * self.entropy_coefficient - return action.detach().cpu().numpy(), weighted_log_prob + # CRITICAL: Use the model's policy_log_prob which includes tanh correction + log_prob = self.model.policy_log_prob(z, mean, log_std) + return action.detach().cpu().numpy(), log_prob else: raise RuntimeError( diff --git a/mighty/mighty_models/sac.py b/mighty/mighty_models/sac.py index dda12072..b27118c6 100644 --- a/mighty/mighty_models/sac.py +++ b/mighty/mighty_models/sac.py @@ -17,7 +17,7 @@ def __init__( self, obs_size: int, action_size: int, - log_std_min: float = -20, + log_std_min: float = -5, log_std_max: float = 2, **kwargs, ): @@ -124,7 +124,17 @@ def forward( feats = self.feature_extractor(state) x = self.policy_net(feats) mean, log_std = x.chunk(2, dim=-1) - log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max) + # log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max) + + # NEW - Soft clamping + log_std = torch.tanh(log_std) + log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1) + + # This maps tanh output [-1, 1] to [log_std_min, log_std_max] + # When tanh(x) = -1: log_std = log_std_min + 0.5 * range * 0 = log_std_min + # When tanh(x) = 0: log_std = log_std_min + 0.5 * range * 1 = (log_std_min + log_std_max) / 2 + # When tanh(x) = 1: log_std = log_std_min + 0.5 * range * 2 = log_std_max + std = torch.exp(log_std) if deterministic: diff --git a/mighty/mighty_update/sac_update.py b/mighty/mighty_update/sac_update.py index aba3aa12..74f16082 100644 --- a/mighty/mighty_update/sac_update.py +++ b/mighty/mighty_update/sac_update.py @@ -127,36 +127,36 @@ def update(self, batch: TransitionBatch) -> Dict: alpha_loss = torch.tensor(0.0) if self.update_step % self.policy_frequency == 0: # do multiple policy updates to compensate for delay - for _ in range(self.policy_frequency): - # recompute alpha after q update - current_alpha = ( - self.log_alpha.exp().detach() if self.auto_alpha else self.alpha - ) - - a, z, mean, log_std = self.model(states) - logp = self.model.policy_log_prob(z, mean, log_std) - sa_pi = torch.cat([states, a], dim=-1) - q1_pi = self.model.q_net1(sa_pi) - q2_pi = self.model.q_net2(sa_pi) - q_pi = torch.min(q1_pi, q2_pi) - policy_loss = (current_alpha * logp - q_pi).mean() - - self.policy_optimizer.zero_grad() - policy_loss.backward() - self.policy_optimizer.step() - - # --- Entropy coefficient (alpha) update --- - if self.auto_alpha: - with torch.no_grad(): - _, _, _, _ = self.model(states) - # Use the logp from the policy update above - alpha_loss = -( - self.log_alpha * (logp.detach() + self.target_entropy) - ).mean() - self.alpha_optimizer.zero_grad() - alpha_loss.backward() - self.alpha_optimizer.step() - self.alpha = self.log_alpha.exp().item() + # for _ in range(self.policy_frequency): + # recompute alpha after q update + current_alpha = ( + self.log_alpha.exp().detach() if self.auto_alpha else self.alpha + ) + + a, z, mean, log_std = self.model(states) + logp = self.model.policy_log_prob(z, mean, log_std) + sa_pi = torch.cat([states, a], dim=-1) + q1_pi = self.model.q_net1(sa_pi) + q2_pi = self.model.q_net2(sa_pi) + q_pi = torch.min(q1_pi, q2_pi) + policy_loss = (current_alpha * logp - q_pi).mean() + + self.policy_optimizer.zero_grad() + policy_loss.backward() + self.policy_optimizer.step() + + # --- Entropy coefficient (alpha) update --- + if self.auto_alpha: + with torch.no_grad(): + _, _, _, _ = self.model(states) + # Use the logp from the policy update above + alpha_loss = -( + self.log_alpha * (logp.detach() + self.target_entropy) + ).mean() + self.alpha_optimizer.zero_grad() + alpha_loss.backward() + self.alpha_optimizer.step() + self.alpha = self.log_alpha.exp().item() # --- Soft update targets --- if self.update_step % self.target_network_frequency == 0: From 707b28002ede91ad1cab154c5124aec7d7d04a05 Mon Sep 17 00:00:00 2001 From: Aditya Mohan Date: Thu, 7 Aug 2025 02:23:55 +0200 Subject: [PATCH 05/14] update --- mighty/configs/algorithm/sac.yaml | 3 +- mighty/mighty_agents/sac.py | 8 +- .../mighty_exploration_policy.py | 37 +++++---- .../mighty_exploration/stochastic_policy.py | 75 +++++++++++-------- mighty/mighty_update/sac_update.py | 60 +++++++-------- 5 files changed, 105 insertions(+), 78 deletions(-) diff --git a/mighty/configs/algorithm/sac.yaml b/mighty/configs/algorithm/sac.yaml index 07c80294..07361e57 100644 --- a/mighty/configs/algorithm/sac.yaml +++ b/mighty/configs/algorithm/sac.yaml @@ -19,13 +19,14 @@ algorithm_kwargs: replay_buffer_kwargs: capacity: 1e6 + # Scheduling & batch-updates batch_size: 256 learning_starts: 5000 # Good, matches CleanRL update_every: 1 # Good, update every step n_gradient_steps: 1 # Good - # Learning rates - CRITICAL CHANGE + # Learning rates policy_lr: 3e-4 q_lr: 1e-3 # This is correct now (was 3e-4) alpha_lr: 3e-4 # 3e-4 is better than 1e-3 for alpha diff --git a/mighty/mighty_agents/sac.py b/mighty/mighty_agents/sac.py index 316eed79..d0feecbd 100644 --- a/mighty/mighty_agents/sac.py +++ b/mighty/mighty_agents/sac.py @@ -207,8 +207,12 @@ def process_transition( # Ensure metrics dict if metrics is None: metrics = {} - # Pack transition - transition = TransitionBatch(curr_s, action, reward, next_s, dones) + # Pack transition + # `terminated` is used for physics failures in environments like `MightyEnv` + # Based on https://github.com/DLR-RM/stable-baselines3/issues/284 + terminated = metrics["transition"]["terminated"] # physics‐failures + transition = TransitionBatch(curr_s, action, reward, next_s, terminated.astype(int)) + # Compute per-transition TD errors for logging td1, td2 = self.update_fn.calculate_td_error(transition) metrics["td_error1"] = td1.detach().cpu().numpy() diff --git a/mighty/mighty_exploration/mighty_exploration_policy.py b/mighty/mighty_exploration/mighty_exploration_policy.py index 39948d5c..534693e8 100644 --- a/mighty/mighty_exploration/mighty_exploration_policy.py +++ b/mighty/mighty_exploration/mighty_exploration_policy.py @@ -8,29 +8,34 @@ import torch from torch.distributions import Categorical, Normal +from mighty.mighty_models import SACModel + def sample_nondeterministic_logprobs( - z: torch.Tensor, mean: torch.Tensor, log_std: torch.Tensor, sac: bool = False -) -> Tuple[torch.Tensor, torch.Tensor]: + z: torch.Tensor, + mean: torch.Tensor, + log_std: torch.Tensor, + sac: bool = False +) -> torch.Tensor: + """ + Compute log-prob of a Gaussian sample z ~ N(mean, exp(log_std)), + and if sac=True apply the tanh-squash correction to get log π(a). + """ std = torch.exp(log_std) # [batch, action_dim] dist = Normal(mean, std) + # base Gaussian log‐prob of z + log_pz = dist.log_prob(z).sum(dim=-1, keepdim=True) # [batch, 1] - # For SAC, don't apply correction if sac: - return dist.log_prob(z).sum(dim=-1, keepdim=True) # [batch, 1] - # If not SAC, we need to apply the tanh correction - else: - log_pz = dist.log_prob(z).sum(dim=-1, keepdim=True) # [batch, 1] - - # 2b) tanh‐correction = ∑ᵢ log(1 − tanh(zᵢ)² + ε) - eps = 1e-6 + # subtract the ∑_i log(d tanh/dz_i) = ∑ log(1 - tanh(z)^2) + eps = 1e-4 log_correction = torch.log(1.0 - torch.tanh(z).pow(2) + eps).sum( dim=-1, keepdim=True ) # [batch, 1] - - # 2c) final log_prob of a = tanh(z) - log_prob = log_pz - log_correction # [batch, 1] - return log_prob + return log_pz - log_correction + else: + # PPO-style or other: no squash correction + return log_pz class MightyExplorationPolicy: @@ -111,8 +116,10 @@ def sample_func_logits(self, state_array): # ─── Continuous squashed‐Gaussian (4‐tuple) ────────────────────────── elif isinstance(out, tuple) and len(out) == 4: action = out[0] # [batch, action_dim] + + print(f'Self Model : {self.model}') log_prob = sample_nondeterministic_logprobs( - z=out[1], mean=out[2], log_std=out[3], sac=self.algo == "sac" + z=out[1], mean=out[2], log_std=out[3], sac=isinstance(self.model, SACModel) ) return action.detach().cpu().numpy(), log_prob diff --git a/mighty/mighty_exploration/stochastic_policy.py b/mighty/mighty_exploration/stochastic_policy.py index b845075c..28e8cfdb 100644 --- a/mighty/mighty_exploration/stochastic_policy.py +++ b/mighty/mighty_exploration/stochastic_policy.py @@ -27,9 +27,13 @@ def __init__( :param entropy_coefficient: weight on entropy term :param discrete: whether the action space is discrete """ + + self.model = model + super().__init__(algo, model, discrete) self.entropy_coefficient = entropy_coefficient self.discrete = discrete + # --- override sample_action only for continuous SAC --- if not discrete and isinstance(model, SACModel): @@ -84,33 +88,24 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso # 4-tuple case (Tanh squashing): (action, z, mean, log_std) elif isinstance(model_output, tuple) and len(model_output) == 4: action, z, mean, log_std = model_output - log_prob = sample_nondeterministic_logprobs( - z=z, - mean=mean, - log_std=log_std, - sac=self.algo == "sac", - ) + + if not isinstance(self.model, SACModel): + + log_prob = sample_nondeterministic_logprobs( + z=z, + mean=mean, + log_std=log_std, + sac=False, + ) + else: + log_prob = self.model.policy_log_prob(z, mean, log_std) if return_logp: return action.detach().cpu().numpy(), log_prob else: - weighted_log_prob = log_prob * self.entropy_coefficient + weighted_log_prob = log_prob return action.detach().cpu().numpy(), weighted_log_prob - # Legacy 2-tuple case: (mean, std) - elif isinstance(model_output, tuple) and len(model_output) == 2: - mean, std = model_output - dist = Normal(mean, std) - z = dist.rsample() # [batch, action_dim] - action = torch.tanh(z) # [batch, action_dim] - - log_prob = sample_nondeterministic_logprobs( - z=z, mean=mean, log_std=torch.log(std), sac=self.algo == "sac" - ) - entropy = dist.entropy().sum(dim=-1, keepdim=True) # [batch, 1] - weighted_log_prob = log_prob * entropy - return action.detach().cpu().numpy(), weighted_log_prob - # Check for model attribute-based approaches elif hasattr(self.model, "continuous_action") and getattr( self.model, "continuous_action" @@ -126,9 +121,16 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso elif len(model_output) == 4: # Tanh squashing mode: (action, z, mean, log_std) action, z, mean, log_std = model_output - log_prob = sample_nondeterministic_logprobs( - z=z, mean=mean, log_std=log_std, sac=self.algo == "sac" - ) + if not isinstance(self.model, SACModel): + + log_prob = sample_nondeterministic_logprobs( + z=z, + mean=mean, + log_std=log_std, + sac=False, + ) + else: + log_prob = self.model.policy_log_prob(z, mean, log_std) else: raise ValueError( f"Unexpected model output length: {len(model_output)}" @@ -145,9 +147,15 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso if self.model.output_style == "squashed_gaussian": # Should be 4-tuple: (action, z, mean, log_std) action, z, mean, log_std = model_output - log_prob = sample_nondeterministic_logprobs( - z=z, mean=mean, log_std=log_std, sac=self.algo == "sac" - ) + if not isinstance(self.model, SACModel): + log_prob = sample_nondeterministic_logprobs( + z=z, + mean=mean, + log_std=log_std, + sac=False, + ) + else: + log_prob = self.model.policy_log_prob(z, mean, log_std) if return_logp: return action.detach().cpu().numpy(), log_prob @@ -162,9 +170,16 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso z = dist.rsample() action = torch.tanh(z) - log_prob = sample_nondeterministic_logprobs( - z=z, mean=mean, log_std=torch.log(std), sac=self.algo == "sac" - ) + if not isinstance(self.model, SACModel): + log_prob = sample_nondeterministic_logprobs( + z=z, + mean=mean, + log_std=log_std, + sac=False, + ) + else: + log_prob = self.model.policy_log_prob(z, mean, log_std) + entropy = dist.entropy().sum(dim=-1, keepdim=True) weighted_log_prob = log_prob * entropy return action.detach().cpu().numpy(), weighted_log_prob diff --git a/mighty/mighty_update/sac_update.py b/mighty/mighty_update/sac_update.py index 74f16082..e1c96db2 100644 --- a/mighty/mighty_update/sac_update.py +++ b/mighty/mighty_update/sac_update.py @@ -127,36 +127,36 @@ def update(self, batch: TransitionBatch) -> Dict: alpha_loss = torch.tensor(0.0) if self.update_step % self.policy_frequency == 0: # do multiple policy updates to compensate for delay - # for _ in range(self.policy_frequency): - # recompute alpha after q update - current_alpha = ( - self.log_alpha.exp().detach() if self.auto_alpha else self.alpha - ) - - a, z, mean, log_std = self.model(states) - logp = self.model.policy_log_prob(z, mean, log_std) - sa_pi = torch.cat([states, a], dim=-1) - q1_pi = self.model.q_net1(sa_pi) - q2_pi = self.model.q_net2(sa_pi) - q_pi = torch.min(q1_pi, q2_pi) - policy_loss = (current_alpha * logp - q_pi).mean() - - self.policy_optimizer.zero_grad() - policy_loss.backward() - self.policy_optimizer.step() - - # --- Entropy coefficient (alpha) update --- - if self.auto_alpha: - with torch.no_grad(): - _, _, _, _ = self.model(states) - # Use the logp from the policy update above - alpha_loss = -( - self.log_alpha * (logp.detach() + self.target_entropy) - ).mean() - self.alpha_optimizer.zero_grad() - alpha_loss.backward() - self.alpha_optimizer.step() - self.alpha = self.log_alpha.exp().item() + for _ in range(self.policy_frequency): + # recompute alpha after q update + current_alpha = ( + self.log_alpha.exp().detach() if self.auto_alpha else self.alpha + ) + + a, z, mean, log_std = self.model(states) + logp = self.model.policy_log_prob(z, mean, log_std) + sa_pi = torch.cat([states, a], dim=-1) + q1_pi = self.model.q_net1(sa_pi) + q2_pi = self.model.q_net2(sa_pi) + q_pi = torch.min(q1_pi, q2_pi) + policy_loss = (current_alpha * logp - q_pi).mean() + + self.policy_optimizer.zero_grad() + policy_loss.backward() + self.policy_optimizer.step() + + # --- Entropy coefficient (alpha) update --- + if self.auto_alpha: + with torch.no_grad(): + _, _, _, _ = self.model(states) + # Use the logp from the policy update above + alpha_loss = -( + self.log_alpha.exp() * (logp.detach() + self.target_entropy) + ).mean() + self.alpha_optimizer.zero_grad() + alpha_loss.backward() + self.alpha_optimizer.step() + self.alpha = self.log_alpha.exp().item() # --- Soft update targets --- if self.update_step % self.target_network_frequency == 0: From 979742e1ede005acf3e8ea2f89efd89fcac04401 Mon Sep 17 00:00:00 2001 From: Aditya Mohan Date: Thu, 7 Aug 2025 04:08:03 +0200 Subject: [PATCH 06/14] update --- mighty/configs/algorithm/sac.yaml | 2 +- mighty/mighty_agents/base_agent.py | 26 ++++++++++++++++++++++---- mighty/mighty_agents/sac.py | 2 ++ mighty/mighty_models/sac.py | 20 +++++++++++++++++++- 4 files changed, 44 insertions(+), 6 deletions(-) diff --git a/mighty/configs/algorithm/sac.yaml b/mighty/configs/algorithm/sac.yaml index 07361e57..658b025a 100644 --- a/mighty/configs/algorithm/sac.yaml +++ b/mighty/configs/algorithm/sac.yaml @@ -51,7 +51,7 @@ algorithm_kwargs: # SAC specific frequencies policy_frequency: 2 # Can also try 1 for even better performance - target_network_frequency: 1 # Update targets every step + target_network_frequency: 2 # Update targets every step # Environment and training configuration num_envs: 1 # Single environment diff --git a/mighty/mighty_agents/base_agent.py b/mighty/mighty_agents/base_agent.py index 790a7c74..4a1672f5 100644 --- a/mighty/mighty_agents/base_agent.py +++ b/mighty/mighty_agents/base_agent.py @@ -141,6 +141,7 @@ def __init__( # noqa: PLR0915, PLR0912 normalize_obs: bool = False, normalize_reward: bool = False, rescale_action: bool = False, + handle_timeout_termination: bool = False, ): """Base agent initialization. @@ -301,6 +302,8 @@ def __init__( # noqa: PLR0915, PLR0912 for m in self.meta_modules.values(): m.seed(self.seed) self.steps = 0 + + self.handle_timeout_termination = handle_timeout_termination def _initialize_agent(self) -> None: """Agent/algorithm specific initializations.""" @@ -603,8 +606,23 @@ def run( # noqa: PLR0915 metrics["episode_reward"] = episode_reward action, log_prob = self.step(curr_s, metrics) - next_s, reward, terminated, truncated, _ = self.env.step(action) # type: ignore - dones = np.logical_or(terminated, truncated) + # 1) step the env as usual + next_s, reward, terminated, truncated, infos = self.env.step(action) + + # 2) decide which samples are true “done” + replay_dones = terminated # physics‐failure only + dones = np.logical_or(terminated, truncated) + + + # 3) optionally overwrite next_s on truncation + if self.handle_timeout_termination: + real_next_s = next_s.copy() + # infos["final_observation"] is a list/array of the last real obs + for i, tr in enumerate(truncated): + if tr: + real_next_s[i] = infos["final_observation"][i] + else: + real_next_s = next_s episode_reward += reward @@ -615,10 +633,10 @@ def run( # noqa: PLR0915 "reward": reward, "action": action, "state": curr_s, - "next_state": next_s, + "next_state": real_next_s, "terminated": terminated.astype(int), "truncated": truncated.astype(int), - "dones": dones.astype(int), + "dones": replay_dones.astype(int), "mean_episode_reward": last_episode_reward.mean() .cpu() .numpy() diff --git a/mighty/mighty_agents/sac.py b/mighty/mighty_agents/sac.py index d0feecbd..758f1919 100644 --- a/mighty/mighty_agents/sac.py +++ b/mighty/mighty_agents/sac.py @@ -57,6 +57,7 @@ def __init__( rescale_action: bool = False, # ← NEW Whether to rescale actions to the environment's action space policy_frequency: int = 2, # Frequency of policy updates target_network_frequency: int = 1, # Frequency of target network updates + handle_timeout_termination: bool = True, ): """Initialize SAC agent with tunable hyperparameters and backward-compatible names.""" if hidden_sizes is None: @@ -116,6 +117,7 @@ def __init__( rescale_action=rescale_action, batch_size=batch_size, learning_rate=policy_lr, # For compatibility with base class + handle_timeout_termination=handle_timeout_termination, ) # Initialize loss buffer for logging diff --git a/mighty/mighty_models/sac.py b/mighty/mighty_models/sac.py index b27118c6..d3902552 100644 --- a/mighty/mighty_models/sac.py +++ b/mighty/mighty_models/sac.py @@ -19,6 +19,8 @@ def __init__( action_size: int, log_std_min: float = -5, log_std_max: float = 2, + action_low: float = -1, + action_high: float = +1, **kwargs, ): super().__init__() @@ -29,6 +31,16 @@ def __init__( # This model is continuous only self.continuous_action = True + + # PR: register the per-dim scale and bias so we can rescale [-1,1]→[low,high]. + action_low = torch.as_tensor(action_low, dtype=torch.float32) + action_high = torch.as_tensor(action_high, dtype=torch.float32) + self.register_buffer( + "action_scale", (action_high - action_low) / 2.0 + ) + self.register_buffer( + "action_bias", (action_high + action_low) / 2.0 + ) head_kwargs = {"hidden_sizes": [256, 256], "activation": "relu"} feature_extractor_kwargs = { @@ -141,7 +153,13 @@ def forward( z = mean else: z = mean + std * torch.randn_like(mean) - action = torch.tanh(z) + + # tanh→[-1,1] + raw_action = torch.tanh(z) + + # **HERE** we rescale into [low,high] + action = raw_action * self.action_scale + self.action_bias + return action, z, mean, log_std def policy_log_prob( From 29a41d5788e982fd8d43d9858993cc5b19f94feb Mon Sep 17 00:00:00 2001 From: Theresa Eimer Date: Thu, 7 Aug 2025 10:11:04 +0200 Subject: [PATCH 07/14] Remove print statement --- mighty/mighty_models/dqn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mighty/mighty_models/dqn.py b/mighty/mighty_models/dqn.py index 9f7eaa97..96f78863 100644 --- a/mighty/mighty_models/dqn.py +++ b/mighty/mighty_models/dqn.py @@ -25,7 +25,6 @@ def __init__(self, num_actions, obs_size, dueling=False, **kwargs): feature_extractor_kwargs.update(kwargs["feature_extractor_kwargs"]) # Make feature extractor - print(obs_size) self.feature_extractor, self.output_size = make_feature_extractor( **feature_extractor_kwargs ) From 484d1f25894e25082828dd590c612833792a4b72 Mon Sep 17 00:00:00 2001 From: Aditya Mohan Date: Thu, 7 Aug 2025 10:25:00 +0200 Subject: [PATCH 08/14] SAC updates --- mighty/configs/algorithm/sac.yaml | 2 +- mighty/mighty_models/sac.py | 37 ++++++++++++++++++++++-------- mighty/mighty_update/sac_update.py | 13 +++++++---- 3 files changed, 37 insertions(+), 15 deletions(-) diff --git a/mighty/configs/algorithm/sac.yaml b/mighty/configs/algorithm/sac.yaml index 658b025a..07361e57 100644 --- a/mighty/configs/algorithm/sac.yaml +++ b/mighty/configs/algorithm/sac.yaml @@ -51,7 +51,7 @@ algorithm_kwargs: # SAC specific frequencies policy_frequency: 2 # Can also try 1 for even better performance - target_network_frequency: 2 # Update targets every step + target_network_frequency: 1 # Update targets every step # Environment and training configuration num_envs: 1 # Single environment diff --git a/mighty/mighty_models/sac.py b/mighty/mighty_models/sac.py index d3902552..a31756fa 100644 --- a/mighty/mighty_models/sac.py +++ b/mighty/mighty_models/sac.py @@ -73,7 +73,12 @@ def __init__( ) # Policy network outputs mean and log_std - self.policy_net = nn.Linear(out_dim, action_size * 2) + # CHANGE: Create separate policy network (actor) similar to CleanRL + self.policy_net = make_policy_head( + in_size=self.obs_size, + out_size=self.action_size * 2, # mean and log_std + **head_kwargs + ) # Twin Q-networks # — live Q-nets — @@ -133,20 +138,13 @@ def forward( mean: Gaussian mean log_std: Gaussian log std """ - feats = self.feature_extractor(state) - x = self.policy_net(feats) + x = self.policy_net(state) mean, log_std = x.chunk(2, dim=-1) - # log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max) - # NEW - Soft clamping + # Soft clamping log_std = torch.tanh(log_std) log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1) - # This maps tanh output [-1, 1] to [log_std_min, log_std_max] - # When tanh(x) = -1: log_std = log_std_min + 0.5 * range * 0 = log_std_min - # When tanh(x) = 0: log_std = log_std_min + 0.5 * range * 1 = (log_std_min + log_std_max) / 2 - # When tanh(x) = 1: log_std = log_std_min + 0.5 * range * 2 = log_std_max - std = torch.exp(log_std) if deterministic: @@ -207,3 +205,22 @@ def make_q_head(in_size, hidden_sizes=None, activation="relu"): layers.append(nn.Linear(last_size, 1)) return nn.Sequential(*layers) + + +def make_policy_head(in_size, out_size, hidden_sizes=None, activation="relu"): + """Make policy head network (actor).""" + if hidden_sizes is None: + hidden_sizes = [] + + layers = [] + last_size = in_size + if isinstance(last_size, list): + last_size = last_size[0] + + for size in hidden_sizes: + layers.append(nn.Linear(last_size, size)) + layers.append(ACTIVATIONS[activation]()) + last_size = size + + layers.append(nn.Linear(last_size, out_size)) + return nn.Sequential(*layers) \ No newline at end of file diff --git a/mighty/mighty_update/sac_update.py b/mighty/mighty_update/sac_update.py index e1c96db2..ecd4a52e 100644 --- a/mighty/mighty_update/sac_update.py +++ b/mighty/mighty_update/sac_update.py @@ -41,7 +41,7 @@ def __init__( self.update_step = 0 if self.auto_alpha: - self.log_alpha = torch.nn.Parameter(torch.zeros(1, requires_grad=True)) + self.log_alpha = torch.zeros(1, requires_grad=True) self.alpha_optimizer = optim.Adam([self.log_alpha], lr=alpha_lr or q_lr) self.target_entropy = ( -float(self.action_dim) @@ -133,9 +133,12 @@ def update(self, batch: TransitionBatch) -> Dict: self.log_alpha.exp().detach() if self.auto_alpha else self.alpha ) + # FIX: Sample fresh actions for each policy update iteration + # This ensures stochasticity across iterations a, z, mean, log_std = self.model(states) logp = self.model.policy_log_prob(z, mean, log_std) sa_pi = torch.cat([states, a], dim=-1) + q1_pi = self.model.q_net1(sa_pi) q2_pi = self.model.q_net2(sa_pi) q_pi = torch.min(q1_pi, q2_pi) @@ -147,11 +150,13 @@ def update(self, batch: TransitionBatch) -> Dict: # --- Entropy coefficient (alpha) update --- if self.auto_alpha: + # CRITICAL FIX: Get fresh sample for alpha update with torch.no_grad(): - _, _, _, _ = self.model(states) - # Use the logp from the policy update above + _, z_alpha, mean_alpha, log_std_alpha = self.model(states) + logp_alpha = self.model.policy_log_prob(z_alpha, mean_alpha, log_std_alpha) + alpha_loss = -( - self.log_alpha.exp() * (logp.detach() + self.target_entropy) + self.log_alpha.exp() * (logp_alpha.detach() + self.target_entropy) ).mean() self.alpha_optimizer.zero_grad() alpha_loss.backward() From c667d0aade2a4af67a9a101ee87a38094d575ada Mon Sep 17 00:00:00 2001 From: Aditya Mohan Date: Thu, 7 Aug 2025 11:25:30 +0200 Subject: [PATCH 09/14] updated code + tests --- mighty/mighty_agents/dqn.py | 2 + mighty/mighty_agents/ppo.py | 2 + test/agents/test_sac_agent.py | 17 ++++++-- test/models/test_sac_networks.py | 71 ++++++++++++++++++++++++-------- 4 files changed, 71 insertions(+), 21 deletions(-) diff --git a/mighty/mighty_agents/dqn.py b/mighty/mighty_agents/dqn.py index 0ced9ce4..c6cc1cdb 100644 --- a/mighty/mighty_agents/dqn.py +++ b/mighty/mighty_agents/dqn.py @@ -69,6 +69,7 @@ def __init__( normalize_obs: bool = False, normalize_reward: bool = False, rescale_action: bool = False, # type: ignore + handle_timeout_termination: bool = False, ): """DQN initialization. @@ -154,6 +155,7 @@ def __init__( normalize_obs=normalize_obs, normalize_reward=normalize_reward, rescale_action=rescale_action, + handle_timeout_termination=handle_timeout_termination ) self.loss_buffer = { diff --git a/mighty/mighty_agents/ppo.py b/mighty/mighty_agents/ppo.py index 8f9a6fc1..8b974ee4 100644 --- a/mighty/mighty_agents/ppo.py +++ b/mighty/mighty_agents/ppo.py @@ -62,6 +62,7 @@ def __init__( normalize_reward: bool = False, rescale_action: bool = False, tanh_squash: bool = False, + handle_timeout_termination: bool = False, ): """Initialize the PPO agent. @@ -143,6 +144,7 @@ def __init__( normalize_obs=normalize_obs, normalize_reward=normalize_reward, rescale_action=rescale_action, + handle_timeout_termination=handle_timeout_termination ) self.loss_buffer = { diff --git a/test/agents/test_sac_agent.py b/test/agents/test_sac_agent.py index cf3d0d13..66e75b42 100644 --- a/test/agents/test_sac_agent.py +++ b/test/agents/test_sac_agent.py @@ -99,6 +99,13 @@ def test_update(self): dones = np.logical_or(terminated, truncated) # Process the transition (this adds to buffer) + # SAC agent expects metrics with transition info including terminated status + transition_metrics = { + "step": step, + "transition": { + "terminated": terminated, # Use the terminated from env.step() + } + } agent.process_transition( curr_s, action, @@ -106,7 +113,7 @@ def test_update(self): next_s, dones, log_prob.detach().cpu().numpy(), - {"step": step}, + transition_metrics, ) # Update current state @@ -272,7 +279,8 @@ def test_reproducibility(self): init_params = deepcopy(list(sac.model.parameters())) sac.run(20, 1) batch = sac.buffer.sample(20) - original_metrics = sac.update_agent(batch, 20) + # Fix: update_agent expects proper keyword arguments + original_metrics = sac.update_fn.update(batch) original_params = deepcopy(list(sac.model.parameters())) env = gym.vector.SyncVectorEnv([DummyContinuousEnv for _ in range(1)]) @@ -303,7 +311,8 @@ def test_reproducibility(self): ) sac.run(20, 1) batch = sac.buffer.sample(20) - new_metrics = sac.update_agent(batch, 20) + # Fix: update_agent expects proper keyword arguments + new_metrics = sac.update_fn.update(batch) for old, new in zip( original_params[:10], list(sac.model.parameters())[:10], strict=False ): @@ -321,4 +330,4 @@ def test_reproducibility(self): original_metrics["Update/policy_loss"], new_metrics["Update/policy_loss"], ), "Policy loss should be the same with same seed" - clean(output_dir) + clean(output_dir) \ No newline at end of file diff --git a/test/models/test_sac_networks.py b/test/models/test_sac_networks.py index 6c8a123b..e913071f 100644 --- a/test/models/test_sac_networks.py +++ b/test/models/test_sac_networks.py @@ -14,14 +14,14 @@ def test_init(self): assert sac.obs_size == 8, "Obs size should be 8" assert sac.action_size == 3, "Action size should be 3" assert sac.activation == "tanh", "Passed activation should be tanh" - assert sac.log_std_min == -20, "Default log_std_min should be -20" + assert sac.log_std_min == -5, "Default log_std_min should be -5" # Fixed: was -20 assert sac.log_std_max == 2, "Default log_std_max should be 2" assert sac.continuous_action is True, "SAC should always be continuous" # Check network structure - updated for new architecture assert hasattr(sac, "feature_extractor"), "Should have feature extractor" - assert isinstance(sac.policy_net, nn.Linear), ( - "Policy network should be Linear (after feature extractor)" + assert isinstance(sac.policy_net, nn.Sequential), ( # Fixed: policy_net is Sequential, not Linear + "Policy network should be Sequential" ) assert isinstance(sac.q_net1, nn.Sequential), "Q-network 1 should be Sequential" assert isinstance(sac.q_net2, nn.Sequential), "Q-network 2 should be Sequential" @@ -116,7 +116,7 @@ def test_value_function_module(self): def test_forward_stochastic(self): """Test forward pass with stochastic policy.""" - sac = SACModel(obs_size=6, action_size=4) + sac = SACModel(obs_size=6, action_size=4, action_low=-2.0, action_high=3.0) dummy_state = torch.rand((10, 6)) action, z, mean, log_std = sac(dummy_state, deterministic=False) @@ -133,19 +133,20 @@ def test_forward_stochastic(self): assert torch.all(torch.isfinite(mean)), "Means should be finite" assert torch.all(torch.isfinite(log_std)), "Log_stds should be finite" - # Check tanh constraint on actions - assert torch.all(action >= -1.0) and torch.all(action <= 1.0), ( - "Actions should be in [-1, 1] range" + # Check action bounds - should be in [action_low, action_high] range + assert torch.all(action >= -2.0) and torch.all(action <= 3.0), ( + "Actions should be in [-2.0, 3.0] range" ) # Check log_std clamping assert torch.all(log_std >= sac.log_std_min), "Log_std should be >= log_std_min" assert torch.all(log_std <= sac.log_std_max), "Log_std should be <= log_std_max" - # Check relationship: action = tanh(z) - expected_action = torch.tanh(z) + # Check relationship: raw_action = tanh(z), then scaled to [action_low, action_high] + raw_action = torch.tanh(z) + expected_action = raw_action * sac.action_scale + sac.action_bias assert torch.allclose(action, expected_action, atol=1e-6), ( - "Action should equal tanh(z)" + "Action should equal scaled tanh(z)" ) def test_forward_deterministic(self): @@ -164,10 +165,11 @@ def test_forward_deterministic(self): # In deterministic mode, z should equal mean assert torch.allclose(z, mean), "In deterministic mode, z should equal mean" - # Action should still be tanh(z) = tanh(mean) - expected_action = torch.tanh(mean) + # Action should be scaled tanh(mean) + raw_action = torch.tanh(mean) + expected_action = raw_action * sac.action_scale + sac.action_bias assert torch.allclose(action, expected_action), ( - "Action should equal tanh(mean) in deterministic mode" + "Action should equal scaled tanh(mean) in deterministic mode" ) def test_stochastic_vs_deterministic(self): @@ -209,9 +211,14 @@ def test_policy_log_prob(self): # Check shape assert log_prob.shape == (6, 1), "Log prob should have shape (6, 1)" - # Check that log probabilities are finite and reasonable + # Check that log probabilities are finite assert torch.all(torch.isfinite(log_prob)), "Log probs should be finite" - assert torch.all(log_prob <= 0.0), "Log probs should be <= 0" + + # Note: Log probabilities can be positive in some cases for transformed distributions + # The key constraint is that they should be reasonable values + # For SAC with tanh transformation, log probs can be positive due to the Jacobian correction + assert torch.all(log_prob > -50.0), "Log probs should not be extremely negative" + assert torch.all(log_prob < 50.0), "Log probs should not be extremely positive" # Test with deterministic actions (z = mean) log_prob_det = sac.policy_log_prob(mean, mean, log_std) @@ -223,7 +230,7 @@ def test_q_networks(self): """Test Q-network forward passes.""" sac = SACModel(obs_size=4, action_size=2) dummy_state = torch.rand((7, 4)) - dummy_action = torch.rand((7, 2)) + dummy_action = torch.rand((7, 2)) * 2 - 1 # Actions in [-1, 1] range # Concatenate state and action for Q-networks state_action = torch.cat([dummy_state, dummy_action], dim=-1) @@ -290,7 +297,7 @@ def test_gradient_flow(self): """Test that gradients flow properly through networks.""" sac = SACModel(obs_size=4, action_size=2) dummy_state = torch.rand((3, 4)) - dummy_action = torch.rand((3, 2)) + dummy_action = torch.rand((3, 2)) * 2 - 1 # Actions in [-1, 1] state_action = torch.cat([dummy_state, dummy_action], dim=-1) # Test policy network gradients @@ -350,3 +357,33 @@ def test_numerical_stability(self): assert torch.all(torch.isfinite(boundary_log_prob)), ( "Log probabilities should be finite for boundary actions" ) + + def test_action_scaling(self): + """Test that action scaling works correctly.""" + # Test with custom action bounds + action_low = -2.5 + action_high = 1.5 + sac = SACModel(obs_size=3, action_size=2, action_low=action_low, action_high=action_high) + + dummy_state = torch.rand((5, 3)) + action, z, mean, log_std = sac(dummy_state) + + # Actions should be within the specified bounds + assert torch.all(action >= action_low), f"Actions should be >= {action_low}" + assert torch.all(action <= action_high), f"Actions should be <= {action_high}" + + # Check the scaling math + raw_action = torch.tanh(z) + expected_scale = (action_high - action_low) / 2.0 + expected_bias = (action_high + action_low) / 2.0 + expected_action = raw_action * expected_scale + expected_bias + + assert torch.allclose(action, expected_action, atol=1e-6), ( + "Action scaling should match expected formula" + ) + assert torch.allclose(sac.action_scale, torch.tensor(expected_scale)), ( + "Action scale should be computed correctly" + ) + assert torch.allclose(sac.action_bias, torch.tensor(expected_bias)), ( + "Action bias should be computed correctly" + ) \ No newline at end of file From ade9d40cdd5b4e3e63b5f1983bf6ce74bfe1ae31 Mon Sep 17 00:00:00 2001 From: Aditya Mohan Date: Thu, 7 Aug 2025 11:26:40 +0200 Subject: [PATCH 10/14] removed FIX comments --- mighty/mighty_update/sac_update.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mighty/mighty_update/sac_update.py b/mighty/mighty_update/sac_update.py index ecd4a52e..83a5018c 100644 --- a/mighty/mighty_update/sac_update.py +++ b/mighty/mighty_update/sac_update.py @@ -133,7 +133,7 @@ def update(self, batch: TransitionBatch) -> Dict: self.log_alpha.exp().detach() if self.auto_alpha else self.alpha ) - # FIX: Sample fresh actions for each policy update iteration + # Sample fresh actions for each policy update iteration # This ensures stochasticity across iterations a, z, mean, log_std = self.model(states) logp = self.model.policy_log_prob(z, mean, log_std) @@ -150,7 +150,7 @@ def update(self, batch: TransitionBatch) -> Dict: # --- Entropy coefficient (alpha) update --- if self.auto_alpha: - # CRITICAL FIX: Get fresh sample for alpha update + # Get fresh sample for alpha update with torch.no_grad(): _, z_alpha, mean_alpha, log_std_alpha = self.model(states) logp_alpha = self.model.policy_log_prob(z_alpha, mean_alpha, log_std_alpha) From c4a6d819fed28def1f350c2c5770934c5465b97e Mon Sep 17 00:00:00 2001 From: Aditya Mohan Date: Fri, 8 Aug 2025 11:56:20 +0200 Subject: [PATCH 11/14] updates for Merge --- mighty/mighty_agents/base_agent.py | 26 ++-- mighty/mighty_agents/dqn.py | 2 - mighty/mighty_agents/ppo.py | 2 - mighty/mighty_agents/sac.py | 5 +- .../mighty_exploration_policy.py | 6 +- .../mighty_exploration/stochastic_policy.py | 2 +- mighty/mighty_models/sac.py | 83 ++++++++---- test/models/test_sac_networks.py | 121 +++++++++++++----- 8 files changed, 161 insertions(+), 86 deletions(-) diff --git a/mighty/mighty_agents/base_agent.py b/mighty/mighty_agents/base_agent.py index 4a1672f5..ec69808e 100644 --- a/mighty/mighty_agents/base_agent.py +++ b/mighty/mighty_agents/base_agent.py @@ -141,7 +141,6 @@ def __init__( # noqa: PLR0915, PLR0912 normalize_obs: bool = False, normalize_reward: bool = False, rescale_action: bool = False, - handle_timeout_termination: bool = False, ): """Base agent initialization. @@ -302,8 +301,6 @@ def __init__( # noqa: PLR0915, PLR0912 for m in self.meta_modules.values(): m.seed(self.seed) self.steps = 0 - - self.handle_timeout_termination = handle_timeout_termination def _initialize_agent(self) -> None: """Agent/algorithm specific initializations.""" @@ -606,24 +603,21 @@ def run( # noqa: PLR0915 metrics["episode_reward"] = episode_reward action, log_prob = self.step(curr_s, metrics) - # 1) step the env as usual + # step the env as usual next_s, reward, terminated, truncated, infos = self.env.step(action) - # 2) decide which samples are true “done” + # decide which samples are true “done” replay_dones = terminated # physics‐failure only - dones = np.logical_or(terminated, truncated) + dones = np.logical_or(terminated, truncated) - # 3) optionally overwrite next_s on truncation - if self.handle_timeout_termination: - real_next_s = next_s.copy() - # infos["final_observation"] is a list/array of the last real obs - for i, tr in enumerate(truncated): - if tr: - real_next_s[i] = infos["final_observation"][i] - else: - real_next_s = next_s - + # Overwrite next_s on truncation + # Based on https://github.com/DLR-RM/stable-baselines3/issues/284 + real_next_s = next_s.copy() + # infos["final_observation"] is a list/array of the last real obs + for i, tr in enumerate(truncated): + if tr: + real_next_s[i] = infos["final_observation"][i] episode_reward += reward # Log everything diff --git a/mighty/mighty_agents/dqn.py b/mighty/mighty_agents/dqn.py index c6cc1cdb..0ced9ce4 100644 --- a/mighty/mighty_agents/dqn.py +++ b/mighty/mighty_agents/dqn.py @@ -69,7 +69,6 @@ def __init__( normalize_obs: bool = False, normalize_reward: bool = False, rescale_action: bool = False, # type: ignore - handle_timeout_termination: bool = False, ): """DQN initialization. @@ -155,7 +154,6 @@ def __init__( normalize_obs=normalize_obs, normalize_reward=normalize_reward, rescale_action=rescale_action, - handle_timeout_termination=handle_timeout_termination ) self.loss_buffer = { diff --git a/mighty/mighty_agents/ppo.py b/mighty/mighty_agents/ppo.py index 8b974ee4..8f9a6fc1 100644 --- a/mighty/mighty_agents/ppo.py +++ b/mighty/mighty_agents/ppo.py @@ -62,7 +62,6 @@ def __init__( normalize_reward: bool = False, rescale_action: bool = False, tanh_squash: bool = False, - handle_timeout_termination: bool = False, ): """Initialize the PPO agent. @@ -144,7 +143,6 @@ def __init__( normalize_obs=normalize_obs, normalize_reward=normalize_reward, rescale_action=rescale_action, - handle_timeout_termination=handle_timeout_termination ) self.loss_buffer = { diff --git a/mighty/mighty_agents/sac.py b/mighty/mighty_agents/sac.py index 758f1919..32753303 100644 --- a/mighty/mighty_agents/sac.py +++ b/mighty/mighty_agents/sac.py @@ -57,7 +57,6 @@ def __init__( rescale_action: bool = False, # ← NEW Whether to rescale actions to the environment's action space policy_frequency: int = 2, # Frequency of policy updates target_network_frequency: int = 1, # Frequency of target network updates - handle_timeout_termination: bool = True, ): """Initialize SAC agent with tunable hyperparameters and backward-compatible names.""" if hidden_sizes is None: @@ -117,7 +116,6 @@ def __init__( rescale_action=rescale_action, batch_size=batch_size, learning_rate=policy_lr, # For compatibility with base class - handle_timeout_termination=handle_timeout_termination, ) # Initialize loss buffer for logging @@ -209,9 +207,8 @@ def process_transition( # Ensure metrics dict if metrics is None: metrics = {} + # Pack transition - # `terminated` is used for physics failures in environments like `MightyEnv` - # Based on https://github.com/DLR-RM/stable-baselines3/issues/284 terminated = metrics["transition"]["terminated"] # physics‐failures transition = TransitionBatch(curr_s, action, reward, next_s, terminated.astype(int)) diff --git a/mighty/mighty_exploration/mighty_exploration_policy.py b/mighty/mighty_exploration/mighty_exploration_policy.py index 534693e8..7af628a8 100644 --- a/mighty/mighty_exploration/mighty_exploration_policy.py +++ b/mighty/mighty_exploration/mighty_exploration_policy.py @@ -115,11 +115,9 @@ def sample_func_logits(self, state_array): # ─── Continuous squashed‐Gaussian (4‐tuple) ────────────────────────── elif isinstance(out, tuple) and len(out) == 4: - action = out[0] # [batch, action_dim] - - print(f'Self Model : {self.model}') + action = out[0] # [batch, action_dim] log_prob = sample_nondeterministic_logprobs( - z=out[1], mean=out[2], log_std=out[3], sac=isinstance(self.model, SACModel) + z=out[1], mean=out[2], log_std=out[3], sac= self.ago == "sac" ) return action.detach().cpu().numpy(), log_prob diff --git a/mighty/mighty_exploration/stochastic_policy.py b/mighty/mighty_exploration/stochastic_policy.py index 28e8cfdb..3b28306e 100644 --- a/mighty/mighty_exploration/stochastic_policy.py +++ b/mighty/mighty_exploration/stochastic_policy.py @@ -103,7 +103,7 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso if return_logp: return action.detach().cpu().numpy(), log_prob else: - weighted_log_prob = log_prob + weighted_log_prob = log_prob * self.entropy_coefficient return action.detach().cpu().numpy(), weighted_log_prob # Check for model attribute-based approaches diff --git a/mighty/mighty_models/sac.py b/mighty/mighty_models/sac.py index a31756fa..045d9d91 100644 --- a/mighty/mighty_models/sac.py +++ b/mighty/mighty_models/sac.py @@ -32,7 +32,7 @@ def __init__( # This model is continuous only self.continuous_action = True - # PR: register the per-dim scale and bias so we can rescale [-1,1]→[low,high]. + # Register the per-dim scale and bias so we can rescale [-1,1]→[low,high]. action_low = torch.as_tensor(action_low, dtype=torch.float32) action_high = torch.as_tensor(action_high, dtype=torch.float32) self.register_buffer( @@ -67,42 +67,75 @@ def __init__( self.hidden_sizes = feature_extractor_kwargs.get("hidden_sizes", [256, 256]) self.activation = feature_extractor_kwargs.get("activation", "relu") - # Shared feature extractor for policy - self.feature_extractor, out_dim = make_feature_extractor( + # Policy feature extractor and head + self.policy_feature_extractor, policy_feat_dim = make_feature_extractor( **feature_extractor_kwargs ) - - # Policy network outputs mean and log_std - # CHANGE: Create separate policy network (actor) similar to CleanRL - self.policy_net = make_policy_head( - in_size=self.obs_size, + + # Policy head: just the final output layer + self.policy_head = make_policy_head( + in_size=policy_feat_dim, out_size=self.action_size * 2, # mean and log_std - **head_kwargs + hidden_sizes=[], # No hidden layers, just final linear layer + activation=head_kwargs["activation"] ) - # Twin Q-networks - # — live Q-nets — - self.q_net1 = make_q_head( - in_size=self.obs_size + self.action_size, **head_kwargs + # Create policy_net for backward compatibility + self.policy_net = nn.Sequential(self.policy_feature_extractor, self.policy_head) + + # Q-networks: feature extractors + heads + q_feature_extractor_kwargs = feature_extractor_kwargs.copy() + q_feature_extractor_kwargs["obs_shape"] = self.obs_size + self.action_size + + # Q-network 1 + self.q_feature_extractor1, q_feat_dim = make_feature_extractor(**q_feature_extractor_kwargs) + self.q_head1 = make_q_head( + in_size=q_feat_dim, + hidden_sizes=[], # No hidden layers, just final linear layer + activation=head_kwargs["activation"] ) - self.q_net2 = make_q_head( - in_size=self.obs_size + self.action_size, **head_kwargs + self.q_net1 = nn.Sequential(self.q_feature_extractor1, self.q_head1) + + # Q-network 2 + self.q_feature_extractor2, _ = make_feature_extractor(**q_feature_extractor_kwargs) + self.q_head2 = make_q_head( + in_size=q_feat_dim, + hidden_sizes=[], # No hidden layers, just final linear layer + activation=head_kwargs["activation"] ) + self.q_net2 = nn.Sequential(self.q_feature_extractor2, self.q_head2) # Target Q-networks - self.target_q_net1 = make_q_head( - in_size=self.obs_size + self.action_size, **head_kwargs + self.target_q_feature_extractor1, _ = make_feature_extractor(**q_feature_extractor_kwargs) + self.target_q_head1 = make_q_head( + in_size=q_feat_dim, + hidden_sizes=[], # No hidden layers, just final linear layer + activation=head_kwargs["activation"] ) - self.target_q_net1.load_state_dict(self.q_net1.state_dict()) - self.target_q_net2 = make_q_head( - in_size=self.obs_size + self.action_size, **head_kwargs + self.target_q_net1 = nn.Sequential(self.target_q_feature_extractor1, self.target_q_head1) + + self.target_q_feature_extractor2, _ = make_feature_extractor(**q_feature_extractor_kwargs) + self.target_q_head2 = make_q_head( + in_size=q_feat_dim, + hidden_sizes=[], # No hidden layers, just final linear layer + activation=head_kwargs["activation"] ) - self.target_q_net2.load_state_dict(self.q_net2.state_dict()) + self.target_q_net2 = nn.Sequential(self.target_q_feature_extractor2, self.target_q_head2) + + # Copy weights from live to target networks + self.target_q_feature_extractor1.load_state_dict(self.q_feature_extractor1.state_dict()) + self.target_q_head1.load_state_dict(self.q_head1.state_dict()) + self.target_q_feature_extractor2.load_state_dict(self.q_feature_extractor2.state_dict()) + self.target_q_head2.load_state_dict(self.q_head2.state_dict()) # Freeze target networks - for p in self.target_q_net1.parameters(): + for p in self.target_q_feature_extractor1.parameters(): + p.requires_grad = False + for p in self.target_q_head1.parameters(): + p.requires_grad = False + for p in self.target_q_feature_extractor2.parameters(): p.requires_grad = False - for p in self.target_q_net2.parameters(): + for p in self.target_q_head2.parameters(): p.requires_grad = False # Create a value function wrapper for compatibility @@ -133,7 +166,7 @@ def forward( Forward pass for policy sampling. Returns: - action: torch.Tensor in [-1,1] + action: torch.Tensor in rescaled range [action_low, action_high] z: raw Gaussian sample before tanh mean: Gaussian mean log_std: Gaussian log std @@ -155,7 +188,7 @@ def forward( # tanh→[-1,1] raw_action = torch.tanh(z) - # **HERE** we rescale into [low,high] + # Rescale into [action_low, action_high] action = raw_action * self.action_scale + self.action_bias return action, z, mean, log_std diff --git a/test/models/test_sac_networks.py b/test/models/test_sac_networks.py index e913071f..54622d2e 100644 --- a/test/models/test_sac_networks.py +++ b/test/models/test_sac_networks.py @@ -14,17 +14,22 @@ def test_init(self): assert sac.obs_size == 8, "Obs size should be 8" assert sac.action_size == 3, "Action size should be 3" assert sac.activation == "tanh", "Passed activation should be tanh" - assert sac.log_std_min == -5, "Default log_std_min should be -5" # Fixed: was -20 + assert sac.log_std_min == -5, "Default log_std_min should be -5" assert sac.log_std_max == 2, "Default log_std_max should be 2" assert sac.continuous_action is True, "SAC should always be continuous" - # Check network structure - updated for new architecture - assert hasattr(sac, "feature_extractor"), "Should have feature extractor" - assert isinstance(sac.policy_net, nn.Sequential), ( # Fixed: policy_net is Sequential, not Linear - "Policy network should be Sequential" - ) + # Check network structure - updated for feature extractor + head architecture + assert hasattr(sac, "policy_feature_extractor"), "Should have policy feature extractor" + assert hasattr(sac, "policy_head"), "Should have policy head" + assert isinstance(sac.policy_net, nn.Sequential), "Policy network should be Sequential" + + # Check Q-networks + assert hasattr(sac, "q_feature_extractor1"), "Should have Q1 feature extractor" + assert hasattr(sac, "q_head1"), "Should have Q1 head" assert isinstance(sac.q_net1, nn.Sequential), "Q-network 1 should be Sequential" assert isinstance(sac.q_net2, nn.Sequential), "Q-network 2 should be Sequential" + + # Check target networks assert isinstance(sac.target_q_net1, nn.Sequential), ( "Target Q-network 1 should be Sequential" ) @@ -36,23 +41,39 @@ def test_init(self): ) # Check that target networks have gradients disabled - for param in sac.target_q_net1.parameters(): + for param in sac.target_q_feature_extractor1.parameters(): + assert not param.requires_grad, ( + "Target Q1 feature extractor parameters should not require gradients" + ) + for param in sac.target_q_head1.parameters(): + assert not param.requires_grad, ( + "Target Q1 head parameters should not require gradients" + ) + for param in sac.target_q_feature_extractor2.parameters(): assert not param.requires_grad, ( - "Target Q-network 1 parameters should not require gradients" + "Target Q2 feature extractor parameters should not require gradients" ) - for param in sac.target_q_net2.parameters(): + for param in sac.target_q_head2.parameters(): assert not param.requires_grad, ( - "Target Q-network 2 parameters should not require gradients" + "Target Q2 head parameters should not require gradients" ) # Check that live networks have gradients enabled - for param in sac.q_net1.parameters(): + for param in sac.q_feature_extractor1.parameters(): assert param.requires_grad, ( - "Q-network 1 parameters should require gradients" + "Q1 feature extractor parameters should require gradients" ) - for param in sac.q_net2.parameters(): + for param in sac.q_head1.parameters(): assert param.requires_grad, ( - "Q-network 2 parameters should require gradients" + "Q1 head parameters should require gradients" + ) + for param in sac.q_feature_extractor2.parameters(): + assert param.requires_grad, ( + "Q2 feature extractor parameters should require gradients" + ) + for param in sac.q_head2.parameters(): + assert param.requires_grad, ( + "Q2 head parameters should require gradients" ) def test_init_custom_params(self): @@ -250,26 +271,45 @@ def test_target_networks_initialization(self): """Test that target networks are initialized with same weights as live networks.""" sac = SACModel(obs_size=3, action_size=2) - # Check that target networks have same weights as live networks initially + # Check that target feature extractors have same weights as live ones for p1, p_target1 in zip( - sac.q_net1.parameters(), sac.target_q_net1.parameters() + sac.q_feature_extractor1.parameters(), sac.target_q_feature_extractor1.parameters() ): assert torch.allclose(p1, p_target1), ( - "Target Q-net 1 should have same initial weights as Q-net 1" + "Target Q1 feature extractor should have same initial weights" ) for p2, p_target2 in zip( - sac.q_net2.parameters(), sac.target_q_net2.parameters() + sac.q_feature_extractor2.parameters(), sac.target_q_feature_extractor2.parameters() ): assert torch.allclose(p2, p_target2), ( - "Target Q-net 2 should have same initial weights as Q-net 2" + "Target Q2 feature extractor should have same initial weights" + ) + + # Check that target heads have same weights as live heads + for p1, p_target1 in zip( + sac.q_head1.parameters(), sac.target_q_head1.parameters() + ): + assert torch.allclose(p1, p_target1), ( + "Target Q1 head should have same initial weights as Q1 head" + ) + + for p2, p_target2 in zip( + sac.q_head2.parameters(), sac.target_q_head2.parameters() + ): + assert torch.allclose(p2, p_target2), ( + "Target Q2 head should have same initial weights as Q2 head" ) def test_twin_q_networks_independence(self): """Test that twin Q-networks are independent.""" sac = SACModel(obs_size=4, action_size=2) - # Check that Q-networks have different parameters (due to random initialization) + # Check that Q-networks have different objects (due to separate creation) + assert sac.q_feature_extractor1 is not sac.q_feature_extractor2, ( + "Q feature extractors should be separate objects" + ) + assert sac.q_head1 is not sac.q_head2, "Q heads should be separate objects" assert sac.q_net1 is not sac.q_net2, "Q-networks should be separate objects" assert sac.target_q_net1 is not sac.target_q_net2, ( "Target Q-networks should be separate objects" @@ -305,13 +345,15 @@ def test_gradient_flow(self): policy_loss = action.mean() # Dummy loss policy_loss.backward(retain_graph=True) - # Check that policy network has gradients - policy_has_grad = any(p.grad is not None for p in sac.policy_net.parameters()) - feature_has_grad = any( - p.grad is not None for p in sac.feature_extractor.parameters() + # Check that policy components have gradients + policy_feat_has_grad = any( + p.grad is not None for p in sac.policy_feature_extractor.parameters() ) - assert policy_has_grad or feature_has_grad, ( - "Policy network or feature extractor should have gradients" + policy_head_has_grad = any( + p.grad is not None for p in sac.policy_head.parameters() + ) + assert policy_feat_has_grad or policy_head_has_grad, ( + "Policy feature extractor or head should have gradients" ) # Test Q-network gradients @@ -320,15 +362,30 @@ def test_gradient_flow(self): q_loss = q1_value.mean() # Dummy loss q_loss.backward() - # Check that Q-network 1 has gradients - q1_has_grad = any(p.grad is not None for p in sac.q_net1.parameters()) - assert q1_has_grad, "Q-network 1 should have gradients" + # Check that Q1 components have gradients + q1_feat_has_grad = any( + p.grad is not None for p in sac.q_feature_extractor1.parameters() + ) + q1_head_has_grad = any( + p.grad is not None for p in sac.q_head1.parameters() + ) + assert q1_feat_has_grad or q1_head_has_grad, ( + "Q1 feature extractor or head should have gradients" + ) # Check that target networks don't have gradients - target_q1_has_grad = any( - p.grad is not None for p in sac.target_q_net1.parameters() + target_q1_feat_has_grad = any( + p.grad is not None for p in sac.target_q_feature_extractor1.parameters() + ) + target_q1_head_has_grad = any( + p.grad is not None for p in sac.target_q_head1.parameters() + ) + assert not target_q1_feat_has_grad, ( + "Target Q1 feature extractor should not have gradients" + ) + assert not target_q1_head_has_grad, ( + "Target Q1 head should not have gradients" ) - assert not target_q1_has_grad, "Target Q-network 1 should not have gradients" def test_numerical_stability(self): """Test numerical stability of log probability calculation.""" From a47ace993eefee9bf399a338258f86aed805e6e7 Mon Sep 17 00:00:00 2001 From: Aditya Mohan Date: Fri, 8 Aug 2025 12:06:05 +0200 Subject: [PATCH 12/14] removed instance comparisons in stochastic and exploration policies --- mighty/mighty_agents/sac.py | 12 ++++++---- .../mighty_exploration_policy.py | 9 +++----- .../mighty_exploration/stochastic_policy.py | 23 +++++++++---------- 3 files changed, 21 insertions(+), 23 deletions(-) diff --git a/mighty/mighty_agents/sac.py b/mighty/mighty_agents/sac.py index 32753303..bd26e1d6 100644 --- a/mighty/mighty_agents/sac.py +++ b/mighty/mighty_agents/sac.py @@ -145,7 +145,7 @@ def _initialize_agent(self) -> None: # Exploration policy wrapper self.policy = self.policy_class( - algo=self, model=self.model, **self.policy_kwargs + algo="sac", model=self.model, **self.policy_kwargs ) # Updater @@ -207,11 +207,13 @@ def process_transition( # Ensure metrics dict if metrics is None: metrics = {} - - # Pack transition + + # Pack transition terminated = metrics["transition"]["terminated"] # physics‐failures - transition = TransitionBatch(curr_s, action, reward, next_s, terminated.astype(int)) - + transition = TransitionBatch( + curr_s, action, reward, next_s, terminated.astype(int) + ) + # Compute per-transition TD errors for logging td1, td2 = self.update_fn.calculate_td_error(transition) metrics["td_error1"] = td1.detach().cpu().numpy() diff --git a/mighty/mighty_exploration/mighty_exploration_policy.py b/mighty/mighty_exploration/mighty_exploration_policy.py index 7af628a8..4d37e4a3 100644 --- a/mighty/mighty_exploration/mighty_exploration_policy.py +++ b/mighty/mighty_exploration/mighty_exploration_policy.py @@ -12,10 +12,7 @@ def sample_nondeterministic_logprobs( - z: torch.Tensor, - mean: torch.Tensor, - log_std: torch.Tensor, - sac: bool = False + z: torch.Tensor, mean: torch.Tensor, log_std: torch.Tensor, sac: bool = False ) -> torch.Tensor: """ Compute log-prob of a Gaussian sample z ~ N(mean, exp(log_std)), @@ -115,9 +112,9 @@ def sample_func_logits(self, state_array): # ─── Continuous squashed‐Gaussian (4‐tuple) ────────────────────────── elif isinstance(out, tuple) and len(out) == 4: - action = out[0] # [batch, action_dim] + action = out[0] # [batch, action_dim] log_prob = sample_nondeterministic_logprobs( - z=out[1], mean=out[2], log_std=out[3], sac= self.ago == "sac" + z=out[1], mean=out[2], log_std=out[3], sac=self.ago == "sac" ) return action.detach().cpu().numpy(), log_prob diff --git a/mighty/mighty_exploration/stochastic_policy.py b/mighty/mighty_exploration/stochastic_policy.py index 3b28306e..4c57c20b 100644 --- a/mighty/mighty_exploration/stochastic_policy.py +++ b/mighty/mighty_exploration/stochastic_policy.py @@ -27,13 +27,12 @@ def __init__( :param entropy_coefficient: weight on entropy term :param discrete: whether the action space is discrete """ - + self.model = model - + super().__init__(algo, model, discrete) self.entropy_coefficient = entropy_coefficient self.discrete = discrete - # --- override sample_action only for continuous SAC --- if not discrete and isinstance(model, SACModel): @@ -88,9 +87,9 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso # 4-tuple case (Tanh squashing): (action, z, mean, log_std) elif isinstance(model_output, tuple) and len(model_output) == 4: action, z, mean, log_std = model_output - - if not isinstance(self.model, SACModel): - + + if not self.algo == "sac": + log_prob = sample_nondeterministic_logprobs( z=z, mean=mean, @@ -121,8 +120,8 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso elif len(model_output) == 4: # Tanh squashing mode: (action, z, mean, log_std) action, z, mean, log_std = model_output - if not isinstance(self.model, SACModel): - + if not self.algo == "sac": + log_prob = sample_nondeterministic_logprobs( z=z, mean=mean, @@ -147,7 +146,7 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso if self.model.output_style == "squashed_gaussian": # Should be 4-tuple: (action, z, mean, log_std) action, z, mean, log_std = model_output - if not isinstance(self.model, SACModel): + if not self.algo == "sac": log_prob = sample_nondeterministic_logprobs( z=z, mean=mean, @@ -170,7 +169,7 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso z = dist.rsample() action = torch.tanh(z) - if not isinstance(self.model, SACModel): + if not self.algo == "sac": log_prob = sample_nondeterministic_logprobs( z=z, mean=mean, @@ -179,7 +178,7 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso ) else: log_prob = self.model.policy_log_prob(z, mean, log_std) - + entropy = dist.entropy().sum(dim=-1, keepdim=True) weighted_log_prob = log_prob * entropy return action.detach().cpu().numpy(), weighted_log_prob @@ -190,7 +189,7 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso ) # Special handling for SACModel - elif isinstance(self.model, SACModel): + elif self.algo == "sac" and isinstance(self.model, SACModel): action, z, mean, log_std = self.model(state, deterministic=False) # CRITICAL: Use the model's policy_log_prob which includes tanh correction log_prob = self.model.policy_log_prob(z, mean, log_std) From 80923ff9ad9e49b86840b42873d6bdd4d8247f0a Mon Sep 17 00:00:00 2001 From: Theresa Eimer Date: Fri, 8 Aug 2025 13:40:48 +0200 Subject: [PATCH 13/14] version update --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 05e97863..052e953b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "Mighty-RL" -version = "0.0.1" +version = "1.0.0" description = "A modular, meta-learning-ready RL library." authors = [{ name = "AutoRL@LUHAI", email = "a.mohan@ai.uni-hannover.de" }] readme = "README.md" @@ -91,4 +91,4 @@ explicit = true omit = [ "mighty/mighty_utils/plotting.py", "mighty/mighty_utils/test_helpers.py" - ] \ No newline at end of file + ] From dd5014d7cfdabc2acde736560e66442e76dc6fa8 Mon Sep 17 00:00:00 2001 From: Theresa Eimer Date: Fri, 8 Aug 2025 13:50:00 +0200 Subject: [PATCH 14/14] Update publish-release.yaml --- .github/workflows/publish-release.yaml | 33 +------------------------- 1 file changed, 1 insertion(+), 32 deletions(-) diff --git a/.github/workflows/publish-release.yaml b/.github/workflows/publish-release.yaml index 6413c59d..066597cb 100644 --- a/.github/workflows/publish-release.yaml +++ b/.github/workflows/publish-release.yaml @@ -13,37 +13,6 @@ on: types: [created] jobs: - test: - name: publish-release - runs-on: "ubuntu-latest" - - steps: - - name: Checkout - uses: actions/checkout@v4 - - - name: Install uv - uses: astral-sh/setup-uv@v5 - with: - # Install a specific version of uv. - version: "0.6.14" - - - name: "Set up Python" - uses: actions/setup-python@v5 - with: - python-version-file: "pyproject.toml" - - - name: Install ${{ env.package-name }} - run: make install-dev - - - name: Store git status - id: status-before - shell: bash - run: | - echo "::set-output name=BEFORE::$(git status --porcelain -b)" - - - name: Tests - run: make test - pypi-publish: name: Upload release to PyPI runs-on: ubuntu-latest @@ -71,4 +40,4 @@ jobs: run: uv build - name: Publish package distributions to PyPI - run: uv publish \ No newline at end of file + run: uv publish