diff --git a/sbx/common/off_policy_algorithm.py b/sbx/common/off_policy_algorithm.py index b7b385c..ae205a3 100644 --- a/sbx/common/off_policy_algorithm.py +++ b/sbx/common/off_policy_algorithm.py @@ -4,6 +4,7 @@ import jax import numpy as np +import optax from gymnasium import spaces from stable_baselines3 import HerReplayBuffer from stable_baselines3.common.buffers import DictReplayBuffer, ReplayBuffer @@ -15,6 +16,8 @@ class OffPolicyAlgorithmJax(OffPolicyAlgorithm): + qf_learning_rate: float + def __init__( self, policy: type[BasePolicy], @@ -75,8 +78,8 @@ def __init__( ) # Will be updated later self.key = jax.random.PRNGKey(0) - # Note: we do not allow schedule for it - self.qf_learning_rate = qf_learning_rate + # Note: we do not allow separate schedule for it + self.initial_qf_learning_rate = qf_learning_rate self.param_resets = param_resets self.reset_idx = 0 @@ -89,7 +92,7 @@ def _maybe_reset_params(self) -> None: ): # Note: we are not resetting the entropy coeff assert isinstance(self.qf_learning_rate, float) - self.key = self.policy.build(self.key, self.lr_schedule, self.qf_learning_rate) + self.key = self.policy.build(self.key, self.lr_schedule, self.qf_learning_rate) # type: ignore[operator] self.reset_idx += 1 def _get_torch_save_params(self): @@ -100,6 +103,29 @@ def _excluded_save_params(self) -> list[str]: excluded.remove("policy") return excluded + def _update_learning_rate( # type: ignore[override] + self, + optimizers: Union[list[optax.OptState], optax.OptState], + learning_rate: float, + name: str = "learning_rate", + ) -> None: + """ + Update the optimizers learning rate using the current learning rate schedule + and the current progress remaining (from 1 to 0). + + :param optimizers: An optimizer or a list of optimizers. + :param learning_rate: The current learning rate to apply + :param name: (Optional) A custom name for the lr (for instance qf_learning_rate) + """ + # Log the current learning rate + self.logger.record(f"train/{name}", learning_rate) + + if not isinstance(optimizers, list): + optimizers = [optimizers] + for optimizer in optimizers: + # Note: the optimizer must have been defined with inject_hyperparams + optimizer.hyperparams["learning_rate"] = learning_rate + def set_random_seed(self, seed: Optional[int]) -> None: # type: ignore[override] super().set_random_seed(seed) if seed is None: @@ -116,7 +142,7 @@ def _setup_model(self) -> None: self._setup_lr_schedule() # By default qf_learning_rate = pi_learning_rate - self.qf_learning_rate = self.qf_learning_rate or self.lr_schedule(1) + self.qf_learning_rate = self.initial_qf_learning_rate or self.lr_schedule(1) self.set_random_seed(self.seed) # Make a local copy as we should not pickle # the environment when using HerReplayBuffer diff --git a/sbx/common/on_policy_algorithm.py b/sbx/common/on_policy_algorithm.py index 36b8ea6..d0c7d97 100644 --- a/sbx/common/on_policy_algorithm.py +++ b/sbx/common/on_policy_algorithm.py @@ -3,6 +3,7 @@ import gymnasium as gym import jax import numpy as np +import optax import torch as th from gymnasium import spaces from stable_baselines3.common.buffers import RolloutBuffer @@ -75,6 +76,27 @@ def _excluded_save_params(self) -> list[str]: excluded.remove("policy") return excluded + def _update_learning_rate( # type: ignore[override] + self, + optimizers: Union[list[optax.OptState], optax.OptState], + learning_rate: float, + ) -> None: + """ + Update the optimizers learning rate using the current learning rate schedule + and the current progress remaining (from 1 to 0). + + :param optimizers: + An optimizer or a list of optimizers. + """ + # Log the current learning rate + self.logger.record("train/learning_rate", learning_rate) + + if not isinstance(optimizers, list): + optimizers = [optimizers] + for optimizer in optimizers: + # Note: the optimizer must have been defined with inject_hyperparams + optimizer.hyperparams["learning_rate"] = learning_rate + def set_random_seed(self, seed: Optional[int]) -> None: # type: ignore[override] super().set_random_seed(seed) if seed is None: @@ -167,12 +189,8 @@ def collect_rollouts( # Handle timeout by bootstraping with value function # see GitHub issue #633 - for idx, done in enumerate(dones): - if ( - done - and infos[idx].get("terminal_observation") is not None - and infos[idx].get("TimeLimit.truncated", False) - ): + for idx in dones.nonzero()[0]: + if infos[idx].get("terminal_observation") is not None and infos[idx].get("TimeLimit.truncated", False): terminal_obs = self.policy.prepare_obs(infos[idx]["terminal_observation"])[0] terminal_value = np.array( self.vf.apply( # type: ignore[union-attr] diff --git a/sbx/common/utils.py b/sbx/common/utils.py new file mode 100644 index 0000000..447c4f3 --- /dev/null +++ b/sbx/common/utils.py @@ -0,0 +1,26 @@ +from dataclasses import dataclass + +import numpy as np + + +@dataclass +class KLAdaptiveLR: + """Adaptive lr schedule, see https://arxiv.org/abs/1707.02286""" + + # If set will trigger adaptive lr + target_kl: float + current_adaptive_lr: float + # Values taken from https://github.com/leggedrobotics/rsl_rl + min_learning_rate: float = 1e-5 + max_learning_rate: float = 1e-2 + kl_margin: float = 2.0 + # Divide or multiply the lr by this factor + adaptive_lr_factor: float = 1.5 + + def update(self, kl_div: float) -> None: + if kl_div > self.target_kl * self.kl_margin: + self.current_adaptive_lr /= self.adaptive_lr_factor + elif kl_div < self.target_kl / self.kl_margin: + self.current_adaptive_lr *= self.adaptive_lr_factor + + self.current_adaptive_lr = np.clip(self.current_adaptive_lr, self.min_learning_rate, self.max_learning_rate) diff --git a/sbx/crossq/crossq.py b/sbx/crossq/crossq.py index ee9d74b..d099faa 100644 --- a/sbx/crossq/crossq.py +++ b/sbx/crossq/crossq.py @@ -158,7 +158,7 @@ def _setup_model(self) -> None: apply_fn=self.ent_coef.apply, params=self.ent_coef.init(ent_key)["params"], tx=optax.adam( - learning_rate=self.learning_rate, + learning_rate=self.lr_schedule(1), ), ) diff --git a/sbx/dqn/dqn.py b/sbx/dqn/dqn.py index 98b0f42..8d4338e 100644 --- a/sbx/dqn/dqn.py +++ b/sbx/dqn/dqn.py @@ -7,7 +7,7 @@ import numpy as np import optax from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule -from stable_baselines3.common.utils import get_linear_fn +from stable_baselines3.common.utils import LinearSchedule from sbx.common.off_policy_algorithm import OffPolicyAlgorithmJax from sbx.common.type_aliases import ReplayBufferSamplesNp, RLTrainState @@ -83,7 +83,7 @@ def __init__( def _setup_model(self) -> None: super()._setup_model() - self.exploration_schedule = get_linear_fn( + self.exploration_schedule = LinearSchedule( self.exploration_initial_eps, self.exploration_final_eps, self.exploration_fraction, diff --git a/sbx/ppo/policies.py b/sbx/ppo/policies.py index 570849e..4cb73b8 100644 --- a/sbx/ppo/policies.py +++ b/sbx/ppo/policies.py @@ -44,6 +44,8 @@ class Actor(nn.Module): # For MultiDiscrete max_num_choices: int = 0 split_indices: np.ndarray = field(default_factory=lambda: np.array([])) + # Last layer with small scale + ortho_init: bool = False def get_std(self) -> jnp.ndarray: # Make it work with gSDE @@ -65,7 +67,15 @@ def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-def x = nn.Dense(n_units)(x) x = self.activation_fn(x) - action_logits = nn.Dense(self.action_dim)(x) + if self.ortho_init: + orthogonal_init = nn.initializers.orthogonal(scale=0.01) + bias_init = nn.initializers.zeros + action_logits = nn.Dense(self.action_dim, kernel_init=orthogonal_init, bias_init=bias_init)(x) + + else: + action_logits = nn.Dense(self.action_dim)(x) + + log_std = jnp.zeros(1) if self.num_discrete_choices is None: # Continuous actions log_std = self.param("log_std", constant(self.log_std_init), (self.action_dim,)) @@ -118,6 +128,8 @@ def __init__( optimizer_class: Callable[..., optax.GradientTransformation] = optax.adam, optimizer_kwargs: Optional[dict[str, Any]] = None, share_features_extractor: bool = False, + actor_class: type[nn.Module] = Actor, + critic_class: type[nn.Module] = Critic, ): if optimizer_kwargs is None: # Small values to avoid NaN in Adam optimizer @@ -146,6 +158,9 @@ def __init__( else: self.net_arch_pi = self.net_arch_vf = [64, 64] self.use_sde = use_sde + self.ortho_init = ortho_init + self.actor_class = actor_class + self.critic_class = critic_class self.key = self.noise_key = jax.random.PRNGKey(0) @@ -189,38 +204,38 @@ def build(self, key: jax.Array, lr_schedule: Schedule, max_grad_norm: float) -> else: raise NotImplementedError(f"{self.action_space}") - self.actor = Actor( + self.actor = self.actor_class( net_arch=self.net_arch_pi, log_std_init=self.log_std_init, activation_fn=self.activation_fn, + ortho_init=self.ortho_init, **actor_kwargs, # type: ignore[arg-type] ) # Hack to make gSDE work without modifying internal SB3 code self.actor.reset_noise = self.reset_noise + # Inject hyperparameters to be able to modify it later + # See https://stackoverflow.com/questions/78527164 + # Note: eps=1e-5 for Adam + optimizer_class = optax.inject_hyperparams(self.optimizer_class)(learning_rate=lr_schedule(1), **self.optimizer_kwargs) + self.actor_state = TrainState.create( apply_fn=self.actor.apply, params=self.actor.init(actor_key, obs), tx=optax.chain( optax.clip_by_global_norm(max_grad_norm), - self.optimizer_class( - learning_rate=lr_schedule(1), # type: ignore[call-arg] - **self.optimizer_kwargs, # , eps=1e-5 - ), + optimizer_class, ), ) - self.vf = Critic(net_arch=self.net_arch_vf, activation_fn=self.activation_fn) + self.vf = self.critic_class(net_arch=self.net_arch_vf, activation_fn=self.activation_fn) self.vf_state = TrainState.create( apply_fn=self.vf.apply, params=self.vf.init({"params": vf_key}, obs), tx=optax.chain( optax.clip_by_global_norm(max_grad_norm), - self.optimizer_class( - learning_rate=lr_schedule(1), # type: ignore[call-arg] - **self.optimizer_kwargs, # , eps=1e-5 - ), + optimizer_class, ), ) diff --git a/sbx/ppo/ppo.py b/sbx/ppo/ppo.py index acbdea6..b9d5f32 100644 --- a/sbx/ppo/ppo.py +++ b/sbx/ppo/ppo.py @@ -8,9 +8,10 @@ from flax.training.train_state import TrainState from gymnasium import spaces from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule -from stable_baselines3.common.utils import explained_variance, get_schedule_fn +from stable_baselines3.common.utils import FloatSchedule, explained_variance from sbx.common.on_policy_algorithm import OnPolicyAlgorithmJax +from sbx.common.utils import KLAdaptiveLR from sbx.ppo.policies import PPOPolicy PPOSelf = TypeVar("PPOSelf", bound="PPO") @@ -54,9 +55,8 @@ class PPO(OnPolicyAlgorithmJax): instead of action noise exploration (default: False) :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout) - :param target_kl: Limit the KL divergence between updates, - because the clipping is not enough to prevent large update - see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213) + :param target_kl: Update the learning rate based on a desired KL divergence (see https://arxiv.org/abs/1707.02286). + Note: this will overwrite any lr schedule. By default, there is no limit on the kl div. :param tensorboard_log: the log location for tensorboard (if None, no logging) :param policy_kwargs: additional arguments to be passed to the policy on creation @@ -74,6 +74,7 @@ class PPO(OnPolicyAlgorithmJax): # "MultiInputPolicy": MultiInputActorCriticPolicy, } policy: PPOPolicy # type: ignore[assignment] + adaptive_lr: KLAdaptiveLR def __init__( self, @@ -159,7 +160,10 @@ def __init__( self.clip_range = clip_range self.clip_range_vf = clip_range_vf self.normalize_advantage = normalize_advantage + # If set will trigger adaptive lr self.target_kl = target_kl + if target_kl is not None and self.verbose > 0: + print(f"Using adaptive learning rate with {target_kl=}, any other lr schedule will be skipped.") if _init_setup_model: self._setup_model() @@ -167,6 +171,9 @@ def __init__( def _setup_model(self) -> None: super()._setup_model() + if self.target_kl is not None: + self.adaptive_lr = KLAdaptiveLR(self.target_kl, self.lr_schedule(1.0)) + if not hasattr(self, "policy") or self.policy is None: # type: ignore[has-type] self.policy = self.policy_class( # type: ignore[assignment] self.observation_space, @@ -179,16 +186,16 @@ def _setup_model(self) -> None: self.key, ent_key = jax.random.split(self.key, 2) - self.actor = self.policy.actor - self.vf = self.policy.vf + self.actor = self.policy.actor # type: ignore[assignment] + self.vf = self.policy.vf # type: ignore[assignment] # Initialize schedules for policy/value clipping - self.clip_range_schedule = get_schedule_fn(self.clip_range) + self.clip_range_schedule = FloatSchedule(self.clip_range) # if self.clip_range_vf is not None: # if isinstance(self.clip_range_vf, (float, int)): # assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, " "pass `None` to deactivate vf clipping" # - # self.clip_range_vf = get_schedule_fn(self.clip_range_vf) + # self.clip_range_vf = FloatSchedule(self.clip_range_vf) @staticmethod @partial(jax.jit, static_argnames=["normalize_advantage"]) @@ -229,9 +236,9 @@ def actor_loss(params): entropy_loss = -jnp.mean(entropy) total_policy_loss = policy_loss + ent_coef * entropy_loss - return total_policy_loss + return total_policy_loss, ratio - pg_loss_value, grads = jax.value_and_grad(actor_loss, has_aux=False)(actor_state.params) + (pg_loss_value, ratio), grads = jax.value_and_grad(actor_loss, has_aux=True)(actor_state.params) actor_state = actor_state.apply_gradients(grads=grads) def critic_loss(params): @@ -243,28 +250,36 @@ def critic_loss(params): vf_state = vf_state.apply_gradients(grads=grads) # loss = policy_loss + ent_coef * entropy_loss + vf_coef * value_loss - return (actor_state, vf_state), (pg_loss_value, vf_loss_value) + return (actor_state, vf_state), (pg_loss_value, vf_loss_value, ratio) def train(self) -> None: """ Update policy using the currently gathered rollout buffer. """ # Update optimizer learning rate - # self._update_learning_rate(self.policy.optimizer) + if self.target_kl is None: + self._update_learning_rate( + [self.policy.actor_state.opt_state[1], self.policy.vf_state.opt_state[1]], + learning_rate=self.lr_schedule(self._current_progress_remaining), + ) # Compute current clip range clip_range = self.clip_range_schedule(self._current_progress_remaining) + n_updates = 0 + mean_clip_fraction = 0.0 + mean_kl_div = 0.0 # train for n_epochs epochs for _ in range(self.n_epochs): # JIT only one update for rollout_data in self.rollout_buffer.get(self.batch_size): # type: ignore[attr-defined] + n_updates += 1 if isinstance(self.action_space, spaces.Discrete): # Convert discrete action from float to int actions = rollout_data.actions.flatten().numpy().astype(np.int32) else: actions = rollout_data.actions.numpy() - (self.policy.actor_state, self.policy.vf_state), (pg_loss, value_loss) = self._one_update( + (self.policy.actor_state, self.policy.vf_state), (pg_loss, value_loss, ratio) = self._one_update( actor_state=self.policy.actor_state, vf_state=self.policy.vf_state, observations=rollout_data.observations.numpy(), @@ -278,6 +293,25 @@ def train(self) -> None: normalize_advantage=self.normalize_advantage, ) + # Calculate approximate form of reverse KL Divergence for adaptive lr + # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417 + # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419 + # and Schulman blog: http://joschu.net/blog/kl-approx.html + eps = 1e-7 # Avoid NaN due to numerical instabilities + approx_kl_div = jnp.mean((ratio - 1.0 + eps) - jnp.log(ratio + eps)).item() + clip_fraction = jnp.mean(jnp.abs(ratio - 1) > clip_range).item() + # Compute average + mean_clip_fraction += (clip_fraction - mean_clip_fraction) / n_updates + mean_kl_div += (approx_kl_div - mean_kl_div) / n_updates + + # Adaptive lr schedule, see https://arxiv.org/abs/1707.02286 + if self.target_kl is not None: + self.adaptive_lr.update(approx_kl_div) + + self._update_learning_rate( + [self.policy.actor_state.opt_state[1], self.policy.vf_state.opt_state[1]], + learning_rate=self.adaptive_lr.current_adaptive_lr, + ) self._n_updates += self.n_epochs explained_var = explained_variance( self.rollout_buffer.values.flatten(), # type: ignore[attr-defined] @@ -289,8 +323,8 @@ def train(self) -> None: # self.logger.record("train/policy_gradient_loss", np.mean(pg_losses)) # TODO: use mean instead of one point self.logger.record("train/value_loss", value_loss.item()) - # self.logger.record("train/approx_kl", np.mean(approx_kl_divs)) - # self.logger.record("train/clip_fraction", np.mean(clip_fractions)) + self.logger.record("train/approx_kl", mean_kl_div) + self.logger.record("train/clip_fraction", mean_clip_fraction) self.logger.record("train/pg_loss", pg_loss.item()) self.logger.record("train/explained_variance", explained_var) try: diff --git a/sbx/sac/policies.py b/sbx/sac/policies.py index 95c319f..e77e57d 100644 --- a/sbx/sac/policies.py +++ b/sbx/sac/policies.py @@ -95,13 +95,14 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) # Hack to make gSDE work without modifying internal SB3 code self.actor.reset_noise = self.reset_noise + # Inject hyperparameters to be able to modify it later + # See https://stackoverflow.com/questions/78527164 + optimizer_class = optax.inject_hyperparams(self.optimizer_class)(learning_rate=lr_schedule(1), **self.optimizer_kwargs) + self.actor_state = TrainState.create( apply_fn=self.actor.apply, params=self.actor.init(actor_key, obs), - tx=self.optimizer_class( - learning_rate=lr_schedule(1), # type: ignore[call-arg] - **self.optimizer_kwargs, - ), + tx=optimizer_class, ) self.qf = self.vector_critic_class( @@ -112,6 +113,10 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) activation_fn=self.activation_fn, ) + optimizer_class_qf = optax.inject_hyperparams(self.optimizer_class)( + learning_rate=qf_learning_rate, **self.optimizer_kwargs + ) + self.qf_state = RLTrainState.create( apply_fn=self.qf.apply, params=self.qf.init( @@ -124,10 +129,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) obs, action, ), - tx=self.optimizer_class( - learning_rate=qf_learning_rate, # type: ignore[call-arg] - **self.optimizer_kwargs, - ), + tx=optimizer_class_qf, ) self.actor.apply = jax.jit(self.actor.apply) # type: ignore[method-assign] diff --git a/sbx/sac/sac.py b/sbx/sac/sac.py index a6826c5..73ff2f8 100644 --- a/sbx/sac/sac.py +++ b/sbx/sac/sac.py @@ -159,7 +159,7 @@ def _setup_model(self) -> None: apply_fn=self.ent_coef.apply, params=self.ent_coef.init(ent_key)["params"], tx=optax.adam( - learning_rate=self.learning_rate, + learning_rate=self.lr_schedule(1), ), ) @@ -195,6 +195,18 @@ def train(self, gradient_steps: int, batch_size: int) -> None: # Sample all at once for efficiency (so we can jit the for loop) data = self.replay_buffer.sample(batch_size * gradient_steps, env=self._vec_normalize_env) + self._update_learning_rate( + self.policy.actor_state.opt_state, + learning_rate=self.lr_schedule(self._current_progress_remaining), + name="learning_rate_actor", + ) + # Note: for now same schedule for actor and critic unless qf_lr = cst + self._update_learning_rate( + self.policy.qf_state.opt_state, + learning_rate=self.initial_qf_learning_rate or self.lr_schedule(self._current_progress_remaining), + name="learning_rate_critic", + ) + # Maybe reset the parameters/optimizers fully self._maybe_reset_params() diff --git a/sbx/tqc/policies.py b/sbx/tqc/policies.py index 0c4c03f..4783e91 100644 --- a/sbx/tqc/policies.py +++ b/sbx/tqc/policies.py @@ -36,7 +36,7 @@ def __init__( use_sde: bool = False, # Note: most gSDE parameters are not used # this is to keep API consistent with SB3 - log_std_init: float = -3, + log_std_init: float = 0.0, use_expln: bool = False, clip_mean: float = 2.0, features_extractor_class=None, @@ -79,6 +79,7 @@ def __init__( self.activation_fn = activation_fn self.actor_class = actor_class self.critic_class = critic_class + self.log_std_init = log_std_init self.key = self.noise_key = jax.random.PRNGKey(0) @@ -103,13 +104,14 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) # Hack to make gSDE work without modifying internal SB3 code self.actor.reset_noise = self.reset_noise + # Inject hyperparameters to be able to modify it later + # See https://stackoverflow.com/questions/78527164 + optimizer_class = optax.inject_hyperparams(self.optimizer_class)(learning_rate=lr_schedule(1), **self.optimizer_kwargs) + self.actor_state = TrainState.create( apply_fn=self.actor.apply, params=self.actor.init(actor_key, obs), - tx=self.optimizer_class( - learning_rate=lr_schedule(1), # type: ignore[call-arg] - **self.optimizer_kwargs, - ), + tx=optimizer_class, ) self.qf = self.critic_class( @@ -120,6 +122,10 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) activation_fn=self.activation_fn, ) + optimizer_class_qf = optax.inject_hyperparams(self.optimizer_class)( + learning_rate=qf_learning_rate, **self.optimizer_kwargs + ) + self.qf1_state = RLTrainState.create( apply_fn=self.qf.apply, params=self.qf.init( @@ -132,7 +138,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) obs, action, ), - tx=optax.adam(learning_rate=qf_learning_rate), # type: ignore[call-arg] + tx=optimizer_class_qf, ) self.qf2_state = RLTrainState.create( apply_fn=self.qf.apply, @@ -146,10 +152,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) obs, action, ), - tx=self.optimizer_class( - learning_rate=qf_learning_rate, # type: ignore[call-arg] - **self.optimizer_kwargs, - ), + tx=optimizer_class_qf, ) self.actor.apply = jax.jit(self.actor.apply) # type: ignore[method-assign] self.qf.apply = jax.jit( # type: ignore[method-assign] @@ -190,7 +193,7 @@ def __init__( n_quantiles: int = 25, activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu, use_sde: bool = False, - log_std_init: float = -3, + log_std_init: float = 0.0, use_expln: bool = False, clip_mean: float = 2, features_extractor_class=None, diff --git a/sbx/tqc/tqc.py b/sbx/tqc/tqc.py index 91b5c9b..9593535 100644 --- a/sbx/tqc/tqc.py +++ b/sbx/tqc/tqc.py @@ -163,7 +163,7 @@ def _setup_model(self) -> None: apply_fn=self.ent_coef.apply, params=self.ent_coef.init(ent_key)["params"], tx=optax.adam( - learning_rate=self.learning_rate, + learning_rate=self.lr_schedule(1), ), ) @@ -199,6 +199,17 @@ def train(self, gradient_steps: int, batch_size: int) -> None: # Sample all at once for efficiency (so we can jit the for loop) data = self.replay_buffer.sample(batch_size * gradient_steps, env=self._vec_normalize_env) + self._update_learning_rate( + self.policy.actor_state.opt_state, + learning_rate=self.lr_schedule(self._current_progress_remaining), + name="learning_rate_actor", + ) + # Note: for now same schedule for actor and critic unless qf_lr = cst + self._update_learning_rate( + [self.policy.qf1_state.opt_state, self.policy.qf2_state.opt_state], + learning_rate=self.initial_qf_learning_rate or self.lr_schedule(self._current_progress_remaining), + name="learning_rate_critic", + ) # Maybe reset the parameters/optimizers fully self._maybe_reset_params() diff --git a/sbx/version.txt b/sbx/version.txt index 5a03fb7..8854156 100644 --- a/sbx/version.txt +++ b/sbx/version.txt @@ -1 +1 @@ -0.20.0 +0.21.0 diff --git a/setup.py b/setup.py index ec2e1dc..a8655da 100644 --- a/setup.py +++ b/setup.py @@ -41,7 +41,7 @@ packages=[package for package in find_packages() if package.startswith("sbx")], package_data={"sbx": ["py.typed", "version.txt"]}, install_requires=[ - "stable_baselines3>=2.5.0,<3.0", + "stable_baselines3>=2.6.1a1,<3.0", "jax>=0.4.24", "jaxlib", "flax", diff --git a/tests/test_run.py b/tests/test_run.py index 6d3b5b9..c6e8ca1 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -146,6 +146,7 @@ def test_policy_kwargs(model_class) -> None: @pytest.mark.parametrize("env_id", ["Pendulum-v1", "CartPole-v1"]) def test_ppo(tmp_path, env_id: str) -> None: + model = PPO( "MlpPolicy", env_id, @@ -154,6 +155,7 @@ def test_ppo(tmp_path, env_id: str) -> None: batch_size=32, n_epochs=2, policy_kwargs=dict(activation_fn=nn.leaky_relu), + target_kl=0.04 if env_id == "Pendulum-v1" else None, ) model.learn(64, progress_bar=True)