Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
<div align="center">

[![PyPI Version](https://img.shields.io/pypi/v/mighty-rl.svg)](https://pypi.org/project/Mighty-RL/)
![Python](https://img.shields.io/badge/Python-3.10-3776AB)
![Python](https://img.shields.io/badge/Python-3.11-3776AB)
![License](https://img.shields.io/badge/License-BSD3-orange)
[![Test](https://github.com/automl/Mighty/actions/workflows/test.yaml/badge.svg)](https://github.com/automl/Mighty/actions/workflows/test.yaml)
[![Doc Status](https://github.com/automl/Mighty/actions/workflows/docs_test.yaml/badge.svg)](https://github.com/automl/Mighty/actions/workflows/docs_test.yaml)
Expand Down Expand Up @@ -42,12 +42,12 @@ Mighty features:

## Installation
We recommend to using uv to install and run Mighty in a virtual environment.
The code has been tested with python 3.10/3.11 on Unix systems.
The code has been tested with python 3.11 on Unix systems.

First create a clean python environment:

```bash
uv venv --python=3.10
uv venv --python=3.11
source .venv/bin/activate
```

Expand Down
2 changes: 1 addition & 1 deletion docs/installation.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
We recommend to using uv to install and run Mighty in a virtual environment.
The code has been tested with python 3.10.
The code has been tested with python 3.11 on Unix Systems.

First create a clean python environment:

Expand Down
4 changes: 2 additions & 2 deletions mighty/configs/base.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
defaults:
- _self_
- algorithm: sac_mujoco
- /cluster: local
- algorithm: ppo
- environment: pufferlib_ocean/bandit
# - search_space: dqn_gym_classic
- override hydra/job_logging: colorlog
- override hydra/hydra_logging: colorlog
- override hydra/help: mighty_help
Expand Down
11 changes: 11 additions & 0 deletions mighty/configs/cluster/example_cpu.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# @package _global_
defaults:
- override /hydra/launcher: submitit_slurm

hydra:
launcher:
partition: your_partition
cpus_per_task: 1
name: mighty_cpu_experiment
timeout_min: 30
mem_gb: 10
12 changes: 12 additions & 0 deletions mighty/configs/cluster/example_gpu.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# @package _global_
defaults:
- override /hydra/launcher: submitit_slurm

hydra:
launcher:
partition: your_partition
gpus_per_task: 1
gres: "gpu:1"
timeout_min: 30
mem_gb: 10
name: mighty_gpu_experiment
13 changes: 0 additions & 13 deletions mighty/configs/cluster/local.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +0,0 @@
# @package _global_
# defaults:
# - override /hydra/launcher: joblib

# hydra:
# launcher:
# n_jobs: 16

cluster:
_target_: distributed.deploy.local.LocalCluster
n_workers: ${hydra.sweeper.scenario.n_workers}
processes: false
threads_per_worker: 1
17 changes: 0 additions & 17 deletions mighty/configs/cluster/luis.yaml

This file was deleted.

25 changes: 0 additions & 25 deletions mighty/configs/cluster/noctua.yaml

This file was deleted.

15 changes: 0 additions & 15 deletions mighty/configs/cluster/tnt.yaml

This file was deleted.

1 change: 1 addition & 0 deletions mighty/configs/cmaes_hpo.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
defaults:
- _self_
- /cluster: local
- algorithm: dqn
- environment: pufferlib_ocean/bandit
- search_space: dqn_gym_classic
Expand Down
1 change: 1 addition & 0 deletions mighty/configs/nes.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
defaults:
- _self_
- /cluster: local
- algorithm: dqn
- environment: pufferlib_ocean/bandit
- search_space: dqn_gym_classic
Expand Down
5 changes: 3 additions & 2 deletions mighty/configs/ppo_smac.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
defaults:
- _self_
- /cluster: local
- algorithm: ppo_mujoco
- environment: gymnasium/pendulum
- search_space: ppo_rs
Expand Down Expand Up @@ -45,6 +46,6 @@ hydra:
output_directory: ${hydra.sweep.dir}
search_space: ${search_space}
run:
dir: ./tmp/branin_smac/
dir: ${output_dir}/${experiment_name}_${seed}
sweep:
dir: ./tmp/branin_smac/
dir: ${output_dir}/${experiment_name}_${seed}
5 changes: 3 additions & 2 deletions mighty/configs/sac_smac.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
defaults:
- _self_
- /cluster: local
- algorithm: sac_mujoco
- environment: gymnasium/pendulum
- search_space: sac_rs
Expand Down Expand Up @@ -45,6 +46,6 @@ hydra:
output_directory: ${hydra.sweep.dir}
search_space: ${search_space}
run:
dir: ./tmp/branin_smac/
dir: ${output_dir}/${experiment_name}_${seed}
sweep:
dir: ./tmp/branin_smac/
dir: ${output_dir}/${experiment_name}_${seed}
1 change: 1 addition & 0 deletions mighty/configs/sweep_ppo_pbt.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
defaults:
- _self_
- /cluster: local
- algorithm: ppo
- environment: gymnasium/pendulum
- search_space: ppo_rs
Expand Down
1 change: 1 addition & 0 deletions mighty/configs/sweep_rs.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
defaults:
- _self_
- /cluster: local
- algorithm: ppo
- environment: gymnasium/pendulum
- search_space: ppo_rs
Expand Down
88 changes: 20 additions & 68 deletions mighty/mighty_agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,52 +34,14 @@
from gymnasium.wrappers.normalize import NormalizeObservation, NormalizeReward


def seed_everything(seed: int, env: gym.Env | None = None):
"""
Seed Python, NumPy, Torch (including cuDNN), plus Gym (action_space, observation_space,
and the environment's own RNG). If `env` is vectorized (has `env.envs`), we dig into each
sub-env. Always call this before ANY network-building or RNG usage.
"""
# 1) Python/NumPy/Torch/Hash
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True) # enforce strictly deterministic ops
os.environ["PYTHONHASHSEED"] = str(seed)

# 2) Gym environment seeding
if env is not None:
# If the env is wrapped, try to unwrap to the core
try:
core_env = env.unwrapped
core_env.seed(seed)
except Exception:
pass

# If vectorized (e.g. SyncVectorEnv), seed each sub‐env separately
if hasattr(env, "envs") and isinstance(env.envs, list):
sub_seeds = [seed for _ in range(len(env.envs))]
for sub_seed, subenv in zip(sub_seeds, env.envs):
subenv.action_space.seed(sub_seed)
subenv.observation_space.seed(sub_seed)
try:
subenv.unwrapped.seed(sub_seed)
except Exception:
pass
# Reset the vectorized env with explicit seeds
env.reset(seed=sub_seeds)
else:
# Single environment
env.action_space.seed(seed)
env.observation_space.seed(seed)
try:
env.unwrapped.seed(seed)
except Exception:
pass
env.reset(seed=seed)
def seed_env_spaces(env: gym.VectorEnv, seed: int) -> None:
env.action_space.seed(seed)
env.single_action_space.seed(seed)
env.observation_space.seed(seed)
env.single_observation_space.seed(seed)
for i in range(len(env.envs)):
env.envs[i].action_space.seed(seed)
env.envs[i].observation_space.seed(seed)


def update_buffer(buffer, new_data):
Expand Down Expand Up @@ -218,28 +180,12 @@ def __init__( # noqa: PLR0915, PLR0912

self.seed = seed
if self.seed is not None:
seed_everything(self.seed, env=None) # type: ignore
# Re-seed Python/NumPy/Torch here again.
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ["PYTHONHASHSEED"] = str(seed)

# Also seed any RNGs you will use later (e.g. for buffer / policy):
self.rng = np.random.default_rng(seed) # for any numpy-based sampling
# If you ever use torch.Generator for sampling actions, you could do:
self.torch_gen = torch.Generator().manual_seed(seed)

# If you use Python’s random inside your policy, call random.seed(seed)
random.seed(seed)

else:
self.rng = np.random.default_rng()
self.seed = 0

# Replay Buffer
replay_buffer_class = retrieve_class(
cls=replay_buffer_class,
Expand Down Expand Up @@ -279,6 +225,10 @@ def __init__( # noqa: PLR0915, PLR0912
else:
self.eval_env = eval_env

if self.seed is not None:
seed_env_spaces(self.env, self.seed)
seed_env_spaces(self.eval_env, self.seed)

self.render_progress = render_progress
self.output_dir = output_dir
if self.output_dir is not None:
Expand All @@ -288,9 +238,9 @@ def __init__( # noqa: PLR0915, PLR0912
self.meta_modules = {}
for i, m in enumerate(meta_methods):
meta_class = retrieve_class(cls=m, default_cls=None) # type: ignore
assert (
meta_class is not None
), f"Class {m} not found, did you specify the correct loading path?"
assert meta_class is not None, (
f"Class {m} not found, did you specify the correct loading path?"
)
kwargs: Dict = {}
if len(meta_kwargs) > i:
kwargs = meta_kwargs[i]
Expand Down Expand Up @@ -345,11 +295,13 @@ def __init__( # noqa: PLR0915, PLR0912
wandb.log(starting_hps)

self.initialize_agent()
if self.seed is not None:
self.buffer.seed(self.seed)
self.policy.seed(self.seed)
for m in self.meta_modules.values():
m.seed(self.seed)
self.steps = 0

seed_everything(self.seed, self.env) # type: ignore
seed_everything(self.seed, self.eval_env) # type: ignore

def _initialize_agent(self) -> None:
"""Agent/algorithm specific initializations."""
raise NotImplementedError
Expand Down
10 changes: 5 additions & 5 deletions mighty/mighty_agents/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,11 @@ def __init__(

# Initialize loss buffer for logging
self.loss_buffer = {
"q_loss1": [],
"q_loss2": [],
"policy_loss": [],
"td_error1": [],
"td_error2": [],
"Update/q_loss1": [],
"Update/q_loss2": [],
"Update/policy_loss": [],
"Update/td_error1": [],
"Update/td_error2": [],
"step": [],
}

Expand Down
4 changes: 4 additions & 0 deletions mighty/mighty_exploration/mighty_exploration_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ def __init__(
else:
self.sample_action = self.sample_func_logits

def seed(self, seed: int) -> None:
"""Set the random seed for reproducibility."""
self.rng = np.random.default_rng(seed)

def sample_func_q(self, state_array):
"""
Q-learning branch:
Expand Down
7 changes: 7 additions & 0 deletions mighty/mighty_meta/mighty_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import annotations

import numpy as np


class MightyMetaComponent:
"""Component for registering meta-control methods."""
Expand All @@ -17,6 +19,7 @@ def __init__(self) -> None:
self.post_update_methods = []
self.pre_episode_methods = []
self.post_episode_methods = []
self.rng = np.random.default_rng()

def pre_step(self, metrics):
"""Execute methods before a step.
Expand Down Expand Up @@ -71,3 +74,7 @@ def post_episode(self, metrics):
"""
for m in self.post_episode_methods:
m(metrics)

def seed(self, seed: int) -> None:
"""Set the random seed for reproducibility."""
self.rng = np.random.default_rng(seed)
1 change: 0 additions & 1 deletion mighty/mighty_meta/plr.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def __init__(
:return:
"""
super().__init__()
self.rng = np.random.default_rng()
self.alpha = alpha
self.rho = rho
self.staleness_coef = staleness_coeff
Expand Down
Loading
Loading