Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
4d9c49c
Only check for terminated episodes
araffin Feb 17, 2025
c98ebde
Start adding ortho init
araffin Feb 18, 2025
cd2ef11
Add SimbaPolicy for PPO
araffin Feb 18, 2025
0461871
Try adding ortho init to SAC
araffin Feb 27, 2025
3e42262
Enable lr schedule for PPO
araffin Feb 27, 2025
101c08d
Allow to pass lr, prepare for adaptive lr
araffin Feb 27, 2025
ef2321d
Implement adaptive lr
araffin Feb 27, 2025
96bf978
Add small test
araffin Feb 27, 2025
6e17def
Refactor adaptive lr
araffin Feb 27, 2025
5832702
Add adaptive lr for SAC
araffin Feb 27, 2025
ab33983
Fix qf_learning_rate
araffin Feb 27, 2025
163acbe
Revert "Fix qf_learning_rate"
araffin Feb 28, 2025
0d83720
Revert "Add adaptive lr for SAC"
araffin Feb 28, 2025
85d4f23
Revert kl div for SAC changes
araffin Feb 28, 2025
dc6bf4e
Revert dist.mode() in two lines
araffin Feb 28, 2025
f0235e6
Cleanup code
araffin Mar 3, 2025
0e14aad
Add support for Gaussian actor for SAC
araffin Mar 8, 2025
1a8063a
Enable Gaussian actor for TQC
araffin Mar 26, 2025
cefbd78
Log std too
araffin Mar 26, 2025
a809a01
Avoid NaN in kl div approx
araffin Apr 28, 2025
c92d840
Allow to use layer_norm in actor
araffin Apr 28, 2025
f54697a
Reformat
araffin May 14, 2025
7afa84f
Allow max grad norm for TQC and fix optimizer class
araffin May 14, 2025
3af7c93
Comment out max grad norm
araffin May 14, 2025
3f9727b
Update to schedule classes
araffin May 14, 2025
d86e89d
Add lr schedule support for TQC
araffin May 14, 2025
c668fd1
Revert experimental changes and add support for lr schedule for SAC
araffin May 19, 2025
22b3e54
Add test for adaptive kl div, remove squash output param
araffin May 19, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 30 additions & 4 deletions sbx/common/off_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import jax
import numpy as np
import optax
from gymnasium import spaces
from stable_baselines3 import HerReplayBuffer
from stable_baselines3.common.buffers import DictReplayBuffer, ReplayBuffer
Expand All @@ -15,6 +16,8 @@


class OffPolicyAlgorithmJax(OffPolicyAlgorithm):
qf_learning_rate: float

def __init__(
self,
policy: type[BasePolicy],
Expand Down Expand Up @@ -75,8 +78,8 @@ def __init__(
)
# Will be updated later
self.key = jax.random.PRNGKey(0)
# Note: we do not allow schedule for it
self.qf_learning_rate = qf_learning_rate
# Note: we do not allow separate schedule for it
self.initial_qf_learning_rate = qf_learning_rate
self.param_resets = param_resets
self.reset_idx = 0

Expand All @@ -89,7 +92,7 @@ def _maybe_reset_params(self) -> None:
):
# Note: we are not resetting the entropy coeff
assert isinstance(self.qf_learning_rate, float)
self.key = self.policy.build(self.key, self.lr_schedule, self.qf_learning_rate)
self.key = self.policy.build(self.key, self.lr_schedule, self.qf_learning_rate) # type: ignore[operator]
self.reset_idx += 1

def _get_torch_save_params(self):
Expand All @@ -100,6 +103,29 @@ def _excluded_save_params(self) -> list[str]:
excluded.remove("policy")
return excluded

def _update_learning_rate( # type: ignore[override]
Copy link

Copilot AI May 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The learning rate update logic is duplicated in both on-policy and off-policy algorithm classes; consider refactoring this into a shared utility to reduce code duplication.

Copilot uses AI. Check for mistakes.
self,
optimizers: Union[list[optax.OptState], optax.OptState],
learning_rate: float,
name: str = "learning_rate",
) -> None:
"""
Update the optimizers learning rate using the current learning rate schedule
and the current progress remaining (from 1 to 0).

:param optimizers: An optimizer or a list of optimizers.
:param learning_rate: The current learning rate to apply
:param name: (Optional) A custom name for the lr (for instance qf_learning_rate)
"""
# Log the current learning rate
self.logger.record(f"train/{name}", learning_rate)

if not isinstance(optimizers, list):
optimizers = [optimizers]
for optimizer in optimizers:
# Note: the optimizer must have been defined with inject_hyperparams
optimizer.hyperparams["learning_rate"] = learning_rate

def set_random_seed(self, seed: Optional[int]) -> None: # type: ignore[override]
super().set_random_seed(seed)
if seed is None:
Expand All @@ -116,7 +142,7 @@ def _setup_model(self) -> None:

self._setup_lr_schedule()
# By default qf_learning_rate = pi_learning_rate
self.qf_learning_rate = self.qf_learning_rate or self.lr_schedule(1)
self.qf_learning_rate = self.initial_qf_learning_rate or self.lr_schedule(1)
self.set_random_seed(self.seed)
# Make a local copy as we should not pickle
# the environment when using HerReplayBuffer
Expand Down
30 changes: 24 additions & 6 deletions sbx/common/on_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import gymnasium as gym
import jax
import numpy as np
import optax
import torch as th
from gymnasium import spaces
from stable_baselines3.common.buffers import RolloutBuffer
Expand Down Expand Up @@ -75,6 +76,27 @@ def _excluded_save_params(self) -> list[str]:
excluded.remove("policy")
return excluded

def _update_learning_rate( # type: ignore[override]
self,
optimizers: Union[list[optax.OptState], optax.OptState],
learning_rate: float,
) -> None:
"""
Update the optimizers learning rate using the current learning rate schedule
and the current progress remaining (from 1 to 0).

:param optimizers:
An optimizer or a list of optimizers.
"""
# Log the current learning rate
self.logger.record("train/learning_rate", learning_rate)

if not isinstance(optimizers, list):
optimizers = [optimizers]
for optimizer in optimizers:
# Note: the optimizer must have been defined with inject_hyperparams
optimizer.hyperparams["learning_rate"] = learning_rate

def set_random_seed(self, seed: Optional[int]) -> None: # type: ignore[override]
super().set_random_seed(seed)
if seed is None:
Expand Down Expand Up @@ -167,12 +189,8 @@ def collect_rollouts(

# Handle timeout by bootstraping with value function
# see GitHub issue #633
for idx, done in enumerate(dones):
if (
done
and infos[idx].get("terminal_observation") is not None
and infos[idx].get("TimeLimit.truncated", False)
):
for idx in dones.nonzero()[0]:
if infos[idx].get("terminal_observation") is not None and infos[idx].get("TimeLimit.truncated", False):
terminal_obs = self.policy.prepare_obs(infos[idx]["terminal_observation"])[0]
terminal_value = np.array(
self.vf.apply( # type: ignore[union-attr]
Expand Down
26 changes: 26 additions & 0 deletions sbx/common/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from dataclasses import dataclass

import numpy as np


@dataclass
class KLAdaptiveLR:
"""Adaptive lr schedule, see https://arxiv.org/abs/1707.02286"""

# If set will trigger adaptive lr
target_kl: float
current_adaptive_lr: float
# Values taken from https://github.com/leggedrobotics/rsl_rl
min_learning_rate: float = 1e-5
max_learning_rate: float = 1e-2
kl_margin: float = 2.0
# Divide or multiply the lr by this factor
adaptive_lr_factor: float = 1.5

def update(self, kl_div: float) -> None:
if kl_div > self.target_kl * self.kl_margin:
self.current_adaptive_lr /= self.adaptive_lr_factor
elif kl_div < self.target_kl / self.kl_margin:
self.current_adaptive_lr *= self.adaptive_lr_factor

self.current_adaptive_lr = np.clip(self.current_adaptive_lr, self.min_learning_rate, self.max_learning_rate)
2 changes: 1 addition & 1 deletion sbx/crossq/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def _setup_model(self) -> None:
apply_fn=self.ent_coef.apply,
params=self.ent_coef.init(ent_key)["params"],
tx=optax.adam(
learning_rate=self.learning_rate,
learning_rate=self.lr_schedule(1),
),
)

Expand Down
4 changes: 2 additions & 2 deletions sbx/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
import optax
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import get_linear_fn
from stable_baselines3.common.utils import LinearSchedule

from sbx.common.off_policy_algorithm import OffPolicyAlgorithmJax
from sbx.common.type_aliases import ReplayBufferSamplesNp, RLTrainState
Expand Down Expand Up @@ -83,7 +83,7 @@ def __init__(
def _setup_model(self) -> None:
super()._setup_model()

self.exploration_schedule = get_linear_fn(
self.exploration_schedule = LinearSchedule(
self.exploration_initial_eps,
self.exploration_final_eps,
self.exploration_fraction,
Expand Down
37 changes: 26 additions & 11 deletions sbx/ppo/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class Actor(nn.Module):
# For MultiDiscrete
max_num_choices: int = 0
split_indices: np.ndarray = field(default_factory=lambda: np.array([]))
# Last layer with small scale
Copy link

Copilot AI May 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Document the new ortho_init parameter in the class docstring and __init__ signature to clarify its purpose and how it affects weight initialization.

Suggested change
# Last layer with small scale

Copilot uses AI. Check for mistakes.
ortho_init: bool = False

def get_std(self) -> jnp.ndarray:
# Make it work with gSDE
Expand All @@ -65,7 +67,15 @@ def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-def
x = nn.Dense(n_units)(x)
x = self.activation_fn(x)

action_logits = nn.Dense(self.action_dim)(x)
if self.ortho_init:
orthogonal_init = nn.initializers.orthogonal(scale=0.01)
bias_init = nn.initializers.zeros
action_logits = nn.Dense(self.action_dim, kernel_init=orthogonal_init, bias_init=bias_init)(x)

else:
action_logits = nn.Dense(self.action_dim)(x)

log_std = jnp.zeros(1)
if self.num_discrete_choices is None:
# Continuous actions
log_std = self.param("log_std", constant(self.log_std_init), (self.action_dim,))
Expand Down Expand Up @@ -118,6 +128,8 @@ def __init__(
optimizer_class: Callable[..., optax.GradientTransformation] = optax.adam,
optimizer_kwargs: Optional[dict[str, Any]] = None,
share_features_extractor: bool = False,
actor_class: type[nn.Module] = Actor,
critic_class: type[nn.Module] = Critic,
):
if optimizer_kwargs is None:
# Small values to avoid NaN in Adam optimizer
Expand Down Expand Up @@ -146,6 +158,9 @@ def __init__(
else:
self.net_arch_pi = self.net_arch_vf = [64, 64]
self.use_sde = use_sde
self.ortho_init = ortho_init
self.actor_class = actor_class
self.critic_class = critic_class

self.key = self.noise_key = jax.random.PRNGKey(0)

Expand Down Expand Up @@ -189,38 +204,38 @@ def build(self, key: jax.Array, lr_schedule: Schedule, max_grad_norm: float) ->
else:
raise NotImplementedError(f"{self.action_space}")

self.actor = Actor(
self.actor = self.actor_class(
net_arch=self.net_arch_pi,
log_std_init=self.log_std_init,
activation_fn=self.activation_fn,
ortho_init=self.ortho_init,
**actor_kwargs, # type: ignore[arg-type]
)
# Hack to make gSDE work without modifying internal SB3 code
self.actor.reset_noise = self.reset_noise

# Inject hyperparameters to be able to modify it later
# See https://stackoverflow.com/questions/78527164
# Note: eps=1e-5 for Adam
optimizer_class = optax.inject_hyperparams(self.optimizer_class)(learning_rate=lr_schedule(1), **self.optimizer_kwargs)

self.actor_state = TrainState.create(
apply_fn=self.actor.apply,
params=self.actor.init(actor_key, obs),
tx=optax.chain(
optax.clip_by_global_norm(max_grad_norm),
self.optimizer_class(
learning_rate=lr_schedule(1), # type: ignore[call-arg]
**self.optimizer_kwargs, # , eps=1e-5
),
optimizer_class,
),
)

self.vf = Critic(net_arch=self.net_arch_vf, activation_fn=self.activation_fn)
self.vf = self.critic_class(net_arch=self.net_arch_vf, activation_fn=self.activation_fn)

self.vf_state = TrainState.create(
apply_fn=self.vf.apply,
params=self.vf.init({"params": vf_key}, obs),
tx=optax.chain(
optax.clip_by_global_norm(max_grad_norm),
self.optimizer_class(
learning_rate=lr_schedule(1), # type: ignore[call-arg]
**self.optimizer_kwargs, # , eps=1e-5
),
optimizer_class,
),
)

Expand Down
Loading