diff --git a/sbx/__init__.py b/sbx/__init__.py index c2762bc..c95fe1f 100644 --- a/sbx/__init__.py +++ b/sbx/__init__.py @@ -1,5 +1,6 @@ import os +from sbx.bro import BRO from sbx.crossq import CrossQ from sbx.ddpg import DDPG from sbx.dqn import DQN @@ -29,5 +30,6 @@ def DroQ(*args, **kwargs): "SAC", "TD3", "TQC", + "BRO", "CrossQ", ] diff --git a/sbx/bro/__init__.py b/sbx/bro/__init__.py new file mode 100644 index 0000000..f1be1c2 --- /dev/null +++ b/sbx/bro/__init__.py @@ -0,0 +1,3 @@ +from sbx.bro.bro import BRO + +__all__ = ["BRO"] diff --git a/sbx/bro/bro.py b/sbx/bro/bro.py new file mode 100644 index 0000000..3da47a7 --- /dev/null +++ b/sbx/bro/bro.py @@ -0,0 +1,584 @@ +from functools import partial +from typing import Any, ClassVar, Dict, Literal, Optional, Tuple, Type, Union + +import flax +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +import optax +from flax.training.train_state import TrainState +from gymnasium import spaces +from jax.typing import ArrayLike +from stable_baselines3.common.buffers import ReplayBuffer +from stable_baselines3.common.noise import ActionNoise +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule + +from sbx.bro.policies import BROPolicy +from sbx.common.off_policy_algorithm import OffPolicyAlgorithmJax +from sbx.common.type_aliases import ReplayBufferSamplesNp, RLTrainState + + +class EntropyCoef(nn.Module): + ent_coef_init: float = 1.0 + + @nn.compact + def __call__(self) -> jnp.ndarray: + log_ent_coef = self.param("log_ent_coef", init_fn=lambda key: jnp.full((), jnp.log(self.ent_coef_init))) + return jnp.exp(log_ent_coef) + + +class ConstantEntropyCoef(nn.Module): + ent_coef_init: float = 1.0 + + @nn.compact + def __call__(self) -> float: + # Hack to not optimize the entropy coefficient while not having to use if/else for the jit + # TODO: add parameter in train to remove that hack + self.param("dummy_param", init_fn=lambda key: jnp.full((), self.ent_coef_init)) + return self.ent_coef_init + + +@jax.jit +def _get_stats( + actor_state: RLTrainState, + qf_state: RLTrainState, + ent_coef_state: TrainState, + observations: jax.Array, + key: jax.Array, +): + key, dropout_key, noise_key = jax.random.split(key, 3) + dist = actor_state.apply_fn(actor_state.params, observations) + actor_actions = dist.sample(seed=noise_key) + log_prob = dist.log_prob(actor_actions).reshape(-1, 1) + qf_pi = qf_state.apply_fn( + qf_state.params, + observations, + actor_actions, + rngs={"dropout": dropout_key}, + ) + ent_coef_value = ent_coef_state.apply_fn({"params": ent_coef_state.params}) + return qf_pi.mean(), jnp.absolute(actor_actions).mean(), ent_coef_value.mean(), -log_prob.mean() + + +class BRO(OffPolicyAlgorithmJax): + policy_aliases: ClassVar[Dict[str, Type[BROPolicy]]] = { # type: ignore[assignment] + "MlpPolicy": BROPolicy, + # Minimal dict support using flatten() + "MultiInputPolicy": BROPolicy, + } + + policy: BROPolicy + action_space: spaces.Box # type: ignore[assignment] + + def __init__( + self, + policy, + env: Union[GymEnv, str], + # BRO + n_quantiles: int = 100, + pessimism: float = 0.1, + learning_rate: Union[float, Schedule] = 3e-4, + qf_learning_rate: Optional[float] = None, + buffer_size: int = 1_000_000, # 1e6 + learning_starts: int = 100, + batch_size: int = 256, + tau: float = 0.005, + gamma: float = 0.99, + train_freq: Union[int, Tuple[int, str]] = 1, + gradient_steps: int = 2, + policy_delay: int = 1, + action_noise: Optional[ActionNoise] = None, + replay_buffer_class: Optional[Type[ReplayBuffer]] = None, + replay_buffer_kwargs: Optional[Dict[str, Any]] = None, + ent_coef: Union[str, float] = "auto", + target_entropy: Union[Literal["auto"], float] = "auto", + use_sde: bool = False, + sde_sample_freq: int = -1, + use_sde_at_warmup: bool = False, + tensorboard_log: Optional[str] = None, + policy_kwargs: Optional[Dict[str, Any]] = None, + verbose: int = 0, + seed: Optional[int] = None, + device: str = "auto", + _init_setup_model: bool = True, + ) -> None: + super().__init__( + policy=policy, + env=env, + learning_rate=learning_rate, + qf_learning_rate=qf_learning_rate, + buffer_size=buffer_size, + learning_starts=learning_starts, + batch_size=batch_size, + tau=tau, + gamma=gamma, + train_freq=train_freq, + gradient_steps=gradient_steps, + action_noise=action_noise, + replay_buffer_class=replay_buffer_class, + replay_buffer_kwargs=replay_buffer_kwargs, + use_sde=use_sde, + sde_sample_freq=sde_sample_freq, + use_sde_at_warmup=use_sde_at_warmup, + policy_kwargs=policy_kwargs, + tensorboard_log=tensorboard_log, + verbose=verbose, + seed=seed, + supported_action_spaces=(spaces.Box,), + support_multi_env=True, + ) + + self.policy_delay = policy_delay + self.ent_coef_init = ent_coef + self.target_entropy = target_entropy + + self.n_quantiles = n_quantiles + taus_ = jnp.arange(0, n_quantiles + 1) / n_quantiles + self.quantile_taus = ((taus_[1:] + taus_[:-1]) / 2.0)[None, ..., None] + self.pessimism = pessimism + + self.distributional = True if self.n_quantiles > 1 else False + if _init_setup_model: + self._setup_model() + + def _setup_model(self) -> None: + super()._setup_model() + self.reset() + + def reset(self): + if not hasattr(self, "policy") or self.policy is None: + self.policy = self.policy_class( # type: ignore[assignment] + self.observation_space, + self.action_space, + self.lr_schedule, + self.n_quantiles, + **self.policy_kwargs, + ) + + assert isinstance(self.qf_learning_rate, float) + + self.key = self.policy.build(self.key, self.lr_schedule, self.qf_learning_rate) + + self.key, ent_key = jax.random.split(self.key, 2) + + self.actor = self.policy.actor # type: ignore[assignment] + self.qf = self.policy.qf # type: ignore[assignment] + + # The entropy coefficient or entropy can be learned automatically + # see Automating Entropy Adjustment for Maximum Entropy RL section + # of https://arxiv.org/abs/1812.05905 + if isinstance(self.ent_coef_init, str) and self.ent_coef_init.startswith("auto"): + # Default initial value of ent_coef when learned + ent_coef_init = 1.0 + if "_" in self.ent_coef_init: + ent_coef_init = float(self.ent_coef_init.split("_")[1]) + assert ent_coef_init > 0.0, "The initial value of ent_coef must be greater than 0" + + # Note: we optimize the log of the entropy coeff which is slightly different from the paper + # as discussed in https://github.com/rail-berkeley/softlearning/issues/37 + self.ent_coef = EntropyCoef(ent_coef_init) + else: + # This will throw an error if a malformed string (different from 'auto') is passed + assert isinstance( + self.ent_coef_init, float + ), f"Entropy coef must be float when not equal to 'auto', actual: {self.ent_coef_init}" + self.ent_coef = ConstantEntropyCoef(self.ent_coef_init) # type: ignore[assignment] + + self.ent_coef_state = TrainState.create( + apply_fn=self.ent_coef.apply, + params=self.ent_coef.init(ent_key)["params"], + tx=optax.adam(learning_rate=self.learning_rate, b1=0.5), + ) + + # Target entropy is used when learning the entropy coefficient + if self.target_entropy == "auto": + # automatically set target entropy if needed + self.target_entropy = -np.prod(self.env.action_space.shape).astype(np.float32) / 2 # type: ignore + else: + # Force conversion + # this will also throw an error for unexpected string + self.target_entropy = float(self.target_entropy) + + def learn( + self, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 4, + tb_log_name: str = "BRO", + reset_num_timesteps: bool = True, + progress_bar: bool = False, + ): + return super().learn( + total_timesteps=total_timesteps, + callback=callback, + log_interval=log_interval, + tb_log_name=tb_log_name, + reset_num_timesteps=reset_num_timesteps, + progress_bar=progress_bar, + ) + + def train(self, gradient_steps: int, batch_size: int) -> None: + assert self.replay_buffer is not None + # Sample all at once for efficiency (so we can jit the for loop) + data = self.replay_buffer.sample(batch_size * gradient_steps, env=self._vec_normalize_env) + + if isinstance(data.observations, dict): + keys = list(self.observation_space.keys()) # type: ignore[attr-defined] + obs = np.concatenate([data.observations[key].numpy() for key in keys], axis=1) + next_obs = np.concatenate([data.next_observations[key].numpy() for key in keys], axis=1) + else: + obs = data.observations.numpy() + next_obs = data.next_observations.numpy() + + # Convert to numpy + data = ReplayBufferSamplesNp( # type: ignore[assignment] + obs, + data.actions.numpy(), + next_obs, + data.dones.numpy().flatten(), + data.rewards.numpy().flatten(), + ) + + ( + self.policy.qf_state, + self.policy.actor_state, + self.ent_coef_state, + self.key, + (actor_loss_value, qf_loss_value, ent_coef_value), + ) = self._train( + self.gamma, + self.tau, + self.target_entropy, + gradient_steps, + data, + self.policy_delay, + (self._n_updates + 1) % self.policy_delay, + self.policy.qf_state, + self.policy.actor_state, + self.ent_coef_state, + self.quantile_taus, + self.distributional, + self.pessimism, + self.key, + ) + self._n_updates += gradient_steps + self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") + self.logger.record("train/actor_loss", actor_loss_value.item()) + self.logger.record("train/critic_loss", qf_loss_value.item()) + self.logger.record("train/ent_coef", ent_coef_value.item()) + + @staticmethod + @jax.jit + def update_critic( + gamma: float, + actor_state: TrainState, + qf_state: RLTrainState, + ent_coef_state: TrainState, + observations: jax.Array, + actions: jax.Array, + next_observations: jax.Array, + rewards: jax.Array, + dones: jax.Array, + quantile_taus: jax.Array, + pessimism: float, + key: jax.Array, + ): + key, noise_key, dropout_key_target, dropout_key_current = jax.random.split(key, 4) + # sample action from the actor + dist = actor_state.apply_fn(actor_state.params, next_observations) + next_state_actions = dist.sample(seed=noise_key) + next_log_prob = dist.log_prob(next_state_actions) + + ent_coef_value = ent_coef_state.apply_fn({"params": ent_coef_state.params}) + + qf_next_values = qf_state.apply_fn( + qf_state.target_params, + next_observations, + next_state_actions, + rngs={"dropout": dropout_key_target}, + ) + + ensemble_size = qf_next_values.shape[0] + diff = jnp.abs(qf_next_values[:, None, :, :] - qf_next_values[None, :, :, :]) / 2 + i, j = jnp.triu_indices(ensemble_size, k=1) # Get indices for upper triangle without the diagonal + critic_disagreement = jnp.mean(diff[i, j], axis=0) + + next_q_values = jnp.mean(qf_next_values, axis=0) - pessimism * critic_disagreement + # td error + entropy term + next_q_values = next_q_values - ent_coef_value * next_log_prob.reshape(-1, 1) + # shape is (batch_size, 1) + target_q_values = rewards.reshape(-1, 1) + (1 - dones.reshape(-1, 1)) * gamma * next_q_values + + def mse_loss(params: flax.core.FrozenDict, dropout_key: jax.Array) -> jax.Array: + # shape is (n_critics, batch_size, 1) + current_q_values = qf_state.apply_fn(params, observations, actions, rngs={"dropout": dropout_key}) + return 0.5 * ((target_q_values - current_q_values) ** 2).mean(axis=1).sum() + + qf_loss_value, grads = jax.value_and_grad(mse_loss, has_aux=False)(qf_state.params, dropout_key_current) + qf_state = qf_state.apply_gradients(grads=grads) + + return ( + qf_state, + (qf_loss_value, ent_coef_value), + key, + ) + + @staticmethod + @jax.jit + def update_critic_quantile( + gamma: float, + actor_state: TrainState, + qf_state: RLTrainState, + ent_coef_state: TrainState, + observations: jax.Array, + actions: jax.Array, + next_observations: jax.Array, + rewards: jax.Array, + dones: jax.Array, + quantile_taus: jax.Array, + pessimism: float, + key: jax.Array, + ): + + key, noise_key, dropout_key_target, dropout_key_current = jax.random.split(key, 4) + # sample action from the actor + dist = actor_state.apply_fn(actor_state.params, next_observations) + next_state_actions = dist.sample(seed=noise_key) + next_log_prob = dist.log_prob(next_state_actions) + + ent_coef_value = ent_coef_state.apply_fn({"params": ent_coef_state.params}) + + qf_next_values = qf_state.apply_fn( + qf_state.target_params, + next_observations, + next_state_actions, + rngs={"dropout": dropout_key_target}, + ) + + # calculate disagreement + ensemble_size = qf_next_values.shape[0] + diff = jnp.abs(qf_next_values[:, None, :, :] - qf_next_values[None, :, :, :]) / 2 + i, j = jnp.triu_indices(ensemble_size, k=1) # Get indices for upper triangle without the diagonal + critic_disagreement = jnp.mean(diff[i, j], axis=0) + next_q_values = jnp.mean(qf_next_values, axis=0) - pessimism * critic_disagreement + # entropy term + next_q_values = next_q_values - ent_coef_value * next_log_prob.reshape(-1, 1) + # shape is (batch_size, 1, n_quantiles) + target_q_values = rewards[..., None, None] + (1 - dones[..., None, None]) * gamma * next_q_values[:, None, :] + + def quantile_loss(params: flax.core.FrozenDict, dropout_key: jax.Array) -> jax.Array: + # shape is (n_critics, batch_size, 1) + current_q_values = qf_state.apply_fn(params, observations, actions, rngs={"dropout": dropout_key}) + quantile_td_error = target_q_values[None, ...] - current_q_values[..., None] + + def calculate_quantile_huber_loss( + quantile_td_error: jnp.ndarray, quantile_taus: jnp.ndarray, kappa: float = 1.0 + ) -> jnp.ndarray: + element_wise_huber_loss = jnp.where( + jnp.absolute(quantile_td_error) <= kappa, + 0.5 * quantile_td_error**2, + kappa * (jnp.absolute(quantile_td_error) - 0.5 * kappa), + ) + mask = jax.lax.stop_gradient(jnp.where(quantile_td_error < 0, 1, 0)) # detach this + element_wise_quantile_huber_loss = jnp.absolute(quantile_taus - mask) * element_wise_huber_loss / kappa + quantile_huber_loss = element_wise_quantile_huber_loss.sum(axis=0).sum(axis=1).mean() + return quantile_huber_loss + + quantile_huber_loss = calculate_quantile_huber_loss(quantile_td_error, quantile_taus) + return quantile_huber_loss + + qf_loss_value, grads = jax.value_and_grad(quantile_loss, has_aux=False)(qf_state.params, dropout_key_current) + qf_state = qf_state.apply_gradients(grads=grads) + + return ( + qf_state, + (qf_loss_value, ent_coef_value), + key, + ) + + @staticmethod + @jax.jit + def update_actor( + actor_state: RLTrainState, + qf_state: RLTrainState, + ent_coef_state: TrainState, + observations: jax.Array, + pessimism: float, + key: jax.Array, + ): + key, dropout_key, noise_key = jax.random.split(key, 3) + + def actor_loss(params: flax.core.FrozenDict) -> Tuple[jax.Array, jax.Array]: + dist = actor_state.apply_fn(params, observations) + actor_actions = dist.sample(seed=noise_key) + log_prob = dist.log_prob(actor_actions).reshape(-1, 1) + + qf_pi = qf_state.apply_fn( + qf_state.params, + observations, + actor_actions, + rngs={"dropout": dropout_key}, + ) + + # Take mean among all critics + ensemble_size = qf_pi.shape[0] + diff = jnp.abs(qf_pi[:, None, :, :] - qf_pi[None, :, :, :]) / 2 + i, j = jnp.triu_indices(ensemble_size, k=1) # Get indices for upper triangle without the diagonal + critic_disagreement = jnp.mean(diff[i, j], axis=0) + + qf_pi_lb = jnp.mean(qf_pi, axis=0) - pessimism * critic_disagreement + qf_pi_lb = jnp.mean(qf_pi_lb, axis=-1, keepdims=True) + ent_coef_value = ent_coef_state.apply_fn({"params": ent_coef_state.params}) + actor_loss = (ent_coef_value * log_prob - qf_pi_lb).mean() + return actor_loss, -log_prob.mean() + + (actor_loss_value, entropy), grads = jax.value_and_grad(actor_loss, has_aux=True)(actor_state.params) + actor_state = actor_state.apply_gradients(grads=grads) + + return actor_state, qf_state, actor_loss_value, key, entropy + + @staticmethod + @jax.jit + def soft_update(tau: float, qf_state: RLTrainState) -> RLTrainState: + qf_state = qf_state.replace(target_params=optax.incremental_update(qf_state.params, qf_state.target_params, tau)) + return qf_state + + @staticmethod + @jax.jit + def update_temperature(target_entropy: ArrayLike, ent_coef_state: TrainState, entropy: float): + def temperature_loss(temp_params: flax.core.FrozenDict) -> jax.Array: + ent_coef_value = ent_coef_state.apply_fn({"params": temp_params}) + ent_coef_loss = ent_coef_value * (entropy - target_entropy).mean() # type: ignore[union-attr] + return ent_coef_loss + + ent_coef_loss, grads = jax.value_and_grad(temperature_loss)(ent_coef_state.params) + ent_coef_state = ent_coef_state.apply_gradients(grads=grads) + + return ent_coef_state, ent_coef_loss + + @classmethod + def update_actor_and_temperature( + cls, + actor_state: RLTrainState, + qf_state: RLTrainState, + ent_coef_state: TrainState, + observations: jax.Array, + target_entropy: ArrayLike, + pessimism: float, + key: jax.Array, + ): + (actor_state, qf_state, actor_loss_value, key, entropy) = cls.update_actor( + actor_state, + qf_state, + ent_coef_state, + observations, + pessimism, + key, + ) + ent_coef_state, ent_coef_loss_value = cls.update_temperature(target_entropy, ent_coef_state, entropy) + return actor_state, qf_state, ent_coef_state, actor_loss_value, ent_coef_loss_value, key + + @classmethod + @partial(jax.jit, static_argnames=["cls", "gradient_steps", "policy_delay", "policy_delay_offset", "distributional"]) + def _train( + cls, + gamma: float, + tau: float, + target_entropy: ArrayLike, + gradient_steps: int, + data: ReplayBufferSamplesNp, + policy_delay: int, + policy_delay_offset: int, + qf_state: RLTrainState, + actor_state: TrainState, + ent_coef_state: TrainState, + quantile_taus: jax.Array, + distributional: bool, + pessimism: float, + key: jax.Array, + ): + assert data.observations.shape[0] % gradient_steps == 0 + batch_size = data.observations.shape[0] // gradient_steps + + carry = { + "actor_state": actor_state, + "qf_state": qf_state, + "ent_coef_state": ent_coef_state, + "key": key, + "info": { + "actor_loss": jnp.array(0.0), + "qf_loss": jnp.array(0.0), + "ent_coef_loss": jnp.array(0.0), + }, + } + + def one_update(i: int, carry: Dict[str, Any]) -> Dict[str, Any]: + # Note: this method must be defined inline because + # `fori_loop` expect a signature fn(index, carry) -> carry + actor_state = carry["actor_state"] + qf_state = carry["qf_state"] + ent_coef_state = carry["ent_coef_state"] + key = carry["key"] + info = carry["info"] + batch_obs = jax.lax.dynamic_slice_in_dim(data.observations, i * batch_size, batch_size) + batch_act = jax.lax.dynamic_slice_in_dim(data.actions, i * batch_size, batch_size) + batch_next_obs = jax.lax.dynamic_slice_in_dim(data.next_observations, i * batch_size, batch_size) + batch_rew = jax.lax.dynamic_slice_in_dim(data.rewards, i * batch_size, batch_size) + batch_done = jax.lax.dynamic_slice_in_dim(data.dones, i * batch_size, batch_size) + + (qf_state, (qf_loss_value, ent_coef_value), key) = jax.lax.cond( + distributional, + # If True: + cls.update_critic_quantile, + # If False: + cls.update_critic, + gamma, + actor_state, + qf_state, + ent_coef_state, + batch_obs, + batch_act, + batch_next_obs, + batch_rew, + batch_done, + quantile_taus, + pessimism, + key, + ) + + qf_state = cls.soft_update(tau, qf_state) + + (actor_state, qf_state, ent_coef_state, actor_loss_value, ent_coef_loss_value, key) = jax.lax.cond( + (policy_delay_offset + i) % policy_delay == 0, + # If True: + cls.update_actor_and_temperature, + # If False: + lambda *_: (actor_state, qf_state, ent_coef_state, info["actor_loss"], info["ent_coef_loss"], key), + actor_state, + qf_state, + ent_coef_state, + batch_obs, + target_entropy, + pessimism, + key, + ) + info = {"actor_loss": actor_loss_value, "qf_loss": qf_loss_value, "ent_coef_loss": ent_coef_loss_value} + + return { + "actor_state": actor_state, + "qf_state": qf_state, + "ent_coef_state": ent_coef_state, + "key": key, + "info": info, + } + + update_carry = jax.lax.fori_loop(0, gradient_steps, one_update, carry) + + return ( + update_carry["qf_state"], + update_carry["actor_state"], + update_carry["ent_coef_state"], + update_carry["key"], + (update_carry["info"]["actor_loss"], update_carry["info"]["qf_loss"], update_carry["info"]["ent_coef_loss"]), + ) diff --git a/sbx/bro/policies.py b/sbx/bro/policies.py new file mode 100644 index 0000000..8b44b09 --- /dev/null +++ b/sbx/bro/policies.py @@ -0,0 +1,262 @@ +from typing import Any, Callable, Dict, List, Optional, Sequence, Union + +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +import optax +import tensorflow_probability.substrates.jax as tfp +from flax.training.train_state import TrainState +from gymnasium import spaces +from stable_baselines3.common.type_aliases import Schedule + +from sbx.common.distributions import TanhTransformedDistribution +from sbx.common.policies import BaseJaxPolicy, Flatten +from sbx.common.type_aliases import RLTrainState + +tfd = tfp.distributions + + +class BroNetBlock(nn.Module): + n_units: int + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu + + @nn.compact + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + out = nn.Dense(self.n_units)(x) + out = nn.LayerNorm()(out) + out = self.activation_fn(out) + out = nn.Dense(self.n_units)(out) + out = nn.LayerNorm()(out) + return x + out + + +class BroNet(nn.Module): + net_arch: Sequence[int] + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu + + @nn.compact + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + x = nn.Dense(self.net_arch[0])(x) + x = nn.LayerNorm()(x) + x = self.activation_fn(x) + for n_units in self.net_arch: + x = BroNetBlock(n_units, self.activation_fn)(x) + return x + + +class Actor(nn.Module): + net_arch: Sequence[int] + action_dim: int + log_std_min: float = -10 + log_std_max: float = 2 + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu + + def get_std(self): + # Make it work with gSDE + return jnp.array(0.0) + + @nn.compact + def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-defined] + x = Flatten()(x) + x = BroNet(net_arch=self.net_arch)(x) + mean = nn.Dense(self.action_dim)(x) + log_std = nn.Dense(self.action_dim)(x) + log_std = self.log_std_min + (self.log_std_max - self.log_std_min) * 0.5 * (1 + nn.tanh(log_std)) + dist = TanhTransformedDistribution( + tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)), + ) + return dist + + +class Critic(nn.Module): + net_arch: Sequence[int] + n_quantiles: int = 100 + dropout_rate: Optional[float] = None + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu + + @nn.compact + def __call__(self, x: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray: + x = Flatten()(x) + out = jnp.concatenate([x, action], -1) + out = BroNet(self.net_arch, self.activation_fn)(out) + out = nn.Dense(self.n_quantiles)(out) + return out + + +class VectorCritic(nn.Module): + net_arch: Sequence[int] + n_quantiles: int = 100 + n_critics: int = 2 + dropout_rate: Optional[float] = None + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu + + @nn.compact + def __call__(self, obs: jnp.ndarray, action: jnp.ndarray): + # Idea taken from https://github.com/perrin-isir/xpag + # Similar to https://github.com/tinkoff-ai/CORL for PyTorch + vmap_critic = nn.vmap( + Critic, + variable_axes={"params": 0}, # parameters not shared between the critics + split_rngs={"params": True}, # different initializations + in_axes=None, + out_axes=0, + axis_size=self.n_critics, + ) + q_values = vmap_critic( + net_arch=self.net_arch, + n_quantiles=self.n_quantiles, + activation_fn=self.activation_fn, + )(obs, action) + return q_values + + +class BROPolicy(BaseJaxPolicy): + action_space: spaces.Box # type: ignore[assignment] + + def __init__( + self, + observation_space: spaces.Space, + action_space: spaces.Box, + lr_schedule: Schedule, + # BRO + n_quantiles: int = 100, + n_critics: int = 2, + net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, + dropout_rate: float = 0.0, + layer_norm: bool = True, + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu, + use_sde: bool = False, + # Note: most gSDE parameters are not used + # this is to keep API consistent with SB3 + log_std_init: float = -3, + use_expln: bool = False, + clip_mean: float = 2.0, + features_extractor_class=None, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Callable[..., optax.GradientTransformation] = optax.adamw, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + share_features_extractor: bool = False, + ): + if optimizer_kwargs is None: + # Note: the default value for b1 is 0.9 in Adam. + # b1=0.5 is used in the original CrossQ implementation + # but shows only little overall improvement. + optimizer_kwargs = {} + if optimizer_class in [optax.adam, optax.adamw]: + pass + + super().__init__( + observation_space, + action_space, + features_extractor_class, + features_extractor_kwargs, + optimizer_class=optimizer_class, + optimizer_kwargs=optimizer_kwargs, + squash_output=True, + ) + + self.dropout_rate = dropout_rate + self.layer_norm = layer_norm + self.n_quantiles = n_quantiles + if net_arch is not None: + if isinstance(net_arch, list): + self.net_arch_pi = self.net_arch_qf = net_arch + else: + self.net_arch_pi = net_arch["pi"] + self.net_arch_qf = net_arch["qf"] + else: + self.net_arch_pi = [256] + # In the original implementation, the authors use [512, 512] + # but with a higher replay ratio (RR), + # here we use bigger network size to compensate for the smaller RR + self.net_arch_qf = [1024, 1024] + self.n_critics = n_critics + self.use_sde = use_sde + self.activation_fn = activation_fn + + self.key = self.noise_key = jax.random.PRNGKey(0) + + def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) -> jax.Array: + key, actor_key, qf_key, dropout_key = jax.random.split(key, 4) + # Keep a key for the actor + key, self.key = jax.random.split(key, 2) + # Initialize noise + self.reset_noise() + + if isinstance(self.observation_space, spaces.Dict): + obs = jnp.array([spaces.flatten(self.observation_space, self.observation_space.sample())]) + else: + obs = jnp.array([self.observation_space.sample()]) + action = jnp.array([self.action_space.sample()]) + + self.actor = Actor( + action_dim=int(np.prod(self.action_space.shape)), + net_arch=self.net_arch_pi, + activation_fn=self.activation_fn, + ) + + # Hack to make gSDE work without modifying internal SB3 code + self.actor.reset_noise = self.reset_noise + + self.actor_state = TrainState.create( + apply_fn=self.actor.apply, + params=self.actor.init(actor_key, obs), + tx=self.optimizer_class( + learning_rate=lr_schedule(1), # type: ignore[call-arg] + # learning_rate=qf_learning_rate, # type: ignore[call-arg] + **self.optimizer_kwargs, + ), + ) + + self.qf = VectorCritic( + net_arch=self.net_arch_qf, + n_quantiles=self.n_quantiles, + n_critics=self.n_critics, + dropout_rate=self.dropout_rate, + activation_fn=self.activation_fn, + ) + + self.qf_state = RLTrainState.create( + apply_fn=self.qf.apply, + params=self.qf.init( + {"params": qf_key, "dropout": dropout_key}, + obs, + action, + ), + target_params=self.qf.init( + {"params": qf_key, "dropout": dropout_key}, + obs, + action, + ), + tx=self.optimizer_class( + learning_rate=qf_learning_rate, # type: ignore[call-arg] + **self.optimizer_kwargs, + ), + ) + + self.actor.apply = jax.jit(self.actor.apply) # type: ignore[method-assign] + self.qf.apply = jax.jit( # type: ignore[method-assign] + self.qf.apply, + static_argnames=("dropout_rate", "use_layer_norm"), + ) + + return key + + def reset_noise(self, batch_size: int = 1) -> None: + """ + Sample new weights for the exploration matrix, when using gSDE. + """ + self.key, self.noise_key = jax.random.split(self.key, 2) + + def forward(self, obs: np.ndarray, deterministic: bool = False) -> np.ndarray: + return self._predict(obs, deterministic=deterministic) + + def _predict(self, observation: np.ndarray, deterministic: bool = False) -> np.ndarray: # type: ignore[override] + if deterministic: + return BaseJaxPolicy.select_action(self.actor_state, observation) + # Trick to use gSDE: repeat sampled noise by using the same noise key + if not self.use_sde: + self.reset_noise() + return BaseJaxPolicy.sample_action(self.actor_state, observation, self.noise_key) diff --git a/scripts/run_tests.sh b/scripts/run_tests.sh index 9ecb207..b356a08 100755 --- a/scripts/run_tests.sh +++ b/scripts/run_tests.sh @@ -1,2 +1,2 @@ #!/bin/bash -python3 -m pytest --cov-config .coveragerc --cov-report html --cov-report term --cov=. -v --color=yes -m "not expensive" +python3 -m pytest --cov-config .coveragerc --cov-report html --cov-report term --cov=. -v --color=yes -m "not expensive" \ No newline at end of file diff --git a/tests/test_run.py b/tests/test_run.py index 6d3b5b9..283c9f4 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -8,7 +8,7 @@ from stable_baselines3.common.envs import BitFlippingEnv from stable_baselines3.common.evaluation import evaluate_policy -from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ, DroQ +from sbx import BRO, DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ, DroQ def check_save_load(model, model_class, tmp_path): @@ -129,7 +129,7 @@ def test_dropout(model_class): model.learn(110) -@pytest.mark.parametrize("model_class", [SAC, TD3, DDPG, DQN, CrossQ]) +@pytest.mark.parametrize("model_class", [SAC, TD3, DDPG, DQN, CrossQ, BRO]) def test_policy_kwargs(model_class) -> None: env_id = "CartPole-v1" if model_class == DQN else "Pendulum-v1"