From 4dfd17bb4f7dc408936f35c2d4c5c5e4cc4e0255 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Tue, 6 Jun 2023 15:57:04 +0200 Subject: [PATCH 01/14] Draft SAC7 implementation --- sbx/__init__.py | 2 + sbx/sac7/__init__.py | 3 + sbx/sac7/policies.py | 331 ++++++++++++++++++++++++++++++ sbx/sac7/sac7.py | 471 +++++++++++++++++++++++++++++++++++++++++++ tests/test_run.py | 4 +- 5 files changed, 809 insertions(+), 2 deletions(-) create mode 100644 sbx/sac7/__init__.py create mode 100644 sbx/sac7/policies.py create mode 100644 sbx/sac7/sac7.py diff --git a/sbx/__init__.py b/sbx/__init__.py index e08a877..b0be45e 100644 --- a/sbx/__init__.py +++ b/sbx/__init__.py @@ -4,6 +4,7 @@ from sbx.droq import DroQ from sbx.ppo import PPO from sbx.sac import SAC +from sbx.sac7 import SAC7 from sbx.tqc import TQC # Read version from file @@ -16,5 +17,6 @@ "DroQ", "PPO", "SAC", + "SAC7", "TQC", ] diff --git a/sbx/sac7/__init__.py b/sbx/sac7/__init__.py new file mode 100644 index 0000000..381b3b5 --- /dev/null +++ b/sbx/sac7/__init__.py @@ -0,0 +1,3 @@ +from sbx.sac7.sac7 import SAC7 + +__all__ = ["SAC7"] diff --git a/sbx/sac7/policies.py b/sbx/sac7/policies.py new file mode 100644 index 0000000..5d321fd --- /dev/null +++ b/sbx/sac7/policies.py @@ -0,0 +1,331 @@ +from typing import Any, Callable, Dict, List, Optional, Sequence, Union + +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +import optax +import tensorflow_probability +from flax.training.train_state import TrainState +from gymnasium import spaces +from stable_baselines3.common.type_aliases import Schedule + +from sbx.common.distributions import TanhTransformedDistribution +from sbx.common.policies import BaseJaxPolicy +from sbx.common.type_aliases import RLTrainState + +tfp = tensorflow_probability.substrates.jax +tfd = tfp.distributions + + +# class AvgL1Norm(nn.Module): +# eps: float = 1e-8 +# @nn.compact +# def __call__(self, x: jnp.ndarray) -> jnp.ndarray: +# return x / jnp.clip(jnp.mean(jnp.abs(x), axis=-1, keepdims=True), a_min=self.eps) + + +@jax.jit +def avg_l1_norm(x: jnp.ndarray, eps: float = 1e-8) -> jnp.ndarray: + return x / jnp.clip(jnp.mean(jnp.abs(x), axis=-1, keepdims=True), a_min=eps) + + +class StateEncoder(nn.Module): + net_arch: Sequence[int] + embedding_dim: int = 256 + + @nn.compact + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + for n_units in self.net_arch: + x = nn.Dense(n_units)(x) + x = nn.elu(x) + x = nn.Dense(self.embedding_dim)(x) + return avg_l1_norm(x) + + +class StateActionEncoder(nn.Module): + net_arch: Sequence[int] + embedding_dim: int = 256 + + @nn.compact + def __call__(self, x: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray: + x = jnp.concatenate([x, action], -1) + for n_units in self.net_arch: + x = nn.Dense(n_units)(x) + x = nn.elu(x) + x = nn.Dense(self.embedding_dim)(x) + return avg_l1_norm(x) + + +class Critic(nn.Module): + net_arch: Sequence[int] + use_layer_norm: bool = False + dropout_rate: Optional[float] = None + + @nn.compact + def __call__(self, x: jnp.ndarray, action: jnp.ndarray, z_state: jnp.ndarray, z_state_action: jnp.ndarray) -> jnp.ndarray: + x = jnp.concatenate([x, action], -1) + embeddings = jnp.concatenate([z_state, z_state_action], -1) + x = avg_l1_norm(nn.Dense(self.net_arch[0])(x)) + # Combine with embeddings + x = jnp.concatenate([x, embeddings], -1) + + for n_units in self.net_arch: + x = nn.Dense(n_units)(x) + if self.dropout_rate is not None and self.dropout_rate > 0: + x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=False) + if self.use_layer_norm: + x = nn.LayerNorm()(x) + x = nn.relu(x) + x = nn.Dense(1)(x) + return x + + +class VectorCritic(nn.Module): + net_arch: Sequence[int] + use_layer_norm: bool = False + dropout_rate: Optional[float] = None + n_critics: int = 2 + + @nn.compact + def __call__(self, obs: jnp.ndarray, action: jnp.ndarray, z_state: jnp.ndarray, z_state_action: jnp.ndarray): + # Idea taken from https://github.com/perrin-isir/xpag + # Similar to https://github.com/tinkoff-ai/CORL for PyTorch + vmap_critic = nn.vmap( + Critic, + variable_axes={"params": 0}, # parameters not shared between the critics + split_rngs={"params": True, "dropout": True}, # different initializations + in_axes=None, + out_axes=0, + axis_size=self.n_critics, + ) + q_values = vmap_critic( + use_layer_norm=self.use_layer_norm, + dropout_rate=self.dropout_rate, + net_arch=self.net_arch, + )(obs, action, z_state, z_state_action) + return q_values + + +class Actor(nn.Module): + net_arch: Sequence[int] + action_dim: int + log_std_min: float = -20 + log_std_max: float = 2 + + def get_std(self): + # Make it work with gSDE + return jnp.array(0.0) + + @nn.compact + def __call__(self, x: jnp.ndarray, z_state: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-defined] + x = avg_l1_norm(nn.Dense(self.net_arch[0])(x)) + # Combine with encoding + x = jnp.concatenate([x, z_state], -1) + for n_units in self.net_arch: + x = nn.Dense(n_units)(x) + x = nn.relu(x) + mean = nn.Dense(self.action_dim)(x) + log_std = nn.Dense(self.action_dim)(x) + log_std = jnp.clip(log_std, self.log_std_min, self.log_std_max) + dist = TanhTransformedDistribution( + tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)), + ) + return dist + + +class SAC7Policy(BaseJaxPolicy): + action_space: spaces.Box # type: ignore[assignment] + + def __init__( + self, + observation_space: spaces.Space, + action_space: spaces.Box, + lr_schedule: Schedule, + net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, + dropout_rate: float = 0.0, + layer_norm: bool = False, + # activation_fn: Type[nn.Module] = nn.ReLU, + use_sde: bool = False, + # Note: most gSDE parameters are not used + # this is to keep API consistent with SB3 + log_std_init: float = -3, + use_expln: bool = False, + clip_mean: float = 2.0, + features_extractor_class=None, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Callable[..., optax.GradientTransformation] = optax.adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + n_critics: int = 2, + share_features_extractor: bool = False, + ): + super().__init__( + observation_space, + action_space, + features_extractor_class, + features_extractor_kwargs, + optimizer_class=optimizer_class, + optimizer_kwargs=optimizer_kwargs, + squash_output=True, + ) + self.dropout_rate = dropout_rate + self.layer_norm = layer_norm + if net_arch is not None: + if isinstance(net_arch, list): + self.net_arch_pi = self.net_arch_qf = net_arch + else: + self.net_arch_pi = net_arch["pi"] + self.net_arch_qf = net_arch["qf"] + else: + self.net_arch_pi = self.net_arch_qf = [256, 256] + self.n_critics = n_critics + self.use_sde = use_sde + + self.key = self.noise_key = jax.random.PRNGKey(0) + + def build(self, key: jax.random.KeyArray, lr_schedule: Schedule, qf_learning_rate: float) -> jax.random.KeyArray: + key, actor_key, qf_key, dropout_key = jax.random.split(key, 4) + # Keys for the encoder and state action encoder + key, encoder_key, action_encoder_key = jax.random.split(key, 3) + + # Keep a key for the actor + key, self.key = jax.random.split(key, 2) + # Initialize noise + self.reset_noise() + + if isinstance(self.observation_space, spaces.Dict): + obs = jnp.array([spaces.flatten(self.observation_space, self.observation_space.sample())]) + else: + obs = jnp.array([self.observation_space.sample()]) + action = jnp.array([self.action_space.sample()]) + + encoding_dim = 256 + z_state = jnp.zeros((1, encoding_dim)) + z_state_action = jnp.zeros((1, encoding_dim)) + + self.actor = Actor( + action_dim=int(np.prod(self.action_space.shape)), + net_arch=self.net_arch_pi, + ) + # Hack to make gSDE work without modifying internal SB3 code + self.actor.reset_noise = self.reset_noise + + self.actor_state = TrainState.create( + apply_fn=self.actor.apply, + params=self.actor.init(actor_key, obs, z_state), + tx=self.optimizer_class( + learning_rate=lr_schedule(1), # type: ignore[call-arg] + **self.optimizer_kwargs, + ), + ) + + self.qf = VectorCritic( + dropout_rate=self.dropout_rate, + use_layer_norm=self.layer_norm, + net_arch=self.net_arch_qf, + n_critics=self.n_critics, + ) + + self.qf_state = RLTrainState.create( + apply_fn=self.qf.apply, + params=self.qf.init( + {"params": qf_key, "dropout": dropout_key}, + obs, + action, + z_state, + z_state_action, + ), + target_params=self.qf.init( + {"params": qf_key, "dropout": dropout_key}, + obs, + action, + z_state, + z_state_action, + ), + tx=self.optimizer_class( + learning_rate=qf_learning_rate, # type: ignore[call-arg] + **self.optimizer_kwargs, + ), + ) + + self.encoder = StateEncoder( + net_arch=[256, 256], + ) + self.action_encoder = StateActionEncoder( + net_arch=[256, 256], + ) + + self.encoder_state = RLTrainState.create( + apply_fn=self.encoder.apply, + params=self.encoder.init( + {"params": encoder_key}, + obs, + ), + target_params=self.encoder.init( + {"params": encoder_key}, + obs, + ), + tx=self.optimizer_class( + learning_rate=qf_learning_rate, # type: ignore[call-arg] + **self.optimizer_kwargs, + ), + ) + + self.action_encoder_state = RLTrainState.create( + apply_fn=self.action_encoder.apply, + params=self.action_encoder.init( + {"params": action_encoder_key}, + obs, + action, + ), + target_params=self.action_encoder.init( + {"params": action_encoder_key}, + obs, + action, + ), + tx=self.optimizer_class( + learning_rate=qf_learning_rate, # type: ignore[call-arg] + **self.optimizer_kwargs, + ), + ) + self.encoder.apply = jax.jit(self.encoder.apply) # type: ignore[method-assign] + self.action_encoder.apply = jax.jit(self.action_encoder.apply) # type: ignore[method-assign] + self.actor.apply = jax.jit(self.actor.apply) # type: ignore[method-assign] + self.qf.apply = jax.jit( # type: ignore[method-assign] + self.qf.apply, + static_argnames=("dropout_rate", "use_layer_norm"), + ) + + return key + + def reset_noise(self, batch_size: int = 1) -> None: + """ + Sample new weights for the exploration matrix, when using gSDE. + """ + self.key, self.noise_key = jax.random.split(self.key, 2) + + def forward(self, obs: np.ndarray, deterministic: bool = False) -> np.ndarray: + return self._predict(obs, deterministic=deterministic) + + def _predict(self, observation: np.ndarray, deterministic: bool = False) -> np.ndarray: # type: ignore[override] + if deterministic: + return SAC7Policy.select_action(self.actor_state, self.encoder_state, observation) + # Trick to use gSDE: repeat sampled noise by using the same noise key + if not self.use_sde: + self.reset_noise() + return SAC7Policy.sample_action(self.actor_state, self.encoder_state, observation, self.noise_key) + + @staticmethod + @jax.jit + def sample_action(actor_state, encoder_state, obervations, key): + z_state = encoder_state.apply_fn(encoder_state.params, obervations) + dist = actor_state.apply_fn(actor_state.params, obervations, z_state) + action = dist.sample(seed=key) + return action + + @staticmethod + @jax.jit + def select_action(actor_state, encoder_state, obervations): + z_state = encoder_state.apply_fn(encoder_state.params, obervations) + return actor_state.apply_fn(actor_state.params, obervations, z_state).mode() diff --git a/sbx/sac7/sac7.py b/sbx/sac7/sac7.py new file mode 100644 index 0000000..5848c85 --- /dev/null +++ b/sbx/sac7/sac7.py @@ -0,0 +1,471 @@ +from functools import partial +from typing import Any, Dict, Optional, Tuple, Type, Union + +import flax +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +import optax +from flax.training.train_state import TrainState +from gymnasium import spaces +from stable_baselines3.common.buffers import ReplayBuffer +from stable_baselines3.common.noise import ActionNoise +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule + +from sbx.common.off_policy_algorithm import OffPolicyAlgorithmJax +from sbx.common.type_aliases import ReplayBufferSamplesNp, RLTrainState +from sbx.sac7.policies import SAC7Policy + + +class EntropyCoef(nn.Module): + ent_coef_init: float = 1.0 + + @nn.compact + def __call__(self) -> jnp.ndarray: + log_ent_coef = self.param("log_ent_coef", init_fn=lambda key: jnp.full((), jnp.log(self.ent_coef_init))) + return jnp.exp(log_ent_coef) + + +class ConstantEntropyCoef(nn.Module): + ent_coef_init: float = 1.0 + + @nn.compact + def __call__(self) -> float: + # Hack to not optimize the entropy coefficient while not having to use if/else for the jit + # TODO: add parameter in train to remove that hack + self.param("dummy_param", init_fn=lambda key: jnp.full((), self.ent_coef_init)) + return self.ent_coef_init + + +class SAC7(OffPolicyAlgorithmJax): + policy_aliases: Dict[str, Type[SAC7Policy]] = { # type: ignore[assignment] + "MlpPolicy": SAC7Policy, + # Minimal dict support using flatten() + "MultiInputPolicy": SAC7Policy, + } + + policy: SAC7Policy + action_space: spaces.Box # type: ignore[assignment] + + def __init__( + self, + policy, + env: Union[GymEnv, str], + learning_rate: Union[float, Schedule] = 3e-4, + qf_learning_rate: Optional[float] = None, + buffer_size: int = 1_000_000, # 1e6 + learning_starts: int = 100, + batch_size: int = 256, + tau: float = 0.005, + gamma: float = 0.99, + train_freq: Union[int, Tuple[int, str]] = 1, + gradient_steps: int = 1, + policy_delay: int = 1, + action_noise: Optional[ActionNoise] = None, + replay_buffer_class: Optional[Type[ReplayBuffer]] = None, + replay_buffer_kwargs: Optional[Dict[str, Any]] = None, + ent_coef: Union[str, float] = "auto", + use_sde: bool = False, + sde_sample_freq: int = -1, + use_sde_at_warmup: bool = False, + tensorboard_log: Optional[str] = None, + policy_kwargs: Optional[Dict[str, Any]] = None, + verbose: int = 0, + seed: Optional[int] = None, + device: str = "auto", + _init_setup_model: bool = True, + ) -> None: + super().__init__( + policy=policy, + env=env, + learning_rate=learning_rate, + qf_learning_rate=qf_learning_rate, + buffer_size=buffer_size, + learning_starts=learning_starts, + batch_size=batch_size, + tau=tau, + gamma=gamma, + train_freq=train_freq, + gradient_steps=gradient_steps, + action_noise=action_noise, + replay_buffer_class=replay_buffer_class, + replay_buffer_kwargs=replay_buffer_kwargs, + use_sde=use_sde, + sde_sample_freq=sde_sample_freq, + use_sde_at_warmup=use_sde_at_warmup, + policy_kwargs=policy_kwargs, + tensorboard_log=tensorboard_log, + verbose=verbose, + seed=seed, + supported_action_spaces=(spaces.Box,), + support_multi_env=True, + ) + + self.policy_delay = policy_delay + self.ent_coef_init = ent_coef + + if _init_setup_model: + self._setup_model() + + def _setup_model(self) -> None: + super()._setup_model() + + if not hasattr(self, "policy") or self.policy is None: + # pytype: disable=not-instantiable + self.policy = self.policy_class( # type: ignore[assignment] + self.observation_space, + self.action_space, + self.lr_schedule, + **self.policy_kwargs, + ) + # pytype: enable=not-instantiable + + assert isinstance(self.qf_learning_rate, float) + + self.key = self.policy.build(self.key, self.lr_schedule, self.qf_learning_rate) + + self.key, ent_key = jax.random.split(self.key, 2) + + self.actor = self.policy.actor # type: ignore[assignment] + self.qf = self.policy.qf # type: ignore[assignment] + self.encoder = self.policy.encoder # type: ignore[assignment] + self.action_encoder = self.policy.action_encoder # type: ignore[assignment] + + # The entropy coefficient or entropy can be learned automatically + # see Automating Entropy Adjustment for Maximum Entropy RL section + # of https://arxiv.org/abs/1812.05905 + if isinstance(self.ent_coef_init, str) and self.ent_coef_init.startswith("auto"): + # Default initial value of ent_coef when learned + ent_coef_init = 1.0 + if "_" in self.ent_coef_init: + ent_coef_init = float(self.ent_coef_init.split("_")[1]) + assert ent_coef_init > 0.0, "The initial value of ent_coef must be greater than 0" + + # Note: we optimize the log of the entropy coeff which is slightly different from the paper + # as discussed in https://github.com/rail-berkeley/softlearning/issues/37 + self.ent_coef = EntropyCoef(ent_coef_init) + else: + # This will throw an error if a malformed string (different from 'auto') is passed + assert isinstance( + self.ent_coef_init, float + ), f"Entropy coef must be float when not equal to 'auto', actual: {self.ent_coef_init}" + self.ent_coef = ConstantEntropyCoef(self.ent_coef_init) # type: ignore[assignment] + + self.ent_coef_state = TrainState.create( + apply_fn=self.ent_coef.apply, + params=self.ent_coef.init(ent_key)["params"], + tx=optax.adam( + learning_rate=self.learning_rate, + ), + ) + + # automatically set target entropy if needed + self.target_entropy = -np.prod(self.action_space.shape).astype(np.float32) + + def learn( + self, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 4, + tb_log_name: str = "SAC", + reset_num_timesteps: bool = True, + progress_bar: bool = False, + ): + return super().learn( + total_timesteps=total_timesteps, + callback=callback, + log_interval=log_interval, + tb_log_name=tb_log_name, + reset_num_timesteps=reset_num_timesteps, + progress_bar=progress_bar, + ) + + def train(self, batch_size, gradient_steps): + # Sample all at once for efficiency (so we can jit the for loop) + data = self.replay_buffer.sample(batch_size * gradient_steps, env=self._vec_normalize_env) + # Pre-compute the indices where we need to update the actor + # This is a hack in order to jit the train loop + # It will compile once per value of policy_delay_indices + policy_delay_indices = {i: True for i in range(gradient_steps) if ((self._n_updates + i + 1) % self.policy_delay) == 0} + policy_delay_indices = flax.core.FrozenDict(policy_delay_indices) + + if isinstance(data.observations, dict): + keys = list(self.observation_space.keys()) + obs = np.concatenate([data.observations[key].numpy() for key in keys], axis=1) + next_obs = np.concatenate([data.next_observations[key].numpy() for key in keys], axis=1) + else: + obs = data.observations.numpy() + next_obs = data.next_observations.numpy() + + # Convert to numpy + data = ReplayBufferSamplesNp( + obs, + data.actions.numpy(), + next_obs, + data.dones.numpy().flatten(), + data.rewards.numpy().flatten(), + ) + + ( + self.policy.qf_state, + self.policy.actor_state, + self.ent_coef_state, + self.policy.encoder_state, + self.policy.action_encoder_state, + self.key, + (actor_loss_value, qf_loss_value, ent_coef_value), + ) = self._train( + self.gamma, + self.tau, + self.target_entropy, + gradient_steps, + data, + policy_delay_indices, + self.policy.qf_state, + self.policy.actor_state, + self.ent_coef_state, + self.policy.encoder_state, + self.policy.action_encoder_state, + self.key, + ) + self._n_updates += gradient_steps + self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") + self.logger.record("train/actor_loss", actor_loss_value.item()) + self.logger.record("train/critic_loss", qf_loss_value.item()) + self.logger.record("train/ent_coef", ent_coef_value.item()) + + @staticmethod + @jax.jit + def update_critic( + gamma: float, + actor_state: TrainState, + qf_state: RLTrainState, + ent_coef_state: TrainState, + encoder_state: RLTrainState, + action_encoder_state: RLTrainState, + observations: np.ndarray, + actions: np.ndarray, + next_observations: np.ndarray, + rewards: np.ndarray, + dones: np.ndarray, + key: jax.random.KeyArray, + ): + key, noise_key, dropout_key_target, dropout_key_current = jax.random.split(key, 4) + + z_next = encoder_state.apply_fn(encoder_state.target_params, next_observations) + + # sample action from the actor + dist = actor_state.apply_fn(actor_state.params, next_observations, z_next) + next_state_actions = dist.sample(seed=noise_key) + next_log_prob = dist.log_prob(next_state_actions) + + ent_coef_value = ent_coef_state.apply_fn({"params": ent_coef_state.params}) + + z_action_next = action_encoder_state.apply_fn( + action_encoder_state.target_params, next_observations, next_state_actions + ) + + qf_next_values = qf_state.apply_fn( + qf_state.target_params, + next_observations, + next_state_actions, + z_next, + z_action_next, + rngs={"dropout": dropout_key_target}, + ) + + next_q_values = jnp.min(qf_next_values, axis=0) + # td error + entropy term + next_q_values = next_q_values - ent_coef_value * next_log_prob.reshape(-1, 1) + # shape is (batch_size, 1) + target_q_values = rewards.reshape(-1, 1) + (1 - dones.reshape(-1, 1)) * gamma * next_q_values + + z_state = encoder_state.apply_fn(encoder_state.target_params, observations) + z_state_action = action_encoder_state.apply_fn(action_encoder_state.target_params, observations, actions) + + def mse_loss(params, dropout_key): + # shape is (n_critics, batch_size, 1) + current_q_values = qf_state.apply_fn( + params, + observations, + actions, + z_state, + z_state_action, + rngs={"dropout": dropout_key}, + ) + return 0.5 * ((target_q_values - current_q_values) ** 2).mean(axis=1).sum() + + qf_loss_value, grads = jax.value_and_grad(mse_loss, has_aux=False)(qf_state.params, dropout_key_current) + qf_state = qf_state.apply_gradients(grads=grads) + + return ( + qf_state, + (qf_loss_value, ent_coef_value), + key, + ) + + @staticmethod + @jax.jit + def update_actor( + actor_state: RLTrainState, + qf_state: RLTrainState, + ent_coef_state: TrainState, + encoder_state: RLTrainState, + action_encoder_state: RLTrainState, + observations: np.ndarray, + key: jax.random.KeyArray, + ): + key, dropout_key, noise_key = jax.random.split(key, 3) + + z_state = encoder_state.apply_fn(encoder_state.target_params, observations) + + def actor_loss(params): + dist = actor_state.apply_fn(params, observations, z_state) + actor_actions = dist.sample(seed=noise_key) + log_prob = dist.log_prob(actor_actions).reshape(-1, 1) + + z_state_action = action_encoder_state.apply_fn(action_encoder_state.target_params, observations, actor_actions) + + qf_pi = qf_state.apply_fn( + qf_state.params, + observations, + actor_actions, + z_state, + z_state_action, + rngs={"dropout": dropout_key}, + ) + # Take min among all critics (mean for droq) + min_qf_pi = jnp.min(qf_pi, axis=0) + ent_coef_value = ent_coef_state.apply_fn({"params": ent_coef_state.params}) + actor_loss = (ent_coef_value * log_prob - min_qf_pi).mean() + return actor_loss, -log_prob.mean() + + (actor_loss_value, entropy), grads = jax.value_and_grad(actor_loss, has_aux=True)(actor_state.params) + actor_state = actor_state.apply_gradients(grads=grads) + + return actor_state, qf_state, actor_loss_value, key, entropy + + @staticmethod + @jax.jit + def soft_update(tau: float, qf_state: RLTrainState): + qf_state = qf_state.replace(target_params=optax.incremental_update(qf_state.params, qf_state.target_params, tau)) + return qf_state + + @staticmethod + @jax.jit + def update_temperature(target_entropy: np.ndarray, ent_coef_state: TrainState, entropy: float): + def temperature_loss(temp_params): + ent_coef_value = ent_coef_state.apply_fn({"params": temp_params}) + ent_coef_loss = ent_coef_value * (entropy - target_entropy).mean() + return ent_coef_loss + + ent_coef_loss, grads = jax.value_and_grad(temperature_loss)(ent_coef_state.params) + ent_coef_state = ent_coef_state.apply_gradients(grads=grads) + + return ent_coef_state, ent_coef_loss + + @staticmethod + @jax.jit + def update_encoders( + encoder_state: RLTrainState, + action_encoder_state: RLTrainState, + observations: np.ndarray, + actions: np.ndarray, + next_observations: np.ndarray, + ): + z_next_state = jax.lax.stop_gradient(encoder_state.apply_fn(encoder_state.params, next_observations)) + + def encoder_loss(encoder_params, action_encoder_params): + # TODO: include z_state in the action state encoder + z_state = encoder_state.apply_fn(encoder_params, observations) + z_state_action = action_encoder_state.apply_fn(action_encoder_params, observations, actions) + # encoder_loss = optax.huber_loss(z_state_action, z_next_state).mean() + return optax.l2_loss(z_state_action, z_next_state).mean() + + _, grads = jax.value_and_grad(encoder_loss)(encoder_state.params, action_encoder_state.params) + + encoder_state = encoder_state.apply_gradients(grads=grads) + # TODO: enable update, update both at the same time + # action_encoder_state = action_encoder_state.apply_gradients(grads=grads) + + return encoder_state, action_encoder_state + + @classmethod + @partial(jax.jit, static_argnames=["cls", "gradient_steps"]) + def _train( + cls, + gamma: float, + tau: float, + target_entropy: np.ndarray, + gradient_steps: int, + data: ReplayBufferSamplesNp, + policy_delay_indices: flax.core.FrozenDict, + qf_state: RLTrainState, + actor_state: TrainState, + ent_coef_state: TrainState, + encoder_state: RLTrainState, + action_encoder_state: RLTrainState, + key, + ): + actor_loss_value = jnp.array(0) + + for i in range(gradient_steps): + + def slice(x, step=i): + assert x.shape[0] % gradient_steps == 0 + batch_size = x.shape[0] // gradient_steps + return x[batch_size * step : batch_size * (step + 1)] + + encoder_state, action_encoder_state = SAC7.update_encoders( + encoder_state, + action_encoder_state, + slice(data.observations), + slice(data.actions), + slice(data.next_observations), + ) + + ( + qf_state, + (qf_loss_value, ent_coef_value), + key, + ) = SAC7.update_critic( + gamma, + actor_state, + qf_state, + ent_coef_state, + encoder_state, + action_encoder_state, + slice(data.observations), + slice(data.actions), + slice(data.next_observations), + slice(data.rewards), + slice(data.dones), + key, + ) + qf_state = SAC7.soft_update(tau, qf_state) + encoder_state = SAC7.soft_update(tau, encoder_state) + action_encoder_state = SAC7.soft_update(tau, action_encoder_state) + + # hack to be able to jit (n_updates % policy_delay == 0) + if i in policy_delay_indices: + (actor_state, qf_state, actor_loss_value, key, entropy) = cls.update_actor( + actor_state, + qf_state, + ent_coef_state, + encoder_state, + action_encoder_state, + slice(data.observations), + key, + ) + ent_coef_state, _ = SAC7.update_temperature(target_entropy, ent_coef_state, entropy) + + return ( + qf_state, + actor_state, + ent_coef_state, + encoder_state, + action_encoder_state, + key, + (actor_loss_value, qf_loss_value, ent_coef_value), + ) diff --git a/tests/test_run.py b/tests/test_run.py index bcf2f12..764e2e9 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -7,7 +7,7 @@ from stable_baselines3.common.envs import BitFlippingEnv from stable_baselines3.common.evaluation import evaluate_policy -from sbx import DQN, PPO, SAC, TQC, DroQ +from sbx import DQN, PPO, SAC, SAC7, TQC, DroQ def test_droq(tmp_path): @@ -60,7 +60,7 @@ def test_tqc() -> None: model.learn(200) -@pytest.mark.parametrize("model_class", [SAC]) +@pytest.mark.parametrize("model_class", [SAC, SAC7]) def test_sac(model_class: Type[SAC]) -> None: model = model_class( "MlpPolicy", From e1e16125eed9fa2bc170af56c96f94c79c028697 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Tue, 6 Jun 2023 15:57:16 +0200 Subject: [PATCH 02/14] Update package command --- Makefile | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/Makefile b/Makefile index a3dbc06..50ab200 100644 --- a/Makefile +++ b/Makefile @@ -44,14 +44,12 @@ commit-checks: format type lint # PyPi package release release: - python setup.py sdist - python setup.py bdist_wheel + python -m build twine upload dist/* # Test PyPi package release test-release: - python setup.py sdist - python setup.py bdist_wheel + python -m build twine upload --repository-url https://test.pypi.org/legacy/ dist/* .PHONY: clean spelling doc lint format check-codestyle commit-checks From 3689b4a0c5b127abe06e8176756501b3fb59f097 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 7 Jun 2023 11:10:12 +0200 Subject: [PATCH 03/14] Fix encoder update --- sbx/sac7/policies.py | 8 ++++---- sbx/sac7/sac7.py | 21 ++++++++++----------- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/sbx/sac7/policies.py b/sbx/sac7/policies.py index 5d321fd..1c5264a 100644 --- a/sbx/sac7/policies.py +++ b/sbx/sac7/policies.py @@ -48,8 +48,8 @@ class StateActionEncoder(nn.Module): embedding_dim: int = 256 @nn.compact - def __call__(self, x: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray: - x = jnp.concatenate([x, action], -1) + def __call__(self, z_state: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray: + x = jnp.concatenate([z_state, action], -1) for n_units in self.net_arch: x = nn.Dense(n_units)(x) x = nn.elu(x) @@ -276,12 +276,12 @@ def build(self, key: jax.random.KeyArray, lr_schedule: Schedule, qf_learning_rat apply_fn=self.action_encoder.apply, params=self.action_encoder.init( {"params": action_encoder_key}, - obs, + z_state, action, ), target_params=self.action_encoder.init( {"params": action_encoder_key}, - obs, + z_state, action, ), tx=self.optimizer_class( diff --git a/sbx/sac7/sac7.py b/sbx/sac7/sac7.py index 5848c85..4ad09ee 100644 --- a/sbx/sac7/sac7.py +++ b/sbx/sac7/sac7.py @@ -262,9 +262,7 @@ def update_critic( ent_coef_value = ent_coef_state.apply_fn({"params": ent_coef_state.params}) - z_action_next = action_encoder_state.apply_fn( - action_encoder_state.target_params, next_observations, next_state_actions - ) + z_action_next = action_encoder_state.apply_fn(action_encoder_state.target_params, z_next, next_state_actions) qf_next_values = qf_state.apply_fn( qf_state.target_params, @@ -282,7 +280,7 @@ def update_critic( target_q_values = rewards.reshape(-1, 1) + (1 - dones.reshape(-1, 1)) * gamma * next_q_values z_state = encoder_state.apply_fn(encoder_state.target_params, observations) - z_state_action = action_encoder_state.apply_fn(action_encoder_state.target_params, observations, actions) + z_state_action = action_encoder_state.apply_fn(action_encoder_state.target_params, z_state, actions) def mse_loss(params, dropout_key): # shape is (n_critics, batch_size, 1) @@ -325,7 +323,7 @@ def actor_loss(params): actor_actions = dist.sample(seed=noise_key) log_prob = dist.log_prob(actor_actions).reshape(-1, 1) - z_state_action = action_encoder_state.apply_fn(action_encoder_state.target_params, observations, actor_actions) + z_state_action = action_encoder_state.apply_fn(action_encoder_state.target_params, z_state, actor_actions) qf_pi = qf_state.apply_fn( qf_state.params, @@ -377,17 +375,18 @@ def update_encoders( z_next_state = jax.lax.stop_gradient(encoder_state.apply_fn(encoder_state.params, next_observations)) def encoder_loss(encoder_params, action_encoder_params): - # TODO: include z_state in the action state encoder z_state = encoder_state.apply_fn(encoder_params, observations) - z_state_action = action_encoder_state.apply_fn(action_encoder_params, observations, actions) + z_state_action = action_encoder_state.apply_fn(action_encoder_params, z_state, actions) # encoder_loss = optax.huber_loss(z_state_action, z_next_state).mean() return optax.l2_loss(z_state_action, z_next_state).mean() - _, grads = jax.value_and_grad(encoder_loss)(encoder_state.params, action_encoder_state.params) + _, (encoder_grads, action_encoder_grads) = jax.value_and_grad(encoder_loss, argnums=(0, 1))( + encoder_state.params, + action_encoder_state.params, + ) - encoder_state = encoder_state.apply_gradients(grads=grads) - # TODO: enable update, update both at the same time - # action_encoder_state = action_encoder_state.apply_gradients(grads=grads) + encoder_state = encoder_state.apply_gradients(grads=encoder_grads) + action_encoder_state = action_encoder_state.apply_gradients(grads=action_encoder_grads) return encoder_state, action_encoder_state From c92d66f0a4ecd31d41b22d0ea727ffb0a163fdaf Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 7 Jun 2023 11:50:53 +0200 Subject: [PATCH 04/14] Cleanup: expose params --- sbx/sac7/policies.py | 22 ++++++++++------------ sbx/sac7/sac7.py | 8 ++++++++ 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/sbx/sac7/policies.py b/sbx/sac7/policies.py index 1c5264a..d699222 100644 --- a/sbx/sac7/policies.py +++ b/sbx/sac7/policies.py @@ -18,13 +18,6 @@ tfd = tfp.distributions -# class AvgL1Norm(nn.Module): -# eps: float = 1e-8 -# @nn.compact -# def __call__(self, x: jnp.ndarray) -> jnp.ndarray: -# return x / jnp.clip(jnp.mean(jnp.abs(x), axis=-1, keepdims=True), a_min=self.eps) - - @jax.jit def avg_l1_norm(x: jnp.ndarray, eps: float = 1e-8) -> jnp.ndarray: return x / jnp.clip(jnp.mean(jnp.abs(x), axis=-1, keepdims=True), a_min=eps) @@ -145,6 +138,7 @@ def __init__( net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, dropout_rate: float = 0.0, layer_norm: bool = False, + embedding_dim: int = 256, # activation_fn: Type[nn.Module] = nn.ReLU, use_sde: bool = False, # Note: most gSDE parameters are not used @@ -171,14 +165,17 @@ def __init__( ) self.dropout_rate = dropout_rate self.layer_norm = layer_norm + self.embedding_dim = embedding_dim if net_arch is not None: if isinstance(net_arch, list): self.net_arch_pi = self.net_arch_qf = net_arch else: self.net_arch_pi = net_arch["pi"] self.net_arch_qf = net_arch["qf"] + self.net_arch_encoder = net_arch["encoder"] else: self.net_arch_pi = self.net_arch_qf = [256, 256] + self.net_arch_encoder = [256, 256] self.n_critics = n_critics self.use_sde = use_sde @@ -200,9 +197,8 @@ def build(self, key: jax.random.KeyArray, lr_schedule: Schedule, qf_learning_rat obs = jnp.array([self.observation_space.sample()]) action = jnp.array([self.action_space.sample()]) - encoding_dim = 256 - z_state = jnp.zeros((1, encoding_dim)) - z_state_action = jnp.zeros((1, encoding_dim)) + z_state = jnp.zeros((1, self.embedding_dim)) + z_state_action = jnp.zeros((1, self.embedding_dim)) self.actor = Actor( action_dim=int(np.prod(self.action_space.shape)), @@ -250,10 +246,12 @@ def build(self, key: jax.random.KeyArray, lr_schedule: Schedule, qf_learning_rat ) self.encoder = StateEncoder( - net_arch=[256, 256], + net_arch=self.net_arch_encoder, + embedding_dim=self.embedding_dim, ) self.action_encoder = StateActionEncoder( - net_arch=[256, 256], + net_arch=self.net_arch_encoder, + embedding_dim=self.embedding_dim, ) self.encoder_state = RLTrainState.create( diff --git a/sbx/sac7/sac7.py b/sbx/sac7/sac7.py index 4ad09ee..4dbec37 100644 --- a/sbx/sac7/sac7.py +++ b/sbx/sac7/sac7.py @@ -39,6 +39,14 @@ def __call__(self) -> float: class SAC7(OffPolicyAlgorithmJax): + """ + Including State-Action Representation learning in SAC + TD7: https://github.com/sfujim/TD7 + + Note: The rest of the tricks (LAP replay buffer, checkpoints, extrapolation error correction, huber loss) + are not yet implemented + """ + policy_aliases: Dict[str, Type[SAC7Policy]] = { # type: ignore[assignment] "MlpPolicy": SAC7Policy, # Minimal dict support using flatten() From d215ecc7fb10fede83a9f31a41cb5a4f69aefeb3 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 7 Jun 2023 12:08:12 +0200 Subject: [PATCH 05/14] Avoid recomputation of z_state --- sbx/sac7/sac7.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/sbx/sac7/sac7.py b/sbx/sac7/sac7.py index 4dbec37..36c9b4d 100644 --- a/sbx/sac7/sac7.py +++ b/sbx/sac7/sac7.py @@ -252,6 +252,7 @@ def update_critic( ent_coef_state: TrainState, encoder_state: RLTrainState, action_encoder_state: RLTrainState, + z_state: jnp.ndarray, observations: np.ndarray, actions: np.ndarray, next_observations: np.ndarray, @@ -287,7 +288,7 @@ def update_critic( # shape is (batch_size, 1) target_q_values = rewards.reshape(-1, 1) + (1 - dones.reshape(-1, 1)) * gamma * next_q_values - z_state = encoder_state.apply_fn(encoder_state.target_params, observations) + # z_state = encoder_state.apply_fn(encoder_state.target_params, observations) z_state_action = action_encoder_state.apply_fn(action_encoder_state.target_params, z_state, actions) def mse_loss(params, dropout_key): @@ -319,12 +320,13 @@ def update_actor( ent_coef_state: TrainState, encoder_state: RLTrainState, action_encoder_state: RLTrainState, + z_state: jnp.ndarray, observations: np.ndarray, key: jax.random.KeyArray, ): key, dropout_key, noise_key = jax.random.split(key, 3) - z_state = encoder_state.apply_fn(encoder_state.target_params, observations) + # z_state = encoder_state.apply_fn(encoder_state.target_params, observations) def actor_loss(params): dist = actor_state.apply_fn(params, observations, z_state) @@ -432,6 +434,9 @@ def slice(x, step=i): slice(data.next_observations), ) + z_state = encoder_state.apply_fn(encoder_state.target_params, slice(data.observations)) + + ( qf_state, (qf_loss_value, ent_coef_value), @@ -443,6 +448,7 @@ def slice(x, step=i): ent_coef_state, encoder_state, action_encoder_state, + z_state, slice(data.observations), slice(data.actions), slice(data.next_observations), @@ -462,6 +468,7 @@ def slice(x, step=i): ent_coef_state, encoder_state, action_encoder_state, + z_state, slice(data.observations), key, ) From d422628b1780c2abee6d09f6f85a2affb26cc940 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 7 Jun 2023 12:17:52 +0200 Subject: [PATCH 06/14] Revert "Avoid recomputation of z_state" This reverts commit d215ecc7fb10fede83a9f31a41cb5a4f69aefeb3. --- sbx/sac7/sac7.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/sbx/sac7/sac7.py b/sbx/sac7/sac7.py index 36c9b4d..4dbec37 100644 --- a/sbx/sac7/sac7.py +++ b/sbx/sac7/sac7.py @@ -252,7 +252,6 @@ def update_critic( ent_coef_state: TrainState, encoder_state: RLTrainState, action_encoder_state: RLTrainState, - z_state: jnp.ndarray, observations: np.ndarray, actions: np.ndarray, next_observations: np.ndarray, @@ -288,7 +287,7 @@ def update_critic( # shape is (batch_size, 1) target_q_values = rewards.reshape(-1, 1) + (1 - dones.reshape(-1, 1)) * gamma * next_q_values - # z_state = encoder_state.apply_fn(encoder_state.target_params, observations) + z_state = encoder_state.apply_fn(encoder_state.target_params, observations) z_state_action = action_encoder_state.apply_fn(action_encoder_state.target_params, z_state, actions) def mse_loss(params, dropout_key): @@ -320,13 +319,12 @@ def update_actor( ent_coef_state: TrainState, encoder_state: RLTrainState, action_encoder_state: RLTrainState, - z_state: jnp.ndarray, observations: np.ndarray, key: jax.random.KeyArray, ): key, dropout_key, noise_key = jax.random.split(key, 3) - # z_state = encoder_state.apply_fn(encoder_state.target_params, observations) + z_state = encoder_state.apply_fn(encoder_state.target_params, observations) def actor_loss(params): dist = actor_state.apply_fn(params, observations, z_state) @@ -434,9 +432,6 @@ def slice(x, step=i): slice(data.next_observations), ) - z_state = encoder_state.apply_fn(encoder_state.target_params, slice(data.observations)) - - ( qf_state, (qf_loss_value, ent_coef_value), @@ -448,7 +443,6 @@ def slice(x, step=i): ent_coef_state, encoder_state, action_encoder_state, - z_state, slice(data.observations), slice(data.actions), slice(data.next_observations), @@ -468,7 +462,6 @@ def slice(x, step=i): ent_coef_state, encoder_state, action_encoder_state, - z_state, slice(data.observations), key, ) From 9648a99de9b39fd88549801c450409d46ef946eb Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 7 Jun 2023 13:00:46 +0200 Subject: [PATCH 07/14] Add target value clipping --- sbx/sac7/sac7.py | 35 ++++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/sbx/sac7/sac7.py b/sbx/sac7/sac7.py index 4dbec37..9d6958b 100644 --- a/sbx/sac7/sac7.py +++ b/sbx/sac7/sac7.py @@ -140,6 +140,10 @@ def _setup_model(self) -> None: self.encoder = self.policy.encoder # type: ignore[assignment] self.action_encoder = self.policy.action_encoder # type: ignore[assignment] + # Value clipping tracked values + self.min_qf_target = 0 + self.max_qf_target = 0 + # The entropy coefficient or entropy can be learned automatically # see Automating Entropy Adjustment for Maximum Entropy RL section # of https://arxiv.org/abs/1812.05905 @@ -206,6 +210,11 @@ def train(self, batch_size, gradient_steps): obs = data.observations.numpy() next_obs = data.next_observations.numpy() + # Initialize clipping at the first iteration + if self._n_updates == 0: + self.min_qf_target = data.rewards.min().item() + self.max_qf_target = data.rewards.max().item() + # Convert to numpy data = ReplayBufferSamplesNp( obs, @@ -223,6 +232,7 @@ def train(self, batch_size, gradient_steps): self.policy.action_encoder_state, self.key, (actor_loss_value, qf_loss_value, ent_coef_value), + (self.min_qf_target, self.max_qf_target), ) = self._train( self.gamma, self.tau, @@ -236,6 +246,8 @@ def train(self, batch_size, gradient_steps): self.policy.encoder_state, self.policy.action_encoder_state, self.key, + self.min_qf_target, + self.max_qf_target, ) self._n_updates += gradient_steps self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") @@ -258,6 +270,8 @@ def update_critic( rewards: np.ndarray, dones: np.ndarray, key: jax.random.KeyArray, + min_qf_target: float, + max_qf_target: float, ): key, noise_key, dropout_key_target, dropout_key_current = jax.random.split(key, 4) @@ -284,9 +298,18 @@ def update_critic( next_q_values = jnp.min(qf_next_values, axis=0) # td error + entropy term next_q_values = next_q_values - ent_coef_value * next_log_prob.reshape(-1, 1) + + # Clip next q_values + next_q_values = jnp.clip(next_q_values, min_qf_target, max_qf_target) + # shape is (batch_size, 1) target_q_values = rewards.reshape(-1, 1) + (1 - dones.reshape(-1, 1)) * gamma * next_q_values + # Compute min/max to update clipping range + # TODO: initialize them properly + qf_min = jnp.minimum(target_q_values.min(), min_qf_target) + qf_max = jnp.maximum(target_q_values.max(), max_qf_target) + z_state = encoder_state.apply_fn(encoder_state.target_params, observations) z_state_action = action_encoder_state.apply_fn(action_encoder_state.target_params, z_state, actions) @@ -309,6 +332,7 @@ def mse_loss(params, dropout_key): qf_state, (qf_loss_value, ent_coef_value), key, + (qf_min, qf_max), ) @staticmethod @@ -413,7 +437,9 @@ def _train( ent_coef_state: TrainState, encoder_state: RLTrainState, action_encoder_state: RLTrainState, - key, + key: jax.random.KeyArray, + min_qf_target: float, + max_qf_target: float, ): actor_loss_value = jnp.array(0) @@ -436,6 +462,7 @@ def slice(x, step=i): qf_state, (qf_loss_value, ent_coef_value), key, + (qf_min, qf_max), ) = SAC7.update_critic( gamma, actor_state, @@ -449,11 +476,16 @@ def slice(x, step=i): slice(data.rewards), slice(data.dones), key, + min_qf_target, + max_qf_target, ) qf_state = SAC7.soft_update(tau, qf_state) encoder_state = SAC7.soft_update(tau, encoder_state) action_encoder_state = SAC7.soft_update(tau, action_encoder_state) + min_qf_target += tau * (qf_min - min_qf_target) + max_qf_target += tau * (qf_max - max_qf_target) + # hack to be able to jit (n_updates % policy_delay == 0) if i in policy_delay_indices: (actor_state, qf_state, actor_loss_value, key, entropy) = cls.update_actor( @@ -475,4 +507,5 @@ def slice(x, step=i): action_encoder_state, key, (actor_loss_value, qf_loss_value, ent_coef_value), + (min_qf_target, max_qf_target), ) From 7c5b89d6d5bda3428c8326781897872907ffdbc4 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 7 Jun 2023 13:10:35 +0200 Subject: [PATCH 08/14] More aggressive update --- sbx/sac7/sac7.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sbx/sac7/sac7.py b/sbx/sac7/sac7.py index 9d6958b..99108a1 100644 --- a/sbx/sac7/sac7.py +++ b/sbx/sac7/sac7.py @@ -483,8 +483,9 @@ def slice(x, step=i): encoder_state = SAC7.soft_update(tau, encoder_state) action_encoder_state = SAC7.soft_update(tau, action_encoder_state) - min_qf_target += tau * (qf_min - min_qf_target) - max_qf_target += tau * (qf_max - max_qf_target) + tau_update = 0.1 + min_qf_target += tau_update * (qf_min - min_qf_target) + max_qf_target += tau_update * (qf_max - max_qf_target) # hack to be able to jit (n_updates % policy_delay == 0) if i in policy_delay_indices: From c67425f60dc4ccf75f298f02907115c2d2b8d96d Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 7 Jun 2023 13:15:36 +0200 Subject: [PATCH 09/14] Log min and max --- sbx/sac7/sac7.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sbx/sac7/sac7.py b/sbx/sac7/sac7.py index 99108a1..69deffb 100644 --- a/sbx/sac7/sac7.py +++ b/sbx/sac7/sac7.py @@ -254,6 +254,8 @@ def train(self, batch_size, gradient_steps): self.logger.record("train/actor_loss", actor_loss_value.item()) self.logger.record("train/critic_loss", qf_loss_value.item()) self.logger.record("train/ent_coef", ent_coef_value.item()) + self.logger.record("train/min_qf_target", self.min_qf_target) + self.logger.record("train/max_qf_target", self.max_qf_target) @staticmethod @jax.jit From 173ad4d7710745987095fdaa2057920aa706fbcb Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 7 Jun 2023 13:19:04 +0200 Subject: [PATCH 10/14] Revert "Add target value clipping" --- sbx/sac7/sac7.py | 38 +------------------------------------- 1 file changed, 1 insertion(+), 37 deletions(-) diff --git a/sbx/sac7/sac7.py b/sbx/sac7/sac7.py index 69deffb..4dbec37 100644 --- a/sbx/sac7/sac7.py +++ b/sbx/sac7/sac7.py @@ -140,10 +140,6 @@ def _setup_model(self) -> None: self.encoder = self.policy.encoder # type: ignore[assignment] self.action_encoder = self.policy.action_encoder # type: ignore[assignment] - # Value clipping tracked values - self.min_qf_target = 0 - self.max_qf_target = 0 - # The entropy coefficient or entropy can be learned automatically # see Automating Entropy Adjustment for Maximum Entropy RL section # of https://arxiv.org/abs/1812.05905 @@ -210,11 +206,6 @@ def train(self, batch_size, gradient_steps): obs = data.observations.numpy() next_obs = data.next_observations.numpy() - # Initialize clipping at the first iteration - if self._n_updates == 0: - self.min_qf_target = data.rewards.min().item() - self.max_qf_target = data.rewards.max().item() - # Convert to numpy data = ReplayBufferSamplesNp( obs, @@ -232,7 +223,6 @@ def train(self, batch_size, gradient_steps): self.policy.action_encoder_state, self.key, (actor_loss_value, qf_loss_value, ent_coef_value), - (self.min_qf_target, self.max_qf_target), ) = self._train( self.gamma, self.tau, @@ -246,16 +236,12 @@ def train(self, batch_size, gradient_steps): self.policy.encoder_state, self.policy.action_encoder_state, self.key, - self.min_qf_target, - self.max_qf_target, ) self._n_updates += gradient_steps self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") self.logger.record("train/actor_loss", actor_loss_value.item()) self.logger.record("train/critic_loss", qf_loss_value.item()) self.logger.record("train/ent_coef", ent_coef_value.item()) - self.logger.record("train/min_qf_target", self.min_qf_target) - self.logger.record("train/max_qf_target", self.max_qf_target) @staticmethod @jax.jit @@ -272,8 +258,6 @@ def update_critic( rewards: np.ndarray, dones: np.ndarray, key: jax.random.KeyArray, - min_qf_target: float, - max_qf_target: float, ): key, noise_key, dropout_key_target, dropout_key_current = jax.random.split(key, 4) @@ -300,18 +284,9 @@ def update_critic( next_q_values = jnp.min(qf_next_values, axis=0) # td error + entropy term next_q_values = next_q_values - ent_coef_value * next_log_prob.reshape(-1, 1) - - # Clip next q_values - next_q_values = jnp.clip(next_q_values, min_qf_target, max_qf_target) - # shape is (batch_size, 1) target_q_values = rewards.reshape(-1, 1) + (1 - dones.reshape(-1, 1)) * gamma * next_q_values - # Compute min/max to update clipping range - # TODO: initialize them properly - qf_min = jnp.minimum(target_q_values.min(), min_qf_target) - qf_max = jnp.maximum(target_q_values.max(), max_qf_target) - z_state = encoder_state.apply_fn(encoder_state.target_params, observations) z_state_action = action_encoder_state.apply_fn(action_encoder_state.target_params, z_state, actions) @@ -334,7 +309,6 @@ def mse_loss(params, dropout_key): qf_state, (qf_loss_value, ent_coef_value), key, - (qf_min, qf_max), ) @staticmethod @@ -439,9 +413,7 @@ def _train( ent_coef_state: TrainState, encoder_state: RLTrainState, action_encoder_state: RLTrainState, - key: jax.random.KeyArray, - min_qf_target: float, - max_qf_target: float, + key, ): actor_loss_value = jnp.array(0) @@ -464,7 +436,6 @@ def slice(x, step=i): qf_state, (qf_loss_value, ent_coef_value), key, - (qf_min, qf_max), ) = SAC7.update_critic( gamma, actor_state, @@ -478,17 +449,11 @@ def slice(x, step=i): slice(data.rewards), slice(data.dones), key, - min_qf_target, - max_qf_target, ) qf_state = SAC7.soft_update(tau, qf_state) encoder_state = SAC7.soft_update(tau, encoder_state) action_encoder_state = SAC7.soft_update(tau, action_encoder_state) - tau_update = 0.1 - min_qf_target += tau_update * (qf_min - min_qf_target) - max_qf_target += tau_update * (qf_max - max_qf_target) - # hack to be able to jit (n_updates % policy_delay == 0) if i in policy_delay_indices: (actor_state, qf_state, actor_loss_value, key, entropy) = cls.update_actor( @@ -510,5 +475,4 @@ def slice(x, step=i): action_encoder_state, key, (actor_loss_value, qf_loss_value, ent_coef_value), - (min_qf_target, max_qf_target), ) From 571f7579bffd9aa921a24866cce1f8a44f1d7851 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 7 Jun 2023 14:10:26 +0200 Subject: [PATCH 11/14] Fix net arch encoder param --- sbx/sac7/policies.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sbx/sac7/policies.py b/sbx/sac7/policies.py index d699222..2a492d3 100644 --- a/sbx/sac7/policies.py +++ b/sbx/sac7/policies.py @@ -169,6 +169,7 @@ def __init__( if net_arch is not None: if isinstance(net_arch, list): self.net_arch_pi = self.net_arch_qf = net_arch + self.net_arch_encoder = net_arch else: self.net_arch_pi = net_arch["pi"] self.net_arch_qf = net_arch["qf"] From 159febadb1cdf3c21cdd6b1c3bc474f18fa82393 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 24 Jul 2023 11:53:03 +0200 Subject: [PATCH 12/14] Allow deterministic exploration --- sbx/sac/policies.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sbx/sac/policies.py b/sbx/sac/policies.py index 7370b4e..6dbb2a3 100644 --- a/sbx/sac/policies.py +++ b/sbx/sac/policies.py @@ -112,6 +112,7 @@ def __init__( optimizer_kwargs: Optional[Dict[str, Any]] = None, n_critics: int = 2, share_features_extractor: bool = False, + deterministic_exploration: bool = False, ): super().__init__( observation_space, @@ -134,6 +135,7 @@ def __init__( self.net_arch_pi = self.net_arch_qf = [256, 256] self.n_critics = n_critics self.use_sde = use_sde + self.deterministic_exploration = deterministic_exploration self.key = self.noise_key = jax.random.PRNGKey(0) @@ -209,7 +211,7 @@ def forward(self, obs: np.ndarray, deterministic: bool = False) -> np.ndarray: return self._predict(obs, deterministic=deterministic) def _predict(self, observation: np.ndarray, deterministic: bool = False) -> np.ndarray: # type: ignore[override] - if deterministic: + if deterministic or self.deterministic_exploration: return BaseJaxPolicy.select_action(self.actor_state, observation) # Trick to use gSDE: repeat sampled noise by using the same noise key if not self.use_sde: From 5b71da79da22657a6d91d72ed530f01688d669ee Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 26 Jul 2023 16:12:08 +0200 Subject: [PATCH 13/14] Update to match SB3 --- .github/workflows/ci.yml | 6 +++--- sbx/dqn/dqn.py | 6 +++--- sbx/droq/droq.py | 4 ++-- sbx/ppo/ppo.py | 4 ++-- sbx/sac/sac.py | 4 ++-- sbx/sac7/sac7.py | 4 ++-- sbx/tqc/tqc.py | 4 ++-- setup.py | 4 ++-- 8 files changed, 18 insertions(+), 18 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3b13b6c..54902e1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -20,7 +20,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.7", "3.8", "3.9", "3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v3 @@ -55,8 +55,8 @@ jobs: - name: Type check run: | make type - # skip mypy type check for python3.7 - if: "!(matrix.python-version == '3.7')" + # skip pytype type check for python3.11 + if: "!(matrix.python-version == '3.11')" - name: Test with pytest run: | make pytest diff --git a/sbx/dqn/dqn.py b/sbx/dqn/dqn.py index 47d0e6d..b3fc3b4 100644 --- a/sbx/dqn/dqn.py +++ b/sbx/dqn/dqn.py @@ -1,5 +1,5 @@ import warnings -from typing import Any, Dict, Optional, Tuple, Type, Union +from typing import Any, ClassVar, Dict, Optional, Tuple, Type, Union import gymnasium as gym import jax @@ -15,7 +15,7 @@ class DQN(OffPolicyAlgorithmJax): - policy_aliases: Dict[str, Type[DQNPolicy]] = { # type: ignore[assignment] + policy_aliases: ClassVar[Dict[str, Type[DQNPolicy]]] = { # type: ignore[assignment] "MlpPolicy": DQNPolicy, } # Linear schedule will be defined in `_setup_model()` @@ -248,7 +248,7 @@ def predict( if not deterministic and np.random.rand() < self.exploration_rate: if self.policy.is_vectorized_observation(observation): if isinstance(observation, dict): - n_batch = observation[list(observation.keys())[0]].shape[0] + n_batch = observation[next(iter(observation.keys()))].shape[0] else: n_batch = observation.shape[0] action = np.array([self.action_space.sample() for _ in range(n_batch)]) diff --git a/sbx/droq/droq.py b/sbx/droq/droq.py index bc2529a..e18f186 100644 --- a/sbx/droq/droq.py +++ b/sbx/droq/droq.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Tuple, Type, Union +from typing import Any, ClassVar, Dict, Optional, Tuple, Type, Union from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.noise import ActionNoise @@ -9,7 +9,7 @@ class DroQ(TQC): - policy_aliases: Dict[str, Type[TQCPolicy]] = { + policy_aliases: ClassVar[Dict[str, Type[TQCPolicy]]] = { "MlpPolicy": TQCPolicy, } diff --git a/sbx/ppo/ppo.py b/sbx/ppo/ppo.py index c44a36e..ae2ed98 100644 --- a/sbx/ppo/ppo.py +++ b/sbx/ppo/ppo.py @@ -1,6 +1,6 @@ import warnings from functools import partial -from typing import Any, Dict, Optional, Type, TypeVar, Union +from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union import jax import jax.numpy as jnp @@ -68,7 +68,7 @@ class PPO(OnPolicyAlgorithmJax): :param _init_setup_model: Whether or not to build the network at the creation of the instance """ - policy_aliases: Dict[str, Type[PPOPolicy]] = { # type: ignore[assignment] + policy_aliases: ClassVar[Dict[str, Type[PPOPolicy]]] = { # type: ignore[assignment] "MlpPolicy": PPOPolicy, # "CnnPolicy": ActorCriticCnnPolicy, # "MultiInputPolicy": MultiInputActorCriticPolicy, diff --git a/sbx/sac/sac.py b/sbx/sac/sac.py index 60d67da..41c05a8 100644 --- a/sbx/sac/sac.py +++ b/sbx/sac/sac.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, Dict, Optional, Tuple, Type, Union +from typing import Any, ClassVar, Dict, Optional, Tuple, Type, Union import flax import flax.linen as nn @@ -39,7 +39,7 @@ def __call__(self) -> float: class SAC(OffPolicyAlgorithmJax): - policy_aliases: Dict[str, Type[SACPolicy]] = { # type: ignore[assignment] + policy_aliases: ClassVar[Dict[str, Type[SACPolicy]]] = { # type: ignore[assignment] "MlpPolicy": SACPolicy, # Minimal dict support using flatten() "MultiInputPolicy": SACPolicy, diff --git a/sbx/sac7/sac7.py b/sbx/sac7/sac7.py index 4dbec37..b9f7421 100644 --- a/sbx/sac7/sac7.py +++ b/sbx/sac7/sac7.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, Dict, Optional, Tuple, Type, Union +from typing import Any, ClassVar, Dict, Optional, Tuple, Type, Union import flax import flax.linen as nn @@ -47,7 +47,7 @@ class SAC7(OffPolicyAlgorithmJax): are not yet implemented """ - policy_aliases: Dict[str, Type[SAC7Policy]] = { # type: ignore[assignment] + policy_aliases: ClassVar[Dict[str, Type[SAC7Policy]]] = { # type: ignore[assignment] "MlpPolicy": SAC7Policy, # Minimal dict support using flatten() "MultiInputPolicy": SAC7Policy, diff --git a/sbx/tqc/tqc.py b/sbx/tqc/tqc.py index f297811..3556f70 100644 --- a/sbx/tqc/tqc.py +++ b/sbx/tqc/tqc.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, Dict, Optional, Tuple, Type, Union +from typing import Any, ClassVar, Dict, Optional, Tuple, Type, Union import flax import flax.linen as nn @@ -39,7 +39,7 @@ def __call__(self) -> float: class TQC(OffPolicyAlgorithmJax): - policy_aliases: Dict[str, Type[TQCPolicy]] = { # type: ignore[assignment] + policy_aliases: ClassVar[Dict[str, Type[TQCPolicy]]] = { # type: ignore[assignment] "MlpPolicy": TQCPolicy, # Minimal dict support using flatten() "MultiInputPolicy": TQCPolicy, diff --git a/setup.py b/setup.py index 35e4bc1..01c3143 100644 --- a/setup.py +++ b/setup.py @@ -74,14 +74,14 @@ long_description=long_description, long_description_content_type="text/markdown", version=__version__, - python_requires=">=3.7", + python_requires=">=3.8", # PyPI package information. classifiers=[ "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", ], ) From 47abe6ad9ff8ddb19cd7a52d62f181a8c4bbe884 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 26 Jul 2023 17:37:45 +0200 Subject: [PATCH 14/14] Update min pytorch version --- .github/workflows/ci.yml | 4 ++-- setup.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 54902e1..d278fca 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -32,7 +32,7 @@ jobs: run: | python -m pip install --upgrade pip # cpu version of pytorch - pip install torch==1.11+cpu -f https://download.pytorch.org/whl/torch_stable.html + pip install torch==1.13+cpu -f https://download.pytorch.org/whl/torch_stable.html # # Install Atari Roms # pip install autorom @@ -55,7 +55,7 @@ jobs: - name: Type check run: | make type - # skip pytype type check for python3.11 + # skip PyType, doesn't support 3.11 yet if: "!(matrix.python-version == '3.11')" - name: Test with pytest run: | diff --git a/setup.py b/setup.py index 01c3143..efd813b 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ packages=[package for package in find_packages() if package.startswith("sbx")], package_data={"sbx": ["py.typed", "version.txt"]}, install_requires=[ - "stable_baselines3>=2.0.0a4", + "stable_baselines3>=2.0.0", "jax", "jaxlib", "flax",