diff --git a/README.md b/README.md
index 1dfca5c7..b1c0e55e 100644
--- a/README.md
+++ b/README.md
@@ -7,7 +7,7 @@
[](https://pypi.org/project/Mighty-RL/)
-
+

[](https://github.com/automl/Mighty/actions/workflows/test.yaml)
[](https://github.com/automl/Mighty/actions/workflows/docs_test.yaml)
@@ -42,12 +42,12 @@ Mighty features:
## Installation
We recommend to using uv to install and run Mighty in a virtual environment.
-The code has been tested with python 3.10/3.11 on Unix systems.
+The code has been tested with python 3.11 on Unix systems.
First create a clean python environment:
```bash
-uv venv --python=3.10
+uv venv --python=3.11
source .venv/bin/activate
```
diff --git a/docs/installation.md b/docs/installation.md
index 621d3e1f..35186c8b 100644
--- a/docs/installation.md
+++ b/docs/installation.md
@@ -1,5 +1,5 @@
We recommend to using uv to install and run Mighty in a virtual environment.
-The code has been tested with python 3.10.
+The code has been tested with python 3.11 on Unix Systems.
First create a clean python environment:
diff --git a/mighty/configs/base.yaml b/mighty/configs/base.yaml
index 313627be..67182c3c 100644
--- a/mighty/configs/base.yaml
+++ b/mighty/configs/base.yaml
@@ -1,8 +1,8 @@
defaults:
- _self_
- - algorithm: sac_mujoco
+ - /cluster: local
+ - algorithm: ppo
- environment: pufferlib_ocean/bandit
- # - search_space: dqn_gym_classic
- override hydra/job_logging: colorlog
- override hydra/hydra_logging: colorlog
- override hydra/help: mighty_help
diff --git a/mighty/configs/cluster/example_cpu.yaml b/mighty/configs/cluster/example_cpu.yaml
new file mode 100644
index 00000000..c1eca3a2
--- /dev/null
+++ b/mighty/configs/cluster/example_cpu.yaml
@@ -0,0 +1,11 @@
+# @package _global_
+defaults:
+ - override /hydra/launcher: submitit_slurm
+
+hydra:
+ launcher:
+ partition: your_partition
+ cpus_per_task: 1
+ name: mighty_cpu_experiment
+ timeout_min: 30
+ mem_gb: 10
\ No newline at end of file
diff --git a/mighty/configs/cluster/example_gpu.yaml b/mighty/configs/cluster/example_gpu.yaml
new file mode 100644
index 00000000..6b775177
--- /dev/null
+++ b/mighty/configs/cluster/example_gpu.yaml
@@ -0,0 +1,12 @@
+# @package _global_
+defaults:
+ - override /hydra/launcher: submitit_slurm
+
+hydra:
+ launcher:
+ partition: your_partition
+ gpus_per_task: 1
+ gres: "gpu:1"
+ timeout_min: 30
+ mem_gb: 10
+ name: mighty_gpu_experiment
\ No newline at end of file
diff --git a/mighty/configs/cluster/local.yaml b/mighty/configs/cluster/local.yaml
index 582473d6..e69de29b 100644
--- a/mighty/configs/cluster/local.yaml
+++ b/mighty/configs/cluster/local.yaml
@@ -1,13 +0,0 @@
-# @package _global_
-# defaults:
-# - override /hydra/launcher: joblib
-
-# hydra:
-# launcher:
-# n_jobs: 16
-
-cluster:
- _target_: distributed.deploy.local.LocalCluster
- n_workers: ${hydra.sweeper.scenario.n_workers}
- processes: false
- threads_per_worker: 1
\ No newline at end of file
diff --git a/mighty/configs/cluster/luis.yaml b/mighty/configs/cluster/luis.yaml
deleted file mode 100644
index ae307576..00000000
--- a/mighty/configs/cluster/luis.yaml
+++ /dev/null
@@ -1,17 +0,0 @@
-# @package _global_
-defaults:
- - override /hydra/launcher: submitit_slurm
-
-cluster:
- queue: ai,tnt # partition
-
-hydra:
- launcher:
- partition: ai
- cpus_per_task: 1
- name: expl2
- timeout_min: 20
- mem_gb: 4
- setup:
- - module load Miniconda3
- - conda activate /bigwork/nhwpbenc/conda/envs/mighty
\ No newline at end of file
diff --git a/mighty/configs/cluster/noctua.yaml b/mighty/configs/cluster/noctua.yaml
deleted file mode 100644
index 3f5a2026..00000000
--- a/mighty/configs/cluster/noctua.yaml
+++ /dev/null
@@ -1,25 +0,0 @@
-# @package _global_
-defaults:
- - override /hydra/launcher: submitit_slurm
-
-hydra:
- launcher:
- partition: normal
- cpus_per_task: 1
- name: expl2
- timeout_min: 20
- mem_gb: 4
- setup:
- - micromamba activate /scratch/hpc-prf-intexml/cbenjamins/envs/mighty
-
-cluster:
- _target_: dask_jobqueue.SLURMCluster
- queue: normal # set in cluster config
- # account: myaccount
- cores: 16
- memory: 32 GB
- walltime: 01:00:00
- processes: 1
- log_directory: tmp/mighty_smac
- n_workers: 16
- death_timeout: 30
diff --git a/mighty/configs/cluster/tnt.yaml b/mighty/configs/cluster/tnt.yaml
deleted file mode 100644
index 3b2a2008..00000000
--- a/mighty/configs/cluster/tnt.yaml
+++ /dev/null
@@ -1,15 +0,0 @@
-defaults:
- - override hydra/launcher: submitit_slurm
-
-cluster:
- queue: cpu_short # partition
-
-hydra:
- launcher:
- partition: cpu_short # change this to your partition name
- #gres: gpu:1 # use this option when running on GPUs
- mem_gb: 12 # memory requirements
- cpus_per_task: 20 # number of cpus per run
- timeout_min: 720 # timeout in minutes
- setup:
- - export XLA_PYTHON_CLIENT_PREALLOCATE=false
diff --git a/mighty/configs/cmaes_hpo.yaml b/mighty/configs/cmaes_hpo.yaml
index 4de1e6a4..a468865b 100644
--- a/mighty/configs/cmaes_hpo.yaml
+++ b/mighty/configs/cmaes_hpo.yaml
@@ -1,5 +1,6 @@
defaults:
- _self_
+ - /cluster: local
- algorithm: dqn
- environment: pufferlib_ocean/bandit
- search_space: dqn_gym_classic
diff --git a/mighty/configs/nes.yaml b/mighty/configs/nes.yaml
index 8d3f761d..78d3509e 100644
--- a/mighty/configs/nes.yaml
+++ b/mighty/configs/nes.yaml
@@ -1,5 +1,6 @@
defaults:
- _self_
+ - /cluster: local
- algorithm: dqn
- environment: pufferlib_ocean/bandit
- search_space: dqn_gym_classic
diff --git a/mighty/configs/ppo_smac.yaml b/mighty/configs/ppo_smac.yaml
index 6c2e19b6..40da7c69 100644
--- a/mighty/configs/ppo_smac.yaml
+++ b/mighty/configs/ppo_smac.yaml
@@ -1,5 +1,6 @@
defaults:
- _self_
+ - /cluster: local
- algorithm: ppo_mujoco
- environment: gymnasium/pendulum
- search_space: ppo_rs
@@ -45,6 +46,6 @@ hydra:
output_directory: ${hydra.sweep.dir}
search_space: ${search_space}
run:
- dir: ./tmp/branin_smac/
+ dir: ${output_dir}/${experiment_name}_${seed}
sweep:
- dir: ./tmp/branin_smac/
\ No newline at end of file
+ dir: ${output_dir}/${experiment_name}_${seed}
diff --git a/mighty/configs/sac_smac.yaml b/mighty/configs/sac_smac.yaml
index 4cfa11f3..613efd26 100644
--- a/mighty/configs/sac_smac.yaml
+++ b/mighty/configs/sac_smac.yaml
@@ -1,5 +1,6 @@
defaults:
- _self_
+ - /cluster: local
- algorithm: sac_mujoco
- environment: gymnasium/pendulum
- search_space: sac_rs
@@ -45,6 +46,6 @@ hydra:
output_directory: ${hydra.sweep.dir}
search_space: ${search_space}
run:
- dir: ./tmp/branin_smac/
+ dir: ${output_dir}/${experiment_name}_${seed}
sweep:
- dir: ./tmp/branin_smac/
\ No newline at end of file
+ dir: ${output_dir}/${experiment_name}_${seed}
diff --git a/mighty/configs/sweep_ppo_pbt.yaml b/mighty/configs/sweep_ppo_pbt.yaml
index 5d2fc903..3aba687f 100644
--- a/mighty/configs/sweep_ppo_pbt.yaml
+++ b/mighty/configs/sweep_ppo_pbt.yaml
@@ -1,5 +1,6 @@
defaults:
- _self_
+ - /cluster: local
- algorithm: ppo
- environment: gymnasium/pendulum
- search_space: ppo_rs
diff --git a/mighty/configs/sweep_rs.yaml b/mighty/configs/sweep_rs.yaml
index d95a6236..650c3545 100644
--- a/mighty/configs/sweep_rs.yaml
+++ b/mighty/configs/sweep_rs.yaml
@@ -1,5 +1,6 @@
defaults:
- _self_
+ - /cluster: local
- algorithm: ppo
- environment: gymnasium/pendulum
- search_space: ppo_rs
diff --git a/mighty/mighty_agents/base_agent.py b/mighty/mighty_agents/base_agent.py
index 2e30e8b7..790a7c74 100644
--- a/mighty/mighty_agents/base_agent.py
+++ b/mighty/mighty_agents/base_agent.py
@@ -34,52 +34,14 @@
from gymnasium.wrappers.normalize import NormalizeObservation, NormalizeReward
-def seed_everything(seed: int, env: gym.Env | None = None):
- """
- Seed Python, NumPy, Torch (including cuDNN), plus Gym (action_space, observation_space,
- and the environment's own RNG). If `env` is vectorized (has `env.envs`), we dig into each
- sub-env. Always call this before ANY network-building or RNG usage.
- """
- # 1) Python/NumPy/Torch/Hash
- random.seed(seed)
- np.random.seed(seed)
- torch.manual_seed(seed)
- torch.cuda.manual_seed_all(seed)
- torch.backends.cudnn.deterministic = True
- torch.backends.cudnn.benchmark = False
- torch.use_deterministic_algorithms(True) # enforce strictly deterministic ops
- os.environ["PYTHONHASHSEED"] = str(seed)
-
- # 2) Gym environment seeding
- if env is not None:
- # If the env is wrapped, try to unwrap to the core
- try:
- core_env = env.unwrapped
- core_env.seed(seed)
- except Exception:
- pass
-
- # If vectorized (e.g. SyncVectorEnv), seed each sub‐env separately
- if hasattr(env, "envs") and isinstance(env.envs, list):
- sub_seeds = [seed for _ in range(len(env.envs))]
- for sub_seed, subenv in zip(sub_seeds, env.envs):
- subenv.action_space.seed(sub_seed)
- subenv.observation_space.seed(sub_seed)
- try:
- subenv.unwrapped.seed(sub_seed)
- except Exception:
- pass
- # Reset the vectorized env with explicit seeds
- env.reset(seed=sub_seeds)
- else:
- # Single environment
- env.action_space.seed(seed)
- env.observation_space.seed(seed)
- try:
- env.unwrapped.seed(seed)
- except Exception:
- pass
- env.reset(seed=seed)
+def seed_env_spaces(env: gym.VectorEnv, seed: int) -> None:
+ env.action_space.seed(seed)
+ env.single_action_space.seed(seed)
+ env.observation_space.seed(seed)
+ env.single_observation_space.seed(seed)
+ for i in range(len(env.envs)):
+ env.envs[i].action_space.seed(seed)
+ env.envs[i].observation_space.seed(seed)
def update_buffer(buffer, new_data):
@@ -218,28 +180,12 @@ def __init__( # noqa: PLR0915, PLR0912
self.seed = seed
if self.seed is not None:
- seed_everything(self.seed, env=None) # type: ignore
- # Re-seed Python/NumPy/Torch here again.
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
- torch.backends.cudnn.deterministic = True
- torch.backends.cudnn.benchmark = False
os.environ["PYTHONHASHSEED"] = str(seed)
- # Also seed any RNGs you will use later (e.g. for buffer / policy):
- self.rng = np.random.default_rng(seed) # for any numpy-based sampling
- # If you ever use torch.Generator for sampling actions, you could do:
- self.torch_gen = torch.Generator().manual_seed(seed)
-
- # If you use Python’s random inside your policy, call random.seed(seed)
- random.seed(seed)
-
- else:
- self.rng = np.random.default_rng()
- self.seed = 0
-
# Replay Buffer
replay_buffer_class = retrieve_class(
cls=replay_buffer_class,
@@ -279,6 +225,10 @@ def __init__( # noqa: PLR0915, PLR0912
else:
self.eval_env = eval_env
+ if self.seed is not None:
+ seed_env_spaces(self.env, self.seed)
+ seed_env_spaces(self.eval_env, self.seed)
+
self.render_progress = render_progress
self.output_dir = output_dir
if self.output_dir is not None:
@@ -288,9 +238,9 @@ def __init__( # noqa: PLR0915, PLR0912
self.meta_modules = {}
for i, m in enumerate(meta_methods):
meta_class = retrieve_class(cls=m, default_cls=None) # type: ignore
- assert (
- meta_class is not None
- ), f"Class {m} not found, did you specify the correct loading path?"
+ assert meta_class is not None, (
+ f"Class {m} not found, did you specify the correct loading path?"
+ )
kwargs: Dict = {}
if len(meta_kwargs) > i:
kwargs = meta_kwargs[i]
@@ -345,11 +295,13 @@ def __init__( # noqa: PLR0915, PLR0912
wandb.log(starting_hps)
self.initialize_agent()
+ if self.seed is not None:
+ self.buffer.seed(self.seed)
+ self.policy.seed(self.seed)
+ for m in self.meta_modules.values():
+ m.seed(self.seed)
self.steps = 0
- seed_everything(self.seed, self.env) # type: ignore
- seed_everything(self.seed, self.eval_env) # type: ignore
-
def _initialize_agent(self) -> None:
"""Agent/algorithm specific initializations."""
raise NotImplementedError
diff --git a/mighty/mighty_agents/sac.py b/mighty/mighty_agents/sac.py
index 5617c739..2125c794 100644
--- a/mighty/mighty_agents/sac.py
+++ b/mighty/mighty_agents/sac.py
@@ -120,11 +120,11 @@ def __init__(
# Initialize loss buffer for logging
self.loss_buffer = {
- "q_loss1": [],
- "q_loss2": [],
- "policy_loss": [],
- "td_error1": [],
- "td_error2": [],
+ "Update/q_loss1": [],
+ "Update/q_loss2": [],
+ "Update/policy_loss": [],
+ "Update/td_error1": [],
+ "Update/td_error2": [],
"step": [],
}
diff --git a/mighty/mighty_exploration/mighty_exploration_policy.py b/mighty/mighty_exploration/mighty_exploration_policy.py
index e73dc4a2..39948d5c 100644
--- a/mighty/mighty_exploration/mighty_exploration_policy.py
+++ b/mighty/mighty_exploration/mighty_exploration_policy.py
@@ -65,6 +65,10 @@ def __init__(
else:
self.sample_action = self.sample_func_logits
+ def seed(self, seed: int) -> None:
+ """Set the random seed for reproducibility."""
+ self.rng = np.random.default_rng(seed)
+
def sample_func_q(self, state_array):
"""
Q-learning branch:
diff --git a/mighty/mighty_meta/mighty_component.py b/mighty/mighty_meta/mighty_component.py
index a303c9b2..7a47ac7e 100644
--- a/mighty/mighty_meta/mighty_component.py
+++ b/mighty/mighty_meta/mighty_component.py
@@ -2,6 +2,8 @@
from __future__ import annotations
+import numpy as np
+
class MightyMetaComponent:
"""Component for registering meta-control methods."""
@@ -17,6 +19,7 @@ def __init__(self) -> None:
self.post_update_methods = []
self.pre_episode_methods = []
self.post_episode_methods = []
+ self.rng = np.random.default_rng()
def pre_step(self, metrics):
"""Execute methods before a step.
@@ -71,3 +74,7 @@ def post_episode(self, metrics):
"""
for m in self.post_episode_methods:
m(metrics)
+
+ def seed(self, seed: int) -> None:
+ """Set the random seed for reproducibility."""
+ self.rng = np.random.default_rng(seed)
diff --git a/mighty/mighty_meta/plr.py b/mighty/mighty_meta/plr.py
index 5c6b4cfe..67d75bd4 100644
--- a/mighty/mighty_meta/plr.py
+++ b/mighty/mighty_meta/plr.py
@@ -43,7 +43,6 @@ def __init__(
:return:
"""
super().__init__()
- self.rng = np.random.default_rng()
self.alpha = alpha
self.rho = rho
self.staleness_coef = staleness_coeff
diff --git a/mighty/mighty_meta/space.py b/mighty/mighty_meta/space.py
index 4fce0669..c5b92986 100644
--- a/mighty/mighty_meta/space.py
+++ b/mighty/mighty_meta/space.py
@@ -45,11 +45,11 @@ def get_instances(self, metrics):
self.all_instances = np.array(env.instance_id_list.copy())
if self.last_evals is None and rollout_values is None:
- self.instance_set = np.random.default_rng().choice(
+ self.instance_set = self.rng.choice(
self.all_instances, size=self.current_instance_set_size
)
elif self.last_evals is None:
- self.instance_set = np.random.default_rng().choice(
+ self.instance_set = self.rng.choice(
self.all_instances, size=self.current_instance_set_size
)
self.last_evals = np.nanmean(rollout_values)
diff --git a/mighty/mighty_models/networks.py b/mighty/mighty_models/networks.py
index af44242c..3072468b 100644
--- a/mighty/mighty_models/networks.py
+++ b/mighty/mighty_models/networks.py
@@ -106,9 +106,7 @@ def __init__(
def forward(self, x, transform: bool = True):
"""Forward pass."""
if transform:
- x = (
- x.permute(2, 0, 1) if len(x.shape) == 3 else x.permute(0, 3, 1, 2)
- ) # noqa: PLR2004
+ x = x.permute(2, 0, 1) if len(x.shape) == 3 else x.permute(0, 3, 1, 2) # noqa: PLR2004
return self.cnn(x)
def __getstate__(self):
@@ -194,9 +192,7 @@ def __init__(
def forward(self, x, transform: bool = True):
"""Forward pass."""
if transform:
- x = (
- x.permute(2, 0, 1) if len(x.shape) == 3 else x.permute(0, 3, 1, 2)
- ) # noqa: PLR2004
+ x = x.permute(2, 0, 1) if len(x.shape) == 3 else x.permute(0, 3, 1, 2) # noqa: PLR2004
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
@@ -221,9 +217,7 @@ def __init__(self, module1, module2):
def forward(self, x):
"""Forward pass."""
- x = (
- x.permute(2, 0, 1) if len(x.shape) == 3 else x.permute(0, 3, 1, 2)
- ) # noqa: PLR2004
+ x = x.permute(2, 0, 1) if len(x.shape) == 3 else x.permute(0, 3, 1, 2) # noqa: PLR2004
x = self.module1(x, False)
return self.module2(x)
diff --git a/mighty/mighty_replay/buffer.py b/mighty/mighty_replay/buffer.py
index 157b0377..d1c12ea5 100644
--- a/mighty/mighty_replay/buffer.py
+++ b/mighty/mighty_replay/buffer.py
@@ -1,5 +1,7 @@
from abc import ABC, abstractmethod
+import numpy as np
+
class MightyBuffer(ABC):
@abstractmethod
@@ -21,3 +23,7 @@ def __len__(self):
@abstractmethod
def __bool__(self):
pass
+
+ def seed(self, seed: int):
+ """Set random seed."""
+ self.rng = np.random.default_rng(seed)
diff --git a/mighty/mighty_replay/mighty_prioritized_replay.py b/mighty/mighty_replay/mighty_prioritized_replay.py
index e55ce4a4..6173c2e7 100644
--- a/mighty/mighty_replay/mighty_prioritized_replay.py
+++ b/mighty/mighty_replay/mighty_prioritized_replay.py
@@ -28,6 +28,7 @@ def __init__(
self.beta = beta
self.epsilon = epsilon
self.device = torch.device(device)
+ self.rng = np.random.default_rng()
super().__init__(capacity, keep_infos, flatten_infos, device)
@@ -134,7 +135,7 @@ def sample(self, batch_size):
for i in range(batch_size):
a = segment * i
b = segment * (i + 1)
- s = np.random.uniform(a, b)
+ s = self.rng.uniform(a, b)
leaf = self._retrieve(1, s) # leaf index in [capacity..2*capacity-1]
data_idx = leaf - self.capacity # ring-buffer index
batch_indices[i] = data_idx
diff --git a/mighty/mighty_replay/mighty_rollout_buffer.py b/mighty/mighty_replay/mighty_rollout_buffer.py
index c040c655..5dd57db3 100644
--- a/mighty/mighty_replay/mighty_rollout_buffer.py
+++ b/mighty/mighty_replay/mighty_rollout_buffer.py
@@ -188,6 +188,7 @@ def __init__(
self.gae_lambda = gae_lambda
self.discrete_action = discrete_action
self.use_latents = use_latents # Store for later use
+ self.rng = np.random.default_rng()
# Shapes -----------------------------------------------------------
if isinstance(obs_shape, int):
@@ -327,7 +328,7 @@ def _flat(t: torch.Tensor | None):
logp_f = _flat(self.log_probs)
val_f = _flat(self.values)
- perm = np.random.permutation(total)
+ perm = self.rng.permutation(total)
perm = perm[: (total // batch_size) * batch_size].reshape(-1, batch_size)
minibatches: list[RolloutBatch] = []
diff --git a/mighty/mighty_runners/mighty_es_runner.py b/mighty/mighty_runners/mighty_es_runner.py
index 03340e7b..6dbddd09 100644
--- a/mighty/mighty_runners/mighty_es_runner.py
+++ b/mighty/mighty_runners/mighty_es_runner.py
@@ -1,7 +1,7 @@
from __future__ import annotations
import importlib.util as iutil
-from typing import TYPE_CHECKING, Dict, Tuple, Callable
+from typing import TYPE_CHECKING, Callable, Dict, Tuple
import numpy as np
import torch
diff --git a/mighty/mighty_runners/mighty_runner.py b/mighty/mighty_runners/mighty_runner.py
index 51109e2d..0dafa17e 100644
--- a/mighty/mighty_runners/mighty_runner.py
+++ b/mighty/mighty_runners/mighty_runner.py
@@ -4,7 +4,7 @@
import warnings
from abc import ABC
from pathlib import Path
-from typing import TYPE_CHECKING, Any, Dict, Tuple, Callable
+from typing import TYPE_CHECKING, Any, Callable, Dict, Tuple
from hydra.utils import get_class
diff --git a/mighty/mighty_update/sac_update.py b/mighty/mighty_update/sac_update.py
index 74d8a6d9..aba3aa12 100644
--- a/mighty/mighty_update/sac_update.py
+++ b/mighty/mighty_update/sac_update.py
@@ -72,9 +72,7 @@ def calculate_td_error(self, transition: TransitionBatch) -> Tuple:
transition.rewards, dtype=torch.float32
).unsqueeze(-1) + (
1 - torch.as_tensor(transition.dones, dtype=torch.float32).unsqueeze(-1)
- ) * self.gamma * (
- torch.min(q1_t, q2_t) - alpha * logp_next
- )
+ ) * self.gamma * (torch.min(q1_t, q2_t) - alpha * logp_next)
sa = torch.cat(
[
torch.as_tensor(transition.observations, dtype=torch.float32),
@@ -176,10 +174,10 @@ def update(self, batch: TransitionBatch) -> Dict:
# --- Logging metrics ---
td1, td2 = self.calculate_td_error(batch)
return {
- "q_loss1": q_loss1.item(),
- "q_loss2": q_loss2.item(),
- "policy_loss": policy_loss.item(),
- "alpha_loss": alpha_loss.item(),
- "td_error1": td1.mean().item(),
- "td_error2": td2.mean().item(),
+ "Update/q_loss1": q_loss1.item(),
+ "Update/q_loss2": q_loss2.item(),
+ "Update/policy_loss": policy_loss.item(),
+ "Update/alpha_loss": alpha_loss.item(),
+ "Update/td_error1": td1.mean().item(),
+ "Update/td_error2": td2.mean().item(),
}
diff --git a/mighty/mighty_utils/plotting.py b/mighty/mighty_utils/plotting.py
index ab7e4f6f..b43bb0a9 100644
--- a/mighty/mighty_utils/plotting.py
+++ b/mighty/mighty_utils/plotting.py
@@ -266,14 +266,10 @@ def plot_final_performance_comparison(
aggregation_funcs = lambda x: np.array([metrics.aggregate_iqm(x)]) # noqa: E731
metric_names = ["IQM"]
elif "mean" in aggregation:
- aggregation_funcs = lambda x: np.array(
- [metrics.aggregate_mean(x)]
- ) # noqa: E731
+ aggregation_funcs = lambda x: np.array([metrics.aggregate_mean(x)]) # noqa: E731
metric_names = ["Mean"]
elif "median" in aggregation:
- aggregation_funcs = lambda x: np.array(
- [metrics.aggregate_median(x)]
- ) # noqa: E731
+ aggregation_funcs = lambda x: np.array([metrics.aggregate_median(x)]) # noqa: E731
metric_names = ["Median"]
score_dict = {}
diff --git a/mighty/mighty_utils/test_helpers.py b/mighty/mighty_utils/test_helpers.py
index b17018f4..ef721f4d 100644
--- a/mighty/mighty_utils/test_helpers.py
+++ b/mighty/mighty_utils/test_helpers.py
@@ -27,13 +27,14 @@ def set_instance_set(self, instance_set):
self.instance_set = instance_set
def reset(self, options={}, seed=None):
+ super().reset(seed=seed if seed is not None else 0)
if self.inst_id is None:
- self.inst_id = np.random.default_rng().integers(0, 100)
+ self.inst_id = self._np_random.integers(0, 100)
return self.observation_space.sample(), {}
def step(self, action):
- tr = np.random.default_rng().choice([0, 1], p=[0.9, 0.1])
- return self.observation_space.sample(), 0, False, tr, {}
+ tr = self._np_random.choice([0, 1], p=[0.9, 0.1])
+ return self.observation_space.sample(), self._np_random.random(), False, tr, {}
class DummyContinuousEnv(gym.Env):
@@ -57,13 +58,14 @@ def set_instance_set(self, instance_set):
self.instance_set = instance_set
def reset(self, options={}, seed=None):
+ super().reset(seed=seed if seed is not None else 0)
if self.inst_id is None:
- self.inst_id = np.random.default_rng().integers(0, 100)
+ self.inst_id = self._np_random.integers(0, 100)
return self.observation_space.sample(), {}
def step(self, action):
- tr = np.random.default_rng().choice([0, 1], p=[0.9, 0.1])
- return self.observation_space.sample(), np.random.rand(), False, tr, {}
+ tr = self._np_random.choice([0, 1], p=[0.9, 0.1])
+ return self.observation_space.sample(), self._np_random.random(), False, tr, {}
class DummyModel:
diff --git a/pyproject.toml b/pyproject.toml
index 086235c6..05e97863 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -8,7 +8,7 @@ version = "0.0.1"
description = "A modular, meta-learning-ready RL library."
authors = [{ name = "AutoRL@LUHAI", email = "a.mohan@ai.uni-hannover.de" }]
readme = "README.md"
-requires-python = ">=3.10,<3.12"
+requires-python = ">=3.11,<3.12"
license = { file = "LICENSE" }
keywords = [
"Reinforcement Learning",
@@ -38,19 +38,21 @@ dependencies = [
"scipy",
"rich~=12.4",
"wandb~=0.12",
- "torch",
+ "torch~=2.7",
"dill",
"imageio",
"evosax==0.1.6",
"rliable",
"seaborn",
- "uniplot"
+ "uniplot",
+ "jax~=0.7",
+ "flax~=0.11"
]
[project.optional-dependencies]
dev = ["ruff", "mypy", "build", "pytest", "pytest-cov"]
-carl = ["carl_bench==1.1.1", "brax==0.12.1"]
-dacbench = ["dacbench==0.4.0", "torchvision", "ioh"]
+carl = ["carl_bench[brax]==1.1.1"]
+dacbench = ["dacbench==0.4.0", "torchvision"]
pufferlib = ["pufferlib==2.0.6"]
docs = ["mkdocs", "mkdocs-material", "mkdocs-autorefs",
"mkdocs-gen-files", "mkdocs-literate-nav",
@@ -71,7 +73,7 @@ ignore = [
]
[tool.mypy]
-python_version = "3.10"
+python_version = "3.11"
disallow_untyped_defs = true
show_error_codes = true
no_implicit_optional = true
diff --git a/test/agents/test_dqn_agent.py b/test/agents/test_dqn_agent.py
index 5922dc63..3ba2ea4a 100644
--- a/test/agents/test_dqn_agent.py
+++ b/test/agents/test_dqn_agent.py
@@ -24,12 +24,12 @@ def test_init(self):
dqn = MightyDQNAgent(output_dir, env, use_target=False)
assert isinstance(dqn.q, DQN), "Model should be an instance of DQN"
assert isinstance(dqn.value_function, DQN), "Vf should be an instance of DQN"
- assert isinstance(
- dqn.policy, EpsilonGreedy
- ), "Policy should be an instance of EpsilonGreedy"
- assert isinstance(
- dqn.qlearning, QLearning
- ), "Update should be an instance of QLearning"
+ assert isinstance(dqn.policy, EpsilonGreedy), (
+ "Policy should be an instance of EpsilonGreedy"
+ )
+ assert isinstance(dqn.qlearning, QLearning), (
+ "Update should be an instance of QLearning"
+ )
assert dqn.q_target is None, "Q_target should be None"
test_obs, _ = env.reset()
@@ -65,12 +65,12 @@ def test_init(self):
q_pred = dqn.q(test_obs)
target_pred = dqn.q_target(test_obs)
assert torch.allclose(q_pred, target_pred), "Q and Q_target should be equal"
- assert isinstance(
- dqn.buffer, PrioritizedReplay
- ), "Replay buffer should be an instance of PrioritizedReplay"
- assert isinstance(
- dqn.qlearning, DoubleQLearning
- ), "Update should be an instance of DoubleQLearning"
+ assert isinstance(dqn.buffer, PrioritizedReplay), (
+ "Replay buffer should be an instance of PrioritizedReplay"
+ )
+ assert isinstance(dqn.qlearning, DoubleQLearning), (
+ "Update should be an instance of DoubleQLearning"
+ )
assert dqn._batch_size == 32, "Batch size should be 32"
assert dqn.learning_rate == 0.01, "Learning rate should be 0.01"
assert dqn._epsilon == 0.1, "Epsilon should be 0.1"
@@ -82,12 +82,12 @@ def test_update(self):
output_dir = Path("test_dqn_agent")
output_dir.mkdir(parents=True, exist_ok=True)
dqn = MightyDQNAgent(output_dir, env, batch_size=2)
- dqn.run(10, 1)
+ dqn.run(20, 1)
original_optimizer = torch.optim.Adam(dqn.q.parameters(), lr=dqn.learning_rate)
original_params = deepcopy(list(dqn.q.parameters()))
original_target_params = deepcopy(list(dqn.q_target.parameters()))
original_feature_params = deepcopy(list(dqn.q.feature_extractor.parameters()))
- batch = dqn.buffer.sample(2)
+ batch = dqn.buffer.sample(20)
metrics = dqn.update_agent(batch, 0)
new_params = deepcopy(list(dqn.q.parameters()))
new_target_params = deepcopy(list(dqn.q_target.parameters()))
@@ -96,9 +96,9 @@ def test_update(self):
for old, new in zip(
original_target_params[:10], new_target_params[:10], strict=False
):
- assert not torch.allclose(
- old, new
- ), "Target model parameters should be changed"
+ assert not torch.allclose(old, new), (
+ "Target model parameters should be changed"
+ )
for old, new in zip(
original_feature_params, dqn.q.feature_extractor.parameters(), strict=False
):
@@ -106,16 +106,16 @@ def test_update(self):
for old, new, new_target in zip(
original_params[:10], new_params[:10], new_target_params[:10], strict=False
):
- assert not torch.allclose(
- old * (1 - 0.01) + new * 0.01, new_target
- ), "Target model parameters should be scaled correctly"
+ assert not torch.allclose(old * (1 - 0.01) + new * 0.01, new_target), (
+ "Target model parameters should be scaled correctly"
+ )
- batch = dqn.buffer.sample(2)
+ batch = dqn.buffer.sample(20)
preds, targets = dqn.qlearning.get_targets(batch, dqn.q, dqn.q_target)
- assert (
- np.mean(targets.detach().numpy() - metrics["Update/td_targets"]) < 0.1
- ), "TD_targets should be equal"
+ assert np.mean(targets.detach().numpy() - metrics["Update/td_targets"]) < 0.1, (
+ "TD_targets should be equal"
+ )
assert (
np.mean((targets - preds).detach().numpy() - metrics["Update/td_errors"])
< 0.1
@@ -123,16 +123,16 @@ def test_update(self):
original_optimizer.zero_grad()
loss = F.mse_loss(preds, targets)
- assert (
- np.mean(loss.detach().numpy() - metrics["Update/loss"]) < 0.05
- ), "Loss should be equal"
+ assert np.mean(loss.detach().numpy() - metrics["Update/loss"]) < 0.05, (
+ "Loss should be equal"
+ )
loss.backward()
original_optimizer.step()
manual_params = deepcopy(list(dqn.q.parameters()))
for manual, agent in zip(manual_params[:10], new_params[:10], strict=False):
- assert torch.allclose(
- manual, agent, atol=1e-2
- ), "Model parameters should be equal to manual update"
+ assert torch.allclose(manual, agent, atol=1e-2), (
+ "Model parameters should be equal to manual update"
+ )
clean(output_dir)
@@ -195,12 +195,12 @@ def process_transition(self):
{},
)
- assert (
- len(metrics["rollout_values"]) == 2
- ), f"One value prediction per state, got: {metrics['rollout_values']}"
- assert (
- len(metrics["td_error"]) == 2
- ), f"TD error should be computed per transition, got: {metrics['td_errors']}"
+ assert len(metrics["rollout_values"]) == 2, (
+ f"One value prediction per state, got: {metrics['rollout_values']}"
+ )
+ assert len(metrics["td_error"]) == 2, (
+ f"TD error should be computed per transition, got: {metrics['td_errors']}"
+ )
state, _ = env.reset()
action = dqn.policy(state, return_logp=False)
@@ -209,8 +209,66 @@ def process_transition(self):
metrics = dqn.process_transition(
state, action, reward, next_state, te or tr, 0, metrics
)
- assert (
- len(metrics["rollout_values"]) == 3
- ), "New value prediction should be added"
+ assert len(metrics["rollout_values"]) == 3, (
+ "New value prediction should be added"
+ )
assert len(metrics["td_error"]) == 1, "TD error is overwritten"
clean(output_dir)
+
+ def test_reproducibility(self):
+ env = gym.vector.SyncVectorEnv([DummyEnv for _ in range(1)])
+ output_dir = Path("test_dqn_agent")
+ output_dir.mkdir(parents=True, exist_ok=True)
+ dqn = MightyDQNAgent(output_dir, env, batch_size=2, seed=42)
+ init_params = deepcopy(list(dqn.q.parameters()))
+ dqn.run(20, 1)
+ original_batch = dqn.buffer.sample(20)
+ original_metrics = dqn.update_agent(original_batch, 0)
+ original_params = deepcopy(list(dqn.q.parameters()))
+
+ for _ in range(3):
+ env = gym.vector.SyncVectorEnv([DummyEnv for _ in range(1)])
+ output_dir = Path("test_dqn_agent")
+ output_dir.mkdir(parents=True, exist_ok=True)
+ dqn = MightyDQNAgent(output_dir, env, batch_size=2, seed=42)
+ for old, new in zip(
+ init_params[:10], list(dqn.q.parameters())[:10], strict=False
+ ):
+ assert torch.allclose(old, new), (
+ "Parameter initialization should be the same with same seed"
+ )
+ dqn.run(20, 1)
+ batch = dqn.buffer.sample(20)
+
+ for old, new in zip(
+ original_batch.observations, batch.observations, strict=False
+ ):
+ assert torch.allclose(old, new), (
+ "Batch observations should be the same with same seed"
+ )
+
+ assert torch.allclose(original_batch.actions, batch.actions), (
+ f"Batch actions should be the same with same seed: {torch.isclose(original_batch.actions, batch.actions)}"
+ )
+ assert torch.allclose(original_batch.rewards, batch.rewards), (
+ f"Batch rewards should be the same with same seed: {torch.isclose(original_batch.rewards, batch.rewards)}"
+ )
+
+ new_metrics = dqn.update_agent(batch, 0)
+ for old, new in zip(
+ original_params[:10], list(dqn.q.parameters())[:10], strict=False
+ ):
+ assert torch.allclose(old, new), (
+ "Model parameters should stay the same with same seed"
+ )
+
+ assert np.allclose(
+ original_metrics["Update/td_targets"], new_metrics["Update/td_targets"]
+ ), "TD targets should be the same with same seed"
+ assert np.allclose(
+ original_metrics["Update/loss"], new_metrics["Update/loss"]
+ ), "Loss should be the same with same seed"
+ assert np.allclose(
+ original_metrics["Update/td_errors"], new_metrics["Update/td_errors"]
+ ), "TD errors should be the same with same seed"
+ clean(output_dir)
diff --git a/test/agents/test_ppo_agent.py b/test/agents/test_ppo_agent.py
index bd427337..06b22761 100644
--- a/test/agents/test_ppo_agent.py
+++ b/test/agents/test_ppo_agent.py
@@ -94,9 +94,9 @@ def test_init_discrete(self):
prediction = agent.step(test_obs, metrics)[0]
assert len(prediction) == 1, "Prediction should have shape (1,)"
- assert (
- 0 <= prediction[0] < 4
- ), "Action should be in valid range [0, 4)" # Updated for 4 actions
+ assert 0 <= prediction[0] < 4, (
+ "Action should be in valid range [0, 4)"
+ ) # Updated for 4 actions
clean(output_dir)
@@ -166,9 +166,9 @@ def test_update(self):
print(f"Buffer size after manual collection: {len(agent.buffer)}")
# Ensure we have enough data in buffer
- assert (
- len(agent.buffer) >= agent._batch_size
- ), f"Buffer size {len(agent.buffer)} should be >= batch size {agent._batch_size}"
+ assert len(agent.buffer) >= agent._batch_size, (
+ f"Buffer size {len(agent.buffer)} should be >= batch size {agent._batch_size}"
+ )
# Perform update
update_kwargs = {"next_s": curr_s, "dones": np.array([False])}
@@ -244,15 +244,15 @@ def test_properties(self):
params = ppo.parameters
assert isinstance(params, list), "Parameters should be a list"
assert len(params) > 0, "Should have parameters"
- assert all(
- isinstance(p, torch.nn.Parameter) for p in params
- ), "All should be Parameters"
+ assert all(isinstance(p, torch.nn.Parameter) for p in params), (
+ "All should be Parameters"
+ )
# Test value_function property - updated for new structure
value_fn = ppo.value_function
- assert (
- value_fn is ppo.model.value_function_module
- ), "Value function should be model's value function module"
+ assert value_fn is ppo.model.value_function_module, (
+ "Value function should be model's value function module"
+ )
# Test that the value function wrapper works - use correct obs shape
obs_shape = env.single_observation_space.shape[
@@ -263,3 +263,123 @@ def test_properties(self):
assert value_output.shape == (1, 1), "Value function should output shape (1, 1)"
clean(output_dir)
+
+ def test_reproducibility(self):
+ env = gym.vector.SyncVectorEnv([DummyEnv for _ in range(1)])
+ output_dir = Path("test_ppo_agent")
+ output_dir.mkdir(parents=True, exist_ok=True)
+ ppo = MightyPPOAgent(output_dir, env, batch_size=2, seed=42)
+ init_params = deepcopy(list(ppo.model.parameters()))
+
+ metrics = {
+ "env": ppo.env,
+ "step": 0,
+ "hp/lr": ppo.learning_rate,
+ "hp/pi_epsilon": ppo._epsilon,
+ "hp/batch_size": ppo._batch_size,
+ "hp/learning_starts": ppo._learning_starts,
+ }
+ curr_s, _ = env.reset(seed=42)
+ # Collect enough transitions to fill the buffer
+ for step in range(128): # Fill buffer with 128 transitions
+ # Get action from agent
+ action, log_prob = ppo.step(curr_s, metrics)
+
+ # Take environment step
+ next_s, reward, terminated, truncated, _ = env.step(action)
+ dones = np.logical_or(terminated, truncated)
+
+ # Process the transition (this adds to buffer)
+ ppo.process_transition(
+ curr_s,
+ action,
+ reward,
+ next_s,
+ dones,
+ log_prob.detach().cpu().numpy(),
+ {"step": step},
+ )
+
+ # Update current state
+ curr_s = next_s
+
+ # Reset environment if done
+ if np.any(dones):
+ curr_s, _ = env.reset()
+
+ update_kwargs = {"next_s": curr_s, "dones": np.array([False])}
+ original_metrics = ppo.update(metrics, update_kwargs)
+ original_params = deepcopy(list(ppo.model.parameters()))
+
+ for _ in range(3):
+ env = gym.vector.SyncVectorEnv([DummyEnv for _ in range(1)])
+ output_dir = Path("test_ppo_agent")
+ output_dir.mkdir(parents=True, exist_ok=True)
+ ppo = MightyPPOAgent(output_dir, env, batch_size=2, seed=42)
+ for old, new in zip(
+ init_params[:10], list(ppo.model.parameters())[:10], strict=False
+ ):
+ assert torch.allclose(old, new), (
+ "Parameter initialization should be the same with same seed"
+ )
+
+ metrics = {
+ "env": ppo.env,
+ "step": 0,
+ "hp/lr": ppo.learning_rate,
+ "hp/pi_epsilon": ppo._epsilon,
+ "hp/batch_size": ppo._batch_size,
+ "hp/learning_starts": ppo._learning_starts,
+ }
+ curr_s, _ = env.reset(seed=42)
+ # Collect enough transitions to fill the buffer
+ for step in range(128): # Fill buffer with 128 transitions
+ # Get action from agent
+ action, log_prob = ppo.step(curr_s, metrics)
+
+ # Take environment step
+ next_s, reward, terminated, truncated, _ = env.step(action)
+ dones = np.logical_or(terminated, truncated)
+
+ # Process the transition (this adds to buffer)
+ ppo.process_transition(
+ curr_s,
+ action,
+ reward,
+ next_s,
+ dones,
+ log_prob.detach().cpu().numpy(),
+ {"step": step},
+ )
+
+ # Update current state
+ curr_s = next_s
+
+ # Reset environment if done
+ if np.any(dones):
+ curr_s, _ = env.reset()
+
+ update_kwargs = {"next_s": curr_s, "dones": np.array([False])}
+ new_metrics = ppo.update(metrics, update_kwargs)
+
+ for old, new in zip(
+ original_params[:10], list(ppo.model.parameters())[:10], strict=False
+ ):
+ assert torch.allclose(old, new), (
+ "Model parameters should stay the same with same seed"
+ )
+
+ assert np.isclose(
+ original_metrics["Update/policy_loss"],
+ new_metrics["Update/policy_loss"],
+ ), "Policy loss should be the same with same seed"
+ assert np.isclose(
+ original_metrics["Update/value_loss"], new_metrics["Update/value_loss"]
+ ), "Value loss should be the same with same seed"
+ assert np.allclose(
+ original_metrics["Update/approx_kl"], new_metrics["Update/approx_kl"]
+ ), "Approx KL should be the same with same seed"
+ assert np.allclose(
+ original_metrics["Update/entropy"], new_metrics["Update/entropy"]
+ ), "Entropy should be the same with same seed"
+ clean(output_dir)
diff --git a/test/agents/test_sac_agent.py b/test/agents/test_sac_agent.py
index 260d66d0..cf3d0d13 100644
--- a/test/agents/test_sac_agent.py
+++ b/test/agents/test_sac_agent.py
@@ -119,9 +119,9 @@ def test_update(self):
print(f"Buffer size after manual collection: {len(agent.buffer)}")
# Ensure we have enough data in buffer
- assert (
- len(agent.buffer) >= agent.batch_size
- ), f"Buffer size {len(agent.buffer)} should be >= batch size {agent.batch_size}"
+ assert len(agent.buffer) >= agent.batch_size, (
+ f"Buffer size {len(agent.buffer)} should be >= batch size {agent.batch_size}"
+ )
# Set agent.steps to satisfy learning_starts condition
agent.steps = agent.learning_starts + 1
@@ -180,19 +180,18 @@ def test_update(self):
)
break
-
# Modified assertions - warn instead of fail for debugging
- assert (
- policy_params_changed or q1_params_changed or q2_params_changed
- ), "At least some parameters should change after update"
+ assert policy_params_changed or q1_params_changed or q2_params_changed, (
+ "At least some parameters should change after update"
+ )
assert isinstance(result_metrics, dict), "Update should return metrics dict"
# Check for expected SAC metrics in the result
expected_metrics = [
- "q_loss1",
- "q_loss2",
- "policy_loss",
- "alpha_loss",
+ "Update/q_loss1",
+ "Update/q_loss2",
+ "Update/policy_loss",
+ "Update/alpha_loss",
] # Added alpha_loss
for metric in expected_metrics:
if metric in result_metrics:
@@ -225,31 +224,31 @@ def test_properties(self):
params = agent.parameters
assert isinstance(params, list), "Parameters should be a list"
assert len(params) > 0, "Should have parameters"
- assert all(
- isinstance(p, torch.nn.Parameter) for p in params
- ), "All should be Parameters"
+ assert all(isinstance(p, torch.nn.Parameter) for p in params), (
+ "All should be Parameters"
+ )
# Test that parameters include all three networks (policy, q1, q2)
policy_params = list(agent.model.policy_net.parameters())
q1_params = list(agent.model.q_net1.parameters())
q2_params = list(agent.model.q_net2.parameters())
expected_param_count = len(policy_params) + len(q1_params) + len(q2_params)
- assert (
- len(params) == expected_param_count
- ), f"Expected {expected_param_count} parameters, got {len(params)}"
+ assert len(params) == expected_param_count, (
+ f"Expected {expected_param_count} parameters, got {len(params)}"
+ )
# Test value_function property
value_fn = agent.value_function
- assert isinstance(
- value_fn, torch.nn.Module
- ), "Value function should be a torch module"
+ assert isinstance(value_fn, torch.nn.Module), (
+ "Value function should be a torch module"
+ )
# Test that value function can be called with a state
dummy_state = torch.randn(1, agent.env.single_observation_space.shape[0])
value_output = value_fn(dummy_state)
- assert isinstance(
- value_output, torch.Tensor
- ), "Value function should return a tensor"
+ assert isinstance(value_output, torch.Tensor), (
+ "Value function should return a tensor"
+ )
assert value_output.shape == (
1,
1,
@@ -257,8 +256,69 @@ def test_properties(self):
# Test that value function is the cached module
value_fn2 = agent.value_function
- assert (
- value_fn is value_fn2
- ), "Value function should be cached and return same instance"
+ assert value_fn is value_fn2, (
+ "Value function should be cached and return same instance"
+ )
+
+ clean(output_dir)
+
+ def test_reproducibility(self):
+ env = gym.vector.SyncVectorEnv([DummyContinuousEnv for _ in range(1)])
+ output_dir = Path("test_sac_agent")
+ output_dir.mkdir(parents=True, exist_ok=True)
+ sac = MightySACAgent(
+ output_dir, env, seed=42, learning_starts=0, update_every=20
+ )
+ init_params = deepcopy(list(sac.model.parameters()))
+ sac.run(20, 1)
+ batch = sac.buffer.sample(20)
+ original_metrics = sac.update_agent(batch, 20)
+ original_params = deepcopy(list(sac.model.parameters()))
+
+ env = gym.vector.SyncVectorEnv([DummyContinuousEnv for _ in range(1)])
+ sac = MightySACAgent(output_dir, env, seed=42)
+ state, _ = sac.env.reset(seed=42)
+ step_state, _, _, _, _ = sac.env.step([0])
+ env = gym.vector.SyncVectorEnv([DummyContinuousEnv for _ in range(1)])
+ sac = MightySACAgent(output_dir, env, seed=42)
+ other_state, _ = sac.env.reset(seed=42)
+ other_step_state, _, _, _, _ = sac.env.step([0])
+ assert np.allclose(state, other_state), "States should be equal with same seed"
+ assert np.allclose(step_state, other_step_state), (
+ "Step states should be equal with same seed"
+ )
+
+ for _ in range(3):
+ env = gym.vector.SyncVectorEnv([DummyContinuousEnv for _ in range(1)])
+ output_dir = Path("test_sac_agent")
+ output_dir.mkdir(parents=True, exist_ok=True)
+ sac = MightySACAgent(
+ output_dir, env, seed=42, learning_starts=0, update_every=20
+ )
+ for old, new in zip(
+ init_params[:10], list(sac.model.parameters())[:10], strict=False
+ ):
+ assert torch.allclose(old, new), (
+ "Parameter initialization should be the same with same seed"
+ )
+ sac.run(20, 1)
+ batch = sac.buffer.sample(20)
+ new_metrics = sac.update_agent(batch, 20)
+ for old, new in zip(
+ original_params[:10], list(sac.model.parameters())[:10], strict=False
+ ):
+ assert torch.allclose(old, new), (
+ "Model parameters should stay the same with same seed"
+ )
+ assert np.isclose(
+ original_metrics["Update/q_loss1"], new_metrics["Update/q_loss1"]
+ ), "Q1 loss should be the same with same seed"
+ assert np.isclose(
+ original_metrics["Update/q_loss2"], new_metrics["Update/q_loss2"]
+ ), "Q2 loss should be the same with same seed"
+ assert np.isclose(
+ original_metrics["Update/policy_loss"],
+ new_metrics["Update/policy_loss"],
+ ), "Policy loss should be the same with same seed"
clean(output_dir)
diff --git a/test/exploration/test_epsilon_greedy.py b/test/exploration/test_epsilon_greedy.py
index 4769b7cc..0e1864a9 100644
--- a/test/exploration/test_epsilon_greedy.py
+++ b/test/exploration/test_epsilon_greedy.py
@@ -27,30 +27,30 @@ def test_exploration_func(self, state):
actions, qvals = policy.explore_func(state)
greedy_actions, greedy_qvals = policy.sample_action(state)
assert len(actions) == len(state), "Action should be predicted per state."
- assert all(
- a == g for g in greedy_actions for a in actions
- ), f"Actions should match greedy: {actions}///{greedy_actions}"
- assert torch.equal(
- qvals, greedy_qvals
- ), f"Q-values should match greedy: {qvals}///{greedy_qvals}"
+ assert all(a == g for g in greedy_actions for a in actions), (
+ f"Actions should match greedy: {actions}///{greedy_actions}"
+ )
+ assert torch.equal(qvals, greedy_qvals), (
+ f"Q-values should match greedy: {qvals}///{greedy_qvals}"
+ )
policy = self.get_policy(epsilon=0.5)
actions = np.array(
[policy.explore_func(state)[0] for _ in range(100)]
).flatten()
- assert (
- sum([a == 1 for a in actions]) / (100 * len(state)) > 0.5
- ), "Actions should match greedy at least in half of cases."
- assert (
- sum([a == 1 for a in actions]) / (100 * len(state)) < 0.8
- ), "Actions should match greedy in less than 4/5 of cases."
+ assert sum([a == 1 for a in actions]) / (100 * len(state)) > 0.5, (
+ "Actions should match greedy at least in half of cases."
+ )
+ assert sum([a == 1 for a in actions]) / (100 * len(state)) < 0.8, (
+ "Actions should match greedy in less than 4/5 of cases."
+ )
policy = self.get_policy(epsilon=np.linspace(0, 1, len(state)))
actions = np.array([policy.explore_func(state)[0] for _ in range(100)])
assert all(actions[:, 0] == 1), "Low index actions should match greedy."
- assert (
- sum(actions[:, -1] == 1) / 100 < 0.33
- ), "High index actions should not match greedy more than 1/3 of the time."
+ assert sum(actions[:, -1] == 1) / 100 < 0.33, (
+ "High index actions should not match greedy more than 1/3 of the time."
+ )
@pytest.mark.parametrize(
"state",
@@ -67,6 +67,6 @@ def test_multiple_epsilons(self, state):
policy = self.get_policy(epsilon=[0.1, 0.5])
assert np.all(policy.epsilon == [0.1, 0.5]), "Epsilon should be [0.1, 0.5]."
action, _ = policy.explore_func(state)
- assert len(action) == len(
- state.numpy()
- ), f"Action should be predicted per state: len({action}) != len({state.numpy()})."
+ assert len(action) == len(state.numpy()), (
+ f"Action should be predicted per state: len({action}) != len({state.numpy()})."
+ )
diff --git a/test/exploration/test_exploration.py b/test/exploration/test_exploration.py
index 817fc92f..ed35176e 100644
--- a/test/exploration/test_exploration.py
+++ b/test/exploration/test_exploration.py
@@ -29,16 +29,16 @@ def test_call(self, state):
policy(state)
greedy_actions, qvals = policy(state, evaluate=True, return_logp=True)
- assert all(
- greedy_actions == 1
- ), f"Greedy actions should be 1: {greedy_actions}///{qvals}"
+ assert all(greedy_actions == 1), (
+ f"Greedy actions should be 1: {greedy_actions}///{qvals}"
+ )
assert qvals.shape[-1] == 5, "Q-value shape should not be changed."
assert len(qvals) == len(state), "Q-value length should not be changed."
policy = self.get_policy(action=3)
greedy_actions, qvals = policy(state, evaluate=True, return_logp=True)
- assert all(
- greedy_actions == 3
- ), f"Greedy actions should be 3: {greedy_actions}///{qvals}"
+ assert all(greedy_actions == 3), (
+ f"Greedy actions should be 3: {greedy_actions}///{qvals}"
+ )
assert qvals.shape[-1] == 5, "Q-value shape should not be changed."
assert len(qvals) == len(state), "Q-value length should not be changed."
diff --git a/test/exploration/test_ez_greedy.py b/test/exploration/test_ez_greedy.py
index 37e04b33..c2aa09dd 100644
--- a/test/exploration/test_ez_greedy.py
+++ b/test/exploration/test_ez_greedy.py
@@ -19,15 +19,15 @@ def test_init(self) -> None:
use_target=False,
policy_class="mighty.mighty_exploration.EZGreedy",
)
- assert isinstance(
- dqn.policy, EZGreedy
- ), "Policy should be an instance of EZGreedy when creating with string."
+ assert isinstance(dqn.policy, EZGreedy), (
+ "Policy should be an instance of EZGreedy when creating with string."
+ )
assert dqn.policy.epsilon == 0.1, "Default epsilon should be 0.1."
assert dqn.policy.zipf_param == 2, "Default zipf_param should be 2."
assert dqn.policy.skipped is None, "Skip should be initialized at None."
- assert (
- dqn.policy.frozen_actions is None
- ), "Frozen actions should be initialized at None."
+ assert dqn.policy.frozen_actions is None, (
+ "Frozen actions should be initialized at None."
+ )
dqn = MightyDQNAgent(
output_dir,
@@ -36,9 +36,9 @@ def test_init(self) -> None:
policy_class=EZGreedy,
policy_kwargs={"epsilon": [0.5, 0.3], "zipf_param": 3},
)
- assert isinstance(
- dqn.policy, EZGreedy
- ), "Policy should be an instance of EZGreedy when creating with class."
+ assert isinstance(dqn.policy, EZGreedy), (
+ "Policy should be an instance of EZGreedy when creating with class."
+ )
assert np.all(dqn.policy.epsilon == [0.5, 0.3]), "Epsilon should be [0.5, 0.3]."
assert dqn.policy.zipf_param == 3, "zipf_param should be 3."
clean(output_dir)
@@ -57,16 +57,16 @@ def test_skip_single(self) -> None:
state, _ = env.reset()
action = dqn.policy([state])
- assert np.all(
- action < env.single_action_space.n
- ), "Action should be within the action space."
+ assert np.all(action < env.single_action_space.n), (
+ "Action should be within the action space."
+ )
assert len(action) == len(state), "Action should be predicted per state."
dqn.policy.skipped = np.array([1])
next_action = dqn.policy([state])
- assert np.all(
- action == next_action
- ), "Action should be the same as the previous action when skip is active."
+ assert np.all(action == next_action), (
+ "Action should be the same as the previous action when skip is active."
+ )
assert dqn.policy.skipped[0] == 0, "Skip should be decayed by one."
clean(output_dir)
@@ -84,16 +84,16 @@ def test_skip_batch(self) -> None:
state, _ = env.reset()
action = dqn.policy(state)
- assert all(
- [a < env.single_action_space.n for a in action]
- ), "Actions should be within the action space."
+ assert all([a < env.single_action_space.n for a in action]), (
+ "Actions should be within the action space."
+ )
assert len(action) == len(state), "Action should be predicted per state."
dqn.policy.skipped = np.array([3, 0])
next_action = dqn.policy(state + 2)
- assert (
- action[0] == next_action[0]
- ), f"First action should be the same as the previous action when skip is active: {action[0]} != {next_action[0]}"
+ assert action[0] == next_action[0], (
+ f"First action should be the same as the previous action when skip is active: {action[0]} != {next_action[0]}"
+ )
assert dqn.policy.skipped[0] == 2, "Skip should be decayed by one."
assert dqn.policy.skipped[1] >= 0, "Skip should not be decayed below one."
clean(output_dir)
diff --git a/test/meta_components/test_cosine_schedule.py b/test/meta_components/test_cosine_schedule.py
index d20bcaad..4c9cea4e 100644
--- a/test/meta_components/test_cosine_schedule.py
+++ b/test/meta_components/test_cosine_schedule.py
@@ -25,12 +25,12 @@ def test_decay(self) -> None:
dqn.learning_rate = lr
for i in range(4):
metrics = dqn.run(n_steps=10 * (i + 1))
- assert (
- metrics["hp/lr"] == dqn.learning_rate
- ), f"Learning rate should be set to schedule value {metrics['hp/lr']} instead of {dqn.learning_rate}."
- assert (
- dqn.learning_rate < lr
- ), f"Learning rate should decrease: {dqn.learning_rate} is not less than {lr}."
+ assert metrics["hp/lr"] == dqn.learning_rate, (
+ f"Learning rate should be set to schedule value {metrics['hp/lr']} instead of {dqn.learning_rate}."
+ )
+ assert dqn.learning_rate < lr, (
+ f"Learning rate should decrease: {dqn.learning_rate} is not less than {lr}."
+ )
lr = dqn.learning_rate.copy()
clean(output_dir)
@@ -47,10 +47,10 @@ def test_restart(self) -> None:
],
)
dqn.run(6, 0)
- assert (
- dqn.meta_modules["CosineLRSchedule"].n_restarts == 1
- ), "Restart counter should increase."
- assert (
- dqn.learning_rate >= dqn.meta_modules["CosineLRSchedule"].eta_max
- ), f"Restart should increase learning rate: {dqn.learning_rate} is not {dqn.meta_modules['CosineLRSchedule'].eta_max}."
+ assert dqn.meta_modules["CosineLRSchedule"].n_restarts == 1, (
+ "Restart counter should increase."
+ )
+ assert dqn.learning_rate >= dqn.meta_modules["CosineLRSchedule"].eta_max, (
+ f"Restart should increase learning rate: {dqn.learning_rate} is not {dqn.meta_modules['CosineLRSchedule'].eta_max}."
+ )
clean(output_dir)
diff --git a/test/meta_components/test_noveld.py b/test/meta_components/test_noveld.py
index 5a4ae171..5428bfc1 100644
--- a/test/meta_components/test_noveld.py
+++ b/test/meta_components/test_noveld.py
@@ -31,33 +31,33 @@ def test_init(self) -> None:
ppo = self.init_NovelD()
assert len(ppo.meta_modules) == 1, "There should be one meta module."
assert "NovelD" in ppo.meta_modules, "NovelD should be in meta modules."
- assert isinstance(
- ppo.meta_modules["NovelD"], NovelD
- ), "NovelD should be meta module when created from string."
- assert (
- ppo.meta_modules["NovelD"].rnd_output_dim == 512
- ), "Default output dim should be 512."
- assert (
- len(ppo.meta_modules["NovelD"].rnd_network_config) == 0
- ), f"Default network config should be empty, got {ppo.meta_modules['NovelD'].rnd_network_config}."
- assert (
- ppo.meta_modules["NovelD"].internal_reward_weight == 0.1
- ), "Default internal reward weight should be 0.1."
- assert (
- ppo.meta_modules["NovelD"].rnd_lr == 0.001
- ), "Default NovelD learning rate should be 0.001."
- assert (
- ppo.meta_modules["NovelD"].rnd_eps == 1e-5
- ), "Default NovelD epsilon should be 1e-5."
- assert (
- ppo.meta_modules["NovelD"].rnd_weight_decay == 0.01
- ), "Default NovelD weight decay should be 0.01."
- assert (
- ppo.meta_modules["NovelD"].update_proportion == 0.5
- ), "Default update proportion should be 0.5."
- assert (
- ppo.meta_modules["NovelD"].rnd_net is None
- ), "NovelD network should be None."
+ assert isinstance(ppo.meta_modules["NovelD"], NovelD), (
+ "NovelD should be meta module when created from string."
+ )
+ assert ppo.meta_modules["NovelD"].rnd_output_dim == 512, (
+ "Default output dim should be 512."
+ )
+ assert len(ppo.meta_modules["NovelD"].rnd_network_config) == 0, (
+ f"Default network config should be empty, got {ppo.meta_modules['NovelD'].rnd_network_config}."
+ )
+ assert ppo.meta_modules["NovelD"].internal_reward_weight == 0.1, (
+ "Default internal reward weight should be 0.1."
+ )
+ assert ppo.meta_modules["NovelD"].rnd_lr == 0.001, (
+ "Default NovelD learning rate should be 0.001."
+ )
+ assert ppo.meta_modules["NovelD"].rnd_eps == 1e-5, (
+ "Default NovelD epsilon should be 1e-5."
+ )
+ assert ppo.meta_modules["NovelD"].rnd_weight_decay == 0.01, (
+ "Default NovelD weight decay should be 0.01."
+ )
+ assert ppo.meta_modules["NovelD"].update_proportion == 0.5, (
+ "Default update proportion should be 0.5."
+ )
+ assert ppo.meta_modules["NovelD"].rnd_net is None, (
+ "NovelD network should be None."
+ )
env = gym.vector.SyncVectorEnv([DummyEnv for _ in range(1)])
output_dir = Path("test_NovelD")
@@ -80,33 +80,33 @@ def test_init(self) -> None:
)
assert len(ppo.meta_modules) == 1, "There should be one meta module."
assert "NovelD" in ppo.meta_modules, "NovelD should be in meta modules."
- assert isinstance(
- ppo.meta_modules["NovelD"], NovelD
- ), "NovelD should be meta module when created from class."
- assert (
- ppo.meta_modules["NovelD"].rnd_output_dim == 12
- ), "Output dim should be 12."
- assert ppo.meta_modules["NovelD"].rnd_network_config == {
- "test": True
- }, "Network config should be {'test': True}."
- assert (
- ppo.meta_modules["NovelD"].internal_reward_weight == 0.5
- ), "Internal reward weight should be 0.5."
- assert (
- ppo.meta_modules["NovelD"].rnd_lr == 0.2
- ), "NovelD learning rate should be 0.2."
- assert (
- ppo.meta_modules["NovelD"].rnd_eps == 1e-4
- ), "NovelD epsilon should be 1e-4."
- assert (
- ppo.meta_modules["NovelD"].rnd_weight_decay == 0.1
- ), "NovelD weight decay should be 0.1."
- assert (
- ppo.meta_modules["NovelD"].update_proportion == 0.3
- ), "Update proportion should be 0.3."
- assert (
- ppo.meta_modules["NovelD"].rnd_net is None
- ), "NovelD network should be None."
+ assert isinstance(ppo.meta_modules["NovelD"], NovelD), (
+ "NovelD should be meta module when created from class."
+ )
+ assert ppo.meta_modules["NovelD"].rnd_output_dim == 12, (
+ "Output dim should be 12."
+ )
+ assert ppo.meta_modules["NovelD"].rnd_network_config == {"test": True}, (
+ "Network config should be {'test': True}."
+ )
+ assert ppo.meta_modules["NovelD"].internal_reward_weight == 0.5, (
+ "Internal reward weight should be 0.5."
+ )
+ assert ppo.meta_modules["NovelD"].rnd_lr == 0.2, (
+ "NovelD learning rate should be 0.2."
+ )
+ assert ppo.meta_modules["NovelD"].rnd_eps == 1e-4, (
+ "NovelD epsilon should be 1e-4."
+ )
+ assert ppo.meta_modules["NovelD"].rnd_weight_decay == 0.1, (
+ "NovelD weight decay should be 0.1."
+ )
+ assert ppo.meta_modules["NovelD"].update_proportion == 0.3, (
+ "Update proportion should be 0.3."
+ )
+ assert ppo.meta_modules["NovelD"].rnd_net is None, (
+ "NovelD network should be None."
+ )
clean("test_NovelD")
def test_reward_computation(self) -> None:
@@ -120,15 +120,15 @@ def test_reward_computation(self) -> None:
}
}
updated_metrics = ppo.meta_modules["NovelD"].get_reward(dummy_metrics)
- assert (
- ppo.meta_modules["NovelD"].rnd_net is not None
- ), "NovelD network should be initialized."
- assert (
- ppo.meta_modules["NovelD"].last_error != 0
- ), "Last error should be non-zero."
- assert (
- "intrinsic_reward" in updated_metrics["transition"]
- ), "Intrinsic reward should be in updated metrics."
+ assert ppo.meta_modules["NovelD"].rnd_net is not None, (
+ "NovelD network should be initialized."
+ )
+ assert ppo.meta_modules["NovelD"].last_error != 0, (
+ "Last error should be non-zero."
+ )
+ assert "intrinsic_reward" in updated_metrics["transition"], (
+ "Intrinsic reward should be in updated metrics."
+ )
assert sum(updated_metrics["transition"]["intrinsic_reward"]) == sum(
updated_metrics["transition"]["reward"]
), "Intrinsic reward should be added to base reward."
@@ -143,7 +143,9 @@ def test_reward_computation(self) -> None:
* ppo.meta_modules["NovelD"].internal_reward_weight,
updated_metrics["transition"]["reward"],
atol=1e-3,
- ), f"Intrinsic reward should be difference between current and last errors. Actual {abs(ppo.meta_modules['NovelD'].last_error - current_last_error) * ppo.meta_modules['NovelD'].internal_reward_weight}, assigned {updated_metrics['transition']['reward']}."
+ ), (
+ f"Intrinsic reward should be difference between current and last errors. Actual {abs(ppo.meta_modules['NovelD'].last_error - current_last_error) * ppo.meta_modules['NovelD'].internal_reward_weight}, assigned {updated_metrics['transition']['reward']}."
+ )
clean("test_NovelD")
def test_update(self) -> None:
@@ -185,11 +187,11 @@ def test_update(self) -> None:
new_target_params = ppo.meta_modules["NovelD"].rnd_net.target.parameters()
for new_param, old_param in zip(new_predictor_params, predictor_params):
- assert not torch.allclose(
- new_param, old_param
- ), "Predictor parameters should be updated."
+ assert not torch.allclose(new_param, old_param), (
+ "Predictor parameters should be updated."
+ )
for new_param, old_param in zip(new_target_params, target_params):
- assert torch.allclose(
- new_param, old_param
- ), "Target parameters should stay fixed."
+ assert torch.allclose(new_param, old_param), (
+ "Target parameters should stay fixed."
+ )
clean("test_NovelD")
diff --git a/test/meta_components/test_plr.py b/test/meta_components/test_plr.py
index 4d0425e6..2498e318 100644
--- a/test/meta_components/test_plr.py
+++ b/test/meta_components/test_plr.py
@@ -73,16 +73,16 @@ def test_init(self) -> None:
assert plr.score_transform == "max", "Score transform should be set to max."
assert plr.temperature == 0.8, "Temperature should be set to 0.8."
assert plr.eps == 1e-5, "Epsilon should be set to 1e-5."
- assert (
- plr.staleness_transform == "max"
- ), "Staleness transform should be set to max."
- assert (
- plr.staleness_temperature == 0.8
- ), "Staleness temperature should be set to 0.8."
+ assert plr.staleness_transform == "max", (
+ "Staleness transform should be set to max."
+ )
+ assert plr.staleness_temperature == 0.8, (
+ "Staleness temperature should be set to 0.8."
+ )
- assert (
- plr.instance_scores == {}
- ), "Instance scores should be an empty dictionary."
+ assert plr.instance_scores == {}, (
+ "Instance scores should be an empty dictionary."
+ )
assert plr.staleness == {}, "Staleness should be an empty dictionary."
assert plr.all_instances is None, "All instances should be None."
assert plr.index == 0, "Index should be 0."
@@ -95,22 +95,22 @@ def test_get_instance(self) -> None:
metrics = {"env": EnvSim(1)}
plr.get_instance(metrics=metrics)
assert metrics["env"].inst_ids is not None, "Instance should not be None."
- assert (
- metrics["env"].inst_ids in plr.instance_scores.keys()
- ), "Instance should be in instance scores."
+ assert metrics["env"].inst_ids in plr.instance_scores.keys(), (
+ "Instance should be in instance scores."
+ )
assert plr.all_instances is not None, "All instances should be initialized."
- assert all(
- [i in plr.all_instances for i in plr.instance_scores.keys()]
- ), "All instances should be in instance scores."
+ assert all([i in plr.all_instances for i in plr.instance_scores.keys()]), (
+ "All instances should be in instance scores."
+ )
plr.sample_strategy = "sequential"
index = plr.index
metrics = {"env": EnvSim(10)}
original_instance = 10
plr.get_instance(metrics=metrics)
- assert (
- original_instance != metrics["env"].inst_ids
- ), "Instance should be changed."
+ assert original_instance != metrics["env"].inst_ids, (
+ "Instance should be changed."
+ )
assert plr.index == index + 1, "Index should be incremented by 1."
def test_sample_weights(self) -> None:
@@ -118,9 +118,9 @@ def test_sample_weights(self) -> None:
for m in DUMMY_METRICS:
plr.add_rollout(m)
weights = plr.sample_weights()
- assert len(weights) == len(
- plr.instance_scores.keys()
- ), "Length of weights should be equal to the number of instances."
+ assert len(weights) == len(plr.instance_scores.keys()), (
+ "Length of weights should be equal to the number of instances."
+ )
assert np.isclose(sum(weights), 1), "Sum of weights should be 1."
def test_score_transforms(self) -> None:
@@ -137,9 +137,9 @@ def test_score_transforms(self) -> None:
]:
plr.score_transform = score_transform
weights = plr.sample_weights()
- assert len(weights) == len(
- plr.instance_scores.keys()
- ), "Length of weights should be equal to the number of instances."
+ assert len(weights) == len(plr.instance_scores.keys()), (
+ "Length of weights should be equal to the number of instances."
+ )
assert np.isclose(sum(weights), 1), "Sum of weights should be 1."
@pytest.mark.parametrize("metrics", DUMMY_METRICS)
@@ -162,22 +162,22 @@ def test_score_function(self, metrics) -> None:
plr.add_rollout(metrics)
else:
plr.add_rollout(metrics)
- assert (
- plr.instance_scores[metrics["env"].inst_ids[0]] is not None
- ), "Instance score should not be None."
+ assert plr.instance_scores[metrics["env"].inst_ids[0]] is not None, (
+ "Instance score should not be None."
+ )
if score_func == "random":
- assert (
- plr.instance_scores[0] == 1.0
- ), f"Random score should be 1. Scores were: {plr.instance_scores}"
+ assert plr.instance_scores[0] == 1.0, (
+ f"Random score should be 1. Scores were: {plr.instance_scores}"
+ )
@pytest.mark.parametrize("metrics", DUMMY_METRICS)
def test_add_rollout(self, metrics) -> None:
plr = PLR()
plr.add_rollout(metrics)
- assert (
- metrics["env"].inst_ids[0] in plr.instance_scores
- ), "Instance should be added to instance scores."
+ assert metrics["env"].inst_ids[0] in plr.instance_scores, (
+ "Instance should be added to instance scores."
+ )
def test_in_loop(self) -> None:
env = ContextualVecEnv([DummyEnv for _ in range(2)])
@@ -189,13 +189,13 @@ def test_in_loop(self) -> None:
use_target=False,
meta_methods=["mighty.mighty_meta.PrioritizedLevelReplay"],
)
- assert (
- dqn.meta_modules["PrioritizedLevelReplay"] is not None
- ), "PLR should be initialized."
+ assert dqn.meta_modules["PrioritizedLevelReplay"] is not None, (
+ "PLR should be initialized."
+ )
dqn.run(100, 0)
- assert (
- dqn.meta_modules["PrioritizedLevelReplay"].all_instances is not None
- ), "All instances should be initialized."
+ assert dqn.meta_modules["PrioritizedLevelReplay"].all_instances is not None, (
+ "All instances should be initialized."
+ )
assert (
env.inst_ids[0] in dqn.meta_modules["PrioritizedLevelReplay"].all_instances
), "Instance should be in all instances."
diff --git a/test/meta_components/test_rnd.py b/test/meta_components/test_rnd.py
index 06b33b8f..df0d906b 100644
--- a/test/meta_components/test_rnd.py
+++ b/test/meta_components/test_rnd.py
@@ -31,30 +31,30 @@ def test_init(self) -> None:
ppo = self.init_rnd()
assert len(ppo.meta_modules) == 1, "There should be one meta module."
assert "RND" in ppo.meta_modules, "RND should be in meta modules."
- assert isinstance(
- ppo.meta_modules["RND"], RND
- ), "RND should be meta module when created from string."
- assert (
- ppo.meta_modules["RND"].rnd_output_dim == 512
- ), "Default output dim should be 512."
- assert (
- len(ppo.meta_modules["RND"].rnd_network_config) == 0
- ), f"Default network config should be empty, got {ppo.meta_modules['RND'].rnd_network_config}."
- assert (
- ppo.meta_modules["RND"].internal_reward_weight == 0.1
- ), "Default internal reward weight should be 0.1."
- assert (
- ppo.meta_modules["RND"].rnd_lr == 0.001
- ), "Default RND learning rate should be 0.001."
- assert (
- ppo.meta_modules["RND"].rnd_eps == 1e-5
- ), "Default RND epsilon should be 1e-5."
- assert (
- ppo.meta_modules["RND"].rnd_weight_decay == 0.01
- ), "Default RND weight decay should be 0.01."
- assert (
- ppo.meta_modules["RND"].update_proportion == 0.5
- ), "Default update proportion should be 0.5."
+ assert isinstance(ppo.meta_modules["RND"], RND), (
+ "RND should be meta module when created from string."
+ )
+ assert ppo.meta_modules["RND"].rnd_output_dim == 512, (
+ "Default output dim should be 512."
+ )
+ assert len(ppo.meta_modules["RND"].rnd_network_config) == 0, (
+ f"Default network config should be empty, got {ppo.meta_modules['RND'].rnd_network_config}."
+ )
+ assert ppo.meta_modules["RND"].internal_reward_weight == 0.1, (
+ "Default internal reward weight should be 0.1."
+ )
+ assert ppo.meta_modules["RND"].rnd_lr == 0.001, (
+ "Default RND learning rate should be 0.001."
+ )
+ assert ppo.meta_modules["RND"].rnd_eps == 1e-5, (
+ "Default RND epsilon should be 1e-5."
+ )
+ assert ppo.meta_modules["RND"].rnd_weight_decay == 0.01, (
+ "Default RND weight decay should be 0.01."
+ )
+ assert ppo.meta_modules["RND"].update_proportion == 0.5, (
+ "Default update proportion should be 0.5."
+ )
assert ppo.meta_modules["RND"].rnd_net is None, "RND network should be None."
env = gym.vector.SyncVectorEnv([DummyEnv for _ in range(1)])
@@ -78,24 +78,24 @@ def test_init(self) -> None:
)
assert len(ppo.meta_modules) == 1, "There should be one meta module."
assert "RND" in ppo.meta_modules, "RND should be in meta modules."
- assert isinstance(
- ppo.meta_modules["RND"], RND
- ), "RND should be meta module when created from class."
+ assert isinstance(ppo.meta_modules["RND"], RND), (
+ "RND should be meta module when created from class."
+ )
assert ppo.meta_modules["RND"].rnd_output_dim == 12, "Output dim should be 12."
- assert ppo.meta_modules["RND"].rnd_network_config == {
- "test": True
- }, "Network config should be {'test': True}."
- assert (
- ppo.meta_modules["RND"].internal_reward_weight == 0.5
- ), "Internal reward weight should be 0.5."
+ assert ppo.meta_modules["RND"].rnd_network_config == {"test": True}, (
+ "Network config should be {'test': True}."
+ )
+ assert ppo.meta_modules["RND"].internal_reward_weight == 0.5, (
+ "Internal reward weight should be 0.5."
+ )
assert ppo.meta_modules["RND"].rnd_lr == 0.2, "RND learning rate should be 0.2."
assert ppo.meta_modules["RND"].rnd_eps == 1e-4, "RND epsilon should be 1e-4."
- assert (
- ppo.meta_modules["RND"].rnd_weight_decay == 0.1
- ), "RND weight decay should be 0.1."
- assert (
- ppo.meta_modules["RND"].update_proportion == 0.3
- ), "Update proportion should be 0.3."
+ assert ppo.meta_modules["RND"].rnd_weight_decay == 0.1, (
+ "RND weight decay should be 0.1."
+ )
+ assert ppo.meta_modules["RND"].update_proportion == 0.3, (
+ "Update proportion should be 0.3."
+ )
assert ppo.meta_modules["RND"].rnd_net is None, "RND network should be None."
clean("test_rnd")
@@ -110,12 +110,12 @@ def test_reward_computation(self) -> None:
}
}
updated_metrics = ppo.meta_modules["RND"].get_reward(dummy_metrics)
- assert (
- ppo.meta_modules["RND"].rnd_net is not None
- ), "RND network should be initialized."
- assert (
- "intrinsic_reward" in updated_metrics["transition"]
- ), "Intrinsic reward should be in updated metrics."
+ assert ppo.meta_modules["RND"].rnd_net is not None, (
+ "RND network should be initialized."
+ )
+ assert "intrinsic_reward" in updated_metrics["transition"], (
+ "Intrinsic reward should be in updated metrics."
+ )
assert sum(updated_metrics["transition"]["intrinsic_reward"]) == sum(
updated_metrics["transition"]["reward"]
), "Intrinsic reward should be added to base reward."
@@ -160,11 +160,11 @@ def test_update(self) -> None:
new_target_params = ppo.meta_modules["RND"].rnd_net.target.parameters()
for new_param, old_param in zip(new_predictor_params, predictor_params):
- assert not torch.allclose(
- new_param, old_param
- ), "Predictor parameters should be updated."
+ assert not torch.allclose(new_param, old_param), (
+ "Predictor parameters should be updated."
+ )
for new_param, old_param in zip(new_target_params, target_params):
- assert torch.allclose(
- new_param, old_param
- ), "Target parameters should stay fixed."
+ assert torch.allclose(new_param, old_param), (
+ "Target parameters should stay fixed."
+ )
clean("test_rnd")
diff --git a/test/meta_components/test_space.py b/test/meta_components/test_space.py
index be03bd98..c352d2a2 100644
--- a/test/meta_components/test_space.py
+++ b/test/meta_components/test_space.py
@@ -25,12 +25,12 @@ def test_get_instances(self) -> None:
"rollout_values": [[0.0, 0.6, 0.7]],
}
space.get_instances(metrics)
- assert (
- len(space.all_instances) == 1
- ), f"Expected 1, got {len(space.all_instances)}"
- assert (
- len(space.instance_set) == 1
- ), f"Expected 1, got {len(space.instance_set)}"
+ assert len(space.all_instances) == 1, (
+ f"Expected 1, got {len(space.all_instances)}"
+ )
+ assert len(space.instance_set) == 1, (
+ f"Expected 1, got {len(space.instance_set)}"
+ )
assert space.last_evals is not None, "Evals should not be None."
def test_get_evals(self) -> None:
@@ -53,10 +53,10 @@ def test_in_loop(self) -> None:
)
assert dqn.meta_modules["SPaCE"] is not None, "SPaCE should be initialized."
dqn.run(100, 0)
- assert (
- dqn.meta_modules["SPaCE"].all_instances is not None
- ), "All instances should be initialized."
- assert (
- env.inst_ids[0] in dqn.meta_modules["SPaCE"].all_instances
- ), "Instance should be in all instances."
+ assert dqn.meta_modules["SPaCE"].all_instances is not None, (
+ "All instances should be initialized."
+ )
+ assert env.inst_ids[0] in dqn.meta_modules["SPaCE"].all_instances, (
+ "Instance should be in all instances."
+ )
clean(output_dir)
diff --git a/test/models/test_networks.py b/test/models/test_networks.py
index d9bd54b2..61b71b97 100644
--- a/test/models/test_networks.py
+++ b/test/models/test_networks.py
@@ -115,15 +115,15 @@ def test_init(self, input_size, n_layers, hidden_sizes, activation):
assert isinstance(mlp, torch.jit.ScriptModule), "MLP is not a ScriptModule."
for n in range(n_layers):
- assert (
- type(mlp.layers[2 * n]) is torch.nn.Linear
- ), f"Layer {n} is not a Linear."
- assert (
- type(mlp.layers[2 * n + 1]) is ACTIVATIONS[activation]
- ), f"Activation {n} is not correct."
- assert (
- mlp.layers[2 * n].out_features == hidden_sizes[n]
- ), f"Wrong in_features in layer {n}."
+ assert type(mlp.layers[2 * n]) is torch.nn.Linear, (
+ f"Layer {n} is not a Linear."
+ )
+ assert type(mlp.layers[2 * n + 1]) is ACTIVATIONS[activation], (
+ f"Activation {n} is not correct."
+ )
+ assert mlp.layers[2 * n].out_features == hidden_sizes[n], (
+ f"Wrong in_features in layer {n}."
+ )
with pytest.raises(IndexError):
mlp.layers[2 * n + 2]
@@ -138,21 +138,21 @@ def test_soft_reset(self):
]
mlp.soft_reset(0, 0.5, 0.5)
reset_pred = mlp(dummy_input)
- assert torch.allclose(
- original_pred, reset_pred
- ), "Model prediction has changed."
- assert torch.allclose(
- mlp.layers[0].weight, prev_model_weights[0]
- ), "Weights have reset in layer 0 even though probability was 0."
- assert torch.allclose(
- mlp.layers[2].weight, prev_model_weights[1]
- ), "Weights have reset in layer 1 even though probability was 0."
+ assert torch.allclose(original_pred, reset_pred), (
+ "Model prediction has changed."
+ )
+ assert torch.allclose(mlp.layers[0].weight, prev_model_weights[0]), (
+ "Weights have reset in layer 0 even though probability was 0."
+ )
+ assert torch.allclose(mlp.layers[2].weight, prev_model_weights[1]), (
+ "Weights have reset in layer 1 even though probability was 0."
+ )
mlp.soft_reset(1, 0.5, 0.5)
reset_pred = mlp(dummy_input)
- assert ~torch.allclose(
- original_pred, reset_pred
- ), "Model prediction has not changed."
+ assert ~torch.allclose(original_pred, reset_pred), (
+ "Model prediction has not changed."
+ )
assert not any(
torch.isclose(mlp.layers[0].weight, prev_model_weights[0])
.flatten()
@@ -170,23 +170,23 @@ def test_soft_reset(self):
original_pred = mlp(dummy_input)
mlp.soft_reset(1, 1, 0)
reset_pred = mlp(dummy_input)
- assert torch.allclose(
- original_pred, reset_pred
- ), "Model prediction has changed though perturb was 0."
+ assert torch.allclose(original_pred, reset_pred), (
+ "Model prediction has changed though perturb was 0."
+ )
for new_param, old_param in zip(mlp.parameters(), prev_params, strict=False):
- assert torch.allclose(
- new_param, old_param
- ), "Weights have been reset even though perturb was 0."
+ assert torch.allclose(new_param, old_param), (
+ "Weights have been reset even though perturb was 0."
+ )
mlp.soft_reset(1, 0.5, 0.0)
reset_pred = mlp(dummy_input)
- assert (
- original_pred * 0.5 - reset_pred
- ).sum() < 0.1, "Model prediction didn't shrink with parameter value."
+ assert (original_pred * 0.5 - reset_pred).sum() < 0.1, (
+ "Model prediction didn't shrink with parameter value."
+ )
for new_param, old_param in zip(mlp.parameters(), prev_params, strict=False):
- assert torch.allclose(
- new_param, old_param * 0.5
- ), "Weights have not been shrunk."
+ assert torch.allclose(new_param, old_param * 0.5), (
+ "Weights have not been shrunk."
+ )
mlp.layers[0].weight = deepcopy(prev_model_weights[0])
mlp.layers[2].weight = deepcopy(prev_model_weights[1])
@@ -194,13 +194,13 @@ def test_soft_reset(self):
original_pred = mlp(dummy_input)
mlp.soft_reset(1, 1, 1)
reset_pred = mlp(dummy_input)
- assert not torch.allclose(
- original_pred, reset_pred
- ), "Model prediction has not changed."
+ assert not torch.allclose(original_pred, reset_pred), (
+ "Model prediction has not changed."
+ )
for new_param, old_param in zip(mlp.parameters(), prev_params, strict=False):
- assert not torch.allclose(
- new_param, old_param
- ), "Weights have not been perturbed."
+ assert not torch.allclose(new_param, old_param), (
+ "Weights have not been perturbed."
+ )
mlp = MLP(3, 5, [100, 100, 100, 100, 100], "relu")
n_reset = 0
@@ -229,15 +229,15 @@ def test_hard_reset(self):
]
mlp.full_hard_reset()
reset_pred = mlp(dummy_input)
- assert ~torch.allclose(
- original_pred, reset_pred
- ), "Model prediction has not changed."
- assert ~torch.allclose(
- mlp.layers[0].weight, prev_model_weights[0]
- ), "Weights have not been reset in layer 0."
- assert ~torch.allclose(
- mlp.layers[2].weight, prev_model_weights[1]
- ), "Weights have not been reset in layer 1."
+ assert ~torch.allclose(original_pred, reset_pred), (
+ "Model prediction has not changed."
+ )
+ assert ~torch.allclose(mlp.layers[0].weight, prev_model_weights[0]), (
+ "Weights have not been reset in layer 0."
+ )
+ assert ~torch.allclose(mlp.layers[2].weight, prev_model_weights[1]), (
+ "Weights have not been reset in layer 1."
+ )
def test_reset(self):
dummy_input = torch.rand(1, 3)
@@ -249,15 +249,15 @@ def test_reset(self):
]
mlp.reset(1)
reset_pred = mlp(dummy_input)
- assert ~torch.allclose(
- original_pred, reset_pred
- ), "Model prediction has not changed."
- assert ~torch.allclose(
- mlp.layers[0].weight, prev_model_weights[0]
- ), "Weights have not been reset in layer 0."
- assert ~torch.allclose(
- mlp.layers[2].weight, prev_model_weights[1]
- ), "Weights have not been reset in layer 1."
+ assert ~torch.allclose(original_pred, reset_pred), (
+ "Model prediction has not changed."
+ )
+ assert ~torch.allclose(mlp.layers[0].weight, prev_model_weights[0]), (
+ "Weights have not been reset in layer 0."
+ )
+ assert ~torch.allclose(mlp.layers[2].weight, prev_model_weights[1]), (
+ "Weights have not been reset in layer 1."
+ )
mlp = MLP(3, 5, [100, 100, 100, 100, 100], "relu")
original_pred = mlp(dummy_input)
@@ -279,9 +279,9 @@ def test_reset(self):
n_reset += 1
reset_param.data = old_param.data
- assert ~torch.allclose(
- original_pred, reset_pred
- ), "Model prediction has not changed."
+ assert ~torch.allclose(original_pred, reset_pred), (
+ "Model prediction has not changed."
+ )
assert n_reset / total_params >= 0.25, "Weights reset too rarely in soft reset."
assert n_reset / total_params <= 0.75, "Weights reset too often in soft reset."
@@ -291,9 +291,9 @@ def test_forward(self):
output = mlp(dummy_input)
assert output.shape == (3, 5), "Output shape is not correct."
assert output.dtype == torch.float32, "Output dtype is not correct."
- assert torch.allclose(
- output, mlp.layers(dummy_input)
- ), "Forward is not correct."
+ assert torch.allclose(output, mlp.layers(dummy_input)), (
+ "Forward is not correct."
+ )
class TestCNN:
@@ -313,9 +313,9 @@ def test_init(self):
)
cnn = CNN((64, 64, 3), 3, [32, 64, 64], [8, 4, 3], [4, 2, 1], [0, 0, 0], "relu")
combo = ComboNet(cnn, mlp)
- assert isinstance(
- combo, torch.jit.ScriptModule
- ), "ComboNet is not a ScriptModule."
+ assert isinstance(combo, torch.jit.ScriptModule), (
+ "ComboNet is not a ScriptModule."
+ )
assert combo.module1 == cnn, "CNN is not the first module."
assert combo.module2 == mlp, "MLP is not the second module."
diff --git a/test/models/test_ppo_networks.py b/test/models/test_ppo_networks.py
index e2cd79d9..d294370f 100644
--- a/test/models/test_ppo_networks.py
+++ b/test/models/test_ppo_networks.py
@@ -17,21 +17,21 @@ def test_init_discrete(self):
assert ppo.tanh_squash is False, "Default tanh_squash should be False"
# Check network structure - updated for new architecture
- assert hasattr(
- ppo, "feature_extractor_policy"
- ), "Should have policy feature extractor"
- assert hasattr(
- ppo, "feature_extractor_value"
- ), "Should have value feature extractor"
- assert isinstance(
- ppo.policy_head, nn.Sequential
- ), "Policy head should be Sequential"
- assert isinstance(
- ppo.value_head, nn.Sequential
- ), "Value head should be Sequential"
- assert hasattr(
- ppo, "value_function_module"
- ), "Should have value function module wrapper"
+ assert hasattr(ppo, "feature_extractor_policy"), (
+ "Should have policy feature extractor"
+ )
+ assert hasattr(ppo, "feature_extractor_value"), (
+ "Should have value feature extractor"
+ )
+ assert isinstance(ppo.policy_head, nn.Sequential), (
+ "Policy head should be Sequential"
+ )
+ assert isinstance(ppo.value_head, nn.Sequential), (
+ "Value head should be Sequential"
+ )
+ assert hasattr(ppo, "value_function_module"), (
+ "Should have value function module wrapper"
+ )
# Test forward pass shapes
dummy_input = torch.rand((10, 4))
@@ -72,9 +72,9 @@ def test_init_continuous_tanh_squash(self):
assert values.shape == (5, 1), "Values should have shape (5, 1)"
# Check that actions are in [-1, 1] range due to tanh
- assert torch.all(action >= -1.0) and torch.all(
- action <= 1.0
- ), "Actions should be in [-1, 1] range"
+ assert torch.all(action >= -1.0) and torch.all(action <= 1.0), (
+ "Actions should be in [-1, 1] range"
+ )
# Check log_std clamping
assert torch.all(log_std >= ppo.log_std_min), "Log_std should be >= log_std_min"
@@ -148,9 +148,9 @@ def test_value_function_module(self):
values_direct = ppo.forward_value(dummy_input)
values_module = ppo.value_function_module(dummy_input)
- assert torch.allclose(
- values_direct, values_module
- ), "Value function module should produce same output as forward_value"
+ assert torch.allclose(values_direct, values_module), (
+ "Value function module should produce same output as forward_value"
+ )
assert values_module.shape == (
8,
1,
@@ -190,15 +190,15 @@ def test_forward_continuous_tanh_squash(self):
assert torch.all(torch.isfinite(log_std)), "Log_stds should be finite"
# Check tanh constraint on actions
- assert torch.all(action >= -1.0) and torch.all(
- action <= 1.0
- ), "Actions should be in [-1, 1]"
+ assert torch.all(action >= -1.0) and torch.all(action <= 1.0), (
+ "Actions should be in [-1, 1]"
+ )
# Check relationship: action = tanh(z) where z = mean + std * eps
expected_action = torch.tanh(z)
- assert torch.allclose(
- action, expected_action, atol=1e-6
- ), "Action should equal tanh(z)"
+ assert torch.allclose(action, expected_action, atol=1e-6), (
+ "Action should equal tanh(z)"
+ )
def test_forward_continuous_standard(self):
"""Test forward pass for continuous actions with standard PPO."""
@@ -219,11 +219,6 @@ def test_forward_continuous_standard(self):
assert torch.all(torch.isfinite(mean)), "Means should be finite"
assert torch.all(torch.isfinite(log_std)), "Log_stds should be finite"
- # Check relationship: action = mean + std * eps (no tanh)
- std = torch.exp(log_std)
- # We can't check exact relationship due to random sampling, but verify no tanh constraint
- # Actions should not be constrained to [-1, 1] in standard PPO
-
def test_forward_value(self):
"""Test value network forward pass."""
ppo = PPOModel(obs_shape=5, action_size=3, continuous_action=False)
@@ -251,12 +246,12 @@ def test_custom_log_std_bounds(self):
dummy_input = torch.rand((10, 4))
_, _, _, log_std = ppo_tanh(dummy_input)
- assert torch.all(
- log_std >= log_std_min
- ), "Log_std should be >= custom log_std_min"
- assert torch.all(
- log_std <= log_std_max
- ), "Log_std should be <= custom log_std_max"
+ assert torch.all(log_std >= log_std_min), (
+ "Log_std should be >= custom log_std_min"
+ )
+ assert torch.all(log_std <= log_std_max), (
+ "Log_std should be <= custom log_std_max"
+ )
# Test with standard PPO
ppo_std = PPOModel(
@@ -270,12 +265,12 @@ def test_custom_log_std_bounds(self):
_, _, log_std = ppo_std(dummy_input)
- assert torch.all(
- log_std >= log_std_min
- ), "Log_std should be >= custom log_std_min (standard PPO)"
- assert torch.all(
- log_std <= log_std_max
- ), "Log_std should be <= custom log_std_max (standard PPO)"
+ assert torch.all(log_std >= log_std_min), (
+ "Log_std should be >= custom log_std_min (standard PPO)"
+ )
+ assert torch.all(log_std <= log_std_max), (
+ "Log_std should be <= custom log_std_max (standard PPO)"
+ )
def test_deterministic_with_same_input(self):
"""Test that same input produces different outputs due to sampling."""
@@ -294,12 +289,12 @@ def test_deterministic_with_same_input(self):
assert torch.allclose(log_std1, log_std2), "Log_stds should be identical"
# Actions and z should be different due to random sampling
- assert not torch.allclose(
- action1, action2
- ), "Actions should be different due to sampling"
- assert not torch.allclose(
- z1, z2
- ), "Raw actions should be different due to sampling"
+ assert not torch.allclose(action1, action2), (
+ "Actions should be different due to sampling"
+ )
+ assert not torch.allclose(z1, z2), (
+ "Raw actions should be different due to sampling"
+ )
# Test standard PPO mode
ppo_std = PPOModel(
@@ -310,17 +305,17 @@ def test_deterministic_with_same_input(self):
action2_std, mean2_std, log_std2_std = ppo_std(dummy_input)
# Mean and log_std should be the same (deterministic)
- assert torch.allclose(
- mean1_std, mean2_std
- ), "Means should be identical (standard PPO)"
- assert torch.allclose(
- log_std1_std, log_std2_std
- ), "Log_stds should be identical (standard PPO)"
+ assert torch.allclose(mean1_std, mean2_std), (
+ "Means should be identical (standard PPO)"
+ )
+ assert torch.allclose(log_std1_std, log_std2_std), (
+ "Log_stds should be identical (standard PPO)"
+ )
# Actions should be different due to random sampling
- assert not torch.allclose(
- action1_std, action2_std
- ), "Actions should be different due to sampling (standard PPO)"
+ assert not torch.allclose(action1_std, action2_std), (
+ "Actions should be different due to sampling (standard PPO)"
+ )
def test_orthogonal_initialization(self):
"""Test that weights are initialized with orthogonal initialization."""
@@ -334,9 +329,9 @@ def test_orthogonal_initialization(self):
module.weight, torch.zeros_like(module.weight)
), "Weights should not be all zeros"
# Check that biases are initialized to zero
- assert torch.allclose(
- module.bias, torch.zeros_like(module.bias)
- ), "Biases should be initialized to zero"
+ assert torch.allclose(module.bias, torch.zeros_like(module.bias)), (
+ "Biases should be initialized to zero"
+ )
def test_separate_feature_extractors(self):
"""Test that policy and value networks have separate feature extractors."""
@@ -348,14 +343,14 @@ def test_separate_feature_extractors(self):
value_features = ppo.feature_extractor_value(dummy_input)
# They should have the same shape but potentially different values
- assert (
- policy_features.shape == value_features.shape
- ), "Feature extractors should output same shape"
+ assert policy_features.shape == value_features.shape, (
+ "Feature extractors should output same shape"
+ )
# Verify they are separate networks by checking if they are different objects
- assert (
- ppo.feature_extractor_policy is not ppo.feature_extractor_value
- ), "Policy and value feature extractors should be separate objects"
+ assert ppo.feature_extractor_policy is not ppo.feature_extractor_value, (
+ "Policy and value feature extractors should be separate objects"
+ )
def test_different_architectures(self):
"""Test with different hidden layer configurations."""
diff --git a/test/models/test_q_networks.py b/test/models/test_q_networks.py
index d079502a..46ad3b45 100644
--- a/test/models/test_q_networks.py
+++ b/test/models/test_q_networks.py
@@ -13,18 +13,18 @@ def test_init(self):
dqn = DQN(num_actions=4, obs_size=3)
assert dqn.num_actions == 4, "Num_actions should be 4"
assert dqn.dueling is False, "Dueling should be False"
- assert isinstance(
- dqn.feature_extractor, MLP
- ), "Default feature extractor should be an instance of MLP"
- assert isinstance(
- dqn.head, torch.nn.Sequential
- ), "Head should be a nn.Sequential"
- assert isinstance(
- dqn.value, torch.nn.Linear
- ), "Value layer should be a nn.Linear"
- assert isinstance(
- dqn.advantage, torch.nn.Linear
- ), "Advantage layer should be a nn.Linear"
+ assert isinstance(dqn.feature_extractor, MLP), (
+ "Default feature extractor should be an instance of MLP"
+ )
+ assert isinstance(dqn.head, torch.nn.Sequential), (
+ "Head should be a nn.Sequential"
+ )
+ assert isinstance(dqn.value, torch.nn.Linear), (
+ "Value layer should be a nn.Linear"
+ )
+ assert isinstance(dqn.advantage, torch.nn.Linear), (
+ "Advantage layer should be a nn.Linear"
+ )
dummy_input = torch.rand((1, 3))
assert dqn(dummy_input).shape == (1, 4), "Output should have shape (1, 4)"
@@ -45,21 +45,21 @@ def test_init(self):
)
assert dqn.num_actions == 2, "Num_actions should be 4"
assert dqn.dueling is True, "Dueling should be True"
- assert isinstance(
- dqn.feature_extractor, CNN
- ), "Feature extractor should be a CNN"
- assert (
- len(dqn.feature_extractor.cnn) == 5
- ), "Feature extractor should have 2 convolutions"
- assert (
- dqn.feature_extractor.cnn[0].out_channels == 32
- ), "First convolution should have 32 output channels"
+ assert isinstance(dqn.feature_extractor, CNN), (
+ "Feature extractor should be a CNN"
+ )
+ assert len(dqn.feature_extractor.cnn) == 5, (
+ "Feature extractor should have 2 convolutions"
+ )
+ assert dqn.feature_extractor.cnn[0].out_channels == 32, (
+ "First convolution should have 32 output channels"
+ )
assert dqn.head[0].out_features == 32, "Head layer 1 should have hidden size 32"
assert dqn.head[2].out_features == 32, "Head layer 2 should have hidden size 32"
assert dqn.value.out_features == 1, "Value layer should have 1 output feature"
- assert (
- dqn.advantage.out_features == 2
- ), "Advantage layer should have 2 output features"
+ assert dqn.advantage.out_features == 2, (
+ "Advantage layer should have 2 output features"
+ )
dummy_input = torch.rand((5, 64, 64, 3))
assert dqn(dummy_input).shape == (5, 2), "Output should have shape (1, 2)"
@@ -87,9 +87,9 @@ def test_forward(self):
+ calculated_advantage
- calculated_advantage.mean(dim=1, keepdim=True)
)
- assert torch.allclose(
- dqn_pred, calculated_pred
- ), "Prediction should be equal to value + advantage - mean_advantage"
+ assert torch.allclose(dqn_pred, calculated_pred), (
+ "Prediction should be equal to value + advantage - mean_advantage"
+ )
def test_reset_head(self):
head_kwargs = {"hidden_sizes": [32, 32]}
@@ -114,9 +114,9 @@ def test_reset_head(self):
new_features = dqn.feature_extractor(dummy_input)
new_pred = dqn(dummy_input)
- assert torch.allclose(
- original_features, new_features
- ), "Features should be equal"
+ assert torch.allclose(original_features, new_features), (
+ "Features should be equal"
+ )
assert ~torch.allclose(original_pred, new_pred), "Predictions should differ"
def test_shrink_weights(self):
@@ -128,21 +128,21 @@ def test_shrink_weights(self):
for new_param, old_param in zip(
dqn.head.parameters(), prev_head_params, strict=False
):
- assert torch.allclose(
- new_param, old_param * 0.5
- ), "Weights have not been shrunk."
+ assert torch.allclose(new_param, old_param * 0.5), (
+ "Weights have not been shrunk."
+ )
for new_param, old_param in zip(
dqn.advantage.parameters(), prev_adv_params, strict=False
):
- assert torch.allclose(
- new_param, old_param * 0.5
- ), "Advantage weights have not been shrunk."
+ assert torch.allclose(new_param, old_param * 0.5), (
+ "Advantage weights have not been shrunk."
+ )
for new_param, old_param in zip(
dqn.value.parameters(), prev_value_params, strict=False
):
- assert torch.allclose(
- new_param, old_param * 0.5
- ), "Value weights have not been shrunk."
+ assert torch.allclose(new_param, old_param * 0.5), (
+ "Value weights have not been shrunk."
+ )
def test_get_state(self):
dqn = DQN(num_actions=4, obs_size=3, dueling=True)
@@ -160,18 +160,18 @@ def test_set_state(self):
state_dict = dqn.state_dict()
dqn2 = DQN(num_actions=4, obs_size=3, dueling=True)
original_pred = dqn2(dummy_input)
- assert ~torch.allclose(
- baseline_pred, original_pred
- ), "Predictions should be different before loading"
+ assert ~torch.allclose(baseline_pred, original_pred), (
+ "Predictions should be different before loading"
+ )
for p1, p2 in zip(dqn.parameters(), dqn2.parameters(), strict=False):
- assert not torch.allclose(
- p1, p2
- ), "Parameters should be different before loading"
+ assert not torch.allclose(p1, p2), (
+ "Parameters should be different before loading"
+ )
dqn2.load_state_dict(state_dict)
new_pred = dqn2(dummy_input)
- assert torch.allclose(
- baseline_pred, new_pred
- ), "Predictions should be equal after loading"
+ assert torch.allclose(baseline_pred, new_pred), (
+ "Predictions should be equal after loading"
+ )
for p1, p2 in zip(dqn.parameters(), dqn2.parameters(), strict=False):
assert torch.allclose(p1, p2), "Parameters should be equal after loading"
diff --git a/test/models/test_sac_networks.py b/test/models/test_sac_networks.py
index 161bd374..6c8a123b 100644
--- a/test/models/test_sac_networks.py
+++ b/test/models/test_sac_networks.py
@@ -20,40 +20,40 @@ def test_init(self):
# Check network structure - updated for new architecture
assert hasattr(sac, "feature_extractor"), "Should have feature extractor"
- assert isinstance(
- sac.policy_net, nn.Linear
- ), "Policy network should be Linear (after feature extractor)"
+ assert isinstance(sac.policy_net, nn.Linear), (
+ "Policy network should be Linear (after feature extractor)"
+ )
assert isinstance(sac.q_net1, nn.Sequential), "Q-network 1 should be Sequential"
assert isinstance(sac.q_net2, nn.Sequential), "Q-network 2 should be Sequential"
- assert isinstance(
- sac.target_q_net1, nn.Sequential
- ), "Target Q-network 1 should be Sequential"
- assert isinstance(
- sac.target_q_net2, nn.Sequential
- ), "Target Q-network 2 should be Sequential"
- assert hasattr(
- sac, "value_function_module"
- ), "Should have value function module wrapper"
+ assert isinstance(sac.target_q_net1, nn.Sequential), (
+ "Target Q-network 1 should be Sequential"
+ )
+ assert isinstance(sac.target_q_net2, nn.Sequential), (
+ "Target Q-network 2 should be Sequential"
+ )
+ assert hasattr(sac, "value_function_module"), (
+ "Should have value function module wrapper"
+ )
# Check that target networks have gradients disabled
for param in sac.target_q_net1.parameters():
- assert (
- not param.requires_grad
- ), "Target Q-network 1 parameters should not require gradients"
+ assert not param.requires_grad, (
+ "Target Q-network 1 parameters should not require gradients"
+ )
for param in sac.target_q_net2.parameters():
- assert (
- not param.requires_grad
- ), "Target Q-network 2 parameters should not require gradients"
+ assert not param.requires_grad, (
+ "Target Q-network 2 parameters should not require gradients"
+ )
# Check that live networks have gradients enabled
for param in sac.q_net1.parameters():
- assert (
- param.requires_grad
- ), "Q-network 1 parameters should require gradients"
+ assert param.requires_grad, (
+ "Q-network 1 parameters should require gradients"
+ )
for param in sac.q_net2.parameters():
- assert (
- param.requires_grad
- ), "Q-network 2 parameters should require gradients"
+ assert param.requires_grad, (
+ "Q-network 2 parameters should require gradients"
+ )
def test_init_custom_params(self):
"""Test initialization with custom parameters."""
@@ -106,9 +106,9 @@ def test_value_function_module(self):
values_module = sac.value_function_module(dummy_state)
values_direct = sac.forward_value(dummy_state)
- assert torch.allclose(
- values_module, values_direct
- ), "Value function module should produce same output as forward_value"
+ assert torch.allclose(values_module, values_direct), (
+ "Value function module should produce same output as forward_value"
+ )
assert values_module.shape == (
8,
1,
@@ -134,9 +134,9 @@ def test_forward_stochastic(self):
assert torch.all(torch.isfinite(log_std)), "Log_stds should be finite"
# Check tanh constraint on actions
- assert torch.all(action >= -1.0) and torch.all(
- action <= 1.0
- ), "Actions should be in [-1, 1] range"
+ assert torch.all(action >= -1.0) and torch.all(action <= 1.0), (
+ "Actions should be in [-1, 1] range"
+ )
# Check log_std clamping
assert torch.all(log_std >= sac.log_std_min), "Log_std should be >= log_std_min"
@@ -144,9 +144,9 @@ def test_forward_stochastic(self):
# Check relationship: action = tanh(z)
expected_action = torch.tanh(z)
- assert torch.allclose(
- action, expected_action, atol=1e-6
- ), "Action should equal tanh(z)"
+ assert torch.allclose(action, expected_action, atol=1e-6), (
+ "Action should equal tanh(z)"
+ )
def test_forward_deterministic(self):
"""Test forward pass with deterministic policy."""
@@ -166,9 +166,9 @@ def test_forward_deterministic(self):
# Action should still be tanh(z) = tanh(mean)
expected_action = torch.tanh(mean)
- assert torch.allclose(
- action, expected_action
- ), "Action should equal tanh(mean) in deterministic mode"
+ assert torch.allclose(action, expected_action), (
+ "Action should equal tanh(mean) in deterministic mode"
+ )
def test_stochastic_vs_deterministic(self):
"""Test that stochastic and deterministic modes produce different results."""
@@ -185,18 +185,18 @@ def test_stochastic_vs_deterministic(self):
# Mean and log_std should be the same
assert torch.allclose(mean_stoch, mean_det), "Means should be identical"
- assert torch.allclose(
- log_std_stoch, log_std_det
- ), "Log_stds should be identical"
+ assert torch.allclose(log_std_stoch, log_std_det), (
+ "Log_stds should be identical"
+ )
# In deterministic mode, z should equal mean
assert torch.allclose(z_det, mean_det), "Deterministic z should equal mean"
# Stochastic z should likely be different from mean (due to noise)
# Note: There's a tiny chance they could be the same, but extremely unlikely
- assert not torch.allclose(
- z_stoch, mean_stoch
- ), "Stochastic z should be different from mean"
+ assert not torch.allclose(z_stoch, mean_stoch), (
+ "Stochastic z should be different from mean"
+ )
def test_policy_log_prob(self):
"""Test policy log probability calculation."""
@@ -215,9 +215,9 @@ def test_policy_log_prob(self):
# Test with deterministic actions (z = mean)
log_prob_det = sac.policy_log_prob(mean, mean, log_std)
- assert torch.all(
- torch.isfinite(log_prob_det)
- ), "Deterministic log probs should be finite"
+ assert torch.all(torch.isfinite(log_prob_det)), (
+ "Deterministic log probs should be finite"
+ )
def test_q_networks(self):
"""Test Q-network forward passes."""
@@ -247,16 +247,16 @@ def test_target_networks_initialization(self):
for p1, p_target1 in zip(
sac.q_net1.parameters(), sac.target_q_net1.parameters()
):
- assert torch.allclose(
- p1, p_target1
- ), "Target Q-net 1 should have same initial weights as Q-net 1"
+ assert torch.allclose(p1, p_target1), (
+ "Target Q-net 1 should have same initial weights as Q-net 1"
+ )
for p2, p_target2 in zip(
sac.q_net2.parameters(), sac.target_q_net2.parameters()
):
- assert torch.allclose(
- p2, p_target2
- ), "Target Q-net 2 should have same initial weights as Q-net 2"
+ assert torch.allclose(p2, p_target2), (
+ "Target Q-net 2 should have same initial weights as Q-net 2"
+ )
def test_twin_q_networks_independence(self):
"""Test that twin Q-networks are independent."""
@@ -264,9 +264,9 @@ def test_twin_q_networks_independence(self):
# Check that Q-networks have different parameters (due to random initialization)
assert sac.q_net1 is not sac.q_net2, "Q-networks should be separate objects"
- assert (
- sac.target_q_net1 is not sac.target_q_net2
- ), "Target Q-networks should be separate objects"
+ assert sac.target_q_net1 is not sac.target_q_net2, (
+ "Target Q-networks should be separate objects"
+ )
def test_log_std_bounds_enforcement(self):
"""Test that log_std bounds are properly enforced."""
@@ -279,12 +279,12 @@ def test_log_std_bounds_enforcement(self):
dummy_state = torch.rand((10, 3))
_, _, _, log_std = sac(dummy_state)
- assert torch.all(
- log_std >= log_std_min
- ), "Log_std should be >= custom log_std_min"
- assert torch.all(
- log_std <= log_std_max
- ), "Log_std should be <= custom log_std_max"
+ assert torch.all(log_std >= log_std_min), (
+ "Log_std should be >= custom log_std_min"
+ )
+ assert torch.all(log_std <= log_std_max), (
+ "Log_std should be <= custom log_std_max"
+ )
def test_gradient_flow(self):
"""Test that gradients flow properly through networks."""
@@ -303,9 +303,9 @@ def test_gradient_flow(self):
feature_has_grad = any(
p.grad is not None for p in sac.feature_extractor.parameters()
)
- assert (
- policy_has_grad or feature_has_grad
- ), "Policy network or feature extractor should have gradients"
+ assert policy_has_grad or feature_has_grad, (
+ "Policy network or feature extractor should have gradients"
+ )
# Test Q-network gradients
sac.zero_grad()
@@ -333,9 +333,9 @@ def test_numerical_stability(self):
# Test log probability calculation doesn't produce NaN or inf
log_prob = sac.policy_log_prob(z, mean, log_std)
- assert torch.all(
- torch.isfinite(log_prob)
- ), "Log probabilities should be finite even with extreme inputs"
+ assert torch.all(torch.isfinite(log_prob)), (
+ "Log probabilities should be finite even with extreme inputs"
+ )
# Test with actions close to boundary values (-1, 1)
boundary_z = torch.tensor(
@@ -347,6 +347,6 @@ def test_numerical_stability(self):
boundary_log_prob = sac.policy_log_prob(
boundary_z, boundary_mean, boundary_log_std
)
- assert torch.all(
- torch.isfinite(boundary_log_prob)
- ), "Log probabilities should be finite for boundary actions"
+ assert torch.all(torch.isfinite(boundary_log_prob)), (
+ "Log probabilities should be finite for boundary actions"
+ )
diff --git a/test/replay/test_buffer.py b/test/replay/test_buffer.py
index e11b8152..d6cdfc6e 100644
--- a/test/replay/test_buffer.py
+++ b/test/replay/test_buffer.py
@@ -59,28 +59,28 @@ def test_init(self, observations, actions, rewards, next_observations, dones, si
batch = TransitionBatch(
observations, actions, rewards, next_observations, dones
)
- assert isinstance(
- batch.observations, torch.Tensor
- ), "Observations were not a tensor."
+ assert isinstance(batch.observations, torch.Tensor), (
+ "Observations were not a tensor."
+ )
assert isinstance(batch.actions, torch.Tensor), "Actions were not a tensor."
assert isinstance(batch.rewards, torch.Tensor), "Rewards were not a tensor."
- assert isinstance(
- batch.next_obs, torch.Tensor
- ), "Next observations were not a tensor."
+ assert isinstance(batch.next_obs, torch.Tensor), (
+ "Next observations were not a tensor."
+ )
assert isinstance(batch.dones, torch.Tensor), "Dones were not a tensor."
- assert (
- len(batch.observations.shape) == 2
- ), f"Observation shape was not 2D: {batch.observations.shape}."
- assert (
- batch.observations.shape == batch.next_obs.shape
- ), "Observation shape was not equal to next observation shape."
- assert (
- batch.actions.shape == batch.rewards.shape
- ), "Action shape was not equal to reward shape."
- assert (
- batch.actions.shape == batch.dones.shape
- ), "Action shape was not equal to reward shape."
+ assert len(batch.observations.shape) == 2, (
+ f"Observation shape was not 2D: {batch.observations.shape}."
+ )
+ assert batch.observations.shape == batch.next_obs.shape, (
+ "Observation shape was not equal to next observation shape."
+ )
+ assert batch.actions.shape == batch.rewards.shape, (
+ "Action shape was not equal to reward shape."
+ )
+ assert batch.actions.shape == batch.dones.shape, (
+ "Action shape was not equal to reward shape."
+ )
assert (
len(batch.actions.shape) == len(batch.observations.shape) - 1
), f"""Action shape was not one less than observation shape:
@@ -98,9 +98,9 @@ def test_iter(self, observations, actions, rewards, next_observations, dones, si
elements = 0
for obs, act, rew, next_obs, done in batch:
assert obs.numpy() in observations, "Observation was not in observations."
- assert (
- next_obs.numpy() in next_observations
- ), "Next observation was not in next_observations."
+ assert next_obs.numpy() in next_observations, (
+ "Next observation was not in next_observations."
+ )
if isinstance(actions, int):
assert act.numpy().item() == actions, "Action was not in actions."
assert rew.numpy().item() == rewards, "Reward was not in rewards."
@@ -146,17 +146,17 @@ def test_add(self, observations, actions, rewards, next_observations, dones, siz
replay = self.get_replay(batch, size, empty=True)
filled_replay = self.get_replay(batch, size)
assert len(replay) == 0, "Empty replay length was not 0."
- assert (
- len(filled_replay) == size
- ), "Filled replay length was not equal to batch size."
+ assert len(filled_replay) == size, (
+ "Filled replay length was not equal to batch size."
+ )
replay.add(batch, {})
- assert len(replay) == len(
- filled_replay
- ), "Replay length was not equal to batch size."
- assert (
- replay.index == filled_replay.index
- ), "Replay index was not equal to filled replay index."
+ assert len(replay) == len(filled_replay), (
+ "Replay length was not equal to batch size."
+ )
+ assert replay.index == filled_replay.index, (
+ "Replay index was not equal to filled replay index."
+ )
assert all(
any(torch.equal(obs, ob) for obs in batch.observations) for ob in replay.obs
), "Observations were not added to replay."
@@ -201,9 +201,9 @@ def test_sample(self):
replay = self.get_replay(batch, size)
minibatch = replay.sample(batch_size=1)
assert len(minibatch) == 1, "Minibatch length was incorrect (batch size 1)."
- assert isinstance(
- minibatch, TransitionBatch
- ), "Minibatch was not a TransitionBatch."
+ assert isinstance(minibatch, TransitionBatch), (
+ "Minibatch was not a TransitionBatch."
+ )
assert all(
any(torch.allclose(obs, ob) for obs in batch.observations)
for ob in minibatch.observations
@@ -228,9 +228,9 @@ def test_sample(self):
minibatch = replay.sample(batch_size=2)
assert len(minibatch) == 2, "Minibatch length was incorrect (batch size 2)."
- assert isinstance(
- minibatch, TransitionBatch
- ), "Minibatch was not a TransitionBatch."
+ assert isinstance(minibatch, TransitionBatch), (
+ "Minibatch was not a TransitionBatch."
+ )
assert all(
any(torch.allclose(obs, ob) for obs in batch.observations)
for ob in minibatch.observations
@@ -254,9 +254,9 @@ def test_sample(self):
batchset = [replay.sample(batch_size=1) for _ in range(10)]
all_actions = [act for batch in batchset for act in batch.actions]
- assert ~all(
- x == all_actions[0] for x in all_actions
- ), "All sampled batches were the same."
+ assert ~all(x == all_actions[0] for x in all_actions), (
+ "All sampled batches were the same."
+ )
def test_reset(self):
(
@@ -295,9 +295,9 @@ def test_len(self):
assert len(replay) == size, "Replay length was not equal to batch size."
replay.add(batch, {})
- assert (
- len(replay) == size * 2
- ), "Replay length was not doubled after doubling transitions."
+ assert len(replay) == size * 2, (
+ "Replay length was not doubled after doubling transitions."
+ )
replay = self.get_replay(batch, size, empty=True)
assert len(replay) == 0, "Replay length of empty replay was not 0."
@@ -316,42 +316,42 @@ def test_full(self):
)
replay = self.get_replay(batch, size, full=False)
assert replay.full is False, "Replay was falsely full."
- assert replay.capacity > len(
- replay
- ), "Replay capacity was not greater than length in non-full replay."
- assert (
- replay.index < replay.capacity
- ), "Replay index was not less than capacity in non-full replay."
+ assert replay.capacity > len(replay), (
+ "Replay capacity was not greater than length in non-full replay."
+ )
+ assert replay.index < replay.capacity, (
+ "Replay index was not less than capacity in non-full replay."
+ )
replay = self.get_replay(batch, size, full=True)
assert replay.full is True, "Replay was not full."
- assert replay.capacity == len(
- replay
- ), "Replay capacity was not equal to length in full replay."
- assert (
- replay.index == replay.capacity
- ), "Replay index was not equal to capacity in full replay."
+ assert replay.capacity == len(replay), (
+ "Replay capacity was not equal to length in full replay."
+ )
+ assert replay.index == replay.capacity, (
+ "Replay index was not equal to capacity in full replay."
+ )
second_batch = TransitionBatch(
observations * 2, actions * 2, rewards * 2, next_observations * 2, dones * 2
)
replay.add(second_batch, {})
- assert (
- replay.full is True
- ), "Replay was not full anymore after adding more transitions."
- assert replay.capacity == len(
- replay
- ), "Replay capacity was not equal to length after adding more transitions."
- assert (
- replay.index == replay.capacity
- ), "Replay index was not equal to capacity after adding more transitions."
+ assert replay.full is True, (
+ "Replay was not full anymore after adding more transitions."
+ )
+ assert replay.capacity == len(replay), (
+ "Replay capacity was not equal to length after adding more transitions."
+ )
+ assert replay.index == replay.capacity, (
+ "Replay index was not equal to capacity after adding more transitions."
+ )
for obs, act, rew, next_obs, done in second_batch:
assert obs in replay.obs, f"Observation {obs} was not in replay."
assert act in replay.actions, f"Action {act} was not in replay."
assert rew in replay.rewards, f"Reward {rew} was not in replay."
- assert (
- next_obs in replay.next_obs
- ), f"Next observation {next_obs} was not in replay."
+ assert next_obs in replay.next_obs, (
+ f"Next observation {next_obs} was not in replay."
+ )
assert done in replay.dones, f"Done {done} was not in replay."
def test_empty(self):
@@ -388,27 +388,27 @@ def test_save(self):
replay.save("test_replay.pkl")
with open("test_replay.pkl", "rb") as f:
loaded_replay = pkl.load(f)
- assert (
- replay.capacity == loaded_replay.capacity
- ), "Replay capacity was not loaded correctly."
- assert (
- replay.index == loaded_replay.index
- ), "Replay index was not loaded correctly."
- assert torch.allclose(
- replay.obs, loaded_replay.obs
- ), "Replay observations were not loaded correctly."
- assert torch.allclose(
- replay.actions, loaded_replay.actions
- ), "Replay actions were not loaded correctly."
- assert torch.allclose(
- replay.rewards, loaded_replay.rewards
- ), "Replay rewards were not loaded correctly."
- assert torch.allclose(
- replay.next_obs, loaded_replay.next_obs
- ), "Replay next observations were not loaded correctly."
- assert torch.allclose(
- replay.dones, loaded_replay.dones
- ), "Replay dones were not loaded correctly."
+ assert replay.capacity == loaded_replay.capacity, (
+ "Replay capacity was not loaded correctly."
+ )
+ assert replay.index == loaded_replay.index, (
+ "Replay index was not loaded correctly."
+ )
+ assert torch.allclose(replay.obs, loaded_replay.obs), (
+ "Replay observations were not loaded correctly."
+ )
+ assert torch.allclose(replay.actions, loaded_replay.actions), (
+ "Replay actions were not loaded correctly."
+ )
+ assert torch.allclose(replay.rewards, loaded_replay.rewards), (
+ "Replay rewards were not loaded correctly."
+ )
+ assert torch.allclose(replay.next_obs, loaded_replay.next_obs), (
+ "Replay next observations were not loaded correctly."
+ )
+ assert torch.allclose(replay.dones, loaded_replay.dones), (
+ "Replay dones were not loaded correctly."
+ )
Path("test_replay.pkl").unlink()
@@ -504,21 +504,21 @@ def test_add(self, observations, actions, rewards, next_observations, dones, siz
filled_replay = self.get_replay(batch, size)
assert replay.current_size == 0, "Empty replay length was not 0."
- assert (
- filled_replay.current_size == size
- ), "Filled replay length was not equal to batch size."
+ assert filled_replay.current_size == size, (
+ "Filled replay length was not equal to batch size."
+ )
# Now add to the previously empty replay with a brand-new td_error array
td_errors = rng.random(size).astype(np.float32)
replay.add(batch, {"td_error": td_errors})
# Both replays should have the same current_size and data_idx
- assert (
- replay.current_size == filled_replay.current_size
- ), "Replay length was not equal to filled replay length."
- assert (
- replay.data_idx == filled_replay.data_idx
- ), "Replay data_idx was not equal to filled replay data_idx."
+ assert replay.current_size == filled_replay.current_size, (
+ "Replay length was not equal to filled replay length."
+ )
+ assert replay.data_idx == filled_replay.data_idx, (
+ "Replay data_idx was not equal to filled replay data_idx."
+ )
# Compute expected priority = (|td_error| + ε)^α for each leaf
eps = replay.epsilon
@@ -527,9 +527,9 @@ def test_add(self, observations, actions, rewards, next_observations, dones, siz
base = replay.capacity
actual_leaves = replay.sum_tree[base : base + size]
- assert np.allclose(
- actual_leaves, expected_prios, atol=1e-6
- ), f"Expected leaves {expected_prios}, but got {actual_leaves}"
+ assert np.allclose(actual_leaves, expected_prios, atol=1e-6), (
+ f"Expected leaves {expected_prios}, but got {actual_leaves}"
+ )
@pytest.mark.parametrize(
("observations", "actions", "rewards", "next_observations", "dones", "size"),
@@ -588,9 +588,9 @@ def test_sample(
# If there is more than one element, uniform draws should vary
if size > 1:
- assert (
- len(set(seen_actions)) >= 2
- ), "Uniform-priority sampling did not vary."
+ assert len(set(seen_actions)) >= 2, (
+ "Uniform-priority sampling did not vary."
+ )
# 2) Skewed priorities (only one index has nonzero td_error):
replay = self.get_replay(batch, size, empty=True)
@@ -615,9 +615,9 @@ def test_sample(
forced_indices.append(int(batch_indices_b.item()))
# All sampled indices must be identical (the only one with nonzero priority)
- assert (
- len(set(forced_indices)) == 1
- ), f"Expected only one index to be chosen, but got {set(forced_indices)}"
+ assert len(set(forced_indices)) == 1, (
+ f"Expected only one index to be chosen, but got {set(forced_indices)}"
+ )
@pytest.mark.parametrize(
("observations", "actions", "rewards", "next_observations", "dones", "size"),
@@ -648,9 +648,9 @@ def test_reset(
# **Because reset() does not clear on‐device buffers or sum_tree**, we expect:
# - current_size remains equal to `size`
# - sum_tree[1] (the total priority) remains > 0
- assert (
- replay.current_size == size
- ), f"After reset(), expected current_size to still be {size}, but got {replay.current_size}"
- assert (
- replay.sum_tree[1] > 0.0
- ), "After reset(), expected total priority (sum_tree[1]) to remain > 0"
+ assert replay.current_size == size, (
+ f"After reset(), expected current_size to still be {size}, but got {replay.current_size}"
+ )
+ assert replay.sum_tree[1] > 0.0, (
+ "After reset(), expected total priority (sum_tree[1]) to remain > 0"
+ )
diff --git a/test/replay/test_rollout_buffer.py b/test/replay/test_rollout_buffer.py
index 42c8d7e8..f7be466b 100644
--- a/test/replay/test_rollout_buffer.py
+++ b/test/replay/test_rollout_buffer.py
@@ -92,16 +92,16 @@ def test_init(
assert isinstance(batch.rewards, torch.Tensor), "Rewards not tensor"
assert isinstance(batch.advantages, torch.Tensor), "Advantages not tensor"
assert isinstance(batch.returns, torch.Tensor), "Returns not tensor"
- assert isinstance(
- batch.episode_starts, torch.Tensor
- ), "Episode starts not tensor"
+ assert isinstance(batch.episode_starts, torch.Tensor), (
+ "Episode starts not tensor"
+ )
assert isinstance(batch.log_probs, torch.Tensor), "Log probs not tensor"
assert isinstance(batch.values, torch.Tensor), "Values not tensor"
# Check dimensions are promoted correctly
- assert (
- batch.observations.dim() >= 2
- ), f"Obs dim too low: {batch.observations.shape}"
+ assert batch.observations.dim() >= 2, (
+ f"Obs dim too low: {batch.observations.shape}"
+ )
assert batch.actions.dim() >= 1, f"Actions dim too low: {batch.actions.shape}"
# For discrete actions, latents should be None
@@ -123,9 +123,9 @@ def test_init(
values=values,
)
assert batch_with_latents.latents is not None, "Latents should not be None"
- assert isinstance(
- batch_with_latents.latents, torch.Tensor
- ), "Latents not tensor"
+ assert isinstance(batch_with_latents.latents, torch.Tensor), (
+ "Latents not tensor"
+ )
@pytest.mark.parametrize(
(
@@ -215,9 +215,9 @@ def test_iter(
if discrete:
assert lat is None, "Latent should be None for discrete"
else:
- assert lat is None or isinstance(
- lat, torch.Tensor
- ), "Latent issue in iteration"
+ assert lat is None or isinstance(lat, torch.Tensor), (
+ "Latent issue in iteration"
+ )
elements += 1
assert elements == size, f"Expected {size} elements, got {elements}"
@@ -391,16 +391,16 @@ def test_attribute_stacking(self):
# Test stacked observations
# After stacking: (2, 1, 1, 2), after reshape: (2, 1, 2)
expected_obs = torch.tensor([[[1, 2]], [[3, 4]]], dtype=torch.float32)
- assert torch.allclose(
- maxi.observations, expected_obs
- ), f"Observations not stacked correctly. Expected {expected_obs}, got {maxi.observations}"
+ assert torch.allclose(maxi.observations, expected_obs), (
+ f"Observations not stacked correctly. Expected {expected_obs}, got {maxi.observations}"
+ )
# Test stacked actions
# After stacking: (2, 1, 1), after reshape: (2, 1)
expected_actions = torch.tensor([[0], [1]], dtype=torch.float32)
- assert torch.allclose(
- maxi.actions, expected_actions
- ), f"Actions not stacked correctly. Expected {expected_actions}, got {maxi.actions}"
+ assert torch.allclose(maxi.actions, expected_actions), (
+ f"Actions not stacked correctly. Expected {expected_actions}, got {maxi.actions}"
+ )
def test_latents_handling(self):
"""Test MaxiBatch handles latents correctly (None vs tensor)"""
@@ -426,9 +426,9 @@ def test_latents_handling(self):
expected_latents_shape = torch.Size(
[1]
) # zeros_like creates (1,1), stacked to (1,1,1), reshaped to (1,)
- assert (
- latents.shape == expected_latents_shape
- ), f"Latents shape {latents.shape} should be {expected_latents_shape}"
+ assert latents.shape == expected_latents_shape, (
+ f"Latents shape {latents.shape} should be {expected_latents_shape}"
+ )
def test_empty_tensor_for_empty_batch(self):
"""Test that empty MaxiBatch returns empty tensors"""
@@ -483,9 +483,9 @@ def test_init_continuous(self):
1,
2,
), "Wrong actions buffer shape (continuous)"
- assert (
- buffer.latents is not None
- ), "Latents should not be None for continuous with use_latents=True"
+ assert buffer.latents is not None, (
+ "Latents should not be None for continuous with use_latents=True"
+ )
assert buffer.latents.shape == (10, 1, 2), "Wrong latents buffer shape"
def test_init_continuous_no_latents(self):
@@ -616,9 +616,9 @@ def test_add_multi_step(self):
)
buffer.add(rb)
- assert (
- buffer.pos == 3
- ), f"Position should be 3 after adding 3 steps, got {buffer.pos}"
+ assert buffer.pos == 3, (
+ f"Position should be 3 after adding 3 steps, got {buffer.pos}"
+ )
def test_buffer_overflow(self):
buffer = self.get_buffer(buffer_size=2) # Small buffer
@@ -670,12 +670,12 @@ def test_compute_returns_and_advantage_single_env(self):
buffer.compute_returns_and_advantage(last_values, dones)
# Check that advantages and returns were computed (non-zero)
- assert not torch.allclose(
- buffer.advantages[:2], torch.zeros(2, 1)
- ), "Advantages should be computed (non-zero)"
- assert not torch.allclose(
- buffer.returns[:2], torch.zeros(2, 1)
- ), "Returns should be computed (non-zero)"
+ assert not torch.allclose(buffer.advantages[:2], torch.zeros(2, 1)), (
+ "Advantages should be computed (non-zero)"
+ )
+ assert not torch.allclose(buffer.returns[:2], torch.zeros(2, 1)), (
+ "Returns should be computed (non-zero)"
+ )
def test_compute_returns_and_advantage_multi_env(self):
"""Test GAE computation with multiple environments"""
@@ -690,9 +690,7 @@ def test_compute_returns_and_advantage_multi_env(self):
advantages=np.array([[0.0, 0.0]]), # (1, 2) ✓
returns=np.array([[0.0, 0.0]]), # (1, 2) ✓
episode_starts=np.array([[1, 1]]), # (1, 2) ✓
- log_probs=np.array(
- [[-0.5, -0.8]]
- ), # (1, 2)
+ log_probs=np.array([[-0.5, -0.8]]), # (1, 2)
values=np.array([[1.0, 0.5]]), # (1, 2) ✓
)
@@ -725,18 +723,18 @@ def test_compute_returns_and_advantage_multi_env(self):
print(f"Computed returns: {returns_computed}")
# Basic sanity checks
- assert not torch.allclose(
- advantages_computed, torch.zeros(2)
- ), "Advantages should be non-zero"
- assert not torch.allclose(
- returns_computed, torch.zeros(2)
- ), "Returns should be non-zero"
+ assert not torch.allclose(advantages_computed, torch.zeros(2)), (
+ "Advantages should be non-zero"
+ )
+ assert not torch.allclose(returns_computed, torch.zeros(2)), (
+ "Returns should be non-zero"
+ )
# For GAE, returns = advantages + values (at time t)
expected_returns = advantages_computed + buffer.values[0]
- assert torch.allclose(
- returns_computed, expected_returns, atol=1e-6
- ), f"Returns should equal advantages + values: {returns_computed} vs {expected_returns}"
+ assert torch.allclose(returns_computed, expected_returns, atol=1e-6), (
+ f"Returns should equal advantages + values: {returns_computed} vs {expected_returns}"
+ )
def test_compute_returns_empty_buffer(self):
"""Test GAE computation on empty buffer"""
@@ -775,9 +773,9 @@ def test_sample_insufficient_data(self):
# Try to sample batch_size=4 when only 1 transition available
maxi_batch = buffer.sample(batch_size=4)
- assert (
- len(maxi_batch) == 0
- ), "Should return empty MaxiBatch when insufficient data"
+ assert len(maxi_batch) == 0, (
+ "Should return empty MaxiBatch when insufficient data"
+ )
def test_sample_with_data(self):
buffer = self.get_buffer(buffer_size=10, n_envs=2, discrete=True)
@@ -822,12 +820,10 @@ def test_sample_with_data(self):
)
buffer.add(rb)
-
# We have 2 timesteps × 2 envs = 4 total transitions
- assert (
- len(buffer) == 4
- ), f"Buffer should contain 4 transitions, got {len(buffer)}"
-
+ assert len(buffer) == 4, (
+ f"Buffer should contain 4 transitions, got {len(buffer)}"
+ )
maxi_batch = buffer.sample(batch_size=2)
@@ -856,29 +852,29 @@ def test_sample_with_data(self):
# Test that each minibatch has valid data
for i, mb in enumerate(minibatches):
- assert (
- mb.observations is not None
- ), f"Minibatch {i} observations should not be None"
- assert (
- mb.log_probs is not None
- ), f"Minibatch {i} log_probs should not be None"
- assert (
- mb.observations.shape[0] > 0
- ), f"Minibatch {i} should have some observations"
+ assert mb.observations is not None, (
+ f"Minibatch {i} observations should not be None"
+ )
+ assert mb.log_probs is not None, (
+ f"Minibatch {i} log_probs should not be None"
+ )
+ assert mb.observations.shape[0] > 0, (
+ f"Minibatch {i} should have some observations"
+ )
print(f"Minibatch {i} validated: obs.shape={mb.observations.shape}")
else:
# Original expected behavior
- assert (
- len(maxi_batch) == 4
- ), f"Should have 4 total sampled elements, got {len(maxi_batch)}"
- assert (
- len(minibatches) == 2
- ), f"Should have 2 minibatches, got {len(minibatches)}"
+ assert len(maxi_batch) == 4, (
+ f"Should have 4 total sampled elements, got {len(maxi_batch)}"
+ )
+ assert len(minibatches) == 2, (
+ f"Should have 2 minibatches, got {len(minibatches)}"
+ )
for i, mb in enumerate(minibatches):
- assert (
- len(mb) == 2
- ), f"Minibatch {i} should have 2 elements, got {len(mb)}"
+ assert len(mb) == 2, (
+ f"Minibatch {i} should have 2 elements, got {len(mb)}"
+ )
def test_len_and_bool(self):
buffer = self.get_buffer(n_envs=2)
@@ -965,9 +961,9 @@ def test_save_and_load(self):
assert loaded_buffer.n_envs == buffer.n_envs, "N envs mismatch"
assert loaded_buffer.gamma == buffer.gamma, "Gamma mismatch"
assert loaded_buffer.gae_lambda == buffer.gae_lambda, "GAE lambda mismatch"
- assert torch.allclose(
- loaded_buffer.observations, buffer.observations
- ), "Observations mismatch"
+ assert torch.allclose(loaded_buffer.observations, buffer.observations), (
+ "Observations mismatch"
+ )
assert torch.allclose(loaded_buffer.actions, buffer.actions), "Actions mismatch"
assert torch.allclose(loaded_buffer.rewards, buffer.rewards), "Rewards mismatch"
@@ -1025,20 +1021,6 @@ def test_gae_computation_details(self):
buffer.compute_returns_and_advantage(last_values, dones)
- # Manually calculate expected values for verification
- gamma = buffer.gamma
- lam = buffer.gae_lambda
-
- # Step 1 (t=1): delta = r1 + gamma * V(s2) * (1-done) - V(s1)
- # = 2.0 + 0.99 * 1.5 * 1 - 1.0 = 2.485
- # Step 0 (t=0): delta = r0 + gamma * V(s1) * (1-episode_start[1]) - V(s0)
- # = 1.0 + 0.99 * 1.0 * 1 - 0.5 = 1.49
-
- # GAE calculation (backward):
- # gae_1 = delta_1 = 2.485
- # gae_0 = delta_0 + gamma * lambda * (1-episode_start[1]) * gae_1
- # = 1.49 + 0.99 * 0.95 * 1 * 2.485 ≈ 3.82
-
advantages = buffer.advantages[:2, 0].cpu().numpy()
returns = buffer.returns[:2, 0].cpu().numpy()
@@ -1048,9 +1030,9 @@ def test_gae_computation_details(self):
# Returns should be advantages + values
expected_returns = advantages + np.array([0.5, 1.0])
- assert np.allclose(
- returns, expected_returns, atol=1e-5
- ), f"Returns should equal advantages + values: {returns} vs {expected_returns}"
+ assert np.allclose(returns, expected_returns, atol=1e-5), (
+ f"Returns should equal advantages + values: {returns} vs {expected_returns}"
+ )
def test_episode_boundary_handling(self):
"""Test that episode boundaries are handled correctly in GAE"""
@@ -1119,9 +1101,9 @@ def test_episode_boundary_handling(self):
advantages = buffer.advantages[:3, 0].cpu().numpy()
# All advantages should be computed (non-zero)
- assert not np.allclose(
- advantages, [0.0, 0.0, 0.0]
- ), "All advantages should be computed"
+ assert not np.allclose(advantages, [0.0, 0.0, 0.0]), (
+ "All advantages should be computed"
+ )
def test_multi_env_independence(self):
"""Test that multiple environments are handled independently"""
@@ -1137,9 +1119,7 @@ def test_multi_env_independence(self):
advantages=np.array([[0.0, 0.0, 0.0]]), # (1, 3)
returns=np.array([[0.0, 0.0, 0.0]]), # (1, 3)
episode_starts=np.array([[1, 1, 1]]), # (1, 3) - All start new episodes
- log_probs=np.array(
- [[-0.5, -0.8, -0.3]]
- ), # (1, 3)
+ log_probs=np.array([[-0.5, -0.8, -0.3]]), # (1, 3)
values=np.array([[0.5, 1.0, 0.3]]), # (1, 3)
)
@@ -1155,28 +1135,27 @@ def test_multi_env_independence(self):
# All environments should have computed advantages
assert len(advantages) == 3, "Should have advantages for all 3 envs"
- assert not np.allclose(
- advantages, [0.0, 0.0, 0.0]
- ), "All advantages should be computed"
+ assert not np.allclose(advantages, [0.0, 0.0, 0.0]), (
+ "All advantages should be computed"
+ )
# The done environment (env 1) should have different computation
# (no bootstrap from next value)
- assert (
- advantages[1] != advantages[0]
- ), "Done env should have different advantage"
- assert (
- advantages[1] != advantages[2]
- ), "Done env should have different advantage"
-
+ assert advantages[1] != advantages[0], (
+ "Done env should have different advantage"
+ )
+ assert advantages[1] != advantages[2], (
+ "Done env should have different advantage"
+ )
# Additional verification: returns should equal advantages + values for GAE
returns = buffer.returns[0, :].cpu().numpy()
values = buffer.values[0, :].cpu().numpy()
expected_returns = advantages + values
- assert np.allclose(
- returns, expected_returns, atol=1e-6
- ), f"Returns should equal advantages + values: {returns} vs {expected_returns}"
+ assert np.allclose(returns, expected_returns, atol=1e-6), (
+ f"Returns should equal advantages + values: {returns} vs {expected_returns}"
+ )
def test_sampling_randomness(self):
"""Test that sampling produces different results when called multiple times"""
@@ -1218,9 +1197,9 @@ def test_sampling_randomness(self):
# Should get some variety in samples (not all identical)
unique_samples = set(samples)
- assert (
- len(unique_samples) > 1
- ), f"Sampling should be random, got {len(unique_samples)} unique samples from: {samples}"
+ assert len(unique_samples) > 1, (
+ f"Sampling should be random, got {len(unique_samples)} unique samples from: {samples}"
+ )
def test_mixed_data_types(self):
"""Test buffer handles different numpy data types correctly"""
@@ -1243,8 +1222,8 @@ def test_mixed_data_types(self):
assert buffer.pos == 1, "Should successfully add data with mixed types"
# All stored tensors should be float32
- assert (
- buffer.observations.dtype == torch.float32
- ), "Observations should be float32"
+ assert buffer.observations.dtype == torch.float32, (
+ "Observations should be float32"
+ )
assert buffer.actions.dtype == torch.float32, "Actions should be float32"
assert buffer.rewards.dtype == torch.float32, "Rewards should be float32"
diff --git a/test/runners/test_es_runner.py b/test/runners/test_es_runner.py
index 2f25bf6a..c7c078e0 100644
--- a/test/runners/test_es_runner.py
+++ b/test/runners/test_es_runner.py
@@ -61,15 +61,15 @@ class TestMightyNESRunner:
def test_init(self):
runner = MightyESRunner(self.runner_config)
- assert isinstance(
- runner, MightyRunner
- ), "MightyNESRunner should be an instance of MightyRunner"
- assert isinstance(
- runner.agent, MightyAgent
- ), "MightyNESRunner should have a MightyAgent"
- assert isinstance(
- runner.agent.eval_env, PufferlibToGymAdapter
- ), "Eval env should be a PufferlibToGymAdapter"
+ assert isinstance(runner, MightyRunner), (
+ "MightyNESRunner should be an instance of MightyRunner"
+ )
+ assert isinstance(runner.agent, MightyAgent), (
+ "MightyNESRunner should have a MightyAgent"
+ )
+ assert isinstance(runner.agent.eval_env, PufferlibToGymAdapter), (
+ "Eval env should be a PufferlibToGymAdapter"
+ )
assert runner.agent.env is not None, "Env should be set"
assert runner.iterations is not None, "Iterations should be set"
assert runner.es is not None, "ES should be set"
@@ -85,18 +85,18 @@ def test_run(self):
new_params = runner.agent.parameters
assert isinstance(train_results, dict), "Train results should be a dictionary"
assert isinstance(eval_results, dict), "Eval results should be a dictionary"
- assert (
- "mean_eval_reward" in eval_results
- ), "Mean eval reward should be in eval results"
+ assert "mean_eval_reward" in eval_results, (
+ "Mean eval reward should be in eval results"
+ )
param_equals = [o == p for o, p in zip(old_params, new_params)]
for params in param_equals:
- assert not all(
- params.flatten()
- ), "Parameters should have changed in training"
- assert (
- not old_lr == runner.agent.learning_rate
- ), "Learning rate should have changed in training"
- assert (
- not old_batch_size == runner.agent._batch_size
- ), "Batch size should have changed in training"
+ assert not all(params.flatten()), (
+ "Parameters should have changed in training"
+ )
+ assert not old_lr == runner.agent.learning_rate, (
+ "Learning rate should have changed in training"
+ )
+ assert not old_batch_size == runner.agent._batch_size, (
+ "Batch size should have changed in training"
+ )
shutil.rmtree("test_nes_runner")
diff --git a/test/runners/test_runner.py b/test/runners/test_runner.py
index e2b6ce31..fb94f41d 100644
--- a/test/runners/test_runner.py
+++ b/test/runners/test_runner.py
@@ -2,14 +2,14 @@
import shutil
-import pytest
import gymnasium as gym
+import pytest
from omegaconf import OmegaConf
from mighty.mighty_agents import MightyAgent
from mighty.mighty_runners import MightyOnlineRunner, MightyRunner
-from mighty.mighty_utils.wrappers import PufferlibToGymAdapter
from mighty.mighty_utils.test_helpers import DummyEnv
+from mighty.mighty_utils.wrappers import PufferlibToGymAdapter
class TestMightyRunner:
@@ -57,22 +57,22 @@ class TestMightyRunner:
def test_init(self):
runner = MightyOnlineRunner(self.runner_config)
- assert isinstance(
- runner, MightyRunner
- ), "MightyOnlineRunner should be an instance of MightyRunner"
- assert isinstance(
- runner.agent, MightyAgent
- ), "MightyOnlineRunner should have a MightyAgent"
- assert isinstance(
- runner.agent.eval_env, PufferlibToGymAdapter
- ), "Eval env should be a PufferlibToGymAdapter"
+ assert isinstance(runner, MightyRunner), (
+ "MightyOnlineRunner should be an instance of MightyRunner"
+ )
+ assert isinstance(runner.agent, MightyAgent), (
+ "MightyOnlineRunner should have a MightyAgent"
+ )
+ assert isinstance(runner.agent.eval_env, PufferlibToGymAdapter), (
+ "Eval env should be a PufferlibToGymAdapter"
+ )
assert runner.agent.env is not None, "Env should not be None"
- assert (
- runner.eval_every_n_steps == self.runner_config.eval_every_n_steps
- ), "Eval every n steps should be set"
- assert (
- runner.num_steps == self.runner_config.num_steps
- ), "Num steps should be set"
+ assert runner.eval_every_n_steps == self.runner_config.eval_every_n_steps, (
+ "Eval every n steps should be set"
+ )
+ assert runner.num_steps == self.runner_config.num_steps, (
+ "Num steps should be set"
+ )
def test_train(self):
runner = MightyOnlineRunner(self.runner_config)
@@ -96,14 +96,14 @@ def test_run(self):
train_results, eval_results = runner.run()
assert isinstance(train_results, dict), "Train results should be a dictionary"
assert isinstance(eval_results, dict), "Eval results should be a dictionary"
- assert (
- "mean_eval_reward" in eval_results
- ), "Eval results should have mean_eval_reward"
+ assert "mean_eval_reward" in eval_results, (
+ "Eval results should have mean_eval_reward"
+ )
shutil.rmtree("test_runner")
def test_run_with_alternate_env(self):
dummy_env = gym.vector.SyncVectorEnv([DummyEnv for _ in range(3)])
- dummy_eval_func = lambda: gym.vector.SyncVectorEnv( # noqa: E731
+ dummy_eval_func = lambda: gym.vector.SyncVectorEnv( # noqa: E731
[DummyEnv for _ in range(10)]
)
eval_default = 10
diff --git a/test/runners/test_runner_factory.py b/test/runners/test_runner_factory.py
index bfad9840..77f556c4 100644
--- a/test/runners/test_runner_factory.py
+++ b/test/runners/test_runner_factory.py
@@ -17,9 +17,9 @@ class TestFactory:
def test_create_agent(self):
for runner_type in VALID_RUNNER_TYPES:
runner_class = get_runner_class(runner_type)
- assert (
- runner_class == RUNNER_CLASSES[runner_type]
- ), f"Runner class should be {RUNNER_CLASSES[runner_type]}"
+ assert runner_class == RUNNER_CLASSES[runner_type], (
+ f"Runner class should be {RUNNER_CLASSES[runner_type]}"
+ )
def test_create_agent_with_invalid_type(self):
with pytest.raises(ValueError):
diff --git a/test/test_env_creation.py b/test/test_env_creation.py
index c6b6eff4..dbe41d14 100644
--- a/test/test_env_creation.py
+++ b/test/test_env_creation.py
@@ -120,56 +120,56 @@ class TestEnvCreation:
def check_vector_env(self, env):
"""Check if environment is a vector environment."""
- assert hasattr(
- env, "num_envs"
- ), f"Vector environment should have num_envs attribute: {env}"
- assert hasattr(
- env, "reset"
- ), f"Vector environment should have reset method: {env}."
- assert hasattr(
- env, "step"
- ), f"Vector environment should have step method: {env}."
- assert hasattr(
- env, "close"
- ), f"Vector environment should have close method: {env}."
- assert hasattr(
- env, "single_action_space"
- ), f"Vector environment should have single action space view: {env}."
- assert hasattr(
- env, "single_observation_space"
- ), f"Vector environment should have single observation space view: {env}."
- assert hasattr(
- env, "envs"
- ), f"Environments should be kept in envs attribute: {env}."
+ assert hasattr(env, "num_envs"), (
+ f"Vector environment should have num_envs attribute: {env}"
+ )
+ assert hasattr(env, "reset"), (
+ f"Vector environment should have reset method: {env}."
+ )
+ assert hasattr(env, "step"), (
+ f"Vector environment should have step method: {env}."
+ )
+ assert hasattr(env, "close"), (
+ f"Vector environment should have close method: {env}."
+ )
+ assert hasattr(env, "single_action_space"), (
+ f"Vector environment should have single action space view: {env}."
+ )
+ assert hasattr(env, "single_observation_space"), (
+ f"Vector environment should have single observation space view: {env}."
+ )
+ assert hasattr(env, "envs"), (
+ f"Environments should be kept in envs attribute: {env}."
+ )
def test_make_gym_env(self):
"""Test env creation with make_gym_env."""
env, eval_env, eval_default = make_gym_env(self.gym_config)
self.check_vector_env(env)
self.check_vector_env(eval_env())
- assert (
- eval_default == self.gym_config.n_episodes_eval
- ), "Default number of eval episodes should match config"
- assert (
- len(env.envs) == self.gym_config.num_envs
- ), "Number of environments should match config."
- assert (
- len(eval_env().envs) == self.gym_config.n_episodes_eval
- ), "Number of environments should match config."
+ assert eval_default == self.gym_config.n_episodes_eval, (
+ "Default number of eval episodes should match config"
+ )
+ assert len(env.envs) == self.gym_config.num_envs, (
+ "Number of environments should match config."
+ )
+ assert len(eval_env().envs) == self.gym_config.n_episodes_eval, (
+ "Number of environments should match config."
+ )
- assert (
- self.gym_config.env == env.envs[0].spec.id
- ), "Environment should be created with the correct id."
- assert (
- self.gym_config.env == eval_env().envs[0].spec.id
- ), "Eval environment should be created with the correct id."
+ assert self.gym_config.env == env.envs[0].spec.id, (
+ "Environment should be created with the correct id."
+ )
+ assert self.gym_config.env == eval_env().envs[0].spec.id, (
+ "Eval environment should be created with the correct id."
+ )
- assert isinstance(
- env, gym.vector.SyncVectorEnv
- ), "Gym environment should be a SyncVectorEnv."
- assert isinstance(
- eval_env(), gym.vector.SyncVectorEnv
- ), "Eval environment should be a SyncVectorEnv."
+ assert isinstance(env, gym.vector.SyncVectorEnv), (
+ "Gym environment should be a SyncVectorEnv."
+ )
+ assert isinstance(eval_env(), gym.vector.SyncVectorEnv), (
+ "Eval environment should be a SyncVectorEnv."
+ )
def test_make_dacbench_env(self):
"""Test env creation with make_dacbench_env."""
@@ -180,32 +180,34 @@ def test_make_dacbench_env(self):
eval_default
== len(env.envs[0].instance_set.keys())
* self.dacbench_config.n_episodes_eval
- ), "Default number of eval episodes should instance set size times evaluation episodes."
- assert (
- len(env.envs) == self.dacbench_config.num_envs
- ), "Number of environments should match config."
+ ), (
+ "Default number of eval episodes should instance set size times evaluation episodes."
+ )
+ assert len(env.envs) == self.dacbench_config.num_envs, (
+ "Number of environments should match config."
+ )
assert (
len(eval_env().envs)
== len(env.envs[0].instance_set.keys())
* self.dacbench_config.n_episodes_eval
), "Number of environments should match eval length."
- assert isinstance(
- env, gym.vector.SyncVectorEnv
- ), "DACBench environment should be a SyncVectorEnv."
- assert isinstance(
- eval_env(), gym.vector.SyncVectorEnv
- ), "Eval environment should be a SyncVectorEnv."
+ assert isinstance(env, gym.vector.SyncVectorEnv), (
+ "DACBench environment should be a SyncVectorEnv."
+ )
+ assert isinstance(eval_env(), gym.vector.SyncVectorEnv), (
+ "Eval environment should be a SyncVectorEnv."
+ )
bench = getattr(benchmarks, self.dacbench_config.env)()
for k in self.dacbench_config.env_kwargs:
bench.config[k] = self.dacbench_config.env_kwargs[k]
- assert isinstance(
- env.envs[0], type(bench.get_environment())
- ), "Environment should have correct type."
- assert isinstance(
- eval_env().envs[0], type(bench.get_environment())
- ), "Eval environment should have correct type."
+ assert isinstance(env.envs[0], type(bench.get_environment())), (
+ "Environment should have correct type."
+ )
+ assert isinstance(eval_env().envs[0], type(bench.get_environment())), (
+ "Eval environment should have correct type."
+ )
assert (
env.envs[0].config.instance_set_path
== self.dacbench_config.env_kwargs.instance_set_path
@@ -221,22 +223,24 @@ def test_make_dacbench_benchmark_mode(self):
eval_default
== len(env.envs[0].instance_set.keys())
* self.dacbench_config_benchmark.n_episodes_eval
- ), "Default number of eval episodes should instance set size times evaluation episodes."
- assert (
- len(env.envs) == self.dacbench_config_benchmark.num_envs
- ), "Number of environments should match config."
+ ), (
+ "Default number of eval episodes should instance set size times evaluation episodes."
+ )
+ assert len(env.envs) == self.dacbench_config_benchmark.num_envs, (
+ "Number of environments should match config."
+ )
assert (
len(eval_env().envs)
== len(env.envs[0].instance_set.keys())
* self.dacbench_config_benchmark.n_episodes_eval
), "Number of environments should match eval length."
- assert isinstance(
- env, gym.vector.SyncVectorEnv
- ), "DACBench environment should be a SyncVectorEnv."
- assert isinstance(
- eval_env(), gym.vector.SyncVectorEnv
- ), "Eval environment should be a SyncVectorEnv."
+ assert isinstance(env, gym.vector.SyncVectorEnv), (
+ "DACBench environment should be a SyncVectorEnv."
+ )
+ assert isinstance(eval_env(), gym.vector.SyncVectorEnv), (
+ "Eval environment should be a SyncVectorEnv."
+ )
benchmark_kwargs = OmegaConf.to_container(
self.dacbench_config_benchmark.env_kwargs, resolve=True
@@ -245,12 +249,12 @@ def test_make_dacbench_benchmark_mode(self):
del benchmark_kwargs["config_space"]
bench = getattr(benchmarks, self.dacbench_config_benchmark.env)()
benchmark_env = bench.get_benchmark(**benchmark_kwargs)
- assert isinstance(
- env.envs[0], type(benchmark_env)
- ), "Environment should have correct type."
- assert isinstance(
- eval_env().envs[0], type(benchmark_env)
- ), "Eval environment should have correct type."
+ assert isinstance(env.envs[0], type(benchmark_env)), (
+ "Environment should have correct type."
+ )
+ assert isinstance(eval_env().envs[0], type(benchmark_env)), (
+ "Eval environment should have correct type."
+ )
assert eval_env().envs[0].test, "Eval environment should be in test mode."
for k in env.envs[0].config.keys():
if k == "observation_space_args":
@@ -260,19 +264,25 @@ def test_make_dacbench_benchmark_mode(self):
assert (
env.envs[0].config[k][i].functions[0].a
== benchmark_env.config[k][i].functions[0].a
- ), f"Environment should have matching instances, mismatch for function parameter a at instance {i}: {env.envs[0].config[k][i].functions[0].a} != {benchmark_env.config[k][i].functions[0].a}"
+ ), (
+ f"Environment should have matching instances, mismatch for function parameter a at instance {i}: {env.envs[0].config[k][i].functions[0].a} != {benchmark_env.config[k][i].functions[0].a}"
+ )
assert (
env.envs[0].config[k][i].functions[0].b
== benchmark_env.config[k][i].functions[0].b
- ), f"Environment should have matching instances, mismatch for function parameter b at instance {i}: {env.envs[0].config[k][i].functions[0].b} != {benchmark_env.config[k][i].functions[0].b}"
+ ), (
+ f"Environment should have matching instances, mismatch for function parameter b at instance {i}: {env.envs[0].config[k][i].functions[0].b} != {benchmark_env.config[k][i].functions[0].b}"
+ )
assert (
env.envs[0].config[k][i].omit_instance_type
== benchmark_env.config[k][i].omit_instance_type
- ), f"Environment should have matching instances, mismatch for omit_instance_type at instance {i}: {env.envs[0].config[k][i].omit_instance_type} != {benchmark_env.config[k][i].omit_instance_type}"
+ ), (
+ f"Environment should have matching instances, mismatch for omit_instance_type at instance {i}: {env.envs[0].config[k][i].omit_instance_type} != {benchmark_env.config[k][i].omit_instance_type}"
+ )
else:
- assert (
- env.envs[0].config[k] == benchmark_env.config[k]
- ), f"Environment should have correct config, mismatch at {k}: {env.envs[0].config[k]} != {benchmark_env.config[k]}"
+ assert env.envs[0].config[k] == benchmark_env.config[k], (
+ f"Environment should have correct config, mismatch at {k}: {env.envs[0].config[k]} != {benchmark_env.config[k]}"
+ )
def test_make_carl_env(self):
"""Test env creation with make_carl_env."""
@@ -284,19 +294,19 @@ def test_make_carl_env(self):
), "Default number of eval episodes should match config"
env_class = getattr(carl.envs, self.carl_config.env)
- assert isinstance(
- env.envs[0], env_class
- ), "Environment should have the correct type."
- assert isinstance(
- eval_env().envs[0], env_class
- ), "Eval environment should have the correct type."
+ assert isinstance(env.envs[0], env_class), (
+ "Environment should have the correct type."
+ )
+ assert isinstance(eval_env().envs[0], env_class), (
+ "Eval environment should have the correct type."
+ )
- assert isinstance(
- env, CARLVectorEnvSimulator
- ), "CARL environment should be wrapped."
- assert isinstance(
- eval_env(), CARLVectorEnvSimulator
- ), "CARL eval environment should be wrapped."
+ assert isinstance(env, CARLVectorEnvSimulator), (
+ "CARL environment should be wrapped."
+ )
+ assert isinstance(eval_env(), CARLVectorEnvSimulator), (
+ "CARL eval environment should be wrapped."
+ )
def test_make_carl_context(self):
"""Test env creation with make_carl_env."""
@@ -312,9 +322,9 @@ def test_make_carl_context(self):
assert (
len(train_contexts) == self.carl_config_context.env_kwargs.num_contexts
), "Number of training contexts should match config."
- assert (
- len(eval_contexts) == 100
- ), "Number of eval contexts should match default."
+ assert len(eval_contexts) == 100, (
+ "Number of eval contexts should match default."
+ )
assert not all(
[
@@ -409,32 +419,36 @@ def test_make_carl_context(self):
), "Eval contexts lie above lower bound for gravity."
assert isinstance(
env.envs[0].context_selector, carl.context.selection.StaticSelector
- ), f"Context selector should be switched to a StaticSelector based on keyword but is {type(env.envs[0].context_selector)}."
+ ), (
+ f"Context selector should be switched to a StaticSelector based on keyword but is {type(env.envs[0].context_selector)}."
+ )
assert isinstance(
eval_env().envs[0].context_selector,
carl.context.selection.RoundRobinSelector,
- ), f"Eval env context selector should stay round robin but is {type(eval_env().envs[0].context_selector)}."
+ ), (
+ f"Eval env context selector should stay round robin but is {type(eval_env().envs[0].context_selector)}."
+ )
def test_make_procgen_env(self):
"""Test env creation with make_procgen_env."""
if PROCGEN:
env, eval_env, eval_default = make_procgen_env(self.procgen_config)
- assert (
- eval_default == self.procgen_config.n_episodes_eval
- ), "Default number of eval episodes should match config"
+ assert eval_default == self.procgen_config.n_episodes_eval, (
+ "Default number of eval episodes should match config"
+ )
self.check_vector_env(env)
self.check_vector_env(eval_env())
if ENVPOOL:
- assert isinstance(
- env, envpool.VectorEnv
- ), "Environment should be an envpool env if we create a gym env with envpool installed."
+ assert isinstance(env, envpool.VectorEnv), (
+ "Environment should be an envpool env if we create a gym env with envpool installed."
+ )
else:
- assert isinstance(
- env, ProcgenVecEnv
- ), "Environment should be ProcGen env if we create a gym env without envpool installed."
- assert isinstance(
- eval_env(), ProcgenVecEnv
- ), "Eval env should be a ProcGen env."
+ assert isinstance(env, ProcgenVecEnv), (
+ "Environment should be ProcGen env if we create a gym env without envpool installed."
+ )
+ assert isinstance(eval_env(), ProcgenVecEnv), (
+ "Eval env should be a ProcGen env."
+ )
else:
Warning("Procgen not installed, skipping test.")
@@ -443,57 +457,57 @@ def test_make_pufferlib_env(self):
env, eval_env, eval_default = make_pufferlib_env(self.pufferlib_config)
self.check_vector_env(env)
self.check_vector_env(eval_env())
- assert (
- eval_default == self.pufferlib_config.n_episodes_eval
- ), "Default number of eval episodes should match config"
- assert (
- len(env.envs) == self.pufferlib_config.num_envs
- ), "Number of environments should match config."
- assert (
- len(eval_env().envs) == self.pufferlib_config.n_episodes_eval
- ), "Number of environments should match config."
+ assert eval_default == self.pufferlib_config.n_episodes_eval, (
+ "Default number of eval episodes should match config"
+ )
+ assert len(env.envs) == self.pufferlib_config.num_envs, (
+ "Number of environments should match config."
+ )
+ assert len(eval_env().envs) == self.pufferlib_config.n_episodes_eval, (
+ "Number of environments should match config."
+ )
domain = ".".join(self.pufferlib_config.env.split(".")[:-1])
name = self.pufferlib_config.env.split(".")[-1]
get_env_func = importlib.import_module(domain).env_creator
make_env = get_env_func(name)(**self.pufferlib_config.env_kwargs)
- assert isinstance(
- env.envs[0], type(make_env)
- ), "Environment should have correct type."
- assert isinstance(
- eval_env().envs[0], type(make_env)
- ), "Eval environment should have correct type."
+ assert isinstance(env.envs[0], type(make_env)), (
+ "Environment should have correct type."
+ )
+ assert isinstance(eval_env().envs[0], type(make_env)), (
+ "Eval environment should have correct type."
+ )
- assert isinstance(
- env, PufferlibToGymAdapter
- ), "Pufferlib env should be wrapped."
- assert isinstance(
- eval_env(), PufferlibToGymAdapter
- ), "Pufferlib eval env should be wrapped."
+ assert isinstance(env, PufferlibToGymAdapter), (
+ "Pufferlib env should be wrapped."
+ )
+ assert isinstance(eval_env(), PufferlibToGymAdapter), (
+ "Pufferlib eval env should be wrapped."
+ )
def test_make_mighty_env(self):
"""Test correct typing of environments when creating with make_mighty_env."""
env, eval_env, eval_default = make_mighty_env(self.gym_config)
- assert (
- eval_default == self.gym_config.n_episodes_eval
- ), "Default number of eval episodes should match config"
+ assert eval_default == self.gym_config.n_episodes_eval, (
+ "Default number of eval episodes should match config"
+ )
self.check_vector_env(env)
self.check_vector_env(eval_env())
if ENVPOOL:
- assert isinstance(
- env, envpool.VectorEnv
- ), "Mighty environment should be an envpool env if we create a gym env with envpool installed."
- assert isinstance(
- eval_env(), gym.vector.SyncVectorEnv
- ), "Eval env should be a SyncVectorEnv env if we create a gym env with envpool installed."
+ assert isinstance(env, envpool.VectorEnv), (
+ "Mighty environment should be an envpool env if we create a gym env with envpool installed."
+ )
+ assert isinstance(eval_env(), gym.vector.SyncVectorEnv), (
+ "Eval env should be a SyncVectorEnv env if we create a gym env with envpool installed."
+ )
else:
Warning("Envpool not installed, skipping test.")
- assert isinstance(
- env, gym.vector.SyncVectorEnv
- ), "Mighty environment should be a SyncVectorEnv if we create a gym env without envpool installed."
- assert isinstance(
- env, gym.vector.SyncVectorEnv
- ), "Eval environment should be a SyncVectorEnv if we create a gym env without envpool installed."
+ assert isinstance(env, gym.vector.SyncVectorEnv), (
+ "Mighty environment should be a SyncVectorEnv if we create a gym env without envpool installed."
+ )
+ assert isinstance(env, gym.vector.SyncVectorEnv), (
+ "Eval environment should be a SyncVectorEnv if we create a gym env without envpool installed."
+ )
for config in [
self.dacbench_config,
diff --git a/test/update/test_ppo_update.py b/test/update/test_ppo_update.py
index 1c66bad0..009524ac 100644
--- a/test/update/test_ppo_update.py
+++ b/test/update/test_ppo_update.py
@@ -213,9 +213,9 @@ def test_discrete_action_update(self):
]
for metric in required_metrics:
assert metric in metrics, f"Missing metric: {metric}"
- assert isinstance(
- metrics[metric], (int, float)
- ), f"Metric {metric} should be numeric"
+ assert isinstance(metrics[metric], (int, float)), (
+ f"Metric {metric} should be numeric"
+ )
def test_continuous_action_update(self):
"""Test PPO update with continuous actions."""
@@ -261,9 +261,9 @@ def test_continuous_action_update(self):
]
for metric in required_metrics:
assert metric in metrics, f"Missing metric: {metric}"
- assert isinstance(
- metrics[metric], (int, float)
- ), f"Metric {metric} should be numeric"
+ assert isinstance(metrics[metric], (int, float)), (
+ f"Metric {metric} should be numeric"
+ )
def test_value_clipping(self):
"""Test value clipping mechanism."""
@@ -314,15 +314,11 @@ def test_adaptive_learning_rate(self):
"""Test adaptive learning rate adjustment."""
update, model = self.get_update_and_model(adaptive_lr=True, kl_target=0.01)
- # Store initial learning rates
- initial_policy_lr = update.optimizer.param_groups[0]["lr"]
- initial_value_lr = update.optimizer.param_groups[1]["lr"]
-
# Create batch that might trigger LR adaptation
batch = DummyMaxiBatch()
# Run update
- metrics = update.update(batch)
+ update.update(batch)
# Learning rates might have changed (depending on KL divergence)
final_policy_lr = update.optimizer.param_groups[0]["lr"]
@@ -389,8 +385,8 @@ def test_metric_shapes_and_types(self, continuous_action):
for metric_name in expected_metrics:
assert metric_name in metrics, f"Missing metric: {metric_name}"
metric_value = metrics[metric_name]
- assert isinstance(
- metric_value, (int, float)
- ), f"Metric {metric_name} should be scalar"
+ assert isinstance(metric_value, (int, float)), (
+ f"Metric {metric_name} should be scalar"
+ )
assert not np.isnan(metric_value), f"Metric {metric_name} should not be NaN"
assert np.isfinite(metric_value), f"Metric {metric_name} should be finite"
diff --git a/test/update/test_q_update.py b/test/update/test_q_update.py
index 17f83185..a7a20418 100644
--- a/test/update/test_q_update.py
+++ b/test/update/test_q_update.py
@@ -54,12 +54,12 @@ def test_update(self, weight, bias):
"""Test Q-learning update."""
update, model = self.get_update(initial_weights=weight, initial_biases=bias)
checked_model = deepcopy(model)
- assert torch.allclose(
- model.layer.weight, checked_model.layer.weight
- ), "Wrong initial weights."
- assert torch.allclose(
- model.layer.bias, checked_model.layer.bias
- ), "Wrong initial biases."
+ assert torch.allclose(model.layer.weight, checked_model.layer.weight), (
+ "Wrong initial weights."
+ )
+ assert torch.allclose(model.layer.bias, checked_model.layer.bias), (
+ "Wrong initial biases."
+ )
preds, targets = update.get_targets(batch, model)
loss_stats = update.apply_update(preds, targets)
@@ -77,9 +77,9 @@ def test_update(self, weight, bias):
assert torch.allclose(
model.layer.weight, checked_model.layer.weight, atol=1e-3
), "Wrong weights after update"
- assert torch.allclose(
- model.layer.bias, checked_model.layer.bias, atol=1e-3
- ), "Wrong biases after update."
+ assert torch.allclose(model.layer.bias, checked_model.layer.bias, atol=1e-3), (
+ "Wrong biases after update."
+ )
def test_get_targets(self):
"""Test get_targets method."""
@@ -148,9 +148,9 @@ def test_get_targets(self):
correct_targets = batch.rewards.unsqueeze(-1) + mask * 0.99 * target(
torch.as_tensor(batch.next_obs, dtype=torch.float32)
).max(1)[0].unsqueeze(-1)
- assert torch.allclose(
- targets.detach(), correct_targets.type(torch.float32)
- ), "Wrong targets (weight 0)."
+ assert torch.allclose(targets.detach(), correct_targets.type(torch.float32)), (
+ "Wrong targets (weight 0)."
+ )
update, model, target = self.get_update(initial_weights=3)
preds, targets = update.get_targets(batch, model, target)
@@ -193,12 +193,10 @@ def test_get_targets(self):
~batch.dones.unsqueeze(-1)
) * 0.99 * torch.minimum(
batch.next_obs.sum(axis=1), torch.zeros(batch.next_obs.shape).sum(axis=1)
- ).unsqueeze(
- -1
+ ).unsqueeze(-1)
+ assert torch.allclose(targets.detach(), correct_targets.type(torch.float32)), (
+ "Wrong targets (weight 0)."
)
- assert torch.allclose(
- targets.detach(), correct_targets.type(torch.float32)
- ), "Wrong targets (weight 0)."
update, model, target = self.get_update(initial_weights=3)
preds, targets = update.get_targets(batch, model, target)
diff --git a/test/update/test_sac_update.py b/test/update/test_sac_update.py
index 009ea078..10328444 100644
--- a/test/update/test_sac_update.py
+++ b/test/update/test_sac_update.py
@@ -133,7 +133,7 @@ def test_initialization(self):
# Check hyperparameters
assert update.gamma == 0.99
assert update.tau == 0.005
- assert update.auto_alpha == True
+ assert update.auto_alpha
# Check new frequency parameters
assert hasattr(update, "policy_frequency")
@@ -144,7 +144,7 @@ def test_initialization_without_auto_alpha(self):
"""Test SAC initialization with fixed alpha."""
update, model = self.get_update_and_model(auto_alpha=False, alpha=0.1)
- assert update.auto_alpha == False
+ assert not update.auto_alpha
assert update.alpha == 0.1 # Should use the provided alpha value
assert not hasattr(update, "log_alpha")
assert not hasattr(update, "alpha_optimizer")
@@ -187,18 +187,18 @@ def test_basic_update(self):
# Check metrics - updated to include alpha_loss
required_metrics = [
- "q_loss1",
- "q_loss2",
- "policy_loss",
- "alpha_loss",
- "td_error1",
- "td_error2",
+ "Update/q_loss1",
+ "Update/q_loss2",
+ "Update/policy_loss",
+ "Update/alpha_loss",
+ "Update/td_error1",
+ "Update/td_error2",
]
for metric in required_metrics:
assert metric in metrics, f"Missing metric: {metric}"
- assert isinstance(
- metrics[metric], (int, float)
- ), f"Metric {metric} should be numeric"
+ assert isinstance(metrics[metric], (int, float)), (
+ f"Metric {metric} should be numeric"
+ )
assert np.isfinite(metrics[metric]), f"Metric {metric} should be finite"
def test_target_network_updates(self):
@@ -210,10 +210,6 @@ def test_target_network_updates(self):
initial_target_q1 = [p.clone() for p in model.target_q_net1.parameters()]
initial_target_q2 = [p.clone() for p in model.target_q_net2.parameters()]
- # Store initial main Q parameters
- initial_q1 = [p.clone() for p in model.q_net1.parameters()]
- initial_q2 = [p.clone() for p in model.q_net2.parameters()]
-
# Run update
update.update(batch)
@@ -239,9 +235,9 @@ def test_target_network_updates(self):
initial_target_q1,
):
expected = (1 - tau) * p_old_target + tau * p_main
- assert torch.allclose(
- p_target, expected, atol=1e-5
- ), "Target update should follow polyak averaging"
+ assert torch.allclose(p_target, expected, atol=1e-5), (
+ "Target update should follow polyak averaging"
+ )
def test_td_error_calculation(self):
"""Test TD error calculation."""
@@ -285,17 +281,17 @@ def test_fixed_alpha_mode(self):
# Metrics should still be valid - but alpha_loss should be 0
required_metrics = [
- "q_loss1",
- "q_loss2",
- "policy_loss",
- "alpha_loss",
- "td_error1",
- "td_error2",
+ "Update/q_loss1",
+ "Update/q_loss2",
+ "Update/policy_loss",
+ "Update/alpha_loss",
+ "Update/td_error1",
+ "Update/td_error2",
]
for metric in required_metrics:
assert metric in metrics
# Alpha loss should be 0 when auto_alpha=False
- assert metrics["alpha_loss"] == 0.0
+ assert metrics["Update/alpha_loss"] == 0.0
def test_custom_target_entropy(self):
"""Test SAC with custom target entropy."""
@@ -308,7 +304,7 @@ def test_custom_target_entropy(self):
batch = DummyTransitionBatch()
metrics = update.update(batch)
- assert "policy_loss" in metrics
+ assert "Update/policy_loss" in metrics
def test_different_learning_rates(self):
"""Test SAC with different learning rates for policy and Q-networks."""
@@ -358,9 +354,9 @@ def test_different_tau_values(self):
)
)
- assert (
- total_change_large > total_change
- ), "Larger tau should cause bigger target network changes"
+ assert total_change_large > total_change, (
+ "Larger tau should cause bigger target network changes"
+ )
def test_zero_rewards_batch(self):
"""Test SAC with zero rewards."""
@@ -372,9 +368,9 @@ def test_zero_rewards_batch(self):
metrics = update.update(batch)
for metric_name, metric_value in metrics.items():
- assert np.isfinite(
- metric_value
- ), f"Metric {metric_name} should be finite with zero rewards"
+ assert np.isfinite(metric_value), (
+ f"Metric {metric_name} should be finite with zero rewards"
+ )
def test_all_done_batch(self):
"""Test SAC with all episodes terminated."""
@@ -386,9 +382,9 @@ def test_all_done_batch(self):
metrics = update.update(batch)
for metric_name, metric_value in metrics.items():
- assert np.isfinite(
- metric_value
- ), f"Metric {metric_name} should be finite with all done"
+ assert np.isfinite(metric_value), (
+ f"Metric {metric_name} should be finite with all done"
+ )
def test_metric_ranges(self):
"""Test that metrics are in reasonable ranges."""
@@ -398,18 +394,20 @@ def test_metric_ranges(self):
metrics = update.update(batch)
# Q losses should be non-negative (MSE loss)
- assert metrics["q_loss1"] >= 0, "Q loss 1 should be non-negative"
- assert metrics["q_loss2"] >= 0, "Q loss 2 should be non-negative"
+ assert metrics["Update/q_loss1"] >= 0, "Q loss 1 should be non-negative"
+ assert metrics["Update/q_loss2"] >= 0, "Q loss 2 should be non-negative"
# Policy loss can be negative (we want to maximize Q - alpha*entropy)
- assert np.isfinite(metrics["policy_loss"]), "Policy loss should be finite"
+ assert np.isfinite(metrics["Update/policy_loss"]), (
+ "Policy loss should be finite"
+ )
# Alpha loss can be positive or negative
- assert np.isfinite(metrics["alpha_loss"]), "Alpha loss should be finite"
+ assert np.isfinite(metrics["Update/alpha_loss"]), "Alpha loss should be finite"
# TD errors can be positive or negative
- assert np.isfinite(metrics["td_error1"]), "TD error 1 should be finite"
- assert np.isfinite(metrics["td_error2"]), "TD error 2 should be finite"
+ assert np.isfinite(metrics["Update/td_error1"]), "TD error 1 should be finite"
+ assert np.isfinite(metrics["Update/td_error2"]), "TD error 2 should be finite"
@pytest.mark.parametrize("batch_size", [1, 16, 64, 128])
def test_different_batch_sizes(self, batch_size):
@@ -421,12 +419,12 @@ def test_different_batch_sizes(self, batch_size):
metrics = update.update(batch)
required_metrics = [
- "q_loss1",
- "q_loss2",
- "policy_loss",
- "alpha_loss",
- "td_error1",
- "td_error2",
+ "Update/q_loss1",
+ "Update/q_loss2",
+ "Update/policy_loss",
+ "Update/alpha_loss",
+ "Update/td_error1",
+ "Update/td_error2",
]
for metric in required_metrics:
assert metric in metrics
@@ -445,12 +443,12 @@ def test_policy_frequency(self):
# Run updates less than policy_frequency - policy shouldn't change
for i in range(policy_freq - 1):
metrics = update.update(batch)
- assert (
- metrics["policy_loss"] == 0.0
- ), "Policy loss should be 0 when no policy update"
- assert (
- metrics["alpha_loss"] == 0.0
- ), "Alpha loss should be 0 when no policy update"
+ assert metrics["Update/policy_loss"] == 0.0, (
+ "Policy loss should be 0 when no policy update"
+ )
+ assert metrics["Update/alpha_loss"] == 0.0, (
+ "Alpha loss should be 0 when no policy update"
+ )
# Policy parameters shouldn't have changed yet
policy_unchanged = all(
@@ -459,16 +457,16 @@ def test_policy_frequency(self):
)
alpha_unchanged = torch.allclose(initial_alpha, update.log_alpha, atol=1e-6)
- assert (
- policy_unchanged
- ), "Policy parameters should not change before policy_frequency"
+ assert policy_unchanged, (
+ "Policy parameters should not change before policy_frequency"
+ )
assert alpha_unchanged, "Alpha should not change before policy_frequency"
# Now run one more update - should trigger policy update
metrics = update.update(batch)
- assert (
- metrics["policy_loss"] != 0.0
- ), "Policy loss should be non-zero when policy updates"
+ assert metrics["Update/policy_loss"] != 0.0, (
+ "Policy loss should be non-zero when policy updates"
+ )
# Policy parameters should have changed now
policy_changed = any(
@@ -507,14 +505,14 @@ def test_gradient_flow(self):
# Use more lenient tolerance since SAC updates can be small
if change > 1e-8: # Much more lenient than 1e-6
changed_params += 1
-
+
# At least some parameters should change
change_ratio = changed_params / total_params
- assert (
- change_ratio > 0.1
- ), f"Only {change_ratio:.2%} of parameters changed, gradient flow might be broken. Changes: {param_changes}"
+ assert change_ratio > 0.1, (
+ f"Only {change_ratio:.2%} of parameters changed, gradient flow might be broken. Changes: {param_changes}"
+ )
# Additional check: ensure losses are reasonable
- assert np.isfinite(metrics["q_loss1"]) and metrics["q_loss1"] >= 0
- assert np.isfinite(metrics["q_loss2"]) and metrics["q_loss2"] >= 0
- assert np.isfinite(metrics["policy_loss"])
+ assert np.isfinite(metrics["Update/q_loss1"]) and metrics["Update/q_loss1"] >= 0
+ assert np.isfinite(metrics["Update/q_loss2"]) and metrics["Update/q_loss2"] >= 0
+ assert np.isfinite(metrics["Update/policy_loss"])