From e45ea4cdf619e57729872a6794dba48c17193910 Mon Sep 17 00:00:00 2001 From: Theresa Eimer Date: Thu, 31 Jul 2025 10:37:23 +0200 Subject: [PATCH 1/5] extended batch size in td comparison --- test/agents/test_dqn_agent.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/agents/test_dqn_agent.py b/test/agents/test_dqn_agent.py index 5922dc63..ca9cb5d6 100644 --- a/test/agents/test_dqn_agent.py +++ b/test/agents/test_dqn_agent.py @@ -82,12 +82,12 @@ def test_update(self): output_dir = Path("test_dqn_agent") output_dir.mkdir(parents=True, exist_ok=True) dqn = MightyDQNAgent(output_dir, env, batch_size=2) - dqn.run(10, 1) + dqn.run(20, 1) original_optimizer = torch.optim.Adam(dqn.q.parameters(), lr=dqn.learning_rate) original_params = deepcopy(list(dqn.q.parameters())) original_target_params = deepcopy(list(dqn.q_target.parameters())) original_feature_params = deepcopy(list(dqn.q.feature_extractor.parameters())) - batch = dqn.buffer.sample(2) + batch = dqn.buffer.sample(20) metrics = dqn.update_agent(batch, 0) new_params = deepcopy(list(dqn.q.parameters())) new_target_params = deepcopy(list(dqn.q_target.parameters())) @@ -110,7 +110,7 @@ def test_update(self): old * (1 - 0.01) + new * 0.01, new_target ), "Target model parameters should be scaled correctly" - batch = dqn.buffer.sample(2) + batch = dqn.buffer.sample(20) preds, targets = dqn.qlearning.get_targets(batch, dqn.q, dqn.q_target) assert ( From b10e74c851780bd5be757fe09fe39a737fc4fc08 Mon Sep 17 00:00:00 2001 From: Theresa Eimer Date: Thu, 31 Jul 2025 10:37:34 +0200 Subject: [PATCH 2/5] added rough versioning to packages --- pyproject.toml | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 086235c6..05cf4c2a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,19 +38,21 @@ dependencies = [ "scipy", "rich~=12.4", "wandb~=0.12", - "torch", + "torch~=2.7", "dill", "imageio", "evosax==0.1.6", "rliable", "seaborn", - "uniplot" + "uniplot", + "jax~=0.7", + "flax~=0.11" ] [project.optional-dependencies] dev = ["ruff", "mypy", "build", "pytest", "pytest-cov"] -carl = ["carl_bench==1.1.1", "brax==0.12.1"] -dacbench = ["dacbench==0.4.0", "torchvision", "ioh"] +carl = ["carl_bench[brax]==1.1.1"] +dacbench = ["dacbench==0.4.0"] pufferlib = ["pufferlib==2.0.6"] docs = ["mkdocs", "mkdocs-material", "mkdocs-autorefs", "mkdocs-gen-files", "mkdocs-literate-nav", From 97fadf07983ba6c01b0f6da948787024ac06eebb Mon Sep 17 00:00:00 2001 From: Theresa Eimer Date: Thu, 31 Jul 2025 10:59:03 +0200 Subject: [PATCH 3/5] fix td test, add reproducibility test --- mighty/configs/cluster/local.yaml | 8 --- mighty/mighty_agents/base_agent.py | 31 ++------- mighty/mighty_agents/sac.py | 10 +-- mighty/mighty_update/sac_update.py | 12 ++-- test/agents/test_dqn_agent.py | 30 +++++++++ test/agents/test_ppo_agent.py | 105 +++++++++++++++++++++++++++++ test/agents/test_sac_agent.py | 40 +++++++++-- 7 files changed, 188 insertions(+), 48 deletions(-) diff --git a/mighty/configs/cluster/local.yaml b/mighty/configs/cluster/local.yaml index 582473d6..858953b7 100644 --- a/mighty/configs/cluster/local.yaml +++ b/mighty/configs/cluster/local.yaml @@ -1,11 +1,3 @@ -# @package _global_ -# defaults: -# - override /hydra/launcher: joblib - -# hydra: -# launcher: -# n_jobs: 16 - cluster: _target_: distributed.deploy.local.LocalCluster n_workers: ${hydra.sweeper.scenario.n_workers} diff --git a/mighty/mighty_agents/base_agent.py b/mighty/mighty_agents/base_agent.py index 2e30e8b7..255f1243 100644 --- a/mighty/mighty_agents/base_agent.py +++ b/mighty/mighty_agents/base_agent.py @@ -52,34 +52,15 @@ def seed_everything(seed: int, env: gym.Env | None = None): # 2) Gym environment seeding if env is not None: - # If the env is wrapped, try to unwrap to the core - try: - core_env = env.unwrapped - core_env.seed(seed) - except Exception: - pass - - # If vectorized (e.g. SyncVectorEnv), seed each sub‐env separately - if hasattr(env, "envs") and isinstance(env.envs, list): - sub_seeds = [seed for _ in range(len(env.envs))] - for sub_seed, subenv in zip(sub_seeds, env.envs): - subenv.action_space.seed(sub_seed) - subenv.observation_space.seed(sub_seed) - try: - subenv.unwrapped.seed(sub_seed) - except Exception: - pass - # Reset the vectorized env with explicit seeds - env.reset(seed=sub_seeds) - else: - # Single environment - env.action_space.seed(seed) - env.observation_space.seed(seed) + sub_seeds = [seed for _ in range(len(env.envs))] + for sub_seed, subenv in zip(sub_seeds, env.envs): + subenv.action_space.seed(sub_seed) + subenv.observation_space.seed(sub_seed) try: - env.unwrapped.seed(seed) + subenv.unwrapped.seed(sub_seed) except Exception: pass - env.reset(seed=seed) + env.reset(seed=sub_seeds) def update_buffer(buffer, new_data): diff --git a/mighty/mighty_agents/sac.py b/mighty/mighty_agents/sac.py index 5617c739..2125c794 100644 --- a/mighty/mighty_agents/sac.py +++ b/mighty/mighty_agents/sac.py @@ -120,11 +120,11 @@ def __init__( # Initialize loss buffer for logging self.loss_buffer = { - "q_loss1": [], - "q_loss2": [], - "policy_loss": [], - "td_error1": [], - "td_error2": [], + "Update/q_loss1": [], + "Update/q_loss2": [], + "Update/policy_loss": [], + "Update/td_error1": [], + "Update/td_error2": [], "step": [], } diff --git a/mighty/mighty_update/sac_update.py b/mighty/mighty_update/sac_update.py index 74d8a6d9..9cee6998 100644 --- a/mighty/mighty_update/sac_update.py +++ b/mighty/mighty_update/sac_update.py @@ -176,10 +176,10 @@ def update(self, batch: TransitionBatch) -> Dict: # --- Logging metrics --- td1, td2 = self.calculate_td_error(batch) return { - "q_loss1": q_loss1.item(), - "q_loss2": q_loss2.item(), - "policy_loss": policy_loss.item(), - "alpha_loss": alpha_loss.item(), - "td_error1": td1.mean().item(), - "td_error2": td2.mean().item(), + "Update/q_loss1": q_loss1.item(), + "Update/q_loss2": q_loss2.item(), + "Update/policy_loss": policy_loss.item(), + "Update/alpha_loss": alpha_loss.item(), + "Update/td_error1": td1.mean().item(), + "Update/td_error2": td2.mean().item(), } diff --git a/test/agents/test_dqn_agent.py b/test/agents/test_dqn_agent.py index ca9cb5d6..81238100 100644 --- a/test/agents/test_dqn_agent.py +++ b/test/agents/test_dqn_agent.py @@ -214,3 +214,33 @@ def process_transition(self): ), "New value prediction should be added" assert len(metrics["td_error"]) == 1, "TD error is overwritten" clean(output_dir) + + def test_reproducibility(self): + env = gym.vector.SyncVectorEnv([DummyEnv for _ in range(1)]) + output_dir = Path("test_dqn_agent") + output_dir.mkdir(parents=True, exist_ok=True) + dqn = MightyDQNAgent(output_dir, env, batch_size=2, seed=42) + init_params = deepcopy(list(dqn.q.parameters())) + dqn.run(20, 1) + batch = dqn.buffer.sample(20) + original_metrics = dqn.update_agent(batch, 0) + original_params = deepcopy(list(dqn.q.parameters())) + + for _ in range(3): + env = gym.vector.SyncVectorEnv([DummyEnv for _ in range(1)]) + output_dir = Path("test_dqn_agent") + output_dir.mkdir(parents=True, exist_ok=True) + dqn = MightyDQNAgent(output_dir, env, batch_size=2, seed=42) + for old, new in zip(init_params[:10], list(dqn.q.parameters())[:10], strict=False): + assert not torch.allclose(old, new), "Parameter initialization should be the same with same seed" + dqn.run(20, 1) + batch = dqn.buffer.sample(20) + new_metrics = dqn.update_agent(batch, 0) + for old, new in zip(original_params[:10], list(dqn.q.parameters())[:10], strict=False): + assert not torch.allclose(old, new), "Model parameters should stay the same with same seed" + + for old, new in zip(original_metrics["Update/td_targets"][:10], new_metrics["Update/td_targets"][:10], strict=False): + assert torch.allclose(old, new), "TD targets should be the same with same seed" + for old, new in zip(original_metrics["Update/td_errors"][:10], new_metrics["Update/td_errors"][:10], strict=False): + assert torch.allclose(old, new), "TD errors should be the same with same seed" + clean(output_dir) \ No newline at end of file diff --git a/test/agents/test_ppo_agent.py b/test/agents/test_ppo_agent.py index bd427337..635b0dc0 100644 --- a/test/agents/test_ppo_agent.py +++ b/test/agents/test_ppo_agent.py @@ -263,3 +263,108 @@ def test_properties(self): assert value_output.shape == (1, 1), "Value function should output shape (1, 1)" clean(output_dir) + + def test_reproducibility(self): + env = gym.vector.SyncVectorEnv([DummyEnv for _ in range(1)]) + output_dir = Path("test_ppo_agent") + output_dir.mkdir(parents=True, exist_ok=True) + ppo = MightyPPOAgent(output_dir, env, batch_size=2, seed=42) + init_params = deepcopy(list(ppo.model.parameters())) + + metrics = { + "env": ppo.env, + "step": 0, + "hp/lr": ppo.learning_rate, + "hp/pi_epsilon": ppo._epsilon, + "hp/batch_size": ppo._batch_size, + "hp/learning_starts": ppo._learning_starts, + } + curr_s, _ = env.reset(seed=42) + # Collect enough transitions to fill the buffer + for step in range(128): # Fill buffer with 128 transitions + # Get action from agent + action, log_prob = ppo.step(curr_s, metrics) + + # Take environment step + next_s, reward, terminated, truncated, _ = env.step(action) + dones = np.logical_or(terminated, truncated) + + # Process the transition (this adds to buffer) + ppo.process_transition( + curr_s, + action, + reward, + next_s, + dones, + log_prob.detach().cpu().numpy(), + {"step": step}, + ) + + # Update current state + curr_s = next_s + + # Reset environment if done + if np.any(dones): + curr_s, _ = env.reset() + + update_kwargs = {"next_s": curr_s, "dones": np.array([False])} + original_metrics = ppo.update(metrics, update_kwargs) + original_params = deepcopy(list(ppo.model.parameters())) + + for _ in range(3): + env = gym.vector.SyncVectorEnv([DummyEnv for _ in range(1)]) + output_dir = Path("test_ppo_agent") + output_dir.mkdir(parents=True, exist_ok=True) + ppo = MightyPPOAgent(output_dir, env, batch_size=2, seed=42) + for old, new in zip(init_params[:10], list(ppo.model.parameters())[:10], strict=False): + assert not torch.allclose(old, new), "Parameter initialization should be the same with same seed" + + metrics = { + "env": ppo.env, + "step": 0, + "hp/lr": ppo.learning_rate, + "hp/pi_epsilon": ppo._epsilon, + "hp/batch_size": ppo._batch_size, + "hp/learning_starts": ppo._learning_starts, + } + curr_s, _ = env.reset(seed=42) + # Collect enough transitions to fill the buffer + for step in range(128): # Fill buffer with 128 transitions + # Get action from agent + action, log_prob = ppo.step(curr_s, metrics) + + # Take environment step + next_s, reward, terminated, truncated, _ = env.step(action) + dones = np.logical_or(terminated, truncated) + + # Process the transition (this adds to buffer) + ppo.process_transition( + curr_s, + action, + reward, + next_s, + dones, + log_prob.detach().cpu().numpy(), + {"step": step}, + ) + + # Update current state + curr_s = next_s + + # Reset environment if done + if np.any(dones): + curr_s, _ = env.reset() + + update_kwargs = {"next_s": curr_s, "dones": np.array([False])} + new_metrics = ppo.update(metrics, update_kwargs) + + for old, new in zip(original_params[:10], list(ppo.model.parameters())[:10], strict=False): + assert not torch.allclose(old, new), "Model parameters should stay the same with same seed" + + for old, new in zip(original_metrics["Update/value_loss"], new_metrics["Update/value_loss"], strict=False): + assert torch.allclose(old, new), "Value loss should be the same with same seed" + for old, new in zip(original_metrics["Update/policy_loss"], new_metrics["Update/policy_loss"], strict=False): + assert torch.allclose(old, new), "Policy loss should be the same with same seed" + for old, new in zip(original_metrics["Update/entropy"], new_metrics["Update/entropy"], strict=False): + assert torch.allclose(old, new), "Entropy should be the same with same seed" + clean(output_dir) \ No newline at end of file diff --git a/test/agents/test_sac_agent.py b/test/agents/test_sac_agent.py index 260d66d0..d9285740 100644 --- a/test/agents/test_sac_agent.py +++ b/test/agents/test_sac_agent.py @@ -189,10 +189,10 @@ def test_update(self): # Check for expected SAC metrics in the result expected_metrics = [ - "q_loss1", - "q_loss2", - "policy_loss", - "alpha_loss", + "Update/q_loss1", + "Update/q_loss2", + "Update/policy_loss", + "Update/alpha_loss", ] # Added alpha_loss for metric in expected_metrics: if metric in result_metrics: @@ -262,3 +262,35 @@ def test_properties(self): ), "Value function should be cached and return same instance" clean(output_dir) + + def test_reproducibility(self): + env = gym.vector.SyncVectorEnv([DummyContinuousEnv for _ in range(1)]) + output_dir = Path("test_sac_agent") + output_dir.mkdir(parents=True, exist_ok=True) + sac = MightySACAgent(output_dir, env, seed=42) + init_params = deepcopy(list(sac.model.parameters())) + sac.run(20, 1) + batch = sac.buffer.sample(20) + original_metrics = sac.update_agent(batch, 0) + original_params = deepcopy(list(sac.model.parameters())) + + for _ in range(3): + env = gym.vector.SyncVectorEnv([DummyContinuousEnv for _ in range(1)]) + output_dir = Path("test_sac_agent") + output_dir.mkdir(parents=True, exist_ok=True) + sac = MightySACAgent(output_dir, env, seed=42) + for old, new in zip(init_params[:10], list(sac.model.parameters())[:10], strict=False): + assert not torch.allclose(old, new), "Parameter initialization should be the same with same seed" + sac.run(20, 1) + batch = sac.buffer.sample(20) + new_metrics = sac.update_agent(batch, 0) + for old, new in zip(original_params[:10], list(sac.model.parameters())[:10], strict=False): + assert not torch.allclose(old, new), "Model parameters should stay the same with same seed" + + for old, new in zip(original_metrics["Update/q_loss1"][:10], new_metrics["Update/q_loss1"][:10], strict=False): + assert torch.allclose(old, new), "Q1 loss should be the same with same seed" + for old, new in zip(original_metrics["Update/q_loss2"][:10], new_metrics["Update/q_loss2"][:10], strict=False): + assert torch.allclose(old, new), "Q2 loss should be the same with same seed" + for old, new in zip(original_metrics["Update/policy_loss"][:10], new_metrics["Update/policy_loss"][:10], strict=False): + assert torch.allclose(old, new), "Policy loss should be the same with same seed" + clean(output_dir) From d65f46e2ef828c1615425e16f3eb6ef71160ef2c Mon Sep 17 00:00:00 2001 From: Theresa Eimer Date: Thu, 31 Jul 2025 12:36:01 +0200 Subject: [PATCH 4/5] ugrade python, format, lint --- README.md | 6 +- docs/installation.md | 2 +- mighty/configs/base.yaml | 4 +- mighty/configs/cluster/example_cpu.yaml | 11 + mighty/configs/cluster/example_gpu.yaml | 12 + mighty/configs/cluster/local.yaml | 5 - mighty/configs/cluster/luis.yaml | 17 - mighty/configs/cluster/noctua.yaml | 25 -- mighty/configs/cluster/tnt.yaml | 15 - mighty/configs/cmaes_hpo.yaml | 1 + mighty/configs/nes.yaml | 1 + mighty/configs/ppo_smac.yaml | 5 +- mighty/configs/sac_smac.yaml | 5 +- mighty/configs/sweep_ppo_pbt.yaml | 1 + mighty/configs/sweep_rs.yaml | 1 + mighty/mighty_agents/base_agent.py | 69 ++-- .../mighty_exploration_policy.py | 4 + mighty/mighty_meta/mighty_component.py | 7 + mighty/mighty_meta/plr.py | 1 - mighty/mighty_meta/space.py | 4 +- mighty/mighty_models/networks.py | 12 +- mighty/mighty_replay/buffer.py | 6 + .../mighty_prioritized_replay.py | 3 +- mighty/mighty_replay/mighty_rollout_buffer.py | 3 +- mighty/mighty_runners/mighty_es_runner.py | 2 +- mighty/mighty_runners/mighty_runner.py | 2 +- mighty/mighty_update/sac_update.py | 4 +- mighty/mighty_utils/plotting.py | 8 +- mighty/mighty_utils/test_helpers.py | 14 +- pyproject.toml | 4 +- test/agents/test_dqn_agent.py | 126 +++++--- test/agents/test_ppo_agent.py | 65 ++-- test/agents/test_sac_agent.py | 104 +++--- test/exploration/test_epsilon_greedy.py | 36 +-- test/exploration/test_exploration.py | 12 +- test/exploration/test_ez_greedy.py | 42 +-- test/meta_components/test_cosine_schedule.py | 24 +- test/meta_components/test_noveld.py | 142 ++++---- test/meta_components/test_plr.py | 78 ++--- test/meta_components/test_rnd.py | 102 +++--- test/meta_components/test_space.py | 24 +- test/models/test_networks.py | 132 ++++---- test/models/test_ppo_networks.py | 131 ++++---- test/models/test_q_networks.py | 96 +++--- test/models/test_sac_networks.py | 138 ++++---- test/replay/test_buffer.py | 222 ++++++------- test/replay/test_rollout_buffer.py | 211 ++++++------ test/runners/test_es_runner.py | 42 +-- test/runners/test_runner.py | 42 +-- test/runners/test_runner_factory.py | 6 +- test/test_env_creation.py | 304 +++++++++--------- test/update/test_ppo_update.py | 24 +- test/update/test_q_update.py | 32 +- test/update/test_sac_update.py | 130 ++++---- 54 files changed, 1259 insertions(+), 1260 deletions(-) create mode 100644 mighty/configs/cluster/example_cpu.yaml create mode 100644 mighty/configs/cluster/example_gpu.yaml delete mode 100644 mighty/configs/cluster/luis.yaml delete mode 100644 mighty/configs/cluster/noctua.yaml delete mode 100644 mighty/configs/cluster/tnt.yaml diff --git a/README.md b/README.md index 1dfca5c7..b1c0e55e 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@
[![PyPI Version](https://img.shields.io/pypi/v/mighty-rl.svg)](https://pypi.org/project/Mighty-RL/) -![Python](https://img.shields.io/badge/Python-3.10-3776AB) +![Python](https://img.shields.io/badge/Python-3.11-3776AB) ![License](https://img.shields.io/badge/License-BSD3-orange) [![Test](https://github.com/automl/Mighty/actions/workflows/test.yaml/badge.svg)](https://github.com/automl/Mighty/actions/workflows/test.yaml) [![Doc Status](https://github.com/automl/Mighty/actions/workflows/docs_test.yaml/badge.svg)](https://github.com/automl/Mighty/actions/workflows/docs_test.yaml) @@ -42,12 +42,12 @@ Mighty features: ## Installation We recommend to using uv to install and run Mighty in a virtual environment. -The code has been tested with python 3.10/3.11 on Unix systems. +The code has been tested with python 3.11 on Unix systems. First create a clean python environment: ```bash -uv venv --python=3.10 +uv venv --python=3.11 source .venv/bin/activate ``` diff --git a/docs/installation.md b/docs/installation.md index 621d3e1f..35186c8b 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -1,5 +1,5 @@ We recommend to using uv to install and run Mighty in a virtual environment. -The code has been tested with python 3.10. +The code has been tested with python 3.11 on Unix Systems. First create a clean python environment: diff --git a/mighty/configs/base.yaml b/mighty/configs/base.yaml index 313627be..67182c3c 100644 --- a/mighty/configs/base.yaml +++ b/mighty/configs/base.yaml @@ -1,8 +1,8 @@ defaults: - _self_ - - algorithm: sac_mujoco + - /cluster: local + - algorithm: ppo - environment: pufferlib_ocean/bandit - # - search_space: dqn_gym_classic - override hydra/job_logging: colorlog - override hydra/hydra_logging: colorlog - override hydra/help: mighty_help diff --git a/mighty/configs/cluster/example_cpu.yaml b/mighty/configs/cluster/example_cpu.yaml new file mode 100644 index 00000000..c1eca3a2 --- /dev/null +++ b/mighty/configs/cluster/example_cpu.yaml @@ -0,0 +1,11 @@ +# @package _global_ +defaults: + - override /hydra/launcher: submitit_slurm + +hydra: + launcher: + partition: your_partition + cpus_per_task: 1 + name: mighty_cpu_experiment + timeout_min: 30 + mem_gb: 10 \ No newline at end of file diff --git a/mighty/configs/cluster/example_gpu.yaml b/mighty/configs/cluster/example_gpu.yaml new file mode 100644 index 00000000..6b775177 --- /dev/null +++ b/mighty/configs/cluster/example_gpu.yaml @@ -0,0 +1,12 @@ +# @package _global_ +defaults: + - override /hydra/launcher: submitit_slurm + +hydra: + launcher: + partition: your_partition + gpus_per_task: 1 + gres: "gpu:1" + timeout_min: 30 + mem_gb: 10 + name: mighty_gpu_experiment \ No newline at end of file diff --git a/mighty/configs/cluster/local.yaml b/mighty/configs/cluster/local.yaml index 858953b7..e69de29b 100644 --- a/mighty/configs/cluster/local.yaml +++ b/mighty/configs/cluster/local.yaml @@ -1,5 +0,0 @@ -cluster: - _target_: distributed.deploy.local.LocalCluster - n_workers: ${hydra.sweeper.scenario.n_workers} - processes: false - threads_per_worker: 1 \ No newline at end of file diff --git a/mighty/configs/cluster/luis.yaml b/mighty/configs/cluster/luis.yaml deleted file mode 100644 index ae307576..00000000 --- a/mighty/configs/cluster/luis.yaml +++ /dev/null @@ -1,17 +0,0 @@ -# @package _global_ -defaults: - - override /hydra/launcher: submitit_slurm - -cluster: - queue: ai,tnt # partition - -hydra: - launcher: - partition: ai - cpus_per_task: 1 - name: expl2 - timeout_min: 20 - mem_gb: 4 - setup: - - module load Miniconda3 - - conda activate /bigwork/nhwpbenc/conda/envs/mighty \ No newline at end of file diff --git a/mighty/configs/cluster/noctua.yaml b/mighty/configs/cluster/noctua.yaml deleted file mode 100644 index 3f5a2026..00000000 --- a/mighty/configs/cluster/noctua.yaml +++ /dev/null @@ -1,25 +0,0 @@ -# @package _global_ -defaults: - - override /hydra/launcher: submitit_slurm - -hydra: - launcher: - partition: normal - cpus_per_task: 1 - name: expl2 - timeout_min: 20 - mem_gb: 4 - setup: - - micromamba activate /scratch/hpc-prf-intexml/cbenjamins/envs/mighty - -cluster: - _target_: dask_jobqueue.SLURMCluster - queue: normal # set in cluster config - # account: myaccount - cores: 16 - memory: 32 GB - walltime: 01:00:00 - processes: 1 - log_directory: tmp/mighty_smac - n_workers: 16 - death_timeout: 30 diff --git a/mighty/configs/cluster/tnt.yaml b/mighty/configs/cluster/tnt.yaml deleted file mode 100644 index 3b2a2008..00000000 --- a/mighty/configs/cluster/tnt.yaml +++ /dev/null @@ -1,15 +0,0 @@ -defaults: - - override hydra/launcher: submitit_slurm - -cluster: - queue: cpu_short # partition - -hydra: - launcher: - partition: cpu_short # change this to your partition name - #gres: gpu:1 # use this option when running on GPUs - mem_gb: 12 # memory requirements - cpus_per_task: 20 # number of cpus per run - timeout_min: 720 # timeout in minutes - setup: - - export XLA_PYTHON_CLIENT_PREALLOCATE=false diff --git a/mighty/configs/cmaes_hpo.yaml b/mighty/configs/cmaes_hpo.yaml index 4de1e6a4..a468865b 100644 --- a/mighty/configs/cmaes_hpo.yaml +++ b/mighty/configs/cmaes_hpo.yaml @@ -1,5 +1,6 @@ defaults: - _self_ + - /cluster: local - algorithm: dqn - environment: pufferlib_ocean/bandit - search_space: dqn_gym_classic diff --git a/mighty/configs/nes.yaml b/mighty/configs/nes.yaml index 8d3f761d..78d3509e 100644 --- a/mighty/configs/nes.yaml +++ b/mighty/configs/nes.yaml @@ -1,5 +1,6 @@ defaults: - _self_ + - /cluster: local - algorithm: dqn - environment: pufferlib_ocean/bandit - search_space: dqn_gym_classic diff --git a/mighty/configs/ppo_smac.yaml b/mighty/configs/ppo_smac.yaml index 6c2e19b6..40da7c69 100644 --- a/mighty/configs/ppo_smac.yaml +++ b/mighty/configs/ppo_smac.yaml @@ -1,5 +1,6 @@ defaults: - _self_ + - /cluster: local - algorithm: ppo_mujoco - environment: gymnasium/pendulum - search_space: ppo_rs @@ -45,6 +46,6 @@ hydra: output_directory: ${hydra.sweep.dir} search_space: ${search_space} run: - dir: ./tmp/branin_smac/ + dir: ${output_dir}/${experiment_name}_${seed} sweep: - dir: ./tmp/branin_smac/ \ No newline at end of file + dir: ${output_dir}/${experiment_name}_${seed} diff --git a/mighty/configs/sac_smac.yaml b/mighty/configs/sac_smac.yaml index 4cfa11f3..613efd26 100644 --- a/mighty/configs/sac_smac.yaml +++ b/mighty/configs/sac_smac.yaml @@ -1,5 +1,6 @@ defaults: - _self_ + - /cluster: local - algorithm: sac_mujoco - environment: gymnasium/pendulum - search_space: sac_rs @@ -45,6 +46,6 @@ hydra: output_directory: ${hydra.sweep.dir} search_space: ${search_space} run: - dir: ./tmp/branin_smac/ + dir: ${output_dir}/${experiment_name}_${seed} sweep: - dir: ./tmp/branin_smac/ \ No newline at end of file + dir: ${output_dir}/${experiment_name}_${seed} diff --git a/mighty/configs/sweep_ppo_pbt.yaml b/mighty/configs/sweep_ppo_pbt.yaml index 5d2fc903..3aba687f 100644 --- a/mighty/configs/sweep_ppo_pbt.yaml +++ b/mighty/configs/sweep_ppo_pbt.yaml @@ -1,5 +1,6 @@ defaults: - _self_ + - /cluster: local - algorithm: ppo - environment: gymnasium/pendulum - search_space: ppo_rs diff --git a/mighty/configs/sweep_rs.yaml b/mighty/configs/sweep_rs.yaml index d95a6236..650c3545 100644 --- a/mighty/configs/sweep_rs.yaml +++ b/mighty/configs/sweep_rs.yaml @@ -1,5 +1,6 @@ defaults: - _self_ + - /cluster: local - algorithm: ppo - environment: gymnasium/pendulum - search_space: ppo_rs diff --git a/mighty/mighty_agents/base_agent.py b/mighty/mighty_agents/base_agent.py index 255f1243..790a7c74 100644 --- a/mighty/mighty_agents/base_agent.py +++ b/mighty/mighty_agents/base_agent.py @@ -34,33 +34,14 @@ from gymnasium.wrappers.normalize import NormalizeObservation, NormalizeReward -def seed_everything(seed: int, env: gym.Env | None = None): - """ - Seed Python, NumPy, Torch (including cuDNN), plus Gym (action_space, observation_space, - and the environment's own RNG). If `env` is vectorized (has `env.envs`), we dig into each - sub-env. Always call this before ANY network-building or RNG usage. - """ - # 1) Python/NumPy/Torch/Hash - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - torch.use_deterministic_algorithms(True) # enforce strictly deterministic ops - os.environ["PYTHONHASHSEED"] = str(seed) - - # 2) Gym environment seeding - if env is not None: - sub_seeds = [seed for _ in range(len(env.envs))] - for sub_seed, subenv in zip(sub_seeds, env.envs): - subenv.action_space.seed(sub_seed) - subenv.observation_space.seed(sub_seed) - try: - subenv.unwrapped.seed(sub_seed) - except Exception: - pass - env.reset(seed=sub_seeds) +def seed_env_spaces(env: gym.VectorEnv, seed: int) -> None: + env.action_space.seed(seed) + env.single_action_space.seed(seed) + env.observation_space.seed(seed) + env.single_observation_space.seed(seed) + for i in range(len(env.envs)): + env.envs[i].action_space.seed(seed) + env.envs[i].observation_space.seed(seed) def update_buffer(buffer, new_data): @@ -199,28 +180,12 @@ def __init__( # noqa: PLR0915, PLR0912 self.seed = seed if self.seed is not None: - seed_everything(self.seed, env=None) # type: ignore - # Re-seed Python/NumPy/Torch here again. random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False os.environ["PYTHONHASHSEED"] = str(seed) - # Also seed any RNGs you will use later (e.g. for buffer / policy): - self.rng = np.random.default_rng(seed) # for any numpy-based sampling - # If you ever use torch.Generator for sampling actions, you could do: - self.torch_gen = torch.Generator().manual_seed(seed) - - # If you use Python’s random inside your policy, call random.seed(seed) - random.seed(seed) - - else: - self.rng = np.random.default_rng() - self.seed = 0 - # Replay Buffer replay_buffer_class = retrieve_class( cls=replay_buffer_class, @@ -260,6 +225,10 @@ def __init__( # noqa: PLR0915, PLR0912 else: self.eval_env = eval_env + if self.seed is not None: + seed_env_spaces(self.env, self.seed) + seed_env_spaces(self.eval_env, self.seed) + self.render_progress = render_progress self.output_dir = output_dir if self.output_dir is not None: @@ -269,9 +238,9 @@ def __init__( # noqa: PLR0915, PLR0912 self.meta_modules = {} for i, m in enumerate(meta_methods): meta_class = retrieve_class(cls=m, default_cls=None) # type: ignore - assert ( - meta_class is not None - ), f"Class {m} not found, did you specify the correct loading path?" + assert meta_class is not None, ( + f"Class {m} not found, did you specify the correct loading path?" + ) kwargs: Dict = {} if len(meta_kwargs) > i: kwargs = meta_kwargs[i] @@ -326,11 +295,13 @@ def __init__( # noqa: PLR0915, PLR0912 wandb.log(starting_hps) self.initialize_agent() + if self.seed is not None: + self.buffer.seed(self.seed) + self.policy.seed(self.seed) + for m in self.meta_modules.values(): + m.seed(self.seed) self.steps = 0 - seed_everything(self.seed, self.env) # type: ignore - seed_everything(self.seed, self.eval_env) # type: ignore - def _initialize_agent(self) -> None: """Agent/algorithm specific initializations.""" raise NotImplementedError diff --git a/mighty/mighty_exploration/mighty_exploration_policy.py b/mighty/mighty_exploration/mighty_exploration_policy.py index e73dc4a2..39948d5c 100644 --- a/mighty/mighty_exploration/mighty_exploration_policy.py +++ b/mighty/mighty_exploration/mighty_exploration_policy.py @@ -65,6 +65,10 @@ def __init__( else: self.sample_action = self.sample_func_logits + def seed(self, seed: int) -> None: + """Set the random seed for reproducibility.""" + self.rng = np.random.default_rng(seed) + def sample_func_q(self, state_array): """ Q-learning branch: diff --git a/mighty/mighty_meta/mighty_component.py b/mighty/mighty_meta/mighty_component.py index a303c9b2..7a47ac7e 100644 --- a/mighty/mighty_meta/mighty_component.py +++ b/mighty/mighty_meta/mighty_component.py @@ -2,6 +2,8 @@ from __future__ import annotations +import numpy as np + class MightyMetaComponent: """Component for registering meta-control methods.""" @@ -17,6 +19,7 @@ def __init__(self) -> None: self.post_update_methods = [] self.pre_episode_methods = [] self.post_episode_methods = [] + self.rng = np.random.default_rng() def pre_step(self, metrics): """Execute methods before a step. @@ -71,3 +74,7 @@ def post_episode(self, metrics): """ for m in self.post_episode_methods: m(metrics) + + def seed(self, seed: int) -> None: + """Set the random seed for reproducibility.""" + self.rng = np.random.default_rng(seed) diff --git a/mighty/mighty_meta/plr.py b/mighty/mighty_meta/plr.py index 5c6b4cfe..67d75bd4 100644 --- a/mighty/mighty_meta/plr.py +++ b/mighty/mighty_meta/plr.py @@ -43,7 +43,6 @@ def __init__( :return: """ super().__init__() - self.rng = np.random.default_rng() self.alpha = alpha self.rho = rho self.staleness_coef = staleness_coeff diff --git a/mighty/mighty_meta/space.py b/mighty/mighty_meta/space.py index 4fce0669..c5b92986 100644 --- a/mighty/mighty_meta/space.py +++ b/mighty/mighty_meta/space.py @@ -45,11 +45,11 @@ def get_instances(self, metrics): self.all_instances = np.array(env.instance_id_list.copy()) if self.last_evals is None and rollout_values is None: - self.instance_set = np.random.default_rng().choice( + self.instance_set = self.rng.choice( self.all_instances, size=self.current_instance_set_size ) elif self.last_evals is None: - self.instance_set = np.random.default_rng().choice( + self.instance_set = self.rng.choice( self.all_instances, size=self.current_instance_set_size ) self.last_evals = np.nanmean(rollout_values) diff --git a/mighty/mighty_models/networks.py b/mighty/mighty_models/networks.py index af44242c..3072468b 100644 --- a/mighty/mighty_models/networks.py +++ b/mighty/mighty_models/networks.py @@ -106,9 +106,7 @@ def __init__( def forward(self, x, transform: bool = True): """Forward pass.""" if transform: - x = ( - x.permute(2, 0, 1) if len(x.shape) == 3 else x.permute(0, 3, 1, 2) - ) # noqa: PLR2004 + x = x.permute(2, 0, 1) if len(x.shape) == 3 else x.permute(0, 3, 1, 2) # noqa: PLR2004 return self.cnn(x) def __getstate__(self): @@ -194,9 +192,7 @@ def __init__( def forward(self, x, transform: bool = True): """Forward pass.""" if transform: - x = ( - x.permute(2, 0, 1) if len(x.shape) == 3 else x.permute(0, 3, 1, 2) - ) # noqa: PLR2004 + x = x.permute(2, 0, 1) if len(x.shape) == 3 else x.permute(0, 3, 1, 2) # noqa: PLR2004 x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) @@ -221,9 +217,7 @@ def __init__(self, module1, module2): def forward(self, x): """Forward pass.""" - x = ( - x.permute(2, 0, 1) if len(x.shape) == 3 else x.permute(0, 3, 1, 2) - ) # noqa: PLR2004 + x = x.permute(2, 0, 1) if len(x.shape) == 3 else x.permute(0, 3, 1, 2) # noqa: PLR2004 x = self.module1(x, False) return self.module2(x) diff --git a/mighty/mighty_replay/buffer.py b/mighty/mighty_replay/buffer.py index 157b0377..d1c12ea5 100644 --- a/mighty/mighty_replay/buffer.py +++ b/mighty/mighty_replay/buffer.py @@ -1,5 +1,7 @@ from abc import ABC, abstractmethod +import numpy as np + class MightyBuffer(ABC): @abstractmethod @@ -21,3 +23,7 @@ def __len__(self): @abstractmethod def __bool__(self): pass + + def seed(self, seed: int): + """Set random seed.""" + self.rng = np.random.default_rng(seed) diff --git a/mighty/mighty_replay/mighty_prioritized_replay.py b/mighty/mighty_replay/mighty_prioritized_replay.py index e55ce4a4..6173c2e7 100644 --- a/mighty/mighty_replay/mighty_prioritized_replay.py +++ b/mighty/mighty_replay/mighty_prioritized_replay.py @@ -28,6 +28,7 @@ def __init__( self.beta = beta self.epsilon = epsilon self.device = torch.device(device) + self.rng = np.random.default_rng() super().__init__(capacity, keep_infos, flatten_infos, device) @@ -134,7 +135,7 @@ def sample(self, batch_size): for i in range(batch_size): a = segment * i b = segment * (i + 1) - s = np.random.uniform(a, b) + s = self.rng.uniform(a, b) leaf = self._retrieve(1, s) # leaf index in [capacity..2*capacity-1] data_idx = leaf - self.capacity # ring-buffer index batch_indices[i] = data_idx diff --git a/mighty/mighty_replay/mighty_rollout_buffer.py b/mighty/mighty_replay/mighty_rollout_buffer.py index c040c655..5dd57db3 100644 --- a/mighty/mighty_replay/mighty_rollout_buffer.py +++ b/mighty/mighty_replay/mighty_rollout_buffer.py @@ -188,6 +188,7 @@ def __init__( self.gae_lambda = gae_lambda self.discrete_action = discrete_action self.use_latents = use_latents # Store for later use + self.rng = np.random.default_rng() # Shapes ----------------------------------------------------------- if isinstance(obs_shape, int): @@ -327,7 +328,7 @@ def _flat(t: torch.Tensor | None): logp_f = _flat(self.log_probs) val_f = _flat(self.values) - perm = np.random.permutation(total) + perm = self.rng.permutation(total) perm = perm[: (total // batch_size) * batch_size].reshape(-1, batch_size) minibatches: list[RolloutBatch] = [] diff --git a/mighty/mighty_runners/mighty_es_runner.py b/mighty/mighty_runners/mighty_es_runner.py index 03340e7b..6dbddd09 100644 --- a/mighty/mighty_runners/mighty_es_runner.py +++ b/mighty/mighty_runners/mighty_es_runner.py @@ -1,7 +1,7 @@ from __future__ import annotations import importlib.util as iutil -from typing import TYPE_CHECKING, Dict, Tuple, Callable +from typing import TYPE_CHECKING, Callable, Dict, Tuple import numpy as np import torch diff --git a/mighty/mighty_runners/mighty_runner.py b/mighty/mighty_runners/mighty_runner.py index 51109e2d..0dafa17e 100644 --- a/mighty/mighty_runners/mighty_runner.py +++ b/mighty/mighty_runners/mighty_runner.py @@ -4,7 +4,7 @@ import warnings from abc import ABC from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Tuple, Callable +from typing import TYPE_CHECKING, Any, Callable, Dict, Tuple from hydra.utils import get_class diff --git a/mighty/mighty_update/sac_update.py b/mighty/mighty_update/sac_update.py index 9cee6998..aba3aa12 100644 --- a/mighty/mighty_update/sac_update.py +++ b/mighty/mighty_update/sac_update.py @@ -72,9 +72,7 @@ def calculate_td_error(self, transition: TransitionBatch) -> Tuple: transition.rewards, dtype=torch.float32 ).unsqueeze(-1) + ( 1 - torch.as_tensor(transition.dones, dtype=torch.float32).unsqueeze(-1) - ) * self.gamma * ( - torch.min(q1_t, q2_t) - alpha * logp_next - ) + ) * self.gamma * (torch.min(q1_t, q2_t) - alpha * logp_next) sa = torch.cat( [ torch.as_tensor(transition.observations, dtype=torch.float32), diff --git a/mighty/mighty_utils/plotting.py b/mighty/mighty_utils/plotting.py index ab7e4f6f..b43bb0a9 100644 --- a/mighty/mighty_utils/plotting.py +++ b/mighty/mighty_utils/plotting.py @@ -266,14 +266,10 @@ def plot_final_performance_comparison( aggregation_funcs = lambda x: np.array([metrics.aggregate_iqm(x)]) # noqa: E731 metric_names = ["IQM"] elif "mean" in aggregation: - aggregation_funcs = lambda x: np.array( - [metrics.aggregate_mean(x)] - ) # noqa: E731 + aggregation_funcs = lambda x: np.array([metrics.aggregate_mean(x)]) # noqa: E731 metric_names = ["Mean"] elif "median" in aggregation: - aggregation_funcs = lambda x: np.array( - [metrics.aggregate_median(x)] - ) # noqa: E731 + aggregation_funcs = lambda x: np.array([metrics.aggregate_median(x)]) # noqa: E731 metric_names = ["Median"] score_dict = {} diff --git a/mighty/mighty_utils/test_helpers.py b/mighty/mighty_utils/test_helpers.py index b17018f4..ef721f4d 100644 --- a/mighty/mighty_utils/test_helpers.py +++ b/mighty/mighty_utils/test_helpers.py @@ -27,13 +27,14 @@ def set_instance_set(self, instance_set): self.instance_set = instance_set def reset(self, options={}, seed=None): + super().reset(seed=seed if seed is not None else 0) if self.inst_id is None: - self.inst_id = np.random.default_rng().integers(0, 100) + self.inst_id = self._np_random.integers(0, 100) return self.observation_space.sample(), {} def step(self, action): - tr = np.random.default_rng().choice([0, 1], p=[0.9, 0.1]) - return self.observation_space.sample(), 0, False, tr, {} + tr = self._np_random.choice([0, 1], p=[0.9, 0.1]) + return self.observation_space.sample(), self._np_random.random(), False, tr, {} class DummyContinuousEnv(gym.Env): @@ -57,13 +58,14 @@ def set_instance_set(self, instance_set): self.instance_set = instance_set def reset(self, options={}, seed=None): + super().reset(seed=seed if seed is not None else 0) if self.inst_id is None: - self.inst_id = np.random.default_rng().integers(0, 100) + self.inst_id = self._np_random.integers(0, 100) return self.observation_space.sample(), {} def step(self, action): - tr = np.random.default_rng().choice([0, 1], p=[0.9, 0.1]) - return self.observation_space.sample(), np.random.rand(), False, tr, {} + tr = self._np_random.choice([0, 1], p=[0.9, 0.1]) + return self.observation_space.sample(), self._np_random.random(), False, tr, {} class DummyModel: diff --git a/pyproject.toml b/pyproject.toml index 05cf4c2a..9e7ff247 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ version = "0.0.1" description = "A modular, meta-learning-ready RL library." authors = [{ name = "AutoRL@LUHAI", email = "a.mohan@ai.uni-hannover.de" }] readme = "README.md" -requires-python = ">=3.10,<3.12" +requires-python = ">=3.11,<3.12" license = { file = "LICENSE" } keywords = [ "Reinforcement Learning", @@ -73,7 +73,7 @@ ignore = [ ] [tool.mypy] -python_version = "3.10" +python_version = "3.11" disallow_untyped_defs = true show_error_codes = true no_implicit_optional = true diff --git a/test/agents/test_dqn_agent.py b/test/agents/test_dqn_agent.py index 81238100..3ba2ea4a 100644 --- a/test/agents/test_dqn_agent.py +++ b/test/agents/test_dqn_agent.py @@ -24,12 +24,12 @@ def test_init(self): dqn = MightyDQNAgent(output_dir, env, use_target=False) assert isinstance(dqn.q, DQN), "Model should be an instance of DQN" assert isinstance(dqn.value_function, DQN), "Vf should be an instance of DQN" - assert isinstance( - dqn.policy, EpsilonGreedy - ), "Policy should be an instance of EpsilonGreedy" - assert isinstance( - dqn.qlearning, QLearning - ), "Update should be an instance of QLearning" + assert isinstance(dqn.policy, EpsilonGreedy), ( + "Policy should be an instance of EpsilonGreedy" + ) + assert isinstance(dqn.qlearning, QLearning), ( + "Update should be an instance of QLearning" + ) assert dqn.q_target is None, "Q_target should be None" test_obs, _ = env.reset() @@ -65,12 +65,12 @@ def test_init(self): q_pred = dqn.q(test_obs) target_pred = dqn.q_target(test_obs) assert torch.allclose(q_pred, target_pred), "Q and Q_target should be equal" - assert isinstance( - dqn.buffer, PrioritizedReplay - ), "Replay buffer should be an instance of PrioritizedReplay" - assert isinstance( - dqn.qlearning, DoubleQLearning - ), "Update should be an instance of DoubleQLearning" + assert isinstance(dqn.buffer, PrioritizedReplay), ( + "Replay buffer should be an instance of PrioritizedReplay" + ) + assert isinstance(dqn.qlearning, DoubleQLearning), ( + "Update should be an instance of DoubleQLearning" + ) assert dqn._batch_size == 32, "Batch size should be 32" assert dqn.learning_rate == 0.01, "Learning rate should be 0.01" assert dqn._epsilon == 0.1, "Epsilon should be 0.1" @@ -96,9 +96,9 @@ def test_update(self): for old, new in zip( original_target_params[:10], new_target_params[:10], strict=False ): - assert not torch.allclose( - old, new - ), "Target model parameters should be changed" + assert not torch.allclose(old, new), ( + "Target model parameters should be changed" + ) for old, new in zip( original_feature_params, dqn.q.feature_extractor.parameters(), strict=False ): @@ -106,16 +106,16 @@ def test_update(self): for old, new, new_target in zip( original_params[:10], new_params[:10], new_target_params[:10], strict=False ): - assert not torch.allclose( - old * (1 - 0.01) + new * 0.01, new_target - ), "Target model parameters should be scaled correctly" + assert not torch.allclose(old * (1 - 0.01) + new * 0.01, new_target), ( + "Target model parameters should be scaled correctly" + ) batch = dqn.buffer.sample(20) preds, targets = dqn.qlearning.get_targets(batch, dqn.q, dqn.q_target) - assert ( - np.mean(targets.detach().numpy() - metrics["Update/td_targets"]) < 0.1 - ), "TD_targets should be equal" + assert np.mean(targets.detach().numpy() - metrics["Update/td_targets"]) < 0.1, ( + "TD_targets should be equal" + ) assert ( np.mean((targets - preds).detach().numpy() - metrics["Update/td_errors"]) < 0.1 @@ -123,16 +123,16 @@ def test_update(self): original_optimizer.zero_grad() loss = F.mse_loss(preds, targets) - assert ( - np.mean(loss.detach().numpy() - metrics["Update/loss"]) < 0.05 - ), "Loss should be equal" + assert np.mean(loss.detach().numpy() - metrics["Update/loss"]) < 0.05, ( + "Loss should be equal" + ) loss.backward() original_optimizer.step() manual_params = deepcopy(list(dqn.q.parameters())) for manual, agent in zip(manual_params[:10], new_params[:10], strict=False): - assert torch.allclose( - manual, agent, atol=1e-2 - ), "Model parameters should be equal to manual update" + assert torch.allclose(manual, agent, atol=1e-2), ( + "Model parameters should be equal to manual update" + ) clean(output_dir) @@ -195,12 +195,12 @@ def process_transition(self): {}, ) - assert ( - len(metrics["rollout_values"]) == 2 - ), f"One value prediction per state, got: {metrics['rollout_values']}" - assert ( - len(metrics["td_error"]) == 2 - ), f"TD error should be computed per transition, got: {metrics['td_errors']}" + assert len(metrics["rollout_values"]) == 2, ( + f"One value prediction per state, got: {metrics['rollout_values']}" + ) + assert len(metrics["td_error"]) == 2, ( + f"TD error should be computed per transition, got: {metrics['td_errors']}" + ) state, _ = env.reset() action = dqn.policy(state, return_logp=False) @@ -209,9 +209,9 @@ def process_transition(self): metrics = dqn.process_transition( state, action, reward, next_state, te or tr, 0, metrics ) - assert ( - len(metrics["rollout_values"]) == 3 - ), "New value prediction should be added" + assert len(metrics["rollout_values"]) == 3, ( + "New value prediction should be added" + ) assert len(metrics["td_error"]) == 1, "TD error is overwritten" clean(output_dir) @@ -222,25 +222,53 @@ def test_reproducibility(self): dqn = MightyDQNAgent(output_dir, env, batch_size=2, seed=42) init_params = deepcopy(list(dqn.q.parameters())) dqn.run(20, 1) - batch = dqn.buffer.sample(20) - original_metrics = dqn.update_agent(batch, 0) + original_batch = dqn.buffer.sample(20) + original_metrics = dqn.update_agent(original_batch, 0) original_params = deepcopy(list(dqn.q.parameters())) - + for _ in range(3): env = gym.vector.SyncVectorEnv([DummyEnv for _ in range(1)]) output_dir = Path("test_dqn_agent") output_dir.mkdir(parents=True, exist_ok=True) dqn = MightyDQNAgent(output_dir, env, batch_size=2, seed=42) - for old, new in zip(init_params[:10], list(dqn.q.parameters())[:10], strict=False): - assert not torch.allclose(old, new), "Parameter initialization should be the same with same seed" + for old, new in zip( + init_params[:10], list(dqn.q.parameters())[:10], strict=False + ): + assert torch.allclose(old, new), ( + "Parameter initialization should be the same with same seed" + ) dqn.run(20, 1) batch = dqn.buffer.sample(20) + + for old, new in zip( + original_batch.observations, batch.observations, strict=False + ): + assert torch.allclose(old, new), ( + "Batch observations should be the same with same seed" + ) + + assert torch.allclose(original_batch.actions, batch.actions), ( + f"Batch actions should be the same with same seed: {torch.isclose(original_batch.actions, batch.actions)}" + ) + assert torch.allclose(original_batch.rewards, batch.rewards), ( + f"Batch rewards should be the same with same seed: {torch.isclose(original_batch.rewards, batch.rewards)}" + ) + new_metrics = dqn.update_agent(batch, 0) - for old, new in zip(original_params[:10], list(dqn.q.parameters())[:10], strict=False): - assert not torch.allclose(old, new), "Model parameters should stay the same with same seed" - - for old, new in zip(original_metrics["Update/td_targets"][:10], new_metrics["Update/td_targets"][:10], strict=False): - assert torch.allclose(old, new), "TD targets should be the same with same seed" - for old, new in zip(original_metrics["Update/td_errors"][:10], new_metrics["Update/td_errors"][:10], strict=False): - assert torch.allclose(old, new), "TD errors should be the same with same seed" - clean(output_dir) \ No newline at end of file + for old, new in zip( + original_params[:10], list(dqn.q.parameters())[:10], strict=False + ): + assert torch.allclose(old, new), ( + "Model parameters should stay the same with same seed" + ) + + assert np.allclose( + original_metrics["Update/td_targets"], new_metrics["Update/td_targets"] + ), "TD targets should be the same with same seed" + assert np.allclose( + original_metrics["Update/loss"], new_metrics["Update/loss"] + ), "Loss should be the same with same seed" + assert np.allclose( + original_metrics["Update/td_errors"], new_metrics["Update/td_errors"] + ), "TD errors should be the same with same seed" + clean(output_dir) diff --git a/test/agents/test_ppo_agent.py b/test/agents/test_ppo_agent.py index 635b0dc0..06b22761 100644 --- a/test/agents/test_ppo_agent.py +++ b/test/agents/test_ppo_agent.py @@ -94,9 +94,9 @@ def test_init_discrete(self): prediction = agent.step(test_obs, metrics)[0] assert len(prediction) == 1, "Prediction should have shape (1,)" - assert ( - 0 <= prediction[0] < 4 - ), "Action should be in valid range [0, 4)" # Updated for 4 actions + assert 0 <= prediction[0] < 4, ( + "Action should be in valid range [0, 4)" + ) # Updated for 4 actions clean(output_dir) @@ -166,9 +166,9 @@ def test_update(self): print(f"Buffer size after manual collection: {len(agent.buffer)}") # Ensure we have enough data in buffer - assert ( - len(agent.buffer) >= agent._batch_size - ), f"Buffer size {len(agent.buffer)} should be >= batch size {agent._batch_size}" + assert len(agent.buffer) >= agent._batch_size, ( + f"Buffer size {len(agent.buffer)} should be >= batch size {agent._batch_size}" + ) # Perform update update_kwargs = {"next_s": curr_s, "dones": np.array([False])} @@ -244,15 +244,15 @@ def test_properties(self): params = ppo.parameters assert isinstance(params, list), "Parameters should be a list" assert len(params) > 0, "Should have parameters" - assert all( - isinstance(p, torch.nn.Parameter) for p in params - ), "All should be Parameters" + assert all(isinstance(p, torch.nn.Parameter) for p in params), ( + "All should be Parameters" + ) # Test value_function property - updated for new structure value_fn = ppo.value_function - assert ( - value_fn is ppo.model.value_function_module - ), "Value function should be model's value function module" + assert value_fn is ppo.model.value_function_module, ( + "Value function should be model's value function module" + ) # Test that the value function wrapper works - use correct obs shape obs_shape = env.single_observation_space.shape[ @@ -310,15 +310,19 @@ def test_reproducibility(self): update_kwargs = {"next_s": curr_s, "dones": np.array([False])} original_metrics = ppo.update(metrics, update_kwargs) original_params = deepcopy(list(ppo.model.parameters())) - + for _ in range(3): env = gym.vector.SyncVectorEnv([DummyEnv for _ in range(1)]) output_dir = Path("test_ppo_agent") output_dir.mkdir(parents=True, exist_ok=True) ppo = MightyPPOAgent(output_dir, env, batch_size=2, seed=42) - for old, new in zip(init_params[:10], list(ppo.model.parameters())[:10], strict=False): - assert not torch.allclose(old, new), "Parameter initialization should be the same with same seed" - + for old, new in zip( + init_params[:10], list(ppo.model.parameters())[:10], strict=False + ): + assert torch.allclose(old, new), ( + "Parameter initialization should be the same with same seed" + ) + metrics = { "env": ppo.env, "step": 0, @@ -358,13 +362,24 @@ def test_reproducibility(self): update_kwargs = {"next_s": curr_s, "dones": np.array([False])} new_metrics = ppo.update(metrics, update_kwargs) - for old, new in zip(original_params[:10], list(ppo.model.parameters())[:10], strict=False): - assert not torch.allclose(old, new), "Model parameters should stay the same with same seed" + for old, new in zip( + original_params[:10], list(ppo.model.parameters())[:10], strict=False + ): + assert torch.allclose(old, new), ( + "Model parameters should stay the same with same seed" + ) - for old, new in zip(original_metrics["Update/value_loss"], new_metrics["Update/value_loss"], strict=False): - assert torch.allclose(old, new), "Value loss should be the same with same seed" - for old, new in zip(original_metrics["Update/policy_loss"], new_metrics["Update/policy_loss"], strict=False): - assert torch.allclose(old, new), "Policy loss should be the same with same seed" - for old, new in zip(original_metrics["Update/entropy"], new_metrics["Update/entropy"], strict=False): - assert torch.allclose(old, new), "Entropy should be the same with same seed" - clean(output_dir) \ No newline at end of file + assert np.isclose( + original_metrics["Update/policy_loss"], + new_metrics["Update/policy_loss"], + ), "Policy loss should be the same with same seed" + assert np.isclose( + original_metrics["Update/value_loss"], new_metrics["Update/value_loss"] + ), "Value loss should be the same with same seed" + assert np.allclose( + original_metrics["Update/approx_kl"], new_metrics["Update/approx_kl"] + ), "Approx KL should be the same with same seed" + assert np.allclose( + original_metrics["Update/entropy"], new_metrics["Update/entropy"] + ), "Entropy should be the same with same seed" + clean(output_dir) diff --git a/test/agents/test_sac_agent.py b/test/agents/test_sac_agent.py index d9285740..cf3d0d13 100644 --- a/test/agents/test_sac_agent.py +++ b/test/agents/test_sac_agent.py @@ -119,9 +119,9 @@ def test_update(self): print(f"Buffer size after manual collection: {len(agent.buffer)}") # Ensure we have enough data in buffer - assert ( - len(agent.buffer) >= agent.batch_size - ), f"Buffer size {len(agent.buffer)} should be >= batch size {agent.batch_size}" + assert len(agent.buffer) >= agent.batch_size, ( + f"Buffer size {len(agent.buffer)} should be >= batch size {agent.batch_size}" + ) # Set agent.steps to satisfy learning_starts condition agent.steps = agent.learning_starts + 1 @@ -180,11 +180,10 @@ def test_update(self): ) break - # Modified assertions - warn instead of fail for debugging - assert ( - policy_params_changed or q1_params_changed or q2_params_changed - ), "At least some parameters should change after update" + assert policy_params_changed or q1_params_changed or q2_params_changed, ( + "At least some parameters should change after update" + ) assert isinstance(result_metrics, dict), "Update should return metrics dict" # Check for expected SAC metrics in the result @@ -225,31 +224,31 @@ def test_properties(self): params = agent.parameters assert isinstance(params, list), "Parameters should be a list" assert len(params) > 0, "Should have parameters" - assert all( - isinstance(p, torch.nn.Parameter) for p in params - ), "All should be Parameters" + assert all(isinstance(p, torch.nn.Parameter) for p in params), ( + "All should be Parameters" + ) # Test that parameters include all three networks (policy, q1, q2) policy_params = list(agent.model.policy_net.parameters()) q1_params = list(agent.model.q_net1.parameters()) q2_params = list(agent.model.q_net2.parameters()) expected_param_count = len(policy_params) + len(q1_params) + len(q2_params) - assert ( - len(params) == expected_param_count - ), f"Expected {expected_param_count} parameters, got {len(params)}" + assert len(params) == expected_param_count, ( + f"Expected {expected_param_count} parameters, got {len(params)}" + ) # Test value_function property value_fn = agent.value_function - assert isinstance( - value_fn, torch.nn.Module - ), "Value function should be a torch module" + assert isinstance(value_fn, torch.nn.Module), ( + "Value function should be a torch module" + ) # Test that value function can be called with a state dummy_state = torch.randn(1, agent.env.single_observation_space.shape[0]) value_output = value_fn(dummy_state) - assert isinstance( - value_output, torch.Tensor - ), "Value function should return a tensor" + assert isinstance(value_output, torch.Tensor), ( + "Value function should return a tensor" + ) assert value_output.shape == ( 1, 1, @@ -257,9 +256,9 @@ def test_properties(self): # Test that value function is the cached module value_fn2 = agent.value_function - assert ( - value_fn is value_fn2 - ), "Value function should be cached and return same instance" + assert value_fn is value_fn2, ( + "Value function should be cached and return same instance" + ) clean(output_dir) @@ -267,30 +266,59 @@ def test_reproducibility(self): env = gym.vector.SyncVectorEnv([DummyContinuousEnv for _ in range(1)]) output_dir = Path("test_sac_agent") output_dir.mkdir(parents=True, exist_ok=True) - sac = MightySACAgent(output_dir, env, seed=42) + sac = MightySACAgent( + output_dir, env, seed=42, learning_starts=0, update_every=20 + ) init_params = deepcopy(list(sac.model.parameters())) sac.run(20, 1) batch = sac.buffer.sample(20) - original_metrics = sac.update_agent(batch, 0) + original_metrics = sac.update_agent(batch, 20) original_params = deepcopy(list(sac.model.parameters())) - + + env = gym.vector.SyncVectorEnv([DummyContinuousEnv for _ in range(1)]) + sac = MightySACAgent(output_dir, env, seed=42) + state, _ = sac.env.reset(seed=42) + step_state, _, _, _, _ = sac.env.step([0]) + env = gym.vector.SyncVectorEnv([DummyContinuousEnv for _ in range(1)]) + sac = MightySACAgent(output_dir, env, seed=42) + other_state, _ = sac.env.reset(seed=42) + other_step_state, _, _, _, _ = sac.env.step([0]) + assert np.allclose(state, other_state), "States should be equal with same seed" + assert np.allclose(step_state, other_step_state), ( + "Step states should be equal with same seed" + ) + for _ in range(3): env = gym.vector.SyncVectorEnv([DummyContinuousEnv for _ in range(1)]) output_dir = Path("test_sac_agent") output_dir.mkdir(parents=True, exist_ok=True) - sac = MightySACAgent(output_dir, env, seed=42) - for old, new in zip(init_params[:10], list(sac.model.parameters())[:10], strict=False): - assert not torch.allclose(old, new), "Parameter initialization should be the same with same seed" + sac = MightySACAgent( + output_dir, env, seed=42, learning_starts=0, update_every=20 + ) + for old, new in zip( + init_params[:10], list(sac.model.parameters())[:10], strict=False + ): + assert torch.allclose(old, new), ( + "Parameter initialization should be the same with same seed" + ) sac.run(20, 1) batch = sac.buffer.sample(20) - new_metrics = sac.update_agent(batch, 0) - for old, new in zip(original_params[:10], list(sac.model.parameters())[:10], strict=False): - assert not torch.allclose(old, new), "Model parameters should stay the same with same seed" - - for old, new in zip(original_metrics["Update/q_loss1"][:10], new_metrics["Update/q_loss1"][:10], strict=False): - assert torch.allclose(old, new), "Q1 loss should be the same with same seed" - for old, new in zip(original_metrics["Update/q_loss2"][:10], new_metrics["Update/q_loss2"][:10], strict=False): - assert torch.allclose(old, new), "Q2 loss should be the same with same seed" - for old, new in zip(original_metrics["Update/policy_loss"][:10], new_metrics["Update/policy_loss"][:10], strict=False): - assert torch.allclose(old, new), "Policy loss should be the same with same seed" + new_metrics = sac.update_agent(batch, 20) + for old, new in zip( + original_params[:10], list(sac.model.parameters())[:10], strict=False + ): + assert torch.allclose(old, new), ( + "Model parameters should stay the same with same seed" + ) + + assert np.isclose( + original_metrics["Update/q_loss1"], new_metrics["Update/q_loss1"] + ), "Q1 loss should be the same with same seed" + assert np.isclose( + original_metrics["Update/q_loss2"], new_metrics["Update/q_loss2"] + ), "Q2 loss should be the same with same seed" + assert np.isclose( + original_metrics["Update/policy_loss"], + new_metrics["Update/policy_loss"], + ), "Policy loss should be the same with same seed" clean(output_dir) diff --git a/test/exploration/test_epsilon_greedy.py b/test/exploration/test_epsilon_greedy.py index 4769b7cc..0e1864a9 100644 --- a/test/exploration/test_epsilon_greedy.py +++ b/test/exploration/test_epsilon_greedy.py @@ -27,30 +27,30 @@ def test_exploration_func(self, state): actions, qvals = policy.explore_func(state) greedy_actions, greedy_qvals = policy.sample_action(state) assert len(actions) == len(state), "Action should be predicted per state." - assert all( - a == g for g in greedy_actions for a in actions - ), f"Actions should match greedy: {actions}///{greedy_actions}" - assert torch.equal( - qvals, greedy_qvals - ), f"Q-values should match greedy: {qvals}///{greedy_qvals}" + assert all(a == g for g in greedy_actions for a in actions), ( + f"Actions should match greedy: {actions}///{greedy_actions}" + ) + assert torch.equal(qvals, greedy_qvals), ( + f"Q-values should match greedy: {qvals}///{greedy_qvals}" + ) policy = self.get_policy(epsilon=0.5) actions = np.array( [policy.explore_func(state)[0] for _ in range(100)] ).flatten() - assert ( - sum([a == 1 for a in actions]) / (100 * len(state)) > 0.5 - ), "Actions should match greedy at least in half of cases." - assert ( - sum([a == 1 for a in actions]) / (100 * len(state)) < 0.8 - ), "Actions should match greedy in less than 4/5 of cases." + assert sum([a == 1 for a in actions]) / (100 * len(state)) > 0.5, ( + "Actions should match greedy at least in half of cases." + ) + assert sum([a == 1 for a in actions]) / (100 * len(state)) < 0.8, ( + "Actions should match greedy in less than 4/5 of cases." + ) policy = self.get_policy(epsilon=np.linspace(0, 1, len(state))) actions = np.array([policy.explore_func(state)[0] for _ in range(100)]) assert all(actions[:, 0] == 1), "Low index actions should match greedy." - assert ( - sum(actions[:, -1] == 1) / 100 < 0.33 - ), "High index actions should not match greedy more than 1/3 of the time." + assert sum(actions[:, -1] == 1) / 100 < 0.33, ( + "High index actions should not match greedy more than 1/3 of the time." + ) @pytest.mark.parametrize( "state", @@ -67,6 +67,6 @@ def test_multiple_epsilons(self, state): policy = self.get_policy(epsilon=[0.1, 0.5]) assert np.all(policy.epsilon == [0.1, 0.5]), "Epsilon should be [0.1, 0.5]." action, _ = policy.explore_func(state) - assert len(action) == len( - state.numpy() - ), f"Action should be predicted per state: len({action}) != len({state.numpy()})." + assert len(action) == len(state.numpy()), ( + f"Action should be predicted per state: len({action}) != len({state.numpy()})." + ) diff --git a/test/exploration/test_exploration.py b/test/exploration/test_exploration.py index 817fc92f..ed35176e 100644 --- a/test/exploration/test_exploration.py +++ b/test/exploration/test_exploration.py @@ -29,16 +29,16 @@ def test_call(self, state): policy(state) greedy_actions, qvals = policy(state, evaluate=True, return_logp=True) - assert all( - greedy_actions == 1 - ), f"Greedy actions should be 1: {greedy_actions}///{qvals}" + assert all(greedy_actions == 1), ( + f"Greedy actions should be 1: {greedy_actions}///{qvals}" + ) assert qvals.shape[-1] == 5, "Q-value shape should not be changed." assert len(qvals) == len(state), "Q-value length should not be changed." policy = self.get_policy(action=3) greedy_actions, qvals = policy(state, evaluate=True, return_logp=True) - assert all( - greedy_actions == 3 - ), f"Greedy actions should be 3: {greedy_actions}///{qvals}" + assert all(greedy_actions == 3), ( + f"Greedy actions should be 3: {greedy_actions}///{qvals}" + ) assert qvals.shape[-1] == 5, "Q-value shape should not be changed." assert len(qvals) == len(state), "Q-value length should not be changed." diff --git a/test/exploration/test_ez_greedy.py b/test/exploration/test_ez_greedy.py index 37e04b33..c2aa09dd 100644 --- a/test/exploration/test_ez_greedy.py +++ b/test/exploration/test_ez_greedy.py @@ -19,15 +19,15 @@ def test_init(self) -> None: use_target=False, policy_class="mighty.mighty_exploration.EZGreedy", ) - assert isinstance( - dqn.policy, EZGreedy - ), "Policy should be an instance of EZGreedy when creating with string." + assert isinstance(dqn.policy, EZGreedy), ( + "Policy should be an instance of EZGreedy when creating with string." + ) assert dqn.policy.epsilon == 0.1, "Default epsilon should be 0.1." assert dqn.policy.zipf_param == 2, "Default zipf_param should be 2." assert dqn.policy.skipped is None, "Skip should be initialized at None." - assert ( - dqn.policy.frozen_actions is None - ), "Frozen actions should be initialized at None." + assert dqn.policy.frozen_actions is None, ( + "Frozen actions should be initialized at None." + ) dqn = MightyDQNAgent( output_dir, @@ -36,9 +36,9 @@ def test_init(self) -> None: policy_class=EZGreedy, policy_kwargs={"epsilon": [0.5, 0.3], "zipf_param": 3}, ) - assert isinstance( - dqn.policy, EZGreedy - ), "Policy should be an instance of EZGreedy when creating with class." + assert isinstance(dqn.policy, EZGreedy), ( + "Policy should be an instance of EZGreedy when creating with class." + ) assert np.all(dqn.policy.epsilon == [0.5, 0.3]), "Epsilon should be [0.5, 0.3]." assert dqn.policy.zipf_param == 3, "zipf_param should be 3." clean(output_dir) @@ -57,16 +57,16 @@ def test_skip_single(self) -> None: state, _ = env.reset() action = dqn.policy([state]) - assert np.all( - action < env.single_action_space.n - ), "Action should be within the action space." + assert np.all(action < env.single_action_space.n), ( + "Action should be within the action space." + ) assert len(action) == len(state), "Action should be predicted per state." dqn.policy.skipped = np.array([1]) next_action = dqn.policy([state]) - assert np.all( - action == next_action - ), "Action should be the same as the previous action when skip is active." + assert np.all(action == next_action), ( + "Action should be the same as the previous action when skip is active." + ) assert dqn.policy.skipped[0] == 0, "Skip should be decayed by one." clean(output_dir) @@ -84,16 +84,16 @@ def test_skip_batch(self) -> None: state, _ = env.reset() action = dqn.policy(state) - assert all( - [a < env.single_action_space.n for a in action] - ), "Actions should be within the action space." + assert all([a < env.single_action_space.n for a in action]), ( + "Actions should be within the action space." + ) assert len(action) == len(state), "Action should be predicted per state." dqn.policy.skipped = np.array([3, 0]) next_action = dqn.policy(state + 2) - assert ( - action[0] == next_action[0] - ), f"First action should be the same as the previous action when skip is active: {action[0]} != {next_action[0]}" + assert action[0] == next_action[0], ( + f"First action should be the same as the previous action when skip is active: {action[0]} != {next_action[0]}" + ) assert dqn.policy.skipped[0] == 2, "Skip should be decayed by one." assert dqn.policy.skipped[1] >= 0, "Skip should not be decayed below one." clean(output_dir) diff --git a/test/meta_components/test_cosine_schedule.py b/test/meta_components/test_cosine_schedule.py index d20bcaad..4c9cea4e 100644 --- a/test/meta_components/test_cosine_schedule.py +++ b/test/meta_components/test_cosine_schedule.py @@ -25,12 +25,12 @@ def test_decay(self) -> None: dqn.learning_rate = lr for i in range(4): metrics = dqn.run(n_steps=10 * (i + 1)) - assert ( - metrics["hp/lr"] == dqn.learning_rate - ), f"Learning rate should be set to schedule value {metrics['hp/lr']} instead of {dqn.learning_rate}." - assert ( - dqn.learning_rate < lr - ), f"Learning rate should decrease: {dqn.learning_rate} is not less than {lr}." + assert metrics["hp/lr"] == dqn.learning_rate, ( + f"Learning rate should be set to schedule value {metrics['hp/lr']} instead of {dqn.learning_rate}." + ) + assert dqn.learning_rate < lr, ( + f"Learning rate should decrease: {dqn.learning_rate} is not less than {lr}." + ) lr = dqn.learning_rate.copy() clean(output_dir) @@ -47,10 +47,10 @@ def test_restart(self) -> None: ], ) dqn.run(6, 0) - assert ( - dqn.meta_modules["CosineLRSchedule"].n_restarts == 1 - ), "Restart counter should increase." - assert ( - dqn.learning_rate >= dqn.meta_modules["CosineLRSchedule"].eta_max - ), f"Restart should increase learning rate: {dqn.learning_rate} is not {dqn.meta_modules['CosineLRSchedule'].eta_max}." + assert dqn.meta_modules["CosineLRSchedule"].n_restarts == 1, ( + "Restart counter should increase." + ) + assert dqn.learning_rate >= dqn.meta_modules["CosineLRSchedule"].eta_max, ( + f"Restart should increase learning rate: {dqn.learning_rate} is not {dqn.meta_modules['CosineLRSchedule'].eta_max}." + ) clean(output_dir) diff --git a/test/meta_components/test_noveld.py b/test/meta_components/test_noveld.py index 5a4ae171..5428bfc1 100644 --- a/test/meta_components/test_noveld.py +++ b/test/meta_components/test_noveld.py @@ -31,33 +31,33 @@ def test_init(self) -> None: ppo = self.init_NovelD() assert len(ppo.meta_modules) == 1, "There should be one meta module." assert "NovelD" in ppo.meta_modules, "NovelD should be in meta modules." - assert isinstance( - ppo.meta_modules["NovelD"], NovelD - ), "NovelD should be meta module when created from string." - assert ( - ppo.meta_modules["NovelD"].rnd_output_dim == 512 - ), "Default output dim should be 512." - assert ( - len(ppo.meta_modules["NovelD"].rnd_network_config) == 0 - ), f"Default network config should be empty, got {ppo.meta_modules['NovelD'].rnd_network_config}." - assert ( - ppo.meta_modules["NovelD"].internal_reward_weight == 0.1 - ), "Default internal reward weight should be 0.1." - assert ( - ppo.meta_modules["NovelD"].rnd_lr == 0.001 - ), "Default NovelD learning rate should be 0.001." - assert ( - ppo.meta_modules["NovelD"].rnd_eps == 1e-5 - ), "Default NovelD epsilon should be 1e-5." - assert ( - ppo.meta_modules["NovelD"].rnd_weight_decay == 0.01 - ), "Default NovelD weight decay should be 0.01." - assert ( - ppo.meta_modules["NovelD"].update_proportion == 0.5 - ), "Default update proportion should be 0.5." - assert ( - ppo.meta_modules["NovelD"].rnd_net is None - ), "NovelD network should be None." + assert isinstance(ppo.meta_modules["NovelD"], NovelD), ( + "NovelD should be meta module when created from string." + ) + assert ppo.meta_modules["NovelD"].rnd_output_dim == 512, ( + "Default output dim should be 512." + ) + assert len(ppo.meta_modules["NovelD"].rnd_network_config) == 0, ( + f"Default network config should be empty, got {ppo.meta_modules['NovelD'].rnd_network_config}." + ) + assert ppo.meta_modules["NovelD"].internal_reward_weight == 0.1, ( + "Default internal reward weight should be 0.1." + ) + assert ppo.meta_modules["NovelD"].rnd_lr == 0.001, ( + "Default NovelD learning rate should be 0.001." + ) + assert ppo.meta_modules["NovelD"].rnd_eps == 1e-5, ( + "Default NovelD epsilon should be 1e-5." + ) + assert ppo.meta_modules["NovelD"].rnd_weight_decay == 0.01, ( + "Default NovelD weight decay should be 0.01." + ) + assert ppo.meta_modules["NovelD"].update_proportion == 0.5, ( + "Default update proportion should be 0.5." + ) + assert ppo.meta_modules["NovelD"].rnd_net is None, ( + "NovelD network should be None." + ) env = gym.vector.SyncVectorEnv([DummyEnv for _ in range(1)]) output_dir = Path("test_NovelD") @@ -80,33 +80,33 @@ def test_init(self) -> None: ) assert len(ppo.meta_modules) == 1, "There should be one meta module." assert "NovelD" in ppo.meta_modules, "NovelD should be in meta modules." - assert isinstance( - ppo.meta_modules["NovelD"], NovelD - ), "NovelD should be meta module when created from class." - assert ( - ppo.meta_modules["NovelD"].rnd_output_dim == 12 - ), "Output dim should be 12." - assert ppo.meta_modules["NovelD"].rnd_network_config == { - "test": True - }, "Network config should be {'test': True}." - assert ( - ppo.meta_modules["NovelD"].internal_reward_weight == 0.5 - ), "Internal reward weight should be 0.5." - assert ( - ppo.meta_modules["NovelD"].rnd_lr == 0.2 - ), "NovelD learning rate should be 0.2." - assert ( - ppo.meta_modules["NovelD"].rnd_eps == 1e-4 - ), "NovelD epsilon should be 1e-4." - assert ( - ppo.meta_modules["NovelD"].rnd_weight_decay == 0.1 - ), "NovelD weight decay should be 0.1." - assert ( - ppo.meta_modules["NovelD"].update_proportion == 0.3 - ), "Update proportion should be 0.3." - assert ( - ppo.meta_modules["NovelD"].rnd_net is None - ), "NovelD network should be None." + assert isinstance(ppo.meta_modules["NovelD"], NovelD), ( + "NovelD should be meta module when created from class." + ) + assert ppo.meta_modules["NovelD"].rnd_output_dim == 12, ( + "Output dim should be 12." + ) + assert ppo.meta_modules["NovelD"].rnd_network_config == {"test": True}, ( + "Network config should be {'test': True}." + ) + assert ppo.meta_modules["NovelD"].internal_reward_weight == 0.5, ( + "Internal reward weight should be 0.5." + ) + assert ppo.meta_modules["NovelD"].rnd_lr == 0.2, ( + "NovelD learning rate should be 0.2." + ) + assert ppo.meta_modules["NovelD"].rnd_eps == 1e-4, ( + "NovelD epsilon should be 1e-4." + ) + assert ppo.meta_modules["NovelD"].rnd_weight_decay == 0.1, ( + "NovelD weight decay should be 0.1." + ) + assert ppo.meta_modules["NovelD"].update_proportion == 0.3, ( + "Update proportion should be 0.3." + ) + assert ppo.meta_modules["NovelD"].rnd_net is None, ( + "NovelD network should be None." + ) clean("test_NovelD") def test_reward_computation(self) -> None: @@ -120,15 +120,15 @@ def test_reward_computation(self) -> None: } } updated_metrics = ppo.meta_modules["NovelD"].get_reward(dummy_metrics) - assert ( - ppo.meta_modules["NovelD"].rnd_net is not None - ), "NovelD network should be initialized." - assert ( - ppo.meta_modules["NovelD"].last_error != 0 - ), "Last error should be non-zero." - assert ( - "intrinsic_reward" in updated_metrics["transition"] - ), "Intrinsic reward should be in updated metrics." + assert ppo.meta_modules["NovelD"].rnd_net is not None, ( + "NovelD network should be initialized." + ) + assert ppo.meta_modules["NovelD"].last_error != 0, ( + "Last error should be non-zero." + ) + assert "intrinsic_reward" in updated_metrics["transition"], ( + "Intrinsic reward should be in updated metrics." + ) assert sum(updated_metrics["transition"]["intrinsic_reward"]) == sum( updated_metrics["transition"]["reward"] ), "Intrinsic reward should be added to base reward." @@ -143,7 +143,9 @@ def test_reward_computation(self) -> None: * ppo.meta_modules["NovelD"].internal_reward_weight, updated_metrics["transition"]["reward"], atol=1e-3, - ), f"Intrinsic reward should be difference between current and last errors. Actual {abs(ppo.meta_modules['NovelD'].last_error - current_last_error) * ppo.meta_modules['NovelD'].internal_reward_weight}, assigned {updated_metrics['transition']['reward']}." + ), ( + f"Intrinsic reward should be difference between current and last errors. Actual {abs(ppo.meta_modules['NovelD'].last_error - current_last_error) * ppo.meta_modules['NovelD'].internal_reward_weight}, assigned {updated_metrics['transition']['reward']}." + ) clean("test_NovelD") def test_update(self) -> None: @@ -185,11 +187,11 @@ def test_update(self) -> None: new_target_params = ppo.meta_modules["NovelD"].rnd_net.target.parameters() for new_param, old_param in zip(new_predictor_params, predictor_params): - assert not torch.allclose( - new_param, old_param - ), "Predictor parameters should be updated." + assert not torch.allclose(new_param, old_param), ( + "Predictor parameters should be updated." + ) for new_param, old_param in zip(new_target_params, target_params): - assert torch.allclose( - new_param, old_param - ), "Target parameters should stay fixed." + assert torch.allclose(new_param, old_param), ( + "Target parameters should stay fixed." + ) clean("test_NovelD") diff --git a/test/meta_components/test_plr.py b/test/meta_components/test_plr.py index 4d0425e6..2498e318 100644 --- a/test/meta_components/test_plr.py +++ b/test/meta_components/test_plr.py @@ -73,16 +73,16 @@ def test_init(self) -> None: assert plr.score_transform == "max", "Score transform should be set to max." assert plr.temperature == 0.8, "Temperature should be set to 0.8." assert plr.eps == 1e-5, "Epsilon should be set to 1e-5." - assert ( - plr.staleness_transform == "max" - ), "Staleness transform should be set to max." - assert ( - plr.staleness_temperature == 0.8 - ), "Staleness temperature should be set to 0.8." + assert plr.staleness_transform == "max", ( + "Staleness transform should be set to max." + ) + assert plr.staleness_temperature == 0.8, ( + "Staleness temperature should be set to 0.8." + ) - assert ( - plr.instance_scores == {} - ), "Instance scores should be an empty dictionary." + assert plr.instance_scores == {}, ( + "Instance scores should be an empty dictionary." + ) assert plr.staleness == {}, "Staleness should be an empty dictionary." assert plr.all_instances is None, "All instances should be None." assert plr.index == 0, "Index should be 0." @@ -95,22 +95,22 @@ def test_get_instance(self) -> None: metrics = {"env": EnvSim(1)} plr.get_instance(metrics=metrics) assert metrics["env"].inst_ids is not None, "Instance should not be None." - assert ( - metrics["env"].inst_ids in plr.instance_scores.keys() - ), "Instance should be in instance scores." + assert metrics["env"].inst_ids in plr.instance_scores.keys(), ( + "Instance should be in instance scores." + ) assert plr.all_instances is not None, "All instances should be initialized." - assert all( - [i in plr.all_instances for i in plr.instance_scores.keys()] - ), "All instances should be in instance scores." + assert all([i in plr.all_instances for i in plr.instance_scores.keys()]), ( + "All instances should be in instance scores." + ) plr.sample_strategy = "sequential" index = plr.index metrics = {"env": EnvSim(10)} original_instance = 10 plr.get_instance(metrics=metrics) - assert ( - original_instance != metrics["env"].inst_ids - ), "Instance should be changed." + assert original_instance != metrics["env"].inst_ids, ( + "Instance should be changed." + ) assert plr.index == index + 1, "Index should be incremented by 1." def test_sample_weights(self) -> None: @@ -118,9 +118,9 @@ def test_sample_weights(self) -> None: for m in DUMMY_METRICS: plr.add_rollout(m) weights = plr.sample_weights() - assert len(weights) == len( - plr.instance_scores.keys() - ), "Length of weights should be equal to the number of instances." + assert len(weights) == len(plr.instance_scores.keys()), ( + "Length of weights should be equal to the number of instances." + ) assert np.isclose(sum(weights), 1), "Sum of weights should be 1." def test_score_transforms(self) -> None: @@ -137,9 +137,9 @@ def test_score_transforms(self) -> None: ]: plr.score_transform = score_transform weights = plr.sample_weights() - assert len(weights) == len( - plr.instance_scores.keys() - ), "Length of weights should be equal to the number of instances." + assert len(weights) == len(plr.instance_scores.keys()), ( + "Length of weights should be equal to the number of instances." + ) assert np.isclose(sum(weights), 1), "Sum of weights should be 1." @pytest.mark.parametrize("metrics", DUMMY_METRICS) @@ -162,22 +162,22 @@ def test_score_function(self, metrics) -> None: plr.add_rollout(metrics) else: plr.add_rollout(metrics) - assert ( - plr.instance_scores[metrics["env"].inst_ids[0]] is not None - ), "Instance score should not be None." + assert plr.instance_scores[metrics["env"].inst_ids[0]] is not None, ( + "Instance score should not be None." + ) if score_func == "random": - assert ( - plr.instance_scores[0] == 1.0 - ), f"Random score should be 1. Scores were: {plr.instance_scores}" + assert plr.instance_scores[0] == 1.0, ( + f"Random score should be 1. Scores were: {plr.instance_scores}" + ) @pytest.mark.parametrize("metrics", DUMMY_METRICS) def test_add_rollout(self, metrics) -> None: plr = PLR() plr.add_rollout(metrics) - assert ( - metrics["env"].inst_ids[0] in plr.instance_scores - ), "Instance should be added to instance scores." + assert metrics["env"].inst_ids[0] in plr.instance_scores, ( + "Instance should be added to instance scores." + ) def test_in_loop(self) -> None: env = ContextualVecEnv([DummyEnv for _ in range(2)]) @@ -189,13 +189,13 @@ def test_in_loop(self) -> None: use_target=False, meta_methods=["mighty.mighty_meta.PrioritizedLevelReplay"], ) - assert ( - dqn.meta_modules["PrioritizedLevelReplay"] is not None - ), "PLR should be initialized." + assert dqn.meta_modules["PrioritizedLevelReplay"] is not None, ( + "PLR should be initialized." + ) dqn.run(100, 0) - assert ( - dqn.meta_modules["PrioritizedLevelReplay"].all_instances is not None - ), "All instances should be initialized." + assert dqn.meta_modules["PrioritizedLevelReplay"].all_instances is not None, ( + "All instances should be initialized." + ) assert ( env.inst_ids[0] in dqn.meta_modules["PrioritizedLevelReplay"].all_instances ), "Instance should be in all instances." diff --git a/test/meta_components/test_rnd.py b/test/meta_components/test_rnd.py index 06b33b8f..df0d906b 100644 --- a/test/meta_components/test_rnd.py +++ b/test/meta_components/test_rnd.py @@ -31,30 +31,30 @@ def test_init(self) -> None: ppo = self.init_rnd() assert len(ppo.meta_modules) == 1, "There should be one meta module." assert "RND" in ppo.meta_modules, "RND should be in meta modules." - assert isinstance( - ppo.meta_modules["RND"], RND - ), "RND should be meta module when created from string." - assert ( - ppo.meta_modules["RND"].rnd_output_dim == 512 - ), "Default output dim should be 512." - assert ( - len(ppo.meta_modules["RND"].rnd_network_config) == 0 - ), f"Default network config should be empty, got {ppo.meta_modules['RND'].rnd_network_config}." - assert ( - ppo.meta_modules["RND"].internal_reward_weight == 0.1 - ), "Default internal reward weight should be 0.1." - assert ( - ppo.meta_modules["RND"].rnd_lr == 0.001 - ), "Default RND learning rate should be 0.001." - assert ( - ppo.meta_modules["RND"].rnd_eps == 1e-5 - ), "Default RND epsilon should be 1e-5." - assert ( - ppo.meta_modules["RND"].rnd_weight_decay == 0.01 - ), "Default RND weight decay should be 0.01." - assert ( - ppo.meta_modules["RND"].update_proportion == 0.5 - ), "Default update proportion should be 0.5." + assert isinstance(ppo.meta_modules["RND"], RND), ( + "RND should be meta module when created from string." + ) + assert ppo.meta_modules["RND"].rnd_output_dim == 512, ( + "Default output dim should be 512." + ) + assert len(ppo.meta_modules["RND"].rnd_network_config) == 0, ( + f"Default network config should be empty, got {ppo.meta_modules['RND'].rnd_network_config}." + ) + assert ppo.meta_modules["RND"].internal_reward_weight == 0.1, ( + "Default internal reward weight should be 0.1." + ) + assert ppo.meta_modules["RND"].rnd_lr == 0.001, ( + "Default RND learning rate should be 0.001." + ) + assert ppo.meta_modules["RND"].rnd_eps == 1e-5, ( + "Default RND epsilon should be 1e-5." + ) + assert ppo.meta_modules["RND"].rnd_weight_decay == 0.01, ( + "Default RND weight decay should be 0.01." + ) + assert ppo.meta_modules["RND"].update_proportion == 0.5, ( + "Default update proportion should be 0.5." + ) assert ppo.meta_modules["RND"].rnd_net is None, "RND network should be None." env = gym.vector.SyncVectorEnv([DummyEnv for _ in range(1)]) @@ -78,24 +78,24 @@ def test_init(self) -> None: ) assert len(ppo.meta_modules) == 1, "There should be one meta module." assert "RND" in ppo.meta_modules, "RND should be in meta modules." - assert isinstance( - ppo.meta_modules["RND"], RND - ), "RND should be meta module when created from class." + assert isinstance(ppo.meta_modules["RND"], RND), ( + "RND should be meta module when created from class." + ) assert ppo.meta_modules["RND"].rnd_output_dim == 12, "Output dim should be 12." - assert ppo.meta_modules["RND"].rnd_network_config == { - "test": True - }, "Network config should be {'test': True}." - assert ( - ppo.meta_modules["RND"].internal_reward_weight == 0.5 - ), "Internal reward weight should be 0.5." + assert ppo.meta_modules["RND"].rnd_network_config == {"test": True}, ( + "Network config should be {'test': True}." + ) + assert ppo.meta_modules["RND"].internal_reward_weight == 0.5, ( + "Internal reward weight should be 0.5." + ) assert ppo.meta_modules["RND"].rnd_lr == 0.2, "RND learning rate should be 0.2." assert ppo.meta_modules["RND"].rnd_eps == 1e-4, "RND epsilon should be 1e-4." - assert ( - ppo.meta_modules["RND"].rnd_weight_decay == 0.1 - ), "RND weight decay should be 0.1." - assert ( - ppo.meta_modules["RND"].update_proportion == 0.3 - ), "Update proportion should be 0.3." + assert ppo.meta_modules["RND"].rnd_weight_decay == 0.1, ( + "RND weight decay should be 0.1." + ) + assert ppo.meta_modules["RND"].update_proportion == 0.3, ( + "Update proportion should be 0.3." + ) assert ppo.meta_modules["RND"].rnd_net is None, "RND network should be None." clean("test_rnd") @@ -110,12 +110,12 @@ def test_reward_computation(self) -> None: } } updated_metrics = ppo.meta_modules["RND"].get_reward(dummy_metrics) - assert ( - ppo.meta_modules["RND"].rnd_net is not None - ), "RND network should be initialized." - assert ( - "intrinsic_reward" in updated_metrics["transition"] - ), "Intrinsic reward should be in updated metrics." + assert ppo.meta_modules["RND"].rnd_net is not None, ( + "RND network should be initialized." + ) + assert "intrinsic_reward" in updated_metrics["transition"], ( + "Intrinsic reward should be in updated metrics." + ) assert sum(updated_metrics["transition"]["intrinsic_reward"]) == sum( updated_metrics["transition"]["reward"] ), "Intrinsic reward should be added to base reward." @@ -160,11 +160,11 @@ def test_update(self) -> None: new_target_params = ppo.meta_modules["RND"].rnd_net.target.parameters() for new_param, old_param in zip(new_predictor_params, predictor_params): - assert not torch.allclose( - new_param, old_param - ), "Predictor parameters should be updated." + assert not torch.allclose(new_param, old_param), ( + "Predictor parameters should be updated." + ) for new_param, old_param in zip(new_target_params, target_params): - assert torch.allclose( - new_param, old_param - ), "Target parameters should stay fixed." + assert torch.allclose(new_param, old_param), ( + "Target parameters should stay fixed." + ) clean("test_rnd") diff --git a/test/meta_components/test_space.py b/test/meta_components/test_space.py index be03bd98..c352d2a2 100644 --- a/test/meta_components/test_space.py +++ b/test/meta_components/test_space.py @@ -25,12 +25,12 @@ def test_get_instances(self) -> None: "rollout_values": [[0.0, 0.6, 0.7]], } space.get_instances(metrics) - assert ( - len(space.all_instances) == 1 - ), f"Expected 1, got {len(space.all_instances)}" - assert ( - len(space.instance_set) == 1 - ), f"Expected 1, got {len(space.instance_set)}" + assert len(space.all_instances) == 1, ( + f"Expected 1, got {len(space.all_instances)}" + ) + assert len(space.instance_set) == 1, ( + f"Expected 1, got {len(space.instance_set)}" + ) assert space.last_evals is not None, "Evals should not be None." def test_get_evals(self) -> None: @@ -53,10 +53,10 @@ def test_in_loop(self) -> None: ) assert dqn.meta_modules["SPaCE"] is not None, "SPaCE should be initialized." dqn.run(100, 0) - assert ( - dqn.meta_modules["SPaCE"].all_instances is not None - ), "All instances should be initialized." - assert ( - env.inst_ids[0] in dqn.meta_modules["SPaCE"].all_instances - ), "Instance should be in all instances." + assert dqn.meta_modules["SPaCE"].all_instances is not None, ( + "All instances should be initialized." + ) + assert env.inst_ids[0] in dqn.meta_modules["SPaCE"].all_instances, ( + "Instance should be in all instances." + ) clean(output_dir) diff --git a/test/models/test_networks.py b/test/models/test_networks.py index d9bd54b2..61b71b97 100644 --- a/test/models/test_networks.py +++ b/test/models/test_networks.py @@ -115,15 +115,15 @@ def test_init(self, input_size, n_layers, hidden_sizes, activation): assert isinstance(mlp, torch.jit.ScriptModule), "MLP is not a ScriptModule." for n in range(n_layers): - assert ( - type(mlp.layers[2 * n]) is torch.nn.Linear - ), f"Layer {n} is not a Linear." - assert ( - type(mlp.layers[2 * n + 1]) is ACTIVATIONS[activation] - ), f"Activation {n} is not correct." - assert ( - mlp.layers[2 * n].out_features == hidden_sizes[n] - ), f"Wrong in_features in layer {n}." + assert type(mlp.layers[2 * n]) is torch.nn.Linear, ( + f"Layer {n} is not a Linear." + ) + assert type(mlp.layers[2 * n + 1]) is ACTIVATIONS[activation], ( + f"Activation {n} is not correct." + ) + assert mlp.layers[2 * n].out_features == hidden_sizes[n], ( + f"Wrong in_features in layer {n}." + ) with pytest.raises(IndexError): mlp.layers[2 * n + 2] @@ -138,21 +138,21 @@ def test_soft_reset(self): ] mlp.soft_reset(0, 0.5, 0.5) reset_pred = mlp(dummy_input) - assert torch.allclose( - original_pred, reset_pred - ), "Model prediction has changed." - assert torch.allclose( - mlp.layers[0].weight, prev_model_weights[0] - ), "Weights have reset in layer 0 even though probability was 0." - assert torch.allclose( - mlp.layers[2].weight, prev_model_weights[1] - ), "Weights have reset in layer 1 even though probability was 0." + assert torch.allclose(original_pred, reset_pred), ( + "Model prediction has changed." + ) + assert torch.allclose(mlp.layers[0].weight, prev_model_weights[0]), ( + "Weights have reset in layer 0 even though probability was 0." + ) + assert torch.allclose(mlp.layers[2].weight, prev_model_weights[1]), ( + "Weights have reset in layer 1 even though probability was 0." + ) mlp.soft_reset(1, 0.5, 0.5) reset_pred = mlp(dummy_input) - assert ~torch.allclose( - original_pred, reset_pred - ), "Model prediction has not changed." + assert ~torch.allclose(original_pred, reset_pred), ( + "Model prediction has not changed." + ) assert not any( torch.isclose(mlp.layers[0].weight, prev_model_weights[0]) .flatten() @@ -170,23 +170,23 @@ def test_soft_reset(self): original_pred = mlp(dummy_input) mlp.soft_reset(1, 1, 0) reset_pred = mlp(dummy_input) - assert torch.allclose( - original_pred, reset_pred - ), "Model prediction has changed though perturb was 0." + assert torch.allclose(original_pred, reset_pred), ( + "Model prediction has changed though perturb was 0." + ) for new_param, old_param in zip(mlp.parameters(), prev_params, strict=False): - assert torch.allclose( - new_param, old_param - ), "Weights have been reset even though perturb was 0." + assert torch.allclose(new_param, old_param), ( + "Weights have been reset even though perturb was 0." + ) mlp.soft_reset(1, 0.5, 0.0) reset_pred = mlp(dummy_input) - assert ( - original_pred * 0.5 - reset_pred - ).sum() < 0.1, "Model prediction didn't shrink with parameter value." + assert (original_pred * 0.5 - reset_pred).sum() < 0.1, ( + "Model prediction didn't shrink with parameter value." + ) for new_param, old_param in zip(mlp.parameters(), prev_params, strict=False): - assert torch.allclose( - new_param, old_param * 0.5 - ), "Weights have not been shrunk." + assert torch.allclose(new_param, old_param * 0.5), ( + "Weights have not been shrunk." + ) mlp.layers[0].weight = deepcopy(prev_model_weights[0]) mlp.layers[2].weight = deepcopy(prev_model_weights[1]) @@ -194,13 +194,13 @@ def test_soft_reset(self): original_pred = mlp(dummy_input) mlp.soft_reset(1, 1, 1) reset_pred = mlp(dummy_input) - assert not torch.allclose( - original_pred, reset_pred - ), "Model prediction has not changed." + assert not torch.allclose(original_pred, reset_pred), ( + "Model prediction has not changed." + ) for new_param, old_param in zip(mlp.parameters(), prev_params, strict=False): - assert not torch.allclose( - new_param, old_param - ), "Weights have not been perturbed." + assert not torch.allclose(new_param, old_param), ( + "Weights have not been perturbed." + ) mlp = MLP(3, 5, [100, 100, 100, 100, 100], "relu") n_reset = 0 @@ -229,15 +229,15 @@ def test_hard_reset(self): ] mlp.full_hard_reset() reset_pred = mlp(dummy_input) - assert ~torch.allclose( - original_pred, reset_pred - ), "Model prediction has not changed." - assert ~torch.allclose( - mlp.layers[0].weight, prev_model_weights[0] - ), "Weights have not been reset in layer 0." - assert ~torch.allclose( - mlp.layers[2].weight, prev_model_weights[1] - ), "Weights have not been reset in layer 1." + assert ~torch.allclose(original_pred, reset_pred), ( + "Model prediction has not changed." + ) + assert ~torch.allclose(mlp.layers[0].weight, prev_model_weights[0]), ( + "Weights have not been reset in layer 0." + ) + assert ~torch.allclose(mlp.layers[2].weight, prev_model_weights[1]), ( + "Weights have not been reset in layer 1." + ) def test_reset(self): dummy_input = torch.rand(1, 3) @@ -249,15 +249,15 @@ def test_reset(self): ] mlp.reset(1) reset_pred = mlp(dummy_input) - assert ~torch.allclose( - original_pred, reset_pred - ), "Model prediction has not changed." - assert ~torch.allclose( - mlp.layers[0].weight, prev_model_weights[0] - ), "Weights have not been reset in layer 0." - assert ~torch.allclose( - mlp.layers[2].weight, prev_model_weights[1] - ), "Weights have not been reset in layer 1." + assert ~torch.allclose(original_pred, reset_pred), ( + "Model prediction has not changed." + ) + assert ~torch.allclose(mlp.layers[0].weight, prev_model_weights[0]), ( + "Weights have not been reset in layer 0." + ) + assert ~torch.allclose(mlp.layers[2].weight, prev_model_weights[1]), ( + "Weights have not been reset in layer 1." + ) mlp = MLP(3, 5, [100, 100, 100, 100, 100], "relu") original_pred = mlp(dummy_input) @@ -279,9 +279,9 @@ def test_reset(self): n_reset += 1 reset_param.data = old_param.data - assert ~torch.allclose( - original_pred, reset_pred - ), "Model prediction has not changed." + assert ~torch.allclose(original_pred, reset_pred), ( + "Model prediction has not changed." + ) assert n_reset / total_params >= 0.25, "Weights reset too rarely in soft reset." assert n_reset / total_params <= 0.75, "Weights reset too often in soft reset." @@ -291,9 +291,9 @@ def test_forward(self): output = mlp(dummy_input) assert output.shape == (3, 5), "Output shape is not correct." assert output.dtype == torch.float32, "Output dtype is not correct." - assert torch.allclose( - output, mlp.layers(dummy_input) - ), "Forward is not correct." + assert torch.allclose(output, mlp.layers(dummy_input)), ( + "Forward is not correct." + ) class TestCNN: @@ -313,9 +313,9 @@ def test_init(self): ) cnn = CNN((64, 64, 3), 3, [32, 64, 64], [8, 4, 3], [4, 2, 1], [0, 0, 0], "relu") combo = ComboNet(cnn, mlp) - assert isinstance( - combo, torch.jit.ScriptModule - ), "ComboNet is not a ScriptModule." + assert isinstance(combo, torch.jit.ScriptModule), ( + "ComboNet is not a ScriptModule." + ) assert combo.module1 == cnn, "CNN is not the first module." assert combo.module2 == mlp, "MLP is not the second module." diff --git a/test/models/test_ppo_networks.py b/test/models/test_ppo_networks.py index e2cd79d9..d294370f 100644 --- a/test/models/test_ppo_networks.py +++ b/test/models/test_ppo_networks.py @@ -17,21 +17,21 @@ def test_init_discrete(self): assert ppo.tanh_squash is False, "Default tanh_squash should be False" # Check network structure - updated for new architecture - assert hasattr( - ppo, "feature_extractor_policy" - ), "Should have policy feature extractor" - assert hasattr( - ppo, "feature_extractor_value" - ), "Should have value feature extractor" - assert isinstance( - ppo.policy_head, nn.Sequential - ), "Policy head should be Sequential" - assert isinstance( - ppo.value_head, nn.Sequential - ), "Value head should be Sequential" - assert hasattr( - ppo, "value_function_module" - ), "Should have value function module wrapper" + assert hasattr(ppo, "feature_extractor_policy"), ( + "Should have policy feature extractor" + ) + assert hasattr(ppo, "feature_extractor_value"), ( + "Should have value feature extractor" + ) + assert isinstance(ppo.policy_head, nn.Sequential), ( + "Policy head should be Sequential" + ) + assert isinstance(ppo.value_head, nn.Sequential), ( + "Value head should be Sequential" + ) + assert hasattr(ppo, "value_function_module"), ( + "Should have value function module wrapper" + ) # Test forward pass shapes dummy_input = torch.rand((10, 4)) @@ -72,9 +72,9 @@ def test_init_continuous_tanh_squash(self): assert values.shape == (5, 1), "Values should have shape (5, 1)" # Check that actions are in [-1, 1] range due to tanh - assert torch.all(action >= -1.0) and torch.all( - action <= 1.0 - ), "Actions should be in [-1, 1] range" + assert torch.all(action >= -1.0) and torch.all(action <= 1.0), ( + "Actions should be in [-1, 1] range" + ) # Check log_std clamping assert torch.all(log_std >= ppo.log_std_min), "Log_std should be >= log_std_min" @@ -148,9 +148,9 @@ def test_value_function_module(self): values_direct = ppo.forward_value(dummy_input) values_module = ppo.value_function_module(dummy_input) - assert torch.allclose( - values_direct, values_module - ), "Value function module should produce same output as forward_value" + assert torch.allclose(values_direct, values_module), ( + "Value function module should produce same output as forward_value" + ) assert values_module.shape == ( 8, 1, @@ -190,15 +190,15 @@ def test_forward_continuous_tanh_squash(self): assert torch.all(torch.isfinite(log_std)), "Log_stds should be finite" # Check tanh constraint on actions - assert torch.all(action >= -1.0) and torch.all( - action <= 1.0 - ), "Actions should be in [-1, 1]" + assert torch.all(action >= -1.0) and torch.all(action <= 1.0), ( + "Actions should be in [-1, 1]" + ) # Check relationship: action = tanh(z) where z = mean + std * eps expected_action = torch.tanh(z) - assert torch.allclose( - action, expected_action, atol=1e-6 - ), "Action should equal tanh(z)" + assert torch.allclose(action, expected_action, atol=1e-6), ( + "Action should equal tanh(z)" + ) def test_forward_continuous_standard(self): """Test forward pass for continuous actions with standard PPO.""" @@ -219,11 +219,6 @@ def test_forward_continuous_standard(self): assert torch.all(torch.isfinite(mean)), "Means should be finite" assert torch.all(torch.isfinite(log_std)), "Log_stds should be finite" - # Check relationship: action = mean + std * eps (no tanh) - std = torch.exp(log_std) - # We can't check exact relationship due to random sampling, but verify no tanh constraint - # Actions should not be constrained to [-1, 1] in standard PPO - def test_forward_value(self): """Test value network forward pass.""" ppo = PPOModel(obs_shape=5, action_size=3, continuous_action=False) @@ -251,12 +246,12 @@ def test_custom_log_std_bounds(self): dummy_input = torch.rand((10, 4)) _, _, _, log_std = ppo_tanh(dummy_input) - assert torch.all( - log_std >= log_std_min - ), "Log_std should be >= custom log_std_min" - assert torch.all( - log_std <= log_std_max - ), "Log_std should be <= custom log_std_max" + assert torch.all(log_std >= log_std_min), ( + "Log_std should be >= custom log_std_min" + ) + assert torch.all(log_std <= log_std_max), ( + "Log_std should be <= custom log_std_max" + ) # Test with standard PPO ppo_std = PPOModel( @@ -270,12 +265,12 @@ def test_custom_log_std_bounds(self): _, _, log_std = ppo_std(dummy_input) - assert torch.all( - log_std >= log_std_min - ), "Log_std should be >= custom log_std_min (standard PPO)" - assert torch.all( - log_std <= log_std_max - ), "Log_std should be <= custom log_std_max (standard PPO)" + assert torch.all(log_std >= log_std_min), ( + "Log_std should be >= custom log_std_min (standard PPO)" + ) + assert torch.all(log_std <= log_std_max), ( + "Log_std should be <= custom log_std_max (standard PPO)" + ) def test_deterministic_with_same_input(self): """Test that same input produces different outputs due to sampling.""" @@ -294,12 +289,12 @@ def test_deterministic_with_same_input(self): assert torch.allclose(log_std1, log_std2), "Log_stds should be identical" # Actions and z should be different due to random sampling - assert not torch.allclose( - action1, action2 - ), "Actions should be different due to sampling" - assert not torch.allclose( - z1, z2 - ), "Raw actions should be different due to sampling" + assert not torch.allclose(action1, action2), ( + "Actions should be different due to sampling" + ) + assert not torch.allclose(z1, z2), ( + "Raw actions should be different due to sampling" + ) # Test standard PPO mode ppo_std = PPOModel( @@ -310,17 +305,17 @@ def test_deterministic_with_same_input(self): action2_std, mean2_std, log_std2_std = ppo_std(dummy_input) # Mean and log_std should be the same (deterministic) - assert torch.allclose( - mean1_std, mean2_std - ), "Means should be identical (standard PPO)" - assert torch.allclose( - log_std1_std, log_std2_std - ), "Log_stds should be identical (standard PPO)" + assert torch.allclose(mean1_std, mean2_std), ( + "Means should be identical (standard PPO)" + ) + assert torch.allclose(log_std1_std, log_std2_std), ( + "Log_stds should be identical (standard PPO)" + ) # Actions should be different due to random sampling - assert not torch.allclose( - action1_std, action2_std - ), "Actions should be different due to sampling (standard PPO)" + assert not torch.allclose(action1_std, action2_std), ( + "Actions should be different due to sampling (standard PPO)" + ) def test_orthogonal_initialization(self): """Test that weights are initialized with orthogonal initialization.""" @@ -334,9 +329,9 @@ def test_orthogonal_initialization(self): module.weight, torch.zeros_like(module.weight) ), "Weights should not be all zeros" # Check that biases are initialized to zero - assert torch.allclose( - module.bias, torch.zeros_like(module.bias) - ), "Biases should be initialized to zero" + assert torch.allclose(module.bias, torch.zeros_like(module.bias)), ( + "Biases should be initialized to zero" + ) def test_separate_feature_extractors(self): """Test that policy and value networks have separate feature extractors.""" @@ -348,14 +343,14 @@ def test_separate_feature_extractors(self): value_features = ppo.feature_extractor_value(dummy_input) # They should have the same shape but potentially different values - assert ( - policy_features.shape == value_features.shape - ), "Feature extractors should output same shape" + assert policy_features.shape == value_features.shape, ( + "Feature extractors should output same shape" + ) # Verify they are separate networks by checking if they are different objects - assert ( - ppo.feature_extractor_policy is not ppo.feature_extractor_value - ), "Policy and value feature extractors should be separate objects" + assert ppo.feature_extractor_policy is not ppo.feature_extractor_value, ( + "Policy and value feature extractors should be separate objects" + ) def test_different_architectures(self): """Test with different hidden layer configurations.""" diff --git a/test/models/test_q_networks.py b/test/models/test_q_networks.py index d079502a..46ad3b45 100644 --- a/test/models/test_q_networks.py +++ b/test/models/test_q_networks.py @@ -13,18 +13,18 @@ def test_init(self): dqn = DQN(num_actions=4, obs_size=3) assert dqn.num_actions == 4, "Num_actions should be 4" assert dqn.dueling is False, "Dueling should be False" - assert isinstance( - dqn.feature_extractor, MLP - ), "Default feature extractor should be an instance of MLP" - assert isinstance( - dqn.head, torch.nn.Sequential - ), "Head should be a nn.Sequential" - assert isinstance( - dqn.value, torch.nn.Linear - ), "Value layer should be a nn.Linear" - assert isinstance( - dqn.advantage, torch.nn.Linear - ), "Advantage layer should be a nn.Linear" + assert isinstance(dqn.feature_extractor, MLP), ( + "Default feature extractor should be an instance of MLP" + ) + assert isinstance(dqn.head, torch.nn.Sequential), ( + "Head should be a nn.Sequential" + ) + assert isinstance(dqn.value, torch.nn.Linear), ( + "Value layer should be a nn.Linear" + ) + assert isinstance(dqn.advantage, torch.nn.Linear), ( + "Advantage layer should be a nn.Linear" + ) dummy_input = torch.rand((1, 3)) assert dqn(dummy_input).shape == (1, 4), "Output should have shape (1, 4)" @@ -45,21 +45,21 @@ def test_init(self): ) assert dqn.num_actions == 2, "Num_actions should be 4" assert dqn.dueling is True, "Dueling should be True" - assert isinstance( - dqn.feature_extractor, CNN - ), "Feature extractor should be a CNN" - assert ( - len(dqn.feature_extractor.cnn) == 5 - ), "Feature extractor should have 2 convolutions" - assert ( - dqn.feature_extractor.cnn[0].out_channels == 32 - ), "First convolution should have 32 output channels" + assert isinstance(dqn.feature_extractor, CNN), ( + "Feature extractor should be a CNN" + ) + assert len(dqn.feature_extractor.cnn) == 5, ( + "Feature extractor should have 2 convolutions" + ) + assert dqn.feature_extractor.cnn[0].out_channels == 32, ( + "First convolution should have 32 output channels" + ) assert dqn.head[0].out_features == 32, "Head layer 1 should have hidden size 32" assert dqn.head[2].out_features == 32, "Head layer 2 should have hidden size 32" assert dqn.value.out_features == 1, "Value layer should have 1 output feature" - assert ( - dqn.advantage.out_features == 2 - ), "Advantage layer should have 2 output features" + assert dqn.advantage.out_features == 2, ( + "Advantage layer should have 2 output features" + ) dummy_input = torch.rand((5, 64, 64, 3)) assert dqn(dummy_input).shape == (5, 2), "Output should have shape (1, 2)" @@ -87,9 +87,9 @@ def test_forward(self): + calculated_advantage - calculated_advantage.mean(dim=1, keepdim=True) ) - assert torch.allclose( - dqn_pred, calculated_pred - ), "Prediction should be equal to value + advantage - mean_advantage" + assert torch.allclose(dqn_pred, calculated_pred), ( + "Prediction should be equal to value + advantage - mean_advantage" + ) def test_reset_head(self): head_kwargs = {"hidden_sizes": [32, 32]} @@ -114,9 +114,9 @@ def test_reset_head(self): new_features = dqn.feature_extractor(dummy_input) new_pred = dqn(dummy_input) - assert torch.allclose( - original_features, new_features - ), "Features should be equal" + assert torch.allclose(original_features, new_features), ( + "Features should be equal" + ) assert ~torch.allclose(original_pred, new_pred), "Predictions should differ" def test_shrink_weights(self): @@ -128,21 +128,21 @@ def test_shrink_weights(self): for new_param, old_param in zip( dqn.head.parameters(), prev_head_params, strict=False ): - assert torch.allclose( - new_param, old_param * 0.5 - ), "Weights have not been shrunk." + assert torch.allclose(new_param, old_param * 0.5), ( + "Weights have not been shrunk." + ) for new_param, old_param in zip( dqn.advantage.parameters(), prev_adv_params, strict=False ): - assert torch.allclose( - new_param, old_param * 0.5 - ), "Advantage weights have not been shrunk." + assert torch.allclose(new_param, old_param * 0.5), ( + "Advantage weights have not been shrunk." + ) for new_param, old_param in zip( dqn.value.parameters(), prev_value_params, strict=False ): - assert torch.allclose( - new_param, old_param * 0.5 - ), "Value weights have not been shrunk." + assert torch.allclose(new_param, old_param * 0.5), ( + "Value weights have not been shrunk." + ) def test_get_state(self): dqn = DQN(num_actions=4, obs_size=3, dueling=True) @@ -160,18 +160,18 @@ def test_set_state(self): state_dict = dqn.state_dict() dqn2 = DQN(num_actions=4, obs_size=3, dueling=True) original_pred = dqn2(dummy_input) - assert ~torch.allclose( - baseline_pred, original_pred - ), "Predictions should be different before loading" + assert ~torch.allclose(baseline_pred, original_pred), ( + "Predictions should be different before loading" + ) for p1, p2 in zip(dqn.parameters(), dqn2.parameters(), strict=False): - assert not torch.allclose( - p1, p2 - ), "Parameters should be different before loading" + assert not torch.allclose(p1, p2), ( + "Parameters should be different before loading" + ) dqn2.load_state_dict(state_dict) new_pred = dqn2(dummy_input) - assert torch.allclose( - baseline_pred, new_pred - ), "Predictions should be equal after loading" + assert torch.allclose(baseline_pred, new_pred), ( + "Predictions should be equal after loading" + ) for p1, p2 in zip(dqn.parameters(), dqn2.parameters(), strict=False): assert torch.allclose(p1, p2), "Parameters should be equal after loading" diff --git a/test/models/test_sac_networks.py b/test/models/test_sac_networks.py index 161bd374..6c8a123b 100644 --- a/test/models/test_sac_networks.py +++ b/test/models/test_sac_networks.py @@ -20,40 +20,40 @@ def test_init(self): # Check network structure - updated for new architecture assert hasattr(sac, "feature_extractor"), "Should have feature extractor" - assert isinstance( - sac.policy_net, nn.Linear - ), "Policy network should be Linear (after feature extractor)" + assert isinstance(sac.policy_net, nn.Linear), ( + "Policy network should be Linear (after feature extractor)" + ) assert isinstance(sac.q_net1, nn.Sequential), "Q-network 1 should be Sequential" assert isinstance(sac.q_net2, nn.Sequential), "Q-network 2 should be Sequential" - assert isinstance( - sac.target_q_net1, nn.Sequential - ), "Target Q-network 1 should be Sequential" - assert isinstance( - sac.target_q_net2, nn.Sequential - ), "Target Q-network 2 should be Sequential" - assert hasattr( - sac, "value_function_module" - ), "Should have value function module wrapper" + assert isinstance(sac.target_q_net1, nn.Sequential), ( + "Target Q-network 1 should be Sequential" + ) + assert isinstance(sac.target_q_net2, nn.Sequential), ( + "Target Q-network 2 should be Sequential" + ) + assert hasattr(sac, "value_function_module"), ( + "Should have value function module wrapper" + ) # Check that target networks have gradients disabled for param in sac.target_q_net1.parameters(): - assert ( - not param.requires_grad - ), "Target Q-network 1 parameters should not require gradients" + assert not param.requires_grad, ( + "Target Q-network 1 parameters should not require gradients" + ) for param in sac.target_q_net2.parameters(): - assert ( - not param.requires_grad - ), "Target Q-network 2 parameters should not require gradients" + assert not param.requires_grad, ( + "Target Q-network 2 parameters should not require gradients" + ) # Check that live networks have gradients enabled for param in sac.q_net1.parameters(): - assert ( - param.requires_grad - ), "Q-network 1 parameters should require gradients" + assert param.requires_grad, ( + "Q-network 1 parameters should require gradients" + ) for param in sac.q_net2.parameters(): - assert ( - param.requires_grad - ), "Q-network 2 parameters should require gradients" + assert param.requires_grad, ( + "Q-network 2 parameters should require gradients" + ) def test_init_custom_params(self): """Test initialization with custom parameters.""" @@ -106,9 +106,9 @@ def test_value_function_module(self): values_module = sac.value_function_module(dummy_state) values_direct = sac.forward_value(dummy_state) - assert torch.allclose( - values_module, values_direct - ), "Value function module should produce same output as forward_value" + assert torch.allclose(values_module, values_direct), ( + "Value function module should produce same output as forward_value" + ) assert values_module.shape == ( 8, 1, @@ -134,9 +134,9 @@ def test_forward_stochastic(self): assert torch.all(torch.isfinite(log_std)), "Log_stds should be finite" # Check tanh constraint on actions - assert torch.all(action >= -1.0) and torch.all( - action <= 1.0 - ), "Actions should be in [-1, 1] range" + assert torch.all(action >= -1.0) and torch.all(action <= 1.0), ( + "Actions should be in [-1, 1] range" + ) # Check log_std clamping assert torch.all(log_std >= sac.log_std_min), "Log_std should be >= log_std_min" @@ -144,9 +144,9 @@ def test_forward_stochastic(self): # Check relationship: action = tanh(z) expected_action = torch.tanh(z) - assert torch.allclose( - action, expected_action, atol=1e-6 - ), "Action should equal tanh(z)" + assert torch.allclose(action, expected_action, atol=1e-6), ( + "Action should equal tanh(z)" + ) def test_forward_deterministic(self): """Test forward pass with deterministic policy.""" @@ -166,9 +166,9 @@ def test_forward_deterministic(self): # Action should still be tanh(z) = tanh(mean) expected_action = torch.tanh(mean) - assert torch.allclose( - action, expected_action - ), "Action should equal tanh(mean) in deterministic mode" + assert torch.allclose(action, expected_action), ( + "Action should equal tanh(mean) in deterministic mode" + ) def test_stochastic_vs_deterministic(self): """Test that stochastic and deterministic modes produce different results.""" @@ -185,18 +185,18 @@ def test_stochastic_vs_deterministic(self): # Mean and log_std should be the same assert torch.allclose(mean_stoch, mean_det), "Means should be identical" - assert torch.allclose( - log_std_stoch, log_std_det - ), "Log_stds should be identical" + assert torch.allclose(log_std_stoch, log_std_det), ( + "Log_stds should be identical" + ) # In deterministic mode, z should equal mean assert torch.allclose(z_det, mean_det), "Deterministic z should equal mean" # Stochastic z should likely be different from mean (due to noise) # Note: There's a tiny chance they could be the same, but extremely unlikely - assert not torch.allclose( - z_stoch, mean_stoch - ), "Stochastic z should be different from mean" + assert not torch.allclose(z_stoch, mean_stoch), ( + "Stochastic z should be different from mean" + ) def test_policy_log_prob(self): """Test policy log probability calculation.""" @@ -215,9 +215,9 @@ def test_policy_log_prob(self): # Test with deterministic actions (z = mean) log_prob_det = sac.policy_log_prob(mean, mean, log_std) - assert torch.all( - torch.isfinite(log_prob_det) - ), "Deterministic log probs should be finite" + assert torch.all(torch.isfinite(log_prob_det)), ( + "Deterministic log probs should be finite" + ) def test_q_networks(self): """Test Q-network forward passes.""" @@ -247,16 +247,16 @@ def test_target_networks_initialization(self): for p1, p_target1 in zip( sac.q_net1.parameters(), sac.target_q_net1.parameters() ): - assert torch.allclose( - p1, p_target1 - ), "Target Q-net 1 should have same initial weights as Q-net 1" + assert torch.allclose(p1, p_target1), ( + "Target Q-net 1 should have same initial weights as Q-net 1" + ) for p2, p_target2 in zip( sac.q_net2.parameters(), sac.target_q_net2.parameters() ): - assert torch.allclose( - p2, p_target2 - ), "Target Q-net 2 should have same initial weights as Q-net 2" + assert torch.allclose(p2, p_target2), ( + "Target Q-net 2 should have same initial weights as Q-net 2" + ) def test_twin_q_networks_independence(self): """Test that twin Q-networks are independent.""" @@ -264,9 +264,9 @@ def test_twin_q_networks_independence(self): # Check that Q-networks have different parameters (due to random initialization) assert sac.q_net1 is not sac.q_net2, "Q-networks should be separate objects" - assert ( - sac.target_q_net1 is not sac.target_q_net2 - ), "Target Q-networks should be separate objects" + assert sac.target_q_net1 is not sac.target_q_net2, ( + "Target Q-networks should be separate objects" + ) def test_log_std_bounds_enforcement(self): """Test that log_std bounds are properly enforced.""" @@ -279,12 +279,12 @@ def test_log_std_bounds_enforcement(self): dummy_state = torch.rand((10, 3)) _, _, _, log_std = sac(dummy_state) - assert torch.all( - log_std >= log_std_min - ), "Log_std should be >= custom log_std_min" - assert torch.all( - log_std <= log_std_max - ), "Log_std should be <= custom log_std_max" + assert torch.all(log_std >= log_std_min), ( + "Log_std should be >= custom log_std_min" + ) + assert torch.all(log_std <= log_std_max), ( + "Log_std should be <= custom log_std_max" + ) def test_gradient_flow(self): """Test that gradients flow properly through networks.""" @@ -303,9 +303,9 @@ def test_gradient_flow(self): feature_has_grad = any( p.grad is not None for p in sac.feature_extractor.parameters() ) - assert ( - policy_has_grad or feature_has_grad - ), "Policy network or feature extractor should have gradients" + assert policy_has_grad or feature_has_grad, ( + "Policy network or feature extractor should have gradients" + ) # Test Q-network gradients sac.zero_grad() @@ -333,9 +333,9 @@ def test_numerical_stability(self): # Test log probability calculation doesn't produce NaN or inf log_prob = sac.policy_log_prob(z, mean, log_std) - assert torch.all( - torch.isfinite(log_prob) - ), "Log probabilities should be finite even with extreme inputs" + assert torch.all(torch.isfinite(log_prob)), ( + "Log probabilities should be finite even with extreme inputs" + ) # Test with actions close to boundary values (-1, 1) boundary_z = torch.tensor( @@ -347,6 +347,6 @@ def test_numerical_stability(self): boundary_log_prob = sac.policy_log_prob( boundary_z, boundary_mean, boundary_log_std ) - assert torch.all( - torch.isfinite(boundary_log_prob) - ), "Log probabilities should be finite for boundary actions" + assert torch.all(torch.isfinite(boundary_log_prob)), ( + "Log probabilities should be finite for boundary actions" + ) diff --git a/test/replay/test_buffer.py b/test/replay/test_buffer.py index e11b8152..d6cdfc6e 100644 --- a/test/replay/test_buffer.py +++ b/test/replay/test_buffer.py @@ -59,28 +59,28 @@ def test_init(self, observations, actions, rewards, next_observations, dones, si batch = TransitionBatch( observations, actions, rewards, next_observations, dones ) - assert isinstance( - batch.observations, torch.Tensor - ), "Observations were not a tensor." + assert isinstance(batch.observations, torch.Tensor), ( + "Observations were not a tensor." + ) assert isinstance(batch.actions, torch.Tensor), "Actions were not a tensor." assert isinstance(batch.rewards, torch.Tensor), "Rewards were not a tensor." - assert isinstance( - batch.next_obs, torch.Tensor - ), "Next observations were not a tensor." + assert isinstance(batch.next_obs, torch.Tensor), ( + "Next observations were not a tensor." + ) assert isinstance(batch.dones, torch.Tensor), "Dones were not a tensor." - assert ( - len(batch.observations.shape) == 2 - ), f"Observation shape was not 2D: {batch.observations.shape}." - assert ( - batch.observations.shape == batch.next_obs.shape - ), "Observation shape was not equal to next observation shape." - assert ( - batch.actions.shape == batch.rewards.shape - ), "Action shape was not equal to reward shape." - assert ( - batch.actions.shape == batch.dones.shape - ), "Action shape was not equal to reward shape." + assert len(batch.observations.shape) == 2, ( + f"Observation shape was not 2D: {batch.observations.shape}." + ) + assert batch.observations.shape == batch.next_obs.shape, ( + "Observation shape was not equal to next observation shape." + ) + assert batch.actions.shape == batch.rewards.shape, ( + "Action shape was not equal to reward shape." + ) + assert batch.actions.shape == batch.dones.shape, ( + "Action shape was not equal to reward shape." + ) assert ( len(batch.actions.shape) == len(batch.observations.shape) - 1 ), f"""Action shape was not one less than observation shape: @@ -98,9 +98,9 @@ def test_iter(self, observations, actions, rewards, next_observations, dones, si elements = 0 for obs, act, rew, next_obs, done in batch: assert obs.numpy() in observations, "Observation was not in observations." - assert ( - next_obs.numpy() in next_observations - ), "Next observation was not in next_observations." + assert next_obs.numpy() in next_observations, ( + "Next observation was not in next_observations." + ) if isinstance(actions, int): assert act.numpy().item() == actions, "Action was not in actions." assert rew.numpy().item() == rewards, "Reward was not in rewards." @@ -146,17 +146,17 @@ def test_add(self, observations, actions, rewards, next_observations, dones, siz replay = self.get_replay(batch, size, empty=True) filled_replay = self.get_replay(batch, size) assert len(replay) == 0, "Empty replay length was not 0." - assert ( - len(filled_replay) == size - ), "Filled replay length was not equal to batch size." + assert len(filled_replay) == size, ( + "Filled replay length was not equal to batch size." + ) replay.add(batch, {}) - assert len(replay) == len( - filled_replay - ), "Replay length was not equal to batch size." - assert ( - replay.index == filled_replay.index - ), "Replay index was not equal to filled replay index." + assert len(replay) == len(filled_replay), ( + "Replay length was not equal to batch size." + ) + assert replay.index == filled_replay.index, ( + "Replay index was not equal to filled replay index." + ) assert all( any(torch.equal(obs, ob) for obs in batch.observations) for ob in replay.obs ), "Observations were not added to replay." @@ -201,9 +201,9 @@ def test_sample(self): replay = self.get_replay(batch, size) minibatch = replay.sample(batch_size=1) assert len(minibatch) == 1, "Minibatch length was incorrect (batch size 1)." - assert isinstance( - minibatch, TransitionBatch - ), "Minibatch was not a TransitionBatch." + assert isinstance(minibatch, TransitionBatch), ( + "Minibatch was not a TransitionBatch." + ) assert all( any(torch.allclose(obs, ob) for obs in batch.observations) for ob in minibatch.observations @@ -228,9 +228,9 @@ def test_sample(self): minibatch = replay.sample(batch_size=2) assert len(minibatch) == 2, "Minibatch length was incorrect (batch size 2)." - assert isinstance( - minibatch, TransitionBatch - ), "Minibatch was not a TransitionBatch." + assert isinstance(minibatch, TransitionBatch), ( + "Minibatch was not a TransitionBatch." + ) assert all( any(torch.allclose(obs, ob) for obs in batch.observations) for ob in minibatch.observations @@ -254,9 +254,9 @@ def test_sample(self): batchset = [replay.sample(batch_size=1) for _ in range(10)] all_actions = [act for batch in batchset for act in batch.actions] - assert ~all( - x == all_actions[0] for x in all_actions - ), "All sampled batches were the same." + assert ~all(x == all_actions[0] for x in all_actions), ( + "All sampled batches were the same." + ) def test_reset(self): ( @@ -295,9 +295,9 @@ def test_len(self): assert len(replay) == size, "Replay length was not equal to batch size." replay.add(batch, {}) - assert ( - len(replay) == size * 2 - ), "Replay length was not doubled after doubling transitions." + assert len(replay) == size * 2, ( + "Replay length was not doubled after doubling transitions." + ) replay = self.get_replay(batch, size, empty=True) assert len(replay) == 0, "Replay length of empty replay was not 0." @@ -316,42 +316,42 @@ def test_full(self): ) replay = self.get_replay(batch, size, full=False) assert replay.full is False, "Replay was falsely full." - assert replay.capacity > len( - replay - ), "Replay capacity was not greater than length in non-full replay." - assert ( - replay.index < replay.capacity - ), "Replay index was not less than capacity in non-full replay." + assert replay.capacity > len(replay), ( + "Replay capacity was not greater than length in non-full replay." + ) + assert replay.index < replay.capacity, ( + "Replay index was not less than capacity in non-full replay." + ) replay = self.get_replay(batch, size, full=True) assert replay.full is True, "Replay was not full." - assert replay.capacity == len( - replay - ), "Replay capacity was not equal to length in full replay." - assert ( - replay.index == replay.capacity - ), "Replay index was not equal to capacity in full replay." + assert replay.capacity == len(replay), ( + "Replay capacity was not equal to length in full replay." + ) + assert replay.index == replay.capacity, ( + "Replay index was not equal to capacity in full replay." + ) second_batch = TransitionBatch( observations * 2, actions * 2, rewards * 2, next_observations * 2, dones * 2 ) replay.add(second_batch, {}) - assert ( - replay.full is True - ), "Replay was not full anymore after adding more transitions." - assert replay.capacity == len( - replay - ), "Replay capacity was not equal to length after adding more transitions." - assert ( - replay.index == replay.capacity - ), "Replay index was not equal to capacity after adding more transitions." + assert replay.full is True, ( + "Replay was not full anymore after adding more transitions." + ) + assert replay.capacity == len(replay), ( + "Replay capacity was not equal to length after adding more transitions." + ) + assert replay.index == replay.capacity, ( + "Replay index was not equal to capacity after adding more transitions." + ) for obs, act, rew, next_obs, done in second_batch: assert obs in replay.obs, f"Observation {obs} was not in replay." assert act in replay.actions, f"Action {act} was not in replay." assert rew in replay.rewards, f"Reward {rew} was not in replay." - assert ( - next_obs in replay.next_obs - ), f"Next observation {next_obs} was not in replay." + assert next_obs in replay.next_obs, ( + f"Next observation {next_obs} was not in replay." + ) assert done in replay.dones, f"Done {done} was not in replay." def test_empty(self): @@ -388,27 +388,27 @@ def test_save(self): replay.save("test_replay.pkl") with open("test_replay.pkl", "rb") as f: loaded_replay = pkl.load(f) - assert ( - replay.capacity == loaded_replay.capacity - ), "Replay capacity was not loaded correctly." - assert ( - replay.index == loaded_replay.index - ), "Replay index was not loaded correctly." - assert torch.allclose( - replay.obs, loaded_replay.obs - ), "Replay observations were not loaded correctly." - assert torch.allclose( - replay.actions, loaded_replay.actions - ), "Replay actions were not loaded correctly." - assert torch.allclose( - replay.rewards, loaded_replay.rewards - ), "Replay rewards were not loaded correctly." - assert torch.allclose( - replay.next_obs, loaded_replay.next_obs - ), "Replay next observations were not loaded correctly." - assert torch.allclose( - replay.dones, loaded_replay.dones - ), "Replay dones were not loaded correctly." + assert replay.capacity == loaded_replay.capacity, ( + "Replay capacity was not loaded correctly." + ) + assert replay.index == loaded_replay.index, ( + "Replay index was not loaded correctly." + ) + assert torch.allclose(replay.obs, loaded_replay.obs), ( + "Replay observations were not loaded correctly." + ) + assert torch.allclose(replay.actions, loaded_replay.actions), ( + "Replay actions were not loaded correctly." + ) + assert torch.allclose(replay.rewards, loaded_replay.rewards), ( + "Replay rewards were not loaded correctly." + ) + assert torch.allclose(replay.next_obs, loaded_replay.next_obs), ( + "Replay next observations were not loaded correctly." + ) + assert torch.allclose(replay.dones, loaded_replay.dones), ( + "Replay dones were not loaded correctly." + ) Path("test_replay.pkl").unlink() @@ -504,21 +504,21 @@ def test_add(self, observations, actions, rewards, next_observations, dones, siz filled_replay = self.get_replay(batch, size) assert replay.current_size == 0, "Empty replay length was not 0." - assert ( - filled_replay.current_size == size - ), "Filled replay length was not equal to batch size." + assert filled_replay.current_size == size, ( + "Filled replay length was not equal to batch size." + ) # Now add to the previously empty replay with a brand-new td_error array td_errors = rng.random(size).astype(np.float32) replay.add(batch, {"td_error": td_errors}) # Both replays should have the same current_size and data_idx - assert ( - replay.current_size == filled_replay.current_size - ), "Replay length was not equal to filled replay length." - assert ( - replay.data_idx == filled_replay.data_idx - ), "Replay data_idx was not equal to filled replay data_idx." + assert replay.current_size == filled_replay.current_size, ( + "Replay length was not equal to filled replay length." + ) + assert replay.data_idx == filled_replay.data_idx, ( + "Replay data_idx was not equal to filled replay data_idx." + ) # Compute expected priority = (|td_error| + ε)^α for each leaf eps = replay.epsilon @@ -527,9 +527,9 @@ def test_add(self, observations, actions, rewards, next_observations, dones, siz base = replay.capacity actual_leaves = replay.sum_tree[base : base + size] - assert np.allclose( - actual_leaves, expected_prios, atol=1e-6 - ), f"Expected leaves {expected_prios}, but got {actual_leaves}" + assert np.allclose(actual_leaves, expected_prios, atol=1e-6), ( + f"Expected leaves {expected_prios}, but got {actual_leaves}" + ) @pytest.mark.parametrize( ("observations", "actions", "rewards", "next_observations", "dones", "size"), @@ -588,9 +588,9 @@ def test_sample( # If there is more than one element, uniform draws should vary if size > 1: - assert ( - len(set(seen_actions)) >= 2 - ), "Uniform-priority sampling did not vary." + assert len(set(seen_actions)) >= 2, ( + "Uniform-priority sampling did not vary." + ) # 2) Skewed priorities (only one index has nonzero td_error): replay = self.get_replay(batch, size, empty=True) @@ -615,9 +615,9 @@ def test_sample( forced_indices.append(int(batch_indices_b.item())) # All sampled indices must be identical (the only one with nonzero priority) - assert ( - len(set(forced_indices)) == 1 - ), f"Expected only one index to be chosen, but got {set(forced_indices)}" + assert len(set(forced_indices)) == 1, ( + f"Expected only one index to be chosen, but got {set(forced_indices)}" + ) @pytest.mark.parametrize( ("observations", "actions", "rewards", "next_observations", "dones", "size"), @@ -648,9 +648,9 @@ def test_reset( # **Because reset() does not clear on‐device buffers or sum_tree**, we expect: # - current_size remains equal to `size` # - sum_tree[1] (the total priority) remains > 0 - assert ( - replay.current_size == size - ), f"After reset(), expected current_size to still be {size}, but got {replay.current_size}" - assert ( - replay.sum_tree[1] > 0.0 - ), "After reset(), expected total priority (sum_tree[1]) to remain > 0" + assert replay.current_size == size, ( + f"After reset(), expected current_size to still be {size}, but got {replay.current_size}" + ) + assert replay.sum_tree[1] > 0.0, ( + "After reset(), expected total priority (sum_tree[1]) to remain > 0" + ) diff --git a/test/replay/test_rollout_buffer.py b/test/replay/test_rollout_buffer.py index 42c8d7e8..f7be466b 100644 --- a/test/replay/test_rollout_buffer.py +++ b/test/replay/test_rollout_buffer.py @@ -92,16 +92,16 @@ def test_init( assert isinstance(batch.rewards, torch.Tensor), "Rewards not tensor" assert isinstance(batch.advantages, torch.Tensor), "Advantages not tensor" assert isinstance(batch.returns, torch.Tensor), "Returns not tensor" - assert isinstance( - batch.episode_starts, torch.Tensor - ), "Episode starts not tensor" + assert isinstance(batch.episode_starts, torch.Tensor), ( + "Episode starts not tensor" + ) assert isinstance(batch.log_probs, torch.Tensor), "Log probs not tensor" assert isinstance(batch.values, torch.Tensor), "Values not tensor" # Check dimensions are promoted correctly - assert ( - batch.observations.dim() >= 2 - ), f"Obs dim too low: {batch.observations.shape}" + assert batch.observations.dim() >= 2, ( + f"Obs dim too low: {batch.observations.shape}" + ) assert batch.actions.dim() >= 1, f"Actions dim too low: {batch.actions.shape}" # For discrete actions, latents should be None @@ -123,9 +123,9 @@ def test_init( values=values, ) assert batch_with_latents.latents is not None, "Latents should not be None" - assert isinstance( - batch_with_latents.latents, torch.Tensor - ), "Latents not tensor" + assert isinstance(batch_with_latents.latents, torch.Tensor), ( + "Latents not tensor" + ) @pytest.mark.parametrize( ( @@ -215,9 +215,9 @@ def test_iter( if discrete: assert lat is None, "Latent should be None for discrete" else: - assert lat is None or isinstance( - lat, torch.Tensor - ), "Latent issue in iteration" + assert lat is None or isinstance(lat, torch.Tensor), ( + "Latent issue in iteration" + ) elements += 1 assert elements == size, f"Expected {size} elements, got {elements}" @@ -391,16 +391,16 @@ def test_attribute_stacking(self): # Test stacked observations # After stacking: (2, 1, 1, 2), after reshape: (2, 1, 2) expected_obs = torch.tensor([[[1, 2]], [[3, 4]]], dtype=torch.float32) - assert torch.allclose( - maxi.observations, expected_obs - ), f"Observations not stacked correctly. Expected {expected_obs}, got {maxi.observations}" + assert torch.allclose(maxi.observations, expected_obs), ( + f"Observations not stacked correctly. Expected {expected_obs}, got {maxi.observations}" + ) # Test stacked actions # After stacking: (2, 1, 1), after reshape: (2, 1) expected_actions = torch.tensor([[0], [1]], dtype=torch.float32) - assert torch.allclose( - maxi.actions, expected_actions - ), f"Actions not stacked correctly. Expected {expected_actions}, got {maxi.actions}" + assert torch.allclose(maxi.actions, expected_actions), ( + f"Actions not stacked correctly. Expected {expected_actions}, got {maxi.actions}" + ) def test_latents_handling(self): """Test MaxiBatch handles latents correctly (None vs tensor)""" @@ -426,9 +426,9 @@ def test_latents_handling(self): expected_latents_shape = torch.Size( [1] ) # zeros_like creates (1,1), stacked to (1,1,1), reshaped to (1,) - assert ( - latents.shape == expected_latents_shape - ), f"Latents shape {latents.shape} should be {expected_latents_shape}" + assert latents.shape == expected_latents_shape, ( + f"Latents shape {latents.shape} should be {expected_latents_shape}" + ) def test_empty_tensor_for_empty_batch(self): """Test that empty MaxiBatch returns empty tensors""" @@ -483,9 +483,9 @@ def test_init_continuous(self): 1, 2, ), "Wrong actions buffer shape (continuous)" - assert ( - buffer.latents is not None - ), "Latents should not be None for continuous with use_latents=True" + assert buffer.latents is not None, ( + "Latents should not be None for continuous with use_latents=True" + ) assert buffer.latents.shape == (10, 1, 2), "Wrong latents buffer shape" def test_init_continuous_no_latents(self): @@ -616,9 +616,9 @@ def test_add_multi_step(self): ) buffer.add(rb) - assert ( - buffer.pos == 3 - ), f"Position should be 3 after adding 3 steps, got {buffer.pos}" + assert buffer.pos == 3, ( + f"Position should be 3 after adding 3 steps, got {buffer.pos}" + ) def test_buffer_overflow(self): buffer = self.get_buffer(buffer_size=2) # Small buffer @@ -670,12 +670,12 @@ def test_compute_returns_and_advantage_single_env(self): buffer.compute_returns_and_advantage(last_values, dones) # Check that advantages and returns were computed (non-zero) - assert not torch.allclose( - buffer.advantages[:2], torch.zeros(2, 1) - ), "Advantages should be computed (non-zero)" - assert not torch.allclose( - buffer.returns[:2], torch.zeros(2, 1) - ), "Returns should be computed (non-zero)" + assert not torch.allclose(buffer.advantages[:2], torch.zeros(2, 1)), ( + "Advantages should be computed (non-zero)" + ) + assert not torch.allclose(buffer.returns[:2], torch.zeros(2, 1)), ( + "Returns should be computed (non-zero)" + ) def test_compute_returns_and_advantage_multi_env(self): """Test GAE computation with multiple environments""" @@ -690,9 +690,7 @@ def test_compute_returns_and_advantage_multi_env(self): advantages=np.array([[0.0, 0.0]]), # (1, 2) ✓ returns=np.array([[0.0, 0.0]]), # (1, 2) ✓ episode_starts=np.array([[1, 1]]), # (1, 2) ✓ - log_probs=np.array( - [[-0.5, -0.8]] - ), # (1, 2) + log_probs=np.array([[-0.5, -0.8]]), # (1, 2) values=np.array([[1.0, 0.5]]), # (1, 2) ✓ ) @@ -725,18 +723,18 @@ def test_compute_returns_and_advantage_multi_env(self): print(f"Computed returns: {returns_computed}") # Basic sanity checks - assert not torch.allclose( - advantages_computed, torch.zeros(2) - ), "Advantages should be non-zero" - assert not torch.allclose( - returns_computed, torch.zeros(2) - ), "Returns should be non-zero" + assert not torch.allclose(advantages_computed, torch.zeros(2)), ( + "Advantages should be non-zero" + ) + assert not torch.allclose(returns_computed, torch.zeros(2)), ( + "Returns should be non-zero" + ) # For GAE, returns = advantages + values (at time t) expected_returns = advantages_computed + buffer.values[0] - assert torch.allclose( - returns_computed, expected_returns, atol=1e-6 - ), f"Returns should equal advantages + values: {returns_computed} vs {expected_returns}" + assert torch.allclose(returns_computed, expected_returns, atol=1e-6), ( + f"Returns should equal advantages + values: {returns_computed} vs {expected_returns}" + ) def test_compute_returns_empty_buffer(self): """Test GAE computation on empty buffer""" @@ -775,9 +773,9 @@ def test_sample_insufficient_data(self): # Try to sample batch_size=4 when only 1 transition available maxi_batch = buffer.sample(batch_size=4) - assert ( - len(maxi_batch) == 0 - ), "Should return empty MaxiBatch when insufficient data" + assert len(maxi_batch) == 0, ( + "Should return empty MaxiBatch when insufficient data" + ) def test_sample_with_data(self): buffer = self.get_buffer(buffer_size=10, n_envs=2, discrete=True) @@ -822,12 +820,10 @@ def test_sample_with_data(self): ) buffer.add(rb) - # We have 2 timesteps × 2 envs = 4 total transitions - assert ( - len(buffer) == 4 - ), f"Buffer should contain 4 transitions, got {len(buffer)}" - + assert len(buffer) == 4, ( + f"Buffer should contain 4 transitions, got {len(buffer)}" + ) maxi_batch = buffer.sample(batch_size=2) @@ -856,29 +852,29 @@ def test_sample_with_data(self): # Test that each minibatch has valid data for i, mb in enumerate(minibatches): - assert ( - mb.observations is not None - ), f"Minibatch {i} observations should not be None" - assert ( - mb.log_probs is not None - ), f"Minibatch {i} log_probs should not be None" - assert ( - mb.observations.shape[0] > 0 - ), f"Minibatch {i} should have some observations" + assert mb.observations is not None, ( + f"Minibatch {i} observations should not be None" + ) + assert mb.log_probs is not None, ( + f"Minibatch {i} log_probs should not be None" + ) + assert mb.observations.shape[0] > 0, ( + f"Minibatch {i} should have some observations" + ) print(f"Minibatch {i} validated: obs.shape={mb.observations.shape}") else: # Original expected behavior - assert ( - len(maxi_batch) == 4 - ), f"Should have 4 total sampled elements, got {len(maxi_batch)}" - assert ( - len(minibatches) == 2 - ), f"Should have 2 minibatches, got {len(minibatches)}" + assert len(maxi_batch) == 4, ( + f"Should have 4 total sampled elements, got {len(maxi_batch)}" + ) + assert len(minibatches) == 2, ( + f"Should have 2 minibatches, got {len(minibatches)}" + ) for i, mb in enumerate(minibatches): - assert ( - len(mb) == 2 - ), f"Minibatch {i} should have 2 elements, got {len(mb)}" + assert len(mb) == 2, ( + f"Minibatch {i} should have 2 elements, got {len(mb)}" + ) def test_len_and_bool(self): buffer = self.get_buffer(n_envs=2) @@ -965,9 +961,9 @@ def test_save_and_load(self): assert loaded_buffer.n_envs == buffer.n_envs, "N envs mismatch" assert loaded_buffer.gamma == buffer.gamma, "Gamma mismatch" assert loaded_buffer.gae_lambda == buffer.gae_lambda, "GAE lambda mismatch" - assert torch.allclose( - loaded_buffer.observations, buffer.observations - ), "Observations mismatch" + assert torch.allclose(loaded_buffer.observations, buffer.observations), ( + "Observations mismatch" + ) assert torch.allclose(loaded_buffer.actions, buffer.actions), "Actions mismatch" assert torch.allclose(loaded_buffer.rewards, buffer.rewards), "Rewards mismatch" @@ -1025,20 +1021,6 @@ def test_gae_computation_details(self): buffer.compute_returns_and_advantage(last_values, dones) - # Manually calculate expected values for verification - gamma = buffer.gamma - lam = buffer.gae_lambda - - # Step 1 (t=1): delta = r1 + gamma * V(s2) * (1-done) - V(s1) - # = 2.0 + 0.99 * 1.5 * 1 - 1.0 = 2.485 - # Step 0 (t=0): delta = r0 + gamma * V(s1) * (1-episode_start[1]) - V(s0) - # = 1.0 + 0.99 * 1.0 * 1 - 0.5 = 1.49 - - # GAE calculation (backward): - # gae_1 = delta_1 = 2.485 - # gae_0 = delta_0 + gamma * lambda * (1-episode_start[1]) * gae_1 - # = 1.49 + 0.99 * 0.95 * 1 * 2.485 ≈ 3.82 - advantages = buffer.advantages[:2, 0].cpu().numpy() returns = buffer.returns[:2, 0].cpu().numpy() @@ -1048,9 +1030,9 @@ def test_gae_computation_details(self): # Returns should be advantages + values expected_returns = advantages + np.array([0.5, 1.0]) - assert np.allclose( - returns, expected_returns, atol=1e-5 - ), f"Returns should equal advantages + values: {returns} vs {expected_returns}" + assert np.allclose(returns, expected_returns, atol=1e-5), ( + f"Returns should equal advantages + values: {returns} vs {expected_returns}" + ) def test_episode_boundary_handling(self): """Test that episode boundaries are handled correctly in GAE""" @@ -1119,9 +1101,9 @@ def test_episode_boundary_handling(self): advantages = buffer.advantages[:3, 0].cpu().numpy() # All advantages should be computed (non-zero) - assert not np.allclose( - advantages, [0.0, 0.0, 0.0] - ), "All advantages should be computed" + assert not np.allclose(advantages, [0.0, 0.0, 0.0]), ( + "All advantages should be computed" + ) def test_multi_env_independence(self): """Test that multiple environments are handled independently""" @@ -1137,9 +1119,7 @@ def test_multi_env_independence(self): advantages=np.array([[0.0, 0.0, 0.0]]), # (1, 3) returns=np.array([[0.0, 0.0, 0.0]]), # (1, 3) episode_starts=np.array([[1, 1, 1]]), # (1, 3) - All start new episodes - log_probs=np.array( - [[-0.5, -0.8, -0.3]] - ), # (1, 3) + log_probs=np.array([[-0.5, -0.8, -0.3]]), # (1, 3) values=np.array([[0.5, 1.0, 0.3]]), # (1, 3) ) @@ -1155,28 +1135,27 @@ def test_multi_env_independence(self): # All environments should have computed advantages assert len(advantages) == 3, "Should have advantages for all 3 envs" - assert not np.allclose( - advantages, [0.0, 0.0, 0.0] - ), "All advantages should be computed" + assert not np.allclose(advantages, [0.0, 0.0, 0.0]), ( + "All advantages should be computed" + ) # The done environment (env 1) should have different computation # (no bootstrap from next value) - assert ( - advantages[1] != advantages[0] - ), "Done env should have different advantage" - assert ( - advantages[1] != advantages[2] - ), "Done env should have different advantage" - + assert advantages[1] != advantages[0], ( + "Done env should have different advantage" + ) + assert advantages[1] != advantages[2], ( + "Done env should have different advantage" + ) # Additional verification: returns should equal advantages + values for GAE returns = buffer.returns[0, :].cpu().numpy() values = buffer.values[0, :].cpu().numpy() expected_returns = advantages + values - assert np.allclose( - returns, expected_returns, atol=1e-6 - ), f"Returns should equal advantages + values: {returns} vs {expected_returns}" + assert np.allclose(returns, expected_returns, atol=1e-6), ( + f"Returns should equal advantages + values: {returns} vs {expected_returns}" + ) def test_sampling_randomness(self): """Test that sampling produces different results when called multiple times""" @@ -1218,9 +1197,9 @@ def test_sampling_randomness(self): # Should get some variety in samples (not all identical) unique_samples = set(samples) - assert ( - len(unique_samples) > 1 - ), f"Sampling should be random, got {len(unique_samples)} unique samples from: {samples}" + assert len(unique_samples) > 1, ( + f"Sampling should be random, got {len(unique_samples)} unique samples from: {samples}" + ) def test_mixed_data_types(self): """Test buffer handles different numpy data types correctly""" @@ -1243,8 +1222,8 @@ def test_mixed_data_types(self): assert buffer.pos == 1, "Should successfully add data with mixed types" # All stored tensors should be float32 - assert ( - buffer.observations.dtype == torch.float32 - ), "Observations should be float32" + assert buffer.observations.dtype == torch.float32, ( + "Observations should be float32" + ) assert buffer.actions.dtype == torch.float32, "Actions should be float32" assert buffer.rewards.dtype == torch.float32, "Rewards should be float32" diff --git a/test/runners/test_es_runner.py b/test/runners/test_es_runner.py index 2f25bf6a..c7c078e0 100644 --- a/test/runners/test_es_runner.py +++ b/test/runners/test_es_runner.py @@ -61,15 +61,15 @@ class TestMightyNESRunner: def test_init(self): runner = MightyESRunner(self.runner_config) - assert isinstance( - runner, MightyRunner - ), "MightyNESRunner should be an instance of MightyRunner" - assert isinstance( - runner.agent, MightyAgent - ), "MightyNESRunner should have a MightyAgent" - assert isinstance( - runner.agent.eval_env, PufferlibToGymAdapter - ), "Eval env should be a PufferlibToGymAdapter" + assert isinstance(runner, MightyRunner), ( + "MightyNESRunner should be an instance of MightyRunner" + ) + assert isinstance(runner.agent, MightyAgent), ( + "MightyNESRunner should have a MightyAgent" + ) + assert isinstance(runner.agent.eval_env, PufferlibToGymAdapter), ( + "Eval env should be a PufferlibToGymAdapter" + ) assert runner.agent.env is not None, "Env should be set" assert runner.iterations is not None, "Iterations should be set" assert runner.es is not None, "ES should be set" @@ -85,18 +85,18 @@ def test_run(self): new_params = runner.agent.parameters assert isinstance(train_results, dict), "Train results should be a dictionary" assert isinstance(eval_results, dict), "Eval results should be a dictionary" - assert ( - "mean_eval_reward" in eval_results - ), "Mean eval reward should be in eval results" + assert "mean_eval_reward" in eval_results, ( + "Mean eval reward should be in eval results" + ) param_equals = [o == p for o, p in zip(old_params, new_params)] for params in param_equals: - assert not all( - params.flatten() - ), "Parameters should have changed in training" - assert ( - not old_lr == runner.agent.learning_rate - ), "Learning rate should have changed in training" - assert ( - not old_batch_size == runner.agent._batch_size - ), "Batch size should have changed in training" + assert not all(params.flatten()), ( + "Parameters should have changed in training" + ) + assert not old_lr == runner.agent.learning_rate, ( + "Learning rate should have changed in training" + ) + assert not old_batch_size == runner.agent._batch_size, ( + "Batch size should have changed in training" + ) shutil.rmtree("test_nes_runner") diff --git a/test/runners/test_runner.py b/test/runners/test_runner.py index e2b6ce31..fb94f41d 100644 --- a/test/runners/test_runner.py +++ b/test/runners/test_runner.py @@ -2,14 +2,14 @@ import shutil -import pytest import gymnasium as gym +import pytest from omegaconf import OmegaConf from mighty.mighty_agents import MightyAgent from mighty.mighty_runners import MightyOnlineRunner, MightyRunner -from mighty.mighty_utils.wrappers import PufferlibToGymAdapter from mighty.mighty_utils.test_helpers import DummyEnv +from mighty.mighty_utils.wrappers import PufferlibToGymAdapter class TestMightyRunner: @@ -57,22 +57,22 @@ class TestMightyRunner: def test_init(self): runner = MightyOnlineRunner(self.runner_config) - assert isinstance( - runner, MightyRunner - ), "MightyOnlineRunner should be an instance of MightyRunner" - assert isinstance( - runner.agent, MightyAgent - ), "MightyOnlineRunner should have a MightyAgent" - assert isinstance( - runner.agent.eval_env, PufferlibToGymAdapter - ), "Eval env should be a PufferlibToGymAdapter" + assert isinstance(runner, MightyRunner), ( + "MightyOnlineRunner should be an instance of MightyRunner" + ) + assert isinstance(runner.agent, MightyAgent), ( + "MightyOnlineRunner should have a MightyAgent" + ) + assert isinstance(runner.agent.eval_env, PufferlibToGymAdapter), ( + "Eval env should be a PufferlibToGymAdapter" + ) assert runner.agent.env is not None, "Env should not be None" - assert ( - runner.eval_every_n_steps == self.runner_config.eval_every_n_steps - ), "Eval every n steps should be set" - assert ( - runner.num_steps == self.runner_config.num_steps - ), "Num steps should be set" + assert runner.eval_every_n_steps == self.runner_config.eval_every_n_steps, ( + "Eval every n steps should be set" + ) + assert runner.num_steps == self.runner_config.num_steps, ( + "Num steps should be set" + ) def test_train(self): runner = MightyOnlineRunner(self.runner_config) @@ -96,14 +96,14 @@ def test_run(self): train_results, eval_results = runner.run() assert isinstance(train_results, dict), "Train results should be a dictionary" assert isinstance(eval_results, dict), "Eval results should be a dictionary" - assert ( - "mean_eval_reward" in eval_results - ), "Eval results should have mean_eval_reward" + assert "mean_eval_reward" in eval_results, ( + "Eval results should have mean_eval_reward" + ) shutil.rmtree("test_runner") def test_run_with_alternate_env(self): dummy_env = gym.vector.SyncVectorEnv([DummyEnv for _ in range(3)]) - dummy_eval_func = lambda: gym.vector.SyncVectorEnv( # noqa: E731 + dummy_eval_func = lambda: gym.vector.SyncVectorEnv( # noqa: E731 [DummyEnv for _ in range(10)] ) eval_default = 10 diff --git a/test/runners/test_runner_factory.py b/test/runners/test_runner_factory.py index bfad9840..77f556c4 100644 --- a/test/runners/test_runner_factory.py +++ b/test/runners/test_runner_factory.py @@ -17,9 +17,9 @@ class TestFactory: def test_create_agent(self): for runner_type in VALID_RUNNER_TYPES: runner_class = get_runner_class(runner_type) - assert ( - runner_class == RUNNER_CLASSES[runner_type] - ), f"Runner class should be {RUNNER_CLASSES[runner_type]}" + assert runner_class == RUNNER_CLASSES[runner_type], ( + f"Runner class should be {RUNNER_CLASSES[runner_type]}" + ) def test_create_agent_with_invalid_type(self): with pytest.raises(ValueError): diff --git a/test/test_env_creation.py b/test/test_env_creation.py index c6b6eff4..dbe41d14 100644 --- a/test/test_env_creation.py +++ b/test/test_env_creation.py @@ -120,56 +120,56 @@ class TestEnvCreation: def check_vector_env(self, env): """Check if environment is a vector environment.""" - assert hasattr( - env, "num_envs" - ), f"Vector environment should have num_envs attribute: {env}" - assert hasattr( - env, "reset" - ), f"Vector environment should have reset method: {env}." - assert hasattr( - env, "step" - ), f"Vector environment should have step method: {env}." - assert hasattr( - env, "close" - ), f"Vector environment should have close method: {env}." - assert hasattr( - env, "single_action_space" - ), f"Vector environment should have single action space view: {env}." - assert hasattr( - env, "single_observation_space" - ), f"Vector environment should have single observation space view: {env}." - assert hasattr( - env, "envs" - ), f"Environments should be kept in envs attribute: {env}." + assert hasattr(env, "num_envs"), ( + f"Vector environment should have num_envs attribute: {env}" + ) + assert hasattr(env, "reset"), ( + f"Vector environment should have reset method: {env}." + ) + assert hasattr(env, "step"), ( + f"Vector environment should have step method: {env}." + ) + assert hasattr(env, "close"), ( + f"Vector environment should have close method: {env}." + ) + assert hasattr(env, "single_action_space"), ( + f"Vector environment should have single action space view: {env}." + ) + assert hasattr(env, "single_observation_space"), ( + f"Vector environment should have single observation space view: {env}." + ) + assert hasattr(env, "envs"), ( + f"Environments should be kept in envs attribute: {env}." + ) def test_make_gym_env(self): """Test env creation with make_gym_env.""" env, eval_env, eval_default = make_gym_env(self.gym_config) self.check_vector_env(env) self.check_vector_env(eval_env()) - assert ( - eval_default == self.gym_config.n_episodes_eval - ), "Default number of eval episodes should match config" - assert ( - len(env.envs) == self.gym_config.num_envs - ), "Number of environments should match config." - assert ( - len(eval_env().envs) == self.gym_config.n_episodes_eval - ), "Number of environments should match config." + assert eval_default == self.gym_config.n_episodes_eval, ( + "Default number of eval episodes should match config" + ) + assert len(env.envs) == self.gym_config.num_envs, ( + "Number of environments should match config." + ) + assert len(eval_env().envs) == self.gym_config.n_episodes_eval, ( + "Number of environments should match config." + ) - assert ( - self.gym_config.env == env.envs[0].spec.id - ), "Environment should be created with the correct id." - assert ( - self.gym_config.env == eval_env().envs[0].spec.id - ), "Eval environment should be created with the correct id." + assert self.gym_config.env == env.envs[0].spec.id, ( + "Environment should be created with the correct id." + ) + assert self.gym_config.env == eval_env().envs[0].spec.id, ( + "Eval environment should be created with the correct id." + ) - assert isinstance( - env, gym.vector.SyncVectorEnv - ), "Gym environment should be a SyncVectorEnv." - assert isinstance( - eval_env(), gym.vector.SyncVectorEnv - ), "Eval environment should be a SyncVectorEnv." + assert isinstance(env, gym.vector.SyncVectorEnv), ( + "Gym environment should be a SyncVectorEnv." + ) + assert isinstance(eval_env(), gym.vector.SyncVectorEnv), ( + "Eval environment should be a SyncVectorEnv." + ) def test_make_dacbench_env(self): """Test env creation with make_dacbench_env.""" @@ -180,32 +180,34 @@ def test_make_dacbench_env(self): eval_default == len(env.envs[0].instance_set.keys()) * self.dacbench_config.n_episodes_eval - ), "Default number of eval episodes should instance set size times evaluation episodes." - assert ( - len(env.envs) == self.dacbench_config.num_envs - ), "Number of environments should match config." + ), ( + "Default number of eval episodes should instance set size times evaluation episodes." + ) + assert len(env.envs) == self.dacbench_config.num_envs, ( + "Number of environments should match config." + ) assert ( len(eval_env().envs) == len(env.envs[0].instance_set.keys()) * self.dacbench_config.n_episodes_eval ), "Number of environments should match eval length." - assert isinstance( - env, gym.vector.SyncVectorEnv - ), "DACBench environment should be a SyncVectorEnv." - assert isinstance( - eval_env(), gym.vector.SyncVectorEnv - ), "Eval environment should be a SyncVectorEnv." + assert isinstance(env, gym.vector.SyncVectorEnv), ( + "DACBench environment should be a SyncVectorEnv." + ) + assert isinstance(eval_env(), gym.vector.SyncVectorEnv), ( + "Eval environment should be a SyncVectorEnv." + ) bench = getattr(benchmarks, self.dacbench_config.env)() for k in self.dacbench_config.env_kwargs: bench.config[k] = self.dacbench_config.env_kwargs[k] - assert isinstance( - env.envs[0], type(bench.get_environment()) - ), "Environment should have correct type." - assert isinstance( - eval_env().envs[0], type(bench.get_environment()) - ), "Eval environment should have correct type." + assert isinstance(env.envs[0], type(bench.get_environment())), ( + "Environment should have correct type." + ) + assert isinstance(eval_env().envs[0], type(bench.get_environment())), ( + "Eval environment should have correct type." + ) assert ( env.envs[0].config.instance_set_path == self.dacbench_config.env_kwargs.instance_set_path @@ -221,22 +223,24 @@ def test_make_dacbench_benchmark_mode(self): eval_default == len(env.envs[0].instance_set.keys()) * self.dacbench_config_benchmark.n_episodes_eval - ), "Default number of eval episodes should instance set size times evaluation episodes." - assert ( - len(env.envs) == self.dacbench_config_benchmark.num_envs - ), "Number of environments should match config." + ), ( + "Default number of eval episodes should instance set size times evaluation episodes." + ) + assert len(env.envs) == self.dacbench_config_benchmark.num_envs, ( + "Number of environments should match config." + ) assert ( len(eval_env().envs) == len(env.envs[0].instance_set.keys()) * self.dacbench_config_benchmark.n_episodes_eval ), "Number of environments should match eval length." - assert isinstance( - env, gym.vector.SyncVectorEnv - ), "DACBench environment should be a SyncVectorEnv." - assert isinstance( - eval_env(), gym.vector.SyncVectorEnv - ), "Eval environment should be a SyncVectorEnv." + assert isinstance(env, gym.vector.SyncVectorEnv), ( + "DACBench environment should be a SyncVectorEnv." + ) + assert isinstance(eval_env(), gym.vector.SyncVectorEnv), ( + "Eval environment should be a SyncVectorEnv." + ) benchmark_kwargs = OmegaConf.to_container( self.dacbench_config_benchmark.env_kwargs, resolve=True @@ -245,12 +249,12 @@ def test_make_dacbench_benchmark_mode(self): del benchmark_kwargs["config_space"] bench = getattr(benchmarks, self.dacbench_config_benchmark.env)() benchmark_env = bench.get_benchmark(**benchmark_kwargs) - assert isinstance( - env.envs[0], type(benchmark_env) - ), "Environment should have correct type." - assert isinstance( - eval_env().envs[0], type(benchmark_env) - ), "Eval environment should have correct type." + assert isinstance(env.envs[0], type(benchmark_env)), ( + "Environment should have correct type." + ) + assert isinstance(eval_env().envs[0], type(benchmark_env)), ( + "Eval environment should have correct type." + ) assert eval_env().envs[0].test, "Eval environment should be in test mode." for k in env.envs[0].config.keys(): if k == "observation_space_args": @@ -260,19 +264,25 @@ def test_make_dacbench_benchmark_mode(self): assert ( env.envs[0].config[k][i].functions[0].a == benchmark_env.config[k][i].functions[0].a - ), f"Environment should have matching instances, mismatch for function parameter a at instance {i}: {env.envs[0].config[k][i].functions[0].a} != {benchmark_env.config[k][i].functions[0].a}" + ), ( + f"Environment should have matching instances, mismatch for function parameter a at instance {i}: {env.envs[0].config[k][i].functions[0].a} != {benchmark_env.config[k][i].functions[0].a}" + ) assert ( env.envs[0].config[k][i].functions[0].b == benchmark_env.config[k][i].functions[0].b - ), f"Environment should have matching instances, mismatch for function parameter b at instance {i}: {env.envs[0].config[k][i].functions[0].b} != {benchmark_env.config[k][i].functions[0].b}" + ), ( + f"Environment should have matching instances, mismatch for function parameter b at instance {i}: {env.envs[0].config[k][i].functions[0].b} != {benchmark_env.config[k][i].functions[0].b}" + ) assert ( env.envs[0].config[k][i].omit_instance_type == benchmark_env.config[k][i].omit_instance_type - ), f"Environment should have matching instances, mismatch for omit_instance_type at instance {i}: {env.envs[0].config[k][i].omit_instance_type} != {benchmark_env.config[k][i].omit_instance_type}" + ), ( + f"Environment should have matching instances, mismatch for omit_instance_type at instance {i}: {env.envs[0].config[k][i].omit_instance_type} != {benchmark_env.config[k][i].omit_instance_type}" + ) else: - assert ( - env.envs[0].config[k] == benchmark_env.config[k] - ), f"Environment should have correct config, mismatch at {k}: {env.envs[0].config[k]} != {benchmark_env.config[k]}" + assert env.envs[0].config[k] == benchmark_env.config[k], ( + f"Environment should have correct config, mismatch at {k}: {env.envs[0].config[k]} != {benchmark_env.config[k]}" + ) def test_make_carl_env(self): """Test env creation with make_carl_env.""" @@ -284,19 +294,19 @@ def test_make_carl_env(self): ), "Default number of eval episodes should match config" env_class = getattr(carl.envs, self.carl_config.env) - assert isinstance( - env.envs[0], env_class - ), "Environment should have the correct type." - assert isinstance( - eval_env().envs[0], env_class - ), "Eval environment should have the correct type." + assert isinstance(env.envs[0], env_class), ( + "Environment should have the correct type." + ) + assert isinstance(eval_env().envs[0], env_class), ( + "Eval environment should have the correct type." + ) - assert isinstance( - env, CARLVectorEnvSimulator - ), "CARL environment should be wrapped." - assert isinstance( - eval_env(), CARLVectorEnvSimulator - ), "CARL eval environment should be wrapped." + assert isinstance(env, CARLVectorEnvSimulator), ( + "CARL environment should be wrapped." + ) + assert isinstance(eval_env(), CARLVectorEnvSimulator), ( + "CARL eval environment should be wrapped." + ) def test_make_carl_context(self): """Test env creation with make_carl_env.""" @@ -312,9 +322,9 @@ def test_make_carl_context(self): assert ( len(train_contexts) == self.carl_config_context.env_kwargs.num_contexts ), "Number of training contexts should match config." - assert ( - len(eval_contexts) == 100 - ), "Number of eval contexts should match default." + assert len(eval_contexts) == 100, ( + "Number of eval contexts should match default." + ) assert not all( [ @@ -409,32 +419,36 @@ def test_make_carl_context(self): ), "Eval contexts lie above lower bound for gravity." assert isinstance( env.envs[0].context_selector, carl.context.selection.StaticSelector - ), f"Context selector should be switched to a StaticSelector based on keyword but is {type(env.envs[0].context_selector)}." + ), ( + f"Context selector should be switched to a StaticSelector based on keyword but is {type(env.envs[0].context_selector)}." + ) assert isinstance( eval_env().envs[0].context_selector, carl.context.selection.RoundRobinSelector, - ), f"Eval env context selector should stay round robin but is {type(eval_env().envs[0].context_selector)}." + ), ( + f"Eval env context selector should stay round robin but is {type(eval_env().envs[0].context_selector)}." + ) def test_make_procgen_env(self): """Test env creation with make_procgen_env.""" if PROCGEN: env, eval_env, eval_default = make_procgen_env(self.procgen_config) - assert ( - eval_default == self.procgen_config.n_episodes_eval - ), "Default number of eval episodes should match config" + assert eval_default == self.procgen_config.n_episodes_eval, ( + "Default number of eval episodes should match config" + ) self.check_vector_env(env) self.check_vector_env(eval_env()) if ENVPOOL: - assert isinstance( - env, envpool.VectorEnv - ), "Environment should be an envpool env if we create a gym env with envpool installed." + assert isinstance(env, envpool.VectorEnv), ( + "Environment should be an envpool env if we create a gym env with envpool installed." + ) else: - assert isinstance( - env, ProcgenVecEnv - ), "Environment should be ProcGen env if we create a gym env without envpool installed." - assert isinstance( - eval_env(), ProcgenVecEnv - ), "Eval env should be a ProcGen env." + assert isinstance(env, ProcgenVecEnv), ( + "Environment should be ProcGen env if we create a gym env without envpool installed." + ) + assert isinstance(eval_env(), ProcgenVecEnv), ( + "Eval env should be a ProcGen env." + ) else: Warning("Procgen not installed, skipping test.") @@ -443,57 +457,57 @@ def test_make_pufferlib_env(self): env, eval_env, eval_default = make_pufferlib_env(self.pufferlib_config) self.check_vector_env(env) self.check_vector_env(eval_env()) - assert ( - eval_default == self.pufferlib_config.n_episodes_eval - ), "Default number of eval episodes should match config" - assert ( - len(env.envs) == self.pufferlib_config.num_envs - ), "Number of environments should match config." - assert ( - len(eval_env().envs) == self.pufferlib_config.n_episodes_eval - ), "Number of environments should match config." + assert eval_default == self.pufferlib_config.n_episodes_eval, ( + "Default number of eval episodes should match config" + ) + assert len(env.envs) == self.pufferlib_config.num_envs, ( + "Number of environments should match config." + ) + assert len(eval_env().envs) == self.pufferlib_config.n_episodes_eval, ( + "Number of environments should match config." + ) domain = ".".join(self.pufferlib_config.env.split(".")[:-1]) name = self.pufferlib_config.env.split(".")[-1] get_env_func = importlib.import_module(domain).env_creator make_env = get_env_func(name)(**self.pufferlib_config.env_kwargs) - assert isinstance( - env.envs[0], type(make_env) - ), "Environment should have correct type." - assert isinstance( - eval_env().envs[0], type(make_env) - ), "Eval environment should have correct type." + assert isinstance(env.envs[0], type(make_env)), ( + "Environment should have correct type." + ) + assert isinstance(eval_env().envs[0], type(make_env)), ( + "Eval environment should have correct type." + ) - assert isinstance( - env, PufferlibToGymAdapter - ), "Pufferlib env should be wrapped." - assert isinstance( - eval_env(), PufferlibToGymAdapter - ), "Pufferlib eval env should be wrapped." + assert isinstance(env, PufferlibToGymAdapter), ( + "Pufferlib env should be wrapped." + ) + assert isinstance(eval_env(), PufferlibToGymAdapter), ( + "Pufferlib eval env should be wrapped." + ) def test_make_mighty_env(self): """Test correct typing of environments when creating with make_mighty_env.""" env, eval_env, eval_default = make_mighty_env(self.gym_config) - assert ( - eval_default == self.gym_config.n_episodes_eval - ), "Default number of eval episodes should match config" + assert eval_default == self.gym_config.n_episodes_eval, ( + "Default number of eval episodes should match config" + ) self.check_vector_env(env) self.check_vector_env(eval_env()) if ENVPOOL: - assert isinstance( - env, envpool.VectorEnv - ), "Mighty environment should be an envpool env if we create a gym env with envpool installed." - assert isinstance( - eval_env(), gym.vector.SyncVectorEnv - ), "Eval env should be a SyncVectorEnv env if we create a gym env with envpool installed." + assert isinstance(env, envpool.VectorEnv), ( + "Mighty environment should be an envpool env if we create a gym env with envpool installed." + ) + assert isinstance(eval_env(), gym.vector.SyncVectorEnv), ( + "Eval env should be a SyncVectorEnv env if we create a gym env with envpool installed." + ) else: Warning("Envpool not installed, skipping test.") - assert isinstance( - env, gym.vector.SyncVectorEnv - ), "Mighty environment should be a SyncVectorEnv if we create a gym env without envpool installed." - assert isinstance( - env, gym.vector.SyncVectorEnv - ), "Eval environment should be a SyncVectorEnv if we create a gym env without envpool installed." + assert isinstance(env, gym.vector.SyncVectorEnv), ( + "Mighty environment should be a SyncVectorEnv if we create a gym env without envpool installed." + ) + assert isinstance(env, gym.vector.SyncVectorEnv), ( + "Eval environment should be a SyncVectorEnv if we create a gym env without envpool installed." + ) for config in [ self.dacbench_config, diff --git a/test/update/test_ppo_update.py b/test/update/test_ppo_update.py index 1c66bad0..009524ac 100644 --- a/test/update/test_ppo_update.py +++ b/test/update/test_ppo_update.py @@ -213,9 +213,9 @@ def test_discrete_action_update(self): ] for metric in required_metrics: assert metric in metrics, f"Missing metric: {metric}" - assert isinstance( - metrics[metric], (int, float) - ), f"Metric {metric} should be numeric" + assert isinstance(metrics[metric], (int, float)), ( + f"Metric {metric} should be numeric" + ) def test_continuous_action_update(self): """Test PPO update with continuous actions.""" @@ -261,9 +261,9 @@ def test_continuous_action_update(self): ] for metric in required_metrics: assert metric in metrics, f"Missing metric: {metric}" - assert isinstance( - metrics[metric], (int, float) - ), f"Metric {metric} should be numeric" + assert isinstance(metrics[metric], (int, float)), ( + f"Metric {metric} should be numeric" + ) def test_value_clipping(self): """Test value clipping mechanism.""" @@ -314,15 +314,11 @@ def test_adaptive_learning_rate(self): """Test adaptive learning rate adjustment.""" update, model = self.get_update_and_model(adaptive_lr=True, kl_target=0.01) - # Store initial learning rates - initial_policy_lr = update.optimizer.param_groups[0]["lr"] - initial_value_lr = update.optimizer.param_groups[1]["lr"] - # Create batch that might trigger LR adaptation batch = DummyMaxiBatch() # Run update - metrics = update.update(batch) + update.update(batch) # Learning rates might have changed (depending on KL divergence) final_policy_lr = update.optimizer.param_groups[0]["lr"] @@ -389,8 +385,8 @@ def test_metric_shapes_and_types(self, continuous_action): for metric_name in expected_metrics: assert metric_name in metrics, f"Missing metric: {metric_name}" metric_value = metrics[metric_name] - assert isinstance( - metric_value, (int, float) - ), f"Metric {metric_name} should be scalar" + assert isinstance(metric_value, (int, float)), ( + f"Metric {metric_name} should be scalar" + ) assert not np.isnan(metric_value), f"Metric {metric_name} should not be NaN" assert np.isfinite(metric_value), f"Metric {metric_name} should be finite" diff --git a/test/update/test_q_update.py b/test/update/test_q_update.py index 17f83185..a7a20418 100644 --- a/test/update/test_q_update.py +++ b/test/update/test_q_update.py @@ -54,12 +54,12 @@ def test_update(self, weight, bias): """Test Q-learning update.""" update, model = self.get_update(initial_weights=weight, initial_biases=bias) checked_model = deepcopy(model) - assert torch.allclose( - model.layer.weight, checked_model.layer.weight - ), "Wrong initial weights." - assert torch.allclose( - model.layer.bias, checked_model.layer.bias - ), "Wrong initial biases." + assert torch.allclose(model.layer.weight, checked_model.layer.weight), ( + "Wrong initial weights." + ) + assert torch.allclose(model.layer.bias, checked_model.layer.bias), ( + "Wrong initial biases." + ) preds, targets = update.get_targets(batch, model) loss_stats = update.apply_update(preds, targets) @@ -77,9 +77,9 @@ def test_update(self, weight, bias): assert torch.allclose( model.layer.weight, checked_model.layer.weight, atol=1e-3 ), "Wrong weights after update" - assert torch.allclose( - model.layer.bias, checked_model.layer.bias, atol=1e-3 - ), "Wrong biases after update." + assert torch.allclose(model.layer.bias, checked_model.layer.bias, atol=1e-3), ( + "Wrong biases after update." + ) def test_get_targets(self): """Test get_targets method.""" @@ -148,9 +148,9 @@ def test_get_targets(self): correct_targets = batch.rewards.unsqueeze(-1) + mask * 0.99 * target( torch.as_tensor(batch.next_obs, dtype=torch.float32) ).max(1)[0].unsqueeze(-1) - assert torch.allclose( - targets.detach(), correct_targets.type(torch.float32) - ), "Wrong targets (weight 0)." + assert torch.allclose(targets.detach(), correct_targets.type(torch.float32)), ( + "Wrong targets (weight 0)." + ) update, model, target = self.get_update(initial_weights=3) preds, targets = update.get_targets(batch, model, target) @@ -193,12 +193,10 @@ def test_get_targets(self): ~batch.dones.unsqueeze(-1) ) * 0.99 * torch.minimum( batch.next_obs.sum(axis=1), torch.zeros(batch.next_obs.shape).sum(axis=1) - ).unsqueeze( - -1 + ).unsqueeze(-1) + assert torch.allclose(targets.detach(), correct_targets.type(torch.float32)), ( + "Wrong targets (weight 0)." ) - assert torch.allclose( - targets.detach(), correct_targets.type(torch.float32) - ), "Wrong targets (weight 0)." update, model, target = self.get_update(initial_weights=3) preds, targets = update.get_targets(batch, model, target) diff --git a/test/update/test_sac_update.py b/test/update/test_sac_update.py index 009ea078..10328444 100644 --- a/test/update/test_sac_update.py +++ b/test/update/test_sac_update.py @@ -133,7 +133,7 @@ def test_initialization(self): # Check hyperparameters assert update.gamma == 0.99 assert update.tau == 0.005 - assert update.auto_alpha == True + assert update.auto_alpha # Check new frequency parameters assert hasattr(update, "policy_frequency") @@ -144,7 +144,7 @@ def test_initialization_without_auto_alpha(self): """Test SAC initialization with fixed alpha.""" update, model = self.get_update_and_model(auto_alpha=False, alpha=0.1) - assert update.auto_alpha == False + assert not update.auto_alpha assert update.alpha == 0.1 # Should use the provided alpha value assert not hasattr(update, "log_alpha") assert not hasattr(update, "alpha_optimizer") @@ -187,18 +187,18 @@ def test_basic_update(self): # Check metrics - updated to include alpha_loss required_metrics = [ - "q_loss1", - "q_loss2", - "policy_loss", - "alpha_loss", - "td_error1", - "td_error2", + "Update/q_loss1", + "Update/q_loss2", + "Update/policy_loss", + "Update/alpha_loss", + "Update/td_error1", + "Update/td_error2", ] for metric in required_metrics: assert metric in metrics, f"Missing metric: {metric}" - assert isinstance( - metrics[metric], (int, float) - ), f"Metric {metric} should be numeric" + assert isinstance(metrics[metric], (int, float)), ( + f"Metric {metric} should be numeric" + ) assert np.isfinite(metrics[metric]), f"Metric {metric} should be finite" def test_target_network_updates(self): @@ -210,10 +210,6 @@ def test_target_network_updates(self): initial_target_q1 = [p.clone() for p in model.target_q_net1.parameters()] initial_target_q2 = [p.clone() for p in model.target_q_net2.parameters()] - # Store initial main Q parameters - initial_q1 = [p.clone() for p in model.q_net1.parameters()] - initial_q2 = [p.clone() for p in model.q_net2.parameters()] - # Run update update.update(batch) @@ -239,9 +235,9 @@ def test_target_network_updates(self): initial_target_q1, ): expected = (1 - tau) * p_old_target + tau * p_main - assert torch.allclose( - p_target, expected, atol=1e-5 - ), "Target update should follow polyak averaging" + assert torch.allclose(p_target, expected, atol=1e-5), ( + "Target update should follow polyak averaging" + ) def test_td_error_calculation(self): """Test TD error calculation.""" @@ -285,17 +281,17 @@ def test_fixed_alpha_mode(self): # Metrics should still be valid - but alpha_loss should be 0 required_metrics = [ - "q_loss1", - "q_loss2", - "policy_loss", - "alpha_loss", - "td_error1", - "td_error2", + "Update/q_loss1", + "Update/q_loss2", + "Update/policy_loss", + "Update/alpha_loss", + "Update/td_error1", + "Update/td_error2", ] for metric in required_metrics: assert metric in metrics # Alpha loss should be 0 when auto_alpha=False - assert metrics["alpha_loss"] == 0.0 + assert metrics["Update/alpha_loss"] == 0.0 def test_custom_target_entropy(self): """Test SAC with custom target entropy.""" @@ -308,7 +304,7 @@ def test_custom_target_entropy(self): batch = DummyTransitionBatch() metrics = update.update(batch) - assert "policy_loss" in metrics + assert "Update/policy_loss" in metrics def test_different_learning_rates(self): """Test SAC with different learning rates for policy and Q-networks.""" @@ -358,9 +354,9 @@ def test_different_tau_values(self): ) ) - assert ( - total_change_large > total_change - ), "Larger tau should cause bigger target network changes" + assert total_change_large > total_change, ( + "Larger tau should cause bigger target network changes" + ) def test_zero_rewards_batch(self): """Test SAC with zero rewards.""" @@ -372,9 +368,9 @@ def test_zero_rewards_batch(self): metrics = update.update(batch) for metric_name, metric_value in metrics.items(): - assert np.isfinite( - metric_value - ), f"Metric {metric_name} should be finite with zero rewards" + assert np.isfinite(metric_value), ( + f"Metric {metric_name} should be finite with zero rewards" + ) def test_all_done_batch(self): """Test SAC with all episodes terminated.""" @@ -386,9 +382,9 @@ def test_all_done_batch(self): metrics = update.update(batch) for metric_name, metric_value in metrics.items(): - assert np.isfinite( - metric_value - ), f"Metric {metric_name} should be finite with all done" + assert np.isfinite(metric_value), ( + f"Metric {metric_name} should be finite with all done" + ) def test_metric_ranges(self): """Test that metrics are in reasonable ranges.""" @@ -398,18 +394,20 @@ def test_metric_ranges(self): metrics = update.update(batch) # Q losses should be non-negative (MSE loss) - assert metrics["q_loss1"] >= 0, "Q loss 1 should be non-negative" - assert metrics["q_loss2"] >= 0, "Q loss 2 should be non-negative" + assert metrics["Update/q_loss1"] >= 0, "Q loss 1 should be non-negative" + assert metrics["Update/q_loss2"] >= 0, "Q loss 2 should be non-negative" # Policy loss can be negative (we want to maximize Q - alpha*entropy) - assert np.isfinite(metrics["policy_loss"]), "Policy loss should be finite" + assert np.isfinite(metrics["Update/policy_loss"]), ( + "Policy loss should be finite" + ) # Alpha loss can be positive or negative - assert np.isfinite(metrics["alpha_loss"]), "Alpha loss should be finite" + assert np.isfinite(metrics["Update/alpha_loss"]), "Alpha loss should be finite" # TD errors can be positive or negative - assert np.isfinite(metrics["td_error1"]), "TD error 1 should be finite" - assert np.isfinite(metrics["td_error2"]), "TD error 2 should be finite" + assert np.isfinite(metrics["Update/td_error1"]), "TD error 1 should be finite" + assert np.isfinite(metrics["Update/td_error2"]), "TD error 2 should be finite" @pytest.mark.parametrize("batch_size", [1, 16, 64, 128]) def test_different_batch_sizes(self, batch_size): @@ -421,12 +419,12 @@ def test_different_batch_sizes(self, batch_size): metrics = update.update(batch) required_metrics = [ - "q_loss1", - "q_loss2", - "policy_loss", - "alpha_loss", - "td_error1", - "td_error2", + "Update/q_loss1", + "Update/q_loss2", + "Update/policy_loss", + "Update/alpha_loss", + "Update/td_error1", + "Update/td_error2", ] for metric in required_metrics: assert metric in metrics @@ -445,12 +443,12 @@ def test_policy_frequency(self): # Run updates less than policy_frequency - policy shouldn't change for i in range(policy_freq - 1): metrics = update.update(batch) - assert ( - metrics["policy_loss"] == 0.0 - ), "Policy loss should be 0 when no policy update" - assert ( - metrics["alpha_loss"] == 0.0 - ), "Alpha loss should be 0 when no policy update" + assert metrics["Update/policy_loss"] == 0.0, ( + "Policy loss should be 0 when no policy update" + ) + assert metrics["Update/alpha_loss"] == 0.0, ( + "Alpha loss should be 0 when no policy update" + ) # Policy parameters shouldn't have changed yet policy_unchanged = all( @@ -459,16 +457,16 @@ def test_policy_frequency(self): ) alpha_unchanged = torch.allclose(initial_alpha, update.log_alpha, atol=1e-6) - assert ( - policy_unchanged - ), "Policy parameters should not change before policy_frequency" + assert policy_unchanged, ( + "Policy parameters should not change before policy_frequency" + ) assert alpha_unchanged, "Alpha should not change before policy_frequency" # Now run one more update - should trigger policy update metrics = update.update(batch) - assert ( - metrics["policy_loss"] != 0.0 - ), "Policy loss should be non-zero when policy updates" + assert metrics["Update/policy_loss"] != 0.0, ( + "Policy loss should be non-zero when policy updates" + ) # Policy parameters should have changed now policy_changed = any( @@ -507,14 +505,14 @@ def test_gradient_flow(self): # Use more lenient tolerance since SAC updates can be small if change > 1e-8: # Much more lenient than 1e-6 changed_params += 1 - + # At least some parameters should change change_ratio = changed_params / total_params - assert ( - change_ratio > 0.1 - ), f"Only {change_ratio:.2%} of parameters changed, gradient flow might be broken. Changes: {param_changes}" + assert change_ratio > 0.1, ( + f"Only {change_ratio:.2%} of parameters changed, gradient flow might be broken. Changes: {param_changes}" + ) # Additional check: ensure losses are reasonable - assert np.isfinite(metrics["q_loss1"]) and metrics["q_loss1"] >= 0 - assert np.isfinite(metrics["q_loss2"]) and metrics["q_loss2"] >= 0 - assert np.isfinite(metrics["policy_loss"]) + assert np.isfinite(metrics["Update/q_loss1"]) and metrics["Update/q_loss1"] >= 0 + assert np.isfinite(metrics["Update/q_loss2"]) and metrics["Update/q_loss2"] >= 0 + assert np.isfinite(metrics["Update/policy_loss"]) From 3f8b2de79b2159e2b3d854135a88dad19f6bc3a3 Mon Sep 17 00:00:00 2001 From: Theresa Eimer Date: Thu, 31 Jul 2025 12:44:08 +0200 Subject: [PATCH 5/5] dacbench need torchvision --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9e7ff247..05e97863 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,7 @@ dependencies = [ [project.optional-dependencies] dev = ["ruff", "mypy", "build", "pytest", "pytest-cov"] carl = ["carl_bench[brax]==1.1.1"] -dacbench = ["dacbench==0.4.0"] +dacbench = ["dacbench==0.4.0", "torchvision"] pufferlib = ["pufferlib==2.0.6"] docs = ["mkdocs", "mkdocs-material", "mkdocs-autorefs", "mkdocs-gen-files", "mkdocs-literate-nav",