From 488b439417743bb7a8ceb800c4dfcc05caf18a13 Mon Sep 17 00:00:00 2001 From: Aditya Mohan Date: Wed, 6 Aug 2025 23:06:33 +0200 Subject: [PATCH 1/8] updated sac to test if it works --- mighty/configs/algorithm/sac.yaml | 36 ++++++----- mighty/mighty_agents/sac.py | 2 +- .../mighty_exploration/stochastic_policy.py | 9 +-- mighty/mighty_models/sac.py | 14 ++++- mighty/mighty_update/sac_update.py | 60 +++++++++---------- 5 files changed, 67 insertions(+), 54 deletions(-) diff --git a/mighty/configs/algorithm/sac.yaml b/mighty/configs/algorithm/sac.yaml index f1804932..07c80294 100644 --- a/mighty/configs/algorithm/sac.yaml +++ b/mighty/configs/algorithm/sac.yaml @@ -7,10 +7,11 @@ algorithm_kwargs: # Normalization normalize_obs: False normalize_reward: False + rescale_action: True # CRITICAL: Add this! Must be True for MuJoCo # Network sizes - n_policy_units: 256 - soft_update_weight: 0.005 + n_policy_units: 256 + soft_update_weight: 0.005 # tau in SAC terms # Replay buffer replay_buffer_class: @@ -19,31 +20,36 @@ algorithm_kwargs: capacity: 1e6 # Scheduling & batch-updates - batch_size: 256 - learning_starts: 5000 - update_every: 1 - n_gradient_steps: 1 + batch_size: 256 + learning_starts: 5000 # Good, matches CleanRL + update_every: 1 # Good, update every step + n_gradient_steps: 1 # Good - # Learning rates + # Learning rates - CRITICAL CHANGE policy_lr: 3e-4 - q_lr: 1e-3 - alpha_lr: 1e-3 + q_lr: 1e-3 # This is correct now (was 3e-4) + alpha_lr: 3e-4 # 3e-4 is better than 1e-3 for alpha # SAC hyperparameters gamma: 0.99 alpha: 0.2 auto_alpha: True - target_entropy: -6.0 # -action_dim for HalfCheetah (6 actions) + target_entropy: null # Let it auto-compute as -action_dim + + # Network architecture + hidden_sizes: [256, 256] # Explicitly specify + activation: relu + log_std_min: -5 + log_std_max: 2 # Policy configuration policy_class: mighty.mighty_exploration.StochasticPolicy policy_kwargs: - entropy_coefficient: 0.0 discrete: False - + # Remove entropy_coefficient - SAC handles alpha internally # SAC specific frequencies - policy_frequency: 2 # Delayed policy updates + policy_frequency: 2 # Can also try 1 for even better performance target_network_frequency: 1 # Update targets every step # Environment and training configuration @@ -55,5 +61,5 @@ max_episode_steps: 1000 # HalfCheetah episode length eval_frequency: 10000 # More frequent eval for single env save_frequency: 50000 # Save every 50k steps - -# python mighty/run_mighty.py algorithm=sac env=HalfCheetah-v4 num_steps=1e6 num_envs=1 \ No newline at end of file +# Command to run: +# python mighty/run_mighty.py algorithm=sac env=HalfCheetah-v4 num_steps=1e6 num_envs=1 \ No newline at end of file diff --git a/mighty/mighty_agents/sac.py b/mighty/mighty_agents/sac.py index 2125c794..316eed79 100644 --- a/mighty/mighty_agents/sac.py +++ b/mighty/mighty_agents/sac.py @@ -38,7 +38,7 @@ def __init__( # --- Network architecture (optional override) --- hidden_sizes: Optional[List[int]] = None, activation: str = "relu", - log_std_min: float = -20, + log_std_min: float = -5, log_std_max: float = 2, # --- Logging & buffer --- render_progress: bool = True, diff --git a/mighty/mighty_exploration/stochastic_policy.py b/mighty/mighty_exploration/stochastic_policy.py index 63e995d0..b845075c 100644 --- a/mighty/mighty_exploration/stochastic_policy.py +++ b/mighty/mighty_exploration/stochastic_policy.py @@ -177,12 +177,9 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso # Special handling for SACModel elif isinstance(self.model, SACModel): action, z, mean, log_std = self.model(state, deterministic=False) - std = torch.exp(log_std) - dist = Normal(mean, std) - - log_pz = dist.log_prob(z).sum(dim=-1, keepdim=True) - weighted_log_prob = log_pz * self.entropy_coefficient - return action.detach().cpu().numpy(), weighted_log_prob + # CRITICAL: Use the model's policy_log_prob which includes tanh correction + log_prob = self.model.policy_log_prob(z, mean, log_std) + return action.detach().cpu().numpy(), log_prob else: raise RuntimeError( diff --git a/mighty/mighty_models/sac.py b/mighty/mighty_models/sac.py index dda12072..b27118c6 100644 --- a/mighty/mighty_models/sac.py +++ b/mighty/mighty_models/sac.py @@ -17,7 +17,7 @@ def __init__( self, obs_size: int, action_size: int, - log_std_min: float = -20, + log_std_min: float = -5, log_std_max: float = 2, **kwargs, ): @@ -124,7 +124,17 @@ def forward( feats = self.feature_extractor(state) x = self.policy_net(feats) mean, log_std = x.chunk(2, dim=-1) - log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max) + # log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max) + + # NEW - Soft clamping + log_std = torch.tanh(log_std) + log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1) + + # This maps tanh output [-1, 1] to [log_std_min, log_std_max] + # When tanh(x) = -1: log_std = log_std_min + 0.5 * range * 0 = log_std_min + # When tanh(x) = 0: log_std = log_std_min + 0.5 * range * 1 = (log_std_min + log_std_max) / 2 + # When tanh(x) = 1: log_std = log_std_min + 0.5 * range * 2 = log_std_max + std = torch.exp(log_std) if deterministic: diff --git a/mighty/mighty_update/sac_update.py b/mighty/mighty_update/sac_update.py index aba3aa12..74f16082 100644 --- a/mighty/mighty_update/sac_update.py +++ b/mighty/mighty_update/sac_update.py @@ -127,36 +127,36 @@ def update(self, batch: TransitionBatch) -> Dict: alpha_loss = torch.tensor(0.0) if self.update_step % self.policy_frequency == 0: # do multiple policy updates to compensate for delay - for _ in range(self.policy_frequency): - # recompute alpha after q update - current_alpha = ( - self.log_alpha.exp().detach() if self.auto_alpha else self.alpha - ) - - a, z, mean, log_std = self.model(states) - logp = self.model.policy_log_prob(z, mean, log_std) - sa_pi = torch.cat([states, a], dim=-1) - q1_pi = self.model.q_net1(sa_pi) - q2_pi = self.model.q_net2(sa_pi) - q_pi = torch.min(q1_pi, q2_pi) - policy_loss = (current_alpha * logp - q_pi).mean() - - self.policy_optimizer.zero_grad() - policy_loss.backward() - self.policy_optimizer.step() - - # --- Entropy coefficient (alpha) update --- - if self.auto_alpha: - with torch.no_grad(): - _, _, _, _ = self.model(states) - # Use the logp from the policy update above - alpha_loss = -( - self.log_alpha * (logp.detach() + self.target_entropy) - ).mean() - self.alpha_optimizer.zero_grad() - alpha_loss.backward() - self.alpha_optimizer.step() - self.alpha = self.log_alpha.exp().item() + # for _ in range(self.policy_frequency): + # recompute alpha after q update + current_alpha = ( + self.log_alpha.exp().detach() if self.auto_alpha else self.alpha + ) + + a, z, mean, log_std = self.model(states) + logp = self.model.policy_log_prob(z, mean, log_std) + sa_pi = torch.cat([states, a], dim=-1) + q1_pi = self.model.q_net1(sa_pi) + q2_pi = self.model.q_net2(sa_pi) + q_pi = torch.min(q1_pi, q2_pi) + policy_loss = (current_alpha * logp - q_pi).mean() + + self.policy_optimizer.zero_grad() + policy_loss.backward() + self.policy_optimizer.step() + + # --- Entropy coefficient (alpha) update --- + if self.auto_alpha: + with torch.no_grad(): + _, _, _, _ = self.model(states) + # Use the logp from the policy update above + alpha_loss = -( + self.log_alpha * (logp.detach() + self.target_entropy) + ).mean() + self.alpha_optimizer.zero_grad() + alpha_loss.backward() + self.alpha_optimizer.step() + self.alpha = self.log_alpha.exp().item() # --- Soft update targets --- if self.update_step % self.target_network_frequency == 0: From 707b28002ede91ad1cab154c5124aec7d7d04a05 Mon Sep 17 00:00:00 2001 From: Aditya Mohan Date: Thu, 7 Aug 2025 02:23:55 +0200 Subject: [PATCH 2/8] update --- mighty/configs/algorithm/sac.yaml | 3 +- mighty/mighty_agents/sac.py | 8 +- .../mighty_exploration_policy.py | 37 +++++---- .../mighty_exploration/stochastic_policy.py | 75 +++++++++++-------- mighty/mighty_update/sac_update.py | 60 +++++++-------- 5 files changed, 105 insertions(+), 78 deletions(-) diff --git a/mighty/configs/algorithm/sac.yaml b/mighty/configs/algorithm/sac.yaml index 07c80294..07361e57 100644 --- a/mighty/configs/algorithm/sac.yaml +++ b/mighty/configs/algorithm/sac.yaml @@ -19,13 +19,14 @@ algorithm_kwargs: replay_buffer_kwargs: capacity: 1e6 + # Scheduling & batch-updates batch_size: 256 learning_starts: 5000 # Good, matches CleanRL update_every: 1 # Good, update every step n_gradient_steps: 1 # Good - # Learning rates - CRITICAL CHANGE + # Learning rates policy_lr: 3e-4 q_lr: 1e-3 # This is correct now (was 3e-4) alpha_lr: 3e-4 # 3e-4 is better than 1e-3 for alpha diff --git a/mighty/mighty_agents/sac.py b/mighty/mighty_agents/sac.py index 316eed79..d0feecbd 100644 --- a/mighty/mighty_agents/sac.py +++ b/mighty/mighty_agents/sac.py @@ -207,8 +207,12 @@ def process_transition( # Ensure metrics dict if metrics is None: metrics = {} - # Pack transition - transition = TransitionBatch(curr_s, action, reward, next_s, dones) + # Pack transition + # `terminated` is used for physics failures in environments like `MightyEnv` + # Based on https://github.com/DLR-RM/stable-baselines3/issues/284 + terminated = metrics["transition"]["terminated"] # physics‐failures + transition = TransitionBatch(curr_s, action, reward, next_s, terminated.astype(int)) + # Compute per-transition TD errors for logging td1, td2 = self.update_fn.calculate_td_error(transition) metrics["td_error1"] = td1.detach().cpu().numpy() diff --git a/mighty/mighty_exploration/mighty_exploration_policy.py b/mighty/mighty_exploration/mighty_exploration_policy.py index 39948d5c..534693e8 100644 --- a/mighty/mighty_exploration/mighty_exploration_policy.py +++ b/mighty/mighty_exploration/mighty_exploration_policy.py @@ -8,29 +8,34 @@ import torch from torch.distributions import Categorical, Normal +from mighty.mighty_models import SACModel + def sample_nondeterministic_logprobs( - z: torch.Tensor, mean: torch.Tensor, log_std: torch.Tensor, sac: bool = False -) -> Tuple[torch.Tensor, torch.Tensor]: + z: torch.Tensor, + mean: torch.Tensor, + log_std: torch.Tensor, + sac: bool = False +) -> torch.Tensor: + """ + Compute log-prob of a Gaussian sample z ~ N(mean, exp(log_std)), + and if sac=True apply the tanh-squash correction to get log π(a). + """ std = torch.exp(log_std) # [batch, action_dim] dist = Normal(mean, std) + # base Gaussian log‐prob of z + log_pz = dist.log_prob(z).sum(dim=-1, keepdim=True) # [batch, 1] - # For SAC, don't apply correction if sac: - return dist.log_prob(z).sum(dim=-1, keepdim=True) # [batch, 1] - # If not SAC, we need to apply the tanh correction - else: - log_pz = dist.log_prob(z).sum(dim=-1, keepdim=True) # [batch, 1] - - # 2b) tanh‐correction = ∑ᵢ log(1 − tanh(zᵢ)² + ε) - eps = 1e-6 + # subtract the ∑_i log(d tanh/dz_i) = ∑ log(1 - tanh(z)^2) + eps = 1e-4 log_correction = torch.log(1.0 - torch.tanh(z).pow(2) + eps).sum( dim=-1, keepdim=True ) # [batch, 1] - - # 2c) final log_prob of a = tanh(z) - log_prob = log_pz - log_correction # [batch, 1] - return log_prob + return log_pz - log_correction + else: + # PPO-style or other: no squash correction + return log_pz class MightyExplorationPolicy: @@ -111,8 +116,10 @@ def sample_func_logits(self, state_array): # ─── Continuous squashed‐Gaussian (4‐tuple) ────────────────────────── elif isinstance(out, tuple) and len(out) == 4: action = out[0] # [batch, action_dim] + + print(f'Self Model : {self.model}') log_prob = sample_nondeterministic_logprobs( - z=out[1], mean=out[2], log_std=out[3], sac=self.algo == "sac" + z=out[1], mean=out[2], log_std=out[3], sac=isinstance(self.model, SACModel) ) return action.detach().cpu().numpy(), log_prob diff --git a/mighty/mighty_exploration/stochastic_policy.py b/mighty/mighty_exploration/stochastic_policy.py index b845075c..28e8cfdb 100644 --- a/mighty/mighty_exploration/stochastic_policy.py +++ b/mighty/mighty_exploration/stochastic_policy.py @@ -27,9 +27,13 @@ def __init__( :param entropy_coefficient: weight on entropy term :param discrete: whether the action space is discrete """ + + self.model = model + super().__init__(algo, model, discrete) self.entropy_coefficient = entropy_coefficient self.discrete = discrete + # --- override sample_action only for continuous SAC --- if not discrete and isinstance(model, SACModel): @@ -84,33 +88,24 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso # 4-tuple case (Tanh squashing): (action, z, mean, log_std) elif isinstance(model_output, tuple) and len(model_output) == 4: action, z, mean, log_std = model_output - log_prob = sample_nondeterministic_logprobs( - z=z, - mean=mean, - log_std=log_std, - sac=self.algo == "sac", - ) + + if not isinstance(self.model, SACModel): + + log_prob = sample_nondeterministic_logprobs( + z=z, + mean=mean, + log_std=log_std, + sac=False, + ) + else: + log_prob = self.model.policy_log_prob(z, mean, log_std) if return_logp: return action.detach().cpu().numpy(), log_prob else: - weighted_log_prob = log_prob * self.entropy_coefficient + weighted_log_prob = log_prob return action.detach().cpu().numpy(), weighted_log_prob - # Legacy 2-tuple case: (mean, std) - elif isinstance(model_output, tuple) and len(model_output) == 2: - mean, std = model_output - dist = Normal(mean, std) - z = dist.rsample() # [batch, action_dim] - action = torch.tanh(z) # [batch, action_dim] - - log_prob = sample_nondeterministic_logprobs( - z=z, mean=mean, log_std=torch.log(std), sac=self.algo == "sac" - ) - entropy = dist.entropy().sum(dim=-1, keepdim=True) # [batch, 1] - weighted_log_prob = log_prob * entropy - return action.detach().cpu().numpy(), weighted_log_prob - # Check for model attribute-based approaches elif hasattr(self.model, "continuous_action") and getattr( self.model, "continuous_action" @@ -126,9 +121,16 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso elif len(model_output) == 4: # Tanh squashing mode: (action, z, mean, log_std) action, z, mean, log_std = model_output - log_prob = sample_nondeterministic_logprobs( - z=z, mean=mean, log_std=log_std, sac=self.algo == "sac" - ) + if not isinstance(self.model, SACModel): + + log_prob = sample_nondeterministic_logprobs( + z=z, + mean=mean, + log_std=log_std, + sac=False, + ) + else: + log_prob = self.model.policy_log_prob(z, mean, log_std) else: raise ValueError( f"Unexpected model output length: {len(model_output)}" @@ -145,9 +147,15 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso if self.model.output_style == "squashed_gaussian": # Should be 4-tuple: (action, z, mean, log_std) action, z, mean, log_std = model_output - log_prob = sample_nondeterministic_logprobs( - z=z, mean=mean, log_std=log_std, sac=self.algo == "sac" - ) + if not isinstance(self.model, SACModel): + log_prob = sample_nondeterministic_logprobs( + z=z, + mean=mean, + log_std=log_std, + sac=False, + ) + else: + log_prob = self.model.policy_log_prob(z, mean, log_std) if return_logp: return action.detach().cpu().numpy(), log_prob @@ -162,9 +170,16 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso z = dist.rsample() action = torch.tanh(z) - log_prob = sample_nondeterministic_logprobs( - z=z, mean=mean, log_std=torch.log(std), sac=self.algo == "sac" - ) + if not isinstance(self.model, SACModel): + log_prob = sample_nondeterministic_logprobs( + z=z, + mean=mean, + log_std=log_std, + sac=False, + ) + else: + log_prob = self.model.policy_log_prob(z, mean, log_std) + entropy = dist.entropy().sum(dim=-1, keepdim=True) weighted_log_prob = log_prob * entropy return action.detach().cpu().numpy(), weighted_log_prob diff --git a/mighty/mighty_update/sac_update.py b/mighty/mighty_update/sac_update.py index 74f16082..e1c96db2 100644 --- a/mighty/mighty_update/sac_update.py +++ b/mighty/mighty_update/sac_update.py @@ -127,36 +127,36 @@ def update(self, batch: TransitionBatch) -> Dict: alpha_loss = torch.tensor(0.0) if self.update_step % self.policy_frequency == 0: # do multiple policy updates to compensate for delay - # for _ in range(self.policy_frequency): - # recompute alpha after q update - current_alpha = ( - self.log_alpha.exp().detach() if self.auto_alpha else self.alpha - ) - - a, z, mean, log_std = self.model(states) - logp = self.model.policy_log_prob(z, mean, log_std) - sa_pi = torch.cat([states, a], dim=-1) - q1_pi = self.model.q_net1(sa_pi) - q2_pi = self.model.q_net2(sa_pi) - q_pi = torch.min(q1_pi, q2_pi) - policy_loss = (current_alpha * logp - q_pi).mean() - - self.policy_optimizer.zero_grad() - policy_loss.backward() - self.policy_optimizer.step() - - # --- Entropy coefficient (alpha) update --- - if self.auto_alpha: - with torch.no_grad(): - _, _, _, _ = self.model(states) - # Use the logp from the policy update above - alpha_loss = -( - self.log_alpha * (logp.detach() + self.target_entropy) - ).mean() - self.alpha_optimizer.zero_grad() - alpha_loss.backward() - self.alpha_optimizer.step() - self.alpha = self.log_alpha.exp().item() + for _ in range(self.policy_frequency): + # recompute alpha after q update + current_alpha = ( + self.log_alpha.exp().detach() if self.auto_alpha else self.alpha + ) + + a, z, mean, log_std = self.model(states) + logp = self.model.policy_log_prob(z, mean, log_std) + sa_pi = torch.cat([states, a], dim=-1) + q1_pi = self.model.q_net1(sa_pi) + q2_pi = self.model.q_net2(sa_pi) + q_pi = torch.min(q1_pi, q2_pi) + policy_loss = (current_alpha * logp - q_pi).mean() + + self.policy_optimizer.zero_grad() + policy_loss.backward() + self.policy_optimizer.step() + + # --- Entropy coefficient (alpha) update --- + if self.auto_alpha: + with torch.no_grad(): + _, _, _, _ = self.model(states) + # Use the logp from the policy update above + alpha_loss = -( + self.log_alpha.exp() * (logp.detach() + self.target_entropy) + ).mean() + self.alpha_optimizer.zero_grad() + alpha_loss.backward() + self.alpha_optimizer.step() + self.alpha = self.log_alpha.exp().item() # --- Soft update targets --- if self.update_step % self.target_network_frequency == 0: From 979742e1ede005acf3e8ea2f89efd89fcac04401 Mon Sep 17 00:00:00 2001 From: Aditya Mohan Date: Thu, 7 Aug 2025 04:08:03 +0200 Subject: [PATCH 3/8] update --- mighty/configs/algorithm/sac.yaml | 2 +- mighty/mighty_agents/base_agent.py | 26 ++++++++++++++++++++++---- mighty/mighty_agents/sac.py | 2 ++ mighty/mighty_models/sac.py | 20 +++++++++++++++++++- 4 files changed, 44 insertions(+), 6 deletions(-) diff --git a/mighty/configs/algorithm/sac.yaml b/mighty/configs/algorithm/sac.yaml index 07361e57..658b025a 100644 --- a/mighty/configs/algorithm/sac.yaml +++ b/mighty/configs/algorithm/sac.yaml @@ -51,7 +51,7 @@ algorithm_kwargs: # SAC specific frequencies policy_frequency: 2 # Can also try 1 for even better performance - target_network_frequency: 1 # Update targets every step + target_network_frequency: 2 # Update targets every step # Environment and training configuration num_envs: 1 # Single environment diff --git a/mighty/mighty_agents/base_agent.py b/mighty/mighty_agents/base_agent.py index 790a7c74..4a1672f5 100644 --- a/mighty/mighty_agents/base_agent.py +++ b/mighty/mighty_agents/base_agent.py @@ -141,6 +141,7 @@ def __init__( # noqa: PLR0915, PLR0912 normalize_obs: bool = False, normalize_reward: bool = False, rescale_action: bool = False, + handle_timeout_termination: bool = False, ): """Base agent initialization. @@ -301,6 +302,8 @@ def __init__( # noqa: PLR0915, PLR0912 for m in self.meta_modules.values(): m.seed(self.seed) self.steps = 0 + + self.handle_timeout_termination = handle_timeout_termination def _initialize_agent(self) -> None: """Agent/algorithm specific initializations.""" @@ -603,8 +606,23 @@ def run( # noqa: PLR0915 metrics["episode_reward"] = episode_reward action, log_prob = self.step(curr_s, metrics) - next_s, reward, terminated, truncated, _ = self.env.step(action) # type: ignore - dones = np.logical_or(terminated, truncated) + # 1) step the env as usual + next_s, reward, terminated, truncated, infos = self.env.step(action) + + # 2) decide which samples are true “done” + replay_dones = terminated # physics‐failure only + dones = np.logical_or(terminated, truncated) + + + # 3) optionally overwrite next_s on truncation + if self.handle_timeout_termination: + real_next_s = next_s.copy() + # infos["final_observation"] is a list/array of the last real obs + for i, tr in enumerate(truncated): + if tr: + real_next_s[i] = infos["final_observation"][i] + else: + real_next_s = next_s episode_reward += reward @@ -615,10 +633,10 @@ def run( # noqa: PLR0915 "reward": reward, "action": action, "state": curr_s, - "next_state": next_s, + "next_state": real_next_s, "terminated": terminated.astype(int), "truncated": truncated.astype(int), - "dones": dones.astype(int), + "dones": replay_dones.astype(int), "mean_episode_reward": last_episode_reward.mean() .cpu() .numpy() diff --git a/mighty/mighty_agents/sac.py b/mighty/mighty_agents/sac.py index d0feecbd..758f1919 100644 --- a/mighty/mighty_agents/sac.py +++ b/mighty/mighty_agents/sac.py @@ -57,6 +57,7 @@ def __init__( rescale_action: bool = False, # ← NEW Whether to rescale actions to the environment's action space policy_frequency: int = 2, # Frequency of policy updates target_network_frequency: int = 1, # Frequency of target network updates + handle_timeout_termination: bool = True, ): """Initialize SAC agent with tunable hyperparameters and backward-compatible names.""" if hidden_sizes is None: @@ -116,6 +117,7 @@ def __init__( rescale_action=rescale_action, batch_size=batch_size, learning_rate=policy_lr, # For compatibility with base class + handle_timeout_termination=handle_timeout_termination, ) # Initialize loss buffer for logging diff --git a/mighty/mighty_models/sac.py b/mighty/mighty_models/sac.py index b27118c6..d3902552 100644 --- a/mighty/mighty_models/sac.py +++ b/mighty/mighty_models/sac.py @@ -19,6 +19,8 @@ def __init__( action_size: int, log_std_min: float = -5, log_std_max: float = 2, + action_low: float = -1, + action_high: float = +1, **kwargs, ): super().__init__() @@ -29,6 +31,16 @@ def __init__( # This model is continuous only self.continuous_action = True + + # PR: register the per-dim scale and bias so we can rescale [-1,1]→[low,high]. + action_low = torch.as_tensor(action_low, dtype=torch.float32) + action_high = torch.as_tensor(action_high, dtype=torch.float32) + self.register_buffer( + "action_scale", (action_high - action_low) / 2.0 + ) + self.register_buffer( + "action_bias", (action_high + action_low) / 2.0 + ) head_kwargs = {"hidden_sizes": [256, 256], "activation": "relu"} feature_extractor_kwargs = { @@ -141,7 +153,13 @@ def forward( z = mean else: z = mean + std * torch.randn_like(mean) - action = torch.tanh(z) + + # tanh→[-1,1] + raw_action = torch.tanh(z) + + # **HERE** we rescale into [low,high] + action = raw_action * self.action_scale + self.action_bias + return action, z, mean, log_std def policy_log_prob( From 484d1f25894e25082828dd590c612833792a4b72 Mon Sep 17 00:00:00 2001 From: Aditya Mohan Date: Thu, 7 Aug 2025 10:25:00 +0200 Subject: [PATCH 4/8] SAC updates --- mighty/configs/algorithm/sac.yaml | 2 +- mighty/mighty_models/sac.py | 37 ++++++++++++++++++++++-------- mighty/mighty_update/sac_update.py | 13 +++++++---- 3 files changed, 37 insertions(+), 15 deletions(-) diff --git a/mighty/configs/algorithm/sac.yaml b/mighty/configs/algorithm/sac.yaml index 658b025a..07361e57 100644 --- a/mighty/configs/algorithm/sac.yaml +++ b/mighty/configs/algorithm/sac.yaml @@ -51,7 +51,7 @@ algorithm_kwargs: # SAC specific frequencies policy_frequency: 2 # Can also try 1 for even better performance - target_network_frequency: 2 # Update targets every step + target_network_frequency: 1 # Update targets every step # Environment and training configuration num_envs: 1 # Single environment diff --git a/mighty/mighty_models/sac.py b/mighty/mighty_models/sac.py index d3902552..a31756fa 100644 --- a/mighty/mighty_models/sac.py +++ b/mighty/mighty_models/sac.py @@ -73,7 +73,12 @@ def __init__( ) # Policy network outputs mean and log_std - self.policy_net = nn.Linear(out_dim, action_size * 2) + # CHANGE: Create separate policy network (actor) similar to CleanRL + self.policy_net = make_policy_head( + in_size=self.obs_size, + out_size=self.action_size * 2, # mean and log_std + **head_kwargs + ) # Twin Q-networks # — live Q-nets — @@ -133,20 +138,13 @@ def forward( mean: Gaussian mean log_std: Gaussian log std """ - feats = self.feature_extractor(state) - x = self.policy_net(feats) + x = self.policy_net(state) mean, log_std = x.chunk(2, dim=-1) - # log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max) - # NEW - Soft clamping + # Soft clamping log_std = torch.tanh(log_std) log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1) - # This maps tanh output [-1, 1] to [log_std_min, log_std_max] - # When tanh(x) = -1: log_std = log_std_min + 0.5 * range * 0 = log_std_min - # When tanh(x) = 0: log_std = log_std_min + 0.5 * range * 1 = (log_std_min + log_std_max) / 2 - # When tanh(x) = 1: log_std = log_std_min + 0.5 * range * 2 = log_std_max - std = torch.exp(log_std) if deterministic: @@ -207,3 +205,22 @@ def make_q_head(in_size, hidden_sizes=None, activation="relu"): layers.append(nn.Linear(last_size, 1)) return nn.Sequential(*layers) + + +def make_policy_head(in_size, out_size, hidden_sizes=None, activation="relu"): + """Make policy head network (actor).""" + if hidden_sizes is None: + hidden_sizes = [] + + layers = [] + last_size = in_size + if isinstance(last_size, list): + last_size = last_size[0] + + for size in hidden_sizes: + layers.append(nn.Linear(last_size, size)) + layers.append(ACTIVATIONS[activation]()) + last_size = size + + layers.append(nn.Linear(last_size, out_size)) + return nn.Sequential(*layers) \ No newline at end of file diff --git a/mighty/mighty_update/sac_update.py b/mighty/mighty_update/sac_update.py index e1c96db2..ecd4a52e 100644 --- a/mighty/mighty_update/sac_update.py +++ b/mighty/mighty_update/sac_update.py @@ -41,7 +41,7 @@ def __init__( self.update_step = 0 if self.auto_alpha: - self.log_alpha = torch.nn.Parameter(torch.zeros(1, requires_grad=True)) + self.log_alpha = torch.zeros(1, requires_grad=True) self.alpha_optimizer = optim.Adam([self.log_alpha], lr=alpha_lr or q_lr) self.target_entropy = ( -float(self.action_dim) @@ -133,9 +133,12 @@ def update(self, batch: TransitionBatch) -> Dict: self.log_alpha.exp().detach() if self.auto_alpha else self.alpha ) + # FIX: Sample fresh actions for each policy update iteration + # This ensures stochasticity across iterations a, z, mean, log_std = self.model(states) logp = self.model.policy_log_prob(z, mean, log_std) sa_pi = torch.cat([states, a], dim=-1) + q1_pi = self.model.q_net1(sa_pi) q2_pi = self.model.q_net2(sa_pi) q_pi = torch.min(q1_pi, q2_pi) @@ -147,11 +150,13 @@ def update(self, batch: TransitionBatch) -> Dict: # --- Entropy coefficient (alpha) update --- if self.auto_alpha: + # CRITICAL FIX: Get fresh sample for alpha update with torch.no_grad(): - _, _, _, _ = self.model(states) - # Use the logp from the policy update above + _, z_alpha, mean_alpha, log_std_alpha = self.model(states) + logp_alpha = self.model.policy_log_prob(z_alpha, mean_alpha, log_std_alpha) + alpha_loss = -( - self.log_alpha.exp() * (logp.detach() + self.target_entropy) + self.log_alpha.exp() * (logp_alpha.detach() + self.target_entropy) ).mean() self.alpha_optimizer.zero_grad() alpha_loss.backward() From c667d0aade2a4af67a9a101ee87a38094d575ada Mon Sep 17 00:00:00 2001 From: Aditya Mohan Date: Thu, 7 Aug 2025 11:25:30 +0200 Subject: [PATCH 5/8] updated code + tests --- mighty/mighty_agents/dqn.py | 2 + mighty/mighty_agents/ppo.py | 2 + test/agents/test_sac_agent.py | 17 ++++++-- test/models/test_sac_networks.py | 71 ++++++++++++++++++++++++-------- 4 files changed, 71 insertions(+), 21 deletions(-) diff --git a/mighty/mighty_agents/dqn.py b/mighty/mighty_agents/dqn.py index 0ced9ce4..c6cc1cdb 100644 --- a/mighty/mighty_agents/dqn.py +++ b/mighty/mighty_agents/dqn.py @@ -69,6 +69,7 @@ def __init__( normalize_obs: bool = False, normalize_reward: bool = False, rescale_action: bool = False, # type: ignore + handle_timeout_termination: bool = False, ): """DQN initialization. @@ -154,6 +155,7 @@ def __init__( normalize_obs=normalize_obs, normalize_reward=normalize_reward, rescale_action=rescale_action, + handle_timeout_termination=handle_timeout_termination ) self.loss_buffer = { diff --git a/mighty/mighty_agents/ppo.py b/mighty/mighty_agents/ppo.py index 8f9a6fc1..8b974ee4 100644 --- a/mighty/mighty_agents/ppo.py +++ b/mighty/mighty_agents/ppo.py @@ -62,6 +62,7 @@ def __init__( normalize_reward: bool = False, rescale_action: bool = False, tanh_squash: bool = False, + handle_timeout_termination: bool = False, ): """Initialize the PPO agent. @@ -143,6 +144,7 @@ def __init__( normalize_obs=normalize_obs, normalize_reward=normalize_reward, rescale_action=rescale_action, + handle_timeout_termination=handle_timeout_termination ) self.loss_buffer = { diff --git a/test/agents/test_sac_agent.py b/test/agents/test_sac_agent.py index cf3d0d13..66e75b42 100644 --- a/test/agents/test_sac_agent.py +++ b/test/agents/test_sac_agent.py @@ -99,6 +99,13 @@ def test_update(self): dones = np.logical_or(terminated, truncated) # Process the transition (this adds to buffer) + # SAC agent expects metrics with transition info including terminated status + transition_metrics = { + "step": step, + "transition": { + "terminated": terminated, # Use the terminated from env.step() + } + } agent.process_transition( curr_s, action, @@ -106,7 +113,7 @@ def test_update(self): next_s, dones, log_prob.detach().cpu().numpy(), - {"step": step}, + transition_metrics, ) # Update current state @@ -272,7 +279,8 @@ def test_reproducibility(self): init_params = deepcopy(list(sac.model.parameters())) sac.run(20, 1) batch = sac.buffer.sample(20) - original_metrics = sac.update_agent(batch, 20) + # Fix: update_agent expects proper keyword arguments + original_metrics = sac.update_fn.update(batch) original_params = deepcopy(list(sac.model.parameters())) env = gym.vector.SyncVectorEnv([DummyContinuousEnv for _ in range(1)]) @@ -303,7 +311,8 @@ def test_reproducibility(self): ) sac.run(20, 1) batch = sac.buffer.sample(20) - new_metrics = sac.update_agent(batch, 20) + # Fix: update_agent expects proper keyword arguments + new_metrics = sac.update_fn.update(batch) for old, new in zip( original_params[:10], list(sac.model.parameters())[:10], strict=False ): @@ -321,4 +330,4 @@ def test_reproducibility(self): original_metrics["Update/policy_loss"], new_metrics["Update/policy_loss"], ), "Policy loss should be the same with same seed" - clean(output_dir) + clean(output_dir) \ No newline at end of file diff --git a/test/models/test_sac_networks.py b/test/models/test_sac_networks.py index 6c8a123b..e913071f 100644 --- a/test/models/test_sac_networks.py +++ b/test/models/test_sac_networks.py @@ -14,14 +14,14 @@ def test_init(self): assert sac.obs_size == 8, "Obs size should be 8" assert sac.action_size == 3, "Action size should be 3" assert sac.activation == "tanh", "Passed activation should be tanh" - assert sac.log_std_min == -20, "Default log_std_min should be -20" + assert sac.log_std_min == -5, "Default log_std_min should be -5" # Fixed: was -20 assert sac.log_std_max == 2, "Default log_std_max should be 2" assert sac.continuous_action is True, "SAC should always be continuous" # Check network structure - updated for new architecture assert hasattr(sac, "feature_extractor"), "Should have feature extractor" - assert isinstance(sac.policy_net, nn.Linear), ( - "Policy network should be Linear (after feature extractor)" + assert isinstance(sac.policy_net, nn.Sequential), ( # Fixed: policy_net is Sequential, not Linear + "Policy network should be Sequential" ) assert isinstance(sac.q_net1, nn.Sequential), "Q-network 1 should be Sequential" assert isinstance(sac.q_net2, nn.Sequential), "Q-network 2 should be Sequential" @@ -116,7 +116,7 @@ def test_value_function_module(self): def test_forward_stochastic(self): """Test forward pass with stochastic policy.""" - sac = SACModel(obs_size=6, action_size=4) + sac = SACModel(obs_size=6, action_size=4, action_low=-2.0, action_high=3.0) dummy_state = torch.rand((10, 6)) action, z, mean, log_std = sac(dummy_state, deterministic=False) @@ -133,19 +133,20 @@ def test_forward_stochastic(self): assert torch.all(torch.isfinite(mean)), "Means should be finite" assert torch.all(torch.isfinite(log_std)), "Log_stds should be finite" - # Check tanh constraint on actions - assert torch.all(action >= -1.0) and torch.all(action <= 1.0), ( - "Actions should be in [-1, 1] range" + # Check action bounds - should be in [action_low, action_high] range + assert torch.all(action >= -2.0) and torch.all(action <= 3.0), ( + "Actions should be in [-2.0, 3.0] range" ) # Check log_std clamping assert torch.all(log_std >= sac.log_std_min), "Log_std should be >= log_std_min" assert torch.all(log_std <= sac.log_std_max), "Log_std should be <= log_std_max" - # Check relationship: action = tanh(z) - expected_action = torch.tanh(z) + # Check relationship: raw_action = tanh(z), then scaled to [action_low, action_high] + raw_action = torch.tanh(z) + expected_action = raw_action * sac.action_scale + sac.action_bias assert torch.allclose(action, expected_action, atol=1e-6), ( - "Action should equal tanh(z)" + "Action should equal scaled tanh(z)" ) def test_forward_deterministic(self): @@ -164,10 +165,11 @@ def test_forward_deterministic(self): # In deterministic mode, z should equal mean assert torch.allclose(z, mean), "In deterministic mode, z should equal mean" - # Action should still be tanh(z) = tanh(mean) - expected_action = torch.tanh(mean) + # Action should be scaled tanh(mean) + raw_action = torch.tanh(mean) + expected_action = raw_action * sac.action_scale + sac.action_bias assert torch.allclose(action, expected_action), ( - "Action should equal tanh(mean) in deterministic mode" + "Action should equal scaled tanh(mean) in deterministic mode" ) def test_stochastic_vs_deterministic(self): @@ -209,9 +211,14 @@ def test_policy_log_prob(self): # Check shape assert log_prob.shape == (6, 1), "Log prob should have shape (6, 1)" - # Check that log probabilities are finite and reasonable + # Check that log probabilities are finite assert torch.all(torch.isfinite(log_prob)), "Log probs should be finite" - assert torch.all(log_prob <= 0.0), "Log probs should be <= 0" + + # Note: Log probabilities can be positive in some cases for transformed distributions + # The key constraint is that they should be reasonable values + # For SAC with tanh transformation, log probs can be positive due to the Jacobian correction + assert torch.all(log_prob > -50.0), "Log probs should not be extremely negative" + assert torch.all(log_prob < 50.0), "Log probs should not be extremely positive" # Test with deterministic actions (z = mean) log_prob_det = sac.policy_log_prob(mean, mean, log_std) @@ -223,7 +230,7 @@ def test_q_networks(self): """Test Q-network forward passes.""" sac = SACModel(obs_size=4, action_size=2) dummy_state = torch.rand((7, 4)) - dummy_action = torch.rand((7, 2)) + dummy_action = torch.rand((7, 2)) * 2 - 1 # Actions in [-1, 1] range # Concatenate state and action for Q-networks state_action = torch.cat([dummy_state, dummy_action], dim=-1) @@ -290,7 +297,7 @@ def test_gradient_flow(self): """Test that gradients flow properly through networks.""" sac = SACModel(obs_size=4, action_size=2) dummy_state = torch.rand((3, 4)) - dummy_action = torch.rand((3, 2)) + dummy_action = torch.rand((3, 2)) * 2 - 1 # Actions in [-1, 1] state_action = torch.cat([dummy_state, dummy_action], dim=-1) # Test policy network gradients @@ -350,3 +357,33 @@ def test_numerical_stability(self): assert torch.all(torch.isfinite(boundary_log_prob)), ( "Log probabilities should be finite for boundary actions" ) + + def test_action_scaling(self): + """Test that action scaling works correctly.""" + # Test with custom action bounds + action_low = -2.5 + action_high = 1.5 + sac = SACModel(obs_size=3, action_size=2, action_low=action_low, action_high=action_high) + + dummy_state = torch.rand((5, 3)) + action, z, mean, log_std = sac(dummy_state) + + # Actions should be within the specified bounds + assert torch.all(action >= action_low), f"Actions should be >= {action_low}" + assert torch.all(action <= action_high), f"Actions should be <= {action_high}" + + # Check the scaling math + raw_action = torch.tanh(z) + expected_scale = (action_high - action_low) / 2.0 + expected_bias = (action_high + action_low) / 2.0 + expected_action = raw_action * expected_scale + expected_bias + + assert torch.allclose(action, expected_action, atol=1e-6), ( + "Action scaling should match expected formula" + ) + assert torch.allclose(sac.action_scale, torch.tensor(expected_scale)), ( + "Action scale should be computed correctly" + ) + assert torch.allclose(sac.action_bias, torch.tensor(expected_bias)), ( + "Action bias should be computed correctly" + ) \ No newline at end of file From ade9d40cdd5b4e3e63b5f1983bf6ce74bfe1ae31 Mon Sep 17 00:00:00 2001 From: Aditya Mohan Date: Thu, 7 Aug 2025 11:26:40 +0200 Subject: [PATCH 6/8] removed FIX comments --- mighty/mighty_update/sac_update.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mighty/mighty_update/sac_update.py b/mighty/mighty_update/sac_update.py index ecd4a52e..83a5018c 100644 --- a/mighty/mighty_update/sac_update.py +++ b/mighty/mighty_update/sac_update.py @@ -133,7 +133,7 @@ def update(self, batch: TransitionBatch) -> Dict: self.log_alpha.exp().detach() if self.auto_alpha else self.alpha ) - # FIX: Sample fresh actions for each policy update iteration + # Sample fresh actions for each policy update iteration # This ensures stochasticity across iterations a, z, mean, log_std = self.model(states) logp = self.model.policy_log_prob(z, mean, log_std) @@ -150,7 +150,7 @@ def update(self, batch: TransitionBatch) -> Dict: # --- Entropy coefficient (alpha) update --- if self.auto_alpha: - # CRITICAL FIX: Get fresh sample for alpha update + # Get fresh sample for alpha update with torch.no_grad(): _, z_alpha, mean_alpha, log_std_alpha = self.model(states) logp_alpha = self.model.policy_log_prob(z_alpha, mean_alpha, log_std_alpha) From c4a6d819fed28def1f350c2c5770934c5465b97e Mon Sep 17 00:00:00 2001 From: Aditya Mohan Date: Fri, 8 Aug 2025 11:56:20 +0200 Subject: [PATCH 7/8] updates for Merge --- mighty/mighty_agents/base_agent.py | 26 ++-- mighty/mighty_agents/dqn.py | 2 - mighty/mighty_agents/ppo.py | 2 - mighty/mighty_agents/sac.py | 5 +- .../mighty_exploration_policy.py | 6 +- .../mighty_exploration/stochastic_policy.py | 2 +- mighty/mighty_models/sac.py | 83 ++++++++---- test/models/test_sac_networks.py | 121 +++++++++++++----- 8 files changed, 161 insertions(+), 86 deletions(-) diff --git a/mighty/mighty_agents/base_agent.py b/mighty/mighty_agents/base_agent.py index 4a1672f5..ec69808e 100644 --- a/mighty/mighty_agents/base_agent.py +++ b/mighty/mighty_agents/base_agent.py @@ -141,7 +141,6 @@ def __init__( # noqa: PLR0915, PLR0912 normalize_obs: bool = False, normalize_reward: bool = False, rescale_action: bool = False, - handle_timeout_termination: bool = False, ): """Base agent initialization. @@ -302,8 +301,6 @@ def __init__( # noqa: PLR0915, PLR0912 for m in self.meta_modules.values(): m.seed(self.seed) self.steps = 0 - - self.handle_timeout_termination = handle_timeout_termination def _initialize_agent(self) -> None: """Agent/algorithm specific initializations.""" @@ -606,24 +603,21 @@ def run( # noqa: PLR0915 metrics["episode_reward"] = episode_reward action, log_prob = self.step(curr_s, metrics) - # 1) step the env as usual + # step the env as usual next_s, reward, terminated, truncated, infos = self.env.step(action) - # 2) decide which samples are true “done” + # decide which samples are true “done” replay_dones = terminated # physics‐failure only - dones = np.logical_or(terminated, truncated) + dones = np.logical_or(terminated, truncated) - # 3) optionally overwrite next_s on truncation - if self.handle_timeout_termination: - real_next_s = next_s.copy() - # infos["final_observation"] is a list/array of the last real obs - for i, tr in enumerate(truncated): - if tr: - real_next_s[i] = infos["final_observation"][i] - else: - real_next_s = next_s - + # Overwrite next_s on truncation + # Based on https://github.com/DLR-RM/stable-baselines3/issues/284 + real_next_s = next_s.copy() + # infos["final_observation"] is a list/array of the last real obs + for i, tr in enumerate(truncated): + if tr: + real_next_s[i] = infos["final_observation"][i] episode_reward += reward # Log everything diff --git a/mighty/mighty_agents/dqn.py b/mighty/mighty_agents/dqn.py index c6cc1cdb..0ced9ce4 100644 --- a/mighty/mighty_agents/dqn.py +++ b/mighty/mighty_agents/dqn.py @@ -69,7 +69,6 @@ def __init__( normalize_obs: bool = False, normalize_reward: bool = False, rescale_action: bool = False, # type: ignore - handle_timeout_termination: bool = False, ): """DQN initialization. @@ -155,7 +154,6 @@ def __init__( normalize_obs=normalize_obs, normalize_reward=normalize_reward, rescale_action=rescale_action, - handle_timeout_termination=handle_timeout_termination ) self.loss_buffer = { diff --git a/mighty/mighty_agents/ppo.py b/mighty/mighty_agents/ppo.py index 8b974ee4..8f9a6fc1 100644 --- a/mighty/mighty_agents/ppo.py +++ b/mighty/mighty_agents/ppo.py @@ -62,7 +62,6 @@ def __init__( normalize_reward: bool = False, rescale_action: bool = False, tanh_squash: bool = False, - handle_timeout_termination: bool = False, ): """Initialize the PPO agent. @@ -144,7 +143,6 @@ def __init__( normalize_obs=normalize_obs, normalize_reward=normalize_reward, rescale_action=rescale_action, - handle_timeout_termination=handle_timeout_termination ) self.loss_buffer = { diff --git a/mighty/mighty_agents/sac.py b/mighty/mighty_agents/sac.py index 758f1919..32753303 100644 --- a/mighty/mighty_agents/sac.py +++ b/mighty/mighty_agents/sac.py @@ -57,7 +57,6 @@ def __init__( rescale_action: bool = False, # ← NEW Whether to rescale actions to the environment's action space policy_frequency: int = 2, # Frequency of policy updates target_network_frequency: int = 1, # Frequency of target network updates - handle_timeout_termination: bool = True, ): """Initialize SAC agent with tunable hyperparameters and backward-compatible names.""" if hidden_sizes is None: @@ -117,7 +116,6 @@ def __init__( rescale_action=rescale_action, batch_size=batch_size, learning_rate=policy_lr, # For compatibility with base class - handle_timeout_termination=handle_timeout_termination, ) # Initialize loss buffer for logging @@ -209,9 +207,8 @@ def process_transition( # Ensure metrics dict if metrics is None: metrics = {} + # Pack transition - # `terminated` is used for physics failures in environments like `MightyEnv` - # Based on https://github.com/DLR-RM/stable-baselines3/issues/284 terminated = metrics["transition"]["terminated"] # physics‐failures transition = TransitionBatch(curr_s, action, reward, next_s, terminated.astype(int)) diff --git a/mighty/mighty_exploration/mighty_exploration_policy.py b/mighty/mighty_exploration/mighty_exploration_policy.py index 534693e8..7af628a8 100644 --- a/mighty/mighty_exploration/mighty_exploration_policy.py +++ b/mighty/mighty_exploration/mighty_exploration_policy.py @@ -115,11 +115,9 @@ def sample_func_logits(self, state_array): # ─── Continuous squashed‐Gaussian (4‐tuple) ────────────────────────── elif isinstance(out, tuple) and len(out) == 4: - action = out[0] # [batch, action_dim] - - print(f'Self Model : {self.model}') + action = out[0] # [batch, action_dim] log_prob = sample_nondeterministic_logprobs( - z=out[1], mean=out[2], log_std=out[3], sac=isinstance(self.model, SACModel) + z=out[1], mean=out[2], log_std=out[3], sac= self.ago == "sac" ) return action.detach().cpu().numpy(), log_prob diff --git a/mighty/mighty_exploration/stochastic_policy.py b/mighty/mighty_exploration/stochastic_policy.py index 28e8cfdb..3b28306e 100644 --- a/mighty/mighty_exploration/stochastic_policy.py +++ b/mighty/mighty_exploration/stochastic_policy.py @@ -103,7 +103,7 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso if return_logp: return action.detach().cpu().numpy(), log_prob else: - weighted_log_prob = log_prob + weighted_log_prob = log_prob * self.entropy_coefficient return action.detach().cpu().numpy(), weighted_log_prob # Check for model attribute-based approaches diff --git a/mighty/mighty_models/sac.py b/mighty/mighty_models/sac.py index a31756fa..045d9d91 100644 --- a/mighty/mighty_models/sac.py +++ b/mighty/mighty_models/sac.py @@ -32,7 +32,7 @@ def __init__( # This model is continuous only self.continuous_action = True - # PR: register the per-dim scale and bias so we can rescale [-1,1]→[low,high]. + # Register the per-dim scale and bias so we can rescale [-1,1]→[low,high]. action_low = torch.as_tensor(action_low, dtype=torch.float32) action_high = torch.as_tensor(action_high, dtype=torch.float32) self.register_buffer( @@ -67,42 +67,75 @@ def __init__( self.hidden_sizes = feature_extractor_kwargs.get("hidden_sizes", [256, 256]) self.activation = feature_extractor_kwargs.get("activation", "relu") - # Shared feature extractor for policy - self.feature_extractor, out_dim = make_feature_extractor( + # Policy feature extractor and head + self.policy_feature_extractor, policy_feat_dim = make_feature_extractor( **feature_extractor_kwargs ) - - # Policy network outputs mean and log_std - # CHANGE: Create separate policy network (actor) similar to CleanRL - self.policy_net = make_policy_head( - in_size=self.obs_size, + + # Policy head: just the final output layer + self.policy_head = make_policy_head( + in_size=policy_feat_dim, out_size=self.action_size * 2, # mean and log_std - **head_kwargs + hidden_sizes=[], # No hidden layers, just final linear layer + activation=head_kwargs["activation"] ) - # Twin Q-networks - # — live Q-nets — - self.q_net1 = make_q_head( - in_size=self.obs_size + self.action_size, **head_kwargs + # Create policy_net for backward compatibility + self.policy_net = nn.Sequential(self.policy_feature_extractor, self.policy_head) + + # Q-networks: feature extractors + heads + q_feature_extractor_kwargs = feature_extractor_kwargs.copy() + q_feature_extractor_kwargs["obs_shape"] = self.obs_size + self.action_size + + # Q-network 1 + self.q_feature_extractor1, q_feat_dim = make_feature_extractor(**q_feature_extractor_kwargs) + self.q_head1 = make_q_head( + in_size=q_feat_dim, + hidden_sizes=[], # No hidden layers, just final linear layer + activation=head_kwargs["activation"] ) - self.q_net2 = make_q_head( - in_size=self.obs_size + self.action_size, **head_kwargs + self.q_net1 = nn.Sequential(self.q_feature_extractor1, self.q_head1) + + # Q-network 2 + self.q_feature_extractor2, _ = make_feature_extractor(**q_feature_extractor_kwargs) + self.q_head2 = make_q_head( + in_size=q_feat_dim, + hidden_sizes=[], # No hidden layers, just final linear layer + activation=head_kwargs["activation"] ) + self.q_net2 = nn.Sequential(self.q_feature_extractor2, self.q_head2) # Target Q-networks - self.target_q_net1 = make_q_head( - in_size=self.obs_size + self.action_size, **head_kwargs + self.target_q_feature_extractor1, _ = make_feature_extractor(**q_feature_extractor_kwargs) + self.target_q_head1 = make_q_head( + in_size=q_feat_dim, + hidden_sizes=[], # No hidden layers, just final linear layer + activation=head_kwargs["activation"] ) - self.target_q_net1.load_state_dict(self.q_net1.state_dict()) - self.target_q_net2 = make_q_head( - in_size=self.obs_size + self.action_size, **head_kwargs + self.target_q_net1 = nn.Sequential(self.target_q_feature_extractor1, self.target_q_head1) + + self.target_q_feature_extractor2, _ = make_feature_extractor(**q_feature_extractor_kwargs) + self.target_q_head2 = make_q_head( + in_size=q_feat_dim, + hidden_sizes=[], # No hidden layers, just final linear layer + activation=head_kwargs["activation"] ) - self.target_q_net2.load_state_dict(self.q_net2.state_dict()) + self.target_q_net2 = nn.Sequential(self.target_q_feature_extractor2, self.target_q_head2) + + # Copy weights from live to target networks + self.target_q_feature_extractor1.load_state_dict(self.q_feature_extractor1.state_dict()) + self.target_q_head1.load_state_dict(self.q_head1.state_dict()) + self.target_q_feature_extractor2.load_state_dict(self.q_feature_extractor2.state_dict()) + self.target_q_head2.load_state_dict(self.q_head2.state_dict()) # Freeze target networks - for p in self.target_q_net1.parameters(): + for p in self.target_q_feature_extractor1.parameters(): + p.requires_grad = False + for p in self.target_q_head1.parameters(): + p.requires_grad = False + for p in self.target_q_feature_extractor2.parameters(): p.requires_grad = False - for p in self.target_q_net2.parameters(): + for p in self.target_q_head2.parameters(): p.requires_grad = False # Create a value function wrapper for compatibility @@ -133,7 +166,7 @@ def forward( Forward pass for policy sampling. Returns: - action: torch.Tensor in [-1,1] + action: torch.Tensor in rescaled range [action_low, action_high] z: raw Gaussian sample before tanh mean: Gaussian mean log_std: Gaussian log std @@ -155,7 +188,7 @@ def forward( # tanh→[-1,1] raw_action = torch.tanh(z) - # **HERE** we rescale into [low,high] + # Rescale into [action_low, action_high] action = raw_action * self.action_scale + self.action_bias return action, z, mean, log_std diff --git a/test/models/test_sac_networks.py b/test/models/test_sac_networks.py index e913071f..54622d2e 100644 --- a/test/models/test_sac_networks.py +++ b/test/models/test_sac_networks.py @@ -14,17 +14,22 @@ def test_init(self): assert sac.obs_size == 8, "Obs size should be 8" assert sac.action_size == 3, "Action size should be 3" assert sac.activation == "tanh", "Passed activation should be tanh" - assert sac.log_std_min == -5, "Default log_std_min should be -5" # Fixed: was -20 + assert sac.log_std_min == -5, "Default log_std_min should be -5" assert sac.log_std_max == 2, "Default log_std_max should be 2" assert sac.continuous_action is True, "SAC should always be continuous" - # Check network structure - updated for new architecture - assert hasattr(sac, "feature_extractor"), "Should have feature extractor" - assert isinstance(sac.policy_net, nn.Sequential), ( # Fixed: policy_net is Sequential, not Linear - "Policy network should be Sequential" - ) + # Check network structure - updated for feature extractor + head architecture + assert hasattr(sac, "policy_feature_extractor"), "Should have policy feature extractor" + assert hasattr(sac, "policy_head"), "Should have policy head" + assert isinstance(sac.policy_net, nn.Sequential), "Policy network should be Sequential" + + # Check Q-networks + assert hasattr(sac, "q_feature_extractor1"), "Should have Q1 feature extractor" + assert hasattr(sac, "q_head1"), "Should have Q1 head" assert isinstance(sac.q_net1, nn.Sequential), "Q-network 1 should be Sequential" assert isinstance(sac.q_net2, nn.Sequential), "Q-network 2 should be Sequential" + + # Check target networks assert isinstance(sac.target_q_net1, nn.Sequential), ( "Target Q-network 1 should be Sequential" ) @@ -36,23 +41,39 @@ def test_init(self): ) # Check that target networks have gradients disabled - for param in sac.target_q_net1.parameters(): + for param in sac.target_q_feature_extractor1.parameters(): + assert not param.requires_grad, ( + "Target Q1 feature extractor parameters should not require gradients" + ) + for param in sac.target_q_head1.parameters(): + assert not param.requires_grad, ( + "Target Q1 head parameters should not require gradients" + ) + for param in sac.target_q_feature_extractor2.parameters(): assert not param.requires_grad, ( - "Target Q-network 1 parameters should not require gradients" + "Target Q2 feature extractor parameters should not require gradients" ) - for param in sac.target_q_net2.parameters(): + for param in sac.target_q_head2.parameters(): assert not param.requires_grad, ( - "Target Q-network 2 parameters should not require gradients" + "Target Q2 head parameters should not require gradients" ) # Check that live networks have gradients enabled - for param in sac.q_net1.parameters(): + for param in sac.q_feature_extractor1.parameters(): assert param.requires_grad, ( - "Q-network 1 parameters should require gradients" + "Q1 feature extractor parameters should require gradients" ) - for param in sac.q_net2.parameters(): + for param in sac.q_head1.parameters(): assert param.requires_grad, ( - "Q-network 2 parameters should require gradients" + "Q1 head parameters should require gradients" + ) + for param in sac.q_feature_extractor2.parameters(): + assert param.requires_grad, ( + "Q2 feature extractor parameters should require gradients" + ) + for param in sac.q_head2.parameters(): + assert param.requires_grad, ( + "Q2 head parameters should require gradients" ) def test_init_custom_params(self): @@ -250,26 +271,45 @@ def test_target_networks_initialization(self): """Test that target networks are initialized with same weights as live networks.""" sac = SACModel(obs_size=3, action_size=2) - # Check that target networks have same weights as live networks initially + # Check that target feature extractors have same weights as live ones for p1, p_target1 in zip( - sac.q_net1.parameters(), sac.target_q_net1.parameters() + sac.q_feature_extractor1.parameters(), sac.target_q_feature_extractor1.parameters() ): assert torch.allclose(p1, p_target1), ( - "Target Q-net 1 should have same initial weights as Q-net 1" + "Target Q1 feature extractor should have same initial weights" ) for p2, p_target2 in zip( - sac.q_net2.parameters(), sac.target_q_net2.parameters() + sac.q_feature_extractor2.parameters(), sac.target_q_feature_extractor2.parameters() ): assert torch.allclose(p2, p_target2), ( - "Target Q-net 2 should have same initial weights as Q-net 2" + "Target Q2 feature extractor should have same initial weights" + ) + + # Check that target heads have same weights as live heads + for p1, p_target1 in zip( + sac.q_head1.parameters(), sac.target_q_head1.parameters() + ): + assert torch.allclose(p1, p_target1), ( + "Target Q1 head should have same initial weights as Q1 head" + ) + + for p2, p_target2 in zip( + sac.q_head2.parameters(), sac.target_q_head2.parameters() + ): + assert torch.allclose(p2, p_target2), ( + "Target Q2 head should have same initial weights as Q2 head" ) def test_twin_q_networks_independence(self): """Test that twin Q-networks are independent.""" sac = SACModel(obs_size=4, action_size=2) - # Check that Q-networks have different parameters (due to random initialization) + # Check that Q-networks have different objects (due to separate creation) + assert sac.q_feature_extractor1 is not sac.q_feature_extractor2, ( + "Q feature extractors should be separate objects" + ) + assert sac.q_head1 is not sac.q_head2, "Q heads should be separate objects" assert sac.q_net1 is not sac.q_net2, "Q-networks should be separate objects" assert sac.target_q_net1 is not sac.target_q_net2, ( "Target Q-networks should be separate objects" @@ -305,13 +345,15 @@ def test_gradient_flow(self): policy_loss = action.mean() # Dummy loss policy_loss.backward(retain_graph=True) - # Check that policy network has gradients - policy_has_grad = any(p.grad is not None for p in sac.policy_net.parameters()) - feature_has_grad = any( - p.grad is not None for p in sac.feature_extractor.parameters() + # Check that policy components have gradients + policy_feat_has_grad = any( + p.grad is not None for p in sac.policy_feature_extractor.parameters() ) - assert policy_has_grad or feature_has_grad, ( - "Policy network or feature extractor should have gradients" + policy_head_has_grad = any( + p.grad is not None for p in sac.policy_head.parameters() + ) + assert policy_feat_has_grad or policy_head_has_grad, ( + "Policy feature extractor or head should have gradients" ) # Test Q-network gradients @@ -320,15 +362,30 @@ def test_gradient_flow(self): q_loss = q1_value.mean() # Dummy loss q_loss.backward() - # Check that Q-network 1 has gradients - q1_has_grad = any(p.grad is not None for p in sac.q_net1.parameters()) - assert q1_has_grad, "Q-network 1 should have gradients" + # Check that Q1 components have gradients + q1_feat_has_grad = any( + p.grad is not None for p in sac.q_feature_extractor1.parameters() + ) + q1_head_has_grad = any( + p.grad is not None for p in sac.q_head1.parameters() + ) + assert q1_feat_has_grad or q1_head_has_grad, ( + "Q1 feature extractor or head should have gradients" + ) # Check that target networks don't have gradients - target_q1_has_grad = any( - p.grad is not None for p in sac.target_q_net1.parameters() + target_q1_feat_has_grad = any( + p.grad is not None for p in sac.target_q_feature_extractor1.parameters() + ) + target_q1_head_has_grad = any( + p.grad is not None for p in sac.target_q_head1.parameters() + ) + assert not target_q1_feat_has_grad, ( + "Target Q1 feature extractor should not have gradients" + ) + assert not target_q1_head_has_grad, ( + "Target Q1 head should not have gradients" ) - assert not target_q1_has_grad, "Target Q-network 1 should not have gradients" def test_numerical_stability(self): """Test numerical stability of log probability calculation.""" From a47ace993eefee9bf399a338258f86aed805e6e7 Mon Sep 17 00:00:00 2001 From: Aditya Mohan Date: Fri, 8 Aug 2025 12:06:05 +0200 Subject: [PATCH 8/8] removed instance comparisons in stochastic and exploration policies --- mighty/mighty_agents/sac.py | 12 ++++++---- .../mighty_exploration_policy.py | 9 +++----- .../mighty_exploration/stochastic_policy.py | 23 +++++++++---------- 3 files changed, 21 insertions(+), 23 deletions(-) diff --git a/mighty/mighty_agents/sac.py b/mighty/mighty_agents/sac.py index 32753303..bd26e1d6 100644 --- a/mighty/mighty_agents/sac.py +++ b/mighty/mighty_agents/sac.py @@ -145,7 +145,7 @@ def _initialize_agent(self) -> None: # Exploration policy wrapper self.policy = self.policy_class( - algo=self, model=self.model, **self.policy_kwargs + algo="sac", model=self.model, **self.policy_kwargs ) # Updater @@ -207,11 +207,13 @@ def process_transition( # Ensure metrics dict if metrics is None: metrics = {} - - # Pack transition + + # Pack transition terminated = metrics["transition"]["terminated"] # physics‐failures - transition = TransitionBatch(curr_s, action, reward, next_s, terminated.astype(int)) - + transition = TransitionBatch( + curr_s, action, reward, next_s, terminated.astype(int) + ) + # Compute per-transition TD errors for logging td1, td2 = self.update_fn.calculate_td_error(transition) metrics["td_error1"] = td1.detach().cpu().numpy() diff --git a/mighty/mighty_exploration/mighty_exploration_policy.py b/mighty/mighty_exploration/mighty_exploration_policy.py index 7af628a8..4d37e4a3 100644 --- a/mighty/mighty_exploration/mighty_exploration_policy.py +++ b/mighty/mighty_exploration/mighty_exploration_policy.py @@ -12,10 +12,7 @@ def sample_nondeterministic_logprobs( - z: torch.Tensor, - mean: torch.Tensor, - log_std: torch.Tensor, - sac: bool = False + z: torch.Tensor, mean: torch.Tensor, log_std: torch.Tensor, sac: bool = False ) -> torch.Tensor: """ Compute log-prob of a Gaussian sample z ~ N(mean, exp(log_std)), @@ -115,9 +112,9 @@ def sample_func_logits(self, state_array): # ─── Continuous squashed‐Gaussian (4‐tuple) ────────────────────────── elif isinstance(out, tuple) and len(out) == 4: - action = out[0] # [batch, action_dim] + action = out[0] # [batch, action_dim] log_prob = sample_nondeterministic_logprobs( - z=out[1], mean=out[2], log_std=out[3], sac= self.ago == "sac" + z=out[1], mean=out[2], log_std=out[3], sac=self.ago == "sac" ) return action.detach().cpu().numpy(), log_prob diff --git a/mighty/mighty_exploration/stochastic_policy.py b/mighty/mighty_exploration/stochastic_policy.py index 3b28306e..4c57c20b 100644 --- a/mighty/mighty_exploration/stochastic_policy.py +++ b/mighty/mighty_exploration/stochastic_policy.py @@ -27,13 +27,12 @@ def __init__( :param entropy_coefficient: weight on entropy term :param discrete: whether the action space is discrete """ - + self.model = model - + super().__init__(algo, model, discrete) self.entropy_coefficient = entropy_coefficient self.discrete = discrete - # --- override sample_action only for continuous SAC --- if not discrete and isinstance(model, SACModel): @@ -88,9 +87,9 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso # 4-tuple case (Tanh squashing): (action, z, mean, log_std) elif isinstance(model_output, tuple) and len(model_output) == 4: action, z, mean, log_std = model_output - - if not isinstance(self.model, SACModel): - + + if not self.algo == "sac": + log_prob = sample_nondeterministic_logprobs( z=z, mean=mean, @@ -121,8 +120,8 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso elif len(model_output) == 4: # Tanh squashing mode: (action, z, mean, log_std) action, z, mean, log_std = model_output - if not isinstance(self.model, SACModel): - + if not self.algo == "sac": + log_prob = sample_nondeterministic_logprobs( z=z, mean=mean, @@ -147,7 +146,7 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso if self.model.output_style == "squashed_gaussian": # Should be 4-tuple: (action, z, mean, log_std) action, z, mean, log_std = model_output - if not isinstance(self.model, SACModel): + if not self.algo == "sac": log_prob = sample_nondeterministic_logprobs( z=z, mean=mean, @@ -170,7 +169,7 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso z = dist.rsample() action = torch.tanh(z) - if not isinstance(self.model, SACModel): + if not self.algo == "sac": log_prob = sample_nondeterministic_logprobs( z=z, mean=mean, @@ -179,7 +178,7 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso ) else: log_prob = self.model.policy_log_prob(z, mean, log_std) - + entropy = dist.entropy().sum(dim=-1, keepdim=True) weighted_log_prob = log_prob * entropy return action.detach().cpu().numpy(), weighted_log_prob @@ -190,7 +189,7 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso ) # Special handling for SACModel - elif isinstance(self.model, SACModel): + elif self.algo == "sac" and isinstance(self.model, SACModel): action, z, mean, log_std = self.model(state, deterministic=False) # CRITICAL: Use the model's policy_log_prob which includes tanh correction log_prob = self.model.policy_log_prob(z, mean, log_std)