Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 21 additions & 14 deletions mighty/configs/algorithm/sac.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,43 +7,50 @@ algorithm_kwargs:
# Normalization
normalize_obs: False
normalize_reward: False
rescale_action: True # CRITICAL: Add this! Must be True for MuJoCo

# Network sizes
n_policy_units: 256
soft_update_weight: 0.005
n_policy_units: 256
soft_update_weight: 0.005 # tau in SAC terms

# Replay buffer
replay_buffer_class:
_target_: mighty.mighty_replay.MightyReplay
replay_buffer_kwargs:
capacity: 1e6


# Scheduling & batch-updates
batch_size: 256
learning_starts: 5000
update_every: 1
n_gradient_steps: 1
batch_size: 256
learning_starts: 5000 # Good, matches CleanRL
update_every: 1 # Good, update every step
n_gradient_steps: 1 # Good

# Learning rates
policy_lr: 3e-4
q_lr: 1e-3
alpha_lr: 1e-3
q_lr: 1e-3 # This is correct now (was 3e-4)
alpha_lr: 3e-4 # 3e-4 is better than 1e-3 for alpha

# SAC hyperparameters
gamma: 0.99
alpha: 0.2
auto_alpha: True
target_entropy: -6.0 # -action_dim for HalfCheetah (6 actions)
target_entropy: null # Let it auto-compute as -action_dim

# Network architecture
hidden_sizes: [256, 256] # Explicitly specify
activation: relu
log_std_min: -5
log_std_max: 2

# Policy configuration
policy_class: mighty.mighty_exploration.StochasticPolicy
policy_kwargs:
entropy_coefficient: 0.0
discrete: False

# Remove entropy_coefficient - SAC handles alpha internally

# SAC specific frequencies
policy_frequency: 2 # Delayed policy updates
policy_frequency: 2 # Can also try 1 for even better performance
target_network_frequency: 1 # Update targets every step

# Environment and training configuration
Expand All @@ -55,5 +62,5 @@ max_episode_steps: 1000 # HalfCheetah episode length
eval_frequency: 10000 # More frequent eval for single env
save_frequency: 50000 # Save every 50k steps


# python mighty/run_mighty.py algorithm=sac env=HalfCheetah-v4 num_steps=1e6 num_envs=1
# Command to run:
# python mighty/run_mighty.py algorithm=sac env=HalfCheetah-v4 num_steps=1e6 num_envs=1
20 changes: 16 additions & 4 deletions mighty/mighty_agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,9 +603,21 @@ def run( # noqa: PLR0915
metrics["episode_reward"] = episode_reward

action, log_prob = self.step(curr_s, metrics)
next_s, reward, terminated, truncated, _ = self.env.step(action) # type: ignore
dones = np.logical_or(terminated, truncated)
# step the env as usual
next_s, reward, terminated, truncated, infos = self.env.step(action)

# decide which samples are true “done”
replay_dones = terminated # physics‐failure only
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment here is env-specific. Also inconsistent: dones are always overwritten to real termination regardless of what the flag says.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make this default

dones = np.logical_or(terminated, truncated)


# Overwrite next_s on truncation
# Based on https://github.com/DLR-RM/stable-baselines3/issues/284
real_next_s = next_s.copy()
# infos["final_observation"] is a list/array of the last real obs
for i, tr in enumerate(truncated):
if tr:
real_next_s[i] = infos["final_observation"][i]
episode_reward += reward

# Log everything
Expand All @@ -615,10 +627,10 @@ def run( # noqa: PLR0915
"reward": reward,
"action": action,
"state": curr_s,
"next_state": next_s,
"next_state": real_next_s,
"terminated": terminated.astype(int),
"truncated": truncated.astype(int),
"dones": dones.astype(int),
"dones": replay_dones.astype(int),
"mean_episode_reward": last_episode_reward.mean()
.cpu()
.numpy()
Expand Down
11 changes: 8 additions & 3 deletions mighty/mighty_agents/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
# --- Network architecture (optional override) ---
hidden_sizes: Optional[List[int]] = None,
activation: str = "relu",
log_std_min: float = -20,
log_std_min: float = -5,
log_std_max: float = 2,
# --- Logging & buffer ---
render_progress: bool = True,
Expand Down Expand Up @@ -145,7 +145,7 @@ def _initialize_agent(self) -> None:

# Exploration policy wrapper
self.policy = self.policy_class(
algo=self, model=self.model, **self.policy_kwargs
algo="sac", model=self.model, **self.policy_kwargs
)

# Updater
Expand Down Expand Up @@ -207,8 +207,13 @@ def process_transition(
# Ensure metrics dict
if metrics is None:
metrics = {}

# Pack transition
transition = TransitionBatch(curr_s, action, reward, next_s, dones)
terminated = metrics["transition"]["terminated"] # physics‐failures
transition = TransitionBatch(
curr_s, action, reward, next_s, terminated.astype(int)
)

# Compute per-transition TD errors for logging
td1, td2 = self.update_fn.calculate_td_error(transition)
metrics["td_error1"] = td1.detach().cpu().numpy()
Expand Down
30 changes: 16 additions & 14 deletions mighty/mighty_exploration/mighty_exploration_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,31 @@
import torch
from torch.distributions import Categorical, Normal

from mighty.mighty_models import SACModel


def sample_nondeterministic_logprobs(
z: torch.Tensor, mean: torch.Tensor, log_std: torch.Tensor, sac: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
"""
Compute log-prob of a Gaussian sample z ~ N(mean, exp(log_std)),
and if sac=True apply the tanh-squash correction to get log π(a).
"""
std = torch.exp(log_std) # [batch, action_dim]
dist = Normal(mean, std)
# base Gaussian log‐prob of z
log_pz = dist.log_prob(z).sum(dim=-1, keepdim=True) # [batch, 1]

# For SAC, don't apply correction
if sac:
return dist.log_prob(z).sum(dim=-1, keepdim=True) # [batch, 1]
# If not SAC, we need to apply the tanh correction
else:
log_pz = dist.log_prob(z).sum(dim=-1, keepdim=True) # [batch, 1]

# 2b) tanh‐correction = ∑ᵢ log(1 − tanh(zᵢ)² + ε)
eps = 1e-6
# subtract the ∑_i log(d tanh/dz_i) = ∑ log(1 - tanh(z)^2)
eps = 1e-4
log_correction = torch.log(1.0 - torch.tanh(z).pow(2) + eps).sum(
dim=-1, keepdim=True
) # [batch, 1]

# 2c) final log_prob of a = tanh(z)
log_prob = log_pz - log_correction # [batch, 1]
return log_prob
return log_pz - log_correction
else:
# PPO-style or other: no squash correction
return log_pz


class MightyExplorationPolicy:
Expand Down Expand Up @@ -112,7 +114,7 @@ def sample_func_logits(self, state_array):
elif isinstance(out, tuple) and len(out) == 4:
action = out[0] # [batch, action_dim]
log_prob = sample_nondeterministic_logprobs(
z=out[1], mean=out[2], log_std=out[3], sac=self.algo == "sac"
z=out[1], mean=out[2], log_std=out[3], sac=self.ago == "sac"
)
return action.detach().cpu().numpy(), log_prob

Expand Down
83 changes: 47 additions & 36 deletions mighty/mighty_exploration/stochastic_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def __init__(
:param entropy_coefficient: weight on entropy term
:param discrete: whether the action space is discrete
"""

self.model = model

super().__init__(algo, model, discrete)
self.entropy_coefficient = entropy_coefficient
self.discrete = discrete
Expand Down Expand Up @@ -84,33 +87,24 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso
# 4-tuple case (Tanh squashing): (action, z, mean, log_std)
elif isinstance(model_output, tuple) and len(model_output) == 4:
action, z, mean, log_std = model_output
log_prob = sample_nondeterministic_logprobs(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the reason for changing this, it's the same code but longer and locking into a specific model class?

z=z,
mean=mean,
log_std=log_std,
sac=self.algo == "sac",
)

if not self.algo == "sac":

log_prob = sample_nondeterministic_logprobs(
z=z,
mean=mean,
log_std=log_std,
sac=False,
)
else:
log_prob = self.model.policy_log_prob(z, mean, log_std)

if return_logp:
return action.detach().cpu().numpy(), log_prob
else:
weighted_log_prob = log_prob * self.entropy_coefficient
return action.detach().cpu().numpy(), weighted_log_prob

# Legacy 2-tuple case: (mean, std)
elif isinstance(model_output, tuple) and len(model_output) == 2:
mean, std = model_output
dist = Normal(mean, std)
z = dist.rsample() # [batch, action_dim]
action = torch.tanh(z) # [batch, action_dim]

log_prob = sample_nondeterministic_logprobs(
z=z, mean=mean, log_std=torch.log(std), sac=self.algo == "sac"
)
entropy = dist.entropy().sum(dim=-1, keepdim=True) # [batch, 1]
weighted_log_prob = log_prob * entropy
return action.detach().cpu().numpy(), weighted_log_prob

# Check for model attribute-based approaches
elif hasattr(self.model, "continuous_action") and getattr(
self.model, "continuous_action"
Expand All @@ -126,9 +120,16 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso
elif len(model_output) == 4:
# Tanh squashing mode: (action, z, mean, log_std)
action, z, mean, log_std = model_output
log_prob = sample_nondeterministic_logprobs(
z=z, mean=mean, log_std=log_std, sac=self.algo == "sac"
)
if not self.algo == "sac":

log_prob = sample_nondeterministic_logprobs(
z=z,
mean=mean,
log_std=log_std,
sac=False,
)
else:
log_prob = self.model.policy_log_prob(z, mean, log_std)
else:
raise ValueError(
f"Unexpected model output length: {len(model_output)}"
Expand All @@ -145,9 +146,15 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso
if self.model.output_style == "squashed_gaussian":
# Should be 4-tuple: (action, z, mean, log_std)
action, z, mean, log_std = model_output
log_prob = sample_nondeterministic_logprobs(
z=z, mean=mean, log_std=log_std, sac=self.algo == "sac"
)
if not self.algo == "sac":
log_prob = sample_nondeterministic_logprobs(
z=z,
mean=mean,
log_std=log_std,
sac=False,
)
else:
log_prob = self.model.policy_log_prob(z, mean, log_std)

if return_logp:
return action.detach().cpu().numpy(), log_prob
Expand All @@ -162,9 +169,16 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso
z = dist.rsample()
action = torch.tanh(z)

log_prob = sample_nondeterministic_logprobs(
z=z, mean=mean, log_std=torch.log(std), sac=self.algo == "sac"
)
if not self.algo == "sac":
log_prob = sample_nondeterministic_logprobs(
z=z,
mean=mean,
log_std=log_std,
sac=False,
)
else:
log_prob = self.model.policy_log_prob(z, mean, log_std)

entropy = dist.entropy().sum(dim=-1, keepdim=True)
weighted_log_prob = log_prob * entropy
return action.detach().cpu().numpy(), weighted_log_prob
Expand All @@ -175,14 +189,11 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso
)

# Special handling for SACModel
elif isinstance(self.model, SACModel):
elif self.algo == "sac" and isinstance(self.model, SACModel):
action, z, mean, log_std = self.model(state, deterministic=False)
std = torch.exp(log_std)
dist = Normal(mean, std)

log_pz = dist.log_prob(z).sum(dim=-1, keepdim=True)
weighted_log_prob = log_pz * self.entropy_coefficient
return action.detach().cpu().numpy(), weighted_log_prob
# CRITICAL: Use the model's policy_log_prob which includes tanh correction
log_prob = self.model.policy_log_prob(z, mean, log_std)
return action.detach().cpu().numpy(), log_prob

else:
raise RuntimeError(
Expand Down
Loading
Loading