From 4d9c49c05d7b127cf5fc658a9f0bb464902fea77 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 17 Feb 2025 11:59:47 +0100 Subject: [PATCH 01/28] Only check for terminated episodes --- sbx/common/on_policy_algorithm.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/sbx/common/on_policy_algorithm.py b/sbx/common/on_policy_algorithm.py index 36b8ea6..bb371c2 100644 --- a/sbx/common/on_policy_algorithm.py +++ b/sbx/common/on_policy_algorithm.py @@ -167,12 +167,8 @@ def collect_rollouts( # Handle timeout by bootstraping with value function # see GitHub issue #633 - for idx, done in enumerate(dones): - if ( - done - and infos[idx].get("terminal_observation") is not None - and infos[idx].get("TimeLimit.truncated", False) - ): + for idx in dones.nonzero()[0]: + if infos[idx].get("terminal_observation") is not None and infos[idx].get("TimeLimit.truncated", False): terminal_obs = self.policy.prepare_obs(infos[idx]["terminal_observation"])[0] terminal_value = np.array( self.vf.apply( # type: ignore[union-attr] From c98ebdee1e00983609e771a433794d2c2780e961 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Tue, 18 Feb 2025 08:00:09 +0100 Subject: [PATCH 02/28] Start adding ortho init --- sbx/ppo/policies.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/sbx/ppo/policies.py b/sbx/ppo/policies.py index 570849e..cb93727 100644 --- a/sbx/ppo/policies.py +++ b/sbx/ppo/policies.py @@ -44,6 +44,8 @@ class Actor(nn.Module): # For MultiDiscrete max_num_choices: int = 0 split_indices: np.ndarray = field(default_factory=lambda: np.array([])) + # Last layer with small scale + ortho_init: bool = False def get_std(self) -> jnp.ndarray: # Make it work with gSDE @@ -65,7 +67,14 @@ def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-def x = nn.Dense(n_units)(x) x = self.activation_fn(x) - action_logits = nn.Dense(self.action_dim)(x) + if self.ortho_init: + orthogonal_init = nn.initializers.orthogonal(scale=0.01) + bias_init = nn.initializers.zeros + action_logits = nn.Dense(self.action_dim, kernel_init=orthogonal_init, bias_init=bias_init)(x) + + else: + action_logits = nn.Dense(self.action_dim)(x) + if self.num_discrete_choices is None: # Continuous actions log_std = self.param("log_std", constant(self.log_std_init), (self.action_dim,)) @@ -146,6 +155,7 @@ def __init__( else: self.net_arch_pi = self.net_arch_vf = [64, 64] self.use_sde = use_sde + self.ortho_init = ortho_init self.key = self.noise_key = jax.random.PRNGKey(0) @@ -193,6 +203,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, max_grad_norm: float) -> net_arch=self.net_arch_pi, log_std_init=self.log_std_init, activation_fn=self.activation_fn, + ortho_init=self.ortho_init, **actor_kwargs, # type: ignore[arg-type] ) # Hack to make gSDE work without modifying internal SB3 code From cd2ef1129bb702124f68fd8c79ed10f186348d71 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Tue, 18 Feb 2025 08:21:37 +0100 Subject: [PATCH 03/28] Add SimbaPolicy for PPO --- sbx/ppo/policies.py | 111 +++++++++++++++++++++++++++++++++++++++++++- sbx/ppo/ppo.py | 8 ++-- tests/test_run.py | 15 ++++++ 3 files changed, 129 insertions(+), 5 deletions(-) diff --git a/sbx/ppo/policies.py b/sbx/ppo/policies.py index cb93727..d85d77a 100644 --- a/sbx/ppo/policies.py +++ b/sbx/ppo/policies.py @@ -14,6 +14,7 @@ from gymnasium import spaces from stable_baselines3.common.type_aliases import Schedule +from sbx.common.jax_layers import SimbaResidualBlock from sbx.common.policies import BaseJaxPolicy, Flatten tfd = tfp.distributions @@ -34,6 +35,23 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: return x +class SimbaCritic(nn.Module): + net_arch: Sequence[int] + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu + scale_factor: int = 4 + + @nn.compact + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + x = Flatten()(x) + x = nn.Dense(self.net_arch[0])(x) + for n_units in self.net_arch: + x = SimbaResidualBlock(n_units, self.activation_fn, self.scale_factor)(x) + + x = nn.LayerNorm()(x) + x = nn.Dense(1)(x) + return x + + class Actor(nn.Module): action_dim: int net_arch: Sequence[int] @@ -106,6 +124,47 @@ def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-def return dist +class SimbaActor(nn.Module): + action_dim: int + net_arch: Sequence[int] + log_std_init: float = 0.0 + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu + # For Discrete, MultiDiscrete and MultiBinary actions + num_discrete_choices: Optional[Union[int, Sequence[int]]] = None + # For MultiDiscrete + max_num_choices: int = 0 + # Last layer with small scale + ortho_init: bool = False + scale_factor: int = 4 + + def get_std(self) -> jnp.ndarray: + # Make it work with gSDE + return jnp.array(0.0) + + @nn.compact + def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-defined] + x = Flatten()(x) + + x = nn.Dense(self.net_arch[0])(x) + for n_units in self.net_arch: + x = SimbaResidualBlock(n_units, self.activation_fn, self.scale_factor)(x) + x = nn.LayerNorm()(x) + + if self.ortho_init: + orthogonal_init = nn.initializers.orthogonal(scale=0.01) + bias_init = nn.initializers.zeros + mean_action = nn.Dense(self.action_dim, kernel_init=orthogonal_init, bias_init=bias_init)(x) + + else: + mean_action = nn.Dense(self.action_dim)(x) + + # Continuous actions + log_std = self.param("log_std", constant(self.log_std_init), (self.action_dim,)) + dist = tfd.MultivariateNormalDiag(loc=mean_action, scale_diag=jnp.exp(log_std)) + + return dist + + class PPOPolicy(BaseJaxPolicy): def __init__( self, @@ -127,6 +186,8 @@ def __init__( optimizer_class: Callable[..., optax.GradientTransformation] = optax.adam, optimizer_kwargs: Optional[dict[str, Any]] = None, share_features_extractor: bool = False, + actor_class: type[nn.Module] = Actor, + critic_class: type[nn.Module] = Critic, ): if optimizer_kwargs is None: # Small values to avoid NaN in Adam optimizer @@ -156,6 +217,8 @@ def __init__( self.net_arch_pi = self.net_arch_vf = [64, 64] self.use_sde = use_sde self.ortho_init = ortho_init + self.actor_class = actor_class + self.critic_class = critic_class self.key = self.noise_key = jax.random.PRNGKey(0) @@ -199,7 +262,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, max_grad_norm: float) -> else: raise NotImplementedError(f"{self.action_space}") - self.actor = Actor( + self.actor = self.actor_class( net_arch=self.net_arch_pi, log_std_init=self.log_std_init, activation_fn=self.activation_fn, @@ -221,7 +284,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, max_grad_norm: float) -> ), ) - self.vf = Critic(net_arch=self.net_arch_vf, activation_fn=self.activation_fn) + self.vf = self.critic_class(net_arch=self.net_arch_vf, activation_fn=self.activation_fn) self.vf_state = TrainState.create( apply_fn=self.vf.apply, @@ -268,3 +331,47 @@ def _predict_all(actor_state, vf_state, observations, key): log_probs = dist.log_prob(actions) values = vf_state.apply_fn(vf_state.params, observations).flatten() return actions, log_probs, values + + +class SimbaPPOPolicy(PPOPolicy): + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, + ortho_init: bool = False, + log_std_init: float = 0, + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.tanh, + use_sde: bool = False, + use_expln: bool = False, + clip_mean: float = 2, + features_extractor_class=None, + features_extractor_kwargs: Optional[dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Callable[..., optax.GradientTransformation] = optax.adam, + optimizer_kwargs: Optional[dict[str, Any]] = None, + share_features_extractor: bool = False, + actor_class: type[nn.Module] = SimbaActor, + critic_class: type[nn.Module] = SimbaCritic, + ): + super().__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + ortho_init, + log_std_init, + activation_fn, + use_sde, + use_expln, + clip_mean, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + share_features_extractor, + actor_class, + critic_class, + ) diff --git a/sbx/ppo/ppo.py b/sbx/ppo/ppo.py index acbdea6..953d199 100644 --- a/sbx/ppo/ppo.py +++ b/sbx/ppo/ppo.py @@ -11,7 +11,7 @@ from stable_baselines3.common.utils import explained_variance, get_schedule_fn from sbx.common.on_policy_algorithm import OnPolicyAlgorithmJax -from sbx.ppo.policies import PPOPolicy +from sbx.ppo.policies import PPOPolicy, SimbaPPOPolicy PPOSelf = TypeVar("PPOSelf", bound="PPO") @@ -70,6 +70,8 @@ class PPO(OnPolicyAlgorithmJax): policy_aliases: ClassVar[dict[str, type[PPOPolicy]]] = { # type: ignore[assignment] "MlpPolicy": PPOPolicy, + # Residual net, from https://github.com/SonyResearch/simba + "SimbaPolicy": SimbaPPOPolicy, # "CnnPolicy": ActorCriticCnnPolicy, # "MultiInputPolicy": MultiInputActorCriticPolicy, } @@ -179,8 +181,8 @@ def _setup_model(self) -> None: self.key, ent_key = jax.random.split(self.key, 2) - self.actor = self.policy.actor - self.vf = self.policy.vf + self.actor = self.policy.actor # type: ignore[assignment] + self.vf = self.policy.vf # type: ignore[assignment] # Initialize schedules for policy/value clipping self.clip_range_schedule = get_schedule_fn(self.clip_range) diff --git a/tests/test_run.py b/tests/test_run.py index 6d3b5b9..1a03eb0 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -160,6 +160,21 @@ def test_ppo(tmp_path, env_id: str) -> None: check_save_load(model, PPO, tmp_path) +def test_simba_ppo(tmp_path) -> None: + model = PPO( + "SimbaPolicy", + "Pendulum-v1", + verbose=1, + n_steps=32, + batch_size=32, + n_epochs=2, + policy_kwargs=dict(activation_fn=nn.leaky_relu, net_arch=[64]), + ) + model.learn(64, progress_bar=True) + + check_save_load(model, PPO, tmp_path) + + def test_dqn(tmp_path) -> None: model = DQN( "MlpPolicy", From 0461871190a40d4c1c0cbc14e073dba976e22b81 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Thu, 27 Feb 2025 08:47:49 +0100 Subject: [PATCH 04/28] Try adding ortho init to SAC --- sbx/common/off_policy_algorithm.py | 2 +- sbx/common/policies.py | 16 ++++++++++++++-- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/sbx/common/off_policy_algorithm.py b/sbx/common/off_policy_algorithm.py index b7b385c..ba49e1f 100644 --- a/sbx/common/off_policy_algorithm.py +++ b/sbx/common/off_policy_algorithm.py @@ -89,7 +89,7 @@ def _maybe_reset_params(self) -> None: ): # Note: we are not resetting the entropy coeff assert isinstance(self.qf_learning_rate, float) - self.key = self.policy.build(self.key, self.lr_schedule, self.qf_learning_rate) + self.key = self.policy.build(self.key, self.lr_schedule, self.qf_learning_rate) # type: ignore[operator] self.reset_idx += 1 def _get_torch_save_params(self): diff --git a/sbx/common/policies.py b/sbx/common/policies.py index ce23d09..c056046 100644 --- a/sbx/common/policies.py +++ b/sbx/common/policies.py @@ -242,6 +242,8 @@ class SquashedGaussianActor(nn.Module): log_std_min: float = -20 log_std_max: float = 2 activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu + ortho_init: bool = False + log_std_init: float = -1.2 # log(0.3) def get_std(self): # Make it work with gSDE @@ -253,8 +255,18 @@ def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-def for n_units in self.net_arch: x = nn.Dense(n_units)(x) x = self.activation_fn(x) - mean = nn.Dense(self.action_dim)(x) - log_std = nn.Dense(self.action_dim)(x) + + if self.ortho_init: + orthogonal_init = nn.initializers.orthogonal(scale=0.01) + # orthogonal_init = nn.initializers.uniform(scale=0.01) + # orthogonal_init = nn.initializers.normal(stddev=0.01) + bias_init = nn.initializers.zeros + mean = nn.Dense(self.action_dim, kernel_init=orthogonal_init, bias_init=bias_init)(x) + log_std = self.param("log_std", nn.initializers.constant(self.log_std_init), (self.action_dim,)) + # log_std = nn.Dense(self.action_dim, kernel_init=orthogonal_init, bias_init=bias_init)(x) + else: + mean = nn.Dense(self.action_dim)(x) + log_std = nn.Dense(self.action_dim)(x) log_std = jnp.clip(log_std, self.log_std_min, self.log_std_max) dist = TanhTransformedDistribution( tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)), From 3e422625e8a2dc8b9cc77ef9a594a89079b97b58 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Thu, 27 Feb 2025 09:22:10 +0100 Subject: [PATCH 05/28] Enable lr schedule for PPO --- sbx/common/on_policy_algorithm.py | 18 ++++++++++++++++++ sbx/common/policies.py | 2 +- sbx/common/utils.py | 13 +++++++++++++ sbx/ppo/policies.py | 15 +++++++-------- sbx/ppo/ppo.py | 2 +- 5 files changed, 40 insertions(+), 10 deletions(-) create mode 100644 sbx/common/utils.py diff --git a/sbx/common/on_policy_algorithm.py b/sbx/common/on_policy_algorithm.py index bb371c2..9d5bd30 100644 --- a/sbx/common/on_policy_algorithm.py +++ b/sbx/common/on_policy_algorithm.py @@ -3,6 +3,7 @@ import gymnasium as gym import jax import numpy as np +import optax import torch as th from gymnasium import spaces from stable_baselines3.common.buffers import RolloutBuffer @@ -12,6 +13,7 @@ from stable_baselines3.common.type_aliases import GymEnv, Schedule from stable_baselines3.common.vec_env import VecEnv +from sbx.common.utils import update_learning_rate from sbx.ppo.policies import Actor, Critic, PPOPolicy OnPolicyAlgorithmSelf = TypeVar("OnPolicyAlgorithmSelf", bound="OnPolicyAlgorithmJax") @@ -75,6 +77,22 @@ def _excluded_save_params(self) -> list[str]: excluded.remove("policy") return excluded + def _update_learning_rate(self, optimizers: Union[list[optax.OptState], optax.OptState]) -> None: + """ + Update the optimizers learning rate using the current learning rate schedule + and the current progress remaining (from 1 to 0). + + :param optimizers: + An optimizer or a list of optimizers. + """ + # Log the current learning rate + self.logger.record("train/learning_rate", self.lr_schedule(self._current_progress_remaining)) + + if not isinstance(optimizers, list): + optimizers = [optimizers] + for optimizer in optimizers: + update_learning_rate(optimizer, self.lr_schedule(self._current_progress_remaining)) + def set_random_seed(self, seed: Optional[int]) -> None: # type: ignore[override] super().set_random_seed(seed) if seed is None: diff --git a/sbx/common/policies.py b/sbx/common/policies.py index c056046..9a44bed 100644 --- a/sbx/common/policies.py +++ b/sbx/common/policies.py @@ -243,7 +243,7 @@ class SquashedGaussianActor(nn.Module): log_std_max: float = 2 activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu ortho_init: bool = False - log_std_init: float = -1.2 # log(0.3) + log_std_init: float = -1.2 # log(0.3) def get_std(self): # Make it work with gSDE diff --git a/sbx/common/utils.py b/sbx/common/utils.py new file mode 100644 index 0000000..d0012bb --- /dev/null +++ b/sbx/common/utils.py @@ -0,0 +1,13 @@ +import optax + + +def update_learning_rate(opt_state: optax.OptState, learning_rate: float) -> None: + """ + Update the learning rate for a given optimizer. + Useful when doing linear schedule. + + :param optimizer: Optax optimizer state + :param learning_rate: New learning rate value + """ + # Note: the optimizer must have been defined with inject_hyperparams + opt_state.hyperparams["learning_rate"] = learning_rate diff --git a/sbx/ppo/policies.py b/sbx/ppo/policies.py index d85d77a..aad4c76 100644 --- a/sbx/ppo/policies.py +++ b/sbx/ppo/policies.py @@ -272,15 +272,17 @@ def build(self, key: jax.Array, lr_schedule: Schedule, max_grad_norm: float) -> # Hack to make gSDE work without modifying internal SB3 code self.actor.reset_noise = self.reset_noise + # Inject hyperparameters to be able to modify it later + # See https://stackoverflow.com/questions/78527164 + # Note: eps=1e-5 for Adam + optimizer_class = optax.inject_hyperparams(self.optimizer_class)(learning_rate=lr_schedule(1), **self.optimizer_kwargs) + self.actor_state = TrainState.create( apply_fn=self.actor.apply, params=self.actor.init(actor_key, obs), tx=optax.chain( optax.clip_by_global_norm(max_grad_norm), - self.optimizer_class( - learning_rate=lr_schedule(1), # type: ignore[call-arg] - **self.optimizer_kwargs, # , eps=1e-5 - ), + optimizer_class, ), ) @@ -291,10 +293,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, max_grad_norm: float) -> params=self.vf.init({"params": vf_key}, obs), tx=optax.chain( optax.clip_by_global_norm(max_grad_norm), - self.optimizer_class( - learning_rate=lr_schedule(1), # type: ignore[call-arg] - **self.optimizer_kwargs, # , eps=1e-5 - ), + optimizer_class, ), ) diff --git a/sbx/ppo/ppo.py b/sbx/ppo/ppo.py index 953d199..3489c3f 100644 --- a/sbx/ppo/ppo.py +++ b/sbx/ppo/ppo.py @@ -252,7 +252,7 @@ def train(self) -> None: Update policy using the currently gathered rollout buffer. """ # Update optimizer learning rate - # self._update_learning_rate(self.policy.optimizer) + self._update_learning_rate([self.policy.actor_state.opt_state[1], self.policy.vf_state.opt_state[1]]) # Compute current clip range clip_range = self.clip_range_schedule(self._current_progress_remaining) From 101c08de441c8f541ae04f09987e12d805077cab Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Thu, 27 Feb 2025 09:34:05 +0100 Subject: [PATCH 06/28] Allow to pass lr, prepare for adaptive lr --- sbx/common/on_policy_algorithm.py | 6 +++--- sbx/ppo/ppo.py | 5 ++++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/sbx/common/on_policy_algorithm.py b/sbx/common/on_policy_algorithm.py index 9d5bd30..b612568 100644 --- a/sbx/common/on_policy_algorithm.py +++ b/sbx/common/on_policy_algorithm.py @@ -77,7 +77,7 @@ def _excluded_save_params(self) -> list[str]: excluded.remove("policy") return excluded - def _update_learning_rate(self, optimizers: Union[list[optax.OptState], optax.OptState]) -> None: + def _update_learning_rate(self, optimizers: Union[list[optax.OptState], optax.OptState], learning_rate: float) -> None: """ Update the optimizers learning rate using the current learning rate schedule and the current progress remaining (from 1 to 0). @@ -86,12 +86,12 @@ def _update_learning_rate(self, optimizers: Union[list[optax.OptState], optax.Op An optimizer or a list of optimizers. """ # Log the current learning rate - self.logger.record("train/learning_rate", self.lr_schedule(self._current_progress_remaining)) + self.logger.record("train/learning_rate", learning_rate) if not isinstance(optimizers, list): optimizers = [optimizers] for optimizer in optimizers: - update_learning_rate(optimizer, self.lr_schedule(self._current_progress_remaining)) + update_learning_rate(optimizer, learning_rate) def set_random_seed(self, seed: Optional[int]) -> None: # type: ignore[override] super().set_random_seed(seed) diff --git a/sbx/ppo/ppo.py b/sbx/ppo/ppo.py index 3489c3f..51dcefb 100644 --- a/sbx/ppo/ppo.py +++ b/sbx/ppo/ppo.py @@ -252,7 +252,10 @@ def train(self) -> None: Update policy using the currently gathered rollout buffer. """ # Update optimizer learning rate - self._update_learning_rate([self.policy.actor_state.opt_state[1], self.policy.vf_state.opt_state[1]]) + self._update_learning_rate( + [self.policy.actor_state.opt_state[1], self.policy.vf_state.opt_state[1]], + learning_rate=self.lr_schedule(self._current_progress_remaining), + ) # Compute current clip range clip_range = self.clip_range_schedule(self._current_progress_remaining) From ef2321de1e0f7eef60b7a62d73b82d85bfd59baf Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Thu, 27 Feb 2025 10:29:56 +0100 Subject: [PATCH 07/28] Implement adaptive lr --- sbx/common/on_policy_algorithm.py | 6 ++- sbx/ppo/ppo.py | 67 +++++++++++++++++++++++++------ 2 files changed, 59 insertions(+), 14 deletions(-) diff --git a/sbx/common/on_policy_algorithm.py b/sbx/common/on_policy_algorithm.py index b612568..48cbff1 100644 --- a/sbx/common/on_policy_algorithm.py +++ b/sbx/common/on_policy_algorithm.py @@ -77,7 +77,11 @@ def _excluded_save_params(self) -> list[str]: excluded.remove("policy") return excluded - def _update_learning_rate(self, optimizers: Union[list[optax.OptState], optax.OptState], learning_rate: float) -> None: + def _update_learning_rate( # type: ignore[override] + self, + optimizers: Union[list[optax.OptState], optax.OptState], + learning_rate: float, + ) -> None: """ Update the optimizers learning rate using the current learning rate schedule and the current progress remaining (from 1 to 0). diff --git a/sbx/ppo/ppo.py b/sbx/ppo/ppo.py index 51dcefb..5fab4b4 100644 --- a/sbx/ppo/ppo.py +++ b/sbx/ppo/ppo.py @@ -54,9 +54,8 @@ class PPO(OnPolicyAlgorithmJax): instead of action noise exploration (default: False) :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout) - :param target_kl: Limit the KL divergence between updates, - because the clipping is not enough to prevent large update - see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213) + :param target_kl: Update the learning rate based on a desired KL divergence. + Note: this will overwrite any lr schedule. By default, there is no limit on the kl div. :param tensorboard_log: the log location for tensorboard (if None, no logging) :param policy_kwargs: additional arguments to be passed to the policy on creation @@ -76,6 +75,7 @@ class PPO(OnPolicyAlgorithmJax): # "MultiInputPolicy": MultiInputActorCriticPolicy, } policy: PPOPolicy # type: ignore[assignment] + current_adaptive_lr: float def __init__( self, @@ -161,7 +161,16 @@ def __init__( self.clip_range = clip_range self.clip_range_vf = clip_range_vf self.normalize_advantage = normalize_advantage + # If set will trigger adaptive lr self.target_kl = target_kl + if target_kl is not None and self.verbose > 0: + print(f"Using adaptive learning rate with {target_kl=}, any other lr schedule will be skipped.") + # Values taken from https://github.com/leggedrobotics/rsl_rl + self.min_learning_rate = 1e-5 + self.max_learning_rate = 1e-2 + self.kl_margin = 2.0 + # Divide or multiple the lr by this factor + self.adaptive_lr_factor = 1.5 if _init_setup_model: self._setup_model() @@ -169,6 +178,8 @@ def __init__( def _setup_model(self) -> None: super()._setup_model() + self.current_adaptive_lr = self.lr_schedule(1.0) + if not hasattr(self, "policy") or self.policy is None: # type: ignore[has-type] self.policy = self.policy_class( # type: ignore[assignment] self.observation_space, @@ -231,9 +242,9 @@ def actor_loss(params): entropy_loss = -jnp.mean(entropy) total_policy_loss = policy_loss + ent_coef * entropy_loss - return total_policy_loss + return total_policy_loss, ratio - pg_loss_value, grads = jax.value_and_grad(actor_loss, has_aux=False)(actor_state.params) + (pg_loss_value, ratio), grads = jax.value_and_grad(actor_loss, has_aux=True)(actor_state.params) actor_state = actor_state.apply_gradients(grads=grads) def critic_loss(params): @@ -245,31 +256,36 @@ def critic_loss(params): vf_state = vf_state.apply_gradients(grads=grads) # loss = policy_loss + ent_coef * entropy_loss + vf_coef * value_loss - return (actor_state, vf_state), (pg_loss_value, vf_loss_value) + return (actor_state, vf_state), (pg_loss_value, vf_loss_value, ratio) def train(self) -> None: """ Update policy using the currently gathered rollout buffer. """ # Update optimizer learning rate - self._update_learning_rate( - [self.policy.actor_state.opt_state[1], self.policy.vf_state.opt_state[1]], - learning_rate=self.lr_schedule(self._current_progress_remaining), - ) + if self.target_kl is None: + self._update_learning_rate( + [self.policy.actor_state.opt_state[1], self.policy.vf_state.opt_state[1]], + learning_rate=self.lr_schedule(self._current_progress_remaining), + ) # Compute current clip range clip_range = self.clip_range_schedule(self._current_progress_remaining) + n_updates = 0 + mean_clip_fraction = 0.0 + mean_kl_div = 0.0 # train for n_epochs epochs for _ in range(self.n_epochs): # JIT only one update for rollout_data in self.rollout_buffer.get(self.batch_size): # type: ignore[attr-defined] + n_updates += 1 if isinstance(self.action_space, spaces.Discrete): # Convert discrete action from float to int actions = rollout_data.actions.flatten().numpy().astype(np.int32) else: actions = rollout_data.actions.numpy() - (self.policy.actor_state, self.policy.vf_state), (pg_loss, value_loss) = self._one_update( + (self.policy.actor_state, self.policy.vf_state), (pg_loss, value_loss, ratio) = self._one_update( actor_state=self.policy.actor_state, vf_state=self.policy.vf_state, observations=rollout_data.observations.numpy(), @@ -283,6 +299,31 @@ def train(self) -> None: normalize_advantage=self.normalize_advantage, ) + # Calculate approximate form of reverse KL Divergence for adaptive lr + # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417 + # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419 + # and Schulman blog: http://joschu.net/blog/kl-approx.html + approx_kl_div = jnp.mean((ratio - 1.0) - jnp.log(ratio)).item() + clip_fraction = jnp.mean(jnp.abs(ratio - 1) > clip_range).item() + # Compute average + mean_clip_fraction += (clip_fraction - mean_clip_fraction) / n_updates + mean_kl_div += (approx_kl_div - mean_kl_div) / n_updates + + # Adaptive lr schedule, see https://arxiv.org/abs/1707.02286 + if self.target_kl is not None: + if approx_kl_div > self.target_kl * self.kl_margin: + self.current_adaptive_lr /= self.adaptive_lr_factor + elif approx_kl_div < self.target_kl / self.kl_margin: + self.current_adaptive_lr *= self.adaptive_lr_factor + + self.current_adaptive_lr = np.clip( + self.current_adaptive_lr, self.min_learning_rate, self.max_learning_rate + ) + + self._update_learning_rate( + [self.policy.actor_state.opt_state[1], self.policy.vf_state.opt_state[1]], + learning_rate=self.current_adaptive_lr, + ) self._n_updates += self.n_epochs explained_var = explained_variance( self.rollout_buffer.values.flatten(), # type: ignore[attr-defined] @@ -294,8 +335,8 @@ def train(self) -> None: # self.logger.record("train/policy_gradient_loss", np.mean(pg_losses)) # TODO: use mean instead of one point self.logger.record("train/value_loss", value_loss.item()) - # self.logger.record("train/approx_kl", np.mean(approx_kl_divs)) - # self.logger.record("train/clip_fraction", np.mean(clip_fractions)) + self.logger.record("train/approx_kl", mean_kl_div) + self.logger.record("train/clip_fraction", mean_clip_fraction) self.logger.record("train/pg_loss", pg_loss.item()) self.logger.record("train/explained_variance", explained_var) try: From 96bf97804b74beda47bd5c5701c14392f756f126 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Thu, 27 Feb 2025 10:32:40 +0100 Subject: [PATCH 08/28] Add small test --- tests/test_run.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_run.py b/tests/test_run.py index 1a03eb0..9fcb04a 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -169,6 +169,8 @@ def test_simba_ppo(tmp_path) -> None: batch_size=32, n_epochs=2, policy_kwargs=dict(activation_fn=nn.leaky_relu, net_arch=[64]), + # Test adaptive lr + target_kl=0.01, ) model.learn(64, progress_bar=True) From 6e17def7c2fb0a53ef2ebd354a962b0c547da177 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Thu, 27 Feb 2025 19:31:21 +0100 Subject: [PATCH 09/28] Refactor adaptive lr --- sbx/common/policies.py | 13 +++++++------ sbx/common/utils.py | 44 ++++++++++++++++++++++++++++++++++++++++++ sbx/crossq/crossq.py | 4 ++-- sbx/crossq/policies.py | 15 +++++++------- sbx/ppo/policies.py | 11 ++++++----- sbx/ppo/ppo.py | 25 +++++++----------------- sbx/sac/sac.py | 8 ++++++-- sbx/tqc/tqc.py | 4 ++-- 8 files changed, 82 insertions(+), 42 deletions(-) diff --git a/sbx/common/policies.py b/sbx/common/policies.py index 9a44bed..d5bb996 100644 --- a/sbx/common/policies.py +++ b/sbx/common/policies.py @@ -38,14 +38,15 @@ def __init__(self, *args, **kwargs): @staticmethod @jax.jit def sample_action(actor_state, observations, key): - dist = actor_state.apply_fn(actor_state.params, observations) + dist, _, _ = actor_state.apply_fn(actor_state.params, observations) action = dist.sample(seed=key) return action @staticmethod @jax.jit def select_action(actor_state, observations): - return actor_state.apply_fn(actor_state.params, observations).mode() + dist, _, _ = actor_state.apply_fn(actor_state.params, observations) + return dist.mode() @no_type_check def predict( @@ -250,7 +251,7 @@ def get_std(self): return jnp.array(0.0) @nn.compact - def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-defined] + def __call__(self, x: jnp.ndarray) -> tuple[tfd.Distribution, jnp.ndarray, jnp.ndarray]: # type: ignore[name-defined] x = Flatten()(x) for n_units in self.net_arch: x = nn.Dense(n_units)(x) @@ -271,7 +272,7 @@ def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-def dist = TanhTransformedDistribution( tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)), ) - return dist + return dist, mean, log_std class SimbaSquashedGaussianActor(nn.Module): @@ -290,7 +291,7 @@ def get_std(self): return jnp.array(0.0) @nn.compact - def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-defined] + def __call__(self, x: jnp.ndarray) -> tuple[tfd.Distribution, jnp.ndarray, jnp.ndarray]: # type: ignore[name-defined] x = Flatten()(x) # Note: simba was using kernel_init=orthogonal_init(1) @@ -305,4 +306,4 @@ def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-def dist = TanhTransformedDistribution( tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)), ) - return dist + return dist, mean, log_std diff --git a/sbx/common/utils.py b/sbx/common/utils.py index d0012bb..8619a88 100644 --- a/sbx/common/utils.py +++ b/sbx/common/utils.py @@ -1,3 +1,7 @@ +from dataclasses import dataclass + +import jax.numpy as jnp +import numpy as np import optax @@ -11,3 +15,43 @@ def update_learning_rate(opt_state: optax.OptState, learning_rate: float) -> Non """ # Note: the optimizer must have been defined with inject_hyperparams opt_state.hyperparams["learning_rate"] = learning_rate + + +def kl_div_gaussian( + old_std: jnp.ndarray, + old_mean: jnp.ndarray, + new_std: jnp.ndarray, + new_mean: jnp.ndarray, + eps: float = 1e-5, +) -> float: + # See https://stats.stackexchange.com/questions/7440/ + # We have independent Gaussian for each action dim + return ( + jnp.sum( + jnp.log(new_std / old_std + eps) + ((old_std) ** 2 + (old_mean - new_mean) ** 2) / (2.0 * (new_std) ** 2) - 0.5, + axis=-1, + ) + .mean() + .item() + ) + + +@dataclass +class KlAdaptiveLR: + # If set will trigger adaptive lr + target_kl: float + current_adaptive_lr: float + # Values taken from https://github.com/leggedrobotics/rsl_rl + min_learning_rate: float = 1e-5 + max_learning_rate: float = 1e-2 + kl_margin: float = 2.0 + # Divide or multiple the lr by this factor + adaptive_lr_factor: float = 1.5 + + def update(self, kl_div: float) -> None: + if kl_div > self.target_kl * self.kl_margin: + self.current_adaptive_lr /= self.adaptive_lr_factor + elif kl_div < self.target_kl / self.kl_margin: + self.current_adaptive_lr *= self.adaptive_lr_factor + + self.current_adaptive_lr = np.clip(self.current_adaptive_lr, self.min_learning_rate, self.max_learning_rate) diff --git a/sbx/crossq/crossq.py b/sbx/crossq/crossq.py index ee9d74b..7d0bcd9 100644 --- a/sbx/crossq/crossq.py +++ b/sbx/crossq/crossq.py @@ -255,7 +255,7 @@ def update_critic( ): key, noise_key, dropout_key_target, dropout_key_current = jax.random.split(key, 4) # sample action from the actor - dist = actor_state.apply_fn( + dist, _, _ = actor_state.apply_fn( {"params": actor_state.params, "batch_stats": actor_state.batch_stats}, next_observations, train=False, @@ -330,7 +330,7 @@ def update_actor( def actor_loss( params: flax.core.FrozenDict, batch_stats: flax.core.FrozenDict ) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: - dist, state_updates = actor_state.apply_fn( + (dist, _, _), state_updates = actor_state.apply_fn( {"params": params, "batch_stats": batch_stats}, observations, mutable=["batch_stats"], diff --git a/sbx/crossq/policies.py b/sbx/crossq/policies.py index c4e2986..8de9ff3 100644 --- a/sbx/crossq/policies.py +++ b/sbx/crossq/policies.py @@ -182,7 +182,7 @@ def get_std(self): return jnp.array(0.0) @nn.compact - def __call__(self, x: jnp.ndarray, train: bool = False) -> tfd.Distribution: # type: ignore[name-defined] + def __call__(self, x: jnp.ndarray, train: bool = False) -> tuple[tfd.Distribution, jnp.ndarray, jnp.ndarray]: # type: ignore[name-defined] x = Flatten()(x) norm_layer = partial( BatchRenorm, @@ -208,7 +208,7 @@ def __call__(self, x: jnp.ndarray, train: bool = False) -> tfd.Distribution: # dist = TanhTransformedDistribution( tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)), ) - return dist + return dist, mean, log_std class Actor(nn.Module): @@ -226,7 +226,7 @@ def get_std(self): return jnp.array(0.0) @nn.compact - def __call__(self, x: jnp.ndarray, train: bool = False) -> tfd.Distribution: # type: ignore[name-defined] + def __call__(self, x: jnp.ndarray, train: bool = False) -> tuple[tfd.Distribution, jnp.ndarray, jnp.ndarray]: # type: ignore[name-defined] x = Flatten()(x) if self.use_batch_norm: x = BatchRenorm( @@ -254,7 +254,7 @@ def __call__(self, x: jnp.ndarray, train: bool = False) -> tfd.Distribution: # dist = TanhTransformedDistribution( tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)), ) - return dist + return dist, mean, log_std class CrossQPolicy(BaseJaxPolicy): @@ -426,7 +426,7 @@ def forward(self, obs: np.ndarray, deterministic: bool = False) -> np.ndarray: @staticmethod @jax.jit def sample_action(actor_state, observations, key): - dist = actor_state.apply_fn( + dist, _, _ = actor_state.apply_fn( {"params": actor_state.params, "batch_stats": actor_state.batch_stats}, observations, train=False, @@ -437,11 +437,12 @@ def sample_action(actor_state, observations, key): @staticmethod @jax.jit def select_action(actor_state, observations): - return actor_state.apply_fn( + dist, _, _ = actor_state.apply_fn( {"params": actor_state.params, "batch_stats": actor_state.batch_stats}, observations, train=False, - ).mode() + ) + return dist.mode() def _predict(self, observation: np.ndarray, deterministic: bool = False) -> np.ndarray: # type: ignore[override] if deterministic: diff --git a/sbx/ppo/policies.py b/sbx/ppo/policies.py index aad4c76..731a33c 100644 --- a/sbx/ppo/policies.py +++ b/sbx/ppo/policies.py @@ -78,7 +78,7 @@ def __post_init__(self) -> None: super().__post_init__() @nn.compact - def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-defined] + def __call__(self, x: jnp.ndarray) -> tuple[tfd.Distribution, jnp.ndarray, jnp.ndarray]: # type: ignore[name-defined] x = Flatten()(x) for n_units in self.net_arch: @@ -93,6 +93,7 @@ def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-def else: action_logits = nn.Dense(self.action_dim)(x) + log_std = jnp.zeros(1) if self.num_discrete_choices is None: # Continuous actions log_std = self.param("log_std", constant(self.log_std_init), (self.action_dim,)) @@ -121,7 +122,7 @@ def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-def dist = tfp.distributions.Independent( tfp.distributions.Categorical(logits=logits_padded), reinterpreted_batch_ndims=1 ) - return dist + return dist, action_logits, log_std class SimbaActor(nn.Module): @@ -142,7 +143,7 @@ def get_std(self) -> jnp.ndarray: return jnp.array(0.0) @nn.compact - def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-defined] + def __call__(self, x: jnp.ndarray) -> tuple[tfd.Distribution, jnp.ndarray, jnp.ndarray]: # type: ignore[name-defined] x = Flatten()(x) x = nn.Dense(self.net_arch[0])(x) @@ -162,7 +163,7 @@ def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-def log_std = self.param("log_std", constant(self.log_std_init), (self.action_dim,)) dist = tfd.MultivariateNormalDiag(loc=mean_action, scale_diag=jnp.exp(log_std)) - return dist + return dist, mean_action, log_std class PPOPolicy(BaseJaxPolicy): @@ -325,7 +326,7 @@ def predict_all(self, observation: np.ndarray, key: jax.Array) -> np.ndarray: @staticmethod @jax.jit def _predict_all(actor_state, vf_state, observations, key): - dist = actor_state.apply_fn(actor_state.params, observations) + dist, _, _ = actor_state.apply_fn(actor_state.params, observations) actions = dist.sample(seed=key) log_probs = dist.log_prob(actions) values = vf_state.apply_fn(vf_state.params, observations).flatten() diff --git a/sbx/ppo/ppo.py b/sbx/ppo/ppo.py index 5fab4b4..34035d8 100644 --- a/sbx/ppo/ppo.py +++ b/sbx/ppo/ppo.py @@ -11,6 +11,7 @@ from stable_baselines3.common.utils import explained_variance, get_schedule_fn from sbx.common.on_policy_algorithm import OnPolicyAlgorithmJax +from sbx.common.utils import KlAdaptiveLR from sbx.ppo.policies import PPOPolicy, SimbaPPOPolicy PPOSelf = TypeVar("PPOSelf", bound="PPO") @@ -75,7 +76,7 @@ class PPO(OnPolicyAlgorithmJax): # "MultiInputPolicy": MultiInputActorCriticPolicy, } policy: PPOPolicy # type: ignore[assignment] - current_adaptive_lr: float + adaptive_lr: KlAdaptiveLR def __init__( self, @@ -165,12 +166,6 @@ def __init__( self.target_kl = target_kl if target_kl is not None and self.verbose > 0: print(f"Using adaptive learning rate with {target_kl=}, any other lr schedule will be skipped.") - # Values taken from https://github.com/leggedrobotics/rsl_rl - self.min_learning_rate = 1e-5 - self.max_learning_rate = 1e-2 - self.kl_margin = 2.0 - # Divide or multiple the lr by this factor - self.adaptive_lr_factor = 1.5 if _init_setup_model: self._setup_model() @@ -178,7 +173,8 @@ def __init__( def _setup_model(self) -> None: super()._setup_model() - self.current_adaptive_lr = self.lr_schedule(1.0) + if self.target_kl is not None: + self.adaptive_lr = KlAdaptiveLR(self.target_kl, self.lr_schedule(1.0)) if not hasattr(self, "policy") or self.policy is None: # type: ignore[has-type] self.policy = self.policy_class( # type: ignore[assignment] @@ -224,7 +220,7 @@ def _one_update( advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) def actor_loss(params): - dist = actor_state.apply_fn(params, observations) + dist, _, _ = actor_state.apply_fn(params, observations) log_prob = dist.log_prob(actions) entropy = dist.entropy() @@ -311,18 +307,11 @@ def train(self) -> None: # Adaptive lr schedule, see https://arxiv.org/abs/1707.02286 if self.target_kl is not None: - if approx_kl_div > self.target_kl * self.kl_margin: - self.current_adaptive_lr /= self.adaptive_lr_factor - elif approx_kl_div < self.target_kl / self.kl_margin: - self.current_adaptive_lr *= self.adaptive_lr_factor - - self.current_adaptive_lr = np.clip( - self.current_adaptive_lr, self.min_learning_rate, self.max_learning_rate - ) + self.adaptive_lr.update(approx_kl_div) self._update_learning_rate( [self.policy.actor_state.opt_state[1], self.policy.vf_state.opt_state[1]], - learning_rate=self.current_adaptive_lr, + learning_rate=self.adaptive_lr.current_adaptive_lr, ) self._n_updates += self.n_epochs explained_var = explained_variance( diff --git a/sbx/sac/sac.py b/sbx/sac/sac.py index a6826c5..403a6d4 100644 --- a/sbx/sac/sac.py +++ b/sbx/sac/sac.py @@ -16,6 +16,7 @@ from sbx.common.off_policy_algorithm import OffPolicyAlgorithmJax from sbx.common.type_aliases import ReplayBufferSamplesNp, RLTrainState +from sbx.common.utils import kl_div_gaussian from sbx.sac.policies import SACPolicy, SimbaSACPolicy @@ -70,6 +71,7 @@ def __init__( replay_buffer_kwargs: Optional[dict[str, Any]] = None, ent_coef: Union[str, float] = "auto", target_entropy: Union[Literal["auto"], float] = "auto", + target_kl: Optional[float] = None, use_sde: bool = False, sde_sample_freq: int = -1, use_sde_at_warmup: bool = False, @@ -113,6 +115,8 @@ def __init__( self.policy_delay = policy_delay self.ent_coef_init = ent_coef self.target_entropy = target_entropy + self.old_mean: Optional[jnp.ndarray] = None + self.old_std: Optional[jnp.ndarray] = None if _init_setup_model: self._setup_model() @@ -257,7 +261,7 @@ def update_critic( ): key, noise_key, dropout_key_target, dropout_key_current = jax.random.split(key, 4) # sample action from the actor - dist = actor_state.apply_fn(actor_state.params, next_observations) + dist, _, _ = actor_state.apply_fn(actor_state.params, next_observations) next_state_actions = dist.sample(seed=noise_key) next_log_prob = dist.log_prob(next_state_actions) @@ -302,7 +306,7 @@ def update_actor( key, dropout_key, noise_key = jax.random.split(key, 3) def actor_loss(params: flax.core.FrozenDict) -> tuple[jax.Array, jax.Array]: - dist = actor_state.apply_fn(params, observations) + dist, _, _ = actor_state.apply_fn(params, observations) actor_actions = dist.sample(seed=noise_key) log_prob = dist.log_prob(actor_actions).reshape(-1, 1) diff --git a/sbx/tqc/tqc.py b/sbx/tqc/tqc.py index 91b5c9b..d90d9ed 100644 --- a/sbx/tqc/tqc.py +++ b/sbx/tqc/tqc.py @@ -266,7 +266,7 @@ def update_critic( key, noise_key, dropout_key_1, dropout_key_2 = jax.random.split(key, 4) key, dropout_key_3, dropout_key_4 = jax.random.split(key, 3) # sample action from the actor - dist = actor_state.apply_fn(actor_state.params, next_observations) + dist, _, _ = actor_state.apply_fn(actor_state.params, next_observations) next_state_actions = dist.sample(seed=noise_key) next_log_prob = dist.log_prob(next_state_actions) @@ -345,7 +345,7 @@ def update_actor( key, dropout_key_1, dropout_key_2, noise_key = jax.random.split(key, 4) def actor_loss(params: flax.core.FrozenDict) -> tuple[jax.Array, jax.Array]: - dist = actor_state.apply_fn(params, observations) + dist, _, _ = actor_state.apply_fn(params, observations) actor_actions = dist.sample(seed=noise_key) log_prob = dist.log_prob(actor_actions).reshape(-1, 1) From 58327023e4c55ad78e3c084338ed0835c9a76606 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Thu, 27 Feb 2025 20:17:04 +0100 Subject: [PATCH 10/28] Add adaptive lr for SAC --- sbx/common/utils.py | 32 ++++++++++++++++++++------------ sbx/ppo/policies.py | 5 ++++- sbx/ppo/ppo.py | 2 +- sbx/sac/policies.py | 17 +++++++++-------- sbx/sac/sac.py | 44 +++++++++++++++++++++++++++++++++++++++++--- tests/test_run.py | 4 ++-- 6 files changed, 77 insertions(+), 27 deletions(-) diff --git a/sbx/common/utils.py b/sbx/common/utils.py index 8619a88..26b8e6d 100644 --- a/sbx/common/utils.py +++ b/sbx/common/utils.py @@ -1,5 +1,6 @@ from dataclasses import dataclass +import jax import jax.numpy as jnp import numpy as np import optax @@ -17,27 +18,30 @@ def update_learning_rate(opt_state: optax.OptState, learning_rate: float) -> Non opt_state.hyperparams["learning_rate"] = learning_rate +@jax.jit def kl_div_gaussian( - old_std: jnp.ndarray, old_mean: jnp.ndarray, - new_std: jnp.ndarray, + old_std: jnp.ndarray, new_mean: jnp.ndarray, + new_std: jnp.ndarray, eps: float = 1e-5, -) -> float: +) -> jnp.ndarray: # See https://stats.stackexchange.com/questions/7440/ # We have independent Gaussian for each action dim - return ( - jnp.sum( - jnp.log(new_std / old_std + eps) + ((old_std) ** 2 + (old_mean - new_mean) ** 2) / (2.0 * (new_std) ** 2) - 0.5, - axis=-1, - ) - .mean() - .item() - ) + # TODO: double check dimensions + # return jnp.sum( + # jnp.log(new_std / old_std + eps) + ((old_std) ** 2 + (old_mean - new_mean) ** 2) / (2.0 * (new_std) ** 2) - 0.5, + # axis=-1, + # ).mean() + # Another approximation, mean over batch and action dim + old_mean, old_std, new_mean, new_std = old_mean.mean(), old_std.mean(), new_mean.mean(), new_std.mean() + return jnp.log(new_std / old_std + eps) + ((old_std) ** 2 + (old_mean - new_mean) ** 2) / (2.0 * (new_std) ** 2) - 0.5 @dataclass class KlAdaptiveLR: + """Adaptive learning rate schedule, see https://arxiv.org/abs/1707.02286""" + # If set will trigger adaptive lr target_kl: float current_adaptive_lr: float @@ -54,4 +58,8 @@ def update(self, kl_div: float) -> None: elif kl_div < self.target_kl / self.kl_margin: self.current_adaptive_lr *= self.adaptive_lr_factor - self.current_adaptive_lr = np.clip(self.current_adaptive_lr, self.min_learning_rate, self.max_learning_rate) + self.current_adaptive_lr = np.clip( + self.current_adaptive_lr, + self.min_learning_rate, + self.max_learning_rate, + ) diff --git a/sbx/ppo/policies.py b/sbx/ppo/policies.py index 731a33c..7d7dac8 100644 --- a/sbx/ppo/policies.py +++ b/sbx/ppo/policies.py @@ -276,7 +276,10 @@ def build(self, key: jax.Array, lr_schedule: Schedule, max_grad_norm: float) -> # Inject hyperparameters to be able to modify it later # See https://stackoverflow.com/questions/78527164 # Note: eps=1e-5 for Adam - optimizer_class = optax.inject_hyperparams(self.optimizer_class)(learning_rate=lr_schedule(1), **self.optimizer_kwargs) + optimizer_class = optax.inject_hyperparams(self.optimizer_class)( + learning_rate=lr_schedule(1), + **self.optimizer_kwargs, + ) self.actor_state = TrainState.create( apply_fn=self.actor.apply, diff --git a/sbx/ppo/ppo.py b/sbx/ppo/ppo.py index 34035d8..9b0632d 100644 --- a/sbx/ppo/ppo.py +++ b/sbx/ppo/ppo.py @@ -162,7 +162,7 @@ def __init__( self.clip_range = clip_range self.clip_range_vf = clip_range_vf self.normalize_advantage = normalize_advantage - # If set will trigger adaptive lr + # If set, will trigger adaptive lr self.target_kl = target_kl if target_kl is not None and self.verbose > 0: print(f"Using adaptive learning rate with {target_kl=}, any other lr schedule will be skipped.") diff --git a/sbx/sac/policies.py b/sbx/sac/policies.py index 95c319f..86e3866 100644 --- a/sbx/sac/policies.py +++ b/sbx/sac/policies.py @@ -95,13 +95,17 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) # Hack to make gSDE work without modifying internal SB3 code self.actor.reset_noise = self.reset_noise + # Inject hyperparameters to be able to modify it later + # See https://stackoverflow.com/questions/78527164 + optimizer_class = optax.inject_hyperparams(self.optimizer_class)( + learning_rate=lr_schedule(1), + **self.optimizer_kwargs, + ) + self.actor_state = TrainState.create( apply_fn=self.actor.apply, params=self.actor.init(actor_key, obs), - tx=self.optimizer_class( - learning_rate=lr_schedule(1), # type: ignore[call-arg] - **self.optimizer_kwargs, - ), + tx=optimizer_class, ) self.qf = self.vector_critic_class( @@ -124,10 +128,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) obs, action, ), - tx=self.optimizer_class( - learning_rate=qf_learning_rate, # type: ignore[call-arg] - **self.optimizer_kwargs, - ), + tx=optimizer_class, ) self.actor.apply = jax.jit(self.actor.apply) # type: ignore[method-assign] diff --git a/sbx/sac/sac.py b/sbx/sac/sac.py index 403a6d4..cdd9079 100644 --- a/sbx/sac/sac.py +++ b/sbx/sac/sac.py @@ -16,7 +16,7 @@ from sbx.common.off_policy_algorithm import OffPolicyAlgorithmJax from sbx.common.type_aliases import ReplayBufferSamplesNp, RLTrainState -from sbx.common.utils import kl_div_gaussian +from sbx.common.utils import KlAdaptiveLR, kl_div_gaussian, update_learning_rate from sbx.sac.policies import SACPolicy, SimbaSACPolicy @@ -51,6 +51,7 @@ class SAC(OffPolicyAlgorithmJax): policy: SACPolicy action_space: spaces.Box # type: ignore[assignment] + adaptive_lr: KlAdaptiveLR def __init__( self, @@ -117,6 +118,10 @@ def __init__( self.target_entropy = target_entropy self.old_mean: Optional[jnp.ndarray] = None self.old_std: Optional[jnp.ndarray] = None + # If set, will trigger adaptive lr + self.target_kl = target_kl + if target_kl is not None and self.verbose > 0: + print(f"Using adaptive learning rate with {target_kl=}, any other lr schedule will be skipped.") if _init_setup_model: self._setup_model() @@ -124,6 +129,9 @@ def __init__( def _setup_model(self) -> None: super()._setup_model() + if self.target_kl is not None: + self.adaptive_lr = KlAdaptiveLR(self.target_kl, self.lr_schedule(1.0)) + if not hasattr(self, "policy") or self.policy is None: self.policy = self.policy_class( # type: ignore[assignment] self.observation_space, @@ -224,7 +232,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None: self.policy.actor_state, self.ent_coef_state, self.key, - (actor_loss_value, qf_loss_value, ent_coef_loss_value, ent_coef_value), + (actor_loss_value, qf_loss_value, ent_coef_loss_value, ent_coef_value, new_mean, new_log_std), ) = self._train( self.gamma, self.tau, @@ -245,6 +253,28 @@ def train(self, gradient_steps: int, batch_size: int) -> None: self.logger.record("train/ent_coef_loss", ent_coef_loss_value.item()) self.logger.record("train/ent_coef", ent_coef_value.item()) + if self.old_mean is None or self.old_std is None: + # First iteration + self.old_mean, self.old_std = new_mean, jnp.exp(new_log_std) + return + + # TODO: skip also when actor is not updated (delayed update) + new_std = jnp.exp(new_log_std) + approx_kl_div = kl_div_gaussian(self.old_mean, self.old_std, new_mean, new_std).item() + self.old_mean, self.old_std = new_mean, new_std + self.logger.record("train/approx_kl", approx_kl_div) + self.logger.record("train/approx_std", new_std.mean().item()) + self.logger.record("train/approx_mean", new_mean.mean().item()) + + if self.target_kl is not None: + # TODO: adaptive lr need to be in the inner loop? + self.adaptive_lr.update(approx_kl_div) + # Log the current learning rate + self.logger.record("train/learning_rate", self.adaptive_lr.current_adaptive_lr) + + for optimizer in [self.policy.actor_state.opt_state, self.policy.qf_state.opt_state]: + update_learning_rate(optimizer, self.adaptive_lr.current_adaptive_lr) + @staticmethod @jax.jit def update_critic( @@ -261,7 +291,7 @@ def update_critic( ): key, noise_key, dropout_key_target, dropout_key_current = jax.random.split(key, 4) # sample action from the actor - dist, _, _ = actor_state.apply_fn(actor_state.params, next_observations) + dist, new_mean, new_log_std = actor_state.apply_fn(actor_state.params, next_observations) next_state_actions = dist.sample(seed=noise_key) next_log_prob = dist.log_prob(next_state_actions) @@ -292,6 +322,7 @@ def mse_loss(params: flax.core.FrozenDict, dropout_key: jax.Array) -> jax.Array: qf_state, (qf_loss_value, ent_coef_value), key, + (new_mean, new_log_std), ) @staticmethod @@ -397,6 +428,8 @@ def _train( "qf_loss": jnp.array(0.0), "ent_coef_loss": jnp.array(0.0), "ent_coef_value": jnp.array(0.0), + "new_mean": jnp.zeros((batch_size, data.actions.shape[1])), + "new_log_std": jnp.zeros((batch_size, data.actions.shape[1])), }, } @@ -417,6 +450,7 @@ def one_update(i: int, carry: dict[str, Any]) -> dict[str, Any]: qf_state, (qf_loss_value, ent_coef_value), key, + (new_mean, new_log_std), ) = cls.update_critic( gamma, actor_state, @@ -449,6 +483,8 @@ def one_update(i: int, carry: dict[str, Any]) -> dict[str, Any]: "qf_loss": qf_loss_value, "ent_coef_loss": ent_coef_loss_value, "ent_coef_value": ent_coef_value, + "new_mean": new_mean, + "new_log_std": new_log_std, } return { @@ -471,5 +507,7 @@ def one_update(i: int, carry: dict[str, Any]) -> dict[str, Any]: update_carry["info"]["qf_loss"], update_carry["info"]["ent_coef_loss"], update_carry["info"]["ent_coef_value"], + update_carry["info"]["new_mean"], + update_carry["info"]["new_log_std"], ), ) diff --git a/tests/test_run.py b/tests/test_run.py index 9fcb04a..4cb4c2f 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -192,6 +192,6 @@ def test_dqn(tmp_path) -> None: @pytest.mark.parametrize("replay_buffer_class", [None, HerReplayBuffer]) def test_dict(replay_buffer_class: Optional[type[HerReplayBuffer]]) -> None: env = BitFlippingEnv(n_bits=2, continuous=True) - model = SAC("MultiInputPolicy", env, replay_buffer_class=replay_buffer_class) + model = SAC("MultiInputPolicy", env, target_kl=0.01, replay_buffer_class=replay_buffer_class) - model.learn(200, progress_bar=True) + model.learn(200, progress_bar=False) From ab3398382024ef4cea0f545b0c384d09595a281f Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Thu, 27 Feb 2025 20:42:24 +0100 Subject: [PATCH 11/28] Fix qf_learning_rate --- sbx/sac/policies.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/sbx/sac/policies.py b/sbx/sac/policies.py index 86e3866..c06c53a 100644 --- a/sbx/sac/policies.py +++ b/sbx/sac/policies.py @@ -97,7 +97,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) # Inject hyperparameters to be able to modify it later # See https://stackoverflow.com/questions/78527164 - optimizer_class = optax.inject_hyperparams(self.optimizer_class)( + actor_optimizer_class = optax.inject_hyperparams(self.optimizer_class)( learning_rate=lr_schedule(1), **self.optimizer_kwargs, ) @@ -105,7 +105,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) self.actor_state = TrainState.create( apply_fn=self.actor.apply, params=self.actor.init(actor_key, obs), - tx=optimizer_class, + tx=actor_optimizer_class, ) self.qf = self.vector_critic_class( @@ -116,6 +116,11 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) activation_fn=self.activation_fn, ) + qf_optimizer_class = optax.inject_hyperparams(self.optimizer_class)( + learning_rate=qf_learning_rate, + **self.optimizer_kwargs, + ) + self.qf_state = RLTrainState.create( apply_fn=self.qf.apply, params=self.qf.init( @@ -128,7 +133,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) obs, action, ), - tx=optimizer_class, + tx=qf_optimizer_class, ) self.actor.apply = jax.jit(self.actor.apply) # type: ignore[method-assign] From 163acbe56cdb7d8a9f1139beaff67c93f763be27 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 28 Feb 2025 08:11:15 +0100 Subject: [PATCH 12/28] Revert "Fix qf_learning_rate" This reverts commit ab3398382024ef4cea0f545b0c384d09595a281f. --- sbx/sac/policies.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/sbx/sac/policies.py b/sbx/sac/policies.py index c06c53a..86e3866 100644 --- a/sbx/sac/policies.py +++ b/sbx/sac/policies.py @@ -97,7 +97,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) # Inject hyperparameters to be able to modify it later # See https://stackoverflow.com/questions/78527164 - actor_optimizer_class = optax.inject_hyperparams(self.optimizer_class)( + optimizer_class = optax.inject_hyperparams(self.optimizer_class)( learning_rate=lr_schedule(1), **self.optimizer_kwargs, ) @@ -105,7 +105,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) self.actor_state = TrainState.create( apply_fn=self.actor.apply, params=self.actor.init(actor_key, obs), - tx=actor_optimizer_class, + tx=optimizer_class, ) self.qf = self.vector_critic_class( @@ -116,11 +116,6 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) activation_fn=self.activation_fn, ) - qf_optimizer_class = optax.inject_hyperparams(self.optimizer_class)( - learning_rate=qf_learning_rate, - **self.optimizer_kwargs, - ) - self.qf_state = RLTrainState.create( apply_fn=self.qf.apply, params=self.qf.init( @@ -133,7 +128,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) obs, action, ), - tx=qf_optimizer_class, + tx=optimizer_class, ) self.actor.apply = jax.jit(self.actor.apply) # type: ignore[method-assign] From 0d8372058bc145c7708c82c8060d8d3355749a7c Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 28 Feb 2025 08:11:26 +0100 Subject: [PATCH 13/28] Revert "Add adaptive lr for SAC" This reverts commit 58327023e4c55ad78e3c084338ed0835c9a76606. --- sbx/common/utils.py | 32 ++++++++++++-------------------- sbx/ppo/policies.py | 5 +---- sbx/ppo/ppo.py | 2 +- sbx/sac/policies.py | 17 ++++++++--------- sbx/sac/sac.py | 44 +++----------------------------------------- tests/test_run.py | 4 ++-- 6 files changed, 27 insertions(+), 77 deletions(-) diff --git a/sbx/common/utils.py b/sbx/common/utils.py index 26b8e6d..8619a88 100644 --- a/sbx/common/utils.py +++ b/sbx/common/utils.py @@ -1,6 +1,5 @@ from dataclasses import dataclass -import jax import jax.numpy as jnp import numpy as np import optax @@ -18,30 +17,27 @@ def update_learning_rate(opt_state: optax.OptState, learning_rate: float) -> Non opt_state.hyperparams["learning_rate"] = learning_rate -@jax.jit def kl_div_gaussian( - old_mean: jnp.ndarray, old_std: jnp.ndarray, - new_mean: jnp.ndarray, + old_mean: jnp.ndarray, new_std: jnp.ndarray, + new_mean: jnp.ndarray, eps: float = 1e-5, -) -> jnp.ndarray: +) -> float: # See https://stats.stackexchange.com/questions/7440/ # We have independent Gaussian for each action dim - # TODO: double check dimensions - # return jnp.sum( - # jnp.log(new_std / old_std + eps) + ((old_std) ** 2 + (old_mean - new_mean) ** 2) / (2.0 * (new_std) ** 2) - 0.5, - # axis=-1, - # ).mean() - # Another approximation, mean over batch and action dim - old_mean, old_std, new_mean, new_std = old_mean.mean(), old_std.mean(), new_mean.mean(), new_std.mean() - return jnp.log(new_std / old_std + eps) + ((old_std) ** 2 + (old_mean - new_mean) ** 2) / (2.0 * (new_std) ** 2) - 0.5 + return ( + jnp.sum( + jnp.log(new_std / old_std + eps) + ((old_std) ** 2 + (old_mean - new_mean) ** 2) / (2.0 * (new_std) ** 2) - 0.5, + axis=-1, + ) + .mean() + .item() + ) @dataclass class KlAdaptiveLR: - """Adaptive learning rate schedule, see https://arxiv.org/abs/1707.02286""" - # If set will trigger adaptive lr target_kl: float current_adaptive_lr: float @@ -58,8 +54,4 @@ def update(self, kl_div: float) -> None: elif kl_div < self.target_kl / self.kl_margin: self.current_adaptive_lr *= self.adaptive_lr_factor - self.current_adaptive_lr = np.clip( - self.current_adaptive_lr, - self.min_learning_rate, - self.max_learning_rate, - ) + self.current_adaptive_lr = np.clip(self.current_adaptive_lr, self.min_learning_rate, self.max_learning_rate) diff --git a/sbx/ppo/policies.py b/sbx/ppo/policies.py index 7d7dac8..731a33c 100644 --- a/sbx/ppo/policies.py +++ b/sbx/ppo/policies.py @@ -276,10 +276,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, max_grad_norm: float) -> # Inject hyperparameters to be able to modify it later # See https://stackoverflow.com/questions/78527164 # Note: eps=1e-5 for Adam - optimizer_class = optax.inject_hyperparams(self.optimizer_class)( - learning_rate=lr_schedule(1), - **self.optimizer_kwargs, - ) + optimizer_class = optax.inject_hyperparams(self.optimizer_class)(learning_rate=lr_schedule(1), **self.optimizer_kwargs) self.actor_state = TrainState.create( apply_fn=self.actor.apply, diff --git a/sbx/ppo/ppo.py b/sbx/ppo/ppo.py index 9b0632d..34035d8 100644 --- a/sbx/ppo/ppo.py +++ b/sbx/ppo/ppo.py @@ -162,7 +162,7 @@ def __init__( self.clip_range = clip_range self.clip_range_vf = clip_range_vf self.normalize_advantage = normalize_advantage - # If set, will trigger adaptive lr + # If set will trigger adaptive lr self.target_kl = target_kl if target_kl is not None and self.verbose > 0: print(f"Using adaptive learning rate with {target_kl=}, any other lr schedule will be skipped.") diff --git a/sbx/sac/policies.py b/sbx/sac/policies.py index 86e3866..95c319f 100644 --- a/sbx/sac/policies.py +++ b/sbx/sac/policies.py @@ -95,17 +95,13 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) # Hack to make gSDE work without modifying internal SB3 code self.actor.reset_noise = self.reset_noise - # Inject hyperparameters to be able to modify it later - # See https://stackoverflow.com/questions/78527164 - optimizer_class = optax.inject_hyperparams(self.optimizer_class)( - learning_rate=lr_schedule(1), - **self.optimizer_kwargs, - ) - self.actor_state = TrainState.create( apply_fn=self.actor.apply, params=self.actor.init(actor_key, obs), - tx=optimizer_class, + tx=self.optimizer_class( + learning_rate=lr_schedule(1), # type: ignore[call-arg] + **self.optimizer_kwargs, + ), ) self.qf = self.vector_critic_class( @@ -128,7 +124,10 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) obs, action, ), - tx=optimizer_class, + tx=self.optimizer_class( + learning_rate=qf_learning_rate, # type: ignore[call-arg] + **self.optimizer_kwargs, + ), ) self.actor.apply = jax.jit(self.actor.apply) # type: ignore[method-assign] diff --git a/sbx/sac/sac.py b/sbx/sac/sac.py index cdd9079..403a6d4 100644 --- a/sbx/sac/sac.py +++ b/sbx/sac/sac.py @@ -16,7 +16,7 @@ from sbx.common.off_policy_algorithm import OffPolicyAlgorithmJax from sbx.common.type_aliases import ReplayBufferSamplesNp, RLTrainState -from sbx.common.utils import KlAdaptiveLR, kl_div_gaussian, update_learning_rate +from sbx.common.utils import kl_div_gaussian from sbx.sac.policies import SACPolicy, SimbaSACPolicy @@ -51,7 +51,6 @@ class SAC(OffPolicyAlgorithmJax): policy: SACPolicy action_space: spaces.Box # type: ignore[assignment] - adaptive_lr: KlAdaptiveLR def __init__( self, @@ -118,10 +117,6 @@ def __init__( self.target_entropy = target_entropy self.old_mean: Optional[jnp.ndarray] = None self.old_std: Optional[jnp.ndarray] = None - # If set, will trigger adaptive lr - self.target_kl = target_kl - if target_kl is not None and self.verbose > 0: - print(f"Using adaptive learning rate with {target_kl=}, any other lr schedule will be skipped.") if _init_setup_model: self._setup_model() @@ -129,9 +124,6 @@ def __init__( def _setup_model(self) -> None: super()._setup_model() - if self.target_kl is not None: - self.adaptive_lr = KlAdaptiveLR(self.target_kl, self.lr_schedule(1.0)) - if not hasattr(self, "policy") or self.policy is None: self.policy = self.policy_class( # type: ignore[assignment] self.observation_space, @@ -232,7 +224,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None: self.policy.actor_state, self.ent_coef_state, self.key, - (actor_loss_value, qf_loss_value, ent_coef_loss_value, ent_coef_value, new_mean, new_log_std), + (actor_loss_value, qf_loss_value, ent_coef_loss_value, ent_coef_value), ) = self._train( self.gamma, self.tau, @@ -253,28 +245,6 @@ def train(self, gradient_steps: int, batch_size: int) -> None: self.logger.record("train/ent_coef_loss", ent_coef_loss_value.item()) self.logger.record("train/ent_coef", ent_coef_value.item()) - if self.old_mean is None or self.old_std is None: - # First iteration - self.old_mean, self.old_std = new_mean, jnp.exp(new_log_std) - return - - # TODO: skip also when actor is not updated (delayed update) - new_std = jnp.exp(new_log_std) - approx_kl_div = kl_div_gaussian(self.old_mean, self.old_std, new_mean, new_std).item() - self.old_mean, self.old_std = new_mean, new_std - self.logger.record("train/approx_kl", approx_kl_div) - self.logger.record("train/approx_std", new_std.mean().item()) - self.logger.record("train/approx_mean", new_mean.mean().item()) - - if self.target_kl is not None: - # TODO: adaptive lr need to be in the inner loop? - self.adaptive_lr.update(approx_kl_div) - # Log the current learning rate - self.logger.record("train/learning_rate", self.adaptive_lr.current_adaptive_lr) - - for optimizer in [self.policy.actor_state.opt_state, self.policy.qf_state.opt_state]: - update_learning_rate(optimizer, self.adaptive_lr.current_adaptive_lr) - @staticmethod @jax.jit def update_critic( @@ -291,7 +261,7 @@ def update_critic( ): key, noise_key, dropout_key_target, dropout_key_current = jax.random.split(key, 4) # sample action from the actor - dist, new_mean, new_log_std = actor_state.apply_fn(actor_state.params, next_observations) + dist, _, _ = actor_state.apply_fn(actor_state.params, next_observations) next_state_actions = dist.sample(seed=noise_key) next_log_prob = dist.log_prob(next_state_actions) @@ -322,7 +292,6 @@ def mse_loss(params: flax.core.FrozenDict, dropout_key: jax.Array) -> jax.Array: qf_state, (qf_loss_value, ent_coef_value), key, - (new_mean, new_log_std), ) @staticmethod @@ -428,8 +397,6 @@ def _train( "qf_loss": jnp.array(0.0), "ent_coef_loss": jnp.array(0.0), "ent_coef_value": jnp.array(0.0), - "new_mean": jnp.zeros((batch_size, data.actions.shape[1])), - "new_log_std": jnp.zeros((batch_size, data.actions.shape[1])), }, } @@ -450,7 +417,6 @@ def one_update(i: int, carry: dict[str, Any]) -> dict[str, Any]: qf_state, (qf_loss_value, ent_coef_value), key, - (new_mean, new_log_std), ) = cls.update_critic( gamma, actor_state, @@ -483,8 +449,6 @@ def one_update(i: int, carry: dict[str, Any]) -> dict[str, Any]: "qf_loss": qf_loss_value, "ent_coef_loss": ent_coef_loss_value, "ent_coef_value": ent_coef_value, - "new_mean": new_mean, - "new_log_std": new_log_std, } return { @@ -507,7 +471,5 @@ def one_update(i: int, carry: dict[str, Any]) -> dict[str, Any]: update_carry["info"]["qf_loss"], update_carry["info"]["ent_coef_loss"], update_carry["info"]["ent_coef_value"], - update_carry["info"]["new_mean"], - update_carry["info"]["new_log_std"], ), ) diff --git a/tests/test_run.py b/tests/test_run.py index 4cb4c2f..9fcb04a 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -192,6 +192,6 @@ def test_dqn(tmp_path) -> None: @pytest.mark.parametrize("replay_buffer_class", [None, HerReplayBuffer]) def test_dict(replay_buffer_class: Optional[type[HerReplayBuffer]]) -> None: env = BitFlippingEnv(n_bits=2, continuous=True) - model = SAC("MultiInputPolicy", env, target_kl=0.01, replay_buffer_class=replay_buffer_class) + model = SAC("MultiInputPolicy", env, replay_buffer_class=replay_buffer_class) - model.learn(200, progress_bar=False) + model.learn(200, progress_bar=True) From 85d4f2310cf52201fbcf0f0dacf5a69120d69bb3 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 28 Feb 2025 08:23:05 +0100 Subject: [PATCH 14/28] Revert kl div for SAC changes --- sbx/common/policies.py | 12 ++++++------ sbx/common/utils.py | 22 ++-------------------- sbx/crossq/crossq.py | 4 ++-- sbx/crossq/policies.py | 12 ++++++------ sbx/ppo/policies.py | 10 +++++----- sbx/ppo/ppo.py | 2 +- sbx/sac/sac.py | 8 ++------ sbx/tqc/tqc.py | 4 ++-- 8 files changed, 26 insertions(+), 48 deletions(-) diff --git a/sbx/common/policies.py b/sbx/common/policies.py index d5bb996..6ba4314 100644 --- a/sbx/common/policies.py +++ b/sbx/common/policies.py @@ -38,14 +38,14 @@ def __init__(self, *args, **kwargs): @staticmethod @jax.jit def sample_action(actor_state, observations, key): - dist, _, _ = actor_state.apply_fn(actor_state.params, observations) + dist = actor_state.apply_fn(actor_state.params, observations) action = dist.sample(seed=key) return action @staticmethod @jax.jit def select_action(actor_state, observations): - dist, _, _ = actor_state.apply_fn(actor_state.params, observations) + dist = actor_state.apply_fn(actor_state.params, observations) return dist.mode() @no_type_check @@ -251,7 +251,7 @@ def get_std(self): return jnp.array(0.0) @nn.compact - def __call__(self, x: jnp.ndarray) -> tuple[tfd.Distribution, jnp.ndarray, jnp.ndarray]: # type: ignore[name-defined] + def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-defined] x = Flatten()(x) for n_units in self.net_arch: x = nn.Dense(n_units)(x) @@ -272,7 +272,7 @@ def __call__(self, x: jnp.ndarray) -> tuple[tfd.Distribution, jnp.ndarray, jnp.n dist = TanhTransformedDistribution( tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)), ) - return dist, mean, log_std + return dist class SimbaSquashedGaussianActor(nn.Module): @@ -291,7 +291,7 @@ def get_std(self): return jnp.array(0.0) @nn.compact - def __call__(self, x: jnp.ndarray) -> tuple[tfd.Distribution, jnp.ndarray, jnp.ndarray]: # type: ignore[name-defined] + def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-defined] x = Flatten()(x) # Note: simba was using kernel_init=orthogonal_init(1) @@ -306,4 +306,4 @@ def __call__(self, x: jnp.ndarray) -> tuple[tfd.Distribution, jnp.ndarray, jnp.n dist = TanhTransformedDistribution( tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)), ) - return dist, mean, log_std + return dist diff --git a/sbx/common/utils.py b/sbx/common/utils.py index 8619a88..c8f7028 100644 --- a/sbx/common/utils.py +++ b/sbx/common/utils.py @@ -1,6 +1,5 @@ from dataclasses import dataclass -import jax.numpy as jnp import numpy as np import optax @@ -17,27 +16,10 @@ def update_learning_rate(opt_state: optax.OptState, learning_rate: float) -> Non opt_state.hyperparams["learning_rate"] = learning_rate -def kl_div_gaussian( - old_std: jnp.ndarray, - old_mean: jnp.ndarray, - new_std: jnp.ndarray, - new_mean: jnp.ndarray, - eps: float = 1e-5, -) -> float: - # See https://stats.stackexchange.com/questions/7440/ - # We have independent Gaussian for each action dim - return ( - jnp.sum( - jnp.log(new_std / old_std + eps) + ((old_std) ** 2 + (old_mean - new_mean) ** 2) / (2.0 * (new_std) ** 2) - 0.5, - axis=-1, - ) - .mean() - .item() - ) - - @dataclass class KlAdaptiveLR: + """Adaptive lr schedule, see https://arxiv.org/abs/1707.02286""" + # If set will trigger adaptive lr target_kl: float current_adaptive_lr: float diff --git a/sbx/crossq/crossq.py b/sbx/crossq/crossq.py index 7d0bcd9..ee9d74b 100644 --- a/sbx/crossq/crossq.py +++ b/sbx/crossq/crossq.py @@ -255,7 +255,7 @@ def update_critic( ): key, noise_key, dropout_key_target, dropout_key_current = jax.random.split(key, 4) # sample action from the actor - dist, _, _ = actor_state.apply_fn( + dist = actor_state.apply_fn( {"params": actor_state.params, "batch_stats": actor_state.batch_stats}, next_observations, train=False, @@ -330,7 +330,7 @@ def update_actor( def actor_loss( params: flax.core.FrozenDict, batch_stats: flax.core.FrozenDict ) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: - (dist, _, _), state_updates = actor_state.apply_fn( + dist, state_updates = actor_state.apply_fn( {"params": params, "batch_stats": batch_stats}, observations, mutable=["batch_stats"], diff --git a/sbx/crossq/policies.py b/sbx/crossq/policies.py index 8de9ff3..c0456c2 100644 --- a/sbx/crossq/policies.py +++ b/sbx/crossq/policies.py @@ -182,7 +182,7 @@ def get_std(self): return jnp.array(0.0) @nn.compact - def __call__(self, x: jnp.ndarray, train: bool = False) -> tuple[tfd.Distribution, jnp.ndarray, jnp.ndarray]: # type: ignore[name-defined] + def __call__(self, x: jnp.ndarray, train: bool = False) -> tfd.Distribution: # type: ignore[name-defined] x = Flatten()(x) norm_layer = partial( BatchRenorm, @@ -208,7 +208,7 @@ def __call__(self, x: jnp.ndarray, train: bool = False) -> tuple[tfd.Distributio dist = TanhTransformedDistribution( tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)), ) - return dist, mean, log_std + return dist class Actor(nn.Module): @@ -226,7 +226,7 @@ def get_std(self): return jnp.array(0.0) @nn.compact - def __call__(self, x: jnp.ndarray, train: bool = False) -> tuple[tfd.Distribution, jnp.ndarray, jnp.ndarray]: # type: ignore[name-defined] + def __call__(self, x: jnp.ndarray, train: bool = False) -> tfd.Distribution: # type: ignore[name-defined] x = Flatten()(x) if self.use_batch_norm: x = BatchRenorm( @@ -254,7 +254,7 @@ def __call__(self, x: jnp.ndarray, train: bool = False) -> tuple[tfd.Distributio dist = TanhTransformedDistribution( tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)), ) - return dist, mean, log_std + return dist class CrossQPolicy(BaseJaxPolicy): @@ -426,7 +426,7 @@ def forward(self, obs: np.ndarray, deterministic: bool = False) -> np.ndarray: @staticmethod @jax.jit def sample_action(actor_state, observations, key): - dist, _, _ = actor_state.apply_fn( + dist = actor_state.apply_fn( {"params": actor_state.params, "batch_stats": actor_state.batch_stats}, observations, train=False, @@ -437,7 +437,7 @@ def sample_action(actor_state, observations, key): @staticmethod @jax.jit def select_action(actor_state, observations): - dist, _, _ = actor_state.apply_fn( + dist = actor_state.apply_fn( {"params": actor_state.params, "batch_stats": actor_state.batch_stats}, observations, train=False, diff --git a/sbx/ppo/policies.py b/sbx/ppo/policies.py index 731a33c..395fc94 100644 --- a/sbx/ppo/policies.py +++ b/sbx/ppo/policies.py @@ -78,7 +78,7 @@ def __post_init__(self) -> None: super().__post_init__() @nn.compact - def __call__(self, x: jnp.ndarray) -> tuple[tfd.Distribution, jnp.ndarray, jnp.ndarray]: # type: ignore[name-defined] + def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-defined] x = Flatten()(x) for n_units in self.net_arch: @@ -122,7 +122,7 @@ def __call__(self, x: jnp.ndarray) -> tuple[tfd.Distribution, jnp.ndarray, jnp.n dist = tfp.distributions.Independent( tfp.distributions.Categorical(logits=logits_padded), reinterpreted_batch_ndims=1 ) - return dist, action_logits, log_std + return dist class SimbaActor(nn.Module): @@ -143,7 +143,7 @@ def get_std(self) -> jnp.ndarray: return jnp.array(0.0) @nn.compact - def __call__(self, x: jnp.ndarray) -> tuple[tfd.Distribution, jnp.ndarray, jnp.ndarray]: # type: ignore[name-defined] + def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-defined] x = Flatten()(x) x = nn.Dense(self.net_arch[0])(x) @@ -163,7 +163,7 @@ def __call__(self, x: jnp.ndarray) -> tuple[tfd.Distribution, jnp.ndarray, jnp.n log_std = self.param("log_std", constant(self.log_std_init), (self.action_dim,)) dist = tfd.MultivariateNormalDiag(loc=mean_action, scale_diag=jnp.exp(log_std)) - return dist, mean_action, log_std + return dist class PPOPolicy(BaseJaxPolicy): @@ -326,7 +326,7 @@ def predict_all(self, observation: np.ndarray, key: jax.Array) -> np.ndarray: @staticmethod @jax.jit def _predict_all(actor_state, vf_state, observations, key): - dist, _, _ = actor_state.apply_fn(actor_state.params, observations) + dist = actor_state.apply_fn(actor_state.params, observations) actions = dist.sample(seed=key) log_probs = dist.log_prob(actions) values = vf_state.apply_fn(vf_state.params, observations).flatten() diff --git a/sbx/ppo/ppo.py b/sbx/ppo/ppo.py index 34035d8..e8a47bc 100644 --- a/sbx/ppo/ppo.py +++ b/sbx/ppo/ppo.py @@ -220,7 +220,7 @@ def _one_update( advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) def actor_loss(params): - dist, _, _ = actor_state.apply_fn(params, observations) + dist = actor_state.apply_fn(params, observations) log_prob = dist.log_prob(actions) entropy = dist.entropy() diff --git a/sbx/sac/sac.py b/sbx/sac/sac.py index 403a6d4..a6826c5 100644 --- a/sbx/sac/sac.py +++ b/sbx/sac/sac.py @@ -16,7 +16,6 @@ from sbx.common.off_policy_algorithm import OffPolicyAlgorithmJax from sbx.common.type_aliases import ReplayBufferSamplesNp, RLTrainState -from sbx.common.utils import kl_div_gaussian from sbx.sac.policies import SACPolicy, SimbaSACPolicy @@ -71,7 +70,6 @@ def __init__( replay_buffer_kwargs: Optional[dict[str, Any]] = None, ent_coef: Union[str, float] = "auto", target_entropy: Union[Literal["auto"], float] = "auto", - target_kl: Optional[float] = None, use_sde: bool = False, sde_sample_freq: int = -1, use_sde_at_warmup: bool = False, @@ -115,8 +113,6 @@ def __init__( self.policy_delay = policy_delay self.ent_coef_init = ent_coef self.target_entropy = target_entropy - self.old_mean: Optional[jnp.ndarray] = None - self.old_std: Optional[jnp.ndarray] = None if _init_setup_model: self._setup_model() @@ -261,7 +257,7 @@ def update_critic( ): key, noise_key, dropout_key_target, dropout_key_current = jax.random.split(key, 4) # sample action from the actor - dist, _, _ = actor_state.apply_fn(actor_state.params, next_observations) + dist = actor_state.apply_fn(actor_state.params, next_observations) next_state_actions = dist.sample(seed=noise_key) next_log_prob = dist.log_prob(next_state_actions) @@ -306,7 +302,7 @@ def update_actor( key, dropout_key, noise_key = jax.random.split(key, 3) def actor_loss(params: flax.core.FrozenDict) -> tuple[jax.Array, jax.Array]: - dist, _, _ = actor_state.apply_fn(params, observations) + dist = actor_state.apply_fn(params, observations) actor_actions = dist.sample(seed=noise_key) log_prob = dist.log_prob(actor_actions).reshape(-1, 1) diff --git a/sbx/tqc/tqc.py b/sbx/tqc/tqc.py index d90d9ed..91b5c9b 100644 --- a/sbx/tqc/tqc.py +++ b/sbx/tqc/tqc.py @@ -266,7 +266,7 @@ def update_critic( key, noise_key, dropout_key_1, dropout_key_2 = jax.random.split(key, 4) key, dropout_key_3, dropout_key_4 = jax.random.split(key, 3) # sample action from the actor - dist, _, _ = actor_state.apply_fn(actor_state.params, next_observations) + dist = actor_state.apply_fn(actor_state.params, next_observations) next_state_actions = dist.sample(seed=noise_key) next_log_prob = dist.log_prob(next_state_actions) @@ -345,7 +345,7 @@ def update_actor( key, dropout_key_1, dropout_key_2, noise_key = jax.random.split(key, 4) def actor_loss(params: flax.core.FrozenDict) -> tuple[jax.Array, jax.Array]: - dist, _, _ = actor_state.apply_fn(params, observations) + dist = actor_state.apply_fn(params, observations) actor_actions = dist.sample(seed=noise_key) log_prob = dist.log_prob(actor_actions).reshape(-1, 1) From dc6bf4e1d0df4be2a367d4cc624090ec198c4aa6 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 28 Feb 2025 08:25:47 +0100 Subject: [PATCH 15/28] Revert dist.mode() in two lines --- sbx/common/policies.py | 3 +-- sbx/crossq/policies.py | 5 ++--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/sbx/common/policies.py b/sbx/common/policies.py index 6ba4314..9a44bed 100644 --- a/sbx/common/policies.py +++ b/sbx/common/policies.py @@ -45,8 +45,7 @@ def sample_action(actor_state, observations, key): @staticmethod @jax.jit def select_action(actor_state, observations): - dist = actor_state.apply_fn(actor_state.params, observations) - return dist.mode() + return actor_state.apply_fn(actor_state.params, observations).mode() @no_type_check def predict( diff --git a/sbx/crossq/policies.py b/sbx/crossq/policies.py index c0456c2..c4e2986 100644 --- a/sbx/crossq/policies.py +++ b/sbx/crossq/policies.py @@ -437,12 +437,11 @@ def sample_action(actor_state, observations, key): @staticmethod @jax.jit def select_action(actor_state, observations): - dist = actor_state.apply_fn( + return actor_state.apply_fn( {"params": actor_state.params, "batch_stats": actor_state.batch_stats}, observations, train=False, - ) - return dist.mode() + ).mode() def _predict(self, observation: np.ndarray, deterministic: bool = False) -> np.ndarray: # type: ignore[override] if deterministic: From f0235e60aa4f895231cae37a5da69d09b1bdaa35 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 3 Mar 2025 15:10:46 +0100 Subject: [PATCH 16/28] Cleanup code --- sbx/common/on_policy_algorithm.py | 4 ++-- sbx/common/utils.py | 17 ++--------------- sbx/ppo/ppo.py | 6 +++--- 3 files changed, 7 insertions(+), 20 deletions(-) diff --git a/sbx/common/on_policy_algorithm.py b/sbx/common/on_policy_algorithm.py index 48cbff1..d0c7d97 100644 --- a/sbx/common/on_policy_algorithm.py +++ b/sbx/common/on_policy_algorithm.py @@ -13,7 +13,6 @@ from stable_baselines3.common.type_aliases import GymEnv, Schedule from stable_baselines3.common.vec_env import VecEnv -from sbx.common.utils import update_learning_rate from sbx.ppo.policies import Actor, Critic, PPOPolicy OnPolicyAlgorithmSelf = TypeVar("OnPolicyAlgorithmSelf", bound="OnPolicyAlgorithmJax") @@ -95,7 +94,8 @@ def _update_learning_rate( # type: ignore[override] if not isinstance(optimizers, list): optimizers = [optimizers] for optimizer in optimizers: - update_learning_rate(optimizer, learning_rate) + # Note: the optimizer must have been defined with inject_hyperparams + optimizer.hyperparams["learning_rate"] = learning_rate def set_random_seed(self, seed: Optional[int]) -> None: # type: ignore[override] super().set_random_seed(seed) diff --git a/sbx/common/utils.py b/sbx/common/utils.py index c8f7028..447c4f3 100644 --- a/sbx/common/utils.py +++ b/sbx/common/utils.py @@ -1,23 +1,10 @@ from dataclasses import dataclass import numpy as np -import optax - - -def update_learning_rate(opt_state: optax.OptState, learning_rate: float) -> None: - """ - Update the learning rate for a given optimizer. - Useful when doing linear schedule. - - :param optimizer: Optax optimizer state - :param learning_rate: New learning rate value - """ - # Note: the optimizer must have been defined with inject_hyperparams - opt_state.hyperparams["learning_rate"] = learning_rate @dataclass -class KlAdaptiveLR: +class KLAdaptiveLR: """Adaptive lr schedule, see https://arxiv.org/abs/1707.02286""" # If set will trigger adaptive lr @@ -27,7 +14,7 @@ class KlAdaptiveLR: min_learning_rate: float = 1e-5 max_learning_rate: float = 1e-2 kl_margin: float = 2.0 - # Divide or multiple the lr by this factor + # Divide or multiply the lr by this factor adaptive_lr_factor: float = 1.5 def update(self, kl_div: float) -> None: diff --git a/sbx/ppo/ppo.py b/sbx/ppo/ppo.py index e8a47bc..6665321 100644 --- a/sbx/ppo/ppo.py +++ b/sbx/ppo/ppo.py @@ -11,7 +11,7 @@ from stable_baselines3.common.utils import explained_variance, get_schedule_fn from sbx.common.on_policy_algorithm import OnPolicyAlgorithmJax -from sbx.common.utils import KlAdaptiveLR +from sbx.common.utils import KLAdaptiveLR from sbx.ppo.policies import PPOPolicy, SimbaPPOPolicy PPOSelf = TypeVar("PPOSelf", bound="PPO") @@ -76,7 +76,7 @@ class PPO(OnPolicyAlgorithmJax): # "MultiInputPolicy": MultiInputActorCriticPolicy, } policy: PPOPolicy # type: ignore[assignment] - adaptive_lr: KlAdaptiveLR + adaptive_lr: KLAdaptiveLR def __init__( self, @@ -174,7 +174,7 @@ def _setup_model(self) -> None: super()._setup_model() if self.target_kl is not None: - self.adaptive_lr = KlAdaptiveLR(self.target_kl, self.lr_schedule(1.0)) + self.adaptive_lr = KLAdaptiveLR(self.target_kl, self.lr_schedule(1.0)) if not hasattr(self, "policy") or self.policy is None: # type: ignore[has-type] self.policy = self.policy_class( # type: ignore[assignment] From 0e14aad17fa96945298100ac1c6e1737997b30bd Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sat, 8 Mar 2025 13:38:56 +0100 Subject: [PATCH 17/28] Add support for Gaussian actor for SAC --- sbx/common/off_policy_algorithm.py | 58 ++++++++++++++++++++++++++++++ sbx/common/policies.py | 18 ++++++---- sbx/sac/policies.py | 23 ++++++++---- sbx/sac/sac.py | 5 +++ sbx/tqc/policies.py | 4 +-- 5 files changed, 93 insertions(+), 15 deletions(-) diff --git a/sbx/common/off_policy_algorithm.py b/sbx/common/off_policy_algorithm.py index ba49e1f..d2efabe 100644 --- a/sbx/common/off_policy_algorithm.py +++ b/sbx/common/off_policy_algorithm.py @@ -146,3 +146,61 @@ def load_replay_buffer( # Override replay buffer device to be always cpu for conversion to numpy assert self.replay_buffer is not None self.replay_buffer.device = get_device("cpu") + + def _sample_action( + self, + learning_starts: int, + action_noise: Optional[ActionNoise] = None, + n_envs: int = 1, + ) -> tuple[np.ndarray, np.ndarray]: + """ + Sample an action according to the exploration policy. + This is either done by sampling the probability distribution of the policy, + or sampling a random action (from a uniform distribution over the action space) + or by adding noise to the deterministic output. + + :param action_noise: Action noise that will be used for exploration + Required for deterministic policy (e.g. TD3). This can also be used + in addition to the stochastic policy for SAC. + :param learning_starts: Number of steps before learning for the warm-up phase. + :param n_envs: + :return: action to take in the environment + and scaled action that will be stored in the replay buffer. + The two differs when the action space is not normalized (bounds are not [-1, 1]). + """ + scaled_action = np.array([0.0]) + # Select action randomly or according to policy + if self.num_timesteps < learning_starts and not (self.use_sde and self.use_sde_at_warmup): + # Warmup phase + action = np.array([self.action_space.sample() for _ in range(n_envs)]) + if isinstance(self.action_space, spaces.Box): + scaled_action = self.policy.scale_action(action) + else: + assert self._last_obs is not None, "self._last_obs was not set" + obs_tensor, _ = self.policy.prepare_obs(self._last_obs) + action = np.array(self.policy._predict(obs_tensor, deterministic=False)) + if self.policy.squash_output: + scaled_action = action + + # Rescale the action from [low, high] to [-1, 1] + if isinstance(self.action_space, spaces.Box) and self.policy.squash_output: + # Add noise to the action (improve exploration) + if action_noise is not None: + scaled_action = np.clip(scaled_action + action_noise(), -1, 1) + + # We store the scaled action in the buffer + buffer_action = scaled_action + action = self.policy.unscale_action(scaled_action) + elif isinstance(self.action_space, spaces.Box) and not self.policy.squash_output: + # Add noise to the action (improve exploration) + if action_noise is not None: + action = action + action_noise() + + buffer_action = action + # Actions could be on arbitrary scale, so clip the actions to avoid + # out of bound error (e.g. if sampling from a Gaussian distribution) + action = np.clip(action, self.action_space.low, self.action_space.high) + else: + # Discrete case, no need to normalize or clip + buffer_action = action + return action, buffer_action diff --git a/sbx/common/policies.py b/sbx/common/policies.py index 9a44bed..847a139 100644 --- a/sbx/common/policies.py +++ b/sbx/common/policies.py @@ -236,14 +236,15 @@ def __call__(self, obs: jnp.ndarray, action: jnp.ndarray): return q_values -class SquashedGaussianActor(nn.Module): +class GaussianActor(nn.Module): net_arch: Sequence[int] action_dim: int log_std_min: float = -20 log_std_max: float = 2 activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu + squash_output: bool = True ortho_init: bool = False - log_std_init: float = -1.2 # log(0.3) + log_std_init: float = 0.0 # -1.2 # log(0.3) def get_std(self): # Make it work with gSDE @@ -267,10 +268,15 @@ def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-def else: mean = nn.Dense(self.action_dim)(x) log_std = nn.Dense(self.action_dim)(x) - log_std = jnp.clip(log_std, self.log_std_min, self.log_std_max) - dist = TanhTransformedDistribution( - tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)), - ) + + if self.squash_output: + log_std = jnp.clip(log_std, self.log_std_min, self.log_std_max) + dist = TanhTransformedDistribution( + tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)), + ) + else: + log_std = jnp.clip(log_std, self.log_std_min, self.log_std_max) + dist = tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)) return dist diff --git a/sbx/sac/policies.py b/sbx/sac/policies.py index 95c319f..2cae5ea 100644 --- a/sbx/sac/policies.py +++ b/sbx/sac/policies.py @@ -11,9 +11,9 @@ from sbx.common.policies import ( BaseJaxPolicy, + GaussianActor, SimbaSquashedGaussianActor, SimbaVectorCritic, - SquashedGaussianActor, VectorCritic, ) from sbx.common.type_aliases import RLTrainState @@ -32,9 +32,9 @@ def __init__( layer_norm: bool = False, activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu, use_sde: bool = False, - # Note: most gSDE parameters are not used - # this is to keep API consistent with SB3 - log_std_init: float = -3, + log_std_init: float = 0.0, + squash_output: bool = True, + ortho_init: bool = False, use_expln: bool = False, clip_mean: float = 2.0, features_extractor_class=None, @@ -44,7 +44,7 @@ def __init__( optimizer_kwargs: Optional[dict[str, Any]] = None, n_critics: int = 2, share_features_extractor: bool = False, - actor_class: type[nn.Module] = SquashedGaussianActor, + actor_class: type[nn.Module] = GaussianActor, vector_critic_class: type[nn.Module] = VectorCritic, ): super().__init__( @@ -54,7 +54,7 @@ def __init__( features_extractor_kwargs, optimizer_class=optimizer_class, optimizer_kwargs=optimizer_kwargs, - squash_output=True, + squash_output=squash_output, ) self.dropout_rate = dropout_rate self.layer_norm = layer_norm @@ -71,6 +71,8 @@ def __init__( self.activation_fn = activation_fn self.actor_class = actor_class self.vector_critic_class = vector_critic_class + self.log_std_init = log_std_init + self.ortho_init = ortho_init self.key = self.noise_key = jax.random.PRNGKey(0) @@ -91,6 +93,9 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) action_dim=int(np.prod(self.action_space.shape)), net_arch=self.net_arch_pi, activation_fn=self.activation_fn, + squash_output=self.squash_output, + log_std_init=self.log_std_init, + ortho_init=self.ortho_init, ) # Hack to make gSDE work without modifying internal SB3 code self.actor.reset_noise = self.reset_noise @@ -167,7 +172,9 @@ def __init__( layer_norm: bool = False, activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu, use_sde: bool = False, - log_std_init: float = -3, + log_std_init: float = 0.0, + squash_output: bool = True, + ortho_init: bool = False, use_expln: bool = False, clip_mean: float = 2, features_extractor_class=None, @@ -191,6 +198,8 @@ def __init__( activation_fn, use_sde, log_std_init, + squash_output, + ortho_init, use_expln, clip_mean, features_extractor_class, diff --git a/sbx/sac/sac.py b/sbx/sac/sac.py index a6826c5..f17fa96 100644 --- a/sbx/sac/sac.py +++ b/sbx/sac/sac.py @@ -240,6 +240,11 @@ def train(self, gradient_steps: int, batch_size: int) -> None: self.logger.record("train/critic_loss", qf_loss_value.item()) self.logger.record("train/ent_coef_loss", ent_coef_loss_value.item()) self.logger.record("train/ent_coef", ent_coef_value.item()) + try: + log_std = self.policy.actor_state.params["params"]["log_std"] + self.logger.record("train/std", np.exp(log_std).mean().item()) + except KeyError: + pass @staticmethod @jax.jit diff --git a/sbx/tqc/policies.py b/sbx/tqc/policies.py index 0c4c03f..9147efe 100644 --- a/sbx/tqc/policies.py +++ b/sbx/tqc/policies.py @@ -12,9 +12,9 @@ from sbx.common.policies import ( BaseJaxPolicy, ContinuousCritic, + GaussianActor, SimbaContinuousCritic, SimbaSquashedGaussianActor, - SquashedGaussianActor, ) from sbx.common.type_aliases import RLTrainState @@ -46,7 +46,7 @@ def __init__( optimizer_kwargs: Optional[dict[str, Any]] = None, n_critics: int = 2, share_features_extractor: bool = False, - actor_class: type[nn.Module] = SquashedGaussianActor, + actor_class: type[nn.Module] = GaussianActor, critic_class: type[nn.Module] = ContinuousCritic, ): super().__init__( From 1a8063a6686c899b712a8a1bb4bcf523c9e5a6d1 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Wed, 26 Mar 2025 12:33:11 +0100 Subject: [PATCH 18/28] Enable Gaussian actor for TQC --- sbx/tqc/policies.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/sbx/tqc/policies.py b/sbx/tqc/policies.py index 9147efe..faafcb2 100644 --- a/sbx/tqc/policies.py +++ b/sbx/tqc/policies.py @@ -34,9 +34,9 @@ def __init__( n_quantiles: int = 25, activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu, use_sde: bool = False, - # Note: most gSDE parameters are not used - # this is to keep API consistent with SB3 - log_std_init: float = -3, + log_std_init: float = 0.0, + squash_output: bool = True, + ortho_init: bool = False, use_expln: bool = False, clip_mean: float = 2.0, features_extractor_class=None, @@ -56,7 +56,7 @@ def __init__( features_extractor_kwargs, optimizer_class=optimizer_class, optimizer_kwargs=optimizer_kwargs, - squash_output=True, + squash_output=squash_output, ) self.dropout_rate = dropout_rate self.layer_norm = layer_norm @@ -79,6 +79,8 @@ def __init__( self.activation_fn = activation_fn self.actor_class = actor_class self.critic_class = critic_class + self.log_std_init = log_std_init + self.ortho_init = ortho_init self.key = self.noise_key = jax.random.PRNGKey(0) @@ -99,6 +101,9 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) action_dim=int(np.prod(self.action_space.shape)), net_arch=self.net_arch_pi, activation_fn=self.activation_fn, + squash_output=self.squash_output, + log_std_init=self.log_std_init, + ortho_init=self.ortho_init, ) # Hack to make gSDE work without modifying internal SB3 code self.actor.reset_noise = self.reset_noise @@ -190,7 +195,9 @@ def __init__( n_quantiles: int = 25, activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu, use_sde: bool = False, - log_std_init: float = -3, + log_std_init: float = 0.0, + squash_output: bool = True, + ortho_init: bool = False, use_expln: bool = False, clip_mean: float = 2, features_extractor_class=None, @@ -215,6 +222,8 @@ def __init__( activation_fn, use_sde, log_std_init, + squash_output, + ortho_init, use_expln, clip_mean, features_extractor_class, From cefbd78452a57451cac9dcd9274f41738982c72e Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Wed, 26 Mar 2025 12:56:31 +0100 Subject: [PATCH 19/28] Log std too --- sbx/tqc/tqc.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sbx/tqc/tqc.py b/sbx/tqc/tqc.py index 91b5c9b..730d13a 100644 --- a/sbx/tqc/tqc.py +++ b/sbx/tqc/tqc.py @@ -246,6 +246,11 @@ def train(self, gradient_steps: int, batch_size: int) -> None: self.logger.record("train/critic_loss", qf1_loss_value.item()) self.logger.record("train/ent_coef_loss", ent_coef_loss_value.item()) self.logger.record("train/ent_coef", ent_coef_value.item()) + try: + log_std = self.policy.actor_state.params["params"]["log_std"] + self.logger.record("train/std", np.exp(log_std).mean().item()) + except KeyError: + pass @staticmethod @partial(jax.jit, static_argnames=["n_target_quantiles"]) From a809a01a6dd04d441591886a2cb5a058762c23ac Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 28 Apr 2025 18:19:52 +0200 Subject: [PATCH 20/28] Avoid NaN in kl div approx --- sbx/ppo/ppo.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sbx/ppo/ppo.py b/sbx/ppo/ppo.py index 6665321..ad9f802 100644 --- a/sbx/ppo/ppo.py +++ b/sbx/ppo/ppo.py @@ -299,7 +299,8 @@ def train(self) -> None: # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417 # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419 # and Schulman blog: http://joschu.net/blog/kl-approx.html - approx_kl_div = jnp.mean((ratio - 1.0) - jnp.log(ratio)).item() + eps = 1e-7 # Avoid NaN due to numerical instabilities + approx_kl_div = jnp.mean((ratio - 1.0 + eps) - jnp.log(ratio + eps)).item() clip_fraction = jnp.mean(jnp.abs(ratio - 1) > clip_range).item() # Compute average mean_clip_fraction += (clip_fraction - mean_clip_fraction) / n_updates From c92d840624064473e33656ecc3ecda80fb11e53c Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 28 Apr 2025 18:20:09 +0200 Subject: [PATCH 21/28] Allow to use layer_norm in actor --- sbx/common/policies.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sbx/common/policies.py b/sbx/common/policies.py index 847a139..3c52ad2 100644 --- a/sbx/common/policies.py +++ b/sbx/common/policies.py @@ -245,6 +245,7 @@ class GaussianActor(nn.Module): squash_output: bool = True ortho_init: bool = False log_std_init: float = 0.0 # -1.2 # log(0.3) + use_layer_norm: bool = False def get_std(self): # Make it work with gSDE @@ -255,6 +256,8 @@ def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-def x = Flatten()(x) for n_units in self.net_arch: x = nn.Dense(n_units)(x) + if self.use_layer_norm: + x = nn.LayerNorm()(x) x = self.activation_fn(x) if self.ortho_init: From f54697af60ff7a6719ebcf54ab649981d8cd95cb Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 14 May 2025 12:49:39 +0200 Subject: [PATCH 22/28] Reformat --- sbx/ppo/ppo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sbx/ppo/ppo.py b/sbx/ppo/ppo.py index ad9f802..d58f926 100644 --- a/sbx/ppo/ppo.py +++ b/sbx/ppo/ppo.py @@ -299,7 +299,7 @@ def train(self) -> None: # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417 # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419 # and Schulman blog: http://joschu.net/blog/kl-approx.html - eps = 1e-7 # Avoid NaN due to numerical instabilities + eps = 1e-7 # Avoid NaN due to numerical instabilities approx_kl_div = jnp.mean((ratio - 1.0 + eps) - jnp.log(ratio + eps)).item() clip_fraction = jnp.mean(jnp.abs(ratio - 1) > clip_range).item() # Compute average From 7afa84f29f8fa13cd752abba2eb6e0666f501cea Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 14 May 2025 12:50:17 +0200 Subject: [PATCH 23/28] Allow max grad norm for TQC and fix optimizer class --- sbx/tqc/policies.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/sbx/tqc/policies.py b/sbx/tqc/policies.py index faafcb2..41601aa 100644 --- a/sbx/tqc/policies.py +++ b/sbx/tqc/policies.py @@ -84,7 +84,7 @@ def __init__( self.key = self.noise_key = jax.random.PRNGKey(0) - def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) -> jax.Array: + def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float, max_grad_norm: float = 100) -> jax.Array: key, actor_key, qf1_key, qf2_key = jax.random.split(key, 4) key, dropout_key1, dropout_key2, self.key = jax.random.split(key, 4) # Initialize noise @@ -108,12 +108,16 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) # Hack to make gSDE work without modifying internal SB3 code self.actor.reset_noise = self.reset_noise + # Inject hyperparameters to be able to modify it later + # See https://stackoverflow.com/questions/78527164 + optimizer_class = optax.inject_hyperparams(self.optimizer_class)(learning_rate=lr_schedule(1), **self.optimizer_kwargs) + self.actor_state = TrainState.create( apply_fn=self.actor.apply, params=self.actor.init(actor_key, obs), - tx=self.optimizer_class( - learning_rate=lr_schedule(1), # type: ignore[call-arg] - **self.optimizer_kwargs, + tx=optax.chain( + optax.clip_by_global_norm(max_grad_norm), + optimizer_class, ), ) @@ -125,6 +129,10 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) activation_fn=self.activation_fn, ) + optimizer_class_qf = optax.inject_hyperparams(self.optimizer_class)( + learning_rate=qf_learning_rate, **self.optimizer_kwargs + ) + self.qf1_state = RLTrainState.create( apply_fn=self.qf.apply, params=self.qf.init( @@ -137,7 +145,10 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) obs, action, ), - tx=optax.adam(learning_rate=qf_learning_rate), # type: ignore[call-arg] + tx=optax.chain( + optax.clip_by_global_norm(max_grad_norm), + optimizer_class_qf, + ), ) self.qf2_state = RLTrainState.create( apply_fn=self.qf.apply, @@ -151,9 +162,9 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) obs, action, ), - tx=self.optimizer_class( - learning_rate=qf_learning_rate, # type: ignore[call-arg] - **self.optimizer_kwargs, + tx=optax.chain( + optax.clip_by_global_norm(max_grad_norm), + optimizer_class_qf, ), ) self.actor.apply = jax.jit(self.actor.apply) # type: ignore[method-assign] From 3af7c93cf817cd3425747bbbcf530cf5fcf803e3 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 14 May 2025 13:25:54 +0200 Subject: [PATCH 24/28] Comment out max grad norm --- sbx/tqc/policies.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sbx/tqc/policies.py b/sbx/tqc/policies.py index 41601aa..3d579ac 100644 --- a/sbx/tqc/policies.py +++ b/sbx/tqc/policies.py @@ -84,7 +84,7 @@ def __init__( self.key = self.noise_key = jax.random.PRNGKey(0) - def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float, max_grad_norm: float = 100) -> jax.Array: + def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) -> jax.Array: key, actor_key, qf1_key, qf2_key = jax.random.split(key, 4) key, dropout_key1, dropout_key2, self.key = jax.random.split(key, 4) # Initialize noise @@ -116,7 +116,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float, apply_fn=self.actor.apply, params=self.actor.init(actor_key, obs), tx=optax.chain( - optax.clip_by_global_norm(max_grad_norm), + # optax.clip_by_global_norm(max_grad_norm), optimizer_class, ), ) @@ -146,7 +146,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float, action, ), tx=optax.chain( - optax.clip_by_global_norm(max_grad_norm), + # optax.clip_by_global_norm(max_grad_norm), optimizer_class_qf, ), ) @@ -163,7 +163,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float, action, ), tx=optax.chain( - optax.clip_by_global_norm(max_grad_norm), + # optax.clip_by_global_norm(max_grad_norm), optimizer_class_qf, ), ) From 3f9727bc6d992fd758fb44cce75e51d14a1dcbe5 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 14 May 2025 15:44:26 +0200 Subject: [PATCH 25/28] Update to schedule classes --- sbx/dqn/dqn.py | 4 ++-- sbx/ppo/ppo.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/sbx/dqn/dqn.py b/sbx/dqn/dqn.py index 98b0f42..8d4338e 100644 --- a/sbx/dqn/dqn.py +++ b/sbx/dqn/dqn.py @@ -7,7 +7,7 @@ import numpy as np import optax from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule -from stable_baselines3.common.utils import get_linear_fn +from stable_baselines3.common.utils import LinearSchedule from sbx.common.off_policy_algorithm import OffPolicyAlgorithmJax from sbx.common.type_aliases import ReplayBufferSamplesNp, RLTrainState @@ -83,7 +83,7 @@ def __init__( def _setup_model(self) -> None: super()._setup_model() - self.exploration_schedule = get_linear_fn( + self.exploration_schedule = LinearSchedule( self.exploration_initial_eps, self.exploration_final_eps, self.exploration_fraction, diff --git a/sbx/ppo/ppo.py b/sbx/ppo/ppo.py index d58f926..73f4ede 100644 --- a/sbx/ppo/ppo.py +++ b/sbx/ppo/ppo.py @@ -8,7 +8,7 @@ from flax.training.train_state import TrainState from gymnasium import spaces from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule -from stable_baselines3.common.utils import explained_variance, get_schedule_fn +from stable_baselines3.common.utils import explained_variance, FloatSchedule from sbx.common.on_policy_algorithm import OnPolicyAlgorithmJax from sbx.common.utils import KLAdaptiveLR @@ -192,12 +192,12 @@ def _setup_model(self) -> None: self.vf = self.policy.vf # type: ignore[assignment] # Initialize schedules for policy/value clipping - self.clip_range_schedule = get_schedule_fn(self.clip_range) + self.clip_range_schedule = FloatSchedule(self.clip_range) # if self.clip_range_vf is not None: # if isinstance(self.clip_range_vf, (float, int)): # assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, " "pass `None` to deactivate vf clipping" # - # self.clip_range_vf = get_schedule_fn(self.clip_range_vf) + # self.clip_range_vf = FloatSchedule(self.clip_range_vf) @staticmethod @partial(jax.jit, static_argnames=["normalize_advantage"]) From d86e89d2b70fa3fd68907bffff9d8d5f8670758a Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 14 May 2025 15:48:33 +0200 Subject: [PATCH 26/28] Add lr schedule support for TQC --- sbx/common/off_policy_algorithm.py | 32 +++++++++++++++++++++++++++--- sbx/crossq/crossq.py | 2 +- sbx/ppo/ppo.py | 2 +- sbx/sac/sac.py | 2 +- sbx/tqc/tqc.py | 13 +++++++++++- 5 files changed, 44 insertions(+), 7 deletions(-) diff --git a/sbx/common/off_policy_algorithm.py b/sbx/common/off_policy_algorithm.py index d2efabe..8b2649b 100644 --- a/sbx/common/off_policy_algorithm.py +++ b/sbx/common/off_policy_algorithm.py @@ -4,6 +4,7 @@ import jax import numpy as np +import optax from gymnasium import spaces from stable_baselines3 import HerReplayBuffer from stable_baselines3.common.buffers import DictReplayBuffer, ReplayBuffer @@ -15,6 +16,8 @@ class OffPolicyAlgorithmJax(OffPolicyAlgorithm): + qf_learning_rate: float + def __init__( self, policy: type[BasePolicy], @@ -75,8 +78,8 @@ def __init__( ) # Will be updated later self.key = jax.random.PRNGKey(0) - # Note: we do not allow schedule for it - self.qf_learning_rate = qf_learning_rate + # Note: we do not allow separate schedule for it + self.initial_qf_learning_rate = qf_learning_rate self.param_resets = param_resets self.reset_idx = 0 @@ -100,6 +103,29 @@ def _excluded_save_params(self) -> list[str]: excluded.remove("policy") return excluded + def _update_learning_rate( # type: ignore[override] + self, + optimizers: Union[list[optax.OptState], optax.OptState], + learning_rate: float, + name: str = "learning_rate", + ) -> None: + """ + Update the optimizers learning rate using the current learning rate schedule + and the current progress remaining (from 1 to 0). + + :param optimizers: An optimizer or a list of optimizers. + :param learning_rate: The current learning rate to apply + :param name: (Optional) A custom name for the lr (for instance qf_learning_rate) + """ + # Log the current learning rate + self.logger.record(f"train/{name}", learning_rate) + + if not isinstance(optimizers, list): + optimizers = [optimizers] + for optimizer in optimizers: + # Note: the optimizer must have been defined with inject_hyperparams + optimizer.hyperparams["learning_rate"] = learning_rate + def set_random_seed(self, seed: Optional[int]) -> None: # type: ignore[override] super().set_random_seed(seed) if seed is None: @@ -116,7 +142,7 @@ def _setup_model(self) -> None: self._setup_lr_schedule() # By default qf_learning_rate = pi_learning_rate - self.qf_learning_rate = self.qf_learning_rate or self.lr_schedule(1) + self.qf_learning_rate = self.initial_qf_learning_rate or self.lr_schedule(1) self.set_random_seed(self.seed) # Make a local copy as we should not pickle # the environment when using HerReplayBuffer diff --git a/sbx/crossq/crossq.py b/sbx/crossq/crossq.py index ee9d74b..d099faa 100644 --- a/sbx/crossq/crossq.py +++ b/sbx/crossq/crossq.py @@ -158,7 +158,7 @@ def _setup_model(self) -> None: apply_fn=self.ent_coef.apply, params=self.ent_coef.init(ent_key)["params"], tx=optax.adam( - learning_rate=self.learning_rate, + learning_rate=self.lr_schedule(1), ), ) diff --git a/sbx/ppo/ppo.py b/sbx/ppo/ppo.py index 73f4ede..8cbad79 100644 --- a/sbx/ppo/ppo.py +++ b/sbx/ppo/ppo.py @@ -8,7 +8,7 @@ from flax.training.train_state import TrainState from gymnasium import spaces from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule -from stable_baselines3.common.utils import explained_variance, FloatSchedule +from stable_baselines3.common.utils import FloatSchedule, explained_variance from sbx.common.on_policy_algorithm import OnPolicyAlgorithmJax from sbx.common.utils import KLAdaptiveLR diff --git a/sbx/sac/sac.py b/sbx/sac/sac.py index f17fa96..e543d9c 100644 --- a/sbx/sac/sac.py +++ b/sbx/sac/sac.py @@ -159,7 +159,7 @@ def _setup_model(self) -> None: apply_fn=self.ent_coef.apply, params=self.ent_coef.init(ent_key)["params"], tx=optax.adam( - learning_rate=self.learning_rate, + learning_rate=self.lr_schedule(1), ), ) diff --git a/sbx/tqc/tqc.py b/sbx/tqc/tqc.py index 730d13a..9db4125 100644 --- a/sbx/tqc/tqc.py +++ b/sbx/tqc/tqc.py @@ -163,7 +163,7 @@ def _setup_model(self) -> None: apply_fn=self.ent_coef.apply, params=self.ent_coef.init(ent_key)["params"], tx=optax.adam( - learning_rate=self.learning_rate, + learning_rate=self.lr_schedule(1), ), ) @@ -199,6 +199,17 @@ def train(self, gradient_steps: int, batch_size: int) -> None: # Sample all at once for efficiency (so we can jit the for loop) data = self.replay_buffer.sample(batch_size * gradient_steps, env=self._vec_normalize_env) + self._update_learning_rate( + self.policy.actor_state.opt_state[0], + learning_rate=self.lr_schedule(self._current_progress_remaining), + name="learning_rate_actor", + ) + # Note: for now same schedule for actor and critic unless qf_lr = cst + self._update_learning_rate( + [self.policy.qf1_state.opt_state[0], self.policy.qf2_state.opt_state[0]], + learning_rate=self.initial_qf_learning_rate or self.lr_schedule(self._current_progress_remaining), + name="learning_rate_critic", + ) # Maybe reset the parameters/optimizers fully self._maybe_reset_params() From c668fd14cd1094dcd6ec15784591b7a557809ac8 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 19 May 2025 12:32:22 +0200 Subject: [PATCH 27/28] Revert experimental changes and add support for lr schedule for SAC --- sbx/common/off_policy_algorithm.py | 58 ---------------- sbx/common/policies.py | 35 ++-------- sbx/ppo/policies.py | 103 ----------------------------- sbx/ppo/ppo.py | 6 +- sbx/sac/policies.py | 41 +++++------- sbx/sac/sac.py | 17 +++-- sbx/tqc/policies.py | 28 ++------ sbx/tqc/tqc.py | 9 +-- sbx/version.txt | 2 +- setup.py | 2 +- tests/test_run.py | 17 ----- 11 files changed, 49 insertions(+), 269 deletions(-) diff --git a/sbx/common/off_policy_algorithm.py b/sbx/common/off_policy_algorithm.py index 8b2649b..ae205a3 100644 --- a/sbx/common/off_policy_algorithm.py +++ b/sbx/common/off_policy_algorithm.py @@ -172,61 +172,3 @@ def load_replay_buffer( # Override replay buffer device to be always cpu for conversion to numpy assert self.replay_buffer is not None self.replay_buffer.device = get_device("cpu") - - def _sample_action( - self, - learning_starts: int, - action_noise: Optional[ActionNoise] = None, - n_envs: int = 1, - ) -> tuple[np.ndarray, np.ndarray]: - """ - Sample an action according to the exploration policy. - This is either done by sampling the probability distribution of the policy, - or sampling a random action (from a uniform distribution over the action space) - or by adding noise to the deterministic output. - - :param action_noise: Action noise that will be used for exploration - Required for deterministic policy (e.g. TD3). This can also be used - in addition to the stochastic policy for SAC. - :param learning_starts: Number of steps before learning for the warm-up phase. - :param n_envs: - :return: action to take in the environment - and scaled action that will be stored in the replay buffer. - The two differs when the action space is not normalized (bounds are not [-1, 1]). - """ - scaled_action = np.array([0.0]) - # Select action randomly or according to policy - if self.num_timesteps < learning_starts and not (self.use_sde and self.use_sde_at_warmup): - # Warmup phase - action = np.array([self.action_space.sample() for _ in range(n_envs)]) - if isinstance(self.action_space, spaces.Box): - scaled_action = self.policy.scale_action(action) - else: - assert self._last_obs is not None, "self._last_obs was not set" - obs_tensor, _ = self.policy.prepare_obs(self._last_obs) - action = np.array(self.policy._predict(obs_tensor, deterministic=False)) - if self.policy.squash_output: - scaled_action = action - - # Rescale the action from [low, high] to [-1, 1] - if isinstance(self.action_space, spaces.Box) and self.policy.squash_output: - # Add noise to the action (improve exploration) - if action_noise is not None: - scaled_action = np.clip(scaled_action + action_noise(), -1, 1) - - # We store the scaled action in the buffer - buffer_action = scaled_action - action = self.policy.unscale_action(scaled_action) - elif isinstance(self.action_space, spaces.Box) and not self.policy.squash_output: - # Add noise to the action (improve exploration) - if action_noise is not None: - action = action + action_noise() - - buffer_action = action - # Actions could be on arbitrary scale, so clip the actions to avoid - # out of bound error (e.g. if sampling from a Gaussian distribution) - action = np.clip(action, self.action_space.low, self.action_space.high) - else: - # Discrete case, no need to normalize or clip - buffer_action = action - return action, buffer_action diff --git a/sbx/common/policies.py b/sbx/common/policies.py index 3c52ad2..ce23d09 100644 --- a/sbx/common/policies.py +++ b/sbx/common/policies.py @@ -236,16 +236,12 @@ def __call__(self, obs: jnp.ndarray, action: jnp.ndarray): return q_values -class GaussianActor(nn.Module): +class SquashedGaussianActor(nn.Module): net_arch: Sequence[int] action_dim: int log_std_min: float = -20 log_std_max: float = 2 activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu - squash_output: bool = True - ortho_init: bool = False - log_std_init: float = 0.0 # -1.2 # log(0.3) - use_layer_norm: bool = False def get_std(self): # Make it work with gSDE @@ -256,30 +252,13 @@ def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-def x = Flatten()(x) for n_units in self.net_arch: x = nn.Dense(n_units)(x) - if self.use_layer_norm: - x = nn.LayerNorm()(x) x = self.activation_fn(x) - - if self.ortho_init: - orthogonal_init = nn.initializers.orthogonal(scale=0.01) - # orthogonal_init = nn.initializers.uniform(scale=0.01) - # orthogonal_init = nn.initializers.normal(stddev=0.01) - bias_init = nn.initializers.zeros - mean = nn.Dense(self.action_dim, kernel_init=orthogonal_init, bias_init=bias_init)(x) - log_std = self.param("log_std", nn.initializers.constant(self.log_std_init), (self.action_dim,)) - # log_std = nn.Dense(self.action_dim, kernel_init=orthogonal_init, bias_init=bias_init)(x) - else: - mean = nn.Dense(self.action_dim)(x) - log_std = nn.Dense(self.action_dim)(x) - - if self.squash_output: - log_std = jnp.clip(log_std, self.log_std_min, self.log_std_max) - dist = TanhTransformedDistribution( - tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)), - ) - else: - log_std = jnp.clip(log_std, self.log_std_min, self.log_std_max) - dist = tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)) + mean = nn.Dense(self.action_dim)(x) + log_std = nn.Dense(self.action_dim)(x) + log_std = jnp.clip(log_std, self.log_std_min, self.log_std_max) + dist = TanhTransformedDistribution( + tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)), + ) return dist diff --git a/sbx/ppo/policies.py b/sbx/ppo/policies.py index 395fc94..4cb73b8 100644 --- a/sbx/ppo/policies.py +++ b/sbx/ppo/policies.py @@ -14,7 +14,6 @@ from gymnasium import spaces from stable_baselines3.common.type_aliases import Schedule -from sbx.common.jax_layers import SimbaResidualBlock from sbx.common.policies import BaseJaxPolicy, Flatten tfd = tfp.distributions @@ -35,23 +34,6 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: return x -class SimbaCritic(nn.Module): - net_arch: Sequence[int] - activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu - scale_factor: int = 4 - - @nn.compact - def __call__(self, x: jnp.ndarray) -> jnp.ndarray: - x = Flatten()(x) - x = nn.Dense(self.net_arch[0])(x) - for n_units in self.net_arch: - x = SimbaResidualBlock(n_units, self.activation_fn, self.scale_factor)(x) - - x = nn.LayerNorm()(x) - x = nn.Dense(1)(x) - return x - - class Actor(nn.Module): action_dim: int net_arch: Sequence[int] @@ -125,47 +107,6 @@ def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-def return dist -class SimbaActor(nn.Module): - action_dim: int - net_arch: Sequence[int] - log_std_init: float = 0.0 - activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu - # For Discrete, MultiDiscrete and MultiBinary actions - num_discrete_choices: Optional[Union[int, Sequence[int]]] = None - # For MultiDiscrete - max_num_choices: int = 0 - # Last layer with small scale - ortho_init: bool = False - scale_factor: int = 4 - - def get_std(self) -> jnp.ndarray: - # Make it work with gSDE - return jnp.array(0.0) - - @nn.compact - def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-defined] - x = Flatten()(x) - - x = nn.Dense(self.net_arch[0])(x) - for n_units in self.net_arch: - x = SimbaResidualBlock(n_units, self.activation_fn, self.scale_factor)(x) - x = nn.LayerNorm()(x) - - if self.ortho_init: - orthogonal_init = nn.initializers.orthogonal(scale=0.01) - bias_init = nn.initializers.zeros - mean_action = nn.Dense(self.action_dim, kernel_init=orthogonal_init, bias_init=bias_init)(x) - - else: - mean_action = nn.Dense(self.action_dim)(x) - - # Continuous actions - log_std = self.param("log_std", constant(self.log_std_init), (self.action_dim,)) - dist = tfd.MultivariateNormalDiag(loc=mean_action, scale_diag=jnp.exp(log_std)) - - return dist - - class PPOPolicy(BaseJaxPolicy): def __init__( self, @@ -331,47 +272,3 @@ def _predict_all(actor_state, vf_state, observations, key): log_probs = dist.log_prob(actions) values = vf_state.apply_fn(vf_state.params, observations).flatten() return actions, log_probs, values - - -class SimbaPPOPolicy(PPOPolicy): - def __init__( - self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, - lr_schedule: Schedule, - net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, - ortho_init: bool = False, - log_std_init: float = 0, - activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.tanh, - use_sde: bool = False, - use_expln: bool = False, - clip_mean: float = 2, - features_extractor_class=None, - features_extractor_kwargs: Optional[dict[str, Any]] = None, - normalize_images: bool = True, - optimizer_class: Callable[..., optax.GradientTransformation] = optax.adam, - optimizer_kwargs: Optional[dict[str, Any]] = None, - share_features_extractor: bool = False, - actor_class: type[nn.Module] = SimbaActor, - critic_class: type[nn.Module] = SimbaCritic, - ): - super().__init__( - observation_space, - action_space, - lr_schedule, - net_arch, - ortho_init, - log_std_init, - activation_fn, - use_sde, - use_expln, - clip_mean, - features_extractor_class, - features_extractor_kwargs, - normalize_images, - optimizer_class, - optimizer_kwargs, - share_features_extractor, - actor_class, - critic_class, - ) diff --git a/sbx/ppo/ppo.py b/sbx/ppo/ppo.py index 8cbad79..b9d5f32 100644 --- a/sbx/ppo/ppo.py +++ b/sbx/ppo/ppo.py @@ -12,7 +12,7 @@ from sbx.common.on_policy_algorithm import OnPolicyAlgorithmJax from sbx.common.utils import KLAdaptiveLR -from sbx.ppo.policies import PPOPolicy, SimbaPPOPolicy +from sbx.ppo.policies import PPOPolicy PPOSelf = TypeVar("PPOSelf", bound="PPO") @@ -55,7 +55,7 @@ class PPO(OnPolicyAlgorithmJax): instead of action noise exploration (default: False) :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout) - :param target_kl: Update the learning rate based on a desired KL divergence. + :param target_kl: Update the learning rate based on a desired KL divergence (see https://arxiv.org/abs/1707.02286). Note: this will overwrite any lr schedule. By default, there is no limit on the kl div. :param tensorboard_log: the log location for tensorboard (if None, no logging) @@ -70,8 +70,6 @@ class PPO(OnPolicyAlgorithmJax): policy_aliases: ClassVar[dict[str, type[PPOPolicy]]] = { # type: ignore[assignment] "MlpPolicy": PPOPolicy, - # Residual net, from https://github.com/SonyResearch/simba - "SimbaPolicy": SimbaPPOPolicy, # "CnnPolicy": ActorCriticCnnPolicy, # "MultiInputPolicy": MultiInputActorCriticPolicy, } diff --git a/sbx/sac/policies.py b/sbx/sac/policies.py index 2cae5ea..e77e57d 100644 --- a/sbx/sac/policies.py +++ b/sbx/sac/policies.py @@ -11,9 +11,9 @@ from sbx.common.policies import ( BaseJaxPolicy, - GaussianActor, SimbaSquashedGaussianActor, SimbaVectorCritic, + SquashedGaussianActor, VectorCritic, ) from sbx.common.type_aliases import RLTrainState @@ -32,9 +32,9 @@ def __init__( layer_norm: bool = False, activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu, use_sde: bool = False, - log_std_init: float = 0.0, - squash_output: bool = True, - ortho_init: bool = False, + # Note: most gSDE parameters are not used + # this is to keep API consistent with SB3 + log_std_init: float = -3, use_expln: bool = False, clip_mean: float = 2.0, features_extractor_class=None, @@ -44,7 +44,7 @@ def __init__( optimizer_kwargs: Optional[dict[str, Any]] = None, n_critics: int = 2, share_features_extractor: bool = False, - actor_class: type[nn.Module] = GaussianActor, + actor_class: type[nn.Module] = SquashedGaussianActor, vector_critic_class: type[nn.Module] = VectorCritic, ): super().__init__( @@ -54,7 +54,7 @@ def __init__( features_extractor_kwargs, optimizer_class=optimizer_class, optimizer_kwargs=optimizer_kwargs, - squash_output=squash_output, + squash_output=True, ) self.dropout_rate = dropout_rate self.layer_norm = layer_norm @@ -71,8 +71,6 @@ def __init__( self.activation_fn = activation_fn self.actor_class = actor_class self.vector_critic_class = vector_critic_class - self.log_std_init = log_std_init - self.ortho_init = ortho_init self.key = self.noise_key = jax.random.PRNGKey(0) @@ -93,20 +91,18 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) action_dim=int(np.prod(self.action_space.shape)), net_arch=self.net_arch_pi, activation_fn=self.activation_fn, - squash_output=self.squash_output, - log_std_init=self.log_std_init, - ortho_init=self.ortho_init, ) # Hack to make gSDE work without modifying internal SB3 code self.actor.reset_noise = self.reset_noise + # Inject hyperparameters to be able to modify it later + # See https://stackoverflow.com/questions/78527164 + optimizer_class = optax.inject_hyperparams(self.optimizer_class)(learning_rate=lr_schedule(1), **self.optimizer_kwargs) + self.actor_state = TrainState.create( apply_fn=self.actor.apply, params=self.actor.init(actor_key, obs), - tx=self.optimizer_class( - learning_rate=lr_schedule(1), # type: ignore[call-arg] - **self.optimizer_kwargs, - ), + tx=optimizer_class, ) self.qf = self.vector_critic_class( @@ -117,6 +113,10 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) activation_fn=self.activation_fn, ) + optimizer_class_qf = optax.inject_hyperparams(self.optimizer_class)( + learning_rate=qf_learning_rate, **self.optimizer_kwargs + ) + self.qf_state = RLTrainState.create( apply_fn=self.qf.apply, params=self.qf.init( @@ -129,10 +129,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) obs, action, ), - tx=self.optimizer_class( - learning_rate=qf_learning_rate, # type: ignore[call-arg] - **self.optimizer_kwargs, - ), + tx=optimizer_class_qf, ) self.actor.apply = jax.jit(self.actor.apply) # type: ignore[method-assign] @@ -172,9 +169,7 @@ def __init__( layer_norm: bool = False, activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu, use_sde: bool = False, - log_std_init: float = 0.0, - squash_output: bool = True, - ortho_init: bool = False, + log_std_init: float = -3, use_expln: bool = False, clip_mean: float = 2, features_extractor_class=None, @@ -198,8 +193,6 @@ def __init__( activation_fn, use_sde, log_std_init, - squash_output, - ortho_init, use_expln, clip_mean, features_extractor_class, diff --git a/sbx/sac/sac.py b/sbx/sac/sac.py index e543d9c..73ff2f8 100644 --- a/sbx/sac/sac.py +++ b/sbx/sac/sac.py @@ -195,6 +195,18 @@ def train(self, gradient_steps: int, batch_size: int) -> None: # Sample all at once for efficiency (so we can jit the for loop) data = self.replay_buffer.sample(batch_size * gradient_steps, env=self._vec_normalize_env) + self._update_learning_rate( + self.policy.actor_state.opt_state, + learning_rate=self.lr_schedule(self._current_progress_remaining), + name="learning_rate_actor", + ) + # Note: for now same schedule for actor and critic unless qf_lr = cst + self._update_learning_rate( + self.policy.qf_state.opt_state, + learning_rate=self.initial_qf_learning_rate or self.lr_schedule(self._current_progress_remaining), + name="learning_rate_critic", + ) + # Maybe reset the parameters/optimizers fully self._maybe_reset_params() @@ -240,11 +252,6 @@ def train(self, gradient_steps: int, batch_size: int) -> None: self.logger.record("train/critic_loss", qf_loss_value.item()) self.logger.record("train/ent_coef_loss", ent_coef_loss_value.item()) self.logger.record("train/ent_coef", ent_coef_value.item()) - try: - log_std = self.policy.actor_state.params["params"]["log_std"] - self.logger.record("train/std", np.exp(log_std).mean().item()) - except KeyError: - pass @staticmethod @jax.jit diff --git a/sbx/tqc/policies.py b/sbx/tqc/policies.py index 3d579ac..a884687 100644 --- a/sbx/tqc/policies.py +++ b/sbx/tqc/policies.py @@ -12,9 +12,9 @@ from sbx.common.policies import ( BaseJaxPolicy, ContinuousCritic, - GaussianActor, SimbaContinuousCritic, SimbaSquashedGaussianActor, + SquashedGaussianActor, ) from sbx.common.type_aliases import RLTrainState @@ -34,9 +34,10 @@ def __init__( n_quantiles: int = 25, activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu, use_sde: bool = False, + # Note: most gSDE parameters are not used + # this is to keep API consistent with SB3 log_std_init: float = 0.0, squash_output: bool = True, - ortho_init: bool = False, use_expln: bool = False, clip_mean: float = 2.0, features_extractor_class=None, @@ -46,7 +47,7 @@ def __init__( optimizer_kwargs: Optional[dict[str, Any]] = None, n_critics: int = 2, share_features_extractor: bool = False, - actor_class: type[nn.Module] = GaussianActor, + actor_class: type[nn.Module] = SquashedGaussianActor, critic_class: type[nn.Module] = ContinuousCritic, ): super().__init__( @@ -80,7 +81,6 @@ def __init__( self.actor_class = actor_class self.critic_class = critic_class self.log_std_init = log_std_init - self.ortho_init = ortho_init self.key = self.noise_key = jax.random.PRNGKey(0) @@ -101,9 +101,6 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) action_dim=int(np.prod(self.action_space.shape)), net_arch=self.net_arch_pi, activation_fn=self.activation_fn, - squash_output=self.squash_output, - log_std_init=self.log_std_init, - ortho_init=self.ortho_init, ) # Hack to make gSDE work without modifying internal SB3 code self.actor.reset_noise = self.reset_noise @@ -115,10 +112,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) self.actor_state = TrainState.create( apply_fn=self.actor.apply, params=self.actor.init(actor_key, obs), - tx=optax.chain( - # optax.clip_by_global_norm(max_grad_norm), - optimizer_class, - ), + tx=optimizer_class, ) self.qf = self.critic_class( @@ -145,10 +139,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) obs, action, ), - tx=optax.chain( - # optax.clip_by_global_norm(max_grad_norm), - optimizer_class_qf, - ), + tx=optimizer_class_qf, ) self.qf2_state = RLTrainState.create( apply_fn=self.qf.apply, @@ -162,10 +153,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) obs, action, ), - tx=optax.chain( - # optax.clip_by_global_norm(max_grad_norm), - optimizer_class_qf, - ), + tx=optimizer_class_qf, ) self.actor.apply = jax.jit(self.actor.apply) # type: ignore[method-assign] self.qf.apply = jax.jit( # type: ignore[method-assign] @@ -208,7 +196,6 @@ def __init__( use_sde: bool = False, log_std_init: float = 0.0, squash_output: bool = True, - ortho_init: bool = False, use_expln: bool = False, clip_mean: float = 2, features_extractor_class=None, @@ -234,7 +221,6 @@ def __init__( use_sde, log_std_init, squash_output, - ortho_init, use_expln, clip_mean, features_extractor_class, diff --git a/sbx/tqc/tqc.py b/sbx/tqc/tqc.py index 9db4125..9593535 100644 --- a/sbx/tqc/tqc.py +++ b/sbx/tqc/tqc.py @@ -200,13 +200,13 @@ def train(self, gradient_steps: int, batch_size: int) -> None: data = self.replay_buffer.sample(batch_size * gradient_steps, env=self._vec_normalize_env) self._update_learning_rate( - self.policy.actor_state.opt_state[0], + self.policy.actor_state.opt_state, learning_rate=self.lr_schedule(self._current_progress_remaining), name="learning_rate_actor", ) # Note: for now same schedule for actor and critic unless qf_lr = cst self._update_learning_rate( - [self.policy.qf1_state.opt_state[0], self.policy.qf2_state.opt_state[0]], + [self.policy.qf1_state.opt_state, self.policy.qf2_state.opt_state], learning_rate=self.initial_qf_learning_rate or self.lr_schedule(self._current_progress_remaining), name="learning_rate_critic", ) @@ -257,11 +257,6 @@ def train(self, gradient_steps: int, batch_size: int) -> None: self.logger.record("train/critic_loss", qf1_loss_value.item()) self.logger.record("train/ent_coef_loss", ent_coef_loss_value.item()) self.logger.record("train/ent_coef", ent_coef_value.item()) - try: - log_std = self.policy.actor_state.params["params"]["log_std"] - self.logger.record("train/std", np.exp(log_std).mean().item()) - except KeyError: - pass @staticmethod @partial(jax.jit, static_argnames=["n_target_quantiles"]) diff --git a/sbx/version.txt b/sbx/version.txt index 5a03fb7..8854156 100644 --- a/sbx/version.txt +++ b/sbx/version.txt @@ -1 +1 @@ -0.20.0 +0.21.0 diff --git a/setup.py b/setup.py index ec2e1dc..a8655da 100644 --- a/setup.py +++ b/setup.py @@ -41,7 +41,7 @@ packages=[package for package in find_packages() if package.startswith("sbx")], package_data={"sbx": ["py.typed", "version.txt"]}, install_requires=[ - "stable_baselines3>=2.5.0,<3.0", + "stable_baselines3>=2.6.1a1,<3.0", "jax>=0.4.24", "jaxlib", "flax", diff --git a/tests/test_run.py b/tests/test_run.py index 9fcb04a..6d3b5b9 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -160,23 +160,6 @@ def test_ppo(tmp_path, env_id: str) -> None: check_save_load(model, PPO, tmp_path) -def test_simba_ppo(tmp_path) -> None: - model = PPO( - "SimbaPolicy", - "Pendulum-v1", - verbose=1, - n_steps=32, - batch_size=32, - n_epochs=2, - policy_kwargs=dict(activation_fn=nn.leaky_relu, net_arch=[64]), - # Test adaptive lr - target_kl=0.01, - ) - model.learn(64, progress_bar=True) - - check_save_load(model, PPO, tmp_path) - - def test_dqn(tmp_path) -> None: model = DQN( "MlpPolicy", From 22b3e5483a6ae4f34521f050b102574859a3f914 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 19 May 2025 12:38:43 +0200 Subject: [PATCH 28/28] Add test for adaptive kl div, remove squash output param --- sbx/tqc/policies.py | 5 +---- tests/test_run.py | 2 ++ 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/sbx/tqc/policies.py b/sbx/tqc/policies.py index a884687..4783e91 100644 --- a/sbx/tqc/policies.py +++ b/sbx/tqc/policies.py @@ -37,7 +37,6 @@ def __init__( # Note: most gSDE parameters are not used # this is to keep API consistent with SB3 log_std_init: float = 0.0, - squash_output: bool = True, use_expln: bool = False, clip_mean: float = 2.0, features_extractor_class=None, @@ -57,7 +56,7 @@ def __init__( features_extractor_kwargs, optimizer_class=optimizer_class, optimizer_kwargs=optimizer_kwargs, - squash_output=squash_output, + squash_output=True, ) self.dropout_rate = dropout_rate self.layer_norm = layer_norm @@ -195,7 +194,6 @@ def __init__( activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu, use_sde: bool = False, log_std_init: float = 0.0, - squash_output: bool = True, use_expln: bool = False, clip_mean: float = 2, features_extractor_class=None, @@ -220,7 +218,6 @@ def __init__( activation_fn, use_sde, log_std_init, - squash_output, use_expln, clip_mean, features_extractor_class, diff --git a/tests/test_run.py b/tests/test_run.py index 6d3b5b9..c6e8ca1 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -146,6 +146,7 @@ def test_policy_kwargs(model_class) -> None: @pytest.mark.parametrize("env_id", ["Pendulum-v1", "CartPole-v1"]) def test_ppo(tmp_path, env_id: str) -> None: + model = PPO( "MlpPolicy", env_id, @@ -154,6 +155,7 @@ def test_ppo(tmp_path, env_id: str) -> None: batch_size=32, n_epochs=2, policy_kwargs=dict(activation_fn=nn.leaky_relu), + target_kl=0.04 if env_id == "Pendulum-v1" else None, ) model.learn(64, progress_bar=True)