diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml
new file mode 100644
index 0000000..622041c
--- /dev/null
+++ b/.github/workflows/docker.yml
@@ -0,0 +1,43 @@
+name: Build Docker Image
+
+on: [push]
+
+jobs:
+ docker-build:
+ runs-on: ubuntu-latest
+ env:
+ IMAGE_NAME: ghcr.io/${{ github.repository_owner }}/posggym:latest
+
+
+ permissions:
+ contents: read
+ packages: write
+
+ steps:
+ - name: Checkout Repository
+ uses: actions/checkout@v3
+ with:
+ submodules: recursive
+
+ - name: Set up Docker Buildx
+ uses: docker/setup-buildx-action@v3
+
+ - name: Log in to GitHub Container Registry
+ uses: docker/login-action@v3
+ with:
+ registry: ghcr.io
+ username: ${{ github.actor }}
+ password: ${{ secrets.GITHUB_TOKEN }}
+
+ - name: Normalize tag name
+ run: echo "IMAGE_TAG=$(echo $IMAGE_NAME | tr '[:upper:]' '[:lower:]')" >> $GITHUB_ENV
+
+ - name: Build and Push Docker Image
+ uses: docker/build-push-action@v5
+ with:
+ context: .
+ file: Dockerfile
+ push: true
+ tags: ${{ env.IMAGE_TAG }}
+ cache-from: type=registry,ref=${{ env.IMAGE_TAG }}
+ cache-to: type=registry,ref=${{ env.IMAGE_TAG }},mode=max
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index e29ba55..9a10e74 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -5,14 +5,24 @@ on: [push]
jobs:
test:
runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ python-version: ["3.10", "3.11", "3.12"]
steps:
- name: Checkout code
uses: actions/checkout@v2
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python-version }}
- - name: pip install
- run: pip install --upgrade pip && pip install --user -e .[all] && pip install --user -e .[testing]
+ - name: Install dependencies
+ run: |
+ pip install --upgrade pip
+ pip install --user -e .[all]
+ pip install --user -e .[testing]
- name: Run tests
- run : pytest
+ run: pytest
diff --git a/.gitignore b/.gitignore
index 3d12201..fbf9187 100644
--- a/.gitignore
+++ b/.gitignore
@@ -140,3 +140,4 @@ dmypy.json
# Ruff linter
.ruff_cache/
+*.pickle
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 9209a0b..2c41ba6 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -30,9 +30,12 @@ repos:
- id: black
- repo: https://github.com/charliermarsh/ruff-pre-commit
# Ruff version.
- rev: 'v0.0.254'
+ rev: 'v0.4.10'
hooks:
- id: ruff
+ args:
+ - --fix
+ - --unsafe-fixes
- repo: local
hooks:
- id: pyright
diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 0000000..1ca3ca8
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,12 @@
+FROM pytorch/pytorch:2.2.2-cuda11.8-cudnn8-runtime
+
+WORKDIR /app
+
+COPY pyproject.toml ./
+COPY setup.py ./
+COPY ./posggym/__init__.py /app/posggym/__init__.py
+
+RUN pip install -e .[all]
+
+
+COPY . .
diff --git a/docs/conf.py b/docs/conf.py
index 5176e7a..51b017c 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -7,6 +7,7 @@
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
import posggym
+
project = "POSGGym"
copyright = "2023, Jonathon Schwartz"
author = "Jonathon Schwartz"
diff --git a/docs/scripts/gen_agent_gifs.py b/docs/scripts/gen_agent_gifs.py
index ac2d3ba..298ee05 100644
--- a/docs/scripts/gen_agent_gifs.py
+++ b/docs/scripts/gen_agent_gifs.py
@@ -9,7 +9,7 @@
import re
from pathlib import Path
from pprint import pprint
-from typing import Any, Dict, List
+from typing import Any
import posggym
import posggym.agents as pga
@@ -27,7 +27,7 @@
def gen_agent_gif(
env_id: str,
- policy_ids: List[str],
+ policy_ids: list[str],
ignore_existing: bool = False,
length: int = 300,
custom_env: bool = False,
@@ -43,7 +43,7 @@ def gen_agent_gif(
for policy_id in policy_ids:
try:
pi_spec = pga.spec(policy_id)
- except posggym.error.NameNotFound as e:
+ except posggym.error.NameNotFoundError as e:
if "/" not in policy_id:
# try prepending env id
policy_id = f"{env_id}/{policy_id}"
@@ -65,8 +65,6 @@ def gen_agent_gif(
env = posggym.make(
env_id, disable_env_checker=True, render_mode="rgb_array", **env_args
)
- # env = posggym.wrappers.RescaleObservations(env, min_obs=-1.0, max_obs=1.0)
- # env = posggym.wrappers.RescaleActions(env, min_action=-1.0, max_action=1.0)
policies = {}
for idx, spec in enumerate(policy_specs):
@@ -107,7 +105,7 @@ def gen_agent_gif(
for _ in range(repeat):
frames.append(Image.fromarray(frame))
- actions: Dict[str, Any] = {}
+ actions: dict[str, Any] = {}
for i in env.agents:
if policies[i].observes_state:
actions[i] = policies[i].step(env.state)
diff --git a/docs/scripts/gen_agent_mds.py b/docs/scripts/gen_agent_mds.py
index 86180d6..68ba695 100644
--- a/docs/scripts/gen_agent_mds.py
+++ b/docs/scripts/gen_agent_mds.py
@@ -5,7 +5,6 @@
"""
import re
-from typing import Dict, List
from pathlib import Path
import posggym
@@ -22,7 +21,7 @@
all_agents = list(pga.registry.values())
# env_type -> env_name -> [PolicySpec]
-filtered_agents_by_env_type: Dict[str, Dict[str, List[PolicySpec]]] = {}
+filtered_agents_by_env_type: dict[str, dict[str, list[PolicySpec]]] = {}
# Obtain filtered list
for pi_spec in tqdm(all_agents):
@@ -100,7 +99,7 @@
else:
info = (
"These policies are for the "
- + f""
+ f""
f"{title_env_name} environment. Read environment page for detailed "
"information about the environment."
)
@@ -124,7 +123,7 @@
env_args_ids.sort()
if None in filtered_agents_by_env_args_id:
- env_args_ids = [None] + env_args_ids
+ env_args_ids = [None, *env_args_ids]
for env_args_id in env_args_ids:
policy_specs = filtered_agents_by_env_args_id[env_args_id]
diff --git a/docs/scripts/gen_env_mds.py b/docs/scripts/gen_env_mds.py
index 4dfb511..e8e1b60 100644
--- a/docs/scripts/gen_env_mds.py
+++ b/docs/scripts/gen_env_mds.py
@@ -9,20 +9,21 @@
import re
from functools import reduce
-from typing import Dict, List
from pathlib import Path
-from tqdm import tqdm
-from utils import kill_strs, trim
import posggym
from posggym.envs.registration import EnvSpec
+from tqdm import tqdm
+
+from utils import kill_strs, trim
+
pattern = re.compile(r"(?{env_type_title} environments. "
- + "Please read that page first for general information."
+ f"{env_type_title} environments. "
+ "Please read that page first for general information."
)
act_spaces_str = str(env.action_spaces)
@@ -144,32 +145,17 @@
env_table += f"| Symmetric | {env.is_symmetric} |\n"
# if env.observation_space.shape:
- # env_table += f"| Observation Shape | {env.observation_space.shape} |\n"
# if hasattr(env.observation_space, "high"):
- # high = env.observation_space.high
# if hasattr(high, "shape"):
# if len(high.shape) == 3:
- # high = high[0][0][0]
# if env_type == "mujoco":
- # high = high[0]
- # high = np.round(high, 2)
- # high = str(high).replace("\n", " ")
- # env_table += f"| Observation High | {high} |\n"
# if hasattr(env.observation_space, "low"):
- # low = env.observation_space.low
# if hasattr(low, "shape"):
# if len(low.shape) == 3:
- # low = low[0][0][0]
# if env_type == "mujoco":
- # low = low[0]
- # low = np.round(low, 2)
- # low = str(low).replace("\n", " ")
- # env_table += f"| Observation Low | {low} |\n"
- # else:
- # env_table += f"| Observation Space | {env.observation_space} |\n"
env_table += f'| Import | `posggym.make("{env_spec.id}")` |\n'
diff --git a/docs/scripts/gen_envs_display.py b/docs/scripts/gen_envs_display.py
index 8669435..55e7223 100644
--- a/docs/scripts/gen_envs_display.py
+++ b/docs/scripts/gen_envs_display.py
@@ -7,6 +7,7 @@
import sys
from pathlib import Path
+
DOCS_DIR = Path(__file__).resolve().parent.parent
all_envs = [
@@ -100,7 +101,7 @@ def generate_page(env, limit=-1, base_path=""):
type_arg = sys.argv[1]
for env in all_envs:
- if type_arg == env["id"] or type_arg == "":
+ if type_arg in {env["id"], ""}:
type_dict_arr.append(env)
for type_dict in type_dict_arr:
@@ -127,7 +128,7 @@ def generate_page(env, limit=-1, base_path=""):
env_name = " ".join(type_id.split("_")).title()
fp.write(
f"# Complete List - {env_name}\n\n"
- + "```{raw} html\n:file: complete_list.html\n```"
+ "```{raw} html\n:file: complete_list.html\n```"
)
else:
page = generate_page(type_dict)
diff --git a/docs/scripts/gen_gifs.py b/docs/scripts/gen_gifs.py
index d0cbae3..0190eef 100644
--- a/docs/scripts/gen_gifs.py
+++ b/docs/scripts/gen_gifs.py
@@ -71,7 +71,7 @@ def gen_gif(
repeat = (
int(60 / env.metadata["render_fps"]) if env_type == "classic" else 1
)
- for i in range(repeat):
+ for _i in range(repeat):
frames.append(Image.fromarray(frame))
action = {i: env.action_spaces[i].sample() for i in env.agents}
_, _, _, _, done, _ = env.step(action)
diff --git a/examples/custom_envs/hurdle_race.py b/examples/custom_envs/hurdle_race.py
index a4fdbfa..7a67c11 100644
--- a/examples/custom_envs/hurdle_race.py
+++ b/examples/custom_envs/hurdle_race.py
@@ -1,4 +1,4 @@
-"""Race for glory in HurdleRace!
+"""Race for glory in HurdleRace!.
This file contains an example of a simple custom POSGGym environment, and can be used
as a reference for implementing your own.
@@ -10,19 +10,19 @@
"""
from __future__ import annotations
-from typing import Any, Dict, List, Tuple
+from typing import Any, ClassVar
import numpy as np
import posggym
import posggym.model as M
-import posggym.utils.seeding as seeding
from gymnasium import spaces
+from posggym.utils import seeding
try:
import pygame
except ImportError as e:
- raise posggym.error.DependencyNotInstalled(
+ raise posggym.error.DependencyNotInstalledError(
"pygame is not installed, run `pip install pygame` or "
"`pip install posggym[all]`"
) from e
@@ -31,7 +31,7 @@
# The type of an individual states
# This is used for type hinting, and is optional, but encouraged if you plan to share
# your environment with others
-HurdleRaceState = Tuple[int, int, int, int, int]
+HurdleRaceState = tuple[int, int, int, int, int]
class HurdleRaceEnv(posggym.DefaultEnv[HurdleRaceState, int, int]):
@@ -95,9 +95,9 @@ class HurdleRaceEnv(posggym.DefaultEnv[HurdleRaceState, int, int]):
# Here we specify the meta-data, this should include as a minimum:
# 'render_modes' - the render modes supported by the environment
# 'render_fps' - the render framerate to use
- metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4}
+ metadata: ClassVar[dict] = {"render_modes": ["human", "rgb_array"], "render_fps": 4}
- def __init__(self, render_mode: str | None):
+ def __init__(self, render_mode: str | None) -> None:
model = HurdleRaceModel()
"""
@@ -157,7 +157,7 @@ def _render_frame(self):
)
# next we draw the agents
- for idx, i in enumerate(self.agents):
+ for idx, _i in enumerate(self.agents):
pygame.draw.circle(
canvas,
color=(0, 0, 255) if idx == 0 else (255, 0, 0),
@@ -228,7 +228,7 @@ class HurdleRaceModel(M.POSGModel[HurdleRaceState, int, int]):
R_DRAW = 0.0
R_LOSS = -1.0
- def __init__(self):
+ def __init__(self) -> None:
# tuple of possible agents in our environment
self.possible_agents = ("0", "1")
# The state space is actually optional to define, but can be helpful for some
@@ -249,7 +249,7 @@ def __init__(self):
self.is_symmetric = True
@property
- def reward_ranges(self) -> Dict[str, Tuple[float, float]]:
+ def reward_ranges(self) -> dict[str, tuple[float, float]]:
# This contains the minimum and maximum reward each agent can receive
return {i: (self.R_LOSS, self.R_WIN) for i in self.possible_agents}
@@ -264,7 +264,7 @@ def rng(self) -> seeding.RNG:
self._rng, seed = seeding.std_random()
return self._rng
- def get_agents(self, state: HurdleRaceState) -> List[str]:
+ def get_agents(self, state: HurdleRaceState) -> list[str]:
# This is the list of agents active in a given state
# For our problem both agents are always active, but for some environments
# agents may leave or join (e.g. via finishing early) and so the active agents
@@ -284,7 +284,7 @@ def sample_initial_state(self) -> HurdleRaceState:
# the full state
return (agent_0_pos, agent_1_pos, hurdle_0_pos, hurdle_1_pos, hurdle_2_pos)
- def sample_initial_obs(self, state: HurdleRaceState) -> Dict[str, int]:
+ def sample_initial_obs(self, state: HurdleRaceState) -> dict[str, int]:
# we get the initial observation for an agent (before any action is taken)
# For this environment the observation is independent of action, so this is easy
# each agent observes whether the next cell contains a hurdle or not
@@ -292,7 +292,7 @@ def sample_initial_obs(self, state: HurdleRaceState) -> Dict[str, int]:
return self._get_obs(state)
def step(
- self, state: HurdleRaceState, actions: Dict[str, int]
+ self, state: HurdleRaceState, actions: dict[str, int]
) -> M.JointTimestep[HurdleRaceState, int]:
# first we get the next state
next_state = self._get_next_state(state, actions)
@@ -326,7 +326,7 @@ def step(
)
def _get_next_state(
- self, state: HurdleRaceState, actions: Dict[str, int]
+ self, state: HurdleRaceState, actions: dict[str, int]
) -> HurdleRaceState:
agent_positions = []
for idx, i in enumerate(self.possible_agents):
@@ -337,18 +337,17 @@ def _get_next_state(
if not any(pos + 1 == h_pos for h_pos in state[2:]):
pos += 1
pos = min(self.TRACK_LENGTH, pos)
- else:
+ elif (
+ not any(pos + 1 == h_pos for h_pos in state[2:])
+ or self.rng.random() < self.JUMP_SUCCESS_RATE
+ ):
# JUMP
- if (
- not any(pos + 1 == h_pos for h_pos in state[2:])
- or self.rng.random() < self.JUMP_SUCCESS_RATE
- ):
- pos += 1
+ pos += 1
agent_positions.append(pos)
# the hurdle positions remain unchanged from previous state
return (agent_positions[0], agent_positions[1], *state[2:])
- def _get_obs(self, state: HurdleRaceState) -> Dict[str, int]:
+ def _get_obs(self, state: HurdleRaceState) -> dict[str, int]:
# each agent observes whether the next cell contains a hurdle or not
obs = {}
for idx, i in enumerate(self.possible_agents):
@@ -362,7 +361,7 @@ def _get_obs(self, state: HurdleRaceState) -> Dict[str, int]:
obs[i] = self.HURDLE if hurdle_present else self.NOHURDLE
return obs
- def _get_rewards(self, state: HurdleRaceState) -> Dict[str, float]:
+ def _get_rewards(self, state: HurdleRaceState) -> dict[str, float]:
# agents only receive a reward when at least one agent reaches the end of their
# track, otherwise the step reward is 0 for both agents
agent_0_pos, agent_1_pos = state[0], state[1]
@@ -376,10 +375,10 @@ def _get_rewards(self, state: HurdleRaceState) -> Dict[str, float]:
agent_0_reward, agent_1_reward = 0, 0
return {"0": agent_0_reward, "1": agent_1_reward}
- def _get_info(self, state: HurdleRaceState) -> Dict[str, Dict]:
+ def _get_info(self, state: HurdleRaceState) -> dict[str, dict]:
# we return the position of the agent each step in the auxiliary information
# as well as the final outcome
- infos: Dict[str, Dict[str, Any]] = {
+ infos: dict[str, dict[str, Any]] = {
i: {"pos": state[idx]} for idx, i in enumerate(self.possible_agents)
}
agent_0_pos, agent_1_pos = state[0], state[1]
@@ -423,8 +422,6 @@ def run_hurdle_race():
if __name__ == "__main__":
- # run_hurdle_race()
-
import sys
sys.path.insert(0, "/home/jonathon/code/posggym/docs/scripts")
diff --git a/examples/record_video.py b/examples/record_video.py
index 75f1bb4..fca021d 100644
--- a/examples/record_video.py
+++ b/examples/record_video.py
@@ -18,7 +18,6 @@
import argparse
from pathlib import Path
-from typing import Dict, List, Optional
import posggym
@@ -26,8 +25,8 @@
def record_env(
env_id: str,
num_episodes: int,
- max_episode_steps: Optional[int] = None,
- seed: Optional[int] = None,
+ max_episode_steps: int | None = None,
+ seed: int | None = None,
):
"""Run random agents."""
if max_episode_steps is not None:
@@ -48,7 +47,7 @@ def record_env(
dones = 0
episode_steps = []
- episode_rewards: Dict[str, List[float]] = {i: [] for i in env.possible_agents}
+ episode_rewards: dict[str, list[float]] = {i: [] for i in env.possible_agents}
for ep_num in range(num_episodes):
t = 0
done = False
diff --git a/examples/run_agents.py b/examples/run_agents.py
index 078fb8d..1585e27 100644
--- a/examples/run_agents.py
+++ b/examples/run_agents.py
@@ -15,7 +15,6 @@
"""
import argparse
-from typing import Dict, List, Optional
import posggym
import posggym.agents as pga
@@ -23,19 +22,19 @@
def run_agents(
env_id: str,
- policy_ids: List[str],
+ policy_ids: list[str],
num_episodes: int,
- seed: Optional[int] = None,
- render_mode: Optional[str] = "human",
+ seed: int | None = None,
+ render_mode: str | None = "human",
):
"""Run agents."""
print("\n== Running Agents ==")
policy_specs = []
env_args, env_args_id = None, None
- for i, policy_id in enumerate(policy_ids):
+ for _i, policy_id in enumerate(policy_ids):
try:
pi_spec = pga.spec(policy_id)
- except posggym.error.NameNotFound as e:
+ except posggym.error.NameNotFoundError as e:
if "/" not in policy_id:
# try prepending env id
policy_id = f"{env_id}/{policy_id}"
@@ -67,8 +66,8 @@ def run_agents(
policy.reset(seed=seed + i)
episode_steps = []
- episode_rewards: Dict[str, List[float]] = {i: [] for i in env.possible_agents}
- for ep_num in range(num_episodes):
+ episode_rewards: dict[str, list[float]] = {i: [] for i in env.possible_agents}
+ for _ep_num in range(num_episodes):
obs, _ = env.reset()
env.render()
for policy in policies.values():
diff --git a/examples/run_all_agents.py b/examples/run_all_agents.py
index ad97376..64c8de5 100644
--- a/examples/run_all_agents.py
+++ b/examples/run_all_agents.py
@@ -35,7 +35,6 @@
"""
import argparse
-from typing import Dict, Optional, Tuple
import posggym
import posggym.agents as pga
@@ -43,8 +42,8 @@
def try_make_policy(
- spec: PolicySpec, render_mode: Optional[str]
-) -> Tuple[Optional[posggym.Env], Optional[Dict[str, pga.Policy]]]:
+ spec: PolicySpec, render_mode: str | None
+) -> tuple[posggym.Env | None, dict[str, pga.Policy] | None]:
"""Tries to make the policy showing if it is possible."""
try:
env_id = "Driving-v1" if spec.env_id is None else spec.env_id
@@ -62,10 +61,10 @@ def try_make_policy(
return env, policies
except (
ImportError,
- posggym.error.DependencyNotInstalled,
- posggym.error.MissingArgument,
+ posggym.error.DependencyNotInstalledError,
+ posggym.error.MissingArgumentError,
) as e:
- posggym.logger.warn(
+ posggym.logger.warning(
f"Not testing posggym.agents policy spec `{spec.id}` due to error: {e}"
)
except RuntimeError as e:
@@ -78,8 +77,8 @@ def try_make_policy(
def run_policy(
spec: PolicySpec,
num_episodes: int,
- seed: Optional[int],
- render_mode: Optional[str] = "human",
+ seed: int | None,
+ render_mode: str | None = "human",
):
"""Run a posggym.policy."""
print(f"Running policy={spec.id}")
@@ -127,9 +126,9 @@ def run_policy(
def run_all_agents(
- env_id_prefix: Optional[str],
+ env_id_prefix: str | None,
num_episodes: int,
- seed: Optional[int],
+ seed: int | None,
render_mode: str = "human",
):
"""Run all agents."""
diff --git a/examples/run_keyboard_agent.py b/examples/run_keyboard_agent.py
index 0f866b3..b5a07f4 100644
--- a/examples/run_keyboard_agent.py
+++ b/examples/run_keyboard_agent.py
@@ -19,14 +19,12 @@
import argparse
import math
import sys
-from typing import Dict, List, Optional, Tuple
import numpy as np
+import posggym
import pygame
from gymnasium import spaces
-import posggym
-
grid_world_key_action_map = {
"Driving-v1": {
@@ -106,8 +104,8 @@ def display_vector_obs(obs: np.ndarray, width: int):
def run_discrete_env_manual_keyboard_agent(
- env: posggym.Env, keyboard_agent_id: List[str], pause_each_step: bool = False
-) -> Tuple[Dict[str, float], int]:
+ env: posggym.Env, keyboard_agent_id: list[str], pause_each_step: bool = False
+) -> tuple[dict[str, float], int]:
"""Run manual keyboard agent in discrete environment.
Assumes environment actions are discrete. So user will be prompted to input an
@@ -161,7 +159,7 @@ def run_discrete_env_manual_keyboard_agent(
def run_continuous_env_manual_keyboard_agent(
env: posggym.Env, keyboard_agent_id: str, pause_each_step: bool = False
-) -> Tuple[Dict[str, float], int]:
+) -> tuple[dict[str, float], int]:
"""Run manual keyboard agent in continuous environment.
Assumes environment actions are continuous (i.e. space.Box). So user will be
@@ -224,7 +222,7 @@ def run_continuous_env_manual_keyboard_agent(
def run_grid_world_env_keyboard_agent(
env: posggym.Env, keyboard_agent_id: str, pause_each_step: bool = False
-) -> Tuple[Dict[str, float], int]:
+) -> tuple[dict[str, float], int]:
"""Run keyboard agent in grid-world environment.
Assumes environment actions are angular and linear velocity.
@@ -275,7 +273,7 @@ def run_grid_world_env_keyboard_agent(
def run_continuous_env_keyboard_agent(
env: posggym.Env, keyboard_agent_id: str, pause_each_step: bool = False
-) -> Tuple[Dict[str, float], int]:
+) -> tuple[dict[str, float], int]:
"""Run keyboard agent in continuous environment.
Assumes environment actions are angular and linear velocity.
@@ -343,17 +341,19 @@ def run_continuous_env_keyboard_agent(
def run_keyboard_agent(
env_id: str,
- keyboard_agent_ids: List[str],
+ keyboard_agent_ids: list[str],
num_episodes: int,
- max_episode_steps: Optional[int] = None,
- seed: Optional[int] = None,
+ max_episode_steps: int | None = None,
+ seed: int | None = None,
pause_each_step: bool = False,
manual_input: bool = False,
):
"""Run keyboard agents."""
if max_episode_steps is not None:
env = posggym.make(
- env_id, render_mode="human", max_episode_steps=max_episode_steps
+ env_id,
+ render_mode="human",
+ max_episode_steps=max_episode_steps,
)
else:
env = posggym.make(env_id, render_mode="human")
@@ -384,7 +384,7 @@ def run_keyboard_agent(
env.reset(seed=seed)
episode_steps = []
- episode_rewards: Dict[str, List[float]] = {i: [] for i in env.possible_agents}
+ episode_rewards: dict[str, list[float]] = {i: [] for i in env.possible_agents}
for _ in range(num_episodes):
if manual_input:
rewards, steps = run_env_episode_fn(
diff --git a/examples/run_random_agents.py b/examples/run_random_agents.py
index d46ee6c..fcc01d4 100644
--- a/examples/run_random_agents.py
+++ b/examples/run_random_agents.py
@@ -16,70 +16,8 @@
"""
import argparse
-from typing import Dict, List, Optional
-import posggym
-
-
-def run_random_agent(
- env_id: str,
- num_episodes: int,
- max_episode_steps: Optional[int] = None,
- seed: Optional[int] = None,
- render_mode: Optional[str] = None,
-):
- """Run random agents."""
- if max_episode_steps is not None:
- env = posggym.make(
- env_id, render_mode=render_mode, max_episode_steps=max_episode_steps
- )
- else:
- env = posggym.make(env_id, render_mode=render_mode)
-
- env.reset(seed=seed)
-
- dones = 0
- episode_steps = []
- episode_rewards: Dict[str, List[float]] = {i: [] for i in env.possible_agents}
- for ep_num in range(num_episodes):
- env.render()
-
- t = 0
- done = False
- rewards = {i: 0.0 for i in env.possible_agents}
- while not done and (max_episode_steps is None or t < max_episode_steps):
- a = {i: env.action_spaces[i].sample() for i in env.agents}
- _, r, _, _, done, _ = env.step(a)
- t += 1
-
- env.render()
-
- for i, r_i in r.items():
- rewards[i] += r_i
-
- print(f"End episode {ep_num}")
- dones += int(done)
- episode_steps.append(t)
-
- env.reset()
-
- for i, r_i in rewards.items():
- episode_rewards[i].append(r_i)
-
- if done:
- print(t, rewards)
-
- env.close()
-
- print("All episodes finished")
- print(
- f"Completed episodes (i.e. where 'done=True') = {dones} out of {num_episodes}"
- )
- mean_steps = sum(episode_steps) / len(episode_steps)
- print(f"Mean episode steps = {mean_steps:.2f}")
- mean_returns = {i: sum(r) / len(r) for i, r in episode_rewards.items()}
- print(f"Mean Episode returns {mean_returns}")
- return mean_steps, mean_returns
+from posggym.utils.run_random_agents import run_random_agent
if __name__ == "__main__":
diff --git a/notebooks/generate_payoffs_figs.py b/notebooks/generate_payoffs_figs.py
index 344cda1..b52d6ae 100644
--- a/notebooks/generate_payoffs_figs.py
+++ b/notebooks/generate_payoffs_figs.py
@@ -4,12 +4,13 @@
import sys
from datetime import datetime
from pathlib import Path
-from typing import Optional
from posggym.config import BASE_RESULTS_DIR, REPO_DIR
+
sys.path.append(str(REPO_DIR / "notebooks"))
-import plot_utils # noqa: E402
+import plot_utils
+
results_dir = REPO_DIR / "notebooks" / "results" / "pairwise_agent_comparison"
@@ -70,7 +71,7 @@ def generate_fig(
)
-def main(env_id: Optional[str], output_dir: Optional[str] = None):
+def main(env_id: str | None, output_dir: Path | None = None):
available_env_result_dirs = [x.name for x in results_dir.glob("*")]
available_env_result_dirs.sort()
@@ -98,10 +99,10 @@ def main(env_id: Optional[str], output_dir: Optional[str] = None):
result_file = result_path.name
print(f"Generating figures for {result_file}")
- df = plot_utils.import_results(result_path)
+ results_df = plot_utils.import_results(result_path)
generate_fig(
- df,
+ results_df,
env_output_dir,
result_file,
policy_key="policy_name",
@@ -109,7 +110,7 @@ def main(env_id: Optional[str], output_dir: Optional[str] = None):
)
generate_fig(
- df,
+ results_df,
env_output_dir,
result_file,
policy_key="policy_type",
diff --git a/notebooks/plot_utils.py b/notebooks/plot_utils.py
index 703bcc5..df3a51b 100644
--- a/notebooks/plot_utils.py
+++ b/notebooks/plot_utils.py
@@ -1,11 +1,12 @@
"""Plotting functions for posggym.agents analysis."""
+from functools import partial
from itertools import product
-from typing import List, Optional, Tuple
import numpy as np
+
try:
- import matplotlib
+ import matplotlib as mpl
import matplotlib.pyplot as plt
except ImportError as e:
raise ImportError(
@@ -31,12 +32,13 @@ def conf_int(row, prefix):
n = row["num_episodes"]
return 1.96 * (std / np.sqrt(n))
- prefix = ""
for col in df.columns:
if not col.endswith("_std"):
continue
prefix = col.replace("_std", "")
- df[f"{prefix}_CI"] = df.apply(lambda row: conf_int(row, prefix), axis=1)
+ conf_int_with_prefix = partial(conf_int, prefix=prefix)
+
+ df[f"{prefix}_CI"] = df.apply(conf_int_with_prefix, axis=1)
return df
@@ -50,18 +52,18 @@ def prop(row, col_name):
columns = ["num_LOSS", "num_DRAW", "num_WIN", "num_NA"]
new_column_names = ["prop_LOSS", "prop_DRAW", "prop_WIN", "prop_NA"]
- for col_name, new_name in zip(columns, new_column_names):
+ for col_name, new_name in zip(columns, new_column_names, strict=False):
if col_name in df.columns:
- df[new_name] = df.apply(lambda row: prop(row, col_name), axis=1)
+ prop_with_col = partial(prop, col_name=col_name)
+ df[new_name] = df.apply(prop_with_col, axis=1)
return df
-def get_policy_type_and_seed(policy_name: str) -> Tuple[str, str]:
+def get_policy_type_and_seed(policy_name: str) -> tuple[str, str]:
"""Get policy type and seed from policy name."""
if "seed" not in policy_name:
return policy_name, "None"
- # policy_name = "policy_type_seed[seed]"
tokens = policy_name.split("_")
policy_type = []
seed_token = None
@@ -87,38 +89,38 @@ def add_policy_type_and_seed(df: pd.DataFrame) -> pd.DataFrame:
return df
-def add_co_team_name(df: pd.DataFrame) -> pd.DataFrame:
+def add_co_team_name(results_df: pd.DataFrame) -> pd.DataFrame:
"""Add co team name to dataframe.
Also removes unwanted rows.
"""
# For each policy we want to group rows where that policy is paired with equivalent
# co-player policies.
- env_symmetric = df["symmetric"].unique().tolist()[0]
+ env_symmetric = results_df["symmetric"].unique().tolist()[0]
if env_symmetric:
# For symmetric environments we group rows where the policy is paired with the
# same co-player policies, independent of the ordering
same_co_team_ids = set()
- for team_id in df["co_team_id"].unique().tolist():
+ for team_id in results_df["co_team_id"].unique().tolist():
# ignore ( and ) and start and end
pi_names = team_id[1:-1].split(",")
if all(name == pi_names[0] for name in pi_names):
same_co_team_ids.add(team_id)
- df = df[df["co_team_id"].isin(same_co_team_ids)]
+ results_df = results_df[results_df["co_team_id"].isin(same_co_team_ids)]
def get_team_name(row):
team_id = row["co_team_id"]
return team_id[1:-1].split(",")[0]
- df["co_team_name"] = df.apply(get_team_name, axis=1)
+ results_df["co_team_name"] = results_df.apply(get_team_name, axis=1)
else:
# for asymmetric environments ordering matters so can't reduce team IDs
def get_team_name_asymmetric(row):
team_id = row["co_team_id"]
return team_id[1:-1]
- df["co_team_name"] = df.apply(get_team_name_asymmetric, axis=1)
+ results_df["co_team_name"] = results_df.apply(get_team_name_asymmetric, axis=1)
def get_team_type(row):
pi_names = row["co_team_name"].split(",")
@@ -134,9 +136,9 @@ def get_team_seed(row):
return pi_seeds[0]
return ",".join(pi_seeds)
- df["co_team_type"] = df.apply(get_team_type, axis=1)
- df["co_team_seed"] = df.apply(get_team_seed, axis=1)
- return df
+ results_df["co_team_type"] = results_df.apply(get_team_type, axis=1)
+ results_df["co_team_seed"] = results_df.apply(get_team_seed, axis=1)
+ return results_df
def import_results(
@@ -145,16 +147,16 @@ def import_results(
"""Import experiment results."""
# disable annoying warning
pd.options.mode.chained_assignment = None
- df = pd.read_csv(result_file)
+ results_df = pd.read_csv(result_file)
- df = add_95CI(df)
- df = add_outcome_proportions(df)
- df = add_policy_type_and_seed(df)
- df = add_co_team_name(df)
+ results_df = add_95CI(results_df)
+ results_df = add_outcome_proportions(results_df)
+ results_df = add_policy_type_and_seed(results_df)
+ results_df = add_co_team_name(results_df)
- # enable annoyin warning
+ # enable annoying warning
pd.options.mode.chained_assignment = "warn"
- return df
+ return results_df
def heatmap(
@@ -163,7 +165,7 @@ def heatmap(
col_labels,
ax=None,
show_cbar=True,
- cbar_kw={},
+ cbar_kw=None,
cbarlabel="",
**kwargs,
):
@@ -202,7 +204,7 @@ def heatmap(
# Create colorbar
if show_cbar:
- cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
+ cbar = ax.figure.colorbar(im, ax=ax, **(cbar_kw or {}))
cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom")
else:
cbar = None
@@ -260,7 +262,7 @@ def annotate_heatmap(
the text labels.
"""
- if not isinstance(data, (list, np.ndarray)):
+ if not isinstance(data, list | np.ndarray):
data = im.get_array()
# Normalize the threshold to the images color range.
@@ -276,7 +278,7 @@ def annotate_heatmap(
# Get the formatter in case a string is supplied
if isinstance(valfmt, str):
- valfmt = matplotlib.ticker.StrMethodFormatter(valfmt)
+ valfmt = mpl.ticker.StrMethodFormatter(valfmt)
# Loop over the data and create a `Text` for each "pixel".
# Change the text's color depending on the data.
@@ -292,11 +294,11 @@ def annotate_heatmap(
def plot_pairwise_heatmap(
ax,
- labels: Tuple[List[str], List[str]],
+ labels: tuple[list[str], list[str]],
values: np.ndarray,
- title: Optional[str] = None,
- vrange: Optional[Tuple[float, float]] = None,
- valfmt: Optional[str] = None,
+ title: str | None = None,
+ vrange: tuple[float, float] | None = None,
+ valfmt: str | None = None,
):
"""Plot pairwise values as a heatmap."""
# Note numpy arrays by default have (0, 0) in the top-left corner.
@@ -339,7 +341,7 @@ def get_pairwise_values(
y_key: str,
policy_key: str = "policy_id",
coplayer_policy_key: str = "coplayer_policy_id",
- coplayer_policies: Optional[List[str]] = None,
+ coplayer_policies: list[str] | None = None,
):
"""Get values for each policy pairing."""
policies = plot_df[policy_key].unique().tolist()
@@ -370,11 +372,11 @@ def plot_pairwise_comparison(
y_key: str,
policy_key: str = "policy_id",
coplayer_policy_key: str = "coplayer_policy_id",
- y_err_key: Optional[str] = None,
+ y_err_key: str | None = None,
vrange=None,
figsize=(20, 20),
valfmt=None,
- coplayer_policies: Optional[List[str]] = None,
+ coplayer_policies: list[str] | None = None,
):
"""Plot results for each policy pairings.
@@ -391,7 +393,6 @@ def plot_pairwise_comparison(
fig, axs = plt.subplots(
nrows=1,
ncols=ncols,
- # figsize=figsize,
squeeze=False,
sharey=True,
)
@@ -566,7 +567,7 @@ def plot_mean_pairwise_comparison(
pop_key: str,
coplayer_policy_key: str,
coplayer_pop_key: str,
- vrange: Optional[Tuple[float, float]] = None,
+ vrange: tuple[float, float] | None = None,
figsize=(12, 6),
valfmt=None,
):
diff --git a/posggym/agents/continuous/driving_continuous/__init__.py b/posggym/agents/continuous/driving_continuous/__init__.py
index c73a0c3..6d92e26 100644
--- a/posggym/agents/continuous/driving_continuous/__init__.py
+++ b/posggym/agents/continuous/driving_continuous/__init__.py
@@ -3,6 +3,7 @@
from posggym.agents.utils import processors
from posggym.config import AGENT_MODEL_DIR
+
ENV_ID = "DrivingContinuous-v0"
agent_model_dir = AGENT_MODEL_DIR / "continuous" / "driving_continuous"
policy_specs = {}
diff --git a/posggym/agents/continuous/drone_team_capture/heuristic.py b/posggym/agents/continuous/drone_team_capture/heuristic.py
index c3540f8..4d383a4 100644
--- a/posggym/agents/continuous/drone_team_capture/heuristic.py
+++ b/posggym/agents/continuous/drone_team_capture/heuristic.py
@@ -30,12 +30,12 @@ def __init__(
model: DroneTeamCaptureModel,
agent_id: str,
policy_id: PolicyID,
- ):
+ ) -> None:
if model.n_com_pursuers < model.n_pursuers - 1 or (
model.observation_limit is not None
and model.observation_limit < 2 * model.r_arena
):
- logger.warn(
+ logger.warning(
"The DroneTeamCapture Heuristic policies are designed for the case "
"where each pursuer can see every other pursuer "
"(i.e. `n_com_pursuers = n_pursuers - 1` and `observation_limit = None`"
@@ -354,7 +354,7 @@ class DTCDPPHeuristicPolicy(DTCHeuristicPolicy):
Souza, C., Castillo, P., & Vidolov, B. (2022). Local interaction and navigation
guidance for hunters drones: a chase behavior approach with real-time tests.
- Robotica, 40(8), 2697–2715.
+ Robotica, 40(8), 2697-2715.
"""
diff --git a/posggym/agents/continuous/predator_prey_continuous/__init__.py b/posggym/agents/continuous/predator_prey_continuous/__init__.py
index 605b631..532f12e 100644
--- a/posggym/agents/continuous/predator_prey_continuous/__init__.py
+++ b/posggym/agents/continuous/predator_prey_continuous/__init__.py
@@ -1,10 +1,10 @@
"""Policies for the PredatorPreyContinuous-v0 environment."""
-from posggym.agents.registration import PolicySpec
from posggym.agents.continuous.predator_prey_continuous import heuristic
+from posggym.agents.registration import PolicySpec
from posggym.agents.torch_policy import PPOPolicy
-from posggym.config import AGENT_MODEL_DIR
from posggym.agents.utils import processors
+from posggym.config import AGENT_MODEL_DIR
ENV_ID = "PredatorPreyContinuous-v0"
diff --git a/posggym/agents/continuous/predator_prey_continuous/heuristic.py b/posggym/agents/continuous/predator_prey_continuous/heuristic.py
index 912794b..ab73ea7 100644
--- a/posggym/agents/continuous/predator_prey_continuous/heuristic.py
+++ b/posggym/agents/continuous/predator_prey_continuous/heuristic.py
@@ -3,7 +3,7 @@
import abc
import math
-from typing import TYPE_CHECKING, Tuple, cast
+from typing import TYPE_CHECKING, cast
import numpy as np
@@ -16,6 +16,7 @@
)
from posggym.utils import seeding
+
if TYPE_CHECKING:
from posggym.posggym.model import POSGModel
from posggym.utils.history import AgentHistory
@@ -24,7 +25,7 @@
class PPCHeuristicPolicy(Policy[PPAction, PPObs], abc.ABC):
"""Base class for heuristic policies for Predator-Prey continuous environment."""
- def __init__(self, model: POSGModel, agent_id: str, policy_id: PolicyID):
+ def __init__(self, model: POSGModel, agent_id: str, policy_id: PolicyID) -> None:
super().__init__(model, agent_id, policy_id)
self.model = cast(PredatorPreyContinuousModel, model)
self._rng, _ = seeding.np_random()
@@ -91,7 +92,7 @@ def get_value(self, state: PolicyState) -> float:
def _get_pi_from_obs(self, obs: PPObs) -> action_distributions.ActionDistribution:
raise NotImplementedError
- def _get_closest_prey(self, obs: PPObs) -> Tuple[float, float] | None:
+ def _get_closest_prey(self, obs: PPObs) -> tuple[float, float] | None:
prey_obs = obs[2 * self.n_sensors : 3 * self.n_sensors]
closest_idx = np.argmin(prey_obs)
if prey_obs[closest_idx] == self.model.obs_dist:
@@ -100,7 +101,7 @@ def _get_closest_prey(self, obs: PPObs) -> Tuple[float, float] | None:
closest_dist = prey_obs[closest_idx]
return closest_dist, closest_angle
- def _get_closest_predator(self, obs: PPObs) -> Tuple[float, float] | None:
+ def _get_closest_predator(self, obs: PPObs) -> tuple[float, float] | None:
pred_obs = obs[self.n_sensors : 2 * self.n_sensors]
closest_idx = np.argmin(pred_obs)
if pred_obs[closest_idx] == self.model.obs_dist:
@@ -111,7 +112,7 @@ def _get_closest_predator(self, obs: PPObs) -> Tuple[float, float] | None:
def _get_closest_prey_to_predator(
self, obs: PPObs, pred_dist: float, pred_angle: float
- ) -> Tuple[float, float] | None:
+ ) -> tuple[float, float] | None:
# find prey with minimum distance to predator
# d^2 = P^2 + p^2 - 2Pp cos(theta)
# d = distance between predator and prey
diff --git a/posggym/agents/continuous/pursuit_evasion_continuous/__init__.py b/posggym/agents/continuous/pursuit_evasion_continuous/__init__.py
index c4b672d..646bd2c 100644
--- a/posggym/agents/continuous/pursuit_evasion_continuous/__init__.py
+++ b/posggym/agents/continuous/pursuit_evasion_continuous/__init__.py
@@ -7,6 +7,7 @@
from posggym.agents.utils import processors
from posggym.config import AGENT_MODEL_DIR
+
ENV_ID = "PursuitEvasionContinuous-v0"
agent_model_dir = AGENT_MODEL_DIR / "continuous" / "pursuit_evasion_continuous"
policy_specs = {}
diff --git a/posggym/agents/continuous/pursuit_evasion_continuous/shortest_path.py b/posggym/agents/continuous/pursuit_evasion_continuous/shortest_path.py
index fbcae71..40aab88 100644
--- a/posggym/agents/continuous/pursuit_evasion_continuous/shortest_path.py
+++ b/posggym/agents/continuous/pursuit_evasion_continuous/shortest_path.py
@@ -3,7 +3,7 @@
import math
from itertools import product
-from typing import TYPE_CHECKING, List, Tuple, cast
+from typing import TYPE_CHECKING, cast
import numpy as np
@@ -22,11 +22,13 @@
from posggym.posggym.model import POSGModel
from posggym.utils.history import AgentHistory
+ZERO_ACTION = 0.0
+
class PECShortestPathPolicy(Policy[PEAction, PEObs]):
"""Shortest path policy for pursuit evasion continuous environment."""
- def __init__(self, model: POSGModel, agent_id: str, policy_id: PolicyID):
+ def __init__(self, model: POSGModel, agent_id: str, policy_id: PolicyID) -> None:
super().__init__(model, agent_id, policy_id)
self.model = cast(PursuitEvasionContinuousModel, model)
self._rng, _ = seeding.np_random()
@@ -168,7 +170,7 @@ def _get_shortest_path_action(
prev_body_state: PMBodyState,
body_state: PMBodyState,
target_coord: np.ndarray,
- ) -> Tuple[List[PEAction], action_distributions.ActionDistribution]:
+ ) -> tuple[list[PEAction], action_distributions.ActionDistribution]:
angle_vels = [
-self.model.dyaw_limit,
-self.model.dyaw_limit / 2.0,
@@ -234,10 +236,10 @@ def _get_shortest_path_action(
rtol=0.0,
atol=1e-1,
).all():
- if all(a[1] == 0.0 for a in sp_actions):
+ if all(a[1] == ZERO_ACTION for a in sp_actions):
# try to move forward
sp_actions = [(a[0], a[1] + 0.5) for a in sp_actions]
- elif all(a[0] == 0.0 for a in sp_actions):
+ elif all(a[0] == ZERO_ACTION for a in sp_actions):
# try to turn
old_sp_actions = sp_actions
sp_actions = []
diff --git a/posggym/agents/evaluation/diversity.py b/posggym/agents/evaluation/diversity.py
index 465170f..cbb3c75 100644
--- a/posggym/agents/evaluation/diversity.py
+++ b/posggym/agents/evaluation/diversity.py
@@ -2,7 +2,8 @@
from __future__ import annotations
from itertools import product
-from typing import Dict, List, TYPE_CHECKING
+from typing import TYPE_CHECKING
+
if TYPE_CHECKING:
from pathlib import Path
@@ -14,12 +15,15 @@
from posggym.agents.evaluation import pairwise
-def measure_return_diversity(
+MIN_POLICIES_FOR_RANDOM = 2
+
+
+def measure_return_diversity( # noqa: PLR0912
pw_returns: np.ndarray,
- policies: List[str],
- co_teams: List[str],
+ policies: list[str],
+ co_teams: list[str],
verbose: bool = False,
-) -> Dict[str, np.ndarray]:
+) -> dict[str, np.ndarray]:
"""Measure return diversity for a set of pairwise returns.
Diversity is measured in terms of the Euclidean distance between the pairwise
@@ -41,7 +45,7 @@ def measure_return_diversity(
# get max and min returns, excluding random policy if it exists
random_co_team_idxs = []
random_idx = None
- if len(policies) > 2 and "Random" in policies:
+ if len(policies) > MIN_POLICIES_FOR_RANDOM and "Random" in policies:
random_idx = policies.index("Random")
pw_returns_excl_random = np.delete(pw_returns, random_idx, axis=0)
random_co_team_idxs = [
@@ -57,7 +61,6 @@ def measure_return_diversity(
if verbose:
with np.printoptions(precision=2, suppress=True):
- # print(f"{pw_returns=}")
print(f"{max_return=:.2f}, {min_return=:.2f}")
# normalize pairwise distributions into [0.0, 1.0]
@@ -90,7 +93,7 @@ def measure_return_diversity(
print(f"{policy_ed=}")
# Group similar policies by relative MSE
- if len(policies) > 2 and "Random" in policies:
+ if len(policies) > MIN_POLICIES_FOR_RANDOM and "Random" in policies:
# exclude random policy from calculating bin sizes for grouping
pw_ed_excl_random = np.delete(pw_ed, random_idx, axis=0)
pw_ed_excl_random = np.delete(pw_ed_excl_random, random_idx, axis=1)
@@ -173,14 +176,16 @@ def run_return_diversity_analysis(
continue
print(f" {args_id=}")
- df = pairwise.load_pairwise_comparison_results(env_id, output_dir, args_id)
+ results_df = pairwise.load_pairwise_comparison_results(
+ env_id, output_dir, args_id
+ )
- if df["symmetric"].unique().tolist()[0]:
+ if results_df["symmetric"].unique().tolist()[0]:
# only keep results for one agent
- agent_ids = df["agent_id"].unique().tolist()
- df = df[df["agent_id"] == agent_ids[0]]
+ agent_ids = results_df["agent_id"].unique().tolist()
+ results_df = results_df[results_df["agent_id"] == agent_ids[0]]
- pw_returns_per_agent = pairwise.get_pairwise_returns_matrix(df)
+ pw_returns_per_agent = pairwise.get_pairwise_returns_matrix(results_df)
pairwise.generate_pairwise_returns_plot(
env_id=env_id,
@@ -202,7 +207,7 @@ def run_return_diversity_analysis(
)
results[i] = (results_i, policy_ids, co_teams_ids)
- for k in list(results.values())[0][0]:
+ for k in next(iter(results.values()))[0]:
fig, axs = plt.subplots(
nrows=1,
ncols=len(results),
@@ -213,7 +218,7 @@ def run_return_diversity_analysis(
for idx, i in enumerate(results):
div_results, policy_ids, co_teams_ids = results[i]
- if len(results) > 2:
+ if len(results) > MIN_POLICIES_FOR_RANDOM:
# only show first policy name for each co-team
co_team_labels = [
team_id.replace("(", "").replace(")", "").split(",")[0]
diff --git a/posggym/agents/evaluation/pairwise.py b/posggym/agents/evaluation/pairwise.py
index 96623e2..9b32dcf 100644
--- a/posggym/agents/evaluation/pairwise.py
+++ b/posggym/agents/evaluation/pairwise.py
@@ -5,7 +5,8 @@
import multiprocessing as mp
from datetime import datetime
from itertools import product
-from typing import Dict, List, NamedTuple, Tuple, TYPE_CHECKING
+from typing import TYPE_CHECKING, NamedTuple
+
if TYPE_CHECKING:
from pathlib import Path
@@ -43,8 +44,8 @@ class PWCParams(NamedTuple):
env_id: str
env_args_id: str | None
- env_args: Dict[str, str]
- policy_ids: Dict[str, str]
+ env_args: dict[str, str]
+ policy_ids: dict[str, str]
output_file: str
num_episodes: int
seed: int | None
@@ -52,7 +53,7 @@ class PWCParams(NamedTuple):
verbose: bool = False
-def run_episodes(args) -> Dict[str, Dict[str, float]]:
+def run_episodes(args) -> dict[str, dict[str, float]]:
"""Run episodes and return the average reward for each policy."""
params, total_runs = args
policy_names = {
@@ -143,14 +144,14 @@ def run_episodes(args) -> Dict[str, Dict[str, float]]:
print(f"Run {params.run_num}/{total_runs} complete.")
-def get_pairwise_comparison_params(
+def get_pairwise_comparison_params( # noqa: PLR0912
env_id: str,
output_dir: Path,
env_args_id: str | None = None,
num_episodes: int = 1000,
seed: int | None = None,
verbose: bool = False,
-) -> List[PWCParams]:
+) -> list[PWCParams]:
"""Get parameters for pairwise comparisons of all of an environment's policies."""
# attempt to make env to check if it is registered (displays nicer error msg)
posggym.make(env_id)
@@ -181,7 +182,7 @@ def get_pairwise_comparison_params(
output_file = env_output_dir / f"{args_id}.csv"
if output_file.exists():
- logger.warn(f"{output_file} exists. Rewriting.")
+ logger.warning(f"{output_file} exists. Rewriting.")
headers = [
"env_id",
@@ -249,8 +250,9 @@ def get_pairwise_comparison_params(
# run all pairwise combinations of policies
for policy_specs in product(*all_policy_specs):
- policy_ids = {i: spec.id for i, spec in zip(agent_ids, policy_specs)}
- # print(f" {policy_ids=}")
+ policy_ids = {
+ i: spec.id for i, spec in zip(agent_ids, policy_specs, strict=False)
+ }
pairwise_policy_ids.append(policy_ids)
if verbose:
@@ -358,7 +360,7 @@ def load_pairwise_comparison_results(
def get_pairwise_returns_matrix(
df: pd.DataFrame,
-) -> Dict[str, Tuple[np.ndarray, List[str], List[str]]]:
+) -> dict[str, tuple[np.ndarray, list[str], list[str]]]:
"""Get pairwise returns matrix for each agent in environment."""
num_episodes = df["num_episodes"].unique().tolist()[0]
agent_ids = df["agent_id"].unique().tolist()
@@ -417,12 +419,12 @@ def get_pairwise_returns_matrix(
(df["agent_id"] == i)
& (df["policy_name"] == policy_name)
& (df["co_team_id"] == co_team_id)
- ]["episode_reward_mean"].values[0]
+ ]["episode_reward_mean"].to_numpy()[0]
pw_returns_i[1, policy_idx, co_team_idx] = df[
(df["agent_id"] == i)
& (df["policy_name"] == policy_name)
& (df["co_team_id"] == co_team_id)
- ]["episode_reward_std"].values[0]
+ ]["episode_reward_std"].to_numpy()[0]
# compute 95% CI for mean: 1.96*std/sqrt(N)
pw_returns_i[2, policy_idx, co_team_idx] = (
1.96 * pw_returns_i[1, policy_idx, co_team_idx] / np.sqrt(num_episodes)
@@ -434,7 +436,7 @@ def generate_pairwise_returns_plot(
env_id: str,
output_dir: Path,
env_args_id: str | None,
- pw_returns_per_agent: Dict[str, Tuple[np.ndarray, List[str], List[str]]],
+ pw_returns_per_agent: dict[str, tuple[np.ndarray, list[str], list[str]]],
show: bool = True,
save: bool = True,
mean_only: bool = False,
@@ -534,14 +536,14 @@ def plot_pairwise_comparison_results(
continue
print(f" {args_id=}")
- df = load_pairwise_comparison_results(env_id, output_dir, args_id)
+ results_df = load_pairwise_comparison_results(env_id, output_dir, args_id)
- if df["symmetric"].unique().tolist()[0]:
+ if results_df["symmetric"].unique().tolist()[0]:
# only keep results for one agent
- agent_ids = df["agent_id"].unique().tolist()
- df = df[df["agent_id"] == agent_ids[0]]
+ agent_ids = results_df["agent_id"].unique().tolist()
+ results_df = results_df[results_df["agent_id"] == agent_ids[0]]
- pw_returns_per_agent = get_pairwise_returns_matrix(df)
+ pw_returns_per_agent = get_pairwise_returns_matrix(results_df)
generate_pairwise_returns_plot(
env_id=env_id,
diff --git a/posggym/agents/grid_world/cooperative_reaching/__init__.py b/posggym/agents/grid_world/cooperative_reaching/__init__.py
index 47b7d1b..1b55b4d 100644
--- a/posggym/agents/grid_world/cooperative_reaching/__init__.py
+++ b/posggym/agents/grid_world/cooperative_reaching/__init__.py
@@ -2,6 +2,7 @@
from posggym.agents.grid_world.cooperative_reaching import heuristic
from posggym.agents.registration import PolicySpec
+
policy_specs = {}
for policy_class in [
heuristic.CRHeuristic1,
diff --git a/posggym/agents/grid_world/cooperative_reaching/heuristic.py b/posggym/agents/grid_world/cooperative_reaching/heuristic.py
index edeb856..324b9e0 100644
--- a/posggym/agents/grid_world/cooperative_reaching/heuristic.py
+++ b/posggym/agents/grid_world/cooperative_reaching/heuristic.py
@@ -11,7 +11,7 @@
from __future__ import annotations
import random
-from typing import TYPE_CHECKING, List, Optional
+from typing import TYPE_CHECKING
from posggym.agents.policy import Policy, PolicyID, PolicyState
from posggym.agents.utils import action_distributions
@@ -26,6 +26,7 @@
CRObs,
)
+
if TYPE_CHECKING:
from posggym.envs.grid_world.core import Coord
@@ -39,7 +40,7 @@ class CRHeuristicPolicy(Policy[CRAction, CRObs]):
def __init__(
self, model: CooperativeReachingModel, agent_id: str, policy_id: PolicyID
- ):
+ ) -> None:
super().__init__(model, agent_id, policy_id)
self._rng = random.Random()
self.grid_size = model.size
@@ -83,7 +84,7 @@ def get_value(self, state: PolicyState) -> float:
f"`get_value()` no implemented by {self.__class__.__name__} policy"
)
- def _move_towards(self, target_pos: Coord, agent_pos: Coord) -> List[CRAction]:
+ def _move_towards(self, target_pos: Coord, agent_pos: Coord) -> list[CRAction]:
"""Get list of actions that move towards target_pos from agent_pos."""
valid_actions = []
if target_pos[1] < agent_pos[1]:
@@ -100,7 +101,7 @@ def _move_towards(self, target_pos: Coord, agent_pos: Coord) -> List[CRAction]:
valid_actions.append(DO_NOTHING)
return valid_actions
- def _get_target_pos(self, obs: CRObs, target_goal: Optional[Coord]) -> Coord:
+ def _get_target_pos(self, obs: CRObs, target_goal: Coord | None) -> Coord:
"""Get target position from observation."""
raise NotImplementedError
@@ -145,7 +146,7 @@ def _get_target_goal(
desired_dist_to_goal = min(dist_to_goal) if closest else max(dist_to_goal)
valid_target_goals = [
g
- for g, dist in zip(goal_list, dist_to_goal)
+ for g, dist in zip(goal_list, dist_to_goal, strict=False)
if dist == desired_dist_to_goal
]
return self._rng.choice(valid_target_goals)
@@ -154,7 +155,7 @@ def _get_target_goal(
class CRHeuristic1(CRHeuristicPolicy):
"""H1 always goes to the closest rewarding goal."""
- def _get_target_pos(self, obs: CRObs, target_goal: Optional[Coord]) -> Coord:
+ def _get_target_pos(self, obs: CRObs, target_goal: Coord | None) -> Coord:
if target_goal is None:
target_goal = self._get_target_goal(obs, closest=True, optimal=None)
return target_goal
@@ -163,7 +164,7 @@ def _get_target_pos(self, obs: CRObs, target_goal: Optional[Coord]) -> Coord:
class CRHeuristic2(CRHeuristicPolicy):
"""H2 always goes to the furthest rewarding goal."""
- def _get_target_pos(self, obs: CRObs, target_goal: Optional[Coord]) -> Coord:
+ def _get_target_pos(self, obs: CRObs, target_goal: Coord | None) -> Coord:
if target_goal is None:
target_goal = self._get_target_goal(obs, closest=False, optimal=None)
return target_goal
@@ -172,7 +173,7 @@ def _get_target_pos(self, obs: CRObs, target_goal: Optional[Coord]) -> Coord:
class CRHeuristic3(CRHeuristicPolicy):
"""H3 always goes to the closest optimal goal."""
- def _get_target_pos(self, obs: CRObs, target_goal: Optional[Coord]) -> Coord:
+ def _get_target_pos(self, obs: CRObs, target_goal: Coord | None) -> Coord:
if target_goal is None:
target_goal = self._get_target_goal(obs, closest=True, optimal=True)
return target_goal
@@ -181,7 +182,7 @@ def _get_target_pos(self, obs: CRObs, target_goal: Optional[Coord]) -> Coord:
class CRHeuristic4(CRHeuristicPolicy):
"""H4 always goes to the furthest optimal goal."""
- def _get_target_pos(self, obs: CRObs, target_goal: Optional[Coord]) -> Coord:
+ def _get_target_pos(self, obs: CRObs, target_goal: Coord | None) -> Coord:
if target_goal is None:
target_goal = self._get_target_goal(obs, closest=False, optimal=True)
return target_goal
@@ -190,7 +191,7 @@ def _get_target_pos(self, obs: CRObs, target_goal: Optional[Coord]) -> Coord:
class CRHeuristic5(CRHeuristicPolicy):
"""H5 always goes to the closest suboptimal goal."""
- def _get_target_pos(self, obs: CRObs, target_goal: Optional[Coord]) -> Coord:
+ def _get_target_pos(self, obs: CRObs, target_goal: Coord | None) -> Coord:
if target_goal is None:
target_goal = self._get_target_goal(obs, closest=True, optimal=False)
return target_goal
@@ -199,7 +200,7 @@ def _get_target_pos(self, obs: CRObs, target_goal: Optional[Coord]) -> Coord:
class CRHeuristic6(CRHeuristicPolicy):
"""H6 always goes to the furthest suboptimal goal."""
- def _get_target_pos(self, obs: CRObs, target_goal: Optional[Coord]) -> Coord:
+ def _get_target_pos(self, obs: CRObs, target_goal: Coord | None) -> Coord:
if target_goal is None:
target_goal = self._get_target_goal(obs, closest=False, optimal=False)
return target_goal
@@ -208,7 +209,7 @@ def _get_target_pos(self, obs: CRObs, target_goal: Optional[Coord]) -> Coord:
class CRHeuristic7(CRHeuristicPolicy):
"""H7 goes to a randomly selected goal."""
- def _get_target_pos(self, obs: CRObs, target_goal: Optional[Coord]) -> Coord:
+ def _get_target_pos(self, obs: CRObs, target_goal: Coord | None) -> Coord:
if target_goal is None:
target_goal = self._rng.choice(list(self.goals))
return target_goal
@@ -217,7 +218,7 @@ def _get_target_pos(self, obs: CRObs, target_goal: Optional[Coord]) -> Coord:
class CRHeuristic8(CRHeuristicPolicy):
"""H8 goes to the goal closest to the other agent at each time step."""
- def _get_target_pos(self, obs: CRObs, target_goal: Optional[Coord]) -> Coord:
+ def _get_target_pos(self, obs: CRObs, target_goal: Coord | None) -> Coord:
other_pos = obs[1]
if other_pos == (self.grid_size, self.grid_size):
# cannot see other agent, so just go towards a random goal
@@ -229,7 +230,9 @@ def _get_target_pos(self, obs: CRObs, target_goal: Optional[Coord]) -> Coord:
]
min_dist_to_goal = min(dist_to_goal)
closest_goals = [
- g for g, dist in zip(self.goals, dist_to_goal) if dist == min_dist_to_goal
+ g
+ for g, dist in zip(self.goals, dist_to_goal, strict=False)
+ if dist == min_dist_to_goal
]
return self._rng.choice(closest_goals)
@@ -237,7 +240,7 @@ def _get_target_pos(self, obs: CRObs, target_goal: Optional[Coord]) -> Coord:
class CRHeuristic9(CRHeuristicPolicy):
"""H9 goes to the optimal goal closest to the other agent."""
- def _get_target_pos(self, obs: CRObs, target_goal: Optional[Coord]) -> Coord:
+ def _get_target_pos(self, obs: CRObs, target_goal: Coord | None) -> Coord:
other_pos = obs[1]
if other_pos == (self.grid_size, self.grid_size):
# cannot see other agent, so just go towards a random optimal goal
@@ -255,7 +258,9 @@ def _get_target_pos(self, obs: CRObs, target_goal: Optional[Coord]) -> Coord:
]
min_dist_to_goal = min(dist_to_goal)
closest_goals = [
- g for g, dist in zip(goal_list, dist_to_goal) if dist == min_dist_to_goal
+ g
+ for g, dist in zip(goal_list, dist_to_goal, strict=False)
+ if dist == min_dist_to_goal
]
return self._rng.choice(closest_goals)
@@ -263,7 +268,7 @@ def _get_target_pos(self, obs: CRObs, target_goal: Optional[Coord]) -> Coord:
class CRHeuristic10(CRHeuristicPolicy):
"""H10 goes to the sub-optimal goal closest to the other agent."""
- def _get_target_pos(self, obs: CRObs, target_goal: Optional[Coord]) -> Coord:
+ def _get_target_pos(self, obs: CRObs, target_goal: Coord | None) -> Coord:
other_pos = obs[1]
if other_pos == (self.grid_size, self.grid_size):
# cannot see other agent, so just go towards a random optimal goal
@@ -281,7 +286,9 @@ def _get_target_pos(self, obs: CRObs, target_goal: Optional[Coord]) -> Coord:
]
min_dist_to_goal = min(dist_to_goal)
closest_goals = [
- g for g, dist in zip(goal_list, dist_to_goal) if dist == min_dist_to_goal
+ g
+ for g, dist in zip(goal_list, dist_to_goal, strict=False)
+ if dist == min_dist_to_goal
]
return self._rng.choice(closest_goals)
@@ -289,7 +296,7 @@ def _get_target_pos(self, obs: CRObs, target_goal: Optional[Coord]) -> Coord:
class CRHeuristic11(CRHeuristicPolicy):
"""H11 follows the other agent."""
- def _get_target_pos(self, obs: CRObs, target_goal: Optional[Coord]) -> Coord:
+ def _get_target_pos(self, obs: CRObs, target_goal: Coord | None) -> Coord:
other_pos = obs[1]
if other_pos == (self.grid_size, self.grid_size):
# cannot see other agent, so just go towards a random goal
diff --git a/posggym/agents/grid_world/driving/shortest_path.py b/posggym/agents/grid_world/driving/shortest_path.py
index ef1a46a..a095977 100644
--- a/posggym/agents/grid_world/driving/shortest_path.py
+++ b/posggym/agents/grid_world/driving/shortest_path.py
@@ -3,7 +3,7 @@
from itertools import product
from queue import PriorityQueue
-from typing import TYPE_CHECKING, Dict, Set, Tuple, cast
+from typing import TYPE_CHECKING, ClassVar, cast
from posggym.agents.policy import Policy, PolicyID, PolicyState
from posggym.agents.utils import action_distributions
@@ -30,9 +30,12 @@
from posggym.model import POSGModel
from posggym.utils.history import AgentHistory
+MIN_AGGRESSIVENESS = 0.0
+MAX_AGGRESSIVENESS = 1.0
+AVG_AGGRESSIVENESS = (MIN_AGGRESSIVENESS + MAX_AGGRESSIVENESS) / 2
# Current coord, speed, facing direction
-Pos = Tuple[Coord, Speed, Direction]
+Pos = tuple[Coord, Speed, Direction]
class DrivingShortestPathPolicy(Policy[DAction, DObs]):
@@ -42,7 +45,7 @@ class DrivingShortestPathPolicy(Policy[DAction, DObs]):
goal. If there are multiple actions on the shortest path then selects uniformly
at random from those actions.
- Arguments
+ Arguments:
---------
aggressiveness : float
The aggressiveness of the policy towards other vehicles. A value of 0.0 means
@@ -58,8 +61,7 @@ class DrivingShortestPathPolicy(Policy[DAction, DObs]):
# this shares shortest path computation and storage between all instances of class
# which is useful if running a vectorized environment or with many shortest path
# agents
- # shortest_paths: Dict[Coord, Dict[Pos, Dict[Pos, int]]] = {}
- shortest_paths: Dict[Coord, Dict[Pos, int]] = {}
+ shortest_paths: ClassVar[dict[Coord, dict[Pos, int]]] = {}
def __init__(
self,
@@ -67,12 +69,13 @@ def __init__(
agent_id: str,
policy_id: PolicyID,
aggressiveness: float = 1.0,
- ):
+ ) -> None:
super().__init__(model, agent_id, policy_id)
self.model = cast(DrivingModel, model)
- assert (
- 0.0 <= aggressiveness <= 1.0
- ), f"Aggressiveness must be between 0.0 and 1.0, got {aggressiveness}"
+ assert MIN_AGGRESSIVENESS <= aggressiveness <= MAX_AGGRESSIVENESS, (
+ f"Aggressiveness must be between {MIN_AGGRESSIVENESS}"
+ f"and {MAX_AGGRESSIVENESS}, got {aggressiveness}"
+ )
self.aggressiveness = aggressiveness
self._grid = self.model.grid
self._action_space = list(range(self.model.action_spaces[agent_id].n))
@@ -167,7 +170,7 @@ def get_pi(self, state: PolicyState) -> action_distributions.ActionDistribution:
dists = []
for a in self._action_space:
if (
- self.aggressiveness < 0.5
+ self.aggressiveness < AVG_AGGRESSIVENESS
and state["speed"] >= Speed.FORWARD_SLOW
and a == ACCELERATE
):
@@ -176,7 +179,7 @@ def get_pi(self, state: PolicyState) -> action_distributions.ActionDistribution:
dists.append(float("inf"))
continue
- a_speed = self.model.get_next_speed(a, state["speed"])
+ a_speed = self.model.get_next_speed_(a, state["speed"])
a_facing_dir = self.model.get_next_direction(
a, a_speed, state["facing_dir"]
)
@@ -211,7 +214,7 @@ def get_value(self, state: PolicyState) -> float:
f"`get_value()` no implemented by {self.__class__.__name__} policy"
)
- def get_min_other_vehicle_dist(self, local_obs: Tuple[int, ...]) -> int:
+ def get_min_other_vehicle_dist(self, local_obs: tuple[int, ...]) -> int:
"""Get minimum distance to other vehicle in local observation."""
min_other_vehicle_dist = self.max_obs_dist + 1
for idx, cell_obs in enumerate(local_obs):
@@ -237,7 +240,7 @@ def get_shortest_path(
origin: Pos,
dest: Coord,
grid: DrivingGrid,
- lookup_table: Dict[Pos, int],
+ lookup_table: dict[Pos, int],
) -> int:
"""Get shortest path to given origin to given destination.
@@ -295,7 +298,7 @@ def get_shortest_path(
return lookup_table[origin]
@staticmethod
- def get_next_positions(pos: Pos, grid: DrivingGrid) -> Set[Pos]:
+ def get_next_positions(pos: Pos, grid: DrivingGrid) -> set[Pos]:
coord, speed, facing_dir = pos
next_positions = set()
diff --git a/posggym/agents/grid_world/driving_gen/__init__.py b/posggym/agents/grid_world/driving_gen/__init__.py
index a3e5d6a..7b6e8c2 100644
--- a/posggym/agents/grid_world/driving_gen/__init__.py
+++ b/posggym/agents/grid_world/driving_gen/__init__.py
@@ -4,6 +4,7 @@
)
from posggym.agents.registration import PolicySpec
+
policy_specs = {}
for a, description in [
(
diff --git a/posggym/agents/grid_world/driving_gen/shortest_path.py b/posggym/agents/grid_world/driving_gen/shortest_path.py
index 8a22122..f08a405 100644
--- a/posggym/agents/grid_world/driving_gen/shortest_path.py
+++ b/posggym/agents/grid_world/driving_gen/shortest_path.py
@@ -30,7 +30,7 @@ class DrivingGenShortestPathPolicy(DrivingShortestPathPolicy):
computes the shortest path for the agent's current destination, rather than for
all destinations.
- Arguments
+ Arguments:
---------
aggressiveness : float
The aggressiveness of the policy towards other vehicles. A value of 0.0 means
@@ -47,7 +47,7 @@ def __init__(
agent_id: str,
policy_id: PolicyID,
aggressiveness: float = 1.0,
- ):
+ ) -> None:
super().__init__(
model,
agent_id,
diff --git a/posggym/agents/grid_world/level_based_foraging/heuristic.py b/posggym/agents/grid_world/level_based_foraging/heuristic.py
index 3db6198..074de30 100644
--- a/posggym/agents/grid_world/level_based_foraging/heuristic.py
+++ b/posggym/agents/grid_world/level_based_foraging/heuristic.py
@@ -8,7 +8,7 @@
from __future__ import annotations
import random
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
+from typing import TYPE_CHECKING, cast
from posggym.agents.policy import Policy, PolicyID, PolicyState
from posggym.agents.utils import action_distributions
@@ -18,6 +18,7 @@
LevelBasedForagingModel,
)
+
if TYPE_CHECKING:
from posggym.envs.grid_world.core import Coord
@@ -31,7 +32,7 @@ class LBFHeuristicPolicy(Policy[LBFAction, LBFObs]):
def __init__(
self, model: LevelBasedForagingModel, agent_id: str, policy_id: PolicyID
- ):
+ ) -> None:
super().__init__(model, agent_id, policy_id)
assert model.observation_mode in ("vector", "tuple")
self._rng = random.Random()
@@ -86,23 +87,23 @@ def get_value(self, state: PolicyState) -> float:
def _get_target_pos(
self,
- agent_obs: Tuple[int, int, int],
- food_obs: List[Tuple[int, int, int]],
- other_agent_obs: List[Tuple[int, int, int]],
+ agent_obs: tuple[int, int, int],
+ food_obs: list[tuple[int, int, int]],
+ other_agent_obs: list[tuple[int, int, int]],
last_action: LBFAction,
- target_pos: Optional[Coord],
- ) -> Optional[Coord]:
+ target_pos: Coord | None,
+ ) -> Coord | None:
"""Get target position from observations."""
raise NotImplementedError
def _get_food_by_distance(
self,
agent_pos: Coord,
- food_obs: List[Tuple[int, int, int]],
+ food_obs: list[tuple[int, int, int]],
closest: bool = True,
- max_food_level: Optional[int] = None,
- ) -> Optional[Coord]:
- food_distances: Dict[int, List[Coord]] = {}
+ max_food_level: int | None = None,
+ ) -> Coord | None:
+ food_distances: dict[int, list[Coord]] = {}
for y, x, level in food_obs:
if x == -1 or (max_food_level is not None and level > max_food_level):
continue
@@ -118,14 +119,14 @@ def _get_food_by_distance(
desired_dist = min(food_distances) if closest else max(food_distances)
return self._rng.choice(food_distances[desired_dist])
- def _center_of_agents(self, agent_obs: List[Tuple[int, int, int]]) -> Coord:
+ def _center_of_agents(self, agent_obs: list[tuple[int, int, int]]) -> Coord:
y_mean = sum(o[0] for o in agent_obs) / len(agent_obs)
x_mean = sum(o[1] for o in agent_obs) / len(agent_obs)
return round(x_mean), round(y_mean)
def _move_towards(
self, agent_pos: Coord, target: Coord, load_if_adjacent: bool = True
- ) -> List[LBFAction]:
+ ) -> list[LBFAction]:
if (
load_if_adjacent
and abs(target[0] - agent_pos[0]) + abs(target[1] - agent_pos[1]) == 1
@@ -168,12 +169,12 @@ class LBFHeuristic1(LBFHeuristicPolicy):
def _get_target_pos(
self,
- agent_obs: Tuple[int, int, int],
- food_obs: List[Tuple[int, int, int]],
- other_agent_obs: List[Tuple[int, int, int]],
+ agent_obs: tuple[int, int, int],
+ food_obs: list[tuple[int, int, int]],
+ other_agent_obs: list[tuple[int, int, int]],
last_action: LBFAction,
- target_pos: Optional[Coord],
- ) -> Optional[Coord]:
+ target_pos: Coord | None,
+ ) -> Coord | None:
agent_pos = agent_obs[:2]
return self._get_food_by_distance(
agent_pos, food_obs, closest=True, max_food_level=None
@@ -187,12 +188,12 @@ class LBFHeuristic2(LBFHeuristicPolicy):
def _get_target_pos(
self,
- agent_obs: Tuple[int, int, int],
- food_obs: List[Tuple[int, int, int]],
- other_agent_obs: List[Tuple[int, int, int]],
+ agent_obs: tuple[int, int, int],
+ food_obs: list[tuple[int, int, int]],
+ other_agent_obs: list[tuple[int, int, int]],
last_action: LBFAction,
- target_pos: Optional[Coord],
- ) -> Optional[Coord]:
+ target_pos: Coord | None,
+ ) -> Coord | None:
if not other_agent_obs:
return None
@@ -207,12 +208,12 @@ class LBFHeuristic3(LBFHeuristicPolicy):
def _get_target_pos(
self,
- agent_obs: Tuple[int, int, int],
- food_obs: List[Tuple[int, int, int]],
- other_agent_obs: List[Tuple[int, int, int]],
+ agent_obs: tuple[int, int, int],
+ food_obs: list[tuple[int, int, int]],
+ other_agent_obs: list[tuple[int, int, int]],
last_action: LBFAction,
- target_pos: Optional[Coord],
- ) -> Optional[Coord]:
+ target_pos: Coord | None,
+ ) -> Coord | None:
agent_pos, agent_level = agent_obs[:2], agent_obs[2]
return self._get_food_by_distance(
agent_pos, food_obs, closest=True, max_food_level=agent_level
@@ -226,12 +227,12 @@ class LBFHeuristic4(LBFHeuristicPolicy):
def _get_target_pos(
self,
- agent_obs: Tuple[int, int, int],
- food_obs: List[Tuple[int, int, int]],
- other_agent_obs: List[Tuple[int, int, int]],
+ agent_obs: tuple[int, int, int],
+ food_obs: list[tuple[int, int, int]],
+ other_agent_obs: list[tuple[int, int, int]],
last_action: LBFAction,
- target_pos: Optional[Coord],
- ) -> Optional[Coord]:
+ target_pos: Coord | None,
+ ) -> Coord | None:
if target_pos is not None:
# At the start of an episode it will select a target food and move towards
# it. Each time it's current target food is collected it then selects a new
@@ -258,12 +259,12 @@ class LBFHeuristic5(LBFHeuristicPolicy):
def _get_target_pos(
self,
- agent_obs: Tuple[int, int, int],
- food_obs: List[Tuple[int, int, int]],
- other_agent_obs: List[Tuple[int, int, int]],
+ agent_obs: tuple[int, int, int],
+ food_obs: list[tuple[int, int, int]],
+ other_agent_obs: list[tuple[int, int, int]],
last_action: LBFAction,
- target_pos: Optional[Coord],
- ) -> Optional[Coord]:
+ target_pos: Coord | None,
+ ) -> Coord | None:
if target_pos is not None:
# At the start of an episode it will select a target food and move towards
# it. Each time it's current target food is collected it then selects a new
diff --git a/posggym/agents/grid_world/predator_prey/__init__.py b/posggym/agents/grid_world/predator_prey/__init__.py
index f811672..2b83af2 100644
--- a/posggym/agents/grid_world/predator_prey/__init__.py
+++ b/posggym/agents/grid_world/predator_prey/__init__.py
@@ -6,6 +6,7 @@
from posggym.agents.utils import processors
from posggym.config import AGENT_MODEL_DIR
+
agent_model_dir = AGENT_MODEL_DIR / "grid_world" / "predator_prey"
policy_specs = {}
diff --git a/posggym/agents/grid_world/predator_prey/heuristic.py b/posggym/agents/grid_world/predator_prey/heuristic.py
index a2e2c12..c3f0387 100644
--- a/posggym/agents/grid_world/predator_prey/heuristic.py
+++ b/posggym/agents/grid_world/predator_prey/heuristic.py
@@ -1,7 +1,7 @@
"""Heuristic policies for the PredatorPrey grid world environment."""
from __future__ import annotations
-from typing import TYPE_CHECKING, List, Tuple, cast
+from typing import TYPE_CHECKING, ClassVar, cast
import posggym.envs.grid_world.predator_prey as pp
from posggym.agents.policy import Policy, PolicyID, PolicyState
@@ -9,6 +9,7 @@
from posggym.envs.grid_world.core import Direction
from posggym.utils import seeding
+
if TYPE_CHECKING:
from posggym.envs.grid_world.core import Coord
from posggym.model import POSGModel
@@ -21,7 +22,7 @@
class PPHeuristicPolicy(Policy[pp.PPAction, pp.PPObs]):
"""Base class for PredatorPrey environment heuristic policies."""
- VALID_EXPLORE_STRATEGIES = [
+ VALID_EXPLORE_STRATEGIES: ClassVar[list[str]] = [
"uniform_random",
"spiral",
]
@@ -33,7 +34,7 @@ def __init__(
policy_id: PolicyID,
explore_strategy: str = "uniform_random",
explore_epsilon: float = 0.05,
- ):
+ ) -> None:
super().__init__(model, agent_id, policy_id)
assert explore_strategy in self.VALID_EXPLORE_STRATEGIES
assert 0 <= explore_epsilon <= 1
@@ -113,15 +114,15 @@ def get_value(self, state: PolicyState) -> float:
)
def get_actions_from_obs(
- self, pred_coords: List[Coord], prey_coords: List[Coord]
- ) -> List[pp.PPAction]:
+ self, pred_coords: list[Coord], prey_coords: list[Coord]
+ ) -> list[pp.PPAction]:
raise NotImplementedError(
f"`get_pi_from_obs()` not implemented by {self.__class__.__name__} policy"
)
def get_explore_actions_from_obs(
- self, wall_obs: List[bool], explore_dir: Direction | None
- ) -> Tuple[List[pp.PPAction], Direction | None]:
+ self, wall_obs: list[bool], explore_dir: Direction | None
+ ) -> tuple[list[pp.PPAction], Direction | None]:
# using list slice for quick copy of list of primitives
if self.explore_strategy == "uniform_random":
# random explore
@@ -146,7 +147,7 @@ def get_explore_actions_from_obs(
f"{self.__class__.__name__} policy"
)
- def parse_obs(self, obs: pp.PPObs) -> Tuple[List[Coord], List[Coord], List[bool]]:
+ def parse_obs(self, obs: pp.PPObs) -> tuple[list[Coord], list[Coord], list[bool]]:
"""Parse obs into list of predator coords, prey coords, and wall directions."""
pred_coords = []
prey_coords = []
@@ -171,14 +172,14 @@ def parse_obs(self, obs: pp.PPObs) -> Tuple[List[Coord], List[Coord], List[bool]
walls_obs[Direction.SOUTH] = True
return pred_coords, prey_coords, walls_obs
- def get_closest_coord(self, origin: Coord, coords: List[Coord]) -> Coord | None:
+ def get_closest_coord(self, origin: Coord, coords: list[Coord]) -> Coord | None:
"""Get coord of from list that is closest to the origin coord."""
return min(
coords,
key=lambda coord: self._grid.manhattan_dist(origin, coord),
)
- def get_actions_towards_target(self, target_coord: Coord) -> List[pp.PPAction]:
+ def get_actions_towards_target(self, target_coord: Coord) -> list[pp.PPAction]:
"""Get action towards target coord."""
agent_coord = self.agent_obs_coord
@@ -203,12 +204,12 @@ class PPHeuristic1(PPHeuristicPolicy):
randomly, in that order.
"""
- def __init__(self, model: POSGModel, agent_id: str, policy_id: PolicyID):
+ def __init__(self, model: POSGModel, agent_id: str, policy_id: PolicyID) -> None:
super().__init__(model, agent_id, policy_id, "uniform_random")
def get_actions_from_obs(
- self, pred_coords: List[Coord], prey_coords: List[Coord]
- ) -> List[pp.PPAction]:
+ self, pred_coords: list[Coord], prey_coords: list[Coord]
+ ) -> list[pp.PPAction]:
if len(prey_coords) != 0:
closest_prey_coord = self.get_closest_coord(
self.agent_obs_coord, prey_coords
@@ -229,12 +230,12 @@ class PPHeuristic2(PPHeuristicPolicy):
a clockwise spiral around arena, in that order.
"""
- def __init__(self, model: POSGModel, agent_id: str, policy_id: PolicyID):
+ def __init__(self, model: POSGModel, agent_id: str, policy_id: PolicyID) -> None:
super().__init__(model, agent_id, policy_id, "spiral")
def get_actions_from_obs(
- self, pred_coords: List[Coord], prey_coords: List[Coord]
- ) -> List[pp.PPAction]:
+ self, pred_coords: list[Coord], prey_coords: list[Coord]
+ ) -> list[pp.PPAction]:
if len(prey_coords) != 0:
closest_prey_coord = self.get_closest_coord(
self.agent_obs_coord, prey_coords
@@ -255,12 +256,12 @@ class PPHeuristic3(PPHeuristicPolicy):
explores in a clockwise spiral around arena, in that order.
"""
- def __init__(self, model: POSGModel, agent_id: str, policy_id: PolicyID):
+ def __init__(self, model: POSGModel, agent_id: str, policy_id: PolicyID) -> None:
super().__init__(model, agent_id, policy_id, "spiral")
def get_actions_from_obs(
- self, pred_coords: List[Coord], prey_coords: List[Coord]
- ) -> List[pp.PPAction]:
+ self, pred_coords: list[Coord], prey_coords: list[Coord]
+ ) -> list[pp.PPAction]:
if len(prey_coords) == 0:
return []
diff --git a/posggym/agents/grid_world/pursuit_evasion/__init__.py b/posggym/agents/grid_world/pursuit_evasion/__init__.py
index 1ae6ad5..8561d89 100644
--- a/posggym/agents/grid_world/pursuit_evasion/__init__.py
+++ b/posggym/agents/grid_world/pursuit_evasion/__init__.py
@@ -6,6 +6,7 @@
from posggym.agents.utils import processors
from posggym.config import AGENT_MODEL_DIR
+
agent_model_dir = AGENT_MODEL_DIR / "grid_world" / "pursuit_evasion"
policy_specs = {}
@@ -42,7 +43,6 @@ def _get_policy_description(policy_file_name):
# PursuitEvasion 16x16
-# Evader (agent=0)
for policy_file_name in [
"KLR0_i0.pkl",
"KLR1_i0.pkl",
@@ -86,7 +86,6 @@ def _get_policy_description(policy_file_name):
policy_specs[spec.id] = spec
-# Pursuer (agent=1)
for policy_file_name in [
"KLR0_i1.pkl",
"KLR1_i1.pkl",
@@ -130,7 +129,6 @@ def _get_policy_description(policy_file_name):
policy_specs[spec.id] = spec
# PursuitEvasion 8x8
-# Evader (agent=0)
for policy_file_name in [
"KLR0_i0.pkl",
"KLR1_i0.pkl",
@@ -170,7 +168,6 @@ def _get_policy_description(policy_file_name):
policy_specs[spec.id] = spec
-# Pursuer (agent=1)
for policy_file_name in [
"KLR0_i1.pkl",
"KLR1_i1.pkl",
diff --git a/posggym/agents/grid_world/pursuit_evasion/shortest_path.py b/posggym/agents/grid_world/pursuit_evasion/shortest_path.py
index 00e51fb..279baa9 100644
--- a/posggym/agents/grid_world/pursuit_evasion/shortest_path.py
+++ b/posggym/agents/grid_world/pursuit_evasion/shortest_path.py
@@ -2,7 +2,7 @@
from __future__ import annotations
from queue import PriorityQueue
-from typing import TYPE_CHECKING, Dict, Set, Tuple, cast
+from typing import TYPE_CHECKING, cast
import posggym.envs.grid_world.pursuit_evasion as pe
from posggym.agents.policy import Policy, PolicyID, PolicyState
@@ -10,13 +10,14 @@
from posggym.envs.grid_world.core import Coord, Direction
from posggym.utils import seeding
+
if TYPE_CHECKING:
from posggym.model import POSGModel
from posggym.utils.history import AgentHistory
# Current coord, facing direction
-Pos = Tuple[Coord, Direction]
+Pos = tuple[Coord, Direction]
class PEShortestPathPolicy(Policy[pe.PEAction, pe.PEObs]):
@@ -33,7 +34,7 @@ class PEShortestPathPolicy(Policy[pe.PEAction, pe.PEObs]):
"""
- def __init__(self, model: POSGModel, agent_id: str, policy_id: PolicyID):
+ def __init__(self, model: POSGModel, agent_id: str, policy_id: PolicyID) -> None:
super().__init__(model, agent_id, policy_id)
self.model = cast(pe.PursuitEvasionModel, model)
self._grid = self.model.grid
@@ -187,7 +188,7 @@ def get_dest_shortest_path_dist(self, dest_coord: Coord, pos: Pos) -> int:
min_dist = min(min_dist, dists.get(pos, float("inf")))
return min_dist
- def get_all_shortest_paths(self, origin: Coord) -> Dict[Pos, Dict[Pos, int]]:
+ def get_all_shortest_paths(self, origin: Coord) -> dict[Pos, dict[Pos, int]]:
"""Get shortest paths from given origin to all other positions.
Note, this is a search over agent configurations (coord, facing_dir), rather
@@ -200,7 +201,7 @@ def get_all_shortest_paths(self, origin: Coord) -> Dict[Pos, Dict[Pos, int]]:
src_dists[pos] = self.dijkstra(pos)
return src_dists
- def dijkstra(self, origin: Pos) -> Dict[Pos, int]:
+ def dijkstra(self, origin: Pos) -> dict[Pos, int]:
"""Get shortest path distance to origin from all other positions."""
dist = {origin: 0}
pq = PriorityQueue() # type: ignore
@@ -220,7 +221,7 @@ def dijkstra(self, origin: Pos) -> Dict[Pos, int]:
visited.add(adj_pos)
return dist
- def get_prev_positions(self, pos: Pos) -> Set[Pos]:
+ def get_prev_positions(self, pos: Pos) -> set[Pos]:
"""Get all positions reachable from given position."""
coord, facing_dir = pos
prev_positions = set()
diff --git a/posggym/agents/policy.py b/posggym/agents/policy.py
index fe7d380..02f3a17 100644
--- a/posggym/agents/policy.py
+++ b/posggym/agents/policy.py
@@ -4,7 +4,7 @@
import abc
import copy
-from typing import TYPE_CHECKING, Any, Dict, Generic
+from typing import TYPE_CHECKING, Any, Generic
from posggym.model import ActType, ObsType
@@ -18,7 +18,7 @@
# Convenient type definitions
PolicyID = str
-PolicyState = Dict[str, Any]
+PolicyState = dict[str, Any]
class Policy(abc.ABC, Generic[ActType, ObsType]):
@@ -80,7 +80,7 @@ class Policy(abc.ABC, Generic[ActType, ObsType]):
# Whether the policy expects the full state as it's observation or not
observes_state: bool = False
- def __init__(self, model: POSGModel, agent_id: str, policy_id: PolicyID):
+ def __init__(self, model: POSGModel, agent_id: str, policy_id: PolicyID) -> None:
self.model = model
self.agent_id = agent_id
self.policy_id = policy_id
@@ -93,12 +93,12 @@ def step(self, obs: ObsType) -> ActType:
This function updates the policy's current internal state given the most recent
observation, and returns the next action for the policy.
- Arguments
+ Arguments:
---------
obs : ObsType
The latest observation.
- Returns
+ Returns:
-------
action : ActType
The next action
@@ -119,7 +119,7 @@ def reset(self, *, seed: int | None = None):
this is that the seed provided once by the user, just after the policy is first
created and before it interacts with an environment.
- Arguments
+ Arguments:
---------
seed : int, optional
Seed for random number generator.
@@ -133,7 +133,6 @@ def close(self):
Should be overridden in subclasses as necessary.
"""
- pass
def get_initial_state(self) -> PolicyState:
"""Get the policy's initial state.
@@ -161,7 +160,7 @@ def get_next_state(
Subclasses must implement this method.
- Arguments
+ Arguments:
---------
action : ActType, optional
The action performed. May be None if this is the first observation.
@@ -170,7 +169,7 @@ def get_next_state(
state : PolicyState
The policy's state before action was performed and obs received
- Returns
+ Returns:
-------
next_state : PolicyState
The next policy state
@@ -187,12 +186,12 @@ def sample_action(self, state: PolicyState) -> ActType:
Subclasses must implement this method.
- Arguments
+ Arguments:
---------
state : PolicyState
The policy's current state.
- Returns
+ Returns:
-------
action : ActType
The sampled action.
@@ -208,12 +207,12 @@ def get_pi(self, state: PolicyState) -> ActionDistribution:
:py:class:`posggym.agents.utils.action_distributions.ActionDistribution`
class.
- Arguments
+ Arguments:
---------
state : PolicyState
The policy's current state.
- Returns
+ Returns:
-------
pi : ActionDistribution
The policy's distribution over actions.
@@ -227,12 +226,12 @@ def get_value(self, state: PolicyState) -> float:
Subclasses must implement this method, but may set it to raise a
NotImplementedError if the policy does not support value estimates.
- Arguments
+ Arguments:
---------
state : PolicyState
The policy's current state.
- Returns
+ Returns:
-------
value : float
The value estimate.
@@ -246,7 +245,7 @@ def set_state(self, state: PolicyState, last_action: ActType | None = None):
override this method, to set any attributes used for the by the class to store
policy state.
- Arguments
+ Arguments:
---------
state : PolicyState
The new policy state.
@@ -254,7 +253,7 @@ def set_state(self, state: PolicyState, last_action: ActType | None = None):
The last action taken by the policy. If not provided then the last action
will be set to None.
- Raises
+ Raises:
------
AssertionError
If new policy state is not valid.
@@ -285,12 +284,12 @@ def get_state_from_history(self, history: AgentHistory) -> PolicyState:
Note, this function will return None for the action in the final output state,
as this would correspond to the action that was selected by the policy to action
- Arguments
+ Arguments:
---------
history : AgentHistory
The agent's action-observation history.
- Returns
+ Returns:
-------
state : PolicyState
Policy state given history.
diff --git a/posggym/agents/random_policies.py b/posggym/agents/random_policies.py
index fea50c9..111b9ca 100644
--- a/posggym/agents/random_policies.py
+++ b/posggym/agents/random_policies.py
@@ -30,7 +30,7 @@ def __init__(
agent_id: str,
policy_id: PolicyID,
dist: action_distributions.ActionDistribution | None = None,
- ):
+ ) -> None:
super().__init__(model, agent_id, policy_id)
self.dist = dist
self._rng, _ = seeding.np_random()
@@ -97,7 +97,7 @@ class RandomPolicy(Policy[ActType, ObsType]):
"""
- def __init__(self, model: M.POSGModel, agent_id: str, policy_id: PolicyID):
+ def __init__(self, model: M.POSGModel, agent_id: str, policy_id: PolicyID) -> None:
super().__init__(model, agent_id, policy_id)
self._action_space = model.action_spaces[agent_id]
self._rng = random.Random()
diff --git a/posggym/agents/registration.py b/posggym/agents/registration.py
index 35ec48f..8729353 100644
--- a/posggym/agents/registration.py
+++ b/posggym/agents/registration.py
@@ -13,7 +13,7 @@
import re
from collections import defaultdict
from dataclasses import dataclass, field
-from typing import TYPE_CHECKING, Any, Dict, List, Protocol, Tuple
+from typing import TYPE_CHECKING, Any, Protocol
import posggym
from posggym import error, logger
@@ -44,12 +44,12 @@ def __call__(
def load(name: str) -> PolicyEntryPoint:
"""Loads policy with name and returns a policy entry point.
- Arguments
+ Arguments:
---------
name : str
The policy name.
- Returns
+ Returns:
-------
entry_point : PolicyEntryPoint
Policy creation function.
@@ -61,19 +61,19 @@ def load(name: str) -> PolicyEntryPoint:
return fn
-def parse_policy_id(policy_id: str) -> Tuple[str | None, str | None, str, int | None]:
+def parse_policy_id(policy_id: str) -> tuple[str | None, str | None, str, int | None]:
"""Parse policy ID string format.
env_id is group 1, env_args_id is group 2. policy_id is group 2, version is group 3
[env_id/][env_args_id/](policy_id)-v(version)
- Arguments
+ Arguments:
---------
policy_id : str
The policy id to parse.
- Returns
+ Returns:
-------
env_id : str | None
The environment ID.
@@ -84,9 +84,11 @@ def parse_policy_id(policy_id: str) -> Tuple[str | None, str | None, str, int |
version : int | None
The policy version.
- Raises
+ Raises:
------
- Error
+
+ Error:
+ -----
If the policy id does not a valid environment regex.
"""
@@ -120,7 +122,7 @@ def get_policy_id(
Inverse of :meth:`parse_policy_id`.
- Arguments
+ Arguments:
---------
env_id : str | None
The environment ID.
@@ -131,7 +133,7 @@ def get_policy_id(
version : int | None
The policy version.
- Returns
+ Returns:
-------
policy_id : str
The policy id.
@@ -153,7 +155,7 @@ def get_policy_id(
return full_name
-def get_env_args_id(env_args: Dict[str, Any]) -> str:
+def get_env_args_id(env_args: dict[str, Any]) -> str:
"""Get string representation of environment keyword arguments.
Converts keyword dictionary {k1: v1, k2: v2, k3: v3} into a string:
@@ -163,12 +165,12 @@ def get_env_args_id(env_args: Dict[str, Any]) -> str:
Note we assume keywords are valid python variable names and so do not contain
any hyphen '-' characters.
- Arguments
+ Arguments:
---------
env_args : Dict[str, Any]
Environment keyword arguments.
- Returns
+ Returns:
-------
env_args_id : str
String representation of the envrinment keyword arguments.
@@ -184,7 +186,7 @@ class PolicySpec:
Used to register agent policies that can then be dynamically loaded using
posggym_agents.make.
- Arguments
+ Arguments:
---------
policy_name
The name of the policy.
@@ -231,15 +233,15 @@ class PolicySpec:
# Environment attributes
env_id: str | None = field(default=None)
- env_args: Dict[str, Any] | None = field(default=None)
+ env_args: dict[str, Any] | None = field(default=None)
# Policy attributes
- valid_agent_ids: List[str] | None = field(default=None)
+ valid_agent_ids: list[str] | None = field(default=None)
nondeterministic: bool = field(default=False)
description: str | None = field(default=None)
# Policy Arguments
- kwargs: Dict = field(default_factory=dict)
+ kwargs: dict = field(default_factory=dict)
# post-init attributes
env_args_id: str | None = field(default=None)
@@ -285,7 +287,7 @@ def _check_env_id_exists(env_id: str | None, env_args_id: str | None):
if suggestion
else f"Have you installed the proper package for {env_id}?"
)
- raise error.PolicyEnvIDNotFound(
+ raise error.PolicyEnvIDNotFoundError(
f"Environment ID {env_id} not found. {suggestion_msg}"
)
@@ -303,7 +305,7 @@ def _check_env_id_exists(env_id: str | None, env_args_id: str | None):
else None
)
suggestion_msg = f"Did you mean: `{suggestion[0]}`?" if suggestion else ""
- raise error.PolicyEnvArgsIDNotFound(
+ raise error.PolicyEnvArgsIDNotFoundError(
f"Environment Arguments {env_args_id} for environment ID {env_id} not "
f"found. {suggestion_msg}"
)
@@ -337,7 +339,7 @@ def _check_name_exists(env_id: str | None, env_args_id: str | None, policy_name:
env_id_msg = f" for env ID {env_id}" if env_id else ""
suggestion_msg = f"Did you mean: `{names[suggestion[0]]}`?" if suggestion else ""
- raise error.PolicyNameNotFound(
+ raise error.PolicyNameNotFoundError(
f"Policy {policy_name} doesn't exist{env_id_msg}. {suggestion_msg}"
)
@@ -350,7 +352,7 @@ def _check_version_exists(
This is a complete test whether an policy ID is valid, and will provide the best
available hints.
- Arguments
+ Arguments:
---------
env_id : str | None
The environment ID.
@@ -361,12 +363,12 @@ def _check_version_exists(
version : int | None
The policy version.
- Raises
+ Raises:
------
- DeprecatedPolicy
+ DeprecatedPolicyError
The policy doesn't exist but a default version does or the policy version is
deprecated
- VersionNotFound
+ VersionNotFoundError
The ``version`` used doesn't exist
"""
@@ -397,7 +399,7 @@ def _check_version_exists(
if default_spec:
message += f" It provides the default version {default_spec[0].id}`."
if len(policy_specs) == 1:
- raise error.DeprecatedPolicy(message)
+ raise error.DeprecatedPolicyError(message)
# Process possible versioned environments
versioned_specs = [spec_ for spec_ in policy_specs if spec_.version is not None]
@@ -411,10 +413,10 @@ def _check_version_exists(
if version > latest_spec.version:
version_list_msg = ", ".join(f"`v{spec_.version}`" for spec_ in policy_specs)
message += f" It provides versioned policies: [ {version_list_msg} ]."
- raise error.PolicyVersionNotFound(message)
+ raise error.PolicyVersionNotFoundError(message)
if version < latest_spec.version:
- raise error.DeprecatedPolicy(
+ raise error.DeprecatedPolicyError(
f"Policy version v{version} for "
f"`{get_policy_id(env_id, env_args_id, policy_name, None)}` "
f"is deprecated. Please use `{latest_spec.id}` instead."
@@ -499,8 +501,8 @@ def register(
entry_point: PolicyEntryPoint | str,
version: int | None = None,
env_id: str | None = None,
- env_args: Dict[str, Any] | None = None,
- valid_agent_ids: List[str] | None = None,
+ env_args: dict[str, Any] | None = None,
+ valid_agent_ids: list[str] | None = None,
nondeterministic: bool = False,
description: str | None = None,
**kwargs,
@@ -510,7 +512,7 @@ def register(
The policy is registered in posggym so it can be used with
:py:method:`posggym.agents.make`
- Arguments
+ Arguments:
---------
policy_name : str
The name of the policy
@@ -553,7 +555,7 @@ def register(
def register_spec(spec: PolicySpec):
"""Register a policy spec with posggym-agents.
- Arguments
+ Arguments:
---------
spec : PolicySpec
The policy spec.
@@ -562,7 +564,7 @@ def register_spec(spec: PolicySpec):
global registry
_check_spec_register(spec)
if spec.id in registry:
- logger.warn(f"Overriding policy {spec.id} already in registry.")
+ logger.warning(f"Overriding policy {spec.id} already in registry.")
registry[spec.id] = spec
@@ -572,7 +574,7 @@ def make(id: str | PolicySpec, model: M.POSGModel, agent_id: str, **kwargs) -> P
To find all available policies use `posggym_agents.agents.registry.keys()` for
all valid ids.
- Arguments
+ Arguments:
---------
id : str | PolicySpec
Unique identifier of the policy or a policy spec.
@@ -583,14 +585,16 @@ def make(id: str | PolicySpec, model: M.POSGModel, agent_id: str, **kwargs) -> P
**kwargs
Additional arguments to pass to the policy constructor.
- Returns
+ Returns:
-------
Policy
An instance of the policy.
- Raises
+ Raises:
------
- Error
+
+ Error:
+ -----
If the ``id`` doesn't exist then an error is raised
"""
@@ -616,7 +620,7 @@ def make(id: str | PolicySpec, model: M.POSGModel, agent_id: str, **kwargs) -> P
and latest_version is not None
and latest_version > version
):
- logger.warn(
+ logger.warning(
f"The policy {id} is out of date. You should consider "
f"upgrading to version `v{latest_version}`."
)
@@ -625,7 +629,7 @@ def make(id: str | PolicySpec, model: M.POSGModel, agent_id: str, **kwargs) -> P
version = latest_version
new_policy_id = get_policy_id(env_id, env_args_id, policy_name, version)
spec_ = registry.get(new_policy_id) # type: ignore
- logger.warn(
+ logger.warning(
f"Using the latest versioned policy `{new_policy_id}` "
f"instead of the unversioned policy `{id}`."
)
@@ -668,19 +672,21 @@ def make(id: str | PolicySpec, model: M.POSGModel, agent_id: str, **kwargs) -> P
def spec(id: str) -> PolicySpec:
"""Retrieve the spec for the given policy from the global registry.
- Arguments
+ Arguments:
---------
id : str
The policy id.
- Returns
+ Returns:
-------
PolicySpec
The policy spec from the global registry.
- Raises
+ Raises:
------
- Error
+
+ Error:
+ -----
If policy with given ``id`` doesn't exist in global registry.
"""
@@ -702,15 +708,15 @@ def spec(id: str) -> PolicySpec:
def pprint_registry(
- _registry: Dict = registry,
+ _registry: dict = registry,
num_cols: int = 3,
- include_env_ids: List[str] | None = None,
- exclude_env_ids: List[str] | None = None,
+ include_env_ids: list[str] | None = None,
+ exclude_env_ids: list[str] | None = None,
disable_print: bool = False,
) -> str | None:
"""Pretty print the policies in the registry.
- Arguments
+ Arguments:
---------
_registry : Dict
Policy registry to be printed.
@@ -725,7 +731,7 @@ def pprint_registry(
Whether to return a string of all the policy IDs instead of printing it to
console.
- Returns
+ Returns:
-------
str | None
Formatted str representation of registry, if ``disable_print=True``, otherwise
@@ -733,7 +739,7 @@ def pprint_registry(
"""
# Defaultdict to store policy names according to env_id.
- env_policies = defaultdict(lambda: defaultdict(lambda: []))
+ env_policies = defaultdict(lambda: defaultdict(list))
max_justify = 0
for spec in _registry.values():
env_id = "Generic" if spec.env_id is None else spec.env_id
@@ -774,13 +780,13 @@ def pprint_registry(
def get_all_env_policies(
env_id: str,
- env_args: Dict[str, Any] | str | None = None,
- _registry: Dict = registry,
+ env_args: dict[str, Any] | str | None = None,
+ _registry: dict = registry,
include_generic_policies: bool = True,
-) -> List[PolicySpec]:
+) -> list[PolicySpec]:
"""Get all PolicySpecs that are associated with a given environment ID.
- Arguments
+ Arguments:
---------
env_id : str
The ID of the environment.
@@ -793,7 +799,7 @@ def get_all_env_policies(
Whether to also return policies that are valid for all environments (e.g. the
random-v0 policy).
- Returns
+ Returns:
-------
List[PolicySpec]
List of specs for policies associated with given environment.
@@ -819,13 +825,13 @@ def get_all_env_policies(
def get_env_agent_policies(
env_id: str,
- env_args: Dict[str, Any] | None = None,
- _registry: Dict = registry,
+ env_args: dict[str, Any] | None = None,
+ _registry: dict = registry,
include_generic_policies: bool = True,
-) -> Dict[str, List[PolicySpec]]:
+) -> dict[str, list[PolicySpec]]:
"""Get each agent's policy specs associated with given environment.
- Arguments
+ Arguments:
---------
env_id : str
The ID of the environment.
@@ -838,7 +844,7 @@ def get_env_agent_policies(
Whether to also return policies that are valid for all environments (e.g. the
random-v0 policy) and environment args.
- Returns
+ Returns:
-------
Dict[str, List[PolicySpec]]
List of specs for policies associated with given environment.
@@ -846,7 +852,7 @@ def get_env_agent_policies(
"""
env = posggym.make(env_id) if env_args is None else posggym.make(env_id, **env_args)
- policies: Dict[str, List[PolicySpec]] = {i: [] for i in env.possible_agents}
+ policies: dict[str, list[PolicySpec]] = {i: [] for i in env.possible_agents}
for spec in get_all_env_policies(
env_id,
env_args,
@@ -860,23 +866,23 @@ def get_env_agent_policies(
def get_all_envs(
- _registry: Dict = registry,
-) -> Dict[str, Dict[str | None, Dict[str, Any] | None]]:
+ _registry: dict = registry,
+) -> dict[str, dict[str | None, dict[str, Any] | None]]:
"""Get all the environments that have at least one registered policy.
- Arguments
+ Arguments:
---------
_registry : Dict
The policy registry.
- Returns
+ Returns:
-------
Dict[str, Dict[str | None, Dict[str, Any] | None]]
A dictionary with env IDs as keys as list of (env_args, env_args_id) tuples as
the values.
"""
- envs: Dict[str, Dict[str | None, Dict[str, Any] | None]] = {}
+ envs: dict[str, dict[str | None, dict[str, Any] | None]] = {}
for spec in _registry.values():
if spec.env_id is not None:
envs.setdefault(spec.env_id, {})
diff --git a/posggym/agents/torch_policy.py b/posggym/agents/torch_policy.py
index ce56a14..3edd90c 100644
--- a/posggym/agents/torch_policy.py
+++ b/posggym/agents/torch_policy.py
@@ -8,21 +8,14 @@
from typing import (
TYPE_CHECKING,
Any,
- Callable,
- Dict,
- List,
NamedTuple,
- Optional,
- Tuple,
- Type,
- Union,
)
import numpy as np
import torch
-import torch.nn as nn
import torch.nn.functional as F
from gymnasium import spaces
+from torch import nn
from torch.distributions import Categorical, Normal
from posggym import logger
@@ -32,18 +25,25 @@
from posggym.agents.utils.download import download_from_repo
from posggym.utils import seeding
+
+BATCH_OBSERVATION_DIMENSION = 2
+SINGLE_OBSERVATION_DIMENSION = 1
+MIN_CUDA_VERSION = 10.2
+
if TYPE_CHECKING:
+ from collections.abc import Callable
+
import posggym.model as M
class PPOTorchModelSaveFileFormat(NamedTuple):
"""Format for saving and loading POSGGym PPOLSTMModel."""
- weights: Dict[str, Any]
- trunk_sizes: List[int]
+ weights: dict[str, Any]
+ trunk_sizes: list[int]
lstm_size: int
lstm_layers: int
- head_sizes: List[int]
+ head_sizes: list[int]
activation: str
lstm_use_prev_action: bool
lstm_use_prev_reward: bool
@@ -73,14 +73,14 @@ def __init__(
self,
obs_space: spaces.Space,
action_space: spaces.Space,
- trunk_sizes: List[int],
+ trunk_sizes: list[int],
lstm_size: int,
lstm_layers: int,
- head_sizes: List[int],
+ head_sizes: list[int],
activation: str,
lstm_use_prev_action: bool,
lstm_use_prev_reward: bool,
- ):
+ ) -> None:
assert isinstance(obs_space, spaces.Box) and len(obs_space.shape) == 1, (
"Only 1D Box observation spaces are supported for PPO PyTorch Policy "
"models. Look into using `gymansium.spaces.flatten_space` to flatten your "
@@ -112,7 +112,7 @@ def __init__(
"Expected either Discrete, MultiDiscrete, or Box."
)
- activation_fn: Optional[Callable] = None
+ activation_fn: Callable | None = None
if activation == "tanh":
activation_fn = nn.Tanh
elif activation == "relu":
@@ -167,16 +167,16 @@ def device(self) -> torch.device:
def get_next_state(
self,
- obs: Union[np.ndarray, torch.Tensor],
- lstm_state: Tuple[torch.Tensor, torch.Tensor],
- prev_action: Optional[Union[np.ndarray, torch.Tensor]] = None,
- prev_reward: Optional[Union[np.ndarray, torch.Tensor]] = None,
- ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ obs: np.ndarray | torch.Tensor,
+ lstm_state: tuple[torch.Tensor, torch.Tensor],
+ prev_action: np.ndarray | torch.Tensor | None = None,
+ prev_reward: np.ndarray | torch.Tensor | None = None,
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""Get next lstm output and state.
If obs is not batched, adds batch dimension with batch size of 1.
- Arguments
+ Arguments:
---------
obs
the observation, shape=(batch_size, obs_size) | (obs_size, )
@@ -188,7 +188,7 @@ def get_next_state(
prev_reward
the previous reward, shape=(batch_size, 1) | (1, )
- Returns
+ Returns:
-------
lstm_output
the lstm output, shape=(batch_size, cell_size)
@@ -199,11 +199,11 @@ def get_next_state(
if isinstance(obs, np.ndarray):
obs = torch.tensor(obs, dtype=torch.float32)
- if len(obs.shape) == 1:
+ if len(obs.shape) == SINGLE_OBSERVATION_DIMENSION:
# Single observation
# Add batch and sequence length dimensions
obs = obs.reshape(1, 1, -1)
- elif len(obs.shape) == 2:
+ elif len(obs.shape) == BATCH_OBSERVATION_DIMENSION:
# Batch of observations
# Add sequence length dimension
obs = obs.unsqueeze(1)
@@ -212,7 +212,7 @@ def get_next_state(
prev_action_reward = []
if self.use_prev_action:
- if isinstance(self.action_space, (spaces.Discrete, spaces.MultiDiscrete)):
+ if isinstance(self.action_space, spaces.Discrete | spaces.MultiDiscrete):
# One-hot encode discrete actions
prev_action = F.one_hot(
torch.tensor(prev_action, dtype=torch.int64),
@@ -227,7 +227,7 @@ def get_next_state(
)
if len(prev_action_reward) > 0:
- hidden = torch.cat([hidden] + prev_action_reward, dim=-1)
+ hidden = torch.cat([hidden, *prev_action_reward], dim=-1)
lstm_output, lstm_state = self.lstm(hidden, lstm_state)
# remove sequence length dimension
@@ -236,16 +236,16 @@ def get_next_state(
def get_value(
self,
- obs: Union[np.ndarray, torch.Tensor],
- lstm_state: Tuple[torch.Tensor, torch.Tensor],
- prev_action: Optional[Union[np.ndarray, torch.Tensor]] = None,
- prev_reward: Optional[Union[np.ndarray, torch.Tensor]] = None,
+ obs: np.ndarray | torch.Tensor,
+ lstm_state: tuple[torch.Tensor, torch.Tensor],
+ prev_action: np.ndarray | torch.Tensor | None = None,
+ prev_reward: np.ndarray | torch.Tensor | None = None,
) -> torch.Tensor:
"""Get value function output.
If input is not batched, then adds batch dimension with batch size of 1.
- Arguments
+ Arguments:
---------
obs
the observation, shape=(batch_size, obs_size) | (obs_size, )
@@ -257,7 +257,7 @@ def get_value(
prev_reward
the previous reward, shape=(batch_size, 1) | (1, )
- Returns
+ Returns:
-------
value
output of value function, shape=(batch_size, 1)
@@ -268,19 +268,19 @@ def get_value(
def get_action_and_value(
self,
- obs: Union[np.ndarray, torch.Tensor],
- lstm_state: Tuple[torch.Tensor, torch.Tensor],
- prev_action: Optional[Union[np.ndarray, torch.Tensor]] = None,
- prev_reward: Optional[Union[np.ndarray, torch.Tensor]] = None,
+ obs: np.ndarray | torch.Tensor,
+ lstm_state: tuple[torch.Tensor, torch.Tensor],
+ prev_action: np.ndarray | torch.Tensor | None = None,
+ prev_reward: np.ndarray | torch.Tensor | None = None,
deterministic: bool = False,
- ) -> Tuple[
- torch.Tensor, Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor
+ ) -> tuple[
+ torch.Tensor, tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor
]:
"""Get next action and value.
If input is not batched, then adds batch dimension with batch size of 1.
- Arguments
+ Arguments:
---------
obs
the observation, shape=(batch_size, obs_size) | (obs_size, )
@@ -295,7 +295,7 @@ def get_action_and_value(
whether to sample action from action distribution or deterministicly select
action with highest probability.
- Returns
+ Returns:
-------
action
next action, shape=(batch_size, action_size)
@@ -342,15 +342,15 @@ def get_action_and_value(
def get_initial_state(
self, batch_size: int = 1
- ) -> Tuple[torch.Tensor, torch.Tensor]:
+ ) -> tuple[torch.Tensor, torch.Tensor]:
"""Get the initial LSTM state.
- Arguments
+ Arguments:
---------
batch_size
the batch size of the LSTM state
- Returns
+ Returns:
-------
initial_state
the initial LSTM state, this is a tuple of two tensors, each with
@@ -366,7 +366,7 @@ def get_initial_state(
class PPOPolicy(Policy[ActType, ObsType]):
"""A PyTorch PPO Policy.
- Arguments
+ Arguments:
---------
model
the model of the environment
@@ -398,7 +398,7 @@ def __init__(
obs_processor: processors.Processor | None = None,
action_processor: processors.Processor | None = None,
deterministic: bool = False,
- ):
+ ) -> None:
self.policy_model = policy_model
self.deterministic = deterministic
self.action_space = model.action_spaces[agent_id]
@@ -442,7 +442,7 @@ def reset(self, *, seed: int | None = None):
cuda_version = torch.version.cuda
if (
cuda_version is not None
- and float(cuda_version) >= 10.2
+ and float(cuda_version) >= MIN_CUDA_VERSION
and (
"CUBLAS_WORKSPACE_CONFIG" not in os.environ
or (
@@ -530,10 +530,10 @@ def load_from_path(
policy_id: str,
policy_file_path: Path,
deterministic: bool = False,
- obs_processor_cls: Type[processors.Processor] | None = None,
- obs_processor_config: Dict[str, Any] | None = None,
- action_processor_cls: Type[processors.Processor] | None = None,
- action_processor_config: Dict[str, Any] | None = None,
+ obs_processor_cls: type[processors.Processor] | None = None,
+ obs_processor_config: dict[str, Any] | None = None,
+ action_processor_cls: type[processors.Processor] | None = None,
+ action_processor_config: dict[str, Any] | None = None,
) -> PPOPolicy:
if not policy_file_path.exists():
logger.info(
@@ -601,17 +601,17 @@ def load_from_path(
def get_spec_from_path(
policy_file_path: Path,
env_id: str,
- env_args: Dict[str, Any] | None,
+ env_args: dict[str, Any] | None,
env_args_id: str | None = None,
version: int = 0,
- valid_agent_ids: List[str] | None = None,
+ valid_agent_ids: list[str] | None = None,
nondeterministic: bool = False,
description: str | None = None,
**kwargs,
) -> PolicySpec:
"""Load PPO policy spec from policy file.
- Arguments
+ Arguments:
---------
policy_file_path
Path to the policy file.
@@ -637,7 +637,7 @@ def get_spec_from_path(
Additional kwargs, if any, to pass to the agent initializing function.
- Returns
+ Returns:
-------
spec
Policy specs for PPO Policy loaded from policy file.
diff --git a/posggym/agents/utils/action_distributions.py b/posggym/agents/utils/action_distributions.py
index 9fd4055..965e81f 100644
--- a/posggym/agents/utils/action_distributions.py
+++ b/posggym/agents/utils/action_distributions.py
@@ -3,7 +3,7 @@
import abc
import random
-from typing import TYPE_CHECKING, Any, Dict, List, Union
+from typing import TYPE_CHECKING, Any
import numpy as np
@@ -27,7 +27,7 @@ def pdf(self, action: Any) -> float:
class DiscreteActionDistribution(ActionDistribution):
"""Action distribution for discrete actions."""
- def __init__(self, probs: Dict[Any, float], rng: seeding.RNG | None = None):
+ def __init__(self, probs: dict[Any, float], rng: seeding.RNG | None = None) -> None:
self.probs = probs
self._rng = rng
@@ -61,7 +61,9 @@ def __eq__(self, other: object) -> bool:
class MultiDiscreteActionDistribution(ActionDistribution):
"""Action distribution for multi-discrete actions."""
- def __init__(self, probs: List[Dict[Any, float]], rng: seeding.RNG | None = None):
+ def __init__(
+ self, probs: list[dict[Any, float]], rng: seeding.RNG | None = None
+ ) -> None:
self.probs = probs
self._rng = rng
@@ -94,18 +96,17 @@ class NormalActionDistribution(ActionDistribution):
def __init__(
self,
- mean: Union[float, np.ndarray],
- stddev: Union[float, np.ndarray],
+ mean: float | np.ndarray,
+ stddev: float | np.ndarray,
rng: np.random.Generator | None = None,
- ):
+ ) -> None:
self.mean = mean
self.stddev = stddev
self._rng = rng
def sample(self) -> Any:
- if self._rng is None:
- return np.random.normal(loc=self.mean, scale=self.stddev)
- return self._rng.normal(loc=self.mean, scale=self.stddev)
+ rng = self._rng or np.random.default_rng()
+ return rng.normal(loc=self.mean, scale=self.stddev)
def pdf(self, action: Any) -> float:
# ref:
@@ -126,7 +127,7 @@ def __eq__(self, other: object) -> bool:
class DeterministicActionDistribution(ActionDistribution):
"""Action distribution for deterministic action distribution."""
- def __init__(self, action: Union[int, float, np.ndarray]):
+ def __init__(self, action: int | float | np.ndarray) -> None:
self.action = action
def sample(self) -> Any:
@@ -146,18 +147,17 @@ class ContinousUniformActionDistribution(ActionDistribution):
def __init__(
self,
- low: Union[float, np.ndarray],
- high: Union[float, np.ndarray],
+ low: float | np.ndarray,
+ high: float | np.ndarray,
rng: np.random.Generator | None = None,
- ):
+ ) -> None:
self.low = low
self.high = high
self._rng = rng
def sample(self) -> Any:
- if self._rng is None:
- return np.random.uniform(low=self.low, high=self.high)
- return self._rng.uniform(low=self.low, high=self.high)
+ rng = self._rng or np.random.default_rng()
+ return rng.uniform(low=self.low, high=self.high)
def pdf(self, action: Any) -> float:
return 1.0 / np.prod(self.high - self.low)
@@ -180,18 +180,17 @@ class DiscreteUniformActionDistribution(ActionDistribution):
def __init__(
self,
- low: Union[int, np.ndarray],
- high: Union[int, np.ndarray],
+ low: int | np.ndarray,
+ high: int | np.ndarray,
rng: np.random.Generator | None = None,
- ):
+ ) -> None:
self.low = low
self.high = high
self._rng = rng
def sample(self) -> Any:
- if self._rng is None:
- return np.random.randint(low=self.low, high=self.high + 1)
- return self._rng.integers(low=self.low, high=self.high + 1)
+ rng = self._rng or np.random.default_rng()
+ return rng.integers(low=self.low, high=self.high + 1)
def pdf(self, action: Any) -> float:
return 1.0 / np.prod(self.high - self.low + 1)
diff --git a/posggym/agents/utils/download.py b/posggym/agents/utils/download.py
index ff18e9a..c939d8b 100644
--- a/posggym/agents/utils/download.py
+++ b/posggym/agents/utils/download.py
@@ -17,14 +17,14 @@
def download_to_file(url: str, dest_file_path: Path):
"""Download file from URL and store at specified destination.
- Arguments
+ Arguments:
---------
url
Full url to download file from.
dest_file_path
File path to write downloaded file to.
- Raises
+ Raises:
------
posggym.error.DownloadError
If error occurred while trying to download file.
@@ -33,14 +33,14 @@ def download_to_file(url: str, dest_file_path: Path):
dest_dir = dest_file_path.parent
dest_dir.mkdir(exist_ok=True)
- r = requests.get(url, stream=True)
+ r = requests.get(url, stream=True, timeout=600)
if r.ok:
with open(dest_file_path, "wb") as f:
content_len = r.headers.get("content-length")
if isinstance(content_len, str):
try:
total_length = int(content_len)
- except (TypeError,):
+ except TypeError:
total_length = LARGEST_FILE_SIZE
else:
total_length = LARGEST_FILE_SIZE
@@ -60,23 +60,23 @@ def download_to_file(url: str, dest_file_path: Path):
except requests.exceptions.HTTPError as e:
# wrap exception in posggym-agents error
raise error.DownloadError(
- f"Error while downloading file, caused by: {type(e).__name__}: {str(e)}"
+ f"Error while downloading file, caused by: {type(e).__name__}: {e!s}"
) from e
def download_from_repo(file_path: Path, rewrite_existing: bool = False):
"""Download file from the posggym-agent-models github repo.
- Arguments
+ Arguments:
---------
file_path
Local path to posgym package file.
rewrite_existing
Whether to re-download and rewrite an existing copy of the file.
- Raises
+ Raises:
------
- posggym.error.InvalidFile
+ posggym.error.InvalidFileError
If file_path is not a valid posggym-agents package file.
posggym.error.DownloadError
If error occurred while trying to download file.
@@ -86,7 +86,7 @@ def download_from_repo(file_path: Path, rewrite_existing: bool = False):
return
if "agents" not in file_path.parts:
- raise error.InvalidFile(
+ raise error.InvalidFileError(
f"Invalid posggym.agents file path '{file_path}'. Path must contain the "
"`agents` directory."
)
diff --git a/posggym/agents/utils/processors.py b/posggym/agents/utils/processors.py
index dc8f1bc..6644543 100644
--- a/posggym/agents/utils/processors.py
+++ b/posggym/agents/utils/processors.py
@@ -13,7 +13,7 @@ class Processor(abc.ABC):
example observations or actions.
"""
- def __init__(self, input_space: spaces.Space):
+ def __init__(self, input_space: spaces.Space) -> None:
self.input_space = input_space
@abc.abstractmethod
@@ -67,7 +67,7 @@ def __init__(
min_val: float = -1.0,
max_val: float = 1.0,
clip: bool = True,
- ):
+ ) -> None:
assert isinstance(input_space, spaces.Box)
super().__init__(input_space)
self.min_val = min_val
diff --git a/posggym/agents/wrappers/agent_env.py b/posggym/agents/wrappers/agent_env.py
index 65cfd43..080eb16 100644
--- a/posggym/agents/wrappers/agent_env.py
+++ b/posggym/agents/wrappers/agent_env.py
@@ -1,5 +1,5 @@
"""Wrapper for incorporating posggym.agents as part of the environment."""
-from typing import Callable, Dict, List, Tuple
+from collections.abc import Callable
from gymnasium import spaces
@@ -14,7 +14,7 @@ class AgentEnvWrapper(posggym.Wrapper):
actions determined internally by a posggym.agent policy. The environment wrapper
will only return observations, rewards, etc, for agents which are not controlled.
- Arguments
+ Arguments:
---------
env : posggym.Env
The environment to apply the wrapper
@@ -27,8 +27,8 @@ class AgentEnvWrapper(posggym.Wrapper):
def __init__(
self,
env: posggym.Env,
- agent_fn: Callable[[posggym.POSGModel], Dict[str, pga.Policy]],
- ):
+ agent_fn: Callable[[posggym.POSGModel], dict[str, pga.Policy]],
+ ) -> None:
"""Initializes the wrapper."""
super().__init__(env)
self.agent_fn = agent_fn
@@ -39,17 +39,17 @@ def __init__(
self.last_terminateds = {}
@property
- def possible_agents(self) -> Tuple[str, ...]:
+ def possible_agents(self) -> tuple[str, ...]:
return tuple(
i for i in super().possible_agents if i not in self.controlled_agents
)
@property
- def agents(self) -> List[str]:
+ def agents(self) -> list[str]:
return [i for i in super().agents if i not in self.controlled_agents]
@property
- def action_spaces(self) -> Dict[str, spaces.Space]:
+ def action_spaces(self) -> dict[str, spaces.Space]:
return {
i: act_space
for i, act_space in super().action_spaces.items()
@@ -57,7 +57,7 @@ def action_spaces(self) -> Dict[str, spaces.Space]:
}
@property
- def observation_spaces(self) -> Dict[str, spaces.Space]:
+ def observation_spaces(self) -> dict[str, spaces.Space]:
return {
i: obs_space
for i, obs_space in super().observation_spaces.items()
@@ -65,7 +65,7 @@ def observation_spaces(self) -> Dict[str, spaces.Space]:
}
@property
- def reward_ranges(self) -> Dict[str, Tuple[float, float]]:
+ def reward_ranges(self) -> dict[str, tuple[float, float]]:
return {
i: rew_range
for i, rew_range in super().reward_ranges.items()
diff --git a/posggym/core.py b/posggym/core.py
index cd75920..6d1cd92 100644
--- a/posggym/core.py
+++ b/posggym/core.py
@@ -12,10 +12,11 @@
import abc
import copy
-from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple, TypeVar
+from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar
from posggym.model import ActType, ObsType, POSGModel, StateType
+
if TYPE_CHECKING:
import numpy as np
from gymnasium import spaces
@@ -70,7 +71,7 @@ class Env(abc.ABC, Generic[StateType, ObsType, ActType]):
"""
# Set this in SOME subclasses
- metadata: Dict[str, Any] = {"render_modes": []}
+ metadata: ClassVar[dict[str, Any]] = {"render_modes": []}
# Define render_mode if your environment supports rendering
render_mode: str | None = None
@@ -84,26 +85,26 @@ class Env(abc.ABC, Generic[StateType, ObsType, ActType]):
@abc.abstractmethod
def step(
- self, actions: Dict[str, ActType]
- ) -> Tuple[
- Dict[str, ObsType],
- Dict[str, float],
- Dict[str, bool],
- Dict[str, bool],
+ self, actions: dict[str, ActType]
+ ) -> tuple[
+ dict[str, ObsType],
+ dict[str, float],
+ dict[str, bool],
+ dict[str, bool],
bool,
- Dict[str, Dict[str, Any]],
+ dict[str, dict[str, Any]],
]:
"""Run one timestep in the environment using the agents' actions.
When the end of an episode is reached, the user is responsible for
calling :meth:`reset()` to reset this environments state.
- Arguments
+ Arguments:
---------
actions : Dict[str, ActType]
a joint action containing one action per active agent in the environment.
- Returns
+ Returns:
-------
observations : Dict[str, ObsType]
the joint observation containing one observation per agent.
@@ -133,8 +134,8 @@ def step(
"""
def reset(
- self, *, seed: int | None = None, options: Dict[str, Any] | None = None
- ) -> Tuple[Dict[str, ObsType], Dict[str, Dict]]:
+ self, *, seed: int | None = None, options: dict[str, Any] | None = None
+ ) -> tuple[dict[str, ObsType], dict[str, dict]]:
"""Resets the environment and returns an initial observations and info.
This method generates a new starting state often with some randomness. This
@@ -149,7 +150,7 @@ def reset(
For Custom environments, the first line of :meth:`reset` should be
``super().reset(seed=seed)`` which implements the seeding correctly.
- Arguments
+ Arguments:
---------
seed : int, optional
The seed that is used to initialize the environment's RNG. If the
@@ -161,7 +162,7 @@ def reset(
Additional information to specify how the environment is reset (optional,
depending on the specific environment)
- Returns
+ Returns:
-------
observations : Dict[str, ObsType]
The joint observation containing one observation per agent in the
@@ -178,7 +179,7 @@ def reset(
def render(
self,
- ) -> None | np.ndarray | str | Dict[str, np.ndarray] | Dict[str, str]:
+ ) -> None | np.ndarray | str | dict[str, np.ndarray] | dict[str, str]:
"""Render the environment as specified by environment :attr:`render_mode`.
The render mode attribute :attr:`render_mode` is set during the initialization
@@ -203,7 +204,7 @@ def render(
render for the entire environment (like `"rgb_array"` and `"ansi"` render
modes) which should be mapped to the `"env"` key in the dictionary by default.
- Note
+ Note:
----
Make sure that your class's :attr:`metadata` ``"render_modes"`` key includes
the list of supported modes.
@@ -216,7 +217,6 @@ def close(self):
Should be overridden in subclasses as necessary.
"""
- pass
@property
@abc.abstractmethod
@@ -232,7 +232,7 @@ def state(self) -> StateType:
"""
@property
- def possible_agents(self) -> Tuple[str, ...]:
+ def possible_agents(self) -> tuple[str, ...]:
"""The list of all possible agents that may appear in the environment.
Returns
@@ -243,7 +243,7 @@ def possible_agents(self) -> Tuple[str, ...]:
return self.model.possible_agents
@property
- def agents(self) -> List[str]:
+ def agents(self) -> list[str]:
"""The list of agents active in the environment for current state.
This will be :attr:`possible_agents`, independent of state, for any environment
@@ -257,7 +257,7 @@ def agents(self) -> List[str]:
return self.model.get_agents(self.state)
@property
- def action_spaces(self) -> Dict[str, spaces.Space]:
+ def action_spaces(self) -> dict[str, spaces.Space]:
"""A mapping from Agent ID to the space of valid actions for that agent.
Returns
@@ -268,7 +268,7 @@ def action_spaces(self) -> Dict[str, spaces.Space]:
return self.model.action_spaces
@property
- def observation_spaces(self) -> Dict[str, spaces.Space]:
+ def observation_spaces(self) -> dict[str, spaces.Space]:
"""A mapping from Agent ID to the space of valid observations for that agent.
Returns
@@ -279,7 +279,7 @@ def observation_spaces(self) -> Dict[str, spaces.Space]:
return self.model.observation_spaces
@property
- def reward_ranges(self) -> Dict[str, Tuple[float, float]]:
+ def reward_ranges(self) -> dict[str, tuple[float, float]]:
r"""A mapping from Agent ID to the min and max possible rewards for that agent.
Each reward tuple corresponding to the minimum and maximum possible rewards for
@@ -320,7 +320,7 @@ def is_symmetric(self) -> bool:
return self.model.is_symmetric
@property
- def unwrapped(self) -> "Env":
+ def unwrapped(self) -> Env:
"""Completely unwrap this env.
Returns
@@ -374,7 +374,11 @@ class DefaultEnv(Env[StateType, ObsType, ActType]):
"""
- def __init__(self, model: POSGModel, render_mode: Optional[str] = None):
+ def __init__(
+ self,
+ model: POSGModel,
+ render_mode: str | None = None,
+ ) -> None:
assert render_mode is None or render_mode in self.metadata["render_modes"]
self.model = model
self.render_mode = render_mode
@@ -382,18 +386,18 @@ def __init__(self, model: POSGModel, render_mode: Optional[str] = None):
self._state = self.model.sample_initial_state()
self._last_obs = self.model.sample_initial_obs(self._state)
self._step_num = 0
- self._last_actions: Dict[str, ActType] | None = None
- self._last_rewards: Dict[str, float] | None = None
+ self._last_actions: dict[str, ActType] | None = None
+ self._last_rewards: dict[str, float] | None = None
def step(
- self, actions: Dict[str, ActType]
- ) -> Tuple[
- Dict[str, ObsType],
- Dict[str, float],
- Dict[str, bool],
- Dict[str, bool],
+ self, actions: dict[str, ActType]
+ ) -> tuple[
+ dict[str, ObsType],
+ dict[str, float],
+ dict[str, bool],
+ dict[str, bool],
bool,
- Dict[str, Dict],
+ dict[str, dict],
]:
step = self.model.step(self._state, actions)
self._step_num += 1
@@ -411,8 +415,8 @@ def step(
)
def reset(
- self, *, seed: int | None = None, options: Dict[str, Any] | None = None
- ) -> Tuple[Dict[str, ObsType], Dict[str, Dict]]:
+ self, *, seed: int | None = None, options: dict[str, Any] | None = None
+ ) -> tuple[dict[str, ObsType], dict[str, dict]]:
super().reset(seed=seed)
self._state = self.model.sample_initial_state()
self._last_obs = self.model.sample_initial_obs(self._state)
@@ -444,19 +448,19 @@ class Wrapper(Env[WrapperStateType, WrapperObsType, WrapperActType]):
back to the wrapper's environment (i.e. to the corresponding attributes of
:attr:`env`).
- Note
+ Note:
----
If you inherit from :class:`Wrapper`, don't forget to call ``super().__init__(env)``
if the subclass overrides the `__init__` method.
"""
- def __init__(self, env: Env[StateType, ObsType, ActType]):
+ def __init__(self, env: Env[StateType, ObsType, ActType]) -> None:
self.env = env
- self._action_spaces: Dict[str, spaces.Space] | None = None
- self._observation_spaces: Dict[str, spaces.Space] | None = None
- self._reward_ranges: Dict[str, Tuple[float, float]] | None = None
- self._metadata: Dict[str, Any] | None = None
+ self._action_spaces: dict[str, spaces.Space] | None = None
+ self._observation_spaces: dict[str, spaces.Space] | None = None
+ self._reward_ranges: dict[str, tuple[float, float]] | None = None
+ self._metadata: dict[str, Any] | None = None
def __getattr__(self, name):
"""Returns attribute with ``name``, unless ``name`` starts with underscore."""
@@ -484,17 +488,17 @@ def state(self) -> WrapperStateType:
return self.env.state # type: ignore
@property
- def possible_agents(self) -> Tuple[str, ...]:
+ def possible_agents(self) -> tuple[str, ...]:
"""Returns the :attr:`Env` :attr:`possible_agents`."""
return self.env.possible_agents
@property
- def agents(self) -> List[str]:
+ def agents(self) -> list[str]:
"""Returns the :attr:`Env` :attr:`agents`."""
return self.env.agents
@property
- def action_spaces(self) -> Dict[str, spaces.Space]:
+ def action_spaces(self) -> dict[str, spaces.Space]:
"""Return the :attr:`Env` :attr:`action_spaces`.
This is the :attr:`Env` :attr:`action_spaces` unless it's overwritten then the
@@ -505,11 +509,11 @@ def action_spaces(self) -> Dict[str, spaces.Space]:
return self._action_spaces
@action_spaces.setter
- def action_spaces(self, action_spaces: Dict[str, spaces.Space]):
+ def action_spaces(self, action_spaces: dict[str, spaces.Space]):
self._action_spaces = action_spaces
@property
- def observation_spaces(self) -> Dict[str, spaces.Space]:
+ def observation_spaces(self) -> dict[str, spaces.Space]:
"""Return the :attr:`Env` :attr:`observation_spaces`.
This is the :attr:`Env` :attr:`observation_spaces` unless it's overwritten then
@@ -520,11 +524,11 @@ def observation_spaces(self) -> Dict[str, spaces.Space]:
return self._observation_spaces
@observation_spaces.setter
- def observation_spaces(self, observation_spaces: Dict[str, spaces.Space]):
+ def observation_spaces(self, observation_spaces: dict[str, spaces.Space]):
self._observation_spaces = observation_spaces
@property
- def reward_ranges(self) -> Dict[str, Tuple[float, float]]:
+ def reward_ranges(self) -> dict[str, tuple[float, float]]:
"""Return the :attr:`Env` :attr:`reward_ranges`.
This is the :attr:`Env` :attr:`reward_ranges`, unless it's overwritten, then
@@ -535,18 +539,18 @@ def reward_ranges(self) -> Dict[str, Tuple[float, float]]:
return self._reward_ranges
@reward_ranges.setter
- def reward_ranges(self, reward_ranges: Dict[str, Tuple[float, float]]):
+ def reward_ranges(self, reward_ranges: dict[str, tuple[float, float]]):
self._reward_ranges = reward_ranges
@property
- def metadata(self) -> Dict[str, Any]:
+ def metadata(self) -> dict[str, Any]:
"""Returns the :attr:`Env` :attr:`metadata`."""
if self._metadata is None:
return self.env.metadata
return self._metadata
@metadata.setter
- def metadata(self, value: Dict[str, Any]):
+ def metadata(self, value: dict[str, Any]):
self._metadata = value
@property
@@ -568,14 +572,14 @@ def render_mode(self, render_mode: str | None):
self.env.render_mode = render_mode
def step(
- self, actions: Dict[str, WrapperActType]
- ) -> Tuple[
- Dict[str, WrapperObsType],
- Dict[str, float],
- Dict[str, bool],
- Dict[str, bool],
+ self, actions: dict[str, WrapperActType]
+ ) -> tuple[
+ dict[str, WrapperObsType],
+ dict[str, float],
+ dict[str, bool],
+ dict[str, bool],
bool,
- Dict[str, Dict],
+ dict[str, dict],
]:
"""Uses the :meth:`step` of the :attr:`env`.
@@ -584,8 +588,8 @@ def step(
return self.env.step(actions) # type: ignore
def reset(
- self, *, seed: int | None = None, options: Dict[str, Any] | None = None
- ) -> Tuple[Dict[str, WrapperObsType], Dict[str, Dict]]:
+ self, *, seed: int | None = None, options: dict[str, Any] | None = None
+ ) -> tuple[dict[str, WrapperObsType], dict[str, dict]]:
"""Uses the :meth:`reset` of the :attr:`env`.
Can be overwritten to change the returned data.
@@ -594,7 +598,7 @@ def reset(
def render(
self,
- ) -> None | np.ndarray | str | Dict[str, np.ndarray] | Dict[str, str]:
+ ) -> None | np.ndarray | str | dict[str, np.ndarray] | dict[str, str]:
"""Uses the :meth:`render` of the :attr:`env`.
Can be overwritten to change the returned data.
@@ -629,31 +633,31 @@ class ObservationWrapper(Wrapper[StateType, WrapperObsType, ActType]):
Subclasses should at least implement the observations function.
"""
- def __init__(self, env: Env[StateType, ObsType, ActType]):
+ def __init__(self, env: Env[StateType, ObsType, ActType]) -> None:
super().__init__(env)
def reset(
- self, *, seed: int | None = None, options: Dict[str, Any] | None = None
- ) -> Tuple[Dict[str, WrapperObsType], Dict[str, Dict]]:
+ self, *, seed: int | None = None, options: dict[str, Any] | None = None
+ ) -> tuple[dict[str, WrapperObsType], dict[str, dict]]:
obs, info = self.env.reset(seed=seed, options=options)
if obs is None:
return obs, info
return self.observations(obs), info
def step(
- self, actions: Dict[str, ActType]
- ) -> Tuple[
- Dict[str, WrapperObsType],
- Dict[str, float],
- Dict[str, bool],
- Dict[str, bool],
+ self, actions: dict[str, ActType]
+ ) -> tuple[
+ dict[str, WrapperObsType],
+ dict[str, float],
+ dict[str, bool],
+ dict[str, bool],
bool,
- Dict[str, Dict],
+ dict[str, dict],
]:
obs, reward, term, trunc, done, infos = self.env.step(actions) # type: ignore
return self.observations(obs), reward, term, trunc, done, infos
- def observations(self, obs: Dict[str, ObsType]) -> Dict[str, WrapperObsType]:
+ def observations(self, obs: dict[str, ObsType]) -> dict[str, WrapperObsType]:
"""Transforms observations received from wrapped environment."""
raise NotImplementedError
@@ -664,23 +668,23 @@ class RewardWrapper(Wrapper[StateType, ObsType, ActType]):
Subclasses should at least implement the rewards function.
"""
- def __init__(self, env: Env[StateType, ObsType, ActType]):
+ def __init__(self, env: Env[StateType, ObsType, ActType]) -> None:
super().__init__(env)
def step(
- self, actions: Dict[str, ActType]
- ) -> Tuple[
- Dict[str, ObsType],
- Dict[str, float],
- Dict[str, bool],
- Dict[str, bool],
+ self, actions: dict[str, ActType]
+ ) -> tuple[
+ dict[str, ObsType],
+ dict[str, float],
+ dict[str, bool],
+ dict[str, bool],
bool,
- Dict[str, Dict],
+ dict[str, dict],
]:
obs, reward, term, trunc, done, info = self.env.step(actions) # type: ignore
return obs, self.rewards(reward), term, trunc, done, info # type: ignore
- def rewards(self, rewards: Dict[str, float]) -> Dict[str, float]:
+ def rewards(self, rewards: dict[str, float]) -> dict[str, float]:
"""Transforms rewards received from wrapped environment."""
raise NotImplementedError
@@ -691,21 +695,21 @@ class ActionWrapper(Wrapper[StateType, ObsType, WrapperActType]):
Subclasses should at least implement the actions function.
"""
- def __init__(self, env: Env[StateType, ObsType, ActType]):
+ def __init__(self, env: Env[StateType, ObsType, ActType]) -> None:
super().__init__(env)
def step(
- self, actions: Dict[str, ActType]
- ) -> Tuple[
- Dict[str, ObsType],
- Dict[str, float],
- Dict[str, bool],
- Dict[str, bool],
+ self, actions: dict[str, ActType]
+ ) -> tuple[
+ dict[str, ObsType],
+ dict[str, float],
+ dict[str, bool],
+ dict[str, bool],
bool,
- Dict[str, Dict],
+ dict[str, dict],
]:
return self.env.step(self.actions(actions)) # type: ignore
- def actions(self, actions: Dict[str, ActType]) -> Dict[str, WrapperActType]:
+ def actions(self, actions: dict[str, ActType]) -> dict[str, WrapperActType]:
"""Transform actions for wrapped environment."""
raise NotImplementedError
diff --git a/posggym/envs/__init__.py b/posggym/envs/__init__.py
index d40a510..39cfa31 100644
--- a/posggym/envs/__init__.py
+++ b/posggym/envs/__init__.py
@@ -8,6 +8,7 @@
from posggym.envs.registration import make, pprint_registry, register, registry, spec
+
# Classic
# -------------------------------------------
@@ -82,6 +83,17 @@
},
)
+
+# Differentiable
+# -------------------------------------------
+
+register(
+ id="PredatorPreyDifferentiable-v0",
+ entry_point="posggym.envs.differentiable.predator_prey_diff:PredatorPreyDiff",
+ max_episode_steps=100,
+ kwargs={"world": "10x10", "batch_size": 1},
+)
+
# Grid World
# -------------------------------------------
diff --git a/posggym/envs/classic/mabc.py b/posggym/envs/classic/mabc.py
index 882f32e..c50fe22 100644
--- a/posggym/envs/classic/mabc.py
+++ b/posggym/envs/classic/mabc.py
@@ -1,7 +1,7 @@
"""The Multi-Access Broadcast problem."""
import sys
from itertools import product
-from typing import Dict, List, Optional, Tuple, Union
+from typing import ClassVar
from gymnasium import spaces
@@ -10,7 +10,8 @@
from posggym.core import DefaultEnv
from posggym.utils import seeding
-MABCState = Tuple[int, ...]
+
+MABCState = tuple[int, ...]
EMPTY = 0
FULL = 1
NODE_STATES = [EMPTY, FULL]
@@ -90,9 +91,8 @@ class MABCEnv(DefaultEnv[MABCState, MABCObs, MABCAction]):
By default episodes continue infinitely long. To set a step limit, specify
`max_episode_steps` when initializing the environment with `posggym.make`.
- Arguments
+ Arguments:
---------
-
- `num_nodes` - the number of nodes (i.e. agents) in the network (default=`2.0`)
- `fill_probs` - the probability each nodes buffer is filled, should be a tuple with
an entry for each node (default = `None` = `(0.9, 0.1)`)
@@ -105,28 +105,28 @@ class MABCEnv(DefaultEnv[MABCState, MABCObs, MABCAction]):
---------------
- `v0`: Initial version
- References
+ References:
----------
- Ooi, J. M., and Wornell, G. W. 1996. Decentralized control of a multiple
access broadcast channel: Performance bounds. In Proceedings of the 35th
- Conference on Decision and Control, 293–298.
+ Conference on Decision and Control, 293-298.
- Hansen, Eric A., Daniel S. Bernstein, and Shlomo Zilberstein. “Dynamic
Programming for Partially Observable Stochastic Games.” In Proceedings of
- the 19th National Conference on Artificial Intelligence, 709–715. AAAI’04.
+ the 19th National Conference on Artificial Intelligence, 709-715. AAAI`04.
San Jose, California: AAAI Press, 2004.
"""
- metadata = {"render_modes": ["human", "ansi"], "render_fps": 4}
+ metadata: ClassVar[dict] = {"render_modes": ["human", "ansi"], "render_fps": 4}
def __init__(
self,
num_nodes: int = 2,
- fill_probs: Optional[Tuple[float, ...]] = None,
+ fill_probs: tuple[float, ...] | None = None,
observation_prob: float = 0.9,
- init_buffer_dist: Optional[Tuple[float, ...]] = None,
- render_mode: Optional[str] = None,
- ):
+ init_buffer_dist: tuple[float, ...] | None = None,
+ render_mode: str | None = None,
+ ) -> None:
super().__init__(
MABCModel(num_nodes, fill_probs, observation_prob, init_buffer_dist),
render_mode=render_mode,
@@ -135,7 +135,7 @@ def __init__(
def render(self):
if self.render_mode is None:
assert self.spec is not None
- logger.warn(
+ logger.warning(
"You are calling render method without specifying any render mode. "
"You can specify the render_mode at initialization, "
f'e.g. posggym.make("{self.spec.id}", render_mode="rgb_array")'
@@ -167,6 +167,9 @@ def render(self):
return output_str
+MIN_NODES = 2
+
+
class MABCModel(M.POSGFullModel[MABCState, MABCObs, MABCAction]):
"""POSG Model for the Multi-Access Broadcast Channel problem."""
@@ -180,11 +183,11 @@ class MABCModel(M.POSGFullModel[MABCState, MABCObs, MABCAction]):
def __init__(
self,
num_nodes: int = 2,
- fill_probs: Optional[Tuple[float, ...]] = None,
+ fill_probs: tuple[float, ...] | None = None,
observation_prob: float = 0.9,
- init_buffer_dist: Optional[Tuple[float, ...]] = None,
- ):
- assert num_nodes >= 2
+ init_buffer_dist: tuple[float, ...] | None = None,
+ ) -> None:
+ assert num_nodes >= MIN_NODES
if fill_probs is None:
fill_probs = self.DEFAULT_FILL_PROBS
@@ -225,7 +228,7 @@ def __init__(
self._obs_map = self._construct_obs_func()
@property
- def reward_ranges(self) -> Dict[str, Tuple[float, float]]:
+ def reward_ranges(self) -> dict[str, tuple[float, float]]:
return {i: (self.R_NO_SEND, self.R_SEND) for i in self.possible_agents}
@property
@@ -234,7 +237,7 @@ def rng(self) -> seeding.RNG:
self._rng, seed = seeding.std_random()
return self._rng
- def get_agents(self, state: MABCState) -> List[str]:
+ def get_agents(self, state: MABCState) -> list[str]:
return list(self.possible_agents)
def sample_initial_state(self) -> MABCState:
@@ -246,11 +249,11 @@ def sample_initial_state(self) -> MABCState:
node_states.append(EMPTY)
return tuple(node_states)
- def sample_initial_obs(self, state: MABCState) -> Dict[str, MABCObs]:
+ def sample_initial_obs(self, state: MABCState) -> dict[str, MABCObs]:
return {i: NOCOLLISION for i in self.possible_agents}
def step(
- self, state: MABCState, actions: Dict[str, MABCAction]
+ self, state: MABCState, actions: dict[str, MABCAction]
) -> M.JointTimestep[MABCState, MABCObs]:
assert all(a_i in ACTIONS for a_i in actions.values())
next_state = self._sample_next_state(state, actions)
@@ -260,14 +263,14 @@ def step(
terminated = {i: False for i in self.possible_agents}
truncated = {i: False for i in self.possible_agents}
all_done = False
- info: Dict[str, Dict] = {i: {} for i in self.possible_agents}
+ info: dict[str, dict] = {i: {} for i in self.possible_agents}
return M.JointTimestep(
next_state, obs, rewards, terminated, truncated, all_done, info
)
def _sample_next_state(
- self, state: MABCState, actions: Dict[str, MABCAction]
+ self, state: MABCState, actions: dict[str, MABCAction]
) -> MABCState:
next_node_states = list(state)
for i, a_i in actions.items():
@@ -279,7 +282,7 @@ def _sample_next_state(
next_node_states[idx] = FULL
return tuple(next_node_states)
- def _sample_obs(self, actions: Dict[str, MABCAction]) -> Dict[str, MABCObs]:
+ def _sample_obs(self, actions: dict[str, MABCAction]) -> dict[str, MABCObs]:
senders = sum(int(a_i == SEND) for a_i in actions.values())
if senders > 1:
correct_obs = COLLISION
@@ -296,8 +299,8 @@ def _sample_obs(self, actions: Dict[str, MABCAction]) -> Dict[str, MABCObs]:
obs[i] = wrong_obs
return obs
- def get_initial_belief(self) -> Dict[MABCState, float]:
- b_map: Dict[MABCState, float] = {}
+ def get_initial_belief(self) -> dict[MABCState, float]:
+ b_map: dict[MABCState, float] = {}
s_prob_sum = 0.0
for s in self._state_space:
s_prob = 1.0
@@ -317,13 +320,13 @@ def get_initial_belief(self) -> Dict[MABCState, float]:
def transition_fn(
self,
state: MABCState,
- actions: Dict[str, MABCAction],
+ actions: dict[str, MABCAction],
next_state: MABCState,
) -> float:
action_tuple = tuple(actions[i] for i in self.possible_agents)
return self._trans_map[(state, action_tuple, next_state)]
- def _construct_trans_func(self) -> Dict:
+ def _construct_trans_func(self) -> dict:
trans_map = {}
agent_ids = [int(i) for i in self.possible_agents]
for s, a, s_next in product(
@@ -343,15 +346,15 @@ def _construct_trans_func(self) -> Dict:
def observation_fn(
self,
- obs: Dict[str, MABCObs],
+ obs: dict[str, MABCObs],
next_state: MABCState,
- actions: Dict[str, MABCAction],
+ actions: dict[str, MABCAction],
) -> float:
obs_tuple = tuple(obs[i] for i in self.possible_agents)
action_tuple = tuple(actions[i] for i in self.possible_agents)
return self._obs_map[(next_state, action_tuple, obs_tuple)]
- def _construct_obs_func(self) -> Dict:
+ def _construct_obs_func(self) -> dict:
obs_map = {}
agent_ids = [int(i) for i in self.possible_agents]
for s_next, a, o in product(
@@ -372,12 +375,12 @@ def _construct_obs_func(self) -> Dict:
return obs_map
def reward_fn(
- self, state: MABCState, actions: Dict[str, MABCAction]
- ) -> Dict[str, float]:
+ self, state: MABCState, actions: dict[str, MABCAction]
+ ) -> dict[str, float]:
action_tuple = tuple(actions[i] for i in self.possible_agents)
return self._rew_map[(state, action_tuple)]
- def _construct_rew_func(self) -> Dict:
+ def _construct_rew_func(self) -> dict:
rew_map = {}
joint_actions_space = product(*self._action_spaces)
for s, a in product(self._state_space, joint_actions_space):
@@ -388,7 +391,7 @@ def _construct_rew_func(self) -> Dict:
def _message_sent(
self,
state: MABCState,
- actions: Union[Dict[str, MABCAction], Tuple[MABCAction, ...]],
+ actions: dict[str, MABCAction] | tuple[MABCAction, ...],
) -> bool:
if isinstance(actions, dict):
actions = tuple(actions[i] for i in self.possible_agents)
diff --git a/posggym/envs/classic/rock_paper_scissors.py b/posggym/envs/classic/rock_paper_scissors.py
index 7adb0c5..f1424c9 100644
--- a/posggym/envs/classic/rock_paper_scissors.py
+++ b/posggym/envs/classic/rock_paper_scissors.py
@@ -1,7 +1,7 @@
"""The classic Rock Paper Scissors problem."""
import sys
from itertools import product
-from typing import Dict, List, Optional, Tuple
+from typing import ClassVar
from gymnasium import spaces
@@ -10,6 +10,7 @@
from posggym.core import DefaultEnv
from posggym.utils import seeding
+
RPSState = int
STATE0 = 0
STATES = [STATE0]
@@ -23,7 +24,7 @@
ACTION_STR = ["R", "P", "S"]
RPSObs = int
-RPSJointObs = Tuple[RPSObs, ...]
+RPSJointObs = tuple[RPSObs, ...]
OBS_SPACE = ACTIONS
OBS_STR = ACTION_STR
@@ -79,7 +80,7 @@ class RockPaperScissorsEnv(DefaultEnv):
By default episodes continue infinitely long. To set a step limit, specify
`max_episode_steps` when initializing the environment with `posggym.make`.
- Arguments
+ Arguments:
---------
No additional arguments are currently supported during construction.
@@ -89,15 +90,15 @@ class RockPaperScissorsEnv(DefaultEnv):
"""
- metadata = {"render_modes": ["human", "ansi"], "render_fps": 4}
+ metadata: ClassVar[dict] = {"render_modes": ["human", "ansi"], "render_fps": 4}
- def __init__(self, render_mode: Optional[str] = None):
+ def __init__(self, render_mode: str | None = None) -> None:
super().__init__(RockPaperScissorsModel(), render_mode=render_mode)
def render(self, mode: str = "human"):
if self.render_mode is None:
assert self.spec is not None
- logger.warn(
+ logger.warning(
"You are calling render method without specifying any render mode. "
"You can specify the render_mode at initialization, "
f'e.g. posggym.make("{self.spec.id}", render_mode="rgb_array")'
@@ -127,9 +128,13 @@ class RockPaperScissorsModel(M.POSGFullModel[RPSState, RPSObs, RPSAction]):
NUM_AGENTS = 2
- R_MATRIX = [[0, -1.0, 1.0], [1.0, 0, -1.0], [-1.0, 1.0, 0]]
+ R_MATRIX: ClassVar[list[list[float]]] = [
+ [0, -1.0, 1.0],
+ [1.0, 0, -1.0],
+ [-1.0, 1.0, 0],
+ ]
- def __init__(self):
+ def __init__(self) -> None:
self.possible_agents = tuple(str(i) for i in range(self.NUM_AGENTS))
self.state_space = spaces.Discrete(len(STATES))
self.action_spaces = {
@@ -150,7 +155,7 @@ def __init__(self):
self._obs_map = self._construct_obs_func()
@property
- def reward_ranges(self) -> Dict[str, Tuple[float, float]]:
+ def reward_ranges(self) -> dict[str, tuple[float, float]]:
return {i: (-1.0, 1.0) for i in self.possible_agents}
@property
@@ -159,45 +164,45 @@ def rng(self) -> seeding.RNG:
self._rng, seed = seeding.std_random()
return self._rng
- def get_agents(self, state: RPSState) -> List[str]:
+ def get_agents(self, state: RPSState) -> list[str]:
return list(self.possible_agents)
def sample_initial_state(self) -> RPSState:
return STATE0
- def sample_initial_obs(self, state: RPSState) -> Dict[str, RPSObs]:
+ def sample_initial_obs(self, state: RPSState) -> dict[str, RPSObs]:
return {i: ROCK for i in self.possible_agents}
def step(
- self, state: RPSState, actions: Dict[str, RPSAction]
+ self, state: RPSState, actions: dict[str, RPSAction]
) -> M.JointTimestep[RPSState, RPSObs]:
assert all(a_i in ACTIONS for a_i in actions.values())
- obs: Dict[str, RPSObs] = {"0": actions["1"], "1": actions["0"]}
+ obs: dict[str, RPSObs] = {"0": actions["1"], "1": actions["0"]}
rewards = self._get_reward(actions)
terminated = {i: False for i in self.possible_agents}
truncated = {i: False for i in self.possible_agents}
all_done = False
- info: Dict[str, Dict] = {i: {} for i in self.possible_agents}
+ info: dict[str, dict] = {i: {} for i in self.possible_agents}
return M.JointTimestep(
STATE0, obs, rewards, terminated, truncated, all_done, info
)
- def _get_reward(self, actions: Dict[str, RPSAction]) -> Dict[str, float]:
+ def _get_reward(self, actions: dict[str, RPSAction]) -> dict[str, float]:
return {
"0": self.R_MATRIX[actions["0"]][actions["1"]],
"1": self.R_MATRIX[actions["1"]][actions["0"]],
}
- def get_initial_belief(self) -> Dict[RPSState, float]:
+ def get_initial_belief(self) -> dict[RPSState, float]:
return {STATE0: 1.0}
def transition_fn(
- self, state: RPSState, actions: Dict[str, RPSAction], next_state: RPSState
+ self, state: RPSState, actions: dict[str, RPSAction], next_state: RPSState
) -> float:
action_tuple = tuple(actions[i] for i in self.possible_agents)
return self._trans_map[(state, action_tuple, next_state)]
- def _construct_trans_func(self) -> Dict:
+ def _construct_trans_func(self) -> dict:
trans_map = {}
for a in product(*self._action_spaces):
trans_map[(STATE0, a, STATE0)] = 1.0
@@ -205,27 +210,27 @@ def _construct_trans_func(self) -> Dict:
def observation_fn(
self,
- obs: Dict[str, RPSObs],
+ obs: dict[str, RPSObs],
next_state: RPSState,
- actions: Dict[str, RPSAction],
+ actions: dict[str, RPSAction],
) -> float:
obs_tuple = tuple(obs[i] for i in self.possible_agents)
action_tuple = tuple(actions[i] for i in self.possible_agents)
return self._obs_map[(next_state, action_tuple, obs_tuple)]
- def _construct_obs_func(self) -> Dict:
+ def _construct_obs_func(self) -> dict:
obs_func = {}
for a, o in product(product(*self._action_spaces), product(*self._obs_spaces)):
obs_func[(STATE0, a, o)] = 1.0 if a == o else 0.0
return obs_func
def reward_fn(
- self, state: RPSState, actions: Dict[str, RPSAction]
- ) -> Dict[str, float]:
+ self, state: RPSState, actions: dict[str, RPSAction]
+ ) -> dict[str, float]:
action_tuple = tuple(actions[i] for i in self.possible_agents)
return self._rew_map[(state, action_tuple)]
- def _construct_rew_func(self) -> Dict:
+ def _construct_rew_func(self) -> dict:
rew_map = {}
for a in product(*self._action_spaces):
rew_map[(STATE0, a)] = {
diff --git a/posggym/envs/classic/tiger.py b/posggym/envs/classic/tiger.py
index 963615c..8bc42cc 100644
--- a/posggym/envs/classic/tiger.py
+++ b/posggym/envs/classic/tiger.py
@@ -1,7 +1,7 @@
"""Model for the classic Multi-Agent Tiger problem."""
import sys
from itertools import product
-from typing import Dict, List, Optional, Tuple
+from typing import ClassVar
from gymnasium import spaces
@@ -24,7 +24,7 @@
ACTIONS = [OPENLEFT, OPENRIGHT, LISTEN]
ACTION_STR = ["OL", "OR", "L"]
-MATObs = Tuple[int, int]
+MATObs = tuple[int, int]
GROWLLEFT = 0
GROWLRIGHT = 1
CREAKLEFT = 0
@@ -42,6 +42,8 @@
OBS_STR = [("GL", "GR"), ("CL", "CR", "S")]
OTHER_AGENT_ID = {"0": "1", "1": "0"}
+MIN_PROBABILITY = 0.0
+MAX_PROBABILITY = 1.0
class MultiAgentTigerEnv(DefaultEnv):
@@ -122,9 +124,8 @@ class MultiAgentTigerEnv(DefaultEnv):
By default episodes continue infinitely long. To set a step limit, specify
`max_episode_steps` when initializing the environment with `posggym.make`.
- Arguments
+ Arguments:
---------
-
- `observation_prob` - the probability of correctly observing the position of the
tiger (default = `0.85`)
- `creak_observation_prob` - the probability of correctly observing which door was
@@ -134,21 +135,21 @@ class MultiAgentTigerEnv(DefaultEnv):
---------------
- `v0`: Initial version
- References
+ References:
----------
- Gmytrasiewicz, Piotr J., and Prashant Doshi. “A Framework for Sequential Planning
in Multi-Agent Settings.” Journal of Artificial Intelligence Research 24 (2005).
"""
- metadata = {"render_modes": ["human", "ansi"], "render_fps": 4}
+ metadata: ClassVar[dict] = {"render_modes": ["human", "ansi"], "render_fps": 4}
def __init__(
self,
observation_prob: float = 0.85,
creak_observation_prob: float = 0.9,
- render_mode: Optional[str] = None,
- ):
+ render_mode: str | None = None,
+ ) -> None:
super().__init__(
MultiAgentTigerModel(observation_prob, creak_observation_prob),
render_mode=render_mode,
@@ -157,7 +158,7 @@ def __init__(
def render(self):
if self.render_mode is None:
assert self.spec is not None
- logger.warn(
+ logger.warning(
"You are calling render method without specifying any render mode. "
"You can specify the render_mode at initialization, "
f'e.g. posggym.make("{self.spec.id}", render_mode="rgb_array")'
@@ -202,9 +203,9 @@ class MultiAgentTigerModel(M.POSGFullModel[MATState, MATObs, MATAction]):
def __init__(
self, observation_prob: float = 0.85, creak_observation_prob: float = 0.9
- ):
- assert 0 <= observation_prob <= 1.0
- assert 0 <= creak_observation_prob <= 1.0
+ ) -> None:
+ assert MIN_PROBABILITY <= observation_prob <= MAX_PROBABILITY
+ assert MIN_PROBABILITY <= creak_observation_prob <= MAX_PROBABILITY
self._obs_prob = observation_prob
self._creak_obs_prob = creak_observation_prob
@@ -233,7 +234,7 @@ def __init__(
self._obs_map = self._construct_obs_func()
@property
- def reward_ranges(self) -> Dict[str, Tuple[float, float]]:
+ def reward_ranges(self) -> dict[str, tuple[float, float]]:
return {i: (self.OPEN_BAD_R, self.OPEN_GOOD_R) for i in self.possible_agents}
@property
@@ -242,17 +243,17 @@ def rng(self) -> seeding.RNG:
self._rng, seed = seeding.std_random()
return self._rng
- def get_agents(self, state: MATState) -> List[str]:
+ def get_agents(self, state: MATState) -> list[str]:
return list(self.possible_agents)
def sample_initial_state(self) -> MATState:
return self.rng.choice(STATES)
- def sample_initial_obs(self, state: MATState) -> Dict[str, MATObs]:
+ def sample_initial_obs(self, state: MATState) -> dict[str, MATObs]:
return {i: OBS_SPACE[2] for i in self.possible_agents}
def step(
- self, state: MATState, actions: Dict[str, MATAction]
+ self, state: MATState, actions: dict[str, MATAction]
) -> M.JointTimestep[MATState, MATObs]:
assert all(a_i in ACTIONS for a_i in actions.values())
next_state = self._sample_next_state(state, actions)
@@ -261,13 +262,13 @@ def step(
terminated = {i: False for i in self.possible_agents}
truncated = {i: False for i in self.possible_agents}
all_done = False
- info: Dict[str, Dict] = {i: {} for i in self.possible_agents}
+ info: dict[str, dict] = {i: {} for i in self.possible_agents}
return M.JointTimestep(
next_state, obs, rewards, terminated, truncated, all_done, info
)
def _sample_next_state(
- self, state: MATState, actions: Dict[str, MATAction]
+ self, state: MATState, actions: dict[str, MATAction]
) -> MATState:
next_state = state
if any(a != LISTEN for a in actions.values()):
@@ -275,9 +276,9 @@ def _sample_next_state(
return next_state
def _sample_obs(
- self, state: MATState, actions: Dict[str, MATAction]
- ) -> Dict[str, MATObs]:
- obs: Dict[str, MATObs] = {}
+ self, state: MATState, actions: dict[str, MATAction]
+ ) -> dict[str, MATObs]:
+ obs: dict[str, MATObs] = {}
for i, a in actions.items():
if a != LISTEN:
obs[i] = self.rng.choice(OBS_SPACE)
@@ -311,9 +312,9 @@ def _sample_creak_obs(self, a_j: MATAction) -> int:
return self.rng.choice([CREAKRIGHT, SILENCE])
def _get_reward(
- self, state: MATState, actions: Dict[str, MATAction]
- ) -> Dict[str, float]:
- rewards: Dict[str, float] = {}
+ self, state: MATState, actions: dict[str, MATAction]
+ ) -> dict[str, float]:
+ rewards: dict[str, float] = {}
for i, a in actions.items():
if a == LISTEN:
rewards[i] = self.LISTEN_R
@@ -323,20 +324,20 @@ def _get_reward(
rewards[i] = self.OPEN_GOOD_R
return rewards
- def get_initial_belief(self) -> Dict[MATState, float]:
- b_map: Dict[MATState, float] = {}
+ def get_initial_belief(self) -> dict[MATState, float]:
+ b_map: dict[MATState, float] = {}
for s in STATES:
s_prob = 1.0 / len(STATES)
b_map[s] = s_prob
return b_map
def transition_fn(
- self, state: MATState, actions: Dict[str, MATAction], next_state: MATState
+ self, state: MATState, actions: dict[str, MATAction], next_state: MATState
) -> float:
action_tuple = tuple(actions[i] for i in self.possible_agents)
return self._trans_map[(state, action_tuple, next_state)]
- def _construct_trans_func(self) -> Dict:
+ def _construct_trans_func(self) -> dict:
trans_map = {}
uniform_prob = 1.0 / len(STATES)
for s, a, s_next in product(
@@ -349,15 +350,15 @@ def _construct_trans_func(self) -> Dict:
def observation_fn(
self,
- obs: Dict[str, MATObs],
+ obs: dict[str, MATObs],
next_state: MATState,
- actions: Dict[str, MATAction],
+ actions: dict[str, MATAction],
) -> float:
obs_tuple = tuple(obs[i] for i in self.possible_agents)
action_tuple = tuple(actions[i] for i in self.possible_agents)
return self._obs_map[(next_state, action_tuple, obs_tuple)]
- def _construct_obs_func(self) -> Dict:
+ def _construct_obs_func(self) -> dict:
obs_func = {}
uniform_o_prob = 1.0 / len(OBS_SPACE)
for s_next, a, o in product(
@@ -406,12 +407,12 @@ def _construct_obs_func(self) -> Dict:
return obs_func
def reward_fn(
- self, state: MATState, actions: Dict[str, MATAction]
- ) -> Dict[str, float]:
+ self, state: MATState, actions: dict[str, MATAction]
+ ) -> dict[str, float]:
action_tuple = tuple(actions[i] for i in self.possible_agents)
return self._rew_map[(state, action_tuple)]
- def _construct_rew_func(self) -> Dict:
+ def _construct_rew_func(self) -> dict:
rew_map = {}
joint_actions_space = product(*self._action_spaces)
for s, a in product(self._state_space, joint_actions_space):
diff --git a/posggym/envs/continuous/core.py b/posggym/envs/continuous/core.py
index b64eda9..7348711 100644
--- a/posggym/envs/continuous/core.py
+++ b/posggym/envs/continuous/core.py
@@ -8,34 +8,41 @@
from enum import Enum
from itertools import product
from queue import PriorityQueue
-from typing import Dict, Iterable, List, NamedTuple, Set, Tuple, Union
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ NamedTuple,
+)
import numpy as np
from gymnasium import spaces
-from posggym.error import DependencyNotInstalled
+from posggym.error import DependencyNotInstalledError
+
+
+if TYPE_CHECKING:
+ from collections.abc import Iterable
try:
import pymunk
from pymunk import Vec2d
except ImportError as e:
- raise DependencyNotInstalled(
+ raise DependencyNotInstalledError(
"pymunk is not installed, run `pip install posggym[continuous]`"
) from e
# (x, y) coord = (col, row) coord
-Coord = Tuple[int, int]
-FloatCoord = Tuple[float, float]
+Coord = tuple[int, int]
+FloatCoord = tuple[float, float]
# (x, y, yaw) in continuous world
-Position = Tuple[float, float, float]
-Location = Union[Coord, FloatCoord, Position, np.ndarray]
+Position = tuple[float, float, float]
+Location = Coord | FloatCoord | Position | np.ndarray
# Position, radius
-CircleEntity = Tuple[Position, float]
-# start (x, y), end (x, y)
-Line = Tuple[FloatCoord, FloatCoord]
-IntLine = Tuple[Tuple[int, int], Tuple[int, int]]
+CircleEntity = tuple[Position, float]
+Line = tuple[FloatCoord, FloatCoord]
+IntLine = tuple[tuple[int, int], tuple[int, int]]
# (Main, Alternative) agent colors (from pygame.colordict.THECOLORS)
AGENT_COLORS = [
@@ -49,6 +56,35 @@
((255, 165, 0, 255), (205, 133, 0, 255)), # orange, orange3
]
+ZERO = 0.0
+ONE = 1.0
+
+
+def clamp(x, lower, upper):
+ return lower if x < lower else upper if x > upper else x
+
+
+class ControlType(Enum):
+ VelocityHolonomoic = 0, "VelocityHolonomoic"
+ ForceHolonomoic = 1, "ForceHolonomoic"
+ VelocityNonHolonomoic = 2, "VelocityNonHolonomoic"
+ ForceNonHolonomoic = 3, "ForceNonHolonomoic"
+ WheeledRobot = 4, "WheeledRobot"
+ Ackermann = 5, "Ackermann"
+
+ def __int__(self):
+ return self.value[0]
+
+ def __str__(self):
+ return self.value[1]
+
+ @classmethod
+ def from_str(cls, name):
+ for member in cls:
+ if member.value[1] == name:
+ return member
+ raise ValueError(f"'{name}' is not a valid {cls.__name__}")
+
class CollisionType(Enum):
"""Type of collision in world."""
@@ -60,6 +96,25 @@ class CollisionType(Enum):
INTERIOR_WALL = 4
+MIN_POSITION_ARRAY_LENGTH = 3
+
+
+def generate_parameters(control_type: ControlType) -> dict[str, float]:
+ match control_type:
+ case ControlType.WheeledRobot:
+ return {
+ "wheel_radius": 0.1,
+ "L": 0.1,
+ }
+ case ControlType.Ackermann:
+ return {
+ "L": 0.1,
+ }
+
+ case _:
+ return {}
+
+
class PMBodyState(NamedTuple):
"""State of a Pymunk Body."""
@@ -79,7 +134,6 @@ def num_features() -> int:
def get_space(world_size: float) -> spaces.Box:
"""Get the space for a pymunk body's state."""
# x, y, angle, vx, vy, vangle
- # shape = (1, 6)
size, angle = world_size, 2 * math.pi
low = np.array([-1, -1, -angle, -1, -1, -angle], dtype=np.float32)
high = np.array(
@@ -89,6 +143,109 @@ def get_space(world_size: float) -> spaces.Box:
return spaces.Box(low=low, high=high)
+X_IDX = PMBodyState._fields.index("x")
+Y_IDX = PMBodyState._fields.index("y")
+ANGLE_IDX = PMBodyState._fields.index("angle")
+VX_IDX = PMBodyState._fields.index("vx")
+VY_IDX = PMBodyState._fields.index("vy")
+VANGLE_IDX = PMBodyState._fields.index("vangle")
+
+
+def generate_action_space(
+ possible_agents: tuple[str, ...],
+ dyaw_limit: float | tuple[float, float],
+ dvel_limit: float | tuple[float, float],
+ fyaw_limit: float | tuple[float, float],
+ fvel_limit: float | tuple[float, float],
+):
+ action_spaces_per_control = {}
+ for i in ControlType:
+ action_spaces_per_control[i] = generate_action_space_per_control(
+ i,
+ possible_agents,
+ dyaw_limit,
+ dvel_limit,
+ fyaw_limit,
+ fvel_limit,
+ )
+ return action_spaces_per_control
+
+
+def scale_action(
+ action: np.ndarray, source_space: spaces.Space, target_space: spaces.Box
+) -> np.ndarray:
+ assert isinstance(source_space, spaces.Box)
+
+ source_low = source_space.low
+ source_high = source_space.high
+ target_low = target_space.low
+ target_high = target_space.high
+
+ # Apply the scaling formula for each dimension of the action
+ scaled_action = target_low + (action - source_low) * (target_high - target_low) / (
+ source_high - source_low
+ )
+
+ return scaled_action
+
+
+def generate_action_space_per_control(
+ control_type: ControlType,
+ possible_agents: tuple[str, ...],
+ dyaw_limit: float | tuple[float, float] | None = None,
+ dvel_limit: float | tuple[float, float] | None = None,
+ fyaw_limit: float | tuple[float, float] | None = None,
+ fvel_limit: float | tuple[float, float] | None = None,
+) -> dict[str, spaces.Space]:
+ if isinstance(dyaw_limit, float):
+ dyaw_limit = (-dyaw_limit, dyaw_limit)
+ if isinstance(dvel_limit, float):
+ dvel_limit = (-dvel_limit, dvel_limit)
+ if isinstance(fyaw_limit, float):
+ fyaw_limit = (-fyaw_limit, fyaw_limit)
+ if isinstance(fvel_limit, float):
+ fvel_limit = (-fvel_limit, fvel_limit)
+
+ assert isinstance(dyaw_limit, tuple)
+ assert isinstance(dvel_limit, tuple)
+ assert isinstance(fyaw_limit, tuple)
+ assert isinstance(fvel_limit, tuple)
+
+ match control_type:
+ case ControlType.VelocityNonHolonomoic:
+ assert dyaw_limit is not None and dvel_limit is not None
+ neg_limits = np.array([dyaw_limit[0], dvel_limit[0]], dtype=np.float32)
+ pos_limits = np.array([dyaw_limit[1], dvel_limit[1]], dtype=np.float32)
+ case ControlType.VelocityHolonomoic:
+ assert dvel_limit is not None
+ neg_limits = np.array([dvel_limit[0], dvel_limit[0]], dtype=np.float32)
+ pos_limits = np.array([dvel_limit[1], dvel_limit[1]], dtype=np.float32)
+ case ControlType.ForceNonHolonomoic:
+ assert fyaw_limit is not None and fvel_limit is not None
+ neg_limits = np.array([fyaw_limit[0], fvel_limit[0]], dtype=np.float32)
+ pos_limits = np.array([fyaw_limit[1], fvel_limit[1]], dtype=np.float32)
+ case ControlType.ForceHolonomoic:
+ assert fvel_limit is not None
+ neg_limits = np.array([fvel_limit[0], fvel_limit[0]], dtype=np.float32)
+ pos_limits = np.array([fvel_limit[1], fvel_limit[1]], dtype=np.float32)
+ case ControlType.WheeledRobot:
+ assert dvel_limit is not None
+ neg_limits = np.array([dvel_limit[0], dvel_limit[0]], dtype=np.float32)
+ pos_limits = np.array([dvel_limit[1], dvel_limit[1]], dtype=np.float32)
+ case ControlType.Ackermann:
+ assert dyaw_limit is not None and dvel_limit is not None
+ neg_limits = np.array([dyaw_limit[0], dvel_limit[0]], dtype=np.float32)
+ pos_limits = np.array([dyaw_limit[1], dvel_limit[1]], dtype=np.float32)
+
+ return {
+ i: spaces.Box(
+ low=neg_limits,
+ high=pos_limits,
+ )
+ for i in possible_agents
+ }
+
+
# This function needs to be in global scope or we get pickle errors
def ignore_collisions(arbiter, space, data):
"""Pymunk collision handler which ignores collisions."""
@@ -96,8 +253,8 @@ def ignore_collisions(arbiter, space, data):
def clip_actions(
- actions: Dict[str, np.ndarray], action_spaces: Dict[str, spaces.Space]
-) -> Dict[str, np.ndarray]:
+ actions: dict[str, np.ndarray], action_spaces: dict[str, spaces.Space]
+) -> dict[str, np.ndarray]:
"""Clip continuous actions so they are within the agents action space dims."""
clipped_actions = {}
for i, a in actions.items():
@@ -116,19 +273,19 @@ class AbstractContinuousWorld(ABC):
def __init__(
self,
size: float,
- blocks: List[CircleEntity] | None = None,
- interior_walls: List[Line] | None = None,
+ blocks: list[CircleEntity] | None = None,
+ interior_walls: list[Line] | None = None,
agent_radius: float = 0.5,
border_thickness: float = 0.1,
enable_agent_collisions: bool = True,
- ):
+ ) -> None:
self.size = size
self.blocks = blocks or []
self.interior_walls = interior_walls or []
self.agent_radius = agent_radius
self.border_thickness = border_thickness
# access via blocked_coords property
- self._blocked_coords: Set[Coord] | None = None
+ self._blocked_coords: set[Coord] | None = None
self.collision_id = 0
self.enable_agent_collisions = enable_agent_collisions
@@ -154,29 +311,25 @@ def __init__(
self.space.add(body, shape)
# moveable entities in the world
- self.entities: Dict[str, Tuple[pymunk.Body, pymunk.Circle]] = {}
+ self.entities: dict[str, tuple[pymunk.Body, pymunk.Circle]] = {}
@abstractmethod
def add_border_to_space(self, size: float):
"""Adds solid border to the world physics space."""
- pass
@abstractmethod
def check_border_collisions(
self, ray_start_coords: np.ndarray, ray_end_coords: np.ndarray
) -> np.ndarray:
"""Check for collision between rays and world border."""
- pass
@abstractmethod
def clip_position(self, position: Vec2d) -> Vec2d:
- """Clip the position of an agent to be inside the border"""
- pass
+ """Clip the position of an agent to be inside the border."""
@abstractmethod
- def copy(self) -> "AbstractContinuousWorld":
+ def copy(self) -> AbstractContinuousWorld:
"""Get a deep copy of this world."""
- pass
def simulate(
self,
@@ -192,7 +345,7 @@ def simulate(
Also performing multiple steps `t` with a smaller `dt` creates a more stable
and accurate simulation.
- Arguments
+ Arguments:
---------
dt : float
the step size
@@ -217,16 +370,89 @@ def simulate(
body.angular_velocity
)
+ def compute_vel_force(
+ self,
+ control_type: ControlType,
+ current_ang: float,
+ current_vel: tuple[float, float] | None,
+ action_i: np.ndarray,
+ vel_limit_norm: float | None,
+ kinematic_parameters: dict[str, float],
+ ) -> dict[str, Any]:
+ """Compute appropriate velocity, force, and torque based on the
+ given control type and action.
+
+ Parameters
+ ----------
+ - control_type (ControlType): The type of control being used.
+ - current_ang (float): The current angle of the agent.
+ - current_vel (Tuple[float, float] | None): The current vel of the agent,
+ if given, velocity will be relative to the current agent
+ - action_i (np.ndarray): The action input for the agent.
+ - vel_limit_norm (float | None): The limit of velocity norm
+ if given, velcoity will be relative to the current agent
+ """
+ angle, vel, torque, local_force, global_force = (
+ None,
+ None,
+ None,
+ None,
+ None,
+ )
+ match control_type:
+ case ControlType.VelocityNonHolonomoic:
+ angle = current_ang + action_i[0]
+ vel = self.linear_to_xy_velocity(action_i[1], angle)
+ if current_vel is not None and vel_limit_norm is not None:
+ vel += Vec2d(*current_vel).rotated(action_i[0])
+ vel = self.clamp_norm(vel[0], vel[1], vel_limit_norm)
+
+ case ControlType.VelocityHolonomoic:
+ angle = 0
+ if current_vel is not None and vel_limit_norm is not None:
+ vel = (current_vel[0] + action_i[0], current_vel[1] + action_i[1])
+ vel = self.clamp_norm(vel[0], vel[1], vel_limit_norm)
+ else:
+ vel = (action_i[0], action_i[1])
+
+ case ControlType.ForceHolonomoic:
+ local_force = (action_i[0], 0)
+ torque = action_i[1]
+ case ControlType.ForceNonHolonomoic:
+ global_force = (action_i[0], action_i[1])
+ angle = 0
+ case ControlType.WheeledRobot:
+ wheel_radius = kinematic_parameters["wheel_radius"]
+ L = kinematic_parameters["L"]
+ v = wheel_radius / 2 * (action_i[0] + action_i[1])
+ omega = wheel_radius / L * (action_i[0] - action_i[1])
+ vel = v * np.array([np.cos(current_ang), np.sin(current_ang)])
+ angle = current_ang + omega
+ case ControlType.Ackermann:
+ v, phi = action_i
+ L = kinematic_parameters["L"]
+ omega = v / L * np.tan(phi)
+ vel = v * np.array([np.cos(current_ang), np.sin(current_ang)])
+ angle = current_ang + omega
+
+ return {
+ "angle": angle,
+ "vel": vel,
+ "torque": torque,
+ "local_force": local_force,
+ "global_force": global_force,
+ }
+
def add_entity(
self,
id: str,
radius: float | None,
- color: Tuple[int, int, int, int] | None,
+ color: tuple[int, int, int, int] | None,
is_static: bool = False,
- ) -> Tuple[pymunk.Body, pymunk.Circle]:
+ ) -> tuple[pymunk.Body, pymunk.Circle]:
"""Add moveable entity to the world.
- Arguments
+ Arguments:
---------
id : str
the unique ID of the entity
@@ -235,7 +461,7 @@ def add_entity(
color : Tuple[int, int, int, int] | None
optional color for the entity. This only impacts rendering of the world.
- Returns
+ Returns:
-------
body : pymunk.Body
underlying physics Body of the entity
@@ -251,8 +477,6 @@ def add_entity(
body = pymunk.Body(mass, inertia, body_type=body_type)
shape = pymunk.Circle(body, radius)
- shape.collision_type = self.get_collision_id()
-
shape.elasticity = 0.0 # no bouncing
shape.collision_type = self.get_collision_id()
if color is not None:
@@ -268,7 +492,25 @@ def remove_entity(self, id: str):
self.space.remove(body, shape)
del self.entities[id]
- def add_interior_walls_to_space(self, walls: List[Line]):
+ def change_entity_dynamics(
+ self,
+ id: str,
+ mass: float | None = None,
+ friction: float | None = None,
+ elasticity: float | None = None,
+ ):
+ body, shape = self.entities[id]
+ if mass is not None:
+ body.mass = mass
+
+ if friction is not None:
+ shape.friction = friction
+
+ if elasticity is not None:
+ elasticity = clamp(elasticity, 0, 0.99)
+ shape.elasticity = elasticity
+
+ def add_interior_walls_to_space(self, walls: list[Line]):
"""Adds interior walls to the world physics space."""
self.interior_walls_array = (
np.array([ln[0] for ln in walls], dtype=np.float32),
@@ -313,10 +555,13 @@ def update_entity_state(
self,
id: str,
*,
- coord: FloatCoord | List[float] | np.ndarray | Vec2d | None = None,
+ coord: FloatCoord | list[float] | np.ndarray | Vec2d | None = None,
angle: float | None = None,
- vel: FloatCoord | List[float] | np.ndarray | Vec2d | None = None,
- vangle: float | None = None,
+ vel: FloatCoord | list[float] | np.ndarray | Vec2d | None = None,
+ v_angle: float | None = None,
+ local_force: tuple[float, float] | None = None,
+ global_force: tuple[float, float] | None = None,
+ torque: float | None = None,
):
"""Update the state of an entity.
@@ -333,15 +578,26 @@ def update_entity_state(
if vel is not None:
body.velocity = Vec2d(vel[0], vel[1])
- if vangle is not None:
- body.angular_velocity = vangle
+ if v_angle is not None:
+ body.angular_velocity = v_angle
+
+ if local_force is not None:
+ body.apply_force_at_local_point(local_force, (0, 0))
+
+ if global_force is not None:
+ body.apply_force_at_world_point(
+ global_force, (body.position[0], body.position[1])
+ )
+
+ if torque is not None:
+ body.torque = torque
- def get_bounds(self) -> Tuple[FloatCoord, FloatCoord]:
+ def get_bounds(self) -> tuple[FloatCoord, FloatCoord]:
"""Get (min x, max_x), (min y, max y) bounds of the world."""
return (0, self.size), (0, self.size)
@property
- def blocked_coords(self) -> Set[Coord]:
+ def blocked_coords(self) -> set[Coord]:
"""The set of all integer coordinates that contain at least part of a block."""
if self._blocked_coords is None:
self._blocked_coords = set()
@@ -379,7 +635,7 @@ def manhattan_dist(loc1: Location, loc2: Location) -> float:
@staticmethod
def euclidean_dist(loc1: Location, loc2: Location) -> float:
"""Get Euclidean distance between two positions on the grid."""
- return math.sqrt((loc1[0] - loc2[0]) ** 2 + (loc1[1] - loc2[1]) ** 2)
+ return math.sqrt(AbstractContinuousWorld.squared_euclidean_dist(loc1, loc2))
@staticmethod
def squared_euclidean_dist(loc1: Location, loc2: Location) -> float:
@@ -393,7 +649,7 @@ def convert_angle_to_0_2pi_interval(angle: float) -> float:
@staticmethod
def convert_angle_to_negpi_pi_interval(angle: float) -> float:
"""Convert angle in radians to be in (-pi, pi] interval."""
- angle = angle % (2 * math.pi)
+ angle = AbstractContinuousWorld.convert_angle_to_0_2pi_interval(angle)
if angle > math.pi:
angle -= 2 * math.pi
return angle
@@ -401,13 +657,13 @@ def convert_angle_to_negpi_pi_interval(angle: float) -> float:
@staticmethod
def array_to_position(arr: np.ndarray) -> Position:
"""Convert from numpy array to tuple representation of a Position."""
- assert arr.shape[0] >= 3
+ assert arr.shape[0] >= MIN_POSITION_ARRAY_LENGTH
return (arr[0], arr[1], arr[2])
@staticmethod
def linear_to_xy_velocity(linear_vel: float, angle: float) -> Vec2d:
"""Convert from linear velocity to velocity along x and y axis."""
- return linear_vel * Vec2d(1, 0).rotated(angle)
+ return linear_vel * AbstractContinuousWorld.rotate_vector(1, 0, angle)
@staticmethod
def rotate_vector(vx: float, vy: float, angle: float) -> Vec2d:
@@ -415,9 +671,9 @@ def rotate_vector(vx: float, vy: float, angle: float) -> Vec2d:
return Vec2d(vx, vy).rotated(angle)
@staticmethod
- def clamp_norm(vx: float, vy: float, norm_max: float) -> Tuple[float, float]:
+ def clamp_norm(vx: float, vy: float, norm_max: float) -> tuple[float, float]:
"""Clamp x, y vector to within a given max norm."""
- if vx == 0.0 and vy == 0.0:
+ if vx == ZERO and vy == ZERO:
return vx, vy
norm = math.sqrt(vx**2 + vy**2)
f = min(norm, norm_max) / norm
@@ -455,7 +711,7 @@ def check_circle_line_intersection(
) -> np.ndarray:
"""Check if lines intersect circles.
- Arguments
+ Arguments:
---------
circle_coords
array containing the `(x, y)` of the center of each circle. Should have
@@ -469,7 +725,7 @@ def check_circle_line_intersection(
array containing the `(x, y)` coords of the end of of each of the lines.
Should have shape `(n_lines, 2)`
- Returns
+ Returns:
-------
distances
An array containing the euclidean distance from each lines start to
@@ -494,8 +750,8 @@ def check_circle_line_intersection(
t1 = (-b - sqrtdisc) / (2 * a)
t2 = (-b + sqrtdisc) / (2 * a)
- t1 = np.where(((t1 >= 0.0) & (t1 <= 1.0)), t1, np.nan)
- t2 = np.where(((t2 >= 0.0) & (t2 <= 1.0)), t2, np.nan)
+ t1 = np.where(((t1 >= ZERO) & (t1 <= ONE)), t1, np.nan)
+ t2 = np.where(((t2 >= ZERO) & (t2 <= ONE)), t2, np.nan)
t = np.where(t1 <= t2, t1, t2)
t = np.expand_dims(t, axis=-1)
@@ -507,12 +763,12 @@ def check_line_line_intersection(
l1_end_coords: np.ndarray,
l2_start_coords: np.ndarray,
l2_end_coords: np.ndarray,
- ) -> Tuple[np.ndarray, np.ndarray]:
+ ) -> tuple[np.ndarray, np.ndarray]:
"""Check if lines intersect.
Checks for each line in `l1` if it intersects with any line in `l2`.
- Arguments
+ Arguments:
---------
l1_start_coords
array with shape `(n_lines1, 2)` containing the (x, y) coord for the start
@@ -527,7 +783,7 @@ def check_line_line_intersection(
array with shape `(n_lines2, 2)` containing the (x, y) coord for the end of
each of the second set of lines.
- Returns
+ Returns:
-------
intersection_coords
array with shape `(n_lines1, n_lines2, 2)` containing the (x, y) coords for
@@ -564,18 +820,15 @@ def check_line_line_intersection(
dl2p[:, 0] = -dl2[:, 1]
dl2p[:, 1] = dl2[:, 0]
- # mult (n_lines1, 2) @ (n_lines1, 2, nlines2) = (n_lines1, n_lines2)
# each i in n_lines1 is multiplied with one of the n_lines1 matrices in dl1l2
# l1[i] @ (l1[i] - l2[j]) for i in [0, n_lines1], j in [0, n_lines2]
u_num = np.stack([np.matmul(dl1p[i], dl1l2_T[i]) for i in range(dl1p.shape[0])])
- # mult (n_lines2, 2) @ (n_lines2, 2, nlines1) = (n_lines2, n_lines1)
# same as above except for l2 lines
t_num = np.stack(
[np.matmul(dl2p[j], dl1l2_T2[j]) for j in range(dl2p.shape[0])]
)
- # mult (n_lines1, 2) @ (2, n_lines2) = (n_lines1, n_lines2)
# get l1[i] dot l2[j] for i in [0, n_lines1], j in [0, n_lines2]
# but using perpendicular lines to l1,
denom = np.matmul(dl1p, dl2.transpose())
@@ -605,10 +858,10 @@ def check_ray_collisions(
other_agents: np.ndarray | None = None,
include_blocks: bool = True,
check_walls: bool = True,
- ) -> Tuple[np.ndarray, np.ndarray]:
+ ) -> tuple[np.ndarray, np.ndarray]:
"""Check for collision along rays.
- Arguments
+ Arguments:
---------
ray_start_coords
start coords of rays. Should be 2D array with shape `(n_rays, 2`),
@@ -624,7 +877,7 @@ def check_ray_collisions(
check_walls
whether to check for collisions with the world border.
- Returns
+ Returns:
-------
distances
the distance each ray extends sway from the origin, up to a max of
@@ -667,14 +920,12 @@ def check_ray_collisions(
np.fmin(closest_distances, min_dists, out=closest_distances)
if check_walls:
- # shape = (n_lines, walls, 2)
wall_intersect_coords = self.check_border_collisions(
ray_start_coords, ray_end_coords
)
# Need to get coords of intersected walls, each ray can intersect a max of
# of 1 wall, so we just find the minimum non nan coords
- # shape = (n_lines, 2)
with warnings.catch_warnings():
# if no wall intersected, we take min of all NaN which throws a warning
# but this is acceptable behevaiour, so we suppress the warning
@@ -713,14 +964,14 @@ def check_collision_circular_rays(
include_blocks: bool = True,
check_walls: bool = True,
use_relative_angle: bool = True,
- angle_bounds: Tuple[float, float] = (0.0, 2 * np.pi),
- ) -> Tuple[np.ndarray, np.ndarray]:
+ angle_bounds: tuple[float, float] = (0.0, 2 * np.pi),
+ ) -> tuple[np.ndarray, np.ndarray]:
"""Check for collision along rays that radiate away from the origin.
Rays are evenly spaced around the origin, with the number of rays controlled
by the `n_rays` arguments.
- Arguments
+ Arguments:
---------
origin
the origin position
@@ -744,7 +995,7 @@ def check_collision_circular_rays(
have a full FOV as a circle around them. This will be between 0 and 2π.
This can be decreased as needed.
- Returns
+ Returns:
-------
distances
the distance each ray extends sway from the origin, up to a max of
@@ -776,7 +1027,7 @@ def check_collision_circular_rays(
def get_all_shortest_paths(
self, origins: Iterable[FloatCoord | Coord | Position]
- ) -> Dict[Tuple[int, int], Dict[Tuple[int, int], int]]:
+ ) -> dict[tuple[int, int], dict[tuple[int, int], int]]:
"""Get shortest path distance from every origin to all other coords."""
src_dists = {}
for origin in origins:
@@ -788,12 +1039,12 @@ def convert_to_coord(self, origin: FloatCoord | Coord | Position) -> Coord:
"""Convert a position/float coord to a integer coords."""
return (math.floor(origin[0]), math.floor(origin[1]))
- def dijkstra(self, origin: FloatCoord | Coord | Position) -> Dict[Coord, int]:
+ def dijkstra(self, origin: FloatCoord | Coord | Position) -> dict[Coord, int]:
"""Get shortest path distance between origin and all other coords."""
coord_origin = self.convert_to_coord(origin)
dist = {coord_origin: 0}
- pq: PriorityQueue[Tuple[int, Coord]] = PriorityQueue()
+ pq: PriorityQueue[tuple[int, Coord]] = PriorityQueue()
pq.put((dist[coord_origin], coord_origin))
visited = {coord_origin}
@@ -816,7 +1067,7 @@ def get_cardinal_neighbours_coords(
coord: Coord,
ignore_blocks: bool = False,
include_out_of_bounds: bool = False,
- ) -> List[Coord]:
+ ) -> list[Coord]:
"""Get set of adjacent non-blocked coords."""
(min_x, max_x), (min_y, max_y) = self.get_bounds()
neighbours = []
@@ -842,7 +1093,7 @@ def get_cardinal_neighbours_coords(
class SquareContinuousWorld(AbstractContinuousWorld):
"""A continuous world with a square border."""
- def copy(self) -> "SquareContinuousWorld":
+ def copy(self) -> SquareContinuousWorld:
world = SquareContinuousWorld(
size=self.size,
blocks=self.blocks,
@@ -873,7 +1124,7 @@ def add_border_to_space(self, size: float):
),
)
- for w_start, w_end in zip(*self.border):
+ for w_start, w_end in zip(*self.border, strict=False):
wall = pymunk.Segment(
self.space.static_body,
(w_start[0], w_start[1]),
@@ -901,7 +1152,7 @@ def check_border_collisions(
class CircularContinuousWorld(AbstractContinuousWorld):
"""A 2D continuous world with a circular border."""
- def copy(self) -> "CircularContinuousWorld":
+ def copy(self) -> CircularContinuousWorld:
world = CircularContinuousWorld(
size=self.size,
blocks=self.blocks,
@@ -970,7 +1221,7 @@ def clip_position(self, position: Vec2d) -> Vec2d:
def generate_interior_walls(
width: int, height: int, blocked_coords: Iterable[Coord]
-) -> List[Line]:
+) -> list[Line]:
"""Generate interior walls for rectangular world based on blocked coordinates."""
# dx, dy
# north, east, south, west
@@ -983,10 +1234,10 @@ def generate_interior_walls(
]
# get line for each block face adjacent to empty cell
- lines_map: Dict[Coord, Set[Coord]] = {}
- lines: Set[IntLine] = set()
+ lines_map: dict[Coord, set[Coord]] = {}
+ lines: set[IntLine] = set()
for x, y in blocked_coords:
- for (dx, dy), line_offset in zip(directions, line_offsets):
+ for (dx, dy), line_offset in zip(directions, line_offsets, strict=False):
if (x + dx, y + dy) in blocked_coords:
# adjacent cell blocked
continue
@@ -1009,7 +1260,7 @@ def generate_interior_walls(
lines.add((l_start, l_end))
# merge lines
- merged_lines: List[IntLine] = []
+ merged_lines: list[IntLine] = []
# lines l1 and l2 can merge if
# 1. l1[1] == l2[0] and (l1[0][0] == l2[1][0] or l1[0][1] == l2[1][1]
@@ -1020,7 +1271,7 @@ def generate_interior_walls(
stack = list(lines)
stack.sort(reverse=True)
- visited: Set[IntLine] = set()
+ visited: set[IntLine] = set()
while len(stack):
line = stack.pop()
if line in visited:
diff --git a/posggym/envs/continuous/driving_continuous.py b/posggym/envs/continuous/driving_continuous.py
index a26d50f..e6728ae 100644
--- a/posggym/envs/continuous/driving_continuous.py
+++ b/posggym/envs/continuous/driving_continuous.py
@@ -1,25 +1,40 @@
"""The Driving Continuous Environment."""
+from __future__ import annotations
import math
from itertools import product
-from typing import Any, Dict, List, NamedTuple, Optional, Set, Tuple, Union, cast
+from typing import (
+ Any,
+ ClassVar,
+ NamedTuple,
+ cast,
+)
import numpy as np
from gymnasium import spaces
-from pymunk import Vec2d
import posggym.model as M
from posggym import logger
from posggym.core import DefaultEnv
from posggym.envs.continuous.core import (
AGENT_COLORS,
+ ANGLE_IDX,
+ VX_IDX,
+ VY_IDX,
+ X_IDX,
+ Y_IDX,
CollisionType,
+ ControlType,
Coord,
FloatCoord,
PMBodyState,
SquareContinuousWorld,
+ clamp,
clip_actions,
+ generate_action_space,
generate_interior_walls,
+ generate_parameters,
+ scale_action,
)
from posggym.utils import seeding
@@ -30,10 +45,11 @@ class VehicleState(NamedTuple):
body: np.ndarray
dest_coord: np.ndarray
status: np.ndarray
+ dest_dist: np.ndarray
min_dest_dist: np.ndarray
-DState = Tuple[VehicleState, ...]
+DState = tuple[VehicleState, ...]
# Obs = (sensor obs, dir, vx, vy, dest_coord)
DObs = np.ndarray
@@ -144,9 +160,8 @@ class DrivingContinuousEnv(DefaultEnv[DState, DObs, DAction]):
larger worlds (this can be done by manually specifying a value for
`max_episode_steps` when creating the environment with `posggym.make`).
- Arguments
+ Arguments:
---------
-
- `world` - the world layout to use. This can either be a string specifying one of
the supported worlds, or a custom :class:`DrivingWorld` object
(default = `"14x14RoundAbout"`).
@@ -185,32 +200,50 @@ class DrivingContinuousEnv(DefaultEnv[DState, DObs, DAction]):
---------------
- `v0`: Initial version
- References
+ References:
----------
- Adam Lerer and Alexander Peysakhovich. 2019. Learning Existing Social Conventions
via Observationally Augmented Self-Play. In Proceedings of the 2019 AAAI/ACM
- Conference on AI, Ethics, and Society. 107–114.
+ Conference on AI, Ethics, and Society. 107-114.
- Kevin R. McKee, Joel Z. Leibo, Charlie Beattie, and Richard Everett. 2022.
Quantifying the Effects of Environment and Population Diversity in Multi-Agent
- Reinforcement Learning. Autonomous Agents and Multi-Agent Systems 36, 1 (2022), 1–16
+ Reinforcement Learning. Autonomous Agents and Multi-Agent Systems 36, 1 (2022), 1-16
"""
- metadata = {
+ metadata: ClassVar[dict] = {
"render_modes": ["human", "rgb_array"],
"render_fps": 15,
}
def __init__(
self,
- world: Union[str, "DrivingWorld"] = "14x14RoundAbout",
- num_agents: int = 2,
+ world: str | DrivingWorld = "14x14RoundAbout",
+ num_agents: int = 1,
obs_dist: float = 5.0,
n_sensors: int = 16,
- render_mode: Optional[str] = None,
- ):
+ obs_self_model: bool = False,
+ control_type: ControlType | str = ControlType.VelocityNonHolonomoic,
+ render_mode: str | None = None,
+ ) -> None:
+ if isinstance(control_type, str):
+ try:
+ control_type = ControlType.from_str(control_type)
+ except ValueError:
+ logger.warning(
+ "Invalid control type, defaulting to VelocityNonHolonomoic"
+ )
+ control_type = ControlType.VelocityNonHolonomoic
+
super().__init__(
- DrivingContinuousModel(world, num_agents, obs_dist, n_sensors),
+ DrivingContinuousModel(
+ world,
+ num_agents,
+ obs_dist,
+ n_sensors,
+ obs_self_model,
+ control_type,
+ ),
render_mode=render_mode,
)
self.window_surface = None
@@ -223,7 +256,7 @@ def __init__(
def render(self):
if self.render_mode is None:
assert self.spec is not None
- logger.warn(
+ logger.warning(
"You are calling render method without specifying any render mode. "
"You can specify the render_mode at initialization, "
f'e.g. posggym.make("{self.spec.id}", render_mode="rgb_array")'
@@ -232,14 +265,12 @@ def render(self):
return self._render_img()
def _render_img(self):
- # import posggym.envs.continuous.render as render_lib
import pygame
from pymunk import Transform, pygame_util
model = cast(DrivingContinuousModel, self.model)
state = cast(DState, self.state)
scale_factor = self.window_size / model.world.size
-
if self.window_surface is None:
pygame.init()
if self.render_mode == "human":
@@ -300,10 +331,11 @@ def _render_img(self):
lines_colors = ["red", "green", "black"]
# draw sensor lines
+
n_sensors = model.n_sensors
for i, obs_i in self._last_obs.items():
line_obs = obs_i[: model.sensor_obs_dim]
- x, y, agent_angle = state[int(i)].body[:3]
+ x, y, agent_angle = state[int(i)].body[[X_IDX, Y_IDX, ANGLE_IDX]]
angle_inc = 2 * math.pi / n_sensors
for k in range(n_sensors):
values = [
@@ -318,7 +350,7 @@ def _render_img(self):
end_x = x + dist * math.cos(angle)
end_y = y + dist * math.sin(angle)
scaled_start = (int(x * scale_factor), int(y * scale_factor))
- scaled_end = int(end_x * scale_factor), (end_y * scale_factor)
+ scaled_end = (int(end_x * scale_factor), int(end_y * scale_factor))
pygame.draw.line(
self.window_surface,
@@ -376,17 +408,19 @@ class DrivingContinuousModel(M.POSGModel[DState, DObs, DAction]):
"""
R_STEP_COST = 0.00
- R_CRASH_VEHICLE = -1.0
+ R_CRASH_VEHICLE = -5.0
R_DESTINATION_REACHED = 1.0
R_PROGRESS = 0.05
def __init__(
self,
- world: Union[str, "DrivingWorld"],
+ world: str | DrivingWorld,
num_agents: int,
obs_dist: float,
n_sensors: int,
- ):
+ obs_self_model: bool,
+ control_type: ControlType,
+ ) -> None:
if isinstance(world, str):
assert world in SUPPORTED_WORLDS, (
f"Unsupported world '{world}'. If world argument is a string it must "
@@ -394,7 +428,8 @@ def __init__(
)
world_info = SUPPORTED_WORLDS[world]
world = parseworld_str(
- world_info["world_str"], world_info["supported_num_agents"]
+ world_info["world_str"],
+ world_info["supported_num_agents"],
)
assert 0 < num_agents <= world.supported_num_agents, (
f"Supplied DrivingWorld `{world}` does not support {num_agents} "
@@ -405,7 +440,11 @@ def __init__(
self.world = world
self.n_sensors = n_sensors
self.obs_dist = obs_dist
+ self.obs_self_model = obs_self_model
self.vehicle_collision_dist = 2.1 * self.world.agent_radius
+ self.control_type = control_type
+ self.dt = 1.0
+ self.substeps = 10
self.possible_agents = tuple(str(i) for i in range(num_agents))
self.state_space = spaces.Tuple(
@@ -425,6 +464,10 @@ def __init__(
low=np.array([0], dtype=np.float32),
high=np.array([self.world.size**2], dtype=np.float32),
),
+ spaces.Box(
+ low=np.array([0], dtype=np.float32),
+ high=np.array([self.world.size**2], dtype=np.float32),
+ ),
)
)
for _ in range(len(self.possible_agents))
@@ -433,31 +476,51 @@ def __init__(
self.dyaw_limit = math.pi / 4
self.dvel_limit = 0.25
- self.vel_limit_norm = 1.0
- # dyaw, dvel
+
+ self.fyaw_limit = math.pi
+ self.fvel_limit = 3.0
+
+ self.action_spaces_per_control = generate_action_space(
+ self.possible_agents,
+ self.dyaw_limit,
+ self.dvel_limit,
+ self.fyaw_limit,
+ self.fvel_limit,
+ )
+
self.action_spaces = {
- i: spaces.Box(
- low=np.array([-self.dyaw_limit, -self.dvel_limit], dtype=np.float32),
- high=np.array([self.dyaw_limit, self.dvel_limit], dtype=np.float32),
- )
+ i: spaces.Box(np.array([-1, -1]), np.array([1, 1]))
for i in self.possible_agents
}
+ self.control_types = {i: self.control_type for i in self.possible_agents}
+ self.init_kinematics()
+ self.vel_limit_norm = 1.0
# Observes entity and distance to entity along a n_sensors rays from the agent
# 0 to n_sensors = wall distance obs
# n_sensors to (2 * n_sensors) = other vehicle dist
# Also observs angle, vx, vy, dest dx, desy dy
self.sensor_obs_dim = self.n_sensors * 2
- self.obs_dim = self.sensor_obs_dim + 5
+ self.obs_dim = self.sensor_obs_dim + 5 + int(self.obs_self_model)
sensor_low = [0.0] * self.sensor_obs_dim
sensor_high = [self.obs_dist] * self.sensor_obs_dim
self.observation_spaces = {
i: spaces.Box(
low=np.array(
- [*sensor_low, -2 * math.pi, -1, -1, 0, 0], dtype=np.float32
+ [
+ *sensor_low,
+ -2 * math.pi,
+ -1,
+ -1,
+ -self.world.size,
+ -self.world.size,
+ ]
+ + ([0] if self.obs_self_model else []),
+ dtype=np.float32,
),
high=np.array(
- [*sensor_high, 2 * math.pi, 1, 1, self.world.size, self.world.size],
+ [*sensor_high, 2 * math.pi, 1, 1, self.world.size, self.world.size]
+ + ([len(ControlType)] if self.obs_self_model else []),
dtype=np.float32,
),
dtype=np.float32,
@@ -473,7 +536,7 @@ def __init__(
self.is_symmetric = True
@property
- def reward_ranges(self) -> Dict[str, Tuple[float, float]]:
+ def reward_ranges(self) -> dict[str, tuple[float, float]]:
return {
i: (self.R_CRASH_VEHICLE, self.R_DESTINATION_REACHED)
for i in self.possible_agents
@@ -485,13 +548,13 @@ def rng(self) -> seeding.RNG:
self._rng, seed = seeding.std_random()
return self._rng
- def get_agents(self, state: DState) -> List[str]:
+ def get_agents(self, state: DState) -> list[str]:
return list(self.possible_agents)
def sample_initial_state(self) -> DState:
state = []
- chosen_start_coords: Set[FloatCoord] = set()
- chosen_dest_coords: Set[FloatCoord] = set()
+ chosen_start_coords: set[FloatCoord] = set()
+ chosen_dest_coords: set[FloatCoord] = set()
for i in range(len(self.possible_agents)):
start_coords_i = self.world.start_coords[i]
avail_start_coords = start_coords_i.difference(chosen_start_coords)
@@ -504,9 +567,10 @@ def sample_initial_state(self) -> DState:
avail_dest_coords.remove(start_coord)
body_state = np.zeros((PMBodyState.num_features()), dtype=np.float32)
- body_state[:2] = start_coord
+ body_state[[X_IDX, Y_IDX]] = start_coord
_dest_coord = self.rng.choice(list(avail_dest_coords))
+
chosen_dest_coords.add(_dest_coord)
dest_coord = np.array(_dest_coord, dtype=np.float32)
@@ -516,28 +580,30 @@ def sample_initial_state(self) -> DState:
body=body_state,
dest_coord=dest_coord,
status=np.array([int(False), int(False)], dtype=np.int8),
+ dest_dist=np.array([dest_dist], dtype=np.float32),
min_dest_dist=np.array([dest_dist], dtype=np.float32),
)
state.append(state_i)
return tuple(state)
- def sample_initial_obs(self, state: DState) -> Dict[str, DObs]:
+ def sample_initial_obs(self, state: DState) -> dict[str, DObs]:
return self._get_obs(state)
def step(
- self, state: DState, actions: Dict[str, DAction]
+ self, state: DState, actions: dict[str, DAction]
) -> M.JointTimestep[DState, DObs]:
clipped_actions = clip_actions(actions, self.action_spaces)
next_state, collision_types = self._get_next_state(state, clipped_actions)
obs = self._get_obs(next_state)
+
rewards = self._get_rewards(state, next_state, collision_types)
terminated = {i: any(next_state[int(i)].status) for i in self.possible_agents}
truncated = {i: False for i in self.possible_agents}
all_done = all(terminated.values())
- info: Dict[str, Dict] = {i: {} for i in self.possible_agents}
+ info: dict[str, dict] = {i: {} for i in self.possible_agents}
for idx in range(len(self.possible_agents)):
if next_state[idx].status[0]:
outcome_i = M.Outcome.WIN
@@ -551,55 +617,79 @@ def step(
next_state, obs, rewards, terminated, truncated, all_done, info
)
+ def init_kinematics(self):
+ self.kinematic_parameters = {
+ i: generate_parameters(self.control_types[i]) for i in self.possible_agents
+ }
+
def _get_next_state(
- self, state: DState, actions: Dict[str, DAction]
- ) -> Tuple[DState, List[CollisionType]]:
+ self, state: DState, actions: dict[str, DAction]
+ ) -> tuple[DState, list[CollisionType]]:
for i in range(len(self.possible_agents)):
state_i = state[i]
+ action_i = actions[str(i)]
+
self.world.set_entity_state(f"vehicle_{i}", state_i.body)
if state[i].status[0] or state[i].status[1]:
self.world.update_entity_state(f"vehicle_{i}", vel=(0.0, 0.0))
continue
- action_i = actions[str(i)]
- v_angle = state_i.body[2] + action_i[0]
- v_vel = Vec2d(*state_i.body[3:5]).rotated(action_i[0]) + (
- action_i[1] * Vec2d(1, 0).rotated(v_angle)
- )
- self.world.update_entity_state(
- f"vehicle_{i}",
- angle=v_angle,
- vel=self.world.clamp_norm(v_vel[0], v_vel[1], self.vel_limit_norm),
+ action_scaled = scale_action(
+ action_i,
+ self.action_spaces[str(i)],
+ self.action_spaces_per_control[self.control_types[str(i)]][str(i)],
)
- self.world.simulate(1.0 / 10, 10)
+ result = self.world.compute_vel_force(
+ self.control_types[str(i)],
+ state_i.body[ANGLE_IDX],
+ (state_i.body[VX_IDX], state_i.body[VY_IDX]),
+ action_scaled,
+ self.vel_limit_norm,
+ self.kinematic_parameters[str(i)],
+ )
+ self.world.update_entity_state(f"vehicle_{i}", **result)
+ self.world.simulate(self.dt / self.substeps, self.substeps)
collision_types = [CollisionType.NONE] * len(self.possible_agents)
- new_state: List[Optional[VehicleState]] = [None] * len(self.possible_agents)
+ new_state: list[VehicleState | None] = [None] * len(self.possible_agents)
for idx in range(len(self.possible_agents)):
next_v_body_state = np.array(
self.world.get_entity_state(f"vehicle_{idx}"), dtype=np.float32
)
- next_v_body_state[2] = self.world.convert_angle_to_0_2pi_interval(
- next_v_body_state[2]
+ next_v_body_state[ANGLE_IDX] = self.world.convert_angle_to_0_2pi_interval(
+ next_v_body_state[ANGLE_IDX]
)
# ensure vx, vy is in [-1, 1]
# with collisions, etc pymunk can sometime push it over this limit
- next_v_body_state[3] = max(-1.0, min(1.0, next_v_body_state[3]))
- next_v_body_state[4] = max(-1.0, min(1.0, next_v_body_state[4]))
+ next_v_body_state[VX_IDX] = clamp(next_v_body_state[VX_IDX], -1.0, 1.0)
+ next_v_body_state[VY_IDX] = clamp(next_v_body_state[VY_IDX], -1.0, 1.0)
state_i = state[idx]
- next_v_coords = next_v_body_state[:2]
- dest_distance = np.linalg.norm(state_i.dest_coord - next_v_coords)
+ next_v_coords = next_v_body_state[[X_IDX, Y_IDX]]
+ current_v_coords = state_i.body[[X_IDX, Y_IDX]]
+
+ # Interpolate between start and end, in case it was between states.
+ fractions = np.array([0, 0.2, 0.4, 0.6, 0.8, 1]) # Array of fractions
+ intermediate_vectors = (
+ current_v_coords[:, np.newaxis]
+ + (next_v_coords - current_v_coords)[:, np.newaxis] * fractions
+ ).T
+ dest_distance = np.linalg.norm(
+ state_i.dest_coord - intermediate_vectors, axis=1
+ )
+
crashed = False
for other_idx, other_v_state in enumerate(new_state):
if other_v_state is None:
continue
- dist = np.linalg.norm(other_v_state.body[:2] - next_v_coords)
+ dist = np.linalg.norm(
+ other_v_state.body[[X_IDX, Y_IDX]] - next_v_coords
+ )
if dist <= self.vehicle_collision_dist:
crashed = True
collision_types[idx] = CollisionType.AGENT
@@ -611,21 +701,24 @@ def _get_next_state(
crashed = crashed or bool(state_i.status[1])
- min_dest_dist = min(
- state_i.min_dest_dist[0],
- self.world.get_shortest_path_distance(
- (next_v_body_state[0], next_v_body_state[1]),
- (state_i.dest_coord[0], state_i.dest_coord[1]),
- ),
+ dest_dist = self.world.get_shortest_path_distance(
+ (next_v_body_state[X_IDX], next_v_body_state[Y_IDX]),
+ (state_i.dest_coord[X_IDX], state_i.dest_coord[Y_IDX]),
)
+ min_dest_dist = min(dest_dist, state_i.min_dest_dist[0])
+
new_state[idx] = VehicleState(
body=next_v_body_state,
dest_coord=state_i.dest_coord,
status=np.array(
- [int(dest_distance <= self.world.agent_radius), int(crashed)],
+ [
+ int((dest_distance <= self.world.agent_radius).any()),
+ int(crashed),
+ ],
dtype=np.int8,
),
+ dest_dist=np.array([dest_dist], dtype=np.float32),
min_dest_dist=np.array([min_dest_dist], dtype=np.float32),
)
@@ -634,7 +727,7 @@ def _get_next_state(
return tuple(final_state), collision_types
- def _get_obs(self, state: DState) -> Dict[str, DObs]:
+ def _get_obs(self, state: DState) -> dict[str, DObs]:
return {i: self._get_agent_obs(i, state) for i in self.possible_agents}
def _get_agent_obs(self, agent_id: str, state: DState) -> np.ndarray:
@@ -642,9 +735,13 @@ def _get_agent_obs(self, agent_id: str, state: DState) -> np.ndarray:
if state_i.status[0] or state_i.status[1]:
return np.zeros((self.obs_dim,), dtype=np.float32)
- pos_i = (state_i.body[0], state_i.body[1], state_i.body[2])
+ pos_i = (state_i.body[X_IDX], state_i.body[Y_IDX], state_i.body[ANGLE_IDX])
vehicle_coords = np.array(
- [[s.body[0], s.body[1]] for i, s in enumerate(state) if i != int(agent_id)]
+ [
+ [s.body[X_IDX], s.body[Y_IDX]]
+ for i, s in enumerate(state)
+ if i != int(agent_id)
+ ]
)
ray_dists, ray_col_type = self.world.check_collision_circular_rays(
@@ -667,20 +764,24 @@ def _get_agent_obs(self, agent_id: str, state: DState) -> np.ndarray:
obs[flat_obs_idx] = ray_dists
d = self.sensor_obs_dim
- obs[d] = self.world.convert_angle_to_0_2pi_interval(state_i.body[2])
- obs[d + 1] = max(-1.0, min(1.0, state_i.body[3]))
- obs[d + 2] = max(-1.0, min(1.0, state_i.body[4]))
- obs[d + 3] = abs(state_i.dest_coord[0] - pos_i[0])
- obs[d + 4] = abs(state_i.dest_coord[1] - pos_i[1])
+ obs[d] = self.world.convert_angle_to_0_2pi_interval(state_i.body[ANGLE_IDX])
+ obs[d + 1] = clamp(state_i.body[VX_IDX], -1.0, 1.0)
+ obs[d + 2] = clamp(state_i.body[VY_IDX], -1.0, 1.0)
+ obs[d + 3] = state_i.dest_coord[X_IDX] - pos_i[X_IDX]
+ obs[d + 4] = state_i.dest_coord[Y_IDX] - pos_i[Y_IDX]
+ if self.obs_self_model:
+ obs[d + 5] = int(self.control_types[agent_id])
return obs
def _get_rewards(
- self, state: DState, next_state: DState, collision_types: List[CollisionType]
- ) -> Dict[str, float]:
- rewards: Dict[str, float] = {}
- for i in self.possible_agents:
- idx = int(i)
+ self,
+ state: DState,
+ next_state: DState,
+ collision_types: list[CollisionType],
+ ) -> dict[str, float]:
+ rewards: dict[str, float] = {}
+ for idx in map(int, self.possible_agents):
if any(state[idx].status):
# already in terminal/rewarded state
r_i = 0.0
@@ -694,7 +795,7 @@ def _get_rewards(
progress = (state[idx].min_dest_dist - next_state[idx].min_dest_dist)[0]
r_i += max(0, progress) * self.R_PROGRESS
- rewards[i] = r_i
+ rewards[str(idx)] = r_i
return rewards
@@ -704,10 +805,10 @@ class DrivingWorld(SquareContinuousWorld):
def __init__(
self,
size: int,
- blocked_coords: Set[Coord],
- start_coords: List[Set[FloatCoord]],
- dest_coords: List[Set[FloatCoord]],
- ):
+ blocked_coords: set[Coord],
+ start_coords: list[set[FloatCoord]],
+ dest_coords: list[set[FloatCoord]],
+ ) -> None:
interior_walls = generate_interior_walls(size, size, blocked_coords)
super().__init__(
size=size,
@@ -721,11 +822,9 @@ def __init__(
self._blocked_coords = blocked_coords
self.start_coords = start_coords
self.dest_coords = dest_coords
- self.shortest_paths = self.get_all_shortest_paths(
- set.union(*dest_coords) # type: ignore
- )
+ self.shortest_paths = self.get_all_shortest_paths(set.union(*dest_coords))
- def copy(self) -> "DrivingWorld":
+ def copy(self) -> DrivingWorld:
assert self._blocked_coords is not None
world = DrivingWorld(
size=int(self.size),
@@ -748,7 +847,9 @@ def supported_num_agents(self) -> int:
"""Get the number of agents supported by this world."""
return len(self.start_coords)
- def get_shortest_path_distance(self, coord: FloatCoord, dest: FloatCoord) -> int:
+ def get_shortest_path_distance(
+ self, coord: FloatCoord, dest: FloatCoord
+ ) -> int | float:
"""Get the shortest path distance from coord to destination."""
coord_c = self.convert_to_coord(coord)
dest_c = self.convert_to_coord(dest)
@@ -801,13 +902,13 @@ def parseworld_str(world_str: str, supported_num_agents: int) -> DrivingWorld:
width = len(row_strs[0])
agent_start_chars = set(["+"] + [str(i) for i in range(10)])
- agent_dest_chars = set(["-"] + list("abcdefghij"))
+ agent_dest_chars = {"-", *list("abcdefghij")}
- blocked_coords: Set[Coord] = set()
- shared_start_coords: Set[FloatCoord] = set()
- agent_start_coords_map: Dict[int, Set[FloatCoord]] = {}
- shared_dest_coords: Set[FloatCoord] = set()
- agent_dest_coords_map: Dict[int, Set[FloatCoord]] = {}
+ blocked_coords: set[Coord] = set()
+ shared_start_coords: set[FloatCoord] = set()
+ agent_start_coords_map: dict[int, set[FloatCoord]] = {}
+ shared_dest_coords: set[FloatCoord] = set()
+ agent_dest_coords_map: dict[int, set[FloatCoord]] = {}
for r, c in product(range(height), range(width)):
coord = (c + 0.5, r + 0.5)
char = row_strs[r][c]
@@ -840,8 +941,8 @@ def parseworld_str(world_str: str, supported_num_agents: int) -> DrivingWorld:
if len(included_agent_ids) > 0:
assert max(included_agent_ids) < supported_num_agents
- start_coords: List[Set[FloatCoord]] = []
- dest_coords: List[Set[FloatCoord]] = []
+ start_coords: list[set[FloatCoord]] = []
+ dest_coords: list[set[FloatCoord]] = []
for i in range(supported_num_agents):
agent_start_coords = set(shared_start_coords)
agent_start_coords.update(agent_start_coords_map.get(i, {}))
@@ -859,11 +960,18 @@ def parseworld_str(world_str: str, supported_num_agents: int) -> DrivingWorld:
)
-SUPPORTED_WORLDS: Dict[str, Dict[str, Any]] = {
+SUPPORTED_WORLDS: dict[str, dict[str, Any]] = {
"6x6Intersection": {
+ # fmt: off
"world_str": (
- "##0b##\n" "##..##\n" "d....3\n" "2....c\n" "##..##\n" "##a1##\n"
+ "##0b##\n"
+ "##..##\n"
+ "d....3\n"
+ "2....c\n"
+ "##..##\n"
+ "##a1##\n"
),
+ # fmt: on
"supported_num_agents": 4,
"max_episode_steps": 20,
},
@@ -878,7 +986,7 @@ def parseworld_str(world_str: str, supported_num_agents: int) -> DrivingWorld:
"#+...+#\n"
),
"supported_num_agents": 4,
- "max_episode_steps": 50,
+ "max_episode_steps": 500,
},
"7x7CrissCross": {
"world_str": (
@@ -891,7 +999,7 @@ def parseworld_str(world_str: str, supported_num_agents: int) -> DrivingWorld:
"#+#+#+#\n"
),
"supported_num_agents": 6,
- "max_episode_steps": 50,
+ "max_episode_steps": 500,
},
"7x7RoundAbout": {
"world_str": (
@@ -904,7 +1012,7 @@ def parseworld_str(world_str: str, supported_num_agents: int) -> DrivingWorld:
"#+...+#\n"
),
"supported_num_agents": 4,
- "max_episode_steps": 50,
+ "max_episode_steps": 500,
},
"14x14Blocks": {
"world_str": (
@@ -924,7 +1032,7 @@ def parseworld_str(world_str: str, supported_num_agents: int) -> DrivingWorld:
"#+..........+#\n"
),
"supported_num_agents": 4,
- "max_episode_steps": 50,
+ "max_episode_steps": 500,
},
"14x14CrissCross": {
"world_str": (
@@ -944,7 +1052,7 @@ def parseworld_str(world_str: str, supported_num_agents: int) -> DrivingWorld:
"##+##+##+##+##\n"
),
"supported_num_agents": 8,
- "max_episode_steps": 50,
+ "max_episode_steps": 500,
},
"14x14RoundAbout": {
"world_str": (
@@ -964,6 +1072,42 @@ def parseworld_str(world_str: str, supported_num_agents: int) -> DrivingWorld:
"#+..........+#\n"
),
"supported_num_agents": 4,
- "max_episode_steps": 50,
+ "max_episode_steps": 500,
+ },
+ "14x14Empty": {
+ "world_str": (
+ ".-------------\n"
+ "-------------+\n"
+ "--------------\n"
+ "--------------\n"
+ "--------------\n"
+ "--------------\n"
+ "--------------\n"
+ "--------------\n"
+ "++++++++++++++\n"
+ "++++++++++++++\n"
+ "++++++++++++++\n"
+ "++++++++++++++\n"
+ "-+++++++++++++\n"
+ "++++++++++++++\n"
+ ),
+ "supported_num_agents": 4,
+ "max_episode_steps": 5000,
},
}
+
+if __name__ == "__main__":
+ from posggym.utils.run_random_agents import run_random
+
+ env = DrivingContinuousEnv(
+ render_mode="human",
+ obs_self_model=True,
+ num_agents=1,
+ control_type=ControlType.VelocityHolonomoic,
+ )
+
+ run_random(
+ env=env,
+ num_episodes=1,
+ max_episode_steps=1000,
+ )
diff --git a/posggym/envs/continuous/drone_team_capture.py b/posggym/envs/continuous/drone_team_capture.py
index 65afb23..7c93b46 100644
--- a/posggym/envs/continuous/drone_team_capture.py
+++ b/posggym/envs/continuous/drone_team_capture.py
@@ -1,6 +1,6 @@
"""The Drone Team Capture Environment."""
import math
-from typing import Dict, List, NamedTuple, Optional, Tuple, cast
+from typing import ClassVar, NamedTuple, cast
import numpy as np
from gymnasium import spaces
@@ -9,6 +9,9 @@
from posggym import logger
from posggym.core import DefaultEnv
from posggym.envs.continuous.core import (
+ ANGLE_IDX,
+ X_IDX,
+ Y_IDX,
CircularContinuousWorld,
PMBodyState,
clip_actions,
@@ -140,9 +143,8 @@ class DroneTeamCaptureEnv(DefaultEnv[DTCState, DTCObs, DTCAction]):
`max_episode_steps` when creating the environment with `posggym.make`).
- Arguments
+ Arguments:
---------
-
- `num_agents` - The number of agents which exist in the environment
Must be between 1 and 8 (default = `3`)
- `n_communicating_pursuers - The maximum number of agents which an
@@ -191,7 +193,7 @@ class DroneTeamCaptureEnv(DefaultEnv[DTCState, DTCObs, DTCAction]):
"""
- metadata = {
+ metadata: ClassVar[dict] = {
"render_modes": ["human", "rgb_array"],
"render_fps": 15,
}
@@ -199,14 +201,14 @@ class DroneTeamCaptureEnv(DefaultEnv[DTCState, DTCObs, DTCAction]):
def __init__(
self,
num_agents: int = 3,
- n_communicating_pursuers: Optional[int] = None,
+ n_communicating_pursuers: int | None = None,
arena_radius: float = 430,
- observation_limit: Optional[float] = None,
+ observation_limit: float | None = None,
velocity_control: bool = False,
capture_radius: float = 30,
use_q_reward: bool = False,
- render_mode: Optional[str] = None,
- ):
+ render_mode: str | None = None,
+ ) -> None:
super().__init__(
DroneTeamCaptureModel(
num_agents,
@@ -233,7 +235,7 @@ def __init__(
def render(self):
if self.render_mode is None:
assert self.spec is not None
- logger.warn(
+ logger.warning(
"You are calling render method without specifying any render mode. "
"You can specify the render_mode at initialization, "
f'e.g. posggym.make("{self.spec.id}", render_mode="rgb_array")'
@@ -343,18 +345,19 @@ class DroneTeamCaptureModel(M.POSGModel[DTCState, DTCObs, DTCAction]):
PURSUER_COLOR = (55, 155, 205, 255) # blueish
EVADER_COLOR = (110, 55, 155, 255) # purpleish
+ MAX_AGENTS = 8
def __init__(
self,
num_agents: int,
- n_communicating_pursuers: Optional[int] = None,
+ n_communicating_pursuers: int | None = None,
arena_radius: float = 430,
- observation_limit: Optional[float] = None,
+ observation_limit: float | None = None,
velocity_control: bool = False,
capture_radius: float = 30,
use_q_reward: bool = False,
- ):
- assert 1 < num_agents <= 8
+ ) -> None:
+ assert 1 < num_agents <= self.MAX_AGENTS
assert (
n_communicating_pursuers is None
or 0 < n_communicating_pursuers < num_agents
@@ -394,7 +397,6 @@ def __init__(
for i in self.possible_agents
}
else:
- # act[0] = angular velocity
self.action_spaces = {
i: spaces.Box(
low=np.array([-self.dyaw_limit], dtype=np.float32),
@@ -460,11 +462,11 @@ def __init__(
self.world.add_entity(f"pursuer_{i}", None, color=self.PURSUER_COLOR)
self.world.add_entity("evader", None, color=self.EVADER_COLOR)
- def get_agents(self, state: DTCState) -> List[str]:
+ def get_agents(self, state: DTCState) -> list[str]:
return list(self.possible_agents)
@property
- def reward_ranges(self) -> Dict[str, Tuple[float, float]]:
+ def reward_ranges(self) -> dict[str, tuple[float, float]]:
min_reward = self.R_TARGET_DIST_COEFF * 2 * self.r_arena
max_reward = self.R_CAPTURE_TEAM + self.R_CAPTURE
if self.use_q_reward:
@@ -487,7 +489,7 @@ def sample_initial_state(self) -> DTCState:
for i in range(self.n_pursuers):
# distributes the agents based on their index
x = 50.0 * (-math.floor(self.n_pursuers / 2) + i) + self.r_arena
- pursuer_states[i][:3] = (x, self.r_arena, 0.0)
+ pursuer_states[i][[X_IDX, Y_IDX, ANGLE_IDX]] = (x, self.r_arena, 0.0)
# Target is placed randomly in sphere,
# excluding area near center where pursuers start
@@ -507,7 +509,7 @@ def sample_initial_state(self) -> DTCState:
target_vel = relative_target_vel * self.max_pursuer_vel
target_state = np.zeros((PMBodyState.num_features()), dtype=np.float32)
- target_state[:3] = (x, y, 0.0)
+ target_state[[X_IDX, Y_IDX, ANGLE_IDX]] = (x, y, 0.0)
return DTCState(
pursuer_states,
@@ -517,11 +519,11 @@ def sample_initial_state(self) -> DTCState:
target_vel,
)
- def sample_initial_obs(self, state: DTCState) -> Dict[str, DTCObs]:
+ def sample_initial_obs(self, state: DTCState) -> dict[str, DTCObs]:
return self._get_obs(state)
def step(
- self, state: DTCState, actions: Dict[str, DTCAction]
+ self, state: DTCState, actions: dict[str, DTCAction]
) -> M.JointTimestep[DTCState, DTCObs]:
clipped_actions = clip_actions(actions, self.action_spaces)
next_state = self._get_next_state(state, clipped_actions)
@@ -529,13 +531,13 @@ def step(
all_done, rewards = self._get_rewards(next_state)
terminations = {i: all_done for i in self.possible_agents}
truncations = {i: False for i in self.possible_agents}
- infos: Dict[str, Dict] = {i: {} for i in self.possible_agents}
+ infos: dict[str, dict] = {i: {} for i in self.possible_agents}
return M.JointTimestep(
next_state, obs, rewards, terminations, truncations, all_done, infos
)
def _get_next_state(
- self, state: DTCState, actions: Dict[str, DTCAction]
+ self, state: DTCState, actions: dict[str, DTCAction]
) -> DTCState:
for i in range(self.n_pursuers):
self.world.set_entity_state(f"pursuer_{i}", state.pursuer_states[i])
@@ -571,7 +573,7 @@ def _get_next_state(
state.target_vel,
)
- def _get_obs(self, state: DTCState) -> Dict[str, DTCObs]:
+ def _get_obs(self, state: DTCState) -> dict[str, DTCObs]:
observation = {}
for i in range(self.n_pursuers):
# getting the target engagement
@@ -622,7 +624,7 @@ def _get_obs(self, state: DTCState) -> Dict[str, DTCObs]:
engagement = sorted(
engagement, key=lambda t: float("inf") if t[1] == -1.0 else t[1]
)
- alphas, dists = list(zip(*engagement))
+ alphas, dists = list(zip(*engagement, strict=False))
angle_i = (
self.world.convert_angle_to_negpi_pi_interval(
@@ -663,7 +665,7 @@ def _get_obs(self, state: DTCState) -> Dict[str, DTCObs]:
def _engagement(
self, agent_i: np.ndarray, agent_j: np.ndarray, dist_norm_factor: float
- ) -> Tuple[Tuple[float, float], bool]:
+ ) -> tuple[tuple[float, float], bool]:
"""Get engagement between two agents.
Engagement here is the angle (in radians) from agent_i's current position and
@@ -679,19 +681,19 @@ def _engagement(
return (-1.0, -1.0), False
# Rotation matrix of yaw
- yaw = agent_i[2]
+ yaw = agent_i[ANGLE_IDX]
rot = np.array(
[[math.cos(yaw), math.sin(yaw)], [-math.sin(yaw), math.cos(yaw)]]
)
- rel_xy = agent_j[:2] - agent_i[:2]
+ rel_xy = agent_j[[X_IDX, Y_IDX]] - agent_i[[X_IDX, Y_IDX]]
rel_xy = rot.dot(rel_xy)
- alpha = math.atan2(rel_xy[1], rel_xy[0])
+ alpha = math.atan2(rel_xy[Y_IDX], rel_xy[X_IDX])
alpha = self.world.convert_angle_to_negpi_pi_interval(alpha)
return (alpha / math.pi, dist / dist_norm_factor), True
- def _get_rewards(self, state: DTCState) -> Tuple[bool, Dict[str, float]]:
+ def _get_rewards(self, state: DTCState) -> tuple[bool, dict[str, float]]:
done = False
- reward: Dict[str, float] = {}
+ reward: dict[str, float] = {}
# q_formation reward: [-1 * (n-1) / n, 3 * (n-1) / n]
q_formation = self._q_parameter(state) if self.use_q_reward else 0.0
for i in self.possible_agents:
@@ -713,7 +715,6 @@ def _get_rewards(self, state: DTCState) -> Tuple[bool, Dict[str, float]]:
def _q_parameter(self, state: DTCState) -> float:
"""Calculate Q-formation value."""
- # min = -1 * (n-1) / n, max = 3 * (n-1) / n
closest = self._get_closest_pursuer(state)
unit = self._get_unit_vectors(state)
Qk = 0.0
@@ -735,16 +736,16 @@ def _get_closest_pursuer(self, state: DTCState) -> int:
raise Exception("No closest index found. Something has gone wrong.")
return min_index
- def _get_unit_vectors(self, state: DTCState) -> List[List[float]]:
+ def _get_unit_vectors(self, state: DTCState) -> list[list[float]]:
"""Get unit vectors between target and each pursuer."""
unit = []
- q = state.target_state[:2]
+ q = state.target_state[[X_IDX, Y_IDX]]
for p in state.pursuer_states:
dist = self.world.euclidean_dist(p, state.target_state)
- unit.append([(q[0] - p[0]) / dist, (q[1] - p[1]) / dist])
+ unit.append([(q[X_IDX] - p[X_IDX]) / dist, (q[Y_IDX] - p[Y_IDX]) / dist])
return unit
- def _get_target_move_repulsive(self, state: DTCState) -> Tuple[float, float]:
+ def _get_target_move_repulsive(self, state: DTCState) -> tuple[float, float]:
xy_pos = state.target_state[:2]
x, y = xy_pos
@@ -753,7 +754,7 @@ def scale_fn(z):
final_vector = [0.0, 0.0]
for s in state.pursuer_states:
- vector = s[:2] - xy_pos
+ vector = s[[X_IDX, Y_IDX]] - xy_pos
final_vector = self._scale_vector(vector, scale_fn, final_vector)
# Find closest point on border then put it in to the vectorial sum
@@ -771,20 +772,20 @@ def scale_fn(z):
)
scaled_move_dir = [v / self._abs_sum(final_vector) for v in final_vector]
- dx, dy = scaled_move_dir[0], scaled_move_dir[1]
+ dx, dy = scaled_move_dir[X_IDX], scaled_move_dir[Y_IDX]
d = np.linalg.norm([dx, dy])
dx = float(state.target_vel * dx / d)
dy = float(state.target_vel * dy / d)
return dx, dy
- def _scale_vector(self, vector, scale_fn, final_vector, factor=1.0) -> List[float]:
+ def _scale_vector(self, vector, scale_fn, final_vector, factor=1.0) -> list[float]:
vec_sum = self._abs_sum(vector)
div = max(0.00001, vec_sum)
f = -factor * scale_fn(vec_sum) / div
vector = f * vector
- return [x + y for x, y in zip(final_vector, vector)]
+ return [x + y for x, y in zip(final_vector, vector, strict=False)]
- def _abs_sum(self, vector: List[float]) -> float:
+ def _abs_sum(self, vector: list[float]) -> float:
return sum([abs(x) for x in vector])
def _target_distance(self, state: DTCState, pursuer_idx: int) -> float:
diff --git a/posggym/envs/continuous/predator_prey_continuous.py b/posggym/envs/continuous/predator_prey_continuous.py
index d96d754..6753e18 100644
--- a/posggym/envs/continuous/predator_prey_continuous.py
+++ b/posggym/envs/continuous/predator_prey_continuous.py
@@ -1,8 +1,9 @@
"""The Continuous Predator-Prey Environment."""
+from __future__ import annotations
import math
from itertools import product
-from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union, cast
+from typing import ClassVar, NamedTuple, cast
import numpy as np
from gymnasium import spaces
@@ -11,11 +12,18 @@
from posggym import logger
from posggym.core import DefaultEnv
from posggym.envs.continuous.core import (
+ ANGLE_IDX,
+ X_IDX,
+ Y_IDX,
CircleEntity,
+ ControlType,
PMBodyState,
Position,
SquareContinuousWorld,
clip_actions,
+ generate_action_space,
+ generate_parameters,
+ scale_action,
)
from posggym.utils import seeding
@@ -132,9 +140,8 @@ class PredatorPreyContinuousEnv(DefaultEnv[PPState, PPObs, PPAction]):
worlds (this can be done by manually specifying a value for `max_episode_steps` when
creating the environment with `posggym.make`).
- Arguments
+ Arguments:
---------
-
- `world` - the world layout to use. This can either be a string specifying one of
the supported worlds, or a custom :class:`PPWorld` object (default = `"10x10"`).
- `num_predators` - the number of predator (and thus controlled agents)
@@ -192,32 +199,43 @@ class PredatorPreyContinuousEnv(DefaultEnv[PPState, PPObs, PPAction]):
---------
- Ming Tan. 1993. Multi-Agent Reinforcement Learning: Independent vs. Cooperative
Agents. In Proceedings of the Tenth International Conference on Machine Learning.
- 330–337.
+ 330-337.
- J. Z. Leibo, V. F. Zambaldi, M. Lanctot, J. Marecki, and T. Graepel. 2017.
Multi-Agent Reinforcement Learning in Sequential Social Dilemmas. In AAMAS,
- Vol. 16. ACM, 464–473
+ Vol. 16. ACM, 464-473
- Lowe, Ryan, Yi I. Wu, Aviv Tamar, Jean Harb, OpenAI Pieter Abbeel, and Igor
Mordatch. 2017. “Multi-Agent Actor-Critic for Mixed Cooperative-Competitive
Environments.” Advances in Neural Information Processing Systems 30.
"""
- metadata = {
+ metadata: ClassVar[dict] = {
"render_modes": ["human", "rgb_array"],
"render_fps": 15,
}
def __init__(
self,
- world: Union[str, "PPWorld"] = "10x10",
+ world: str | PPWorld = "10x10",
num_predators: int = 2,
num_prey: int = 3,
cooperative: bool = True,
- prey_strength: Optional[int] = None,
+ prey_strength: int | None = None,
obs_dist: float = 4,
n_sensors: int = 16,
- render_mode: Optional[str] = None,
- ):
+ control_type: ControlType | str = ControlType.VelocityNonHolonomoic,
+ obs_self_model: bool = False,
+ render_mode: str | None = None,
+ ) -> None:
+ if isinstance(control_type, str):
+ try:
+ control_type = ControlType.from_str(control_type)
+ except ValueError:
+ logger.warning(
+ "Invalid control type, defaulting to VelocityNonHolonomoic"
+ )
+ control_type = ControlType.VelocityNonHolonomoic
+
super().__init__(
PredatorPreyContinuousModel(
world=world,
@@ -227,6 +245,8 @@ def __init__(
prey_strength=prey_strength,
obs_dist=obs_dist,
n_sensors=n_sensors,
+ control_type=control_type,
+ obs_self_model=obs_self_model,
),
render_mode=render_mode,
)
@@ -239,7 +259,7 @@ def __init__(
def render(self):
if self.render_mode is None:
assert self.spec is not None
- logger.warn(
+ logger.warning(
"You are calling render method without specifying any render mode. "
"You can specify the render_mode at initialization, "
f'e.g. posggym.make("{self.spec.id}", render_mode="rgb_array")'
@@ -304,7 +324,7 @@ def _render_img(self):
n_sensors = model.n_sensors
for i, obs_i in self._last_obs.items():
p_state = state.predator_states[int(i)]
- x, y, agent_angle = p_state[:3]
+ x, y, agent_angle = p_state[[X_IDX, Y_IDX, ANGLE_IDX]]
angle_inc = 2 * math.pi / n_sensors
for k in range(n_sensors):
dist = min(obs_i[k], obs_i[n_sensors + k], obs_i[2 * n_sensors + k])
@@ -312,7 +332,7 @@ def _render_img(self):
end_x = x + dist * math.cos(angle)
end_y = y + dist * math.sin(angle)
scaled_start = (int(x * scale_factor), int(y * scale_factor))
- scaled_end = int(end_x * scale_factor), (end_y * scale_factor)
+ scaled_end = (int(end_x * scale_factor), int(end_y * scale_factor))
pygame.draw.line(
self.window_surface, pygame.Color("red"), scaled_start, scaled_end
@@ -368,18 +388,21 @@ class PredatorPreyContinuousModel(M.POSGModel[PPState, PPObs, PPAction]):
PREDATOR_COLOR = (55, 155, 205, 255) # Blueish
PREY_COLOR = (110, 55, 155, 255) # purpleish
+ MAX_AGENTS = 8
def __init__(
self,
- world: Union[str, "PPWorld"],
+ world: str | PPWorld,
num_predators: int,
num_prey: int,
cooperative: bool,
- prey_strength: Optional[int],
+ prey_strength: int | None,
obs_dist: float,
n_sensors: int,
- ):
- assert 1 < num_predators <= 8
+ control_type: ControlType,
+ obs_self_model: bool,
+ ) -> None:
+ assert 1 < num_predators <= self.MAX_AGENTS
assert num_prey > 0
assert obs_dist > 0
@@ -418,12 +441,12 @@ def __init__(
self.n_sensors = n_sensors
# capture radius large enough so prey in corner can be captured by 3 predators
self.prey_capture_dist = 2.75 * self.world.agent_radius
- self.possible_agents = tuple((str(x) for x in range(self.num_predators)))
+ self.possible_agents = tuple(str(x) for x in range(self.num_predators))
+ self.obs_self_model = obs_self_model
def _pos_space(n_agents: int):
# x, y, angle, vx, vy, vangle
# stacked n_agents time
- # shape = (n_agents, 6)
size, angle = self.world.size, 2 * math.pi
low = np.array([-1, -1, -angle, -1, -1, -angle], dtype=np.float32)
high = np.array(
@@ -445,20 +468,45 @@ def _pos_space(n_agents: int):
)
)
+ self.control_type = control_type
+ self.dt = 1.0
+ self.substeps = 10
+
# can turn up to 45 degrees per step
self.dyaw_limit = math.pi / 4
+ self.dvel_limit = 1.0
+
+ self.fyaw_limit = math.pi
+ self.fvel_limit = 3.0
+
+ self.action_spaces_per_control = generate_action_space(
+ self.possible_agents,
+ self.dyaw_limit,
+ self.dvel_limit,
+ self.fyaw_limit,
+ self.fvel_limit,
+ )
+
self.action_spaces = {
- i: spaces.Box(
- low=np.array([-self.dyaw_limit, 0.0], dtype=np.float32),
- high=np.array([self.dyaw_limit, 1.0], dtype=np.float32),
- )
+ i: spaces.Box(np.array([-1, -1]), np.array([1, 1]))
for i in self.possible_agents
}
+ self.control_types = {i: self.control_type for i in self.possible_agents}
+ self.init_kinematics()
+
self.obs_dim = self.n_sensors * 3
self.observation_spaces = {
i: spaces.Box(
- low=0.0, high=self.obs_dist, shape=(self.obs_dim,), dtype=np.float32
+ low=np.array(
+ [0.0] * self.obs_dim + ([0] if self.obs_self_model else [])
+ ),
+ high=np.array(
+ [self.obs_dist] * self.obs_dim
+ + ([len(ControlType)] if self.obs_self_model else [])
+ ),
+ shape=(self.obs_dim + int(self.obs_self_model),),
+ dtype=np.float32,
)
for i in self.possible_agents
}
@@ -474,7 +522,7 @@ def _pos_space(n_agents: int):
self.world.add_entity(f"prey_{i}", None, color=self.PREY_COLOR)
@property
- def reward_ranges(self) -> Dict[str, Tuple[float, float]]:
+ def reward_ranges(self) -> dict[str, tuple[float, float]]:
return {i: (0.0, self.R_MAX) for i in self.possible_agents}
@property
@@ -483,9 +531,14 @@ def rng(self) -> seeding.RNG:
self._rng, seed = seeding.std_random()
return self._rng
- def get_agents(self, state: PPState) -> List[str]:
+ def get_agents(self, state: PPState) -> list[str]:
return list(self.possible_agents)
+ def init_kinematics(self):
+ self.kinematic_parameters = {
+ i: generate_parameters(self.control_types[i]) for i in self.possible_agents
+ }
+
def sample_initial_state(self) -> PPState:
predator_positions = [*self.world.predator_start_positions]
self.rng.shuffle(predator_positions)
@@ -509,11 +562,11 @@ def sample_initial_state(self) -> PPState:
np.zeros(self.num_prey, dtype=np.int8),
)
- def sample_initial_obs(self, state: PPState) -> Dict[str, PPObs]:
+ def sample_initial_obs(self, state: PPState) -> dict[str, PPObs]:
return self.get_obs(state)
def step(
- self, state: PPState, actions: Dict[str, PPAction]
+ self, state: PPState, actions: dict[str, PPAction]
) -> M.JointTimestep[PPState, PPObs]:
clipped_actions = clip_actions(actions, self.action_spaces)
@@ -525,7 +578,7 @@ def step(
truncated = {i: False for i in self.possible_agents}
terminated = {i: all_done for i in self.possible_agents}
- info: Dict[str, Dict] = {i: {} for i in self.possible_agents}
+ info: dict[str, dict] = {i: {} for i in self.possible_agents}
if all_done:
for i in self.possible_agents:
info[i]["outcome"] = M.Outcome.WIN
@@ -534,7 +587,7 @@ def step(
next_state, obs, rewards, terminated, truncated, all_done, info
)
- def _get_next_state(self, state: PPState, actions: Dict[str, PPAction]) -> PPState:
+ def _get_next_state(self, state: PPState, actions: dict[str, PPAction]) -> PPState:
prey_move_angles = self._get_prey_move_angles(state)
# apply prey actions
@@ -553,17 +606,29 @@ def _get_next_state(self, state: PPState, actions: Dict[str, PPAction]) -> PPSta
# apply predator actions
for i in range(self.num_predators):
- action = actions[str(i)]
+ action_i = actions[str(i)]
+
+ action_scaled = scale_action(
+ action_i,
+ self.action_spaces[str(i)],
+ self.action_spaces_per_control[self.control_types[str(i)]][str(i)],
+ )
+
self.world.set_entity_state(f"pred_{i}", state.predator_states[i])
- angle = state.predator_states[i][2] + action[0]
- self.world.update_entity_state(
- f"pred_{i}",
- angle=angle,
- vel=self.world.linear_to_xy_velocity(action[1], angle),
+
+ result = self.world.compute_vel_force(
+ self.control_types[str(i)],
+ state.predator_states[i][ANGLE_IDX],
+ current_vel=None,
+ action_i=action_scaled,
+ vel_limit_norm=None,
+ kinematic_parameters=self.kinematic_parameters[str(i)],
)
+ self.world.update_entity_state(f"pred_{i}", **result)
+
# simulate
- self.world.simulate(1.0 / 10, 10)
+ self.world.simulate(self.dt / self.substeps, self.substeps)
# extract next state
next_pred_states = np.array(
@@ -585,7 +650,9 @@ def _get_next_state(self, state: PPState, actions: Dict[str, PPAction]) -> PPSta
next_prey_states[i] = [-1.0, -1.0, 0.0, 0.0, 0.0, 0.0]
continue
pred_dists = np.linalg.norm(
- next_prey_states[i][:2] - next_pred_states[:, :2], axis=1
+ next_prey_states[i][[X_IDX, Y_IDX]]
+ - next_pred_states[:, [X_IDX, Y_IDX]],
+ axis=1,
)
if (
np.where(pred_dists <= self.prey_capture_dist, 1, 0).sum()
@@ -596,7 +663,7 @@ def _get_next_state(self, state: PPState, actions: Dict[str, PPAction]) -> PPSta
return PPState(next_pred_states, next_prey_states, next_prey_caught)
- def _get_prey_move_angles(self, state: PPState) -> List[float]:
+ def _get_prey_move_angles(self, state: PPState) -> list[float]:
prey_actions = []
active_prey = self.num_prey - state.prey_caught.sum()
for i in range(self.num_prey):
@@ -608,14 +675,17 @@ def _get_prey_move_angles(self, state: PPState) -> List[float]:
prey_state = state.prey_states[i]
# try move away from predators
pred_states = state.predator_states
- pred_dists = np.linalg.norm(prey_state[:2] - pred_states[:, :2], axis=1)
+ pred_dists = np.linalg.norm(
+ prey_state[[X_IDX, Y_IDX]] - pred_states[:, [X_IDX, Y_IDX]], axis=1
+ )
min_pred_dist = pred_dists.min()
if min_pred_dist <= self.prey_obs_dist:
# get any predators within obs distance
pred_idx = self.rng.choice(np.where(pred_dists == min_pred_dist)[0])
pred_state = state.predator_states[pred_idx]
angle = math.atan2(
- prey_state[1] - pred_state[1], prey_state[0] - pred_state[0]
+ prey_state[Y_IDX] - pred_state[Y_IDX],
+ prey_state[X_IDX] - pred_state[X_IDX],
)
prey_actions.append(angle)
continue
@@ -628,7 +698,7 @@ def _get_prey_move_angles(self, state: PPState) -> List[float]:
# try move away from prey
prey_dists = [
- np.linalg.norm(prey_state[:2] - p[:2])
+ np.linalg.norm(prey_state[[X_IDX, Y_IDX]] - p[[X_IDX, Y_IDX]])
for j, p in enumerate(state.prey_states)
if not state.prey_caught[j] and j != i
]
@@ -639,8 +709,8 @@ def _get_prey_move_angles(self, state: PPState) -> List[float]:
)
other_prey_state = state.prey_states[other_prey_idx]
angle = math.atan2(
- prey_state[1] - other_prey_state[1],
- prey_state[0] - other_prey_state[0],
+ prey_state[Y_IDX] - other_prey_state[Y_IDX],
+ prey_state[X_IDX] - other_prey_state[X_IDX],
)
prey_actions.append(angle)
continue
@@ -651,14 +721,13 @@ def _get_prey_move_angles(self, state: PPState) -> List[float]:
return prey_actions
- def get_obs(self, state: PPState) -> Dict[str, PPObs]:
+ def get_obs(self, state: PPState) -> dict[str, PPObs]:
return {i: self._get_local_obs(i, state) for i in self.possible_agents}
def _get_local_obs(self, agent_id: str, state: PPState) -> np.ndarray:
state_i = state.predator_states[int(agent_id)]
- pos_i = (state_i[0], state_i[1], state_i[2])
-
- prey_coords = state.prey_states[state.prey_caught == 0, :2]
+ pos_i = (state_i[X_IDX], state_i[Y_IDX], state_i[ANGLE_IDX])
+ prey_coords = state.prey_states[state.prey_caught == 0][:, [X_IDX, Y_IDX]]
prey_obs, _ = self.world.check_collision_circular_rays(
pos_i,
self.obs_dist,
@@ -671,7 +740,7 @@ def _get_local_obs(self, agent_id: str, state: PPState) -> np.ndarray:
mask = np.ones(len(state.predator_states), dtype=bool)
mask[int(agent_id)] = False
- pred_coords = state.predator_states[mask, :2]
+ pred_coords = state.predator_states[mask][:, [X_IDX, Y_IDX]]
pred_obs, _ = self.world.check_collision_circular_rays(
pos_i,
self.obs_dist,
@@ -703,9 +772,12 @@ def _get_local_obs(self, agent_id: str, state: PPState) -> np.ndarray:
)
obs[idx] = np.minimum(min_val, obs[idx])
+ if self.obs_self_model:
+ obs[-1] = int(self.control_types[agent_id])
+
return obs
- def _get_rewards(self, state: PPState, next_state: PPState) -> Dict[str, float]:
+ def _get_rewards(self, state: PPState, next_state: PPState) -> dict[str, float]:
new_caught_prey = []
for i in range(self.num_prey):
if not state.prey_caught[i] and next_state.prey_caught[i]:
@@ -721,7 +793,9 @@ def _get_rewards(self, state: PPState, next_state: PPState) -> Dict[str, float]:
rewards = {i: 0.0 for i in self.possible_agents}
pred_states = next_state.predator_states
for prey_state in new_caught_prey:
- pred_dists = np.linalg.norm(prey_state[:2] - pred_states[:, :2], axis=1)
+ pred_dists = np.linalg.norm(
+ prey_state[[X_IDX, Y_IDX]] - pred_states[:, [X_IDX, Y_IDX]], axis=1
+ )
involved_predators = np.where(pred_dists <= self.prey_capture_dist)[0]
predator_reward = self.per_prey_reward / len(involved_predators)
for i in involved_predators:
@@ -733,14 +807,16 @@ def _get_rewards(self, state: PPState, next_state: PPState) -> Dict[str, float]:
class PPWorld(SquareContinuousWorld):
"""A continuous 2D world for the Predator-Prey Problem."""
+ MIN_GRID_SIZE = 3
+
def __init__(
self,
size: int,
- blocks: Optional[List[CircleEntity]],
- predator_start_positions: Optional[List[Position]] = None,
- prey_start_positions: Optional[List[Position]] = None,
- ):
- assert size >= 3
+ blocks: list[CircleEntity] | None,
+ predator_start_positions: list[Position] | None = None,
+ prey_start_positions: list[Position] | None = None,
+ ) -> None:
+ assert size >= self.MIN_GRID_SIZE
super().__init__(
size=size,
blocks=blocks,
@@ -794,7 +870,7 @@ def __init__(
self.prey_start_positions = prey_start_positions
- def copy(self) -> "PPWorld":
+ def copy(self) -> PPWorld:
world = PPWorld(
size=int(self.size),
blocks=self.blocks,
@@ -862,7 +938,7 @@ def parse_world_str(world_str: str) -> PPWorld:
assert len(row_strs) == len(row_strs[0])
size = len(row_strs)
- blocks: Set[CircleEntity] = set()
+ blocks: set[CircleEntity] = set()
predator_coords = set()
prey_coords = set()
for r, c in product(range(size), repeat=2):
@@ -954,7 +1030,6 @@ def get_20x20_blocks_world() -> PPWorld:
return get_default_world(20, include_blocks=True)
-# world: world_make_fn
SUPPORTED_WORLDS = {
"5x5": get_5x5_world,
"5x5Blocks": get_5x5_blocks_world,
@@ -965,3 +1040,15 @@ def get_20x20_blocks_world() -> PPWorld:
"20x20": get_20x20_world,
"20x20Blocks": get_20x20_blocks_world,
}
+
+if __name__ == "__main__":
+ from posggym.utils.run_random_agents import run_random
+
+ run_random(
+ PredatorPreyContinuousEnv(
+ render_mode="human",
+ obs_self_model=True,
+ ),
+ num_episodes=5,
+ max_episode_steps=100,
+ )
diff --git a/posggym/envs/continuous/pursuit_evasion_continuous.py b/posggym/envs/continuous/pursuit_evasion_continuous.py
index 15ed6a6..bdc47e8 100644
--- a/posggym/envs/continuous/pursuit_evasion_continuous.py
+++ b/posggym/envs/continuous/pursuit_evasion_continuous.py
@@ -1,18 +1,9 @@
"""The Pursuit-Evasion World World Environment."""
+from __future__ import annotations
+
import math
from itertools import product
-from typing import (
- Any,
- Callable,
- Dict,
- List,
- NamedTuple,
- Optional,
- Set,
- Tuple,
- Union,
- cast,
-)
+from typing import TYPE_CHECKING, Any, ClassVar, NamedTuple, cast
import numpy as np
from gymnasium import spaces
@@ -21,17 +12,28 @@
from posggym import logger
from posggym.core import DefaultEnv
from posggym.envs.continuous.core import (
+ ANGLE_IDX,
+ X_IDX,
+ Y_IDX,
CollisionType,
+ ControlType,
Coord,
FloatCoord,
PMBodyState,
SquareContinuousWorld,
clip_actions,
+ generate_action_space,
generate_interior_walls,
+ generate_parameters,
+ scale_action,
)
from posggym.utils import seeding
+if TYPE_CHECKING:
+ from collections.abc import Callable
+
+
class PEState(NamedTuple):
"""Environment state in Pursuit Evastion problem."""
@@ -163,9 +165,8 @@ class PursuitEvasionContinuousEnv(DefaultEnv):
by manually specifying a value for `max_episode_steps` when creating the environment
with `posggym.make`).
- Arguments
+ Arguments:
---------
-
- `world` - the world layout to use. This can either be a string specifying one of
the supported worlds (see SUPPORTED_WORLDS), or a custom :class:`PEWorld`
object (default = `"16x16"`).
@@ -210,7 +211,7 @@ class PursuitEvasionContinuousEnv(DefaultEnv):
)
```
- References
+ References:
----------
- [This Pursuit-Evasion implementation is directly inspired by the problem] Seaman,
Iris Rubi, Jan-Willem van de Meent, and David Wingate. 2018. “Nested Reasoning
@@ -223,22 +224,33 @@ class PursuitEvasionContinuousEnv(DefaultEnv):
"""
- metadata = {
+ metadata: ClassVar[dict] = {
"render_modes": ["human", "rgb_array"],
"render_fps": 15,
}
def __init__(
self,
- world: Union[str, "PEWorld"] = "16x16",
- max_obs_distance: Optional[float] = None,
+ world: str | PEWorld = "16x16",
+ max_obs_distance: float | None = None,
fov: float = np.pi / 3,
n_sensors: int = 16,
normalize_reward: bool = True,
use_progress_reward: bool = True,
- render_mode: Optional[str] = None,
+ obs_self_model: bool = False,
+ control_type: ControlType | str = ControlType.VelocityNonHolonomoic,
+ render_mode: str | None = None,
**kwargs,
- ):
+ ) -> None:
+ if isinstance(control_type, str):
+ try:
+ control_type = ControlType.from_str(control_type)
+ except ValueError:
+ logger.warning(
+ "Invalid control type, defaulting to VelocityNonHolonomoic"
+ )
+ control_type = ControlType.VelocityNonHolonomoic
+
model = PursuitEvasionContinuousModel(
world,
max_obs_distance=max_obs_distance,
@@ -246,9 +258,14 @@ def __init__(
use_progress_reward=use_progress_reward,
fov=fov,
n_sensors=n_sensors,
+ control_type=control_type,
+ obs_self_model=obs_self_model,
**kwargs,
)
- super().__init__(model, render_mode=render_mode)
+ super().__init__(
+ model,
+ render_mode=render_mode,
+ )
self.window_surface = None
self.blocked_surface = None
self.clock = None
@@ -258,8 +275,8 @@ def __init__(
self.fov = fov
def reset(
- self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None
- ) -> Tuple[Dict[str, M.ObsType], Dict[str, Dict]]:
+ self, *, seed: int | None = None, options: dict[str, Any] | None = None
+ ) -> tuple[dict[str, M.ObsType], dict[str, dict]]:
# reset renderer since goal location can change between episodes
self._renderer = None
return super().reset(seed=seed, options=options)
@@ -267,10 +284,11 @@ def reset(
def render(self):
if self.render_mode is None:
assert self.spec is not None
- logger.warn(
+ logger.warning(
"You are calling render method without specifying any render mode. "
"You can specify the render_mode at initialization, "
- f'e.g. posggym.make("{self.spec.id}", render_mode="rgb_array")'
+ 'e.g. posggym.make("%s", render_mode="rgb_array")',
+ self.spec.id,
)
return
return self._render_img()
@@ -348,7 +366,7 @@ def _render_img(self):
(model.PURSUER_IDX, state.pursuer_state),
]:
obs_i = self._last_obs[str(idx)]
- x, y, agent_angle = p_state[:3]
+ x, y, agent_angle = p_state[[X_IDX, Y_IDX, ANGLE_IDX]]
angles = np.linspace(
-self.fov / 2, self.fov / 2, n_sensors, endpoint=False, dtype=np.float32
@@ -359,7 +377,7 @@ def _render_img(self):
end_x = x + dist * math.cos(angle)
end_y = y + dist * math.sin(angle)
scaled_start = (int(x * scale_factor), int(y * scale_factor))
- scaled_end = int(end_x * scale_factor), (end_y * scale_factor)
+ scaled_end = (int(end_x * scale_factor), int(end_y * scale_factor))
pygame.draw.line(
self.window_surface,
pygame.Color("red"),
@@ -400,7 +418,7 @@ def close(self) -> None:
class PursuitEvasionContinuousModel(M.POSGModel[PEState, PEObs, PEAction]):
"""Continuous Pursuit-Evasion Model.
- Arguments
+ Arguments:
---------
world : str, PEWorld
the world layout to use. This can either be a string specifying one of
@@ -442,13 +460,15 @@ class PursuitEvasionContinuousModel(M.POSGModel[PEState, PEObs, PEAction]):
def __init__(
self,
- world: Union[str, "PEWorld"],
- max_obs_distance: Optional[float] = None,
+ world: str | PEWorld,
+ max_obs_distance: float | None = None,
fov: float = np.pi / 3,
n_sensors: int = 16,
normalize_reward: bool = True,
use_progress_reward: bool = True,
- ):
+ obs_self_model: bool = False,
+ control_type: ControlType = ControlType.VelocityNonHolonomoic,
+ ) -> None:
assert 0 < fov < 2 * np.pi, "fov must be in (0, 2 * pi)"
assert n_sensors > 0, "n_sensors must be positive"
@@ -468,6 +488,10 @@ def __init__(
self._normalize_reward = normalize_reward
self._use_progress_reward = use_progress_reward
self.fov = fov
+ self.control_type = control_type
+ self.obs_self_model = obs_self_model
+ self.dt = 1.0
+ self.substeps = 10
self._max_sp_distance = self.world.get_max_shortest_path_distance()
self._max_raw_return = self.R_EVASION
@@ -496,14 +520,26 @@ def _coord_space():
# can turn by up to 45 degrees per timestep
self.dyaw_limit = math.pi / 4
+ self.dvel_limit = 1.0
+
+ self.fyaw_limit = math.pi
+ self.fvel_limit = 3.0
+
+ self.action_spaces_per_control = generate_action_space(
+ self.possible_agents,
+ self.dyaw_limit,
+ self.dvel_limit,
+ self.fyaw_limit,
+ self.fvel_limit,
+ )
+
self.action_spaces = {
- i: spaces.Box(
- low=np.array([-self.dyaw_limit, 0.0], dtype=np.float32),
- high=np.array([self.dyaw_limit, 1.0], dtype=np.float32),
- )
+ i: spaces.Box(np.array([-1, -1]), np.array([1, 1]))
for i in self.possible_agents
}
+ self.control_types = {i: self.control_type for i in self.possible_agents}
+ self.init_kinematics()
self.obs_dist = self.max_obs_distance
self.n_sensors = n_sensors
self.sensor_obs_dim = self.n_sensors * 2
@@ -513,9 +549,14 @@ def _coord_space():
size = self.world.size
self.observation_spaces = {
i: spaces.Box(
- low=np.array([*sensor_low, 0, 0, 0, 0, 0, 0, 0], dtype=np.float32),
+ low=np.array(
+ [*sensor_low, 0, 0, 0, 0, 0, 0, 0]
+ + ([0] if self.obs_self_model else []),
+ dtype=np.float32,
+ ),
high=np.array(
- [*sensor_high, 1, size, size, size, size, size, size],
+ [*sensor_high, 1, size, size, size, size, size, size]
+ + ([len(ControlType)] if self.obs_self_model else []),
dtype=np.float32,
),
dtype=np.float32,
@@ -529,7 +570,7 @@ def _coord_space():
self.world.add_entity("evader", None, color=self.EVADER_COLOR)
@property
- def reward_ranges(self) -> Dict[str, Tuple[float, float]]:
+ def reward_ranges(self) -> dict[str, tuple[float, float]]:
max_reward = self.R_EVASION
if self._use_progress_reward:
max_reward += self.R_PROGRESS
@@ -537,39 +578,44 @@ def reward_ranges(self) -> Dict[str, Tuple[float, float]]:
max_reward = self._get_normalized_reward(max_reward)
return {i: (-max_reward, max_reward) for i in self.possible_agents}
+ def init_kinematics(self):
+ self.kinematic_parameters = {
+ i: generate_parameters(self.control_types[i]) for i in self.possible_agents
+ }
+
@property
def rng(self) -> seeding.RNG:
if self._rng is None:
self._rng, seed = seeding.std_random()
return self._rng
- def get_agents(self, state: PEState) -> List[str]:
+ def get_agents(self, state: PEState) -> list[str]:
return list(self.possible_agents)
def sample_initial_state(self) -> PEState:
evader_coord = self.rng.choice(self.world.evader_start_coords)
evader_state = np.zeros(PMBodyState.num_features(), dtype=np.float32)
- evader_state[:2] = evader_coord
+ evader_state[[X_IDX, Y_IDX]] = evader_coord
pursuer_coord = self.rng.choice(self.world.pursuer_start_coords)
pursuer_state = np.zeros(PMBodyState.num_features(), dtype=np.float32)
- pursuer_state[:2] = pursuer_coord
+ pursuer_state[[X_IDX, Y_IDX]] = pursuer_coord
goal_coord = self.rng.choice(self.world.get_goal_coords(evader_coord))
return PEState(
evader_state,
pursuer_state,
- evader_state[:2],
- pursuer_state[:2],
+ evader_state[[X_IDX, Y_IDX]],
+ pursuer_state[[X_IDX, Y_IDX]],
np.array(goal_coord, dtype=np.float32),
self.world.get_shortest_path_distance(evader_coord, goal_coord),
)
- def sample_initial_obs(self, state: PEState) -> Dict[str, PEObs]:
+ def sample_initial_obs(self, state: PEState) -> dict[str, PEObs]:
return self._get_obs(state)[0]
def step(
- self, state: PEState, actions: Dict[str, PEAction]
+ self, state: PEState, actions: dict[str, PEAction]
) -> M.JointTimestep[PEState, PEObs]:
clipped_actions = clip_actions(actions, self.action_spaces)
next_state = self._get_next_state(state, clipped_actions)
@@ -579,7 +625,7 @@ def step(
all_done = self._is_done(next_state, evader_seen)
terminated = {i: all_done for i in self.possible_agents}
truncated = {i: False for i in self.possible_agents}
- info: Dict[str, Dict] = {i: {} for i in self.possible_agents}
+ info: dict[str, dict] = {i: {} for i in self.possible_agents}
if all_done:
for i, outcome in self._get_outcome(next_state, evader_seen).items():
info[i]["outcome"] = outcome
@@ -588,23 +634,52 @@ def step(
next_state, obs, rewards, terminated, truncated, all_done, info
)
- def _get_next_state(self, state: PEState, actions: Dict[str, PEAction]) -> PEState:
+ def _get_next_state(self, state: PEState, actions: dict[str, PEAction]) -> PEState:
evader_a = actions[str(self.EVADER_IDX)]
pursuer_a = actions[str(self.PURSUER_IDX)]
-
self.world.set_entity_state("pursuer", state.pursuer_state)
self.world.set_entity_state("evader", state.evader_state)
- pursuer_angle = state.pursuer_state[2] + pursuer_a[0]
- pursuer_vel = self.world.linear_to_xy_velocity(pursuer_a[1], pursuer_angle)
- self.world.update_entity_state("pursuer", angle=pursuer_angle, vel=pursuer_vel)
+ pursuer_a_scaled = scale_action(
+ pursuer_a,
+ self.action_spaces[str(self.PURSUER_IDX)],
+ self.action_spaces_per_control[self.control_types[str(self.PURSUER_IDX)]][
+ str(self.PURSUER_IDX)
+ ],
+ )
- evader_angle = state.evader_state[2] + evader_a[0]
- evader_vel = self.world.linear_to_xy_velocity(evader_a[1], evader_angle)
- self.world.update_entity_state("evader", angle=evader_angle, vel=evader_vel)
+ result = self.world.compute_vel_force(
+ self.control_types[str(self.PURSUER_IDX)],
+ state.pursuer_state[ANGLE_IDX],
+ current_vel=None,
+ action_i=pursuer_a_scaled,
+ vel_limit_norm=None,
+ kinematic_parameters=self.kinematic_parameters[str(self.PURSUER_IDX)],
+ )
+
+ self.world.update_entity_state("pursuer", **result)
+
+ evader_a_scaled = scale_action(
+ evader_a,
+ self.action_spaces[str(self.EVADER_IDX)],
+ self.action_spaces_per_control[self.control_types[str(self.EVADER_IDX)]][
+ str(self.EVADER_IDX)
+ ],
+ )
+
+ result = self.world.compute_vel_force(
+ self.control_types[str(self.EVADER_IDX)],
+ state.evader_state[ANGLE_IDX],
+ current_vel=None,
+ action_i=evader_a_scaled,
+ vel_limit_norm=None,
+ kinematic_parameters=self.kinematic_parameters[str(self.PURSUER_IDX)],
+ )
+
+ self.world.update_entity_state("evader", **result)
# simulate
- self.world.simulate(1.0 / 10, 10)
+ self.world.simulate(self.dt / self.substeps, self.substeps)
pursuer_next_state = np.array(
self.world.get_entity_state("pursuer"),
@@ -614,8 +689,8 @@ def _get_next_state(self, state: PEState, actions: Dict[str, PEAction]) -> PESta
self.world.get_entity_state("evader"),
dtype=np.float32,
)
- evader_coord = (evader_next_state[0], evader_next_state[1])
- goal_coord = (state.evader_goal_coord[0], state.evader_goal_coord[1])
+ evader_coord = (evader_next_state[X_IDX], evader_next_state[Y_IDX])
+ goal_coord = (state.evader_goal_coord[X_IDX], state.evader_goal_coord[Y_IDX])
min_sp_distance = min(
state.min_goal_dist,
self.world.get_shortest_path_distance(evader_coord, goal_coord),
@@ -630,7 +705,7 @@ def _get_next_state(self, state: PEState, actions: Dict[str, PEAction]) -> PESta
min_sp_distance,
)
- def _get_obs(self, state: PEState) -> Tuple[Dict[str, PEObs], bool]:
+ def _get_obs(self, state: PEState) -> tuple[dict[str, PEObs], bool]:
evader_obs, _ = self._get_agent_obs(state, evader=True)
pursuer_obs, evader_seen = self._get_agent_obs(state, evader=False)
@@ -643,22 +718,21 @@ def _get_agent_obs(
self,
state: PEState,
evader: bool,
- ) -> Tuple[np.ndarray, bool]:
+ ) -> tuple[np.ndarray, bool]:
if evader:
agent_pos = (
- state.evader_state[0],
- state.evader_state[1],
- state.evader_state[2],
+ state.evader_state[X_IDX],
+ state.evader_state[Y_IDX],
+ state.evader_state[ANGLE_IDX],
)
- opp_coord = (state.pursuer_state[0], state.pursuer_state[1])
+ opp_coord = (state.pursuer_state[X_IDX], state.pursuer_state[Y_IDX])
else:
agent_pos = (
- state.pursuer_state[0],
- state.pursuer_state[1],
- state.pursuer_state[2],
+ state.pursuer_state[X_IDX],
+ state.pursuer_state[Y_IDX],
+ state.pursuer_state[ANGLE_IDX],
)
- opp_coord = (state.evader_state[0], state.evader_state[1])
-
+ opp_coord = (state.evader_state[X_IDX], state.evader_state[Y_IDX])
ray_dists, ray_col_type = self.world.check_collision_circular_rays(
agent_pos,
self.max_obs_distance,
@@ -692,11 +766,16 @@ def _get_agent_obs(
else:
obs[aux_obs_idx + 5 : aux_obs_idx + 7] = [0, 0]
+ if self.obs_self_model:
+ if evader:
+ obs[-1] = int(self.control_types[str(self.EVADER_IDX)])
+ else:
+ obs[-1] = int(self.control_types[str(self.PURSUER_IDX)])
return obs, other_agent_seen
def _get_reward(
self, state: PEState, next_state: PEState, evader_seen: bool
- ) -> Dict[str, float]:
+ ) -> dict[str, float]:
evader_reward = 0.0
if self._use_progress_reward and next_state.min_goal_dist < state.min_goal_dist:
evader_reward += self.R_PROGRESS
@@ -726,7 +805,7 @@ def _is_done(self, state: PEState, evader_seen: bool) -> bool:
or self.world.agents_collide(state.evader_state, state.evader_goal_coord)
)
- def _get_outcome(self, state: PEState, evader_seen: bool) -> Dict[str, M.Outcome]:
+ def _get_outcome(self, state: PEState, evader_seen: bool) -> dict[str, M.Outcome]:
evader_id, pursuer_id = str(self.EVADER_IDX), str(self.PURSUER_IDX)
if evader_seen or self.world.agents_collide(
state.evader_state, state.pursuer_state
@@ -745,7 +824,7 @@ def _get_normalized_reward(self, reward: float) -> float:
class PEWorld(SquareContinuousWorld):
"""A world for the Pursuit Evasion Problem.
- Arguments
+ Arguments:
---------
size : int
height and width of the world.
@@ -763,11 +842,11 @@ class PEWorld(SquareContinuousWorld):
def __init__(
self,
size: int,
- blocked_coords: Set[Coord],
- goal_coords_map: Dict[FloatCoord, List[FloatCoord]],
- evader_start_coords: List[FloatCoord],
- pursuer_start_coords: List[FloatCoord],
- ):
+ blocked_coords: set[Coord],
+ goal_coords_map: dict[FloatCoord, list[FloatCoord]],
+ evader_start_coords: list[FloatCoord],
+ pursuer_start_coords: list[FloatCoord],
+ ) -> None:
interior_walls = generate_interior_walls(size, size, blocked_coords)
super().__init__(
size=size,
@@ -783,7 +862,7 @@ def __init__(
self.pursuer_start_coords = pursuer_start_coords
self.shortest_paths = self.get_all_shortest_paths(self.all_goal_coords)
- def copy(self) -> "PEWorld":
+ def copy(self) -> PEWorld:
world = PEWorld(
size=int(self.size),
blocked_coords=self.blocked_coords,
@@ -802,14 +881,14 @@ def copy(self) -> "PEWorld":
return world
@property
- def all_goal_coords(self) -> List[FloatCoord]:
+ def all_goal_coords(self) -> list[FloatCoord]:
"""The list of all evader goal locations."""
all_locs = set()
for v in self._goal_coords_map.values():
all_locs.update(v)
return list(all_locs)
- def get_goal_coords(self, evader_start_coord: FloatCoord) -> List[FloatCoord]:
+ def get_goal_coords(self, evader_start_coord: FloatCoord) -> list[FloatCoord]:
"""Get list of possible evader goal coords for given start coords."""
return self._goal_coords_map[evader_start_coord]
@@ -915,9 +994,9 @@ def convert_map_to_world(
height: int,
width: int,
block_symbol: str = "#",
- pursuer_start_symbols: Optional[Set[str]] = None,
- evader_start_symbols: Optional[Set[str]] = None,
- evader_goal_symbol_map: Optional[Dict] = None,
+ pursuer_start_symbols: set[str] | None = None,
+ evader_start_symbols: set[str] | None = None,
+ evader_goal_symbol_map: dict | None = None,
) -> PEWorld:
"""Generate PE world layout from ascii map.
@@ -950,7 +1029,7 @@ def convert_map_to_world(
"9": ["0", "1", "2"],
}
- blocked_coords: Set[Coord] = set()
+ blocked_coords: set[Coord] = set()
evader_start_coords = []
pursuer_start_coords = []
evader_symbol_coord_map = {}
@@ -981,9 +1060,22 @@ def convert_map_to_world(
)
-# world_name: world_make_fn
-SUPPORTED_WORLDS: Dict[str, Callable[[], PEWorld]] = {
+SUPPORTED_WORLDS: dict[str, Callable[[], PEWorld]] = {
"8x8": get_8x8_world,
"16x16": get_16x16_world,
"32x32": get_32x32_world,
}
+
+
+if __name__ == "__main__":
+ from posggym.utils.run_random_agents import run_random
+
+ run_random(
+ PursuitEvasionContinuousEnv(
+ render_mode="human",
+ obs_self_model=True,
+ control_type=ControlType.WheeledRobot,
+ ),
+ num_episodes=5,
+ max_episode_steps=100,
+ )
diff --git a/posggym/envs/differentiable/__init__.py b/posggym/envs/differentiable/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/posggym/envs/differentiable/predator_prey_diff.py b/posggym/envs/differentiable/predator_prey_diff.py
new file mode 100644
index 0000000..4893d66
--- /dev/null
+++ b/posggym/envs/differentiable/predator_prey_diff.py
@@ -0,0 +1,910 @@
+from ctypes import byref
+from functools import partial
+from itertools import product
+from typing import NamedTuple, cast
+
+import numpy as np
+import torch
+from gymnasium import spaces
+from vmas.simulator.core import Agent, EntityState, Line, Sphere, World
+from vmas.simulator.dynamics.common import Dynamics
+from vmas.simulator.dynamics.diff_drive import DiffDrive
+from vmas.simulator.dynamics.holonomic import Holonomic
+from vmas.simulator.utils import (
+ Color,
+ TorchUtils,
+ X,
+ Y,
+)
+
+import posggym.model as M
+from posggym.core import DefaultEnv
+from posggym.envs.differentiable.utils import (
+ AgentStateWrapper,
+ POSGGymLandmark,
+ POSGGymLidar,
+ POSGGymSensor,
+ TensorJointTimestep,
+ clip_actions,
+ clone_state,
+)
+
+
+torch.backends.cudnn.deterministic = True
+torch.backends.cudnn.benchmark = False
+
+
+class PPState(NamedTuple):
+ """A state in the Continuous Predator-Prey Environment."""
+
+ predator_states: dict[str, AgentStateWrapper]
+ prey_states: dict[str, AgentStateWrapper]
+ prey_caught: torch.Tensor
+
+
+class PPAgent(Agent):
+ def __init__(self, batch_size, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+ self.caught = torch.zeros(batch_size, 1, dtype=torch.bool)
+
+ self._state = AgentStateWrapper()
+ self.rew = torch.Tensor()
+
+ def set_caught(self, caught):
+ self.caught = caught
+
+ def __getstate__(self):
+ state = self.__dict__.copy()
+ # Remove lambda function or other unpicklable attributes
+ if "_collision_filter" in state:
+ del state["_collision_filter"]
+ return state
+
+ # Optional: Override __setstate__ to restore state
+ def __setstate__(self, state):
+ self.__dict__.update(state)
+ # Optionally re-create the lambda function after unpickling
+ self._collision_filter = lambda _: True
+
+ @property
+ def state(self) -> AgentStateWrapper:
+ return self._state
+
+ @property
+ def sensors(self) -> list[POSGGymSensor]:
+ return self._sensors
+
+ def set_all_pos(self, pos: torch.Tensor):
+ self._set_all_state_property(EntityState.pos, self.state, pos)
+
+ def set_all_vel(self, vel: torch.Tensor):
+ self._set_all_state_property(EntityState.vel, self.state, vel)
+
+ def set_all_rot(self, rot: torch.Tensor):
+ self._set_all_state_property(EntityState.rot, self.state, rot)
+
+ def set_all_ang_vel(self, ang_vel: torch.Tensor):
+ self._set_all_state_property(EntityState.ang_vel, self.state, ang_vel)
+
+ def _set_all_state_property(self, prop, entity: EntityState, new: torch.Tensor):
+ value = prop.fget(entity)
+ value[:, ...] = new
+ self.notify_observers()
+
+
+class PPWorld(World):
+ def __init__(
+ self,
+ bound: float,
+ blocks: list[tuple[tuple[float, float, float], float]] | None = None,
+ *args,
+ **kwargs,
+ ) -> None:
+ super().__init__(*args, x_semidim=bound, y_semidim=bound, **kwargs)
+ self.bound = bound
+ self._agents: list[PPAgent] = []
+
+ for i in range(4):
+ self.add_landmark(
+ POSGGymLandmark(
+ name=f"landmark-wall{i}",
+ collide=True,
+ shape=Line(length=self.bound * 2),
+ color=Color.WHITE,
+ )
+ )
+ # Add landmarks
+ if blocks is not None:
+ for _, r in blocks:
+ landmark = POSGGymLandmark(
+ name=f"landmark {i}",
+ collide=True,
+ shape=Sphere(radius=r),
+ color=Color.BLACK,
+ )
+ self.add_landmark(landmark)
+
+ all_wall_pos: list[tuple[list[float], float]] = [
+ ([0, self.bound], 0),
+ ([0, -self.bound], 0),
+ ([-self.bound, 0], np.pi / 2),
+ ([self.bound, 0], np.pi / 2),
+ ]
+
+ for idx, (pos, rot) in enumerate(all_wall_pos):
+ self.landmarks[idx].set_pos(
+ torch.tensor(
+ pos,
+ device=self.device,
+ ),
+ batch_index=None, # type: ignore
+ )
+ self.landmarks[idx].set_rot(
+ torch.tensor(
+ [rot],
+ device=self.device,
+ ),
+ batch_index=None, # type: ignore
+ )
+ if blocks is not None:
+ for ((x, y, _), _), landmark in zip(
+ blocks, self.landmarks[4:], strict=False
+ ):
+ landmark.set_pos(
+ torch.ones(
+ (self.batch_dim, self.dim_p),
+ device=self.device,
+ dtype=torch.float32,
+ )
+ * torch.tensor([x, y], device=self.device, dtype=torch.float32),
+ batch_index=None, # type: ignore
+ )
+
+ def add_agent(self, agent: PPAgent):
+ super().add_agent(agent)
+
+ def update_state(self, state: PPState):
+ for a in self.agents:
+ if a.name.startswith("adversary"):
+ a_state = state.prey_states[a.name]
+ else:
+ a_state = state.predator_states[a.name]
+
+ a.set_all_pos(a_state.pos_safe.clone())
+ a.set_all_vel(a_state.vel_safe.clone())
+ a.set_all_rot(a_state.rot_safe.clone())
+ a.set_all_ang_vel(a_state.ang_vel_safe.clone())
+
+ def get_state(self) -> PPState:
+ return PPState(
+ {x.name: clone_state(x.state) for x in self.predator},
+ {x.name: clone_state(x.state) for x in self.prey},
+ torch.cat([x.caught for x in self.prey], dim=1),
+ )
+
+ @property
+ def agents(self) -> list[PPAgent]:
+ return self._agents
+
+ @property
+ def prey(self) -> list[PPAgent]:
+ return [x for x in self.agents if x.name.startswith("adversary")]
+
+ @property
+ def predator(self) -> list[PPAgent]:
+ return [x for x in self.agents if x.name.startswith("agent")]
+
+
+class PredatorPreyDiffModel(M.POSGModel[PPState, torch.Tensor, torch.Tensor]):
+ R_MAX = 2
+ MAX_AGENTS = 8
+
+ def __init__(
+ self,
+ world: partial[PPWorld],
+ num_predators: int = 5,
+ num_prey: int = 8,
+ cooperative: bool = False,
+ prey_strength: int | None = None,
+ obs_dist: float = 10,
+ n_sensors: int = 8,
+ batch_size=4,
+ device: str = "cpu",
+ ) -> None:
+ assert 1 < num_predators <= self.MAX_AGENTS
+ assert num_prey > 0
+ assert obs_dist > 0
+
+ self._world = world
+
+ self.num_predators = num_predators
+ self.num_prey = num_prey
+ self.num_landmarks = 2
+ self.num_agents = self.num_predators + self.num_predators
+ self.cooperative = cooperative
+ self.prey_strength = prey_strength
+ self.obs_dist = obs_dist
+ self.n_sensors = n_sensors
+ self.prey_obs_dist = 1.0
+ self.adversaries_share_rew = True
+ self.shape_agent_rew = True
+ self.shape_adversary_rew = True
+ self.agents_share_rew = False
+ self.prey_share_rew = True
+ self.observe_same_team = True
+ self.observe_pos = True
+ self.observe_vel = True
+ self.bound = None
+ self.respawn_at_catch = False
+ self.per_prey_reward = self.R_MAX / self.num_prey
+ self.prey_capture_dist = 0.1
+ self.batch_size = batch_size
+ self.is_symmetric = True
+ self.device = device
+
+ self.action_spaces = {
+ i: spaces.Box(np.array([-1, -1]), np.array([1, 1]), seed=42 + idx)
+ for idx, i in enumerate(self.possible_agents)
+ }
+
+ self.observation_spaces = {
+ i: spaces.Box(
+ low=np.array([0.0] * self.n_sensors * 3),
+ high=np.array([self.obs_dist] * self.n_sensors * 3),
+ )
+ for i in self.possible_agents
+ }
+ self.initialise()
+
+ @property
+ def reward_ranges(self) -> dict[str, tuple[float, float]]:
+ return {i: (-12, self.R_MAX) for i in self.possible_agents}
+
+ def get_agents(self, state: PPState) -> list[str]:
+ return list(self.possible_agents)
+
+ def sample_initial_state(self) -> PPState:
+ return self.reset_world_at()
+
+ def gen_dynamics(self) -> Dynamics:
+ idx = torch.randint(0, 2, (1,), generator=self.rng, device=self.device).item()
+ return [Holonomic(), DiffDrive(self.world, integration="rk4")][
+ idx
+ ] # type: ignore
+
+ def initialise(self) -> PPState:
+ self.world = self._world(
+ batch_dim=self.batch_size,
+ device=self.device,
+ substeps=10,
+ collision_force=500,
+ )
+
+ self.bound = self.world.bound
+
+ # set any world properties first
+ num_agents = self.num_predators + self.num_predators
+ self.adversary_radius = 0.075
+
+ # Add agents
+ for i in range(num_agents):
+ adversary = i < self.num_predators
+ name = f"adversary_{i}" if adversary else f"agent_{i - self.num_predators}"
+ agent = PPAgent(
+ batch_size=self.batch_size,
+ name=name,
+ collide=True,
+ shape=Sphere(radius=self.adversary_radius if adversary else 0.05),
+ u_multiplier=3.0 if adversary else 4.0,
+ max_speed=1.0 if adversary else 1.3,
+ color=Color.BLUE if adversary else Color.GREEN,
+ adversary=adversary,
+ dynamics=Holonomic() if adversary else self.gen_dynamics(),
+ sensors=(
+ [
+ POSGGymLidar(
+ self.world,
+ entity_name="landmark",
+ n_rays=self.n_sensors,
+ max_range=self.obs_dist,
+ render_color=Color.GREEN,
+ angle_start=0.05,
+ angle_end=(2 * torch.pi) + 0.05,
+ ),
+ POSGGymLidar(
+ self.world,
+ entity_name="adversary",
+ n_rays=self.n_sensors,
+ max_range=self.obs_dist,
+ render_color=Color.RED,
+ angle_start=0.05,
+ angle_end=(2 * torch.pi) + 0.05,
+ ),
+ POSGGymLidar(
+ self.world,
+ entity_name="agent",
+ n_rays=self.n_sensors,
+ max_range=self.obs_dist,
+ render_color=Color.BLUE,
+ angle_start=0.05,
+ angle_end=(2 * torch.pi) + 0.05,
+ ),
+ ]
+ if not adversary
+ else []
+ ),
+ )
+ self.world.add_agent(agent)
+
+ return self.reset_world_at()
+
+ def reset_world_at(self) -> PPState:
+ assert self.bound is not None
+
+ predator_states, prey_states, prey_caught = (
+ {},
+ {},
+ torch.zeros(self.batch_size, self.num_prey, dtype=torch.bool),
+ )
+ for p in self.world.predator:
+ state = AgentStateWrapper()
+ state.batch_dim = self.world._batch_dim # pyright: ignore
+ state.device = self.world._device # pyright: ignore
+
+ state.pos = torch.zeros(
+ (self.world.batch_dim, self.world.dim_p),
+ device=self.device,
+ dtype=torch.float32,
+ ).uniform_(-self.bound, self.bound, generator=self.rng)
+ state.pos.requires_grad = True
+ state.vel = torch.zeros(
+ (self.world.batch_dim, self.world.dim_p),
+ device=self.device,
+ dtype=torch.float32,
+ requires_grad=True,
+ )
+ state.rot = torch.zeros(
+ self.world.batch_dim,
+ 1,
+ device=self.device,
+ dtype=torch.float32,
+ requires_grad=True,
+ )
+ state.ang_vel = torch.zeros(
+ self.world.batch_dim,
+ 1,
+ device=self.device,
+ dtype=torch.float32,
+ requires_grad=True,
+ )
+ predator_states[p.name] = state
+
+ for p in self.world.prey:
+ state = AgentStateWrapper()
+ state.batch_dim = self.world._batch_dim # pyright: ignore
+ state.device = self.world._device # pyright: ignore
+
+ state.pos = torch.zeros(
+ (self.world.batch_dim, self.world.dim_p),
+ device=self.device,
+ dtype=torch.float32,
+ ).uniform_(-self.bound, self.bound, generator=self.rng)
+ state.pos.requires_grad = True
+ state.vel = torch.zeros(
+ (self.world.batch_dim, self.world.dim_p),
+ device=self.device,
+ dtype=torch.float32,
+ requires_grad=True,
+ )
+ state.rot = torch.zeros(
+ self.world.batch_dim,
+ 1,
+ device=self.device,
+ dtype=torch.float32,
+ requires_grad=True,
+ )
+ state.ang_vel = torch.zeros(
+ self.world.batch_dim,
+ 1,
+ device=self.device,
+ dtype=torch.float32,
+ requires_grad=True,
+ )
+ prey_states[p.name] = state
+
+ return PPState(predator_states, prey_states, prey_caught)
+
+ def is_collision(self, agent1: Agent, agent2: Agent):
+ delta_pos = agent1.state.pos - agent2.state.pos # type: ignore
+ dist = torch.linalg.vector_norm(delta_pos, dim=-1)
+ dist_min = agent1.shape.radius + agent2.shape.radius # type: ignore
+ return dist < dist_min
+
+ # return all adversarial agents
+ def prey(self, world: PPWorld):
+ return [agent for agent in world.agents if agent.adversary]
+
+ def _get_prey_move_angles(self, state: PPState) -> torch.Tensor:
+ pred_states = torch.stack(
+ [x.pos_safe for x in state.predator_states.values()], dim=1
+ )
+ prey_states = torch.stack(
+ [x.pos_safe for x in state.prey_states.values()], dim=1
+ )
+
+ pred_dists = torch.linalg.norm(prey_states - pred_states, axis=-1)
+ prey_dists = torch.linalg.norm(
+ prey_states.unsqueeze(2) - prey_states.unsqueeze(1), dim=-1
+ )
+
+ prey_actions = -torch.ones(state.prey_caught.shape)
+
+ a = []
+
+ for i, prey in enumerate(state.prey_states.values()):
+ prey_actions[state.prey_caught[:, i], i] = 0
+
+ pred_dists = torch.linalg.norm(
+ prey.pos_safe[:, None, :] - pred_states, axis=1
+ )
+
+ min_pred_dist, pred_idx = pred_dists.min(dim=1)
+
+ expanded_idx = pred_idx.view(-1, 1, 1).expand(-1, 1, 2)
+ gathered_values = torch.gather(pred_states, 1, expanded_idx)
+ pred_influence_angle = torch.atan2(
+ prey.pos_safe[:, 1] - gathered_values[:, 0, 1],
+ prey.pos_safe[:, 0] - gathered_values[:, 0, 0],
+ )
+
+ not_current_mask = torch.ones(len(state.prey_states), dtype=torch.bool)
+ not_current_mask[i] = False
+
+ # Compute distances
+ prey_dists = torch.linalg.norm(
+ prey_states[:, not_current_mask, :] - prey_states[:, i : i + 1, :],
+ dim=2,
+ )
+ min_prey_dist, prey_idx = prey_dists.min(dim=1)
+
+ prey_influence_strength = torch.clamp(
+ 1.0 - prey_dists, min=0.0
+ ) # Clamp to ensure no negative values
+ strength_sum = prey_influence_strength.sum(dim=1, keepdim=True)
+ normalized_strength = prey_influence_strength / (
+ strength_sum + 1e-6
+ ) # Avoid division by zero
+
+ expanded_idx = prey_idx.view(-1, 1, 1).expand(-1, 1, 2)
+ gathered_values = torch.gather(pred_states, 1, expanded_idx)
+ prey_influence_angle = torch.atan2(
+ prey.pos_safe[:, 1] - gathered_values[:, 0, 1],
+ prey.pos_safe[:, 0] - gathered_values[:, 0, 0],
+ )
+
+ prey_influence_strength = torch.clamp(
+ 1.0 - min_prey_dist, min=0.0
+ ) # Prey influence strength
+ pred_influence_strength = torch.clamp(
+ 1.0 - min_pred_dist, min=0.0
+ ) # Predator influence strength
+
+ # Step 2: Combine the influence strengths
+ total_influence_strength = torch.stack(
+ (prey_influence_strength, pred_influence_strength), dim=0
+ )
+ row_sums = total_influence_strength.sum(dim=0, keepdim=True)
+ normalized_strength = total_influence_strength / (row_sums + 1e-6)
+ zero_rows = (normalized_strength.sum(dim=0, keepdim=True) == 0).float()
+ normalized_strength += zero_rows / normalized_strength.size(1)
+
+ angles = torch.stack((prey_influence_angle, pred_influence_angle), dim=0)
+ a.append((angles * normalized_strength).sum(dim=0, keepdim=True).T)
+ return torch.stack(a, dim=1).squeeze(2)
+
+ def reward(self, agent: PPAgent):
+ is_first = agent == self.world.predator[0]
+
+ if is_first:
+ for a in self.world.predator:
+ a.rew = self.agent_reward(a)
+
+ self.agents_rew = torch.stack(
+ [a.rew for a in self.world.predator], dim=-1
+ ).sum(-1)
+
+ if self.agents_share_rew:
+ return self.agents_rew
+ else:
+ return agent.rew
+
+ def agent_reward(self, agent: PPAgent):
+ # Agents are negatively rewarded if caught by adversaries
+ rew = torch.zeros(
+ self.world.batch_dim, device=self.world.device, dtype=torch.float32
+ )
+ adversaries = self.world.prey
+ if self.shape_agent_rew:
+ # reward can optionally be shaped
+ # (increased reward for increased distance from adversary)
+ for adv in adversaries:
+ rew += 0.1 * torch.linalg.vector_norm(
+ agent.state.pos_safe - adv.state.pos, dim=-1
+ )
+ if agent.collide:
+ for a in adversaries:
+ # pass
+ rew[self.is_collision(a, agent)] -= 10 / len(adversaries)
+
+ return rew
+
+ def adversary_reward(self, agent: PPAgent):
+ # Adversaries are rewarded for collisions with agents
+ rew = torch.zeros(
+ self.world.batch_dim, device=self.world.device, dtype=torch.float32
+ )
+ agents = self.world.predator
+ if self.shape_adversary_rew: # reward can optionally be shaped
+ # (decreased reward for increased distance from agents)
+ rew -= (
+ 0.1
+ * torch.min(
+ torch.stack(
+ [
+ torch.linalg.vector_norm(
+ a.state.pos_safe - agent.state.pos,
+ dim=-1,
+ )
+ for a in agents
+ ],
+ dim=-1,
+ ),
+ dim=-1,
+ )[0]
+ )
+ if agent.collide:
+ for ag in agents:
+ rew[self.is_collision(ag, agent)] += 10
+ return rew
+
+ def observation(self, name: str):
+ world_agent = next(x for x in self.world.agents if name == x.name)
+ lidar_1_measures = torch.stack(
+ tuple(s.measure(self.world) for s in world_agent.sensors)
+ )
+ return lidar_1_measures.reshape(self.batch_size, -1)
+
+ def sample_initial_obs(self, state: PPState) -> dict[str, torch.Tensor]:
+ obs = {}
+ for name, _agent in state.predator_states.items():
+ observation = TorchUtils.recursive_clone(self.observation(name))
+ obs.update({name: observation})
+ return obs
+
+ def info(self, world: PPWorld, agent: Agent):
+ return {}
+
+ def done(self, world: PPWorld):
+ agents: list[PPAgent] = world.agents # type: ignore
+ return torch.Tensor([x.caught for x in agents])
+
+ def get_from_scenario(
+ self,
+ ):
+ obs, rewards, infos, dones = {}, {}, {}, {}
+
+ for agent in self.world.agents:
+ if agent.name.startswith("agent"):
+ observation = TorchUtils.recursive_clone(self.observation(agent.name))
+ obs.update({agent.name: observation})
+
+ for agent in self.world.predator:
+ reward = self.reward(agent).clone()
+ rewards.update({agent.name: reward})
+
+ for agent in self.world.predator:
+ info = TorchUtils.recursive_clone(self.info(self.world, agent))
+ infos.update({agent.name: info})
+
+ dones = {
+ i: torch.zeros((self.batch_size), dtype=torch.bool, device=self.device)
+ for i in self.possible_agents
+ }
+ truncated = {
+ i: torch.zeros((self.batch_size), dtype=torch.bool, device=self.device)
+ for i in self.possible_agents
+ }
+
+ return [obs, rewards, dones, truncated, infos]
+
+ @property
+ def possible_agents(self):
+ return tuple(f"agent_{x}" for x in range(self.num_predators))
+
+ def render(self):
+ pass
+
+ def step(
+ self, state: PPState, actions: dict[str, torch.Tensor]
+ ) -> TensorJointTimestep:
+ self.world.update_state(state)
+
+ prey_actions = self._get_prey_move_angles(state)
+ prey_actions_ = [
+ torch.stack([torch.cos(angle), torch.sin(angle)]).detach()
+ for angle in prey_actions
+ ]
+ prey_actions_ = torch.stack(
+ [torch.cos(prey_actions), torch.sin(prey_actions)], dim=-1
+ ).permute(1, 0, 2)
+
+ # clip actions
+ clipped_actions = clip_actions(actions, self.action_spaces)
+
+ for idx, agent in enumerate(self.world.predator):
+ action = clipped_actions[f"agent_{idx}"]
+ agent.action.u = action
+ agent.dynamics.process_action()
+
+ for act, agent in zip(prey_actions_, self.world.prey, strict=False):
+ agent.action.u = act
+ agent.state.force = agent.action.u
+ self.world.step()
+
+ next_state = self.world.get_state()
+
+ obs, rewards, terminated, truncated, infos = self.get_from_scenario()
+ all_done = torch.stack(tuple(terminated.values())).transpose(1, 0).all(dim=1)
+
+ return TensorJointTimestep(
+ next_state, obs, rewards, terminated, truncated, all_done, infos
+ )
+
+ @property
+ def rng(self) -> torch.Generator:
+ if self._rng is None:
+ self._rng = torch.Generator(device=self.device)
+
+ return self._rng
+
+
+class PredatorPreyDiff(DefaultEnv[PPState, torch.Tensor, torch.Tensor]):
+ def __init__(
+ self,
+ world: partial[PPWorld] | str,
+ num_predators: int = 4,
+ num_prey: int = 8,
+ cooperative: bool = False,
+ prey_strength: int | None = None,
+ obs_dist: float = 10,
+ n_sensors: int = 32,
+ batch_size: int = 4,
+ device: str | None = None,
+ render_mode: str = "human",
+ ) -> None:
+ if isinstance(world, str):
+ assert world in SUPPORTED_WORLDS, (
+ f"Unsupported world name '{world}'. World name must be one of: "
+ f"{list(SUPPORTED_WORLDS)}."
+ )
+ world = SUPPORTED_WORLDS[world]()
+
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
+
+ model = PredatorPreyDiffModel(
+ world,
+ num_predators,
+ num_prey,
+ cooperative,
+ prey_strength,
+ obs_dist,
+ n_sensors,
+ batch_size,
+ self.device,
+ )
+ self.batch_size = batch_size
+ self.viewer = None
+ self.visible_display = None
+
+ super().__init__(model)
+
+ def render(
+ self,
+ mode="human",
+ env_index=0,
+ agent_index_focus: int | None = None,
+ visualize_when_rgb: bool = False,
+ ):
+ """Render function for environment using pyglet
+ From VMAS.
+ """
+ viewer_size = (700, 700)
+
+ model = cast(PredatorPreyDiffModel, self.model)
+
+ shared_viewer = agent_index_focus is None
+ aspect_ratio = viewer_size[0] / viewer_size[1]
+
+ headless = mode == "rgb_array" and not visualize_when_rgb
+ # First time rendering
+ if self.visible_display is None:
+ self.visible_display = not headless
+ self.headless = headless
+ # All other times headless should be the same
+ else:
+ assert self.visible_display is not headless
+
+ # First time rendering
+ if self.viewer is None:
+ try:
+ import pyglet
+ except ImportError as err:
+ raise ImportError(
+ "Cannot import pyglet: you can install"
+ "pyglet directly via 'pip install pyglet'."
+ ) from err
+
+ try:
+ # Try to use EGL
+ pyglet.lib.load_library("EGL")
+
+ # Only if we have GPUs
+ from pyglet.libs.egl import egl, eglext
+
+ num_devices = egl.EGLint()
+ eglext.eglQueryDevicesEXT(0, None, byref(num_devices))
+ assert num_devices.value > 0
+
+ except (ImportError, AssertionError):
+ self.headless = False
+ pyglet.options["headless"] = self.headless
+
+ self._init_rendering()
+
+ zoom = 1.2
+
+ if aspect_ratio < 1:
+ cam_range = torch.tensor([zoom, zoom / aspect_ratio], device=self.device)
+ else:
+ cam_range = torch.tensor([zoom * aspect_ratio, zoom], device=self.device)
+
+ if shared_viewer:
+ # zoom out to fit everyone
+ all_poses = torch.stack(
+ [
+ agent.state.pos[env_index] # type: ignore
+ for agent in model.world.agents + model.world.landmarks
+ ],
+ dim=0,
+ )
+ max_agent_radius = max(
+ [agent.shape.circumscribed_radius() for agent in model.world.agents]
+ )
+ viewer_size_fit = (
+ torch.stack(
+ [
+ torch.max(torch.abs(all_poses[:, X] - 0)),
+ torch.max(torch.abs(all_poses[:, Y] - 0)),
+ ]
+ )
+ + 2 * max_agent_radius
+ )
+
+ viewer_size = torch.maximum(
+ viewer_size_fit / cam_range,
+ torch.tensor(zoom, device=self.device),
+ )
+ cam_range *= torch.max(viewer_size)
+ assert self.viewer is not None
+
+ self.viewer.set_bounds(
+ -cam_range[X] + 0,
+ cam_range[X] + 0,
+ -cam_range[Y] + 0,
+ cam_range[Y] + 0,
+ )
+
+ for entity in model.world.entities:
+ assert self.viewer is not None
+ self.viewer.add_onetime_list(entity.render(env_index=env_index))
+
+ # render to display or array
+ assert self.viewer is not None
+
+ return self.viewer.render(return_rgb_array=mode == "rgb_array")
+
+ def _init_rendering(self):
+ from vmas.simulator import rendering
+
+ self.viewer = rendering.Viewer(
+ *(700, 700), visible=self.visible_display or False
+ )
+ model = cast(PredatorPreyDiffModel, self.model)
+
+ self.text_lines = []
+ idx = 0
+ if model.world.dim_c > 0:
+ for agent in model.world.agents:
+ if not agent.silent:
+ text_line = rendering.TextLine(y=idx * 40)
+ self.viewer.geoms.append(text_line)
+ self.text_lines.append(text_line)
+ idx += 1
+
+
+def get_5x5_world() -> partial[PPWorld]:
+ """Generate 5x5 world layou`t."""
+ return get_default_world(5, include_blocks=False)
+
+
+def get_5x5_blocks_world() -> partial[PPWorld]:
+ """Generate 5x5 Blocks world layout."""
+ return get_default_world(5, include_blocks=True)
+
+
+def get_10x10_world() -> partial[PPWorld]:
+ """Generate 10x10 world layou`t."""
+ return get_default_world(10, include_blocks=False)
+
+
+def get_10x10_blocks_world() -> partial[PPWorld]:
+ """Generate 10x10 Blocks world layout."""
+ return get_default_world(10, include_blocks=True)
+
+
+def get_15x15_world() -> partial[PPWorld]:
+ """Generate 15x15 world layou`t."""
+ return get_default_world(15, include_blocks=False)
+
+
+def get_15x15_blocks_world() -> partial[PPWorld]:
+ """Generate 15x15 Blocks world layout."""
+ return get_default_world(15, include_blocks=True)
+
+
+def get_20x20_world() -> partial[PPWorld]:
+ """Generate 20x20 world layout."""
+ return get_default_world(20, include_blocks=False)
+
+
+def get_20x20_blocks_world() -> partial[PPWorld]:
+ """Generate 20x20 Blocks world layout."""
+ return get_default_world(20, include_blocks=True)
+
+
+def get_default_world(size: int, include_blocks: bool) -> partial[PPWorld]:
+ """Get function for generaing default world with given size.
+
+ If `include_blocks=True` then world will contain blocks with the following layout:
+
+ .....
+ .#.#.
+ .....
+ .#.#.
+ .....
+
+ Where `#` are the blocks, which will be represented as a single circle.
+ """
+ bound = size / 10
+
+ r = float(bound / 10)
+ if include_blocks:
+ blocks = [
+ ((x, y, 0.0), r)
+ for x, y in product([-3 * bound / 5, 3 * bound / 5], repeat=2)
+ ]
+ else:
+ blocks = []
+ return partial(PPWorld, bound=bound, blocks=blocks)
+
+
+SUPPORTED_WORLDS = {
+ "5x5": get_5x5_world,
+ "5x5Blocks": get_5x5_blocks_world,
+ "10x10": get_10x10_world,
+ "10x10Blocks": get_10x10_blocks_world,
+ "15x15": get_15x15_world,
+ "15x15Blocks": get_15x15_blocks_world,
+ "20x20": get_20x20_world,
+ "20x20Blocks": get_20x20_blocks_world,
+}
diff --git a/posggym/envs/differentiable/utils.py b/posggym/envs/differentiable/utils.py
new file mode 100644
index 0000000..6621bc0
--- /dev/null
+++ b/posggym/envs/differentiable/utils.py
@@ -0,0 +1,454 @@
+import dataclasses
+from abc import abstractmethod
+from collections.abc import Callable
+
+import numpy as np
+import torch
+from gymnasium import spaces
+from vmas.simulator.core import AgentState, Box, Entity, Landmark, Line, Sphere, World
+from vmas.simulator.sensors import Lidar, Sensor
+from vmas.simulator.utils import TorchUtils, X, Y
+
+import posggym.model as M
+
+
+ZERO = 0.0
+ABOVE_VALUE = 0.5
+BELOW_VALUE = -0.5
+
+
+@dataclasses.dataclass(order=True)
+class TensorJointTimestep(M.JointTimestep):
+ """The result of a single step in the model.
+
+ Supports iteration.
+
+ A dataclass is used instead of a Namedtuple so that generic typing is seamlessly
+ supported.
+
+ """
+
+ terminations: dict[str, torch.Tensor]
+ truncations: dict[str, torch.Tensor]
+ all_done: torch.Tensor
+ infos: dict[str, dict]
+
+ def __iter__(self):
+ for field in dataclasses.fields(self):
+ yield getattr(self, field.name)
+
+
+def clone_tensors(obj: AgentState):
+ cloned_attrs = {}
+ for attr_name, attr_value in obj.__dict__.items():
+ if isinstance(attr_value, torch.Tensor):
+ cloned_attrs[attr_name] = attr_value.clone()
+ else:
+ cloned_attrs[attr_name] = attr_value
+ return cloned_attrs
+
+
+class AgentStateWrapper(AgentState):
+ @property
+ def pos_safe(self) -> torch.Tensor:
+ if self._pos is None:
+ raise AttributeError("pos is none")
+
+ return self._pos
+
+ @property
+ def rot_safe(self) -> torch.Tensor:
+ if self._rot is None:
+ raise AttributeError("rot is none")
+
+ return self._rot
+
+ @property
+ def vel_safe(self) -> torch.Tensor:
+ if self._vel is None:
+ raise AttributeError("vel is none")
+
+ return self._vel
+
+ @property
+ def ang_vel_safe(self) -> torch.Tensor:
+ if self._ang_vel is None:
+ raise AttributeError("ang_vel is none")
+
+ return self._ang_vel
+
+ def __eq__(self, value) -> bool:
+ if not isinstance(value, AgentStateWrapper):
+ return False
+
+ return (
+ torch.equal(self.pos_safe, value.pos_safe)
+ and torch.equal(self.rot_safe, value.rot_safe)
+ and torch.equal(self.vel_safe, value.vel_safe)
+ and torch.equal(self.ang_vel_safe, value.ang_vel_safe)
+ )
+
+
+def clone_state(state: AgentStateWrapper):
+ a_s = AgentStateWrapper()
+ t = clone_tensors(state)
+ a_s._batch_dim = t["_batch_dim"] # pyright: ignore
+ a_s._device = t["_device"] # pyright: ignore
+ a_s.pos = t["_pos"]
+ a_s.ang_vel = t["_ang_vel"]
+ a_s.force = t["_force"]
+ a_s.pos = t["_pos"]
+ a_s.rot = t["_rot"]
+ a_s.torque = t["_torque"]
+ a_s.vel = t["_vel"]
+
+ return a_s
+
+
+def clip_actions(
+ actions: dict[str, torch.Tensor], action_spaces: dict[str, spaces.Space]
+) -> dict[str, torch.Tensor]:
+ """Clip continuous actions so they are within the agents action space dims."""
+ clipped_actions = {}
+ for i, a in actions.items():
+ a_space = action_spaces[i]
+ assert isinstance(a_space, spaces.Box)
+ if isinstance(a, torch.Tensor):
+ clipped_actions[i] = torch.clip(
+ a, torch.from_numpy(a_space.low), torch.from_numpy(a_space.high)
+ )
+ else:
+ clipped_actions[i] = torch.from_numpy(np.clip(a, a_space.low, a_space.high))
+
+ return clipped_actions
+
+
+class POSGGymLandmark(Landmark):
+ def __getstate__(self):
+ state = self.__dict__.copy()
+ # Remove lambda function or other unpicklable attributes
+ if "_collision_filter" in state:
+ del state["_collision_filter"]
+ return state
+
+ # Optional: Override __setstate__ to restore state
+ def __setstate__(self, state):
+ self.__dict__.update(state)
+ # Optionally re-create the lambda function after unpickling
+ self._collision_filter = lambda _: True
+
+
+class POSGGymSensor(Sensor):
+ @abstractmethod
+ def measure(self, world: World):
+ raise NotImplementedError
+
+
+class POSGGymLidar(Lidar, POSGGymSensor):
+ def __init__(self, world: World, entity_name: str, **kwargs) -> None:
+ self.entity_name = entity_name
+ super().__init__(world, **kwargs)
+
+ def entity_filter(self, e: Entity) -> bool:
+ return e.name.startswith(self.entity_name)
+
+ def __getstate__(self):
+ state = self.__dict__.copy()
+ # Remove lambda function or other unpicklable attributes
+ if "_entity_filter" in state:
+ del state["_entity_filter"]
+ return state
+
+ # Optional: Override __setstate__ to restore state
+ def __setstate__(self, state):
+ self.__dict__.update(state)
+ # Optionally re-create the lambda function after unpickling
+ self._entity_filter = lambda _: True
+
+ def measure(self, world: World):
+ assert self.agent is not None
+ dists = cast_ray(
+ self.agent,
+ world.entities,
+ self._angles,
+ max_range=self._max_range,
+ entity_filter=self.entity_filter,
+ batch_dim=world.batch_dim,
+ device=world.device,
+ )
+ self._last_measurement = dists.swapaxes(1, 0)
+ return torch.clip(dists, 0, self._max_range)
+
+
+# @torch.compile
+def cast_ray(
+ entity: Entity,
+ entities: list[Entity],
+ angles: torch.Tensor,
+ max_range: float,
+ entity_filter: Callable[[Entity], bool] = lambda _: False,
+ batch_dim: int = 0,
+ device: torch.device | None = None,
+):
+ if device is None:
+ device = torch.device("cuda")
+
+ pos = entity.state.pos
+
+ # Initialize with full max_range to avoid
+ # dists being empty when all entities are filtered
+ dists = [
+ torch.full((batch_dim, angles.shape[0]), fill_value=max_range, device=device)
+ ]
+
+ for e in entities:
+ if entity is e or not entity_filter(e):
+ continue
+ assert e.collides(entity) and entity.collides(
+ e
+ ), "Rays are only casted among collidables"
+ if isinstance(e.shape, Box):
+ d = _cast_ray_to_box(e, pos, angles.T, max_range)
+ elif isinstance(e.shape, Sphere):
+ d = _cast_ray_to_sphere(e, pos, angles.T, max_range)
+ elif isinstance(e.shape, Line):
+ d = _cast_ray_to_line(e, pos, angles.T, max_range)
+ else:
+ raise RuntimeError(f"Shape {e.shape} currently not handled by cast_ray")
+ dists.append(d)
+ dist, _ = torch.min(torch.stack(dists, dim=-1), dim=-1)
+ return dist
+
+
+VECTOR_SIZE = 2
+
+
+def rotate_vector(vector: torch.Tensor, angle: torch.Tensor):
+ if len(angle.shape) == len(vector.shape):
+ angle = angle.squeeze(-1)
+
+ if angle.ndim < vector.ndim:
+ angle = angle.view(*([1] * (vector.ndim - angle.ndim)), *angle.shape)
+
+ assert vector.shape[-1] == VECTOR_SIZE
+
+ cos = torch.cos(angle)
+ sin = torch.sin(angle)
+
+ return torch.stack(
+ [
+ vector[..., 0] * cos - vector[..., 1] * sin,
+ vector[..., 0] * sin + vector[..., 1] * cos,
+ ],
+ dim=-1,
+ )
+
+
+# @torch.compile
+def _cast_ray_to_sphere(
+ sphere: Entity,
+ ray_origin: torch.Tensor,
+ ray_direction: torch.Tensor,
+ max_range: float,
+):
+ ray_dir_world = torch.stack(
+ [torch.cos(ray_direction), torch.sin(ray_direction)], dim=-1
+ )
+ assert sphere.state.pos is not None
+
+ test_point_pos = sphere.state.pos[:, None, :].repeat(1, ray_dir_world.shape[1], 1)
+ line_rot = ray_direction
+ line_length = max_range
+ ray_origin_ = ray_origin[:, None, :].repeat(1, ray_dir_world.shape[1], 1)
+ line_pos = ray_origin_ + ray_dir_world * (line_length / 2)
+
+ closest_point = _get_closest_point_line(
+ line_pos,
+ line_rot.unsqueeze(-1),
+ line_length,
+ test_point_pos,
+ limit_to_line_length=False,
+ )
+
+ d = test_point_pos - closest_point
+ d_norm = torch.linalg.vector_norm(d, dim=2)
+ ray_intersects = d_norm < sphere.shape.radius
+ a = sphere.shape.radius**2 - d_norm**2
+ m = torch.sqrt(torch.where(a > 0, a, 1e-8))
+
+ u = test_point_pos - ray_origin_
+ u1 = closest_point - ray_origin_
+
+ # Dot product of u and u1
+ u_dot_ray = (u * ray_dir_world).sum(-1)
+ sphere_is_in_front = u_dot_ray > ZERO
+ dist = torch.linalg.vector_norm(u1, dim=2) - m
+ dist[~(ray_intersects & sphere_is_in_front)] = max_range
+
+ return dist
+
+
+def cross(vector_a: torch.Tensor, vector_b: torch.Tensor):
+ # Ensure vector_a is broadcasted to match the shape of vector_b
+ vector_a_expanded = vector_a.unsqueeze(1).expand(-1, vector_b.size(1), -1)
+
+ return (
+ vector_a_expanded[..., 0] * vector_b[..., 1]
+ - vector_a_expanded[..., 1] * vector_b[..., 0]
+ ).unsqueeze(-1)
+
+
+# @torch.compile
+def _cast_ray_to_line(
+ line: Entity,
+ ray_origin: torch.Tensor,
+ ray_direction: torch.Tensor,
+ max_range: float,
+):
+ """Inspired by:
+ https://stackoverflow.com/questions/563198/how-do-you-detect-where-two-line-segments-intersect/565282#565282
+ Computes distance of ray originating from pos at angle to a line an
+ sets distance to max_range if there is no intersection.
+ """
+ assert isinstance(line.shape, Line)
+
+ assert line.state.rot is not None
+
+ p = line.state.pos
+ r = (
+ torch.stack(
+ [
+ torch.cos(line.state.rot.squeeze(1)),
+ torch.sin(line.state.rot.squeeze(1)),
+ ],
+ dim=-1,
+ )
+ * line.shape.length
+ )
+
+ q = ray_origin
+ s = torch.stack(
+ [
+ torch.cos(ray_direction),
+ torch.sin(ray_direction),
+ ],
+ dim=-1,
+ )
+
+ r = r.unsqueeze(1) # Shape becomes [2, 1, 2]
+ r = r.expand(-1, s.shape[1], -1) # Shape becomes [2, 32, 2]
+
+ rxs = TorchUtils.cross(r, s)
+ rxs[rxs == ZERO] = 1e-10
+
+ t = cross(q - p, s / rxs)
+ u = cross(q - p, r / rxs)
+
+ d = torch.linalg.norm(u * s, dim=-1)
+
+ perpendicular = rxs == ZERO
+ above_line = t > ABOVE_VALUE
+ below_line = t < BELOW_VALUE
+ behind_line = u < ZERO
+
+ new_d = d.clone()
+ new_d[perpendicular.squeeze(-1)] = max_range
+ new_d[above_line.squeeze(-1)] = max_range
+ new_d[below_line.squeeze(-1)] = max_range
+ new_d[behind_line.squeeze(-1)] = max_range
+
+ return new_d
+
+
+# @torch.compile
+def _get_closest_point_line(
+ line_pos,
+ line_rot,
+ line_length,
+ test_point_pos,
+ limit_to_line_length: bool = True,
+):
+ if not isinstance(line_length, torch.Tensor):
+ line_length = torch.tensor(
+ line_length, dtype=torch.float32, device=line_pos.device
+ ).expand(line_pos.shape[0])
+ # Rotate it by the angle of the line
+ rotated_vector = torch.cat([line_rot.cos(), line_rot.sin()], dim=-1)
+ # Get distance between line and sphere
+ delta_pos = line_pos - test_point_pos
+ # Dot product of distance and line vector
+ dot_p = (delta_pos * rotated_vector).sum(-1).unsqueeze(-1)
+ # Coordinates of the closes point
+ sign = torch.sign(dot_p)
+ distance_from_line_center = (
+ torch.minimum(
+ torch.abs(dot_p),
+ (line_length / 2).view(dot_p.shape),
+ )
+ if limit_to_line_length
+ else torch.abs(dot_p)
+ )
+ closest_point = line_pos - sign * distance_from_line_center * rotated_vector
+ return closest_point
+
+
+# @torch.compile
+def _cast_ray_to_box(
+ box: Entity,
+ ray_origin: torch.Tensor,
+ ray_direction: torch.Tensor,
+ max_range: float,
+):
+ """Inspired from https://tavianator.com/2011/ray_box.html
+ Computes distance of ray originating from pos at angle to a box and sets distance to
+ max_range if there is no intersection.
+ """
+ assert isinstance(box.shape, Box)
+
+ pos_origin = ray_origin - box.state.pos
+ pos_aabb = rotate_vector(pos_origin, -box.state.rot)[:, :, None, :].repeat(
+ 1, 1, ray_direction.shape[1], 1
+ )
+ ray_dir_world = torch.stack(
+ [torch.cos(ray_direction), torch.sin(ray_direction)], dim=-1
+ )
+
+ ray_dir_aabb = rotate_vector(ray_dir_world, -box.state.rot)
+
+ tx1 = (-box.shape.length / 2 - pos_aabb[..., X]) / ray_dir_aabb[..., X]
+ tx2 = (box.shape.length / 2 - pos_aabb[..., X]) / ray_dir_aabb[..., X]
+ tx = torch.stack([tx1, tx2], dim=-1)
+ tmin, _ = torch.min(tx, dim=-1)
+ tmax, _ = torch.max(tx, dim=-1)
+
+ ty1 = (-box.shape.width / 2 - pos_aabb[..., Y]) / ray_dir_aabb[..., Y]
+
+ ty2 = (box.shape.width / 2 - pos_aabb[..., Y]) / ray_dir_aabb[..., Y]
+
+ ty = torch.stack([ty1, ty2], dim=-1)
+ tymin, _ = torch.min(ty, dim=-1)
+ tymax, _ = torch.max(ty, dim=-1)
+ tmin, _ = torch.max(torch.stack([tmin, tymin], dim=-1), dim=-1)
+ tmax, _ = torch.min(torch.stack([tmax, tymax], dim=-1), dim=-1)
+
+ intersect_aabb = tmin.unsqueeze(tmin.ndim) * ray_dir_aabb + pos_aabb
+
+ assert box.state.pos is not None
+ assert box.state.rot is not None
+
+ intersect_world = rotate_vector(
+ intersect_aabb, box.state.rot.reshape(1, box.state.rot.shape[0], 1, 1)
+ ) + box.state.pos.reshape(1, box.state.pos.shape[0], 1, box.state.pos.shape[1])
+
+ collision = (tmax >= tmin) & (tmin > ZERO)
+
+ dist = torch.linalg.norm(
+ ray_origin[:, None, :].repeat(1, ray_direction.shape[1], 1) - intersect_world,
+ dim=-1,
+ )
+
+ new_dist = dist.clone().squeeze(0).squeeze(0)
+
+ new_dist[~collision.squeeze(0)] = max_range
+
+ return new_dist
diff --git a/posggym/envs/grid_world/cooperative_reaching.py b/posggym/envs/grid_world/cooperative_reaching.py
index 8d4943c..3f92289 100644
--- a/posggym/envs/grid_world/cooperative_reaching.py
+++ b/posggym/envs/grid_world/cooperative_reaching.py
@@ -2,7 +2,7 @@
from itertools import product
from pathlib import Path
-from typing import Dict, List, Optional, Tuple
+from typing import ClassVar
from gymnasium import spaces
@@ -12,8 +12,8 @@
from posggym.envs.grid_world.core import Coord, Direction, Grid
from posggym.utils import seeding
-# State = (coord_0, coord_1)
-CRState = Tuple[Coord, Coord]
+
+CRState = tuple[Coord, Coord]
# The actions
CRAction = int
@@ -27,8 +27,7 @@
ACTIONS_STR = ["0", "U", "D", "L", "R"]
ACTION_TO_DIR = [None, Direction.NORTH, Direction.SOUTH, Direction.WEST, Direction.EAST]
-# Obs = (ego_coord, other_coord)
-CRObs = Tuple[Coord, Coord]
+CRObs = tuple[Coord, Coord]
class CooperativeReachingEnv(DefaultEnv[CRState, CRObs, CRAction]):
@@ -89,9 +88,8 @@ class CooperativeReachingEnv(DefaultEnv[CRState, CRObs, CRAction]):
need to be adjusted when using larger grids (this can be done by manually specifying
a value for `max_episode_steps` when creating the environment with `posggym.make`).
- Arguments
+ Arguments:
---------
-
- `size` - the size (width and height) of grid.
- `num_goals` - the number of goal cells in the grid.
- `mode` - the mode of the environment, which determines the layout of goals in the
@@ -148,14 +146,14 @@ class CooperativeReachingEnv(DefaultEnv[CRState, CRObs, CRAction]):
---------------
- `v0`: Initial version
- References
+ References:
----------
- Arrasy Rahman, Elliot Fosong, Ignacio Carlucho, and Stefano V. Albrecht. 2023.
Generating Teammates for Training Robust Ad Hoc Teamwork Agents via Best-Response
Diversity. Transactions on Machine Learning Research.
"""
- metadata = {
+ metadata: ClassVar[dict] = {
"render_modes": ["human", "ansi", "rgb_array", "rgb_array_dict"],
"render_fps": 15,
}
@@ -165,9 +163,9 @@ def __init__(
size: int = 5,
num_goals: int = 4,
mode: str = "original",
- obs_distance: Optional[int] = None,
- render_mode: Optional[str] = None,
- ):
+ obs_distance: int | None = None,
+ render_mode: str | None = None,
+ ) -> None:
super().__init__(
CooperativeReachingModel(size, num_goals, mode, obs_distance),
render_mode=render_mode,
@@ -178,7 +176,7 @@ def __init__(
def render(self):
if self.render_mode is None:
assert self.spec is not None
- logger.warn(
+ logger.warning(
"You are calling render method without specifying any render mode. "
"You can specify the render_mode at initialization, "
f'e.g. posggym.make("{self.spec.id}", render_mode="rgb_array")'
@@ -294,18 +292,23 @@ class CooperativeReachingModel(M.POSGModel[CRState, CRObs, CRAction]):
"""
NUM_AGENTS = 2
+ MIN_GRID_SIZE = 3
+ MIN_GOALS = 1
+ NUM_GOALS = 4
- MODES = ["square", "line", "original"]
+ MODES: ClassVar[list] = ["square", "line", "original"]
def __init__(
self,
size: int,
num_goals: int,
mode: str,
- obs_distance: Optional[int],
- ):
- assert size >= 3, "Grid size must be at least 3"
- assert num_goals >= 1, "Must have at least one goal"
+ obs_distance: int | None,
+ ) -> None:
+ assert (
+ size >= self.MIN_GRID_SIZE
+ ), f"Grid size must be at least {self.MIN_GRID_SIZE}"
+ assert num_goals >= self.MIN_GOALS, "Must have at least one goal"
assert mode in self.MODES, f"Mode must be one of {self.MODES}"
if obs_distance is None:
obs_distance = 2 * size
@@ -323,10 +326,13 @@ def __init__(
points = equid_points_line(num_goals, self.size)
self.goals = {p: 1.0 for p in points}
else:
- if num_goals != 4:
- logger.warn(
- f"'original' mode only supports 4 goals, but got {num_goals}. "
- "Continuing with 4 goals."
+ if num_goals != self.NUM_GOALS:
+ logger.warning(
+ "'original' mode only supports %s goals, but got %s. "
+ "Continuing with %s goals.",
+ self.NUM_GOALS,
+ num_goals,
+ self.NUM_GOALS,
)
self.goals = {
(0, 0): 1.0,
@@ -362,7 +368,7 @@ def _coord_space(s):
self.is_symmetric = True
@property
- def reward_ranges(self) -> Dict[str, Tuple[float, float]]:
+ def reward_ranges(self) -> dict[str, tuple[float, float]]:
max_goal_value = max(self.goals.values())
return {i: (0.0, max_goal_value) for i in self.possible_agents}
@@ -372,7 +378,7 @@ def rng(self) -> seeding.RNG:
self._rng, _ = seeding.std_random()
return self._rng
- def get_agents(self, state: CRState) -> List[str]:
+ def get_agents(self, state: CRState) -> list[str]:
return list(self.possible_agents)
def sample_initial_state(self) -> CRState:
@@ -389,11 +395,11 @@ def sample_agent_initial_state(self, agent_id: str, obs: CRObs) -> CRState:
agent_idx = int(agent_id)
return (obs[0], obs[1]) if agent_idx == 0 else (obs[1], obs[0])
- def sample_initial_obs(self, state: CRState) -> Dict[str, CRObs]:
+ def sample_initial_obs(self, state: CRState) -> dict[str, CRObs]:
return self._get_obs(state)
def step(
- self, state: CRState, actions: Dict[str, CRAction]
+ self, state: CRState, actions: dict[str, CRAction]
) -> M.JointTimestep[CRState, CRObs]:
assert all(0 <= a_i < len(ACTIONS) for a_i in actions.values())
next_state = self._get_next_state(state, actions)
@@ -406,7 +412,7 @@ def step(
terminated = {i: all_done for i in self.possible_agents}
truncated = {i: False for i in self.possible_agents}
- info: Dict[str, Dict] = {i: {} for i in self.possible_agents}
+ info: dict[str, dict] = {i: {} for i in self.possible_agents}
if all_done:
for i in self.possible_agents:
info[i]["outcome"] = M.Outcome.WIN
@@ -415,7 +421,7 @@ def step(
next_state, obs, rewards, terminated, truncated, all_done, info
)
- def _get_next_state(self, state: CRState, actions: Dict[str, CRAction]) -> CRState:
+ def _get_next_state(self, state: CRState, actions: dict[str, CRAction]) -> CRState:
next_state = list(state)
for i, action_i in actions.items():
if action_i == DO_NOTHING:
@@ -426,8 +432,8 @@ def _get_next_state(self, state: CRState, actions: Dict[str, CRAction]) -> CRSta
)
return tuple(next_state)
- def _get_obs(self, state: CRState) -> Dict[str, CRObs]:
- obs: Dict[str, CRObs] = {}
+ def _get_obs(self, state: CRState) -> dict[str, CRObs]:
+ obs: dict[str, CRObs] = {}
for i in self.possible_agents:
idx = int(i)
other_idx = (idx + 1) % 2
@@ -442,10 +448,10 @@ def _get_obs(self, state: CRState) -> Dict[str, CRObs]:
obs[i] = (state_i, (self.size, self.size))
return obs
- def get_obs_coords(self, origin: Coord) -> List[Coord]:
+ def get_obs_coords(self, origin: Coord) -> list[Coord]:
"""Get the list of coords observed from agent at origin."""
obs_size = (2 * self.obs_distance) + 1
- obs_coords: List[Coord] = []
+ obs_coords: list[Coord] = []
for obs_col, obs_row in product(range(obs_size), repeat=2):
grid_col = origin[0] + obs_col - self.obs_distance
grid_row = origin[1] + obs_row - self.obs_distance
@@ -453,7 +459,7 @@ def get_obs_coords(self, origin: Coord) -> List[Coord]:
obs_coords.append((grid_col, grid_row))
return obs_coords
- def _get_rewards(self, state: CRState) -> Dict[str, float]:
+ def _get_rewards(self, state: CRState) -> dict[str, float]:
all_done = all(p == state[0] for p in state) and state[0] in self.goals
if all_done:
goal_value = self.goals[state[0]]
@@ -464,12 +470,16 @@ def _get_rewards(self, state: CRState) -> Dict[str, float]:
class CooperativeReachingGrid(Grid):
"""A grid for the Cooperative Reaching Problem."""
+ MIN_GRID_SIZE = 3
+
def __init__(
self,
size: int,
- goal_coords: List[Coord],
- ):
- assert size >= 3, "Grid size must be at least 3"
+ goal_coords: list[Coord],
+ ) -> None:
+ assert (
+ size >= self.MIN_GRID_SIZE
+ ), f"Grid size must be at least {self.MIN_GRID_SIZE}"
super().__init__(size, size, block_coords=set())
self.size = size
self.goal_coords = goal_coords
@@ -479,7 +489,7 @@ def get_shortest_path_distance(self, coord: Coord, goal: Coord) -> int:
"""Get the shortest path distance from coord to goal."""
return int(self.shortest_paths[goal][coord])
- def get_ascii_repr(self, agent_coords: Optional[CRState]) -> str:
+ def get_ascii_repr(self, agent_coords: CRState | None) -> str:
"""Get ascii repr of grid."""
grid_repr = []
for row in range(self.height):
@@ -501,7 +511,7 @@ def get_ascii_repr(self, agent_coords: Optional[CRState]) -> str:
return "\n".join([" ".join(r) for r in grid_repr])
-def equid_points_square(n_points: int, grid_size: int) -> List[Tuple[int, int]]:
+def equid_points_square(n_points: int, grid_size: int) -> list[tuple[int, int]]:
"""Return n_points equidistant points on square border of grid."""
assert 0 < n_points <= (grid_size - 1) * 4
perimeter_length = (grid_size - 1) * 4
@@ -522,7 +532,7 @@ def equid_points_square(n_points: int, grid_size: int) -> List[Tuple[int, int]]:
return points
-def equid_points_line(n_points: int, grid_size: int) -> List[Tuple[int, int]]:
+def equid_points_line(n_points: int, grid_size: int) -> list[tuple[int, int]]:
"""Return n_points equidistant points on line in middle of grid."""
assert 0 < n_points <= grid_size
col = int(grid_size / 2)
diff --git a/posggym/envs/grid_world/core.py b/posggym/envs/grid_world/core.py
index a89cb4b..8b7b6fa 100644
--- a/posggym/envs/grid_world/core.py
+++ b/posggym/envs/grid_world/core.py
@@ -3,11 +3,13 @@
import enum
import itertools
import random
+from collections.abc import Callable, Iterable
from queue import PriorityQueue, Queue
-from typing import Dict, Iterable, List, Optional, Set, Tuple
+from typing import TypeVar
+
# (x, y) coord = (col, row) coord
-Coord = Tuple[int, int]
+Coord = tuple[int, int]
class Direction(enum.IntEnum):
@@ -29,8 +31,8 @@ def __init__(
self,
grid_width: int,
grid_height: int,
- block_coords: Optional[Set[Coord]] = None,
- ):
+ block_coords: set[Coord] | None = None,
+ ) -> None:
self.width = grid_width
self.height = grid_height
@@ -39,7 +41,7 @@ def __init__(
self.block_coords = block_coords
@property
- def all_coords(self) -> List[Coord]:
+ def all_coords(self) -> list[Coord]:
"""The list of all locations on grid, including blocks."""
return list(itertools.product(range(self.width), range(self.width)))
@@ -49,7 +51,7 @@ def n_coords(self) -> int:
return self.height * self.width
@property
- def unblocked_coords(self) -> List[Coord]:
+ def unblocked_coords(self) -> list[Coord]:
"""The list of all coordinates on the grid excluding blocks."""
return [coord for coord in self.all_coords if coord not in self.block_coords]
@@ -67,7 +69,7 @@ def get_neighbours(
coord: Coord,
ignore_blocks: bool = False,
include_out_of_bounds: bool = False,
- ) -> List[Coord]:
+ ) -> list[Coord]:
"""Get set of adjacent non-blocked coordinates."""
neighbours = []
if coord[1] > 0 or include_out_of_bounds:
@@ -114,7 +116,7 @@ def get_next_coord(
def get_coords_within_dist(
self, origin: Coord, dist: int, ignore_blocks: bool, include_origin: bool
- ) -> Set[Coord]:
+ ) -> set[Coord]:
"""Get set of coords within given distance from origin."""
if dist == 0:
return {origin} if include_origin else set()
@@ -141,14 +143,14 @@ def get_coords_within_dist(
def get_coords_at_dist(
self, origin: Coord, dist: int, ignore_blocks: bool
- ) -> Set[Coord]:
+ ) -> set[Coord]:
"""Get set of coords at given distance from origin."""
if dist == 0:
return {origin}
in_dist_coords = self.get_coords_within_dist(origin, dist, ignore_blocks, False)
- at_dist_coords: Set[Coord] = set()
+ at_dist_coords: set[Coord] = set()
for coord in in_dist_coords:
if self.manhattan_dist(origin, coord) == dist:
at_dist_coords.add(coord)
@@ -157,7 +159,7 @@ def get_coords_at_dist(
def get_min_dist_coords(
self, origin: Coord, coords: Iterable[Coord]
- ) -> List[Coord]:
+ ) -> list[Coord]:
"""Get list of coord in coords closest to origin."""
dists = self.get_coords_by_distance(origin, coords)
if len(dists) == 0:
@@ -166,9 +168,9 @@ def get_min_dist_coords(
def get_coords_by_distance(
self, origin: Coord, coords: Iterable[Coord]
- ) -> Dict[int, List[Coord]]:
+ ) -> dict[int, list[Coord]]:
"""Get mapping from distance to coords at that distance from origin."""
- dists: Dict[int, List[Coord]] = {}
+ dists: dict[int, list[Coord]] = {}
for coord in coords:
d = self.manhattan_dist(origin, coord)
if d not in dists:
@@ -178,14 +180,14 @@ def get_coords_by_distance(
def get_all_shortest_paths(
self, origins: Iterable[Coord]
- ) -> Dict[Coord, Dict[Coord, int]]:
+ ) -> dict[Coord, dict[Coord, int]]:
"""Get shortest path distance from every origin to all other coords."""
src_dists = {}
for origin in origins:
src_dists[origin] = self.dijkstra(origin)
return src_dists
- def dijkstra(self, origin: Coord) -> Dict[Coord, int]:
+ def dijkstra(self, origin: Coord) -> dict[Coord, int]:
"""Get shortest path distance between origin and all other coords."""
dist = {origin: 0}
pq = PriorityQueue() # type: ignore
@@ -205,7 +207,7 @@ def dijkstra(self, origin: Coord) -> Dict[Coord, int]:
visited.add(adj_coord)
return dist
- def get_connected_components(self) -> List[Set[Coord]]:
+ def get_connected_components(self) -> list[set[Coord]]:
"""Get list of connected components.
A connected component is the set of all coords that are connected to
@@ -244,7 +246,7 @@ def map_relative_to_absolute_coord(
This is useful for mapping from coords in agent observations to coords in the
actual grid.
- Arguments
+ Arguments:
---------
rel_coord
the relative coordinate. This is the coordinate relative to the origin
@@ -263,7 +265,7 @@ def map_relative_to_absolute_coord(
relative grid reaches. Defines where the zeroth column of the grid is
relative to the origin.
- Returns
+ Returns:
-------
Coord
The relative coord mapped to the actual grid.
@@ -286,11 +288,11 @@ def map_relative_to_absolute_coord(
return (grid_col, grid_row)
def get_rectangular_bounds(
- self, origin: Coord, facing_dir: Direction, rect_size: Tuple[int, int, int, int]
- ) -> Tuple[int, int, int, int]:
+ self, origin: Coord, facing_dir: Direction, rect_size: tuple[int, int, int, int]
+ ) -> tuple[int, int, int, int]:
"""Get rectangular bounds for a rectangle with given origin and size.
- Arguments
+ Arguments:
---------
origin
the origin coordinate for the rectangle
@@ -303,7 +305,7 @@ def get_rectangular_bounds(
rect_size[2] = number of cells left
rect_size[3] = number of cells right
- Returns
+ Returns:
-------
Tuple[int, int, int, int]
min_col, max_col, min_row, max_row of rectangle that is within the grid's
@@ -332,14 +334,14 @@ def get_rectangular_bounds(
return min_col, max_col, min_row, max_row
def get_rectangular_padding(
- self, origin: Coord, facing_dir: Direction, rect_size: Tuple[int, int, int, int]
- ) -> Tuple[Tuple[int, int], Tuple[int, int]]:
+ self, origin: Coord, facing_dir: Direction, rect_size: tuple[int, int, int, int]
+ ) -> tuple[tuple[int, int], tuple[int, int]]:
"""Get padding for a rectangle with given origin and size.
The padding quantities are the number of cells the rectangle is out-of-bounds
on each dimension.
- Arguments
+ Arguments:
---------
origin
the origin coordinate for the rectangle
@@ -352,7 +354,7 @@ def get_rectangular_padding(
rect_size[2] = number of cells left
rect_size[3] = number of cells right
- Returns
+ Returns:
-------
Tuple[Tuple[int, int], Tuple[int, int]]
(before_col, after_col), (before_row, after_row padding)
@@ -390,12 +392,12 @@ def __init__(
self,
width: int,
height: int,
- mask: Set[Coord],
+ mask: set[Coord],
max_obstacle_size: int,
max_num_obstacles: int,
ensure_grid_connected: bool,
- seed: Optional[int] = None,
- ):
+ seed: int | None = None,
+ ) -> None:
assert max_obstacle_size > 0
self.width = width
self.height = height
@@ -407,7 +409,7 @@ def __init__(
def generate(self) -> Grid:
"""Generate a new grid."""
- block_coords: Set[Coord] = set()
+ block_coords: set[Coord] = set()
for _ in range(self.max_num_obstacles):
obstacle = self._get_random_obstacle()
if not self.mask.intersection(obstacle):
@@ -419,12 +421,12 @@ def generate(self) -> Grid:
return grid
- def generate_n(self, n: int) -> List[Grid]:
+ def generate_n(self, n: int) -> list[Grid]:
"""Generate N new grids."""
grids = [self.generate() for _ in range(n)]
return grids
- def _get_random_obstacle(self) -> Set[Coord]:
+ def _get_random_obstacle(self) -> set[Coord]:
obstacle_height = self._rng.randint(1, self.max_obstacle_size)
obstacle_width = self._rng.randint(1, self.max_obstacle_size)
obstacle_x = self._rng.randint(0, self.width - 1)
@@ -475,13 +477,13 @@ def connect_grid_components(self, grid: Grid) -> Grid:
return grid
def _component_distance(
- self, grid: Grid, origin: Coord, component: Set[Coord]
+ self, grid: Grid, origin: Coord, component: set[Coord]
) -> int:
return min([grid.manhattan_dist(origin, coord) for coord in component])
def _get_closest_pair(
- self, grid: Grid, component_0: Set[Coord], component_1: Set[Coord]
- ) -> Tuple[Coord, Coord]:
+ self, grid: Grid, component_0: set[Coord], component_1: set[Coord]
+ ) -> tuple[Coord, Coord]:
min_dist = grid.width * grid.height
min_pair = ((0, 0), (0, 0))
for c0, c1 in itertools.product(component_0, component_1):
@@ -493,7 +495,7 @@ def _get_closest_pair(
def _get_shortest_direct_path(
self, grid: Grid, start_coord: Coord, goal_coord: Coord
- ) -> List[Coord]:
+ ) -> list[Coord]:
"""Get shortest direct path between two coords.
This will possibly include blocks in the path, but will greedily chose
@@ -555,8 +557,8 @@ class GridCycler:
"""Class for handling cycling through a set of generated grids."""
def __init__(
- self, grids: List[Grid], shuffle_each_cycle: bool, seed: Optional[int] = None
- ):
+ self, grids: list[Grid], shuffle_each_cycle: bool, seed: int | None = None
+ ) -> None:
self.grids = grids
self.shuffle = shuffle_each_cycle
self._rng = random.Random(seed)
@@ -573,3 +575,7 @@ def next(self) -> Grid:
grid = self.grids[self._next_idx]
self._next_idx += 1
return grid
+
+
+GridSubClass = TypeVar("GridSubClass", bound=Grid)
+SupportedGridTypes = dict[str, tuple[Callable[[], GridSubClass], int]]
diff --git a/posggym/envs/grid_world/driving.py b/posggym/envs/grid_world/driving.py
index f12ae02..2412380 100644
--- a/posggym/envs/grid_world/driving.py
+++ b/posggym/envs/grid_world/driving.py
@@ -1,8 +1,9 @@
"""The Driving Grid World Environment."""
+from __future__ import annotations
import enum
from itertools import product
-from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple, Union
+from typing import TYPE_CHECKING, Any, ClassVar, NamedTuple
from gymnasium import spaces
@@ -13,6 +14,18 @@
from posggym.utils import seeding
+if TYPE_CHECKING:
+ from collections.abc import Sequence
+
+
+def max_enum_value(enum_cls: type[enum.Enum]):
+ return max(enum_member.value for enum_member in enum_cls)
+
+
+def min_enum_value(enum_cls: type[enum.Enum]):
+ return min(enum_member.value for enum_member in enum_cls)
+
+
class Speed(enum.IntEnum):
"""A speed setting for a vehicle."""
@@ -35,7 +48,7 @@ class VehicleState(NamedTuple):
init_dest_dist: int
-DState = Tuple[VehicleState, ...]
+DState = tuple[VehicleState, ...]
# Initial direction and speed of each vehicle
INIT_DIR = Direction.NORTH
@@ -52,13 +65,12 @@ class VehicleState(NamedTuple):
ACTIONS = [DO_NOTHING, ACCELERATE, DECELERATE, TURN_RIGHT, TURN_LEFT]
ACTIONS_STR = ["0", "acc", "dec", "tr", "tl"]
-# Obs = [
# V0 Obs = (adj_obs, speed, dest_coord, dest_reached, crashed)
# V1 Obs = (adj_obs, speed, Coord, dest_coord, dest_reached, crashed)
-DObs = Union[
- Tuple[Tuple[int, ...], Speed, Coord, int, int],
- Tuple[Tuple[int, ...], Speed, Coord, Coord, int, int],
-]
+DObs = (
+ tuple[tuple[int, ...], Speed, Coord, int, int]
+ | tuple[tuple[int, ...], Speed, Coord, Coord, int, int]
+)
# Cell obs
VEHICLE = 0
@@ -169,9 +181,8 @@ class DrivingEnv(DefaultEnv[DState, DObs, DAction]):
can be done by manually specifying a value for `max_episode_steps` when creating the
environment with `posggym.make`).
- Arguments
+ Arguments:
---------
-
- `grid` - the grid layout to use. This can either be a string specifying one of
the supported grids, or a custom :class:`DrivingGrid` object
(default = `"14x14RoundAbout"`).
@@ -219,33 +230,34 @@ class DrivingEnv(DefaultEnv[DState, DObs, DAction]):
from progress) and min return is -1.0 (-1.0 for crashing),
- `v0`: Initial version
- References
+ References:
----------
- Adam Lerer and Alexander Peysakhovich. 2019. Learning Existing Social Conventions
via Observationally Augmented Self-Play. In Proceedings of the 2019 AAAI/ACM
- Conference on AI, Ethics, and Society. 107–114.
+ Conference on AI, Ethics, and Society. 107-114.
- Kevin R. McKee, Joel Z. Leibo, Charlie Beattie, and Richard Everett. 2022.
Quantifying the Effects of Environment and Population Diversity in Multi-Agent
- Reinforcement Learning. Autonomous Agents and Multi-Agent Systems 36, 1 (2022), 1–16
+ Reinforcement Learning. Autonomous Agents and Multi-Agent Systems 36, 1 (2022), 1-16
"""
- metadata = {
+ metadata: ClassVar[dict] = {
"render_modes": ["human", "ansi", "rgb_array", "rgb_array_dict"],
"render_fps": 15,
}
def __init__(
self,
- grid: Union[str, "DrivingGrid"] = "14x14RoundAbout",
+ grid: str | DrivingGrid = "14x14RoundAbout",
num_agents: int = 2,
- obs_dim: Tuple[int, int, int] = (3, 1, 1),
- render_mode: Optional[str] = None,
- ):
+ obs_dim: tuple[int, int, int] = (3, 1, 1),
+ render_mode: str | None = None,
+ ) -> None:
super().__init__(
DrivingModel(grid, num_agents, obs_dim),
render_mode=render_mode,
)
+
self._obs_dim = obs_dim
self.renderer = None
self._agent_imgs = None
@@ -253,7 +265,7 @@ def __init__(
def render(self):
if self.render_mode is None:
assert self.spec is not None
- logger.warn(
+ logger.warning(
"You are calling render method without specifying any render mode. "
"You can specify the render_mode at initialization, "
f'e.g. posggym.make("{self.spec.id}", render_mode="rgb_array")'
@@ -285,17 +297,20 @@ def _render_ansi(self):
return "\n".join(output) + "\n"
def _render_img(self):
+ assert self.render_mode in ["human", "rgb", "rgb_array", "rgb_array_dict"]
model: DrivingModel = self.model # type: ignore
import posggym.envs.grid_world.render as render_lib
- if self.renderer is None:
+ if self.renderer is None and self.render_mode is not None:
self.renderer = render_lib.GWRenderer(
self.render_mode,
model.grid,
render_fps=self.metadata["render_fps"],
env_name="Driving",
)
+ if self.renderer is None:
+ return
if self._agent_imgs is None:
self._agent_imgs = {
@@ -321,7 +336,7 @@ def _render_img(self):
render_lib.GWRectangle(
vs.dest_coord,
self.renderer.cell_size,
- render_lib.get_agent_color(i)[1],
+ render_lib.get_agent_color(str(i))[1],
)
)
@@ -337,7 +352,7 @@ def _render_img(self):
}
# Add visualization for crashed agents
- for i, vs in enumerate(self._state):
+ for _, vs in enumerate(self._state):
if vs.crashed:
render_objects.append(
render_lib.GWCircle(
@@ -381,10 +396,10 @@ class DrivingModel(M.POSGModel[DState, DObs, DAction]):
def __init__(
self,
- grid: Union[str, "DrivingGrid"],
+ grid: str | DrivingGrid,
num_agents: int,
- obs_dim: Tuple[int, int, int],
- ):
+ obs_dim: tuple[int, int, int],
+ ) -> None:
if isinstance(grid, str):
assert grid in SUPPORTED_GRIDS, (
f"Unsupported grid '{grid}'. If grid argument is a string it must be "
@@ -410,6 +425,10 @@ def __init__(
self._grid = grid
self.obs_dim = obs_dim
self._obs_front, self._obs_back, self._obs_side = obs_dim
+ self.num_agents = num_agents
+ self.max_speeds: list[Speed] = [Speed.FORWARD_FAST] * self.num_agents
+ self.min_speeds: list[Speed] = [Speed.STOPPED] * self.num_agents
+ self.allow_reverse_turn = [False] * self.num_agents
def _coord_space():
return spaces.Tuple(
@@ -464,7 +483,7 @@ def _coord_space():
self.is_symmetric = True
@property
- def reward_ranges(self) -> Dict[str, Tuple[float, float]]:
+ def reward_ranges(self) -> dict[str, tuple[float, float]]:
return {
i: (
self.R_CRASH_VEHICLE,
@@ -479,23 +498,23 @@ def rng(self) -> seeding.RNG:
self._rng, _ = seeding.std_random()
return self._rng
- def get_agents(self, state: DState) -> List[str]:
+ def get_agents(self, state: DState) -> list[str]:
return list(self.possible_agents)
@property
- def grid(self) -> "DrivingGrid":
+ def grid(self) -> DrivingGrid:
"""The underlying grid for this model instance."""
return self._grid
@grid.setter
- def grid(self, grid: "DrivingGrid"):
+ def grid(self, grid: DrivingGrid):
assert (self._grid.height, self._grid.width) == (grid.height, grid.width)
self._grid = grid
def sample_initial_state(self) -> DState:
state = []
- chosen_start_coords: Set[Coord] = set()
- chosen_dest_coords: Set[Coord] = set()
+ chosen_start_coords: set[Coord] = set()
+ chosen_dest_coords: set[Coord] = set()
for i in range(len(self.possible_agents)):
start_coords_i = self.grid.start_coords[i]
avail_start_coords = start_coords_i.difference(chosen_start_coords)
@@ -525,13 +544,14 @@ def sample_initial_state(self) -> DState:
return tuple(state)
def sample_agent_initial_state(self, agent_id: str, obs: DObs) -> DState:
+ assert isinstance(obs[3], tuple)
agent_idx = int(agent_id)
agent_start_coord = obs[2]
agent_dest_coord = obs[3]
state = []
- chosen_start_coords: Set[Coord] = set()
- chosen_dest_coords: Set[Coord] = set()
+ chosen_start_coords: set[Coord] = set()
+ chosen_dest_coords: set[Coord] = set()
chosen_start_coords.add(agent_start_coord)
chosen_dest_coords.add(agent_dest_coord)
@@ -570,11 +590,11 @@ def sample_agent_initial_state(self, agent_id: str, obs: DObs) -> DState:
state.append(state_i)
return tuple(state)
- def sample_initial_obs(self, state: DState) -> Dict[str, DObs]:
+ def sample_initial_obs(self, state: DState) -> dict[str, DObs]:
return self._get_obs(state)
def step(
- self, state: DState, actions: Dict[str, DAction]
+ self, state: DState, actions: dict[str, DAction]
) -> M.JointTimestep[DState, DObs]:
assert all(a_i in ACTIONS for a_i in actions.values())
next_state = self._get_next_state(state, actions)
@@ -587,7 +607,7 @@ def step(
truncated = {i: False for i in self.possible_agents}
all_done = all(terminated.values())
- info: Dict[str, Dict] = {i: {} for i in self.possible_agents}
+ info: dict[str, dict] = {i: {} for i in self.possible_agents}
for idx in range(len(self.possible_agents)):
if next_state[idx].dest_reached:
outcome_i = M.Outcome.WIN
@@ -602,8 +622,8 @@ def step(
)
def _get_next_state(
- self, state: DState, actions: Dict[str, DAction]
- ) -> Tuple[DState, List[bool]]:
+ self, state: DState, actions: dict[str, DAction]
+ ) -> tuple[DState, list[bool]]:
exec_order = list(range(len(self.possible_agents)))
self.rng.shuffle(exec_order)
@@ -620,11 +640,15 @@ def _get_next_state(
vehicle_coords.pop(state_i.coord)
- next_speed = self.get_next_speed(action_i, state_i.speed)
- move_dir = self.get_move_direction(action_i, next_speed, state_i.facing_dir)
- next_dir = self.get_next_direction(action_i, next_speed, state_i.facing_dir)
+ next_speed = self.get_next_speed_(action_i, state_i.speed, idx)
+ move_dir = self.get_move_direction(
+ action_i, next_speed, state_i.facing_dir, self.allow_reverse_turn[idx]
+ )
+ next_dir = self.get_next_direction(
+ action_i, next_speed, state_i.facing_dir, self.allow_reverse_turn[idx]
+ )
next_coord, crashed, hit_vehicle = self._get_next_coord(
- state_i.coord, next_speed, move_dir, vehicle_coords
+ state_i.coord, next_speed, move_dir, set(vehicle_coords.keys())
)
if next_coord == state_i.coord:
# crashed or hit a wall
@@ -635,7 +659,7 @@ def _get_next_state(
self.grid.get_shortest_path_distance(next_coord, state_i.dest_coord),
)
- if crashed:
+ if crashed and hit_vehicle is not None:
# update state of vehicle that was hit
jdx = vehicle_coords[hit_vehicle]
next_state_j = next_state[jdx]
@@ -667,43 +691,74 @@ def _get_next_state(
@staticmethod
def get_move_direction(
- action: DAction, speed: Speed, curr_dir: Direction
+ action: DAction,
+ speed: Speed,
+ curr_dir: Direction,
+ allow_reverse_turn: bool = False,
) -> Direction:
- if speed == Speed.REVERSE:
+ if speed < Speed.STOPPED and not allow_reverse_turn:
# No turning while in reverse,
# so movement dir is always just the opposite of current direction
return Direction((curr_dir + 2) % len(Direction))
- return DrivingModel.get_next_direction(action, speed, curr_dir)
+ return DrivingModel.get_next_direction(
+ action, speed, curr_dir, allow_reverse_turn
+ )
@staticmethod
def get_next_direction(
- action: DAction, speed: Speed, curr_dir: Direction
+ action: DAction,
+ speed: Speed,
+ curr_dir: Direction,
+ allow_reverse_turn: bool = False,
) -> Direction:
- if action == TURN_RIGHT and speed != Speed.REVERSE:
+ if action == TURN_RIGHT and (speed >= Speed.STOPPED and not allow_reverse_turn):
return Direction((curr_dir + 1) % len(Direction))
- if action == TURN_LEFT and speed != Speed.REVERSE:
+ if action == TURN_LEFT and (speed >= Speed.STOPPED and not allow_reverse_turn):
return Direction((curr_dir - 1) % len(Direction))
return curr_dir
+ def get_next_speed_(
+ self, action: DAction, curr_speed: Speed, agent_idx: int = 0
+ ) -> Speed:
+ return DrivingModel.get_next_speed(
+ action, curr_speed, self.max_speeds[agent_idx], self.min_speeds[agent_idx]
+ )
+
@staticmethod
- def get_next_speed(action: DAction, curr_speed: Speed) -> Speed:
+ def get_next_speed(
+ action: DAction,
+ curr_speed: Speed,
+ max_speed: Speed | None = None,
+ min_speed: Speed | None = None,
+ ) -> Speed:
+ if max_speed is None:
+ max_speed = max_enum_value(Speed)
+ if min_speed is None:
+ min_speed = min_enum_value(Speed)
+
if action == DO_NOTHING:
return curr_speed
+
if action in (TURN_LEFT, TURN_RIGHT):
- if curr_speed == Speed.FORWARD_FAST:
- return Speed.FORWARD_SLOW
+ if curr_speed > Speed.STOPPED:
+ return Speed(curr_speed - 1)
+
return curr_speed
+
if action == ACCELERATE:
- return Speed(min(curr_speed + 1, Speed.FORWARD_FAST))
- return Speed(max(curr_speed - 1, Speed.REVERSE))
+ return Speed(min(curr_speed + 1, max_speed))
+ if action == DECELERATE:
+ return Speed(max(curr_speed - 1, min_speed))
+
+ raise ValueError("Invalid Action!")
def _get_next_coord(
self,
curr_coord: Coord,
speed: Speed,
move_dir: Direction,
- vehicle_coords: Set[Coord],
- ) -> Tuple[Coord, bool, Optional[Coord]]:
+ vehicle_coords: set[Coord],
+ ) -> tuple[Coord, bool, Coord] | None:
# assumes curr_coord isn't in vehicle coords
next_coord = curr_coord
crashed = False
@@ -721,8 +776,8 @@ def _get_next_coord(
return (next_coord, crashed, hit_vehicle_coord)
- def _get_obs(self, state: DState) -> Dict[str, DObs]:
- obs: Dict[str, DObs] = {}
+ def _get_obs(self, state: DState) -> dict[str, DObs]:
+ obs: dict[str, DObs] = {}
for i in self.possible_agents:
idx = int(i)
local_cell_obs = self._get_local_cell__obs(
@@ -747,7 +802,7 @@ def _get_local_cell__obs(
vehicle_coords: Sequence[Coord],
facing_dir: Direction,
dest_coord: Coord,
- ) -> Tuple[int, ...]:
+ ) -> tuple[int, ...]:
obs_depth = self._obs_front + self._obs_back + 1
obs_width = (2 * self._obs_side) + 1
agent_coord = vehicle_coords[agent_idx]
@@ -769,7 +824,7 @@ def _get_local_cell__obs(
def _map_obs_to_grid_coord(
self, obs_coord: Coord, agent_coord: Coord, facing_dir: Direction
- ) -> Optional[Coord]:
+ ) -> Coord | None:
if facing_dir == Direction.NORTH:
grid_row = agent_coord[1] + obs_coord[1] - self._obs_front
grid_col = agent_coord[0] + obs_coord[0] - self._obs_side
@@ -787,19 +842,19 @@ def _map_obs_to_grid_coord(
return (grid_col, grid_row)
return None
- def get_obs_coords(self, origin: Coord, facing_dir: Direction) -> List[Coord]:
+ def get_obs_coords(self, origin: Coord, facing_dir: Direction) -> list[Coord]:
"""Get the list of coords observed by agent at origin."""
obs_depth = self._obs_front + self._obs_back + 1
obs_width = (2 * self._obs_side) + 1
- obs_coords: List[Coord] = []
+ obs_coords: list[Coord] = []
for col, row in product(range(obs_width), range(obs_depth)):
obs_grid_coord = self._map_obs_to_grid_coord((col, row), origin, facing_dir)
if obs_grid_coord is not None:
obs_coords.append(obs_grid_coord)
return obs_coords
- def _get_rewards(self, state: DState, next_state: DState) -> Dict[str, float]:
- rewards: Dict[str, float] = {}
+ def _get_rewards(self, state: DState, next_state: DState) -> dict[str, float]:
+ rewards: dict[str, float] = {}
for i in self.possible_agents:
idx = int(i)
if state[idx].crashed or state[idx].dest_reached:
@@ -827,10 +882,10 @@ def __init__(
self,
grid_width: int,
grid_height: int,
- block_coords: Set[Coord],
- start_coords: List[Set[Coord]],
- dest_coords: List[Set[Coord]],
- ):
+ block_coords: set[Coord],
+ start_coords: list[set[Coord]],
+ dest_coords: list[set[Coord]],
+ ) -> None:
super().__init__(grid_width, grid_height, block_coords)
assert len(start_coords) == len(dest_coords)
self.start_coords = start_coords
@@ -852,9 +907,9 @@ def get_max_shortest_path_distance(self) -> int:
def get_ascii_repr(
self,
- vehicle_coords: List[Coord],
- vehicle_dirs: List[Direction],
- vehicle_dests: List[Coord],
+ vehicle_coords: list[Coord],
+ vehicle_dirs: list[Direction],
+ vehicle_dests: list[Coord],
) -> str:
"""Get ascii repr of grid."""
grid_repr = []
@@ -870,7 +925,7 @@ def get_ascii_repr(
row_repr.append(".")
grid_repr.append(row_repr)
- for coord, direction in zip(vehicle_coords, vehicle_dirs):
+ for coord, direction in zip(vehicle_coords, vehicle_dirs, strict=False):
grid_repr[coord[0]][coord[1]] = DIRECTION_ASCII_REPR[direction]
return "\n".join([" ".join(r) for r in grid_repr])
@@ -918,13 +973,13 @@ def parse_grid_str(grid_str: str, supported_num_agents: int) -> DrivingGrid:
grid_width = len(row_strs[0])
agent_start_chars = set(["+"] + [str(i) for i in range(10)])
- agent_dest_chars = set(["-"] + list("abcdefghij"))
+ agent_dest_chars = {"-", *list("abcdefghij")}
- block_coords: Set[Coord] = set()
- shared_start_coords: Set[Coord] = set()
- agent_start_coords_map: Dict[int, Set[Coord]] = {}
- shared_dest_coords: Set[Coord] = set()
- agent_dest_coords_map: Dict[int, Set[Coord]] = {}
+ block_coords: set[Coord] = set()
+ shared_start_coords: set[Coord] = set()
+ agent_start_coords_map: dict[int, set[Coord]] = {}
+ shared_dest_coords: set[Coord] = set()
+ agent_dest_coords_map: dict[int, set[Coord]] = {}
for r, c in product(range(grid_height), range(grid_width)):
coord = (c, r)
char = row_strs[r][c]
@@ -957,8 +1012,8 @@ def parse_grid_str(grid_str: str, supported_num_agents: int) -> DrivingGrid:
if len(included_agent_ids) > 0:
assert max(included_agent_ids) < supported_num_agents
- start_coords: List[Set[Coord]] = []
- dest_coords: List[Set[Coord]] = []
+ start_coords: list[set[Coord]] = []
+ dest_coords: list[set[Coord]] = []
for i in range(supported_num_agents):
agent_start_coords = set(shared_start_coords)
agent_start_coords.update(agent_start_coords_map.get(i, {}))
@@ -978,14 +1033,29 @@ def parse_grid_str(grid_str: str, supported_num_agents: int) -> DrivingGrid:
# (grid_make_fn, max step_limit, )
-SUPPORTED_GRIDS: Dict[str, Dict[str, Any]] = {
+SUPPORTED_GRIDS: dict[str, dict[str, Any]] = {
"3x3": {
- "grid_str": ("a1.\n" ".#.\n" ".0b\n"),
+ # fmt: off
+ "grid_str": (
+ "a1.\n"
+ ".#.\n"
+ ".0b\n"
+ ),
+ # fmt: on
"supported_num_agents": 2,
"max_episode_steps": 15,
},
"6x6Intersection": {
- "grid_str": ("##0b##\n" "##..##\n" "d....3\n" "2....c\n" "##..##\n" "##a1##\n"),
+ # fmt: off
+ "grid_str": (
+ "##0b##\n"
+ "##..##\n"
+ "d....3\n"
+ "2....c\n"
+ "##..##\n"
+ "##a1##\n"
+ ),
+ # fmt: on
"supported_num_agents": 4,
"max_episode_steps": 20,
},
diff --git a/posggym/envs/grid_world/driving_gen.py b/posggym/envs/grid_world/driving_gen.py
index 36cb7b5..9dfcd66 100644
--- a/posggym/envs/grid_world/driving_gen.py
+++ b/posggym/envs/grid_world/driving_gen.py
@@ -1,5 +1,5 @@
"""The Generated Driving Grid World Environment."""
-from typing import Any, Dict, Optional, Set, Tuple, Union
+from typing import Any
from posggym.envs.grid_world.core import Coord, GridCycler, GridGenerator
from posggym.envs.grid_world.driving import DObs, DrivingEnv, DrivingGrid
@@ -14,9 +14,8 @@ class DrivingGenEnv(DrivingEnv):
For environment attributes see [Driving](/environments/grid_world/driving)
environment class documentation.
- Arguments
+ Arguments:
---------
-
- `num_agents` - the number of agents in the environment (default = `2`).
- `obs_dim` - the local observation dimensions, specifying how many cells in front,
behind, and to each side the agent observes (default = `(3, 1, 1)`, resulting
@@ -68,12 +67,12 @@ class DrivingGenEnv(DrivingEnv):
def __init__(
self,
num_agents: int = 2,
- obs_dim: Tuple[int, int, int] = (3, 1, 2),
- generator_params: Union[str, Dict[str, int]] = "14x14",
- n_grids: Optional[int] = None,
+ obs_dim: tuple[int, int, int] = (3, 1, 2),
+ generator_params: str | dict[str, int] = "14x14",
+ n_grids: int | None = None,
shuffle_grid_order: bool = True,
- render_mode: Optional[str] = None,
- ):
+ render_mode: str | None = None,
+ ) -> None:
if isinstance(generator_params, str):
assert generator_params in SUPPORTED_GEN_PARAMS, (
f"Unsupported grid generator parameters'{generator_params}'. If "
@@ -82,7 +81,7 @@ def __init__(
)
generator_params = SUPPORTED_GEN_PARAMS[generator_params][0]
- self._generator_params = generator_params
+ self._generator_params: dict[str, int] = generator_params # type: ignore
self._n_grids = n_grids
self._shuffle_grid_order = shuffle_grid_order
self._gen = DrivingGridGenerator(**self._generator_params)
@@ -90,7 +89,7 @@ def __init__(
if n_grids is not None:
grids = self._gen.generate_n(n_grids)
self._cycler = GridCycler(grids, shuffle_grid_order)
- grid: "DrivingGrid" = grids[0] # type: ignore
+ grid: DrivingGrid = grids[0] # type: ignore
else:
self._cycler = None # type: ignore
grid = self._gen.generate()
@@ -103,8 +102,8 @@ def __init__(
)
def reset(
- self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None
- ) -> Tuple[Dict[str, DObs], Dict[str, Dict]]:
+ self, *, seed: int | None = None, options: dict[str, Any] | None = None
+ ) -> tuple[dict[str, DObs], dict[str, dict]]:
if seed is not None:
self._model_seed = seed
self._gen = DrivingGridGenerator(seed=seed, **self._generator_params)
@@ -137,8 +136,8 @@ def __init__(
height: int,
max_obstacle_size: int,
max_num_obstacles: int,
- seed: Optional[int] = None,
- ):
+ seed: int | None = None,
+ ) -> None:
super().__init__(
width,
height,
@@ -152,7 +151,7 @@ def __init__(
self._start_coords = [self.mask for _ in range(len(self.mask))]
self._dest_coords = [self.mask for _ in range(len(self.mask))]
- def _generate_mask(self, width: int, height: int) -> Set[Coord]:
+ def _generate_mask(self, width: int, height: int) -> set[Coord]:
start = 1
mask = set()
for x in range(start, width, 2):
diff --git a/posggym/envs/grid_world/level_based_foraging.py b/posggym/envs/grid_world/level_based_foraging.py
index 69a903e..67837c5 100644
--- a/posggym/envs/grid_world/level_based_foraging.py
+++ b/posggym/envs/grid_world/level_based_foraging.py
@@ -5,7 +5,7 @@
from collections import defaultdict
from itertools import product
from pathlib import Path
-from typing import Dict, List, NamedTuple, Optional, Tuple, Union
+from typing import ClassVar, NamedTuple
import numpy as np
from gymnasium import spaces
@@ -35,8 +35,8 @@ class Food(NamedTuple):
class LBFState(NamedTuple):
"""State in Level-Based Foraging environment."""
- players: Tuple[Player, ...]
- food: Tuple[Food, ...]
+ players: tuple[Player, ...]
+ food: tuple[Food, ...]
# sum of levels of all food that have been spawned in the episode
food_spawned: int
@@ -69,7 +69,7 @@ class CellEntity(enum.IntEnum):
AGENT = 3
-LBFObs = Union[Tuple[int, ...], np.ndarray]
+LBFObs = tuple[int, ...] | np.ndarray
class LBFEntityObs(NamedTuple):
@@ -186,9 +186,8 @@ class LevelBasedForagingEnv(DefaultEnv[LBFState, LBFObs, LBFAction]):
need to be adjusted when using larger grids (this can be done by manually specifying
a value for `max_episode_steps` when creating the environment with `posggym.make`).
- Arguments
+ Arguments:
---------
-
- `num_agents` - the number of agents in the environment (default = `2`).
- `max_agent_level` - the maximum level of an agent (default = `3`).
- `size` - the width and height of the square grid world (default = `10`).
@@ -216,19 +215,19 @@ class LevelBasedForagingEnv(DefaultEnv[LBFState, LBFObs, LBFAction]):
(`field_size`/`size` is now a single int)
- `v2`: Version adapted from
- References
+ References:
----------
- Stefano V. Albrecht and Subramanian Ramamoorthy. 2013. A Game-Theoretic Model and
Best-Response Learning Method for Ad Hoc Coordination in Multia-gent Systems.
In Proceedings of the 2013 International Conference on Autonomous Agents and
- Multi-Agent Systems. 1155–1156.
+ Multi-Agent Systems. 1155-1156.
- S. V. Albrecht and Peter Stone. 2017. Reasoning about Hypothetical Agent
Behaviours and Their Parameters. In 16th International Conference on Autonomous
Agents and Multiagent Systems 2017. International Foundation for Autonomous
- Agents and Multiagent Systems, 547–555
+ Agents and Multiagent Systems, 547-555
- Filippos Christianos, Lukas Schäfer, and Stefano Albrecht. 2020. Shared Experience
Actor-Critic for Multi-Agent Reinforcement Learning. Advances in Neural
- Information Processing Systems 33 (2020), 10707–10717
+ Information Processing Systems 33 (2020), 10707-10717
- Georgios Papoudakis, Filippos Christianos, Lukas Schäfer, and Stefano V. Albrecht.
2021. Benchmarking Multi-Agent Deep Reinforcement Learning Algorithms in
Cooperative Tasks. In Thirty-Fifth Conference on Neural Information Processing
@@ -236,7 +235,7 @@ class LevelBasedForagingEnv(DefaultEnv[LBFState, LBFObs, LBFAction]):
"""
- metadata = {
+ metadata: ClassVar[dict] = {
"render_modes": ["human", "rgb_array", "rgb_array_dict"],
"render_fps": 15,
}
@@ -251,8 +250,8 @@ def __init__(
force_coop: bool = False,
static_layout: bool = False,
observation_mode: str = "tuple",
- render_mode: Optional[str] = None,
- ):
+ render_mode: str | None = None,
+ ) -> None:
super().__init__(
LevelBasedForagingModel(
num_agents,
@@ -274,7 +273,7 @@ def __init__(
def render(self):
if self.render_mode is None:
assert self.spec is not None
- logger.warn(
+ logger.warning(
"You are calling render method without specifying any render mode. "
"You can specify the render_mode at initialization, "
f'e.g. posggym.make("{self.spec.id}", render_mode="rgb_array")'
@@ -335,7 +334,7 @@ def render(self):
img_obj.text = str(food.level)
render_objects.append(img_obj)
- agent_coords_and_dirs = {
+ agent_coords_and_dirs: dict[str, tuple[Coord, Direction]] = {
str(i): (player.coord, Direction.NORTH)
for i, player in enumerate(self._state.players)
}
@@ -381,7 +380,11 @@ class LevelBasedForagingModel(M.POSGModel[LBFState, LBFObs, LBFAction]):
containing integers instead of floats
"""
- OBSERVATION_MODES = ["grid", "vector", "tuple"]
+ OBSERVATION_MODES: ClassVar[list] = ["grid", "vector", "tuple"]
+ MIN_GRID_SIZE = 3
+ MIN_AGENTS = 2
+
+ MAX_ATTEMPTS = 1000
def __init__(
self,
@@ -393,10 +396,10 @@ def __init__(
force_coop: bool,
static_layout: bool,
observation_mode: str = "tuple",
- ):
- assert num_agents >= 2
+ ) -> None:
+ assert num_agents >= self.MIN_AGENTS
assert max_agent_level >= 1
- assert size >= 3
+ assert size >= self.MIN_GRID_SIZE
assert max_food >= 1
assert sight >= 1
assert observation_mode in self.OBSERVATION_MODES
@@ -419,8 +422,8 @@ def __init__(
self.grid = Grid(grid_width=self.size, grid_height=self.size, block_coords=None)
- self._food_locations: Optional[List[Coord]] = None
- self._player_locations: Optional[List[Coord]] = None
+ self._food_locations: list[Coord] | None = None
+ self._player_locations: list[Coord] | None = None
if static_layout:
assert self.max_food <= math.floor((self.size - 1) / 2) ** 2, (
"when using static layout there must be enough space to surround each "
@@ -487,7 +490,6 @@ def get_agent_observation_space(self) -> spaces.Space:
min_obs = np.stack([agents_min, foods_min, access_min])
max_obs = np.stack([agents_max, foods_max, access_max])
- # dtype = np.int8 if self.observation_mode == "vector" else np.uint8
dtype = np.float32
return spaces.Box(
np.array(min_obs, dtype=dtype),
@@ -496,10 +498,10 @@ def get_agent_observation_space(self) -> spaces.Space:
)
@property
- def reward_ranges(self) -> Dict[str, Tuple[float, float]]:
+ def reward_ranges(self) -> dict[str, tuple[float, float]]:
return {i: (0.0, 1.0) for i in self.possible_agents}
- def get_agents(self, state: LBFState) -> List[str]:
+ def get_agents(self, state: LBFState) -> list[str]:
return list(self.possible_agents)
@property
@@ -508,13 +510,13 @@ def rng(self) -> seeding.RNG:
self._rng, _ = seeding.std_random()
return self._rng
- def seed(self, seed: Optional[int] = None):
+ def seed(self, seed: int | None = None):
super().seed(seed)
if self.static_layout:
self._food_locations = self._generate_static_food_coords()
self._player_locations = self._generate_static_player_coords()
- def _generate_static_food_coords(self) -> List[Coord]:
+ def _generate_static_food_coords(self) -> list[Coord]:
"""Generate food coords for static layout.
Number and location of food is the same for given pairing of
@@ -532,7 +534,7 @@ def _generate_static_food_coords(self) -> List[Coord]:
food_locs.append((x, y))
return food_locs
- def _generate_static_player_coords(self) -> List[Coord]:
+ def _generate_static_player_coords(self) -> list[Coord]:
"""Generate player start coords for static layout.
Players always start around edge of field.
@@ -550,8 +552,8 @@ def _generate_static_player_coords(self) -> List[Coord]:
product(idxs_reverse[:-1], [0]),
product(idxs[:-1], [self.size - 1]),
]
- available_locations: List[Coord] = []
- for locs in zip(*sides):
+ available_locations: list[Coord] = []
+ for locs in zip(*sides, strict=False):
available_locations.extend(locs)
return available_locations[: len(self.possible_agents)]
@@ -573,7 +575,7 @@ def sample_initial_state(self) -> LBFState:
)
return LBFState(players, food, sum(f.level for f in food))
- def _spawn_players_static(self) -> Tuple[Player, ...]:
+ def _spawn_players_static(self) -> tuple[Player, ...]:
assert self._player_locations is not None
players = []
for i in range(len(self.possible_agents)):
@@ -582,7 +584,7 @@ def _spawn_players_static(self) -> Tuple[Player, ...]:
players.append(Player(i, (x, y), level))
return tuple(players)
- def _spawn_players_generative(self) -> Tuple[Player, ...]:
+ def _spawn_players_generative(self) -> tuple[Player, ...]:
players = []
available_coords = list(product(range(self.size), range(self.size)))
for i in range(len(self.possible_agents)):
@@ -592,7 +594,7 @@ def _spawn_players_generative(self) -> Tuple[Player, ...]:
players.append(Player(i, coord, level))
return tuple(players)
- def _spawn_food_static(self, max_level: int) -> Tuple[Food, ...]:
+ def _spawn_food_static(self, max_level: int) -> tuple[Food, ...]:
"""Spawn food in static layout.
Number and location of food is the same for given pairing of (size, max_food),
@@ -612,13 +614,13 @@ def _spawn_food_generative(
self,
max_food: int,
max_level: int,
- player_coords: List[Coord],
- ) -> Tuple[Food, ...]:
+ player_coords: list[Coord],
+ ) -> tuple[Food, ...]:
attempts = 0
min_level = max_level if self.force_coop else 1
unavailable_coords = set(player_coords)
food = []
- while len(food) < max_food and attempts < 1000:
+ while len(food) < max_food and attempts < self.MAX_ATTEMPTS:
attempts += 1
x = self.rng.randint(1, self.size - 2)
y = self.rng.randint(1, self.size - 2)
@@ -637,11 +639,11 @@ def _spawn_food_generative(
food.append(Food((x, y), level))
return tuple(food)
- def sample_initial_obs(self, state: LBFState) -> Dict[str, LBFObs]:
+ def sample_initial_obs(self, state: LBFState) -> dict[str, LBFObs]:
return self._get_obs(state)
def step(
- self, state: LBFState, actions: Dict[str, LBFAction]
+ self, state: LBFState, actions: dict[str, LBFAction]
) -> M.JointTimestep[LBFState, LBFObs]:
assert all(0 <= a < len(LBFAction) for a in actions.values())
next_state, rewards = self._get_next_state_and_rewards(state, actions)
@@ -655,8 +657,8 @@ def step(
)
def _get_next_state_and_rewards(
- self, state: LBFState, actions: Dict[str, LBFAction]
- ) -> Tuple[LBFState, Dict[str, float]]:
+ self, state: LBFState, actions: dict[str, LBFAction]
+ ) -> tuple[LBFState, dict[str, float]]:
next_food = {f.coord: f for f in state.food}
# try move agents
@@ -712,7 +714,7 @@ def _get_next_state_and_rewards(
)
return next_state, rewards
- def _get_obs(self, state: LBFState) -> Dict[str, LBFObs]:
+ def _get_obs(self, state: LBFState) -> dict[str, LBFObs]:
obs = {}
for i in self.possible_agents:
player_obs, food_obs = self._get_local_obs(state, int(i))
@@ -728,7 +730,7 @@ def _get_obs(self, state: LBFState) -> Dict[str, LBFObs]:
def _get_local_obs(
self, state: LBFState, agent_id: int
- ) -> Tuple[List[LBFEntityObs], List[LBFEntityObs]]:
+ ) -> tuple[list[LBFEntityObs], list[LBFEntityObs]]:
# player is always in center of observable area
player_obs = []
ego_player = state.players[agent_id]
@@ -765,9 +767,9 @@ def _get_local_obs(
return player_obs, food_obs
def _get_tuple_obs(
- self, player_obs: List[LBFEntityObs], food_obs: List[LBFEntityObs]
+ self, player_obs: list[LBFEntityObs], food_obs: list[LBFEntityObs]
) -> LBFObs:
- obs: List[int] = []
+ obs: list[int] = []
for o in player_obs:
if o.is_self:
obs.insert(0, o.level)
@@ -785,8 +787,8 @@ def _get_tuple_obs(
def _get_vector_obs(
self,
agent_id: int,
- player_obs: List[LBFEntityObs],
- food_obs: List[LBFEntityObs],
+ player_obs: list[LBFEntityObs],
+ food_obs: list[LBFEntityObs],
) -> np.ndarray:
# initialize obs array to (-1, -1, 0)
obs = np.full(
@@ -811,8 +813,8 @@ def _get_vector_obs(
def _get_grid_obs(
self,
agent_coord: Coord,
- player_obs: List[LBFEntityObs],
- food_obs: List[LBFEntityObs],
+ player_obs: list[LBFEntityObs],
+ food_obs: list[LBFEntityObs],
) -> np.ndarray:
grid_shape_x, grid_shape_y = (2 * self.sight + 1, 2 * self.sight + 1)
# agent, food, access layers
@@ -844,7 +846,7 @@ def _get_grid_obs(
def parse_obs(
self, obs: LBFObs
- ) -> Tuple[List[Tuple[int, int, int]], List[Tuple[int, int, int]]]:
+ ) -> tuple[list[tuple[int, int, int]], list[tuple[int, int, int]]]:
"""Parse observation into (x, y, level) agent and food triplets.
Agent obs are ordered so the observing agent is first, then the
@@ -864,7 +866,7 @@ def parse_obs(
def parse_grid_obs(
self, obs: np.ndarray
- ) -> Tuple[List[Tuple[int, int, int]], List[Tuple[int, int, int]]]:
+ ) -> tuple[list[tuple[int, int, int]], list[tuple[int, int, int]]]:
"""Parse grid observation int (x, y, level) agent and food triplets.
Agent obs are ordered so the observing agent is first, then the
@@ -876,7 +878,7 @@ def parse_grid_obs(
def parse_vector_obs(
self, obs: np.ndarray
- ) -> Tuple[List[Tuple[int, int, int]], List[Tuple[int, int, int]]]:
+ ) -> tuple[list[tuple[int, int, int]], list[tuple[int, int, int]]]:
"""Parse vector obs into (x, y, level) agent and food triplets.
Agent obs are ordered so the observing agent is first, then the
@@ -896,8 +898,8 @@ def parse_vector_obs(
return agent_obs, food_obs
def parse_tuple_obs(
- self, obs: Tuple[int, ...]
- ) -> Tuple[List[Tuple[int, int, int]], List[Tuple[int, int, int]]]:
+ self, obs: tuple[int, ...]
+ ) -> tuple[list[tuple[int, int, int]], list[tuple[int, int, int]]]:
"""Parse tuple obs into (x, y, level) agent and food triplets.
Agent obs are ordered so the observing agent is first, then the
@@ -916,10 +918,10 @@ def parse_tuple_obs(
food_obs.append(triplet)
return agent_obs, food_obs
- def get_obs_coords(self, origin: Coord) -> List[Coord]:
+ def get_obs_coords(self, origin: Coord) -> list[Coord]:
"""Get the list of coords observed from agent at origin."""
obs_size = (2 * self.sight) + 1
- obs_coords: List[Coord] = []
+ obs_coords: list[Coord] = []
for col, row in product(range(obs_size), repeat=2):
obs_grid_coord = self._map_obs_to_grid_coord((col, row), origin)
if obs_grid_coord is not None:
@@ -928,7 +930,7 @@ def get_obs_coords(self, origin: Coord) -> List[Coord]:
def _map_obs_to_grid_coord(
self, obs_coord: Coord, agent_coord: Coord
- ) -> Optional[Coord]:
+ ) -> Coord | None:
grid_col = agent_coord[0] + obs_coord[0] - self.sight
grid_row = agent_coord[1] + obs_coord[1] - self.sight
if 0 <= grid_row < self.size and 0 <= grid_col < self.size:
@@ -941,6 +943,6 @@ def sorted_from_middle(lst):
left = lst[len(lst) // 2 - 1 :: -1]
right = lst[len(lst) // 2 :]
output = [right.pop(0)] if len(lst) % 2 else []
- for t in zip(left, right):
+ for t in zip(left, right, strict=False):
output += sorted(t)
return output
diff --git a/posggym/envs/grid_world/predator_prey.py b/posggym/envs/grid_world/predator_prey.py
index b30c81e..e68305d 100644
--- a/posggym/envs/grid_world/predator_prey.py
+++ b/posggym/envs/grid_world/predator_prey.py
@@ -1,25 +1,30 @@
"""The Predator-Prey Grid World Environment."""
+from __future__ import annotations
import math
from itertools import product
from pathlib import Path
-from typing import Dict, List, NamedTuple, Optional, Sequence, Set, Tuple, Union
+from typing import TYPE_CHECKING, ClassVar, NamedTuple
from gymnasium import spaces
import posggym.model as M
from posggym import logger
from posggym.core import DefaultEnv
-from posggym.envs.grid_world.core import Coord, Direction, Grid
+from posggym.envs.grid_world.core import Coord, Direction, Grid, SupportedGridTypes
from posggym.utils import seeding
+if TYPE_CHECKING:
+ from collections.abc import Sequence
+
+
class PPState(NamedTuple):
"""A state in the Predator-Prey Environment."""
- predator_coords: Tuple[Coord, ...]
- prey_coords: Tuple[Coord, ...]
- prey_caught: Tuple[int, ...]
+ predator_coords: tuple[Coord, ...]
+ prey_coords: tuple[Coord, ...]
+ prey_caught: tuple[int, ...]
# Actions
@@ -34,8 +39,7 @@ class PPState(NamedTuple):
ACTION_TO_DIR = [None, Direction.NORTH, Direction.SOUTH, Direction.WEST, Direction.EAST]
# Observations
-# Obs = (adj_obs)
-PPObs = Tuple[int, ...]
+PPObs = tuple[int, ...]
# Cell Obs
EMPTY = 0
WALL = 1
@@ -122,9 +126,8 @@ class PredatorPreyEnv(DefaultEnv[PPState, PPObs, PPAction]):
grids (this can be done by manually specifying a value for `max_episode_steps` when
creating the environment with `posggym.make`).
- Arguments
+ Arguments:
---------
-
- `grid` - the grid layout to use. This can either be a string specifying one of
the supported grids, or a custom :class:`PredatorPreyGrid` object
(default = `"10x10"`).
@@ -181,28 +184,28 @@ class PredatorPreyEnv(DefaultEnv[PPState, PPObs, PPAction]):
---------
- Ming Tan. 1993. Multi-Agent Reinforcement Learning: Independent vs. Cooperative
Agents. In Proceedings of the Tenth International Conference on Machine Learning.
- 330–337.
+ 330-337.
- J. Z. Leibo, V. F. Zambaldi, M. Lanctot, J. Marecki, and T. Graepel. 2017.
Multi-Agent Reinforcement Learning in Sequential Social Dilemmas. In AAMAS,
- Vol. 16. ACM, 464–473
+ Vol. 16. ACM, 464-473
"""
- metadata = {
+ metadata: ClassVar[dict] = {
"render_modes": ["human", "ansi", "rgb_array", "rgb_array_dict"],
"render_fps": 15,
}
def __init__(
self,
- grid: Union[str, "PredatorPreyGrid"] = "10x10",
+ grid: str | PredatorPreyGrid = "10x10",
num_predators: int = 2,
num_prey: int = 3,
cooperative: bool = True,
- prey_strength: Optional[int] = None,
+ prey_strength: int | None = None,
obs_dim: int = 2,
- render_mode: Optional[str] = None,
- ):
+ render_mode: str | None = None,
+ ) -> None:
super().__init__(
PredatorPreyModel(
grid,
@@ -222,7 +225,7 @@ def __init__(
def render(self):
if self.render_mode is None:
assert self.spec is not None
- logger.warn(
+ logger.warning(
"You are calling render method without specifying any render mode. "
"You can specify the render_mode at initialization, "
f'e.g. posggym.make("{self.spec.id}", render_mode="rgb_array")'
@@ -237,7 +240,7 @@ def _render_ansi(self):
grid = model.grid
uncaught_prey_coords = [
self._state.prey_coords[i]
- for i in range(self.model.num_prey)
+ for i in range(model.num_prey)
if not self._state.prey_caught[i]
]
grid_str = grid.get_ascii_repr(
@@ -257,6 +260,8 @@ def _render_ansi(self):
return "\n".join(output) + "\n"
def _render_img(self):
+ print(self.render_mode)
+ assert self.render_mode in ["human", "rgb", "rgb_array", "rgb_array_dict"]
model: PredatorPreyModel = self.model # type: ignore
import posggym.envs.grid_world.render as render_lib
@@ -298,7 +303,7 @@ def _render_img(self):
agent_obj.coord = coord
render_objects.append(agent_obj)
- agent_coords_and_dirs = {
+ agent_coords_and_dirs: dict[str, tuple[Coord, Direction]] = {
str(i): (coord, Direction.NORTH)
for i, coord in enumerate(self._state.predator_coords)
}
@@ -343,16 +348,17 @@ class PredatorPreyModel(M.POSGModel[PPState, PPObs, PPAction]):
R_MAX = 1.0
PREY_CAUGHT_COORD = (0, 0)
+ MAX_AGENTS = 8
def __init__(
self,
- grid: Union[str, "PredatorPreyGrid"],
+ grid: str | PredatorPreyGrid,
num_predators: int,
num_prey: int,
cooperative: bool,
- prey_strength: Optional[int],
+ prey_strength: int | None,
obs_dim: int,
- ):
+ ) -> None:
if isinstance(grid, str):
assert grid in SUPPORTED_GRIDS, (
f"Unsupported grid name '{grid}'. Grid name must be one of: "
@@ -363,7 +369,7 @@ def __init__(
if prey_strength is None:
prey_strength = min(4, num_predators)
- assert 1 < num_predators <= 8
+ assert 1 < num_predators <= self.MAX_AGENTS
assert num_prey > 0
assert obs_dim > 0
assert 0 < prey_strength <= min(4, num_predators)
@@ -376,6 +382,7 @@ def __init__(
self.cooperative = cooperative
self.prey_strength = prey_strength
self._per_prey_reward = self.R_MAX / self.num_prey
+ self.action_mask = [DO_NOTHING] * self.num_predators
if self.grid.prey_start_coords is None:
center_coords = self.grid.get_unblocked_center_coords(num_prey)
@@ -412,7 +419,7 @@ def _coord_space():
self.is_symmetric = True
@property
- def reward_ranges(self) -> Dict[str, Tuple[float, float]]:
+ def reward_ranges(self) -> dict[str, tuple[float, float]]:
return {i: (0.0, self.R_MAX) for i in self.possible_agents}
@property
@@ -421,7 +428,7 @@ def rng(self) -> seeding.RNG:
self._rng, _ = seeding.std_random()
return self._rng
- def get_agents(self, state: PPState) -> List[str]:
+ def get_agents(self, state: PPState) -> list[str]:
return list(self.possible_agents)
def sample_initial_state(self) -> PPState:
@@ -441,12 +448,13 @@ def sample_initial_state(self) -> PPState:
return PPState(tuple(predator_coords), tuple(prey_coords_list), prey_caught)
- def sample_initial_obs(self, state: PPState) -> Dict[str, PPObs]:
+ def sample_initial_obs(self, state: PPState) -> dict[str, PPObs]:
return self._get_obs(state, state)
def step(
- self, state: PPState, actions: Dict[str, PPAction]
+ self, state: PPState, actions: dict[str, PPAction]
) -> M.JointTimestep[PPState, PPObs]:
+ assert all(a_i in ACTIONS for a_i in actions.values())
next_state = self._get_next_state(state, actions)
obs = self._get_obs(state, next_state)
rewards = self._get_rewards(state, next_state)
@@ -455,7 +463,7 @@ def step(
truncated = {i: False for i in self.possible_agents}
terminated = {i: all_done for i in self.possible_agents}
- info: Dict[str, Dict] = {i: {} for i in self.possible_agents}
+ info: dict[str, dict] = {i: {} for i in self.possible_agents}
if all_done:
for i in self.possible_agents:
info[i]["outcome"] = M.Outcome.WIN
@@ -464,15 +472,15 @@ def step(
next_state, obs, rewards, terminated, truncated, all_done, info
)
- def _get_next_state(self, state: PPState, actions: Dict[str, PPAction]) -> PPState:
+ def _get_next_state(self, state: PPState, actions: dict[str, PPAction]) -> PPState:
# prey move first
prey_coords = self._get_next_prey_state(state)
predator_coords = self._get_next_predator_state(state, actions, prey_coords)
prey_caught = self._get_next_prey_caught(state, prey_coords, predator_coords)
return PPState(predator_coords, prey_coords, prey_caught)
- def _get_next_prey_state(self, state: PPState) -> Tuple[Coord, ...]:
- next_prey_coords: List[Optional[Coord]] = [None] * self.num_prey
+ def _get_next_prey_state(self, state: PPState) -> tuple[Coord, ...]:
+ next_prey_coords: list[Coord | None] = [None] * self.num_prey
occupied_coords = set(
state.predator_coords
+ tuple(
@@ -540,9 +548,9 @@ def _get_next_prey_state(self, state: PPState) -> Tuple[Coord, ...]:
def _move_away_from_predators(
self,
prey_coord: Coord,
- predator_coords: Tuple[Coord, ...],
- occupied_coords: Set[Coord],
- ) -> Optional[Coord]:
+ predator_coords: tuple[Coord, ...],
+ occupied_coords: set[Coord],
+ ) -> Coord | None:
# get any predators within obs distance
predator_dists = [
self.grid.manhattan_dist(prey_coord, c) for c in predator_coords
@@ -557,7 +565,7 @@ def _move_away_from_predators(
if all(
abs(a - b) > self.obs_dim
- for a, b in zip(prey_coord, closest_predator_coord)
+ for a, b in zip(prey_coord, closest_predator_coord, strict=False)
):
# closes predator out of obs range
return None
@@ -565,18 +573,18 @@ def _move_away_from_predators(
# move into furthest away free cell, includes current coord
neighbours = [
(self.grid.manhattan_dist(c, closest_predator_coord), c)
- for c in self.grid.get_neighbours(prey_coord) + [prey_coord]
+ for c in [*self.grid.get_neighbours(prey_coord), prey_coord]
]
neighbours.sort()
- for d, c in reversed(neighbours):
+ for _d, c in reversed(neighbours):
if c == prey_coord or self._coord_available_for_prey(c, occupied_coords):
return c
raise AssertionError("Something has gone wrong, please investigate.")
def _move_away_from_preys(
- self, prey_coord: Coord, state: PPState, occupied_coords: Set[Coord]
- ) -> Optional[Coord]:
+ self, prey_coord: Coord, state: PPState, occupied_coords: set[Coord]
+ ) -> Coord | None:
prey_dists = [
(
self.grid.manhattan_dist(prey_coord, c)
@@ -594,7 +602,8 @@ def _move_away_from_preys(
closest_prey_coord = self.rng.choice(all_closest_prey_coords)
if all(
- abs(a - b) > self.obs_dim for a, b in zip(prey_coord, closest_prey_coord)
+ abs(a - b) > self.obs_dim
+ for a, b in zip(prey_coord, closest_prey_coord, strict=False)
):
# closes predator out of obs range
return None
@@ -602,17 +611,17 @@ def _move_away_from_preys(
# move into furthest away free cell, includes current coord
neighbours = [
(self.grid.manhattan_dist(c, closest_prey_coord), c)
- for c in self.grid.get_neighbours(prey_coord) + [prey_coord]
+ for c in [*self.grid.get_neighbours(prey_coord), prey_coord]
]
neighbours.sort()
- for d, c in reversed(neighbours):
+ for _d, c in reversed(neighbours):
if c == prey_coord or self._coord_available_for_prey(c, occupied_coords):
return c
raise AssertionError("Something has gone wrong, please investigate.")
def _coord_available_for_prey(
- self, coord: Coord, occupied_coords: Set[Coord]
+ self, coord: Coord, occupied_coords: set[Coord]
) -> bool:
if coord in occupied_coords:
return False
@@ -627,15 +636,17 @@ def _coord_available_for_prey(
def _get_next_predator_state(
self,
state: PPState,
- actions: Dict[str, PPAction],
- next_prey_coords: Tuple[Coord, ...],
- ) -> Tuple[Coord, ...]:
+ actions: dict[str, PPAction],
+ next_prey_coords: tuple[Coord, ...],
+ ) -> tuple[Coord, ...]:
potential_next_coords = []
occupied_prey_coords = {
c for i, c in enumerate(next_prey_coords) if not state.prey_caught[i]
}
for i, coord in enumerate(state.predator_coords):
- if actions[str(i)] == 0:
+ if actions[str(i)] == self.action_mask[i] or actions[str(i)] == DO_NOTHING:
+ # Current action is `masked` out, e.g., impossible
+ # `or` do nothing
next_coord = coord
else:
a_dir = ACTION_TO_DIR[actions[str(i)]]
@@ -664,9 +675,9 @@ def _get_next_predator_state(
def _get_next_prey_caught(
self,
state: PPState,
- next_prey_coords: Tuple[Coord, ...],
- next_predator_coords: Tuple[Coord, ...],
- ) -> Tuple[int, ...]:
+ next_prey_coords: tuple[Coord, ...],
+ next_predator_coords: tuple[Coord, ...],
+ ) -> tuple[int, ...]:
prey_caught = []
for i in range(self.num_prey):
if state.prey_caught[i]:
@@ -680,7 +691,7 @@ def _get_next_prey_caught(
prey_caught.append(int(num_adj_predators >= self.prey_strength))
return tuple(prey_caught)
- def _get_obs(self, state: PPState, next_state: PPState) -> Dict[str, PPObs]:
+ def _get_obs(self, state: PPState, next_state: PPState) -> dict[str, PPObs]:
return {
i: self._get_local_cell__obs(int(i), state, next_state)
for i in self.possible_agents
@@ -688,7 +699,7 @@ def _get_obs(self, state: PPState, next_state: PPState) -> Dict[str, PPObs]:
def _get_local_cell__obs(
self, agent_idx: int, state: PPState, next_state: PPState
- ) -> Tuple[int, ...]:
+ ) -> tuple[int, ...]:
obs_size = (2 * self.obs_dim) + 1
agent_coord = next_state.predator_coords[agent_idx]
@@ -716,7 +727,7 @@ def _get_local_cell__obs(
def _map_obs_to_grid_coord(
self, obs_coord: Coord, agent_coord: Coord
- ) -> Optional[Coord]:
+ ) -> Coord | None:
grid_col = agent_coord[0] + obs_coord[0] - self.obs_dim
grid_row = agent_coord[1] + obs_coord[1] - self.obs_dim
@@ -724,17 +735,17 @@ def _map_obs_to_grid_coord(
return (grid_col, grid_row)
return None
- def get_obs_coords(self, origin: Coord) -> List[Coord]:
+ def get_obs_coords(self, origin: Coord) -> list[Coord]:
"""Get the list of coords observed from agent at origin."""
obs_size = (2 * self.obs_dim) + 1
- obs_coords: List[Coord] = []
+ obs_coords: list[Coord] = []
for col, row in product(range(obs_size), repeat=2):
obs_grid_coord = self._map_obs_to_grid_coord((col, row), origin)
if obs_grid_coord is not None:
obs_coords.append(obs_grid_coord)
return obs_coords
- def _get_rewards(self, state: PPState, next_state: PPState) -> Dict[str, float]:
+ def _get_rewards(self, state: PPState, next_state: PPState) -> dict[str, float]:
new_caught_prey = []
for i in range(self.num_prey):
if not state.prey_caught[i] and next_state.prey_caught[i]:
@@ -769,30 +780,32 @@ def _get_rewards(self, state: PPState, next_state: PPState) -> Dict[str, float]:
class PredatorPreyGrid(Grid):
"""A grid for the Predator-Prey Problem."""
+ MIN_GRID_SIZE = 3
+
def __init__(
self,
grid_size: int,
- block_coords: Optional[Set[Coord]],
- predator_start_coords: Optional[List[Coord]] = None,
- prey_start_coords: Optional[List[Coord]] = None,
- ):
- assert grid_size >= 3
+ block_coords: set[Coord] | None,
+ predator_start_coords: list[Coord] | None = None,
+ prey_start_coords: list[Coord] | None = None,
+ ) -> None:
+ assert grid_size >= self.MIN_GRID_SIZE
super().__init__(grid_size, grid_size, block_coords)
self.size = grid_size
# predators start in corners or half-way along a side
if predator_start_coords is None:
predator_start_coords = [
- c
+ (c[0], c[1])
for c in product([0, grid_size // 2, grid_size - 1], repeat=2)
if c[0] in (0, grid_size - 1) or c[1] in (0, grid_size - 1)
]
- self.predator_start_coords: List[Coord] = predator_start_coords
+ self.predator_start_coords: list[Coord] = predator_start_coords
self.prey_start_coords = prey_start_coords
def get_ascii_repr(
self,
- predator_coords: Optional[Sequence[Coord]],
- prey_coords: Optional[Sequence[Coord]],
+ predator_coords: Sequence[Coord] | None,
+ prey_coords: Sequence[Coord] | None,
) -> str:
"""Get ascii repr of grid."""
grid_repr = []
@@ -815,7 +828,7 @@ def get_ascii_repr(
return str(self) + "\n" + "\n".join([" ".join(r) for r in grid_repr])
- def get_unblocked_center_coords(self, num: int) -> List[Coord]:
+ def get_unblocked_center_coords(self, num: int) -> list[Coord]:
"""Get at least num closest coords to the center of grid.
May return more than num, since can be more than one coord at equal
@@ -929,7 +942,15 @@ def get_5x5_grid() -> PredatorPreyGrid:
def get_5x5_blocks_grid() -> PredatorPreyGrid:
"""Generate 5x5 Blocks grid layout."""
- grid_str = ".....\n" ".#.#.\n" ".....\n" ".#.#.\n" ".....\n"
+ # fmt: off
+ grid_str = (
+ ".....\n"
+ ".#.#.\n"
+ ".....\n"
+ ".#.#.\n"
+ ".....\n"
+ )
+ # fmt: on
return parse_grid_str(grid_str)
@@ -1014,8 +1035,7 @@ def get_20x20_blocks_grid() -> PredatorPreyGrid:
return parse_grid_str(grid_str)
-# (grid_make_fn, step_limit)
-SUPPORTED_GRIDS = {
+SUPPORTED_GRIDS: SupportedGridTypes[PredatorPreyGrid] = {
"5x5": (get_5x5_grid, 25),
"5x5Blocks": (get_5x5_blocks_grid, 50),
"10x10": (get_10x10_grid, 50),
diff --git a/posggym/envs/grid_world/pursuit_evasion.py b/posggym/envs/grid_world/pursuit_evasion.py
index 55fdc1d..ff95625 100644
--- a/posggym/envs/grid_world/pursuit_evasion.py
+++ b/posggym/envs/grid_world/pursuit_evasion.py
@@ -1,17 +1,23 @@
"""The Pursuit-Evasion Grid World Environment."""
+from __future__ import annotations
from collections import deque
-from typing import Any, Deque, Dict, List, NamedTuple, Optional, Set, Tuple, Union, cast
+from typing import (
+ Any,
+ ClassVar,
+ NamedTuple,
+ cast,
+)
from gymnasium import spaces
import posggym.model as M
from posggym import logger
from posggym.core import DefaultEnv
-from posggym.envs.grid_world.core import Coord, Direction, Grid
+from posggym.envs.grid_world.core import Coord, Direction, Grid, SupportedGridTypes
from posggym.utils import seeding
-# State = (e_coord, e_dir, p_coord, p_dir, e_0_coord, p_0_coord, e_goal_coord)
+
INITIAL_DIR = Direction.NORTH
@@ -51,9 +57,9 @@ class PEState(NamedTuple):
# = Tuple[Tuple[int, int, int, int, int, int], Coord, Coord, Coord]
# Note, we use blank_coord for P Obs so Obs spaces are identical between the
# two agents. The blank_coord is always (0, 0).
-PEEvaderObs = Tuple[Tuple[int, ...], Coord, Coord, Coord]
-PEPursuerObs = Tuple[Tuple[int, ...], Coord, Coord, Coord]
-PEObs = Union[PEEvaderObs, PEPursuerObs]
+PEEvaderObs = tuple[tuple[int, ...], Coord, Coord, Coord]
+PEPursuerObs = tuple[tuple[int, ...], Coord, Coord, Coord]
+PEObs = PEEvaderObs | PEPursuerObs
class PursuitEvasionEnv(DefaultEnv):
@@ -151,9 +157,8 @@ class PursuitEvasionEnv(DefaultEnv):
be adjusted when using larger grids (this can be done by manually specifying a value
for `max_episode_steps` when creating the environment with `posggym.make`).
- Arguments
+ Arguments:
---------
-
- `grid` - the grid layout to use. This can either be a string specifying one of
the supported grids, or a custom :class:`PEGrid` object (default = `"16x16"`).
- `max_obs_distance` - the maximum number of cells in front each agent's field of
@@ -194,7 +199,7 @@ class PursuitEvasionEnv(DefaultEnv):
- removed `normalize_reward` parameter (rewards are now always normalized)
- `v0`: Initial version
- References
+ References:
----------
- [This Pursuit-Evasion implementation is directly inspired by the problem] Seaman,
Iris Rubi, Jan-Willem van de Meent, and David Wingate. 2018. “Nested Reasoning
@@ -207,24 +212,27 @@ class PursuitEvasionEnv(DefaultEnv):
"""
- metadata = {
+ metadata: ClassVar[dict] = {
"render_modes": ["human", "ansi", "rgb_array"],
"render_fps": 15,
}
def __init__(
self,
- grid: Union[str, "PEGrid"] = "16x16",
+ grid: str | PEGrid = "16x16",
max_obs_distance: int = 12,
use_progress_reward: bool = True,
- render_mode: Optional[str] = None,
- ):
+ render_mode: str | None = None,
+ ) -> None:
model = PursuitEvasionModel(
grid,
max_obs_distance=max_obs_distance,
use_progress_reward=use_progress_reward,
)
- super().__init__(model, render_mode=render_mode)
+ super().__init__(
+ model,
+ render_mode=render_mode,
+ )
self.max_obs_distance = max_obs_distance
fov_width = model.grid.get_max_fov_width(
@@ -240,8 +248,8 @@ def __init__(
self._agent_imgs = None
def reset(
- self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None
- ) -> Tuple[Dict[str, M.ObsType], Dict[str, Dict]]:
+ self, *, seed: int | None = None, options: dict[str, Any] | None = None
+ ) -> tuple[dict[str, M.ObsType], dict[str, dict]]:
# reset renderer since goal location can change between episodes
self._renderer = None
return super().reset(seed=seed, options=options)
@@ -249,7 +257,7 @@ def reset(
def render(self):
if self.render_mode is None:
assert self.spec is not None
- logger.warn(
+ logger.warning(
"You are calling render method without specifying any render mode. "
"You can specify the render_mode at initialization, "
f'e.g. posggym.make("{self.spec.id}", render_mode="rgb_array")'
@@ -279,6 +287,7 @@ def _render_ansi(self):
return "\n".join(output) + "\n"
def _render_img(self):
+ assert self.render_mode in ["human", "rgb_array"]
evader_coord = self._state[0]
pursuer_coord = self._state[2]
goal_coord = self._state[6]
@@ -311,12 +320,12 @@ def _render_img(self):
model.grid.get_fov(
self._state[2 * i],
self._state[2 * i + 1],
- self.model.FOV_EXPANSION_INCR,
+ model.FOV_EXPANSION_INCR,
self.max_obs_distance,
)
)
- render_objects = [
+ render_objects: list[render_lib.GWObject] = [
render_lib.GWRectangle(
goal_coord, self.renderer.cell_size, render_lib.get_color("green")
)
@@ -360,10 +369,10 @@ class PursuitEvasionModel(M.POSGModel[PEState, PEObs, PEAction]):
def __init__(
self,
- grid: Union[str, "PEGrid"],
+ grid: str | PEGrid,
max_obs_distance: int = 12,
use_progress_reward: bool = True,
- ):
+ ) -> None:
if isinstance(grid, str):
assert grid in SUPPORTED_GRIDS, (
f"Unsupported grid name '{grid}'. If grid is a string it must be one "
@@ -374,6 +383,7 @@ def __init__(
self._grid = grid
self.max_obs_distance = max_obs_distance
self.use_progress_reward = use_progress_reward
+ self.action_mask = [-1] * self.NUM_AGENTS
self._max_sp_distance = self._grid.get_max_shortest_path_distance()
self._max_raw_return = self.R_EVASION
@@ -387,7 +397,6 @@ def _coord_space():
)
self.possible_agents = tuple(str(i) for i in range(self.NUM_AGENTS))
- # s = Tuple[Coord, Direction, Coord, Direction, Coord, Coord, Coord, int]
# e_coord, e_dir, p_coord, p_dir, e_start, p_start, e_goal, max_sp
self.state_space = spaces.Tuple(
(
@@ -404,7 +413,6 @@ def _coord_space():
self.action_spaces = {
i: spaces.Discrete(len(Direction)) for i in self.possible_agents
}
- # o = Tuple[Tuple[WallObs, seen , heard], Coord, Coord, Coord]
# Wall obs, seen, heard, e_start, p_start, e_goal/blank
self.observation_spaces = {
i: spaces.Tuple(
@@ -420,12 +428,12 @@ def _coord_space():
self.is_symmetric = False
@property
- def grid(self) -> "PEGrid":
+ def grid(self) -> PEGrid:
"""The underlying grid for this model instance."""
return self._grid
@property
- def reward_ranges(self) -> Dict[str, Tuple[float, float]]:
+ def reward_ranges(self) -> dict[str, tuple[float, float]]:
max_reward = self.R_EVASION
if self.use_progress_reward:
max_reward += self.R_PROGRESS
@@ -438,7 +446,7 @@ def rng(self) -> seeding.RNG:
self._rng, seed = seeding.std_random()
return self._rng
- def get_agents(self, state: PEState) -> List[str]:
+ def get_agents(self, state: PEState) -> list[str]:
return list(self.possible_agents)
def sample_initial_state(self) -> PEState:
@@ -455,9 +463,9 @@ def sample_agent_initial_state(self, agent_id: str, obs: PEObs) -> PEState:
def _sample_initial_state(
self,
- evader_coord: Optional[Coord],
- pursuer_coord: Optional[Coord],
- goal_coord: Optional[Coord],
+ evader_coord: Coord | None,
+ pursuer_coord: Coord | None,
+ goal_coord: Coord | None,
) -> PEState:
if evader_coord is None:
evader_coord = self.rng.choice(self.grid.evader_start_coords)
@@ -476,11 +484,11 @@ def _sample_initial_state(
self.grid.get_shortest_path_distance(evader_coord, goal_coord),
)
- def sample_initial_obs(self, state: PEState) -> Dict[str, PEObs]:
+ def sample_initial_obs(self, state: PEState) -> dict[str, PEObs]:
return self._get_obs(state)[0]
def step(
- self, state: PEState, actions: Dict[str, PEAction]
+ self, state: PEState, actions: dict[str, PEAction]
) -> M.JointTimestep[PEState, PEObs]:
assert all(0 <= a_i < len(Direction) for a_i in actions.values())
next_state = self._get_next_state(state, actions)
@@ -489,7 +497,7 @@ def step(
all_done = self._is_done(next_state)
terminated = {i: all_done for i in self.possible_agents}
truncated = {i: False for i in self.possible_agents}
- info: Dict[str, Dict] = {i: {} for i in self.possible_agents}
+ info: dict[str, dict] = {i: {} for i in self.possible_agents}
if all_done:
for i, outcome in self._get_outcome(next_state).items():
info[i]["outcome"] = outcome
@@ -497,20 +505,28 @@ def step(
next_state, obs, rewards, terminated, truncated, all_done, info
)
- def _get_next_state(self, state: PEState, actions: Dict[str, PEAction]) -> PEState:
+ def _get_next_state(self, state: PEState, actions: dict[str, PEAction]) -> PEState:
evader_a = actions[str(self.EVADER_IDX)]
pursuer_a = actions[str(self.PURSUER_IDX)]
- pursuer_next_dir = Direction(ACTION_TO_DIR[pursuer_a][state.pursuer_dir])
- pursuer_next_coord = self.grid.get_next_coord(
- state.pursuer_coord, pursuer_next_dir, ignore_blocks=False
- )
+ if pursuer_a == self.action_mask[self.PURSUER_IDX]:
+ pursuer_next_dir = state.pursuer_dir
+ pursuer_next_coord = state.pursuer_coord
+ else:
+ pursuer_next_dir = Direction(ACTION_TO_DIR[pursuer_a][state.pursuer_dir])
+ pursuer_next_coord = self.grid.get_next_coord(
+ state.pursuer_coord, pursuer_next_dir, ignore_blocks=False
+ )
evader_next_coord = state.evader_coord
- evader_next_dir = Direction(ACTION_TO_DIR[evader_a][state.evader_dir])
- if pursuer_next_coord != state.evader_coord:
- evader_next_coord = self.grid.get_next_coord(
- state.evader_coord, evader_next_dir, ignore_blocks=False
- )
+ if evader_a == self.action_mask[self.EVADER_IDX]:
+ # Action is masked out!
+ evader_next_dir = state.evader_dir
+ else:
+ evader_next_dir = Direction(ACTION_TO_DIR[evader_a][state.evader_dir])
+ if pursuer_next_coord != state.evader_coord:
+ evader_next_coord = self.grid.get_next_coord(
+ state.evader_coord, evader_next_dir, ignore_blocks=False
+ )
min_sp_distance = min(
state.min_goal_dist,
@@ -530,7 +546,7 @@ def _get_next_state(self, state: PEState, actions: Dict[str, PEAction]) -> PESta
min_sp_distance,
)
- def _get_obs(self, state: PEState) -> Tuple[Dict[str, PEObs], bool]:
+ def _get_obs(self, state: PEState) -> tuple[dict[str, PEObs], bool]:
walls, seen, heard = self._get_agent_obs(
state.evader_coord, state.evader_dir, state.pursuer_coord
)
@@ -558,11 +574,11 @@ def _get_obs(self, state: PEState) -> Tuple[Dict[str, PEObs], bool]:
def _get_agent_obs(
self, agent_coord: Coord, agent_dir: Direction, opp_coord: Coord
- ) -> Tuple[Tuple[int, int, int, int], int, int]:
+ ) -> tuple[tuple[int, int, int, int], int, int]:
adj_coords = self.grid.get_neighbours(
agent_coord, ignore_blocks=True, include_out_of_bounds=True
)
- walls: Tuple[int, int, int, int] = tuple( # type: ignore
+ walls: tuple[int, int, int, int] = tuple( # type: ignore
int(not self.grid.coord_in_bounds(coord) or coord in self.grid.block_coords)
for coord in adj_coords
)
@@ -581,7 +597,7 @@ def _get_opponent_seen(
def _get_reward(
self, state: PEState, next_state: PEState, evader_seen: bool
- ) -> Dict[str, float]:
+ ) -> dict[str, float]:
evader_coord = next_state.evader_coord
pursuer_coord = next_state.pursuer_coord
evader_goal_coord = next_state.evader_goal_coord
@@ -611,7 +627,7 @@ def _is_done(self, state: PEState) -> bool:
or self._get_opponent_seen(pursuer_coord, pursuer_dir, evader_coord)
)
- def _get_outcome(self, state: PEState) -> Dict[str, M.Outcome]:
+ def _get_outcome(self, state: PEState) -> dict[str, M.Outcome]:
# Assuming this method is called on final timestep
evader_coord, pursuer_coord = state.evader_coord, state.pursuer_coord
evader_goal_coord = state.evader_goal_coord
@@ -639,11 +655,11 @@ def __init__(
self,
grid_width: int,
grid_height: int,
- block_coords: Set[Coord],
- goal_coords_map: Dict[Coord, List[Coord]],
- evader_start_coords: List[Coord],
- pursuer_start_coords: List[Coord],
- ):
+ block_coords: set[Coord],
+ goal_coords_map: dict[Coord, list[Coord]],
+ evader_start_coords: list[Coord],
+ pursuer_start_coords: list[Coord],
+ ) -> None:
super().__init__(grid_width, grid_height, block_coords)
self._goal_coords_map = goal_coords_map
self.evader_start_coords = evader_start_coords
@@ -651,14 +667,14 @@ def __init__(
self.shortest_paths = self.get_all_shortest_paths(self.all_goal_coords)
@property
- def all_goal_coords(self) -> List[Coord]:
+ def all_goal_coords(self) -> list[Coord]:
"""The list of all evader goal locations."""
all_locs = set()
for v in self._goal_coords_map.values():
all_locs.update(v)
return list(all_locs)
- def get_goal_coords(self, evader_start_coord: Coord) -> List[Coord]:
+ def get_goal_coords(self, evader_start_coord: Coord) -> list[Coord]:
"""Get list of possible evader goal coords for given start coords."""
return self._goal_coords_map[evader_start_coord]
@@ -685,9 +701,9 @@ def get_max_shortest_path_distance(self) -> int:
def get_ascii_repr(
self,
- goal_coord: Union[None, Coord, List[Coord]],
- evader_coord: Union[None, Coord, List[Coord]],
- pursuer_coord: Union[None, Coord, List[Coord]],
+ goal_coord: Coord | list[Coord] | None,
+ evader_coord: Coord | list[Coord] | None,
+ pursuer_coord: Coord | list[Coord] | None,
) -> str:
"""Get ascii repr of grid."""
if goal_coord is None:
@@ -712,7 +728,7 @@ def get_ascii_repr(
if evader_coord is None:
evader_coord = []
- elif not isinstance(evader_coord, List):
+ elif not isinstance(evader_coord, list):
evader_coord = [evader_coord]
for coord in evader_coord:
@@ -720,7 +736,7 @@ def get_ascii_repr(
if pursuer_coord is None:
pursuer_coord = []
- elif not isinstance(pursuer_coord, List):
+ elif not isinstance(pursuer_coord, list):
pursuer_coord = [pursuer_coord]
for coord in pursuer_coord:
@@ -743,7 +759,7 @@ def get_fov(
direction: Direction,
widening_increment: int,
max_depth: int,
- ) -> Set[Coord]:
+ ) -> set[Coord]:
"""Get the Field of vision from origin looking in given direction.
Uses BFS starting from origin and expanding in the direction, while
@@ -754,7 +770,7 @@ def get_fov(
assert max_depth > 0
fov = {origin}
- frontier_queue: Deque[Coord] = deque([origin])
+ frontier_queue: deque[Coord] = deque([origin])
visited = {origin}
while len(frontier_queue):
@@ -776,7 +792,7 @@ def _get_fov_successors(
coord: Coord,
widening_increment: int,
max_depth: int,
- ) -> List[Coord]:
+ ) -> list[Coord]:
if direction in [Direction.NORTH, Direction.SOUTH]:
depth = abs(origin[1] - coord[1])
else:
@@ -797,20 +813,20 @@ def _get_fov_successors(
# Don't expand sideways
return successors
- side_coords_list: List[Coord] = []
+ side_coords_list: list[Coord] = []
if direction in [Direction.NORTH, Direction.SOUTH]:
if 0 < coord[0] <= origin[0]:
side_coords_list.append((coord[0] - 1, coord[1]))
if origin[0] <= coord[0] < self.width - 1:
side_coords_list.append((coord[0] + 1, coord[1]))
- else:
- if 0 < coord[1] <= origin[1]:
- side_coords_list.append((coord[0], coord[1] - 1))
- elif origin[1] <= coord[1] < self.height - 1:
- side_coords_list.append((coord[0], coord[1] + 1))
- side_successor: Optional[Coord] = None
+ elif 0 < coord[1] <= origin[1]:
+ side_coords_list.append((coord[0], coord[1] - 1))
+ elif origin[1] <= coord[1] < self.height - 1:
+ side_coords_list.append((coord[0], coord[1] + 1))
+
+ side_successor: Coord | None = None
for side_coord in side_coords_list:
if side_coord in self.block_coords:
continue
@@ -821,7 +837,7 @@ def _get_fov_successors(
return successors
- def _get_fov_successor(self, coord: Coord, direction: Direction) -> Optional[Coord]:
+ def _get_fov_successor(self, coord: Coord, direction: Direction) -> Coord | None:
new_coord = self.get_next_coord(coord, direction, ignore_blocks=False)
if new_coord == coord:
# move in given direction is blocked or out-of-bounds
@@ -959,9 +975,9 @@ def _convert_map_to_grid(
height: int,
width: int,
block_symbol: str = "#",
- pursuer_start_symbols: Optional[Set[str]] = None,
- evader_start_symbols: Optional[Set[str]] = None,
- evader_goal_symbol_map: Optional[Dict] = None,
+ pursuer_start_symbols: set[str] | None = None,
+ evader_start_symbols: set[str] | None = None,
+ evader_goal_symbol_map: dict | None = None,
) -> PEGrid:
assert len(ascii_map) == height * width
@@ -1011,8 +1027,7 @@ def _convert_map_to_grid(
)
-# grid_name: (grid_make_fn, step_limit)
-SUPPORTED_GRIDS = {
+SUPPORTED_GRIDS: SupportedGridTypes[PEGrid] = {
"8x8": (get_8x8_grid, 50),
"16x16": (get_16x16_grid, 100),
"32x32": (get_32x32_grid, 200),
diff --git a/posggym/envs/grid_world/render.py b/posggym/envs/grid_world/render.py
index ce3b80b..eea6e64 100644
--- a/posggym/envs/grid_world/render.py
+++ b/posggym/envs/grid_world/render.py
@@ -1,19 +1,19 @@
"""Functions and classes for rendering grid world environments."""
import abc
-from typing import Dict, List, Optional, Tuple, Union
-
from pathlib import Path
+
import numpy as np
from posggym.envs.grid_world.core import Coord, Direction, Grid
-from posggym.error import DependencyNotInstalled
+from posggym.error import DependencyNotInstalledError
-ColorTuple = Union[Tuple[int, int, int], Tuple[int, int, int, int]]
+
+ColorTuple = tuple[int, int, int] | tuple[int, int, int, int]
try:
import pygame
except ImportError as e:
- raise DependencyNotInstalled(
+ raise DependencyNotInstalledError(
"pygame is not installed, run `pip install posggym[grid-world]`"
) from e
@@ -31,7 +31,7 @@
]
-def get_agent_color(agent_id: str) -> Tuple[ColorTuple, ColorTuple]:
+def get_agent_color(agent_id: str) -> tuple[ColorTuple, ColorTuple]:
"""Get color for agent."""
return AGENT_COLORS[int(agent_id) % len(AGENT_COLORS)]
@@ -41,12 +41,12 @@ def get_color(color_name: str) -> ColorTuple:
return pygame.colordict.THECOLORS[color_name]
-def load_img_file(img_path: Path, cell_size: Tuple[int, int]):
+def load_img_file(img_path: Path, cell_size: tuple[int, int]):
"""Load an image from file and scale it to cell size."""
return pygame.transform.scale(pygame.image.load(img_path), cell_size)
-def get_default_font_size(cell_size: Tuple[int, int]) -> int:
+def get_default_font_size(cell_size: tuple[int, int]) -> int:
"""Get the default font size based on cell size."""
return cell_size[1] // 2
@@ -62,13 +62,13 @@ class GWObject(abc.ABC):
def __init__(
self,
coord: Coord,
- cell_size: Tuple[int, int],
- ):
+ cell_size: tuple[int, int],
+ ) -> None:
self.coord = coord
self.cell_size = cell_size
@property
- def pos(self) -> Tuple[int, int]:
+ def pos(self) -> tuple[int, int]:
"""The (x, y) position of the object on render surface."""
return (self.coord[0] * self.cell_size[0], self.coord[1] * self.cell_size[1])
@@ -83,9 +83,9 @@ class GWRectangle(GWObject):
def __init__(
self,
coord: Coord,
- cell_size: Tuple[int, int],
+ cell_size: tuple[int, int],
color: ColorTuple,
- ):
+ ) -> None:
super().__init__(coord, cell_size)
self.color = color
@@ -100,10 +100,10 @@ class GWTriangle(GWObject):
def __init__(
self,
coord: Coord,
- cell_size: Tuple[int, int],
+ cell_size: tuple[int, int],
color: ColorTuple,
facing_dir: Direction,
- ):
+ ) -> None:
super().__init__(coord, cell_size)
self.color = color
self.facing_dir = facing_dir
@@ -156,9 +156,9 @@ class GWCircle(GWObject):
def __init__(
self,
coord: Coord,
- cell_size: Tuple[int, int],
+ cell_size: tuple[int, int],
color: ColorTuple,
- ):
+ ) -> None:
super().__init__(coord, cell_size)
self.color = color
@@ -174,7 +174,9 @@ def render(self, surface: pygame.Surface):
class GWHighlight(GWObject):
"""A transparent rectangle for highlighting a cell in the grid world."""
- def __init__(self, coord: Coord, cell_size: Tuple[int, int], alpha: float = 0.25):
+ def __init__(
+ self, coord: Coord, cell_size: tuple[int, int], alpha: float = 0.25
+ ) -> None:
super().__init__(coord, cell_size)
self.alpha = alpha
self.surface = pygame.Surface(cell_size, pygame.SRCALPHA)
@@ -190,9 +192,9 @@ class GWImage(GWObject):
def __init__(
self,
coord: Coord,
- cell_size: Tuple[int, int],
+ cell_size: tuple[int, int],
img: pygame.Surface,
- ):
+ ) -> None:
super().__init__(coord, cell_size)
self.img = img
@@ -206,10 +208,10 @@ class GWText(GWObject):
def __init__(
self,
coord: Coord,
- cell_size: Tuple[int, int],
+ cell_size: tuple[int, int],
text: str,
font: pygame.font.Font,
- ):
+ ) -> None:
super().__init__(coord, cell_size)
self.text = text
self.font = font
@@ -227,11 +229,11 @@ class GWImageAndText(GWObject):
def __init__(
self,
coord: Coord,
- cell_size: Tuple[int, int],
+ cell_size: tuple[int, int],
img: pygame.Surface,
text: str,
font: pygame.font.Font,
- ):
+ ) -> None:
super().__init__(coord, cell_size)
self.img = img
self.text = text
@@ -258,7 +260,7 @@ def __init__(
bg_color: ColorTuple = (0, 0, 0),
grid_line_color: ColorTuple = (255, 255, 255),
block_color: ColorTuple = (131, 139, 139),
- ):
+ ) -> None:
self.render_mode = render_mode
self.grid = grid
self.render_fps = render_fps
@@ -278,7 +280,7 @@ def __init__(
for coord in grid.block_coords
]
# list of static objects user can add to
- self.static_objects: List[GWObject] = []
+ self.static_objects: list[GWObject] = []
pygame.init()
if render_mode == "human":
@@ -316,8 +318,8 @@ def reset_blocks(self):
]
def render(
- self, objects: List[GWObject], observed_coords: Optional[List[Coord]] = None
- ) -> Optional[np.ndarray]:
+ self, objects: list[GWObject], observed_coords: list[Coord] | None = None
+ ) -> np.ndarray | None:
"""Generate Grid-World render."""
self._reset_surface()
for obj in objects:
@@ -343,12 +345,12 @@ def render(
def render_agents(
self,
- objects: List[GWObject],
- agent_coords_and_dirs: Dict[str, Tuple[Coord, Direction]],
- agent_obs_dims: Union[int, Tuple[int, int, int, int]],
- observed_coords: Optional[List[Coord]] = None,
- agent_obs_mask: Optional[List[Coord]] = None,
- ) -> Dict[str, np.ndarray]:
+ objects: list[GWObject],
+ agent_coords_and_dirs: dict[str, tuple[Coord, Direction]],
+ agent_obs_dims: int | tuple[int, int, int, int],
+ observed_coords: list[Coord] | None = None,
+ agent_obs_mask: list[Coord] | None = None,
+ ) -> dict[str, np.ndarray]:
"""Generate environment and agent-centric grid-world renders."""
if agent_obs_mask is None:
agent_obs_mask = []
@@ -364,7 +366,7 @@ def render_agents(
env_array = np.array(pygame.surfarray.pixels3d(self.window_surface))
- array_dict: Dict[str, np.ndarray] = {}
+ array_dict: dict[str, np.ndarray] = {}
for i, (coord, facing_dir) in agent_coords_and_dirs.items():
# 1. get agent's view of env
# (min_col, max_col, min_row, max_row) of coords in grid that agent observed
diff --git a/posggym/envs/grid_world/two_paths.py b/posggym/envs/grid_world/two_paths.py
index d4d5bbb..95196a2 100644
--- a/posggym/envs/grid_world/two_paths.py
+++ b/posggym/envs/grid_world/two_paths.py
@@ -1,7 +1,7 @@
"""The Two-Paths Grid World Environment."""
import itertools
from pathlib import Path
-from typing import Dict, List, Optional, Set, Tuple, Union
+from typing import ClassVar
from gymnasium import spaces
@@ -11,10 +11,10 @@
from posggym.envs.grid_world.core import Coord, Direction, Grid
from posggym.utils import seeding
-TPState = Tuple[Coord, Coord]
+
+TPState = tuple[Coord, Coord]
TPAction = int
-# Obs = adj_obs
-TPObs = Tuple[int, int, int, int]
+TPObs = tuple[int, int, int, int]
# Cell obs
OPPONENT = 0
@@ -99,9 +99,8 @@ class TwoPathsEnv(DefaultEnv[TPState, TPObs, TPAction]):
Episode ends when either the runner is caught, or reaches a goal. By default a
`max_episode_steps` limit of `20` is also set.
- Arguments
+ Arguments:
---------
-
- `grid_size` - the grid size to use. This can either `3`, `4`, or `7`, each size
`n` create a TwoPaths Env with a `n`-by-`n` grid layout (default = `7`).
- `action_probs` - the action success probability for each agent. This can be a
@@ -117,7 +116,7 @@ class TwoPathsEnv(DefaultEnv[TPState, TPObs, TPAction]):
"""
- metadata = {
+ metadata: ClassVar[dict] = {
"render_modes": ["human", "ansi", "rgb_array", "rgb_array_dict"],
"render_fps": 15,
}
@@ -125,9 +124,9 @@ class TwoPathsEnv(DefaultEnv[TPState, TPObs, TPAction]):
def __init__(
self,
grid_size: int = 7,
- action_probs: Union[float, Tuple[float, float]] = 1.0,
- render_mode: Optional[str] = None,
- ):
+ action_probs: float | tuple[float, float] = 1.0,
+ render_mode: str | None = None,
+ ) -> None:
super().__init__(
TwoPathsModel(grid_size, action_probs),
render_mode=render_mode,
@@ -139,7 +138,7 @@ def __init__(
def render(self):
if self.render_mode is None:
assert self.spec is not None
- logger.warn(
+ logger.warning(
"You are calling render method without specifying any render mode. "
"You can specify the render_mode at initialization, "
f'e.g. posggym.make("{self.spec.id}", render_mode="rgb_array")'
@@ -150,8 +149,8 @@ def render(self):
return self._render_img()
def _render_ansi(self):
- grid = self.model.grid # type: ignore
- grid_str = grid.get_ascii_repr(self._state[0], self._state[1])
+ model: TwoPathsModel = self.model # type: ignore
+ grid_str = model.grid.get_ascii_repr(self._state[0], self._state[1])
output = [
f"Step: {self._step_num}",
@@ -167,14 +166,15 @@ def _render_ansi(self):
return "\n".join(output) + "\n"
def _render_img(self):
- grid: Grid = self.model.grid # type: ignore
+ assert self.render_mode in ["human", "rgb", "rgb_array", "rgb_array_dict"]
+ model: TwoPathsModel = self.model # type: ignore
import posggym.envs.grid_world.render as render_lib
if self.renderer is None:
self.renderer = render_lib.GWRenderer(
self.render_mode,
- grid,
+ model.grid,
render_fps=self.metadata["render_fps"],
env_name="Two Paths",
)
@@ -183,7 +183,7 @@ def _render_img(self):
render_lib.GWRectangle(
coord, self.renderer.cell_size, render_lib.get_color("green")
)
- for coord in grid.goal_coords
+ for coord in model.grid.goal_coords
]
self.renderer.static_objects.extend(goal_imgs)
@@ -199,10 +199,10 @@ def _render_img(self):
self.runner_img.coord = self._state[0]
self.chaser_img.coord = self._state[1]
- render_objects = [self.runner_img, self.chaser_img]
+ render_objects: list[render_lib.GWObject] = [self.runner_img, self.chaser_img]
- observed_coords = grid.get_neighbours(self._state[0])
- observed_coords.extend(grid.get_neighbours(self._state[1]))
+ observed_coords = model.grid.get_neighbours(self._state[0])
+ observed_coords.extend(model.grid.get_neighbours(self._state[1]))
if self.render_mode in ("human", "rgb_array"):
return self.renderer.render(render_objects, observed_coords)
@@ -251,8 +251,8 @@ class TwoPathsModel(M.POSGModel[TPState, TPObs, TPAction]):
def __init__(
self,
grid_size: int,
- action_probs: Union[float, Tuple[float, float]] = 1.0,
- ):
+ action_probs: float | tuple[float, float] = 1.0,
+ ) -> None:
assert grid_size in SUPPORTED_GRIDS, (
f"Unsupported grid_size of `{grid_size}`, must be one of: "
f"{SUPPORTED_GRIDS.keys()}."
@@ -261,7 +261,8 @@ def __init__(
if isinstance(action_probs, float):
action_probs = (action_probs, action_probs)
- self._action_probs = action_probs
+
+ self._action_probs: tuple[float, float] = action_probs # type: ignore
self.possible_agents = tuple(str(i) for i in range(self.NUM_AGENTS))
self.state_space = spaces.Tuple(
@@ -292,7 +293,7 @@ def __init__(
self.is_symmetric = False
@property
- def reward_ranges(self) -> Dict[str, Tuple[float, float]]:
+ def reward_ranges(self) -> dict[str, tuple[float, float]]:
return {i: (self.R_CAPTURE, self.R_SAFE) for i in self.possible_agents}
@property
@@ -301,7 +302,7 @@ def rng(self) -> seeding.RNG:
self._rng, seed = seeding.std_random()
return self._rng
- def get_agents(self, state: TPState) -> List[str]:
+ def get_agents(self, state: TPState) -> list[str]:
return list(self.possible_agents)
def sample_initial_state(self) -> TPState:
@@ -310,25 +311,25 @@ def sample_initial_state(self) -> TPState:
def sample_agent_initial_state(self, agent_id: str, obs: TPObs) -> TPState:
return self.sample_initial_state()
- def get_initial_belief_dist(self) -> Dict[TPState, float]:
+ def get_initial_belief_dist(self) -> dict[TPState, float]:
s_0 = (self.grid.init_runner_coord, self.grid.init_chaser_coord)
return {
s: float(s == s_0) # type: ignore
for s in itertools.product(self.grid.all_coords, repeat=2)
}
- def sample_initial_obs(self, state: TPState) -> Dict[str, TPObs]:
+ def sample_initial_obs(self, state: TPState) -> dict[str, TPObs]:
return self._get_obs(state)
def step(
- self, state: TPState, actions: Dict[str, TPAction]
+ self, state: TPState, actions: dict[str, TPAction]
) -> M.JointTimestep[TPState, TPObs]:
assert all(0 <= a_i < len(Direction) for a_i in actions.values())
next_state = self._get_next_state(state, actions)
rewards = self._get_rewards(next_state)
all_done = self._state_is_terminal(next_state)
- info: Dict[str, Dict] = {i: {} for i in self.possible_agents}
+ info: dict[str, dict] = {i: {} for i in self.possible_agents}
if all_done:
for i, outcome in self._get_outcome(next_state).items():
info[i]["outcome"] = outcome
@@ -341,7 +342,7 @@ def step(
next_state, obs, rewards, terminated, truncated, all_done, info
)
- def _get_next_state(self, state: TPState, actions: Dict[str, TPAction]) -> TPState:
+ def _get_next_state(self, state: TPState, actions: dict[str, TPAction]) -> TPState:
runner_coord = state[self.RUNNER_IDX]
chaser_coord = state[self.CHASER_IDX]
runner_a = actions[str(self.RUNNER_IDX)]
@@ -367,7 +368,7 @@ def _get_next_state(self, state: TPState, actions: Dict[str, TPAction]) -> TPSta
return (runner_next_coord, chaser_next_coord)
- def _get_obs(self, state: TPState) -> Dict[str, TPObs]:
+ def _get_obs(self, state: TPState) -> dict[str, TPObs]:
runner_coord = state[self.RUNNER_IDX]
chaser_coord = state[self.CHASER_IDX]
return {
@@ -377,7 +378,7 @@ def _get_obs(self, state: TPState) -> Dict[str, TPObs]:
def _get_adj_obs(
self, coord: Coord, opponent_coord: Coord
- ) -> Tuple[int, int, int, int]:
+ ) -> tuple[int, int, int, int]:
adj_obs = []
for d in Direction:
next_coord = self.grid.get_next_coord(coord, d, False)
@@ -389,7 +390,7 @@ def _get_adj_obs(
adj_obs.append(EMPTY)
return tuple(adj_obs) # type: ignore
- def _get_rewards(self, state: TPState) -> Dict[str, float]:
+ def _get_rewards(self, state: TPState) -> dict[str, float]:
runner_coord = state[self.RUNNER_IDX]
chaser_coord = state[self.CHASER_IDX]
r_runner, r_chaser = (self.R_ACTION, self.R_ACTION)
@@ -401,7 +402,7 @@ def _get_rewards(self, state: TPState) -> Dict[str, float]:
r_runner, r_chaser = (self.R_CAPTURE, -self.R_CAPTURE)
return {str(self.RUNNER_IDX): r_runner, str(self.CHASER_IDX): r_chaser}
- def _get_outcome(self, state: TPState) -> Dict[str, M.Outcome]:
+ def _get_outcome(self, state: TPState) -> dict[str, M.Outcome]:
# Assuming state is terminal
runner_coord = state[self.RUNNER_IDX]
chaser_coord = state[self.CHASER_IDX]
@@ -433,18 +434,18 @@ def __init__(
self,
grid_width: int,
grid_height: int,
- block_coords: Set[Coord],
- goal_coords: Set[Coord],
+ block_coords: set[Coord],
+ goal_coords: set[Coord],
init_runner_coord: Coord,
init_chaser_coord: Coord,
- ):
+ ) -> None:
super().__init__(grid_width, grid_height, block_coords)
self.goal_coords = goal_coords
self.init_runner_coord = init_runner_coord
self.init_chaser_coord = init_chaser_coord
def get_ascii_repr(
- self, runner_coord: Optional[Coord], chaser_coord: Optional[Coord]
+ self, runner_coord: Coord | None, chaser_coord: Coord | None
) -> str:
"""Get ascii repr of grid."""
grid_repr = []
@@ -578,7 +579,6 @@ def get_7x7_grid() -> TPGrid:
)
-# grid_size: grid_make_fn
SUPPORTED_GRIDS = {
3: get_3x3_grid,
4: get_4x4_grid,
diff --git a/posggym/envs/grid_world/uav.py b/posggym/envs/grid_world/uav.py
index af8415d..0e18ee9 100644
--- a/posggym/envs/grid_world/uav.py
+++ b/posggym/envs/grid_world/uav.py
@@ -1,7 +1,9 @@
"""The Unmanned Aerial Vehicle Grid World Environment."""
+from __future__ import annotations
+
import random
from pathlib import Path
-from typing import Dict, List, Optional, Set, Tuple, Union
+from typing import ClassVar
from gymnasium import spaces
@@ -11,14 +13,15 @@
from posggym.envs.grid_world.core import Coord, Direction, Grid
from posggym.utils import seeding
-UAVState = Tuple[Coord, Coord]
+
+UAVState = tuple[Coord, Coord]
UAVAction = int
# UAV Obs = (uav coord, fug coord)
-UAVUAVObs = Tuple[Coord, Coord]
+UAVUAVObs = tuple[Coord, Coord]
# FUG Obs = house direction
UAVFUGObs = int
-UAVObs = Union[UAVUAVObs, UAVFUGObs]
+UAVObs = UAVUAVObs | UAVFUGObs
OBSNORTH = 0
OBSSOUTH = 1
@@ -94,9 +97,8 @@ class UAVEnv(DefaultEnv[UAVState, UAVObs, UAVAction]):
limit is reached, as specified by `max_episode_steps` when initializing the
environment with `posggym.make` (default=`50`).
- Arguments
+ Arguments:
---------
-
- `grid` - the grid of the environment. This can be an integer specifying
the width and height of the grid, in which case an empty grid with the given
dimensions and default position for the safe house will be used. Alternatively,
@@ -123,15 +125,16 @@ class UAVEnv(DefaultEnv[UAVState, UAVObs, UAVAction]):
---------
Panella, Alessandro, and Piotr Gmytrasiewicz. 2017. “Interactive POMDPs
with Finite-State Models of Other Agents.” Autonomous Agents and
- Multi-Agent Systems 31 (4): 861–904.
+ Multi-Agent Systems 31 (4): 861-904.
"""
- metadata = {"render_modes": ["human", "ansi", "rgb_array"], "render_fps": 15}
+ metadata: ClassVar[dict] = {
+ "render_modes": ["human", "ansi", "rgb_array"],
+ "render_fps": 15,
+ }
- def __init__(
- self, grid: Union["UAVGrid", int] = 5, render_mode: Optional[str] = None
- ):
+ def __init__(self, grid: UAVGrid | int = 5, render_mode: str | None = None) -> None:
super().__init__(UAVModel(grid), render_mode=render_mode)
self.renderer = None
self.uav_img = None
@@ -140,7 +143,7 @@ def __init__(
def render(self):
if self.render_mode is None:
assert self.spec is not None
- logger.warn(
+ logger.warning(
"You are calling render method without specifying any render mode. "
"You can specify the render_mode at initialization, "
f'e.g. posggym.make("{self.spec.id}", render_mode="rgb_array")'
@@ -222,13 +225,18 @@ class UAVModel(M.POSGModel[UAVState, UAVObs, UAVAction]):
R_CAPTURE = 1.0 # UAV reward, fugitive = -R_CAPTURE
R_SAFE = -1.0 # UAV reward, fugitive = -R_SAFE
- # Observatio Accuracy for each agent
+ # Observation Accuracy for each agent
FUG_OBS_ACC = 0.8
UAV_OBS_ACC = 0.9
- def __init__(self, grid: Union["UAVGrid", int]):
+ MIN_GRID_SIZE = 3
+ REQUIRED_ADJACENT_COORDS = 4
+
+ def __init__(self, grid: UAVGrid | int) -> None:
if isinstance(grid, int):
- assert grid >= 3, "Grid size must be >= 3."
+ assert (
+ grid >= self.MIN_GRID_SIZE
+ ), f"Grid size must be >= {self.MIN_GRID_SIZE}."
# grid specified size of grid,
grid = UAVGrid(grid, grid, None)
self.grid = grid
@@ -261,11 +269,11 @@ def __init__(self, grid: Union["UAVGrid", int]):
self.is_symmetric = False
# cache for sampling obs conditioned init state for fug
- self._cached_init_fug_obs: Optional[UAVFUGObs] = None
- self._valid_fug_coords_dist: Tuple[List[Coord], List[float]] = ([], [])
+ self._cached_init_fug_obs: UAVFUGObs | None = None
+ self._valid_fug_coords_dist: tuple[list[Coord], list[float]] = ([], [])
@property
- def reward_ranges(self) -> Dict[str, Tuple[float, float]]:
+ def reward_ranges(self) -> dict[str, tuple[float, float]]:
return {i: (self.R_SAFE, self.R_CAPTURE) for i in self.possible_agents}
@property
@@ -274,7 +282,7 @@ def rng(self) -> seeding.RNG:
self._rng, seed = seeding.std_random()
return self._rng
- def get_agents(self, state: UAVState) -> List[str]:
+ def get_agents(self, state: UAVState) -> list[str]:
return list(self.possible_agents)
def sample_initial_state(self) -> UAVState:
@@ -294,7 +302,7 @@ def sample_agent_initial_state(self, agent_id: str, obs: UAVObs) -> UAVState:
house_adj_coords = self.grid.get_neighbours(
self.grid.safe_house_coord, ignore_blocks=False
)
- if len(house_adj_coords) != 4:
+ if len(house_adj_coords) != self.REQUIRED_ADJACENT_COORDS:
# Doesn't work for 3x3 grid
raise NotImplementedError(
"Sampling observation conditioned initial state for the fugitive is "
@@ -329,7 +337,7 @@ def _sample_fug(self, fug_obs: UAVFUGObs) -> UAVState:
uav_coord = self.rng.choice(uav_start_coords)
return uav_coord, fug_coord
- def _get_fug_coord_dist(self, obs: UAVFUGObs) -> Tuple[List[Coord], List[float]]:
+ def _get_fug_coord_dist(self, obs: UAVFUGObs) -> tuple[list[Coord], list[float]]:
house_adj_coords = self.grid.get_neighbours(
self.grid.safe_house_coord, ignore_blocks=False
)
@@ -369,11 +377,11 @@ def _get_fug_coord_dist(self, obs: UAVFUGObs) -> Tuple[List[Coord], List[float]]
dist.append((1.0 - self.FUG_OBS_ACC) / (num_adj - num_true))
return house_adj_coords, dist
- def sample_initial_obs(self, state: UAVState) -> Dict[str, UAVObs]:
+ def sample_initial_obs(self, state: UAVState) -> dict[str, UAVObs]:
return self._sample_obs(state)
def step(
- self, state: UAVState, actions: Dict[str, UAVAction]
+ self, state: UAVState, actions: dict[str, UAVAction]
) -> M.JointTimestep[UAVState, UAVObs]:
assert all(0 <= a_i < len(Direction) for a_i in actions.values())
next_state = self._sample_next_state(state, actions)
@@ -389,13 +397,13 @@ def step(
terminated = {i: False for i in self.possible_agents}
truncated = {i: False for i in self.possible_agents}
all_done = False
- info: Dict[str, Dict] = {i: {} for i in self.possible_agents}
+ info: dict[str, dict] = {i: {} for i in self.possible_agents}
return M.JointTimestep(
next_state, obs, rewards, terminated, truncated, all_done, info
)
def _sample_next_state(
- self, state: UAVState, actions: Dict[str, UAVAction]
+ self, state: UAVState, actions: dict[str, UAVAction]
) -> UAVState:
uav_a, fug_a = actions[self.UAV_ID], actions[self.FUG_ID]
uav_coord, fug_coord = state
@@ -415,7 +423,7 @@ def _sample_fug_coord(self, uav_coord: Coord) -> Coord:
fug_start_coords.remove(uav_coord)
return self.rng.choice(fug_start_coords)
- def _sample_obs(self, state: UAVState) -> Dict[str, UAVObs]:
+ def _sample_obs(self, state: UAVState) -> dict[str, UAVObs]:
return {
self.UAV_ID: self._sample_uav_obs(state),
self.FUG_ID: self._sample_fug_obs(state),
@@ -458,7 +466,7 @@ def _sample_fug_obs(self, state: UAVState) -> UAVFUGObs:
return true_obs
return self.rng.choice([OBSNORTH, OBSSOUTH, OBSLEVEL])
- def _get_reward(self, next_state: UAVState) -> Dict[str, float]:
+ def _get_reward(self, next_state: UAVState) -> dict[str, float]:
uav_coord, fug_coord = next_state
uav_reward, fug_reward = self.R_ACTION, self.R_ACTION
if fug_coord == self.grid.safe_house_coord:
@@ -480,11 +488,11 @@ def __init__(
self,
grid_width: int,
grid_height: int,
- block_coords: Optional[Set[Coord]],
- safe_house_coord: Optional[Coord] = None,
- init_fug_coords: Optional[List[Coord]] = None,
- init_uav_coords: Optional[List[Coord]] = None,
- ):
+ block_coords: set[Coord] | None,
+ safe_house_coord: Coord | None = None,
+ init_fug_coords: list[Coord] | None = None,
+ init_uav_coords: list[Coord] | None = None,
+ ) -> None:
super().__init__(grid_width, grid_height, block_coords)
if safe_house_coord is None:
safe_house_coord = (grid_width // 2, grid_height // 4)
@@ -503,9 +511,7 @@ def __init__(
self.valid_coords = set(self.unblocked_coords)
self.valid_coords.remove(self.safe_house_coord)
- def get_ascii_repr(
- self, fug_coord: Optional[Coord], uav_coord: Optional[Coord]
- ) -> str:
+ def get_ascii_repr(self, fug_coord: Coord | None, uav_coord: Coord | None) -> str:
"""Get ascii repr of grid."""
grid_repr = []
for row in range(self.height):
diff --git a/posggym/envs/registration.py b/posggym/envs/registration.py
index f82fd9e..572d3b5 100644
--- a/posggym/envs/registration.py
+++ b/posggym/envs/registration.py
@@ -13,16 +13,18 @@
import difflib
import importlib
import re
-import sys
from collections import defaultdict
+from collections.abc import Callable, Iterable
from dataclasses import dataclass, field
-from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Tuple
+from typing import TYPE_CHECKING, Any
from posggym import error, logger
-from posggym.wrappers import OrderEnforcing, PassiveEnvChecker, TimeLimit
-
-if sys.version_info < (3, 10):
- import importlib_metadata as metadata # type: ignore
+from posggym.wrappers import (
+ BatchTimeLimit,
+ OrderEnforcing,
+ PassiveEnvChecker,
+ TimeLimit,
+)
if TYPE_CHECKING:
@@ -38,12 +40,12 @@
def load(name: str) -> Callable:
"""Loads environment with name and returns an environment creation function.
- Arguments
+ Arguments:
---------
name : str
The environment name
- Returns
+ Returns:
-------
entry_point : Callable
Environment creation function.
@@ -55,17 +57,17 @@ def load(name: str) -> Callable:
return fn
-def parse_env_id(env_id: str) -> Tuple[str | None, str, int | None]:
+def parse_env_id(env_id: str) -> tuple[str | None, str, int | None]:
"""Parse environment ID string format.
[namespace/](env-name)-v(version) env-name is group 1, version is group 2
- Arguments
+ Arguments:
---------
env_id : str
The environment id to parse
- Returns
+ Returns:
-------
ns : str | None
The environment namespace
@@ -74,9 +76,11 @@ def parse_env_id(env_id: str) -> Tuple[str | None, str, int | None]:
version : int | None
The environment version
- Raises
+ Raises:
------
- Error
+
+ Error:
+ -----
If the environment id does not a valid environment regex
"""
@@ -98,7 +102,7 @@ def get_env_id(ns: str | None, name: str, version: int | None) -> str:
Inverse of :meth:`parse_env_id`.
- Arguments
+ Arguments:
---------
ns : str | None
The environment namespace.
@@ -107,7 +111,7 @@ def get_env_id(ns: str | None, name: str, version: int | None) -> str:
version : int | None
The environment version.
- Returns
+ Returns:
-------
str
The environment id.
@@ -161,7 +165,7 @@ class EnvSpec:
disable_env_checker: bool = field(default=False)
# Environment Arguments
- kwargs: Dict = field(default_factory=dict)
+ kwargs: dict = field(default_factory=dict)
# post-init attributes
namespace: str | None = field(init=False)
@@ -201,7 +205,7 @@ def _check_namespace_exists(ns: str | None):
else f"Have you installed the proper package for {ns}?"
)
- raise error.NamespaceNotFound(f"Namespace {ns} not found. {suggestion_msg}")
+ raise error.NamespaceNotFoundError(f"Namespace {ns} not found. {suggestion_msg}")
def _check_name_exists(ns: str | None, name: str):
@@ -220,7 +224,7 @@ def _check_name_exists(ns: str | None, name: str):
namespace_msg = f" in namespace {ns}" if ns else ""
suggestion_msg = f"Did you mean: `{names[suggestion[0]]}`?" if suggestion else ""
- raise error.NameNotFound(
+ raise error.NameNotFoundError(
f"Environment {name} doesn't exist{namespace_msg}. {suggestion_msg}"
)
@@ -231,7 +235,7 @@ def _check_version_exists(ns: str | None, name: str, version: int | None):
This is a complete test whether an environment identifier is valid, and will
provide the best available hints.
- Arguments
+ Arguments:
---------
ns : str | None
The environment namespace.
@@ -240,12 +244,12 @@ def _check_version_exists(ns: str | None, name: str, version: int | None):
version : int | None
The environment version.
- Raises
+ Raises:
------
- DeprecatedEnv
+ DeprecatedEnvError
The environment doesn't exist but a default version does or the environment
version is deprecated.
- VersionNotFound
+ VersionNotFoundError
The ``version`` used doesn't exist.
"""
@@ -273,7 +277,7 @@ def _check_version_exists(ns: str | None, name: str, version: int | None):
if default_spec:
message += f" It provides the default version {default_spec[0].id}`."
if len(env_specs) == 1:
- raise error.DeprecatedEnv(message)
+ raise error.DeprecatedEnvError(message)
# Process possible versioned environments
versioned_specs = [spec_ for spec_ in env_specs if spec_.version is not None]
@@ -287,10 +291,10 @@ def _check_version_exists(ns: str | None, name: str, version: int | None):
if version > latest_spec.version:
version_list_msg = ", ".join(f"`v{spec_.version}`" for spec_ in env_specs)
message += f" It provides versioned environments: [ {version_list_msg} ]."
- raise error.VersionNotFound(message)
+ raise error.VersionNotFoundError(message)
if version < latest_spec.version:
- raise error.DeprecatedEnv(
+ raise error.DeprecatedEnvError(
f"Environment version v{version} for `{get_env_id(ns, name, None)}` "
f"is deprecated. Please use `{latest_spec.id}` instead."
)
@@ -357,21 +361,22 @@ def _check_spec_register(spec: EnvSpec):
)
-def _check_metadata(metadata_: Dict):
+def _check_metadata(metadata_: dict):
"""Checks validity of metadata. Printing warnings if it's invalid."""
if not isinstance(metadata_, dict):
- raise error.InvalidMetadata(
- f"Expect the environment metadata to be dict, actual type: {type(metadata)}"
+ raise error.InvalidMetadataError(
+ "Expect the environment metadata to be dict,",
+ f"actual type: {type(metadata_)}",
)
render_modes = metadata_.get("render_modes")
if render_modes is None:
- logger.warn(
+ logger.warning(
"The environment creator metadata doesn't include `render_modes`, "
f"contains: {list(metadata_.keys())}"
)
elif not isinstance(render_modes, Iterable):
- logger.warn(
+ logger.warning(
"Expects the environment metadata render_modes to be a Iterable, actual "
f"type: {type(render_modes)}"
)
@@ -409,7 +414,7 @@ def register(
It takes arbitrary keyword arguments, which are passed to the `EnvSpec` constructor.
- Arguments
+ Arguments:
---------
id : str
The environment id.
@@ -442,7 +447,7 @@ def register(
kwargs.get("namespace") is not None
and kwargs.get("namespace") != current_namespace
):
- logger.warn(
+ logger.warning(
f"Custom namespace `{kwargs.get('namespace')}` is being overridden by "
f"namespace `{current_namespace}`. If you are developing a plugin you "
"shouldn't specify a namespace in `register` calls. "
@@ -464,7 +469,7 @@ def register(
)
_check_spec_register(new_spec)
if new_spec.id in registry:
- logger.warn(f"Overriding environment {new_spec.id} already in registry.")
+ logger.warning(f"Overriding environment {new_spec.id} already in registry.")
registry[new_spec.id] = new_spec
@@ -479,7 +484,7 @@ def make(
To find all available environments use `posggym.envs.registry.keys()` for all valid
ids.
- Arguments
+ Arguments:
---------
id : str | EnvSpec
Name of the environment. Optionally, a module to import can be included,
@@ -494,14 +499,16 @@ def make(
**kwargs
Additional arguments to pass to the environment constructor.
- Returns
+ Returns:
-------
Env
An instance of the environment.
- Raises
+ Raises:
------
- Error
+
+ Error:
+ -----
If the ``id`` doesn't exist then an error is raised
"""
@@ -528,7 +535,7 @@ def make(
and latest_version is not None
and latest_version > version
):
- logger.warn(
+ logger.warning(
f"The environment {id} is out of date. You should consider "
f"upgrading to version `v{latest_version}`."
)
@@ -537,7 +544,7 @@ def make(
version = latest_version
new_env_id = get_env_id(ns, name, version)
spec_ = registry.get(new_env_id) # type: ignore
- logger.warn(
+ logger.warning(
f"Using the latest versioned environment `{new_env_id}` "
f"instead of the unversioned environment `{id}`."
)
@@ -564,7 +571,7 @@ def make(
mode = _kwargs.get("render_mode")
if mode is not None and render_modes is not None and mode not in render_modes:
- raise error.UnsupportedMode(
+ raise error.UnsupportedModeError(
f"The environment is being initialised with render_mode={mode} "
f"that is not in the possible render_modes ({render_modes})."
)
@@ -581,6 +588,8 @@ def make(
env.unwrapped.spec = spec_
env.unwrapped.model.spec = spec_
+ TL = BatchTimeLimit if hasattr(env, "batch_size") else TimeLimit
+
# Run the environment checker as the lowest level wrapper
if disable_env_checker is False or (
disable_env_checker is None and spec_.disable_env_checker is False
@@ -593,9 +602,9 @@ def make(
# Add the time limit wrapper
if max_episode_steps is not None:
- env = TimeLimit(env, max_episode_steps)
+ env = TL(env, max_episode_steps)
elif spec_.max_episode_steps is not None:
- env = TimeLimit(env, spec_.max_episode_steps)
+ env = TL(env, spec_.max_episode_steps)
return env
@@ -603,19 +612,21 @@ def make(
def spec(env_id: str) -> EnvSpec:
"""Retrieve the spec for the given environment from the global registry.
- Arguments
+ Arguments:
---------
env_id : str
The environment id.
- Returns
+ Returns:
-------
EnvSpec
The environment spec from the global registry.
- Raises
+ Raises:
------
- Error
+
+ Error:
+ -----
If environment with given ``env_id`` doesn't exist in global registry.
"""
@@ -629,15 +640,18 @@ def spec(env_id: str) -> EnvSpec:
return spec_
+NAMESPACE_MIN_PARTS = 3
+
+
def pprint_registry(
- _registry: Dict = registry,
+ _registry: dict = registry,
num_cols: int = 3,
- exclude_namespaces: List[str] | None = None,
+ exclude_namespaces: list[str] | None = None,
disable_print: bool = False,
) -> str | None:
"""Pretty print the environments in the registry.
- Arguments
+ Arguments:
---------
_registry : Dict
Environment registry to be printed.
@@ -649,7 +663,7 @@ def pprint_registry(
Whether to return a string of all the namespaces and environment IDs instead of
printing it to console.
- Returns
+ Returns:
-------
str | None
Formatted str representation of registry, if ``disable_print=True``, otherwise
@@ -657,7 +671,7 @@ def pprint_registry(
"""
# Defaultdict to store environment names according to namespace.
- namespace_envs = defaultdict(lambda: [])
+ namespace_envs = defaultdict(list)
max_justify = float("-inf")
for env in _registry.values():
namespace, _, _ = parse_env_id(env.id)
@@ -666,7 +680,7 @@ def pprint_registry(
# entrypoints.
env_entry_point = re.sub(r":\w+", "", env.entry_point)
e_ep_split = env_entry_point.split(".")
- if len(e_ep_split) >= 3:
+ if len(e_ep_split) >= NAMESPACE_MIN_PARTS:
# If namespace is of the format - posggym.envs.env_group.env_name:env_id
# or posggym.envs.env_group:env_id
idx = 2
diff --git a/posggym/error.py b/posggym/error.py
index 2bbf9e8..6edb07d 100644
--- a/posggym/error.py
+++ b/posggym/error.py
@@ -11,27 +11,27 @@ class Error(Exception):
"""Base posggym error."""
-class Unregistered(Error):
+class UnregisteredError(Error):
"""Raised when user requests item from registry that doesn't exist."""
-class UnregisteredEnv(Unregistered):
+class UnregisteredEnvError(UnregisteredError):
"""Raised when user requests env from registry that doesn't exist."""
-class NamespaceNotFound(UnregisteredEnv):
+class NamespaceNotFoundError(UnregisteredEnvError):
"""Raised when user requests env from registry where namespace doesn't exist."""
-class NameNotFound(UnregisteredEnv):
+class NameNotFoundError(UnregisteredEnvError):
"""Raised when user requests env from registry where name doesn't exist."""
-class VersionNotFound(UnregisteredEnv):
+class VersionNotFoundError(UnregisteredEnvError):
"""Raised when user requests env from registry where version doesn't exist."""
-class DeprecatedEnv(Error):
+class DeprecatedEnvError(Error):
"""Raised when user requests env from registry with old version.
I.e. if the version number is older than the latest version env with the same
@@ -46,88 +46,84 @@ class RegistrationError(Error):
"""
-class UnseedableEnv(Error):
+class UnseedableEnvError(Error):
"""Raised when the user tries to seed an env that does not support seeding."""
-class DependencyNotInstalled(Error):
+class DependencyNotInstalledError(Error):
"""Raised when the user has not installed a dependency."""
-class UnsupportedMode(Error):
+class UnsupportedModeError(Error):
"""Raised when user requests rendering mode not supported by the environment."""
-class InvalidMetadata(Error):
+class InvalidMetadataError(Error):
"""Raised when the metadata of an environment is not valid."""
-class ResetNeeded(Error):
+class ResetNeededError(Error):
"""Raised when the user attempts to step environment before a reset."""
-class ResetNotAllowed(Error):
+class ResetNotAllowedError(Error):
"""Raised when user tries to reset an environment that's not done.
Applicable when monitor is active.
"""
-class InvalidAction(Error):
+class InvalidActionError(Error):
"""Raised when the user performs an action not contained within the action space."""
-class MissingArgument(Error):
+class MissingArgumentError(Error):
"""Raised when a required argument in the initializer is missing."""
-class InvalidProbability(Error):
+class InvalidProbabilityError(Error):
"""Raised when given an invalid value for a probability."""
-class InvalidBound(Error):
+class InvalidBoundError(Error):
"""Raised when the clipping an array with invalid upper and/or lower bound."""
# Video errors
-class VideoRecorderError(Error):
+class VideoRecorderErrorError(Error):
"""Video recorder error."""
- pass
-
-class InvalidFrame(Error):
+class InvalidFrameError(Error):
"""Invalid video frame error."""
- pass
-
# posggym.agent specific errors
-class UnregisteredPolicy(Unregistered):
+class UnregisteredPolicyError(UnregisteredError):
"""Raised when user requests policy from registry that doesn't exist."""
-class PolicyEnvIDNotFound(UnregisteredPolicy):
+class PolicyEnvIDNotFoundError(UnregisteredPolicyError):
"""Raised when user requests policy from registry with env-id that doesn't exist."""
-class PolicyEnvArgsIDNotFound(UnregisteredPolicy):
+class PolicyEnvArgsIDNotFoundError(UnregisteredPolicyError):
"""Raised when user requests policy from registry with env-args that don't exist."""
-class PolicyNameNotFound(UnregisteredPolicy):
+class PolicyNameNotFoundError(UnregisteredPolicyError):
"""Raised when user requests policy from registry where name doesn't exist."""
-class PolicyVersionNotFound(UnregisteredPolicy):
+class PolicyVersionNotFoundError(UnregisteredPolicyError):
"""Raised when user requests policy from registry where version doesn't exist."""
-class DeprecatedPolicy(Error):
+class DeprecatedPolicyError(Error):
"""Raised when user requests policy from registry with old version.
I.e. if the version number is older than the latest version env with the same
@@ -142,11 +138,11 @@ class PolicyRegistrationError(Error):
"""
-class UnseedablePolicy(Error):
+class UnseedablePolicyError(Error):
"""Raised when the user tries to seed an policy that does not support seeding."""
-class InvalidFile(Error):
+class InvalidFileError(Error):
"""Raised when trying to access and invalid posggym file."""
diff --git a/posggym/logger.py b/posggym/logger.py
index 99bac9f..8868eb2 100644
--- a/posggym/logger.py
+++ b/posggym/logger.py
@@ -6,7 +6,6 @@
"""
import sys
import warnings
-from typing import Optional, Type
from gymnasium.utils.colorize import colorize
@@ -38,15 +37,15 @@ def info(msg: str, *args):
print(f"INFO: {msg % args}", file=sys.stderr)
-def warn(
+def warning(
msg: str,
*args: object,
- category: Optional[Type[Warning]] = None,
+ category: type[Warning] | None = None,
stacklevel: int = 1,
):
"""Raises a warning to the user if the min_level <= WARN.
- Arguments
+ Arguments:
---------
msg: str
The message to warn the user
@@ -68,7 +67,7 @@ def warn(
def deprecation(msg: str, *args: object):
"""Logs a deprecation warning to users."""
- warn(msg, *args, category=DeprecationWarning, stacklevel=2)
+ warning(msg, *args, category=DeprecationWarning, stacklevel=2)
def error(msg: str, *args):
diff --git a/posggym/model.py b/posggym/model.py
index a54923a..8234b01 100644
--- a/posggym/model.py
+++ b/posggym/model.py
@@ -6,13 +6,20 @@
import dataclasses
import enum
import random
-from typing import TYPE_CHECKING, Dict, Generic, List, Tuple, TypeVar
+from typing import TYPE_CHECKING, Generic, TypeVar
import numpy as np
+
+try:
+ import torch
+except ImportError:
+ torch = None
+
from posggym import error
from posggym.utils import seeding
+
if TYPE_CHECKING:
from gymnasium import spaces
@@ -36,12 +43,12 @@ class JointTimestep(Generic[StateType, ObsType]):
"""
state: StateType
- observations: Dict[str, ObsType]
- rewards: Dict[str, float]
- terminations: Dict[str, bool]
- truncations: Dict[str, bool]
+ observations: dict[str, ObsType]
+ rewards: dict[str, float]
+ terminations: dict[str, bool]
+ truncations: dict[str, bool]
all_done: bool
- infos: Dict[str, Dict]
+ infos: dict[str, dict]
def __iter__(self):
for field in dataclasses.fields(self):
@@ -91,7 +98,7 @@ class POSGModel(abc.ABC, Generic[StateType, ObsType, ActType]):
Custom models may optionally provide implementations for the
:meth:`sample_agent_initial_state` method and :attr:`state_space` attribute.
- Note
+ Note:
----
The POSGGym Model API models all environments as environments that are
`observation first`, that is the environment provides an initial observation before
@@ -107,23 +114,23 @@ class POSGModel(abc.ABC, Generic[StateType, ObsType, ActType]):
# EnvSpec used to instantiate env instance this model is for
# This is set when env is made using posggym.make function
- spec: "EnvSpec" | None = None
+ spec: EnvSpec | None = None
# All agents that may appear in the environment
- possible_agents: Tuple[str, ...]
+ possible_agents: tuple[str, ...]
# State space
state_space: spaces.Space | None = None
# Action space for each agent
- action_spaces: Dict[str, spaces.Space]
+ action_spaces: dict[str, spaces.Space]
# Observation space for each agent
- observation_spaces: Dict[str, spaces.Space]
+ observation_spaces: dict[str, spaces.Space]
# Whether the environment is symmetric or not (is asymmetric)
is_symmetric: bool
# Random number generator, created as needed by `rng` method.
_rng: seeding.RNG | None = None
@abc.abstractmethod
- def get_agents(self, state: StateType) -> List[str]:
+ def get_agents(self, state: StateType) -> list[str]:
"""Get list of IDs for all agents that are active in given state.
The list of active agents may change depending on state.
@@ -131,12 +138,12 @@ def get_agents(self, state: StateType) -> List[str]:
For any environment where the number of agents remains constant during AND
across episodes. This will be :attr:`possible_agents`, independent of state.
- Arguments
+ Arguments:
---------
state : StateType
The environment state
- Returns
+ Returns:
-------
List[str]
List of IDs for all agents that active in given state,
@@ -155,15 +162,15 @@ def sample_initial_state(self) -> StateType:
"""
@abc.abstractmethod
- def sample_initial_obs(self, state: StateType) -> Dict[str, ObsType]:
+ def sample_initial_obs(self, state: StateType) -> dict[str, ObsType]:
"""Sample initial agent observations given an initial state.
- Arguments
+ Arguments:
---------
state : StateType
The initial state.
- Returns
+ Returns:
-------
Dict[str, ObsType]
A mapping from agent ID to their initial observation.
@@ -172,7 +179,7 @@ def sample_initial_obs(self, state: StateType) -> Dict[str, ObsType]:
@abc.abstractmethod
def step(
- self, state: StateType, actions: Dict[str, ActType]
+ self, state: StateType, actions: dict[str, ActType]
) -> JointTimestep[StateType, ObsType]:
"""Perform generative step.
@@ -188,14 +195,14 @@ def step(
value. We suggest using the "outcome" key with an instance of the ``Outcome``
class for values.
- Arguments
+ Arguments:
---------
state : StateType
The state.
actions : Dict[str, ActType]
a joint action containing one action per active agent in the environment.
- Returns
+ Returns:
-------
JointTimestep
joint timestep result of performing actions in given state, including next
@@ -209,7 +216,7 @@ def seed(self, seed: int | None = None):
Also handles seeding for the action, observation, and (if it exists) state
spaces.
- Arguments
+ Arguments:
---------
seed : int, optional
The seed that is used to initialize the models's RNG. If the
@@ -222,8 +229,12 @@ def seed(self, seed: int | None = None):
self._rng, seed = seeding.std_random(seed)
elif isinstance(self.rng, np.random.Generator):
self._rng, seed = seeding.np_random(seed)
+ elif torch is not None and isinstance(self.rng, torch.Generator):
+ if seed is None:
+ seed = 42
+ self.rng.manual_seed(seed)
else:
- raise error.UnseedableEnv(
+ raise error.UnseedableEnvError(
f"{self.__class__.__name__} unseedable. Please ensure the model has "
"implemented the rng property. The model class must also overwrite "
"the `seed` method if it uses a RNG not from the `random` or "
@@ -249,19 +260,19 @@ def sample_agent_initial_state(self, agent_id: str, obs: ObsType) -> StateType:
are used for planning and where there are a huge number of possible initial
states.
- Arguments
+ Arguments:
---------
- agent_id : Union[int, str]
+ agent_id : int | str
The ID of the agent to get initial state for.
obs : ObsType
The initial observation of the agent.
- Returns
+ Returns:
-------
StateType
An initial state for the agent conditioned on their initial observation.
- Raises
+ Raises:
------
NotImplementedError
If this method is not implemented.
@@ -270,7 +281,7 @@ def sample_agent_initial_state(self, agent_id: str, obs: ObsType) -> StateType:
raise NotImplementedError
@property
- def reward_ranges(self) -> Dict[str, Tuple[float, float]]:
+ def reward_ranges(self) -> dict[str, tuple[float, float]]:
r"""A mapping from Agent ID to min and max possible rewards for that agent.
Each reward tuple corresponding to the minimum and maximum possible rewards for
@@ -327,7 +338,7 @@ class POSGFullModel(POSGModel[StateType, ObsType, ActType], abc.ABC):
"""
@abc.abstractmethod
- def get_initial_belief(self) -> Dict[StateType, float]:
+ def get_initial_belief(self) -> dict[StateType, float]:
r"""The initial belief distribution: :math:`b_{0}`.
The initial belief distribution :math:`b_{0}` maps initial states to
@@ -344,7 +355,7 @@ def get_initial_belief(self) -> Dict[StateType, float]:
@abc.abstractmethod
def transition_fn(
- self, state: StateType, actions: Dict[str, ActType], next_state: StateType
+ self, state: StateType, actions: dict[str, ActType], next_state: StateType
) -> float:
r"""Transition function :math:`T(s', a, s)`.
@@ -352,7 +363,7 @@ def transition_fn(
:math:`Pr(s'|s, a)`, the probability of getting next state `s'` given the
environment was in state `s` and joint action `a` was performed.
- Arguments
+ Arguments:
---------
state : StateType
the state the environment was in
@@ -361,7 +372,7 @@ def transition_fn(
next_state : StateType
the state of the environment after actions were performed
- Returns
+ Returns:
-------
float
:math:`Pr(s'|s, a)`, the probability of getting next state `s'` given the
@@ -372,9 +383,9 @@ def transition_fn(
@abc.abstractmethod
def observation_fn(
self,
- obs: Dict[str, ObsType],
+ obs: dict[str, ObsType],
next_state: StateType,
- actions: Dict[str, ActType],
+ actions: dict[str, ActType],
) -> float:
r"""Observation function :math:`Z(o, s', a)`.
@@ -382,7 +393,7 @@ def observation_fn(
:math:`Pr(o|s', a)`, the probability of joint observation `o` given the joint
action `a` was performed and the environment ended up in state `s'`
- Arguments
+ Arguments:
---------
obs : Dict[str, ObsType]
the observation received
@@ -391,7 +402,7 @@ def observation_fn(
next_state : StateType
the state of the environment after actions were performed
- Returns
+ Returns:
-------
float
:math:`Pr(o|s', a)`, the probability of joint observation `o` given the
@@ -401,22 +412,22 @@ def observation_fn(
@abc.abstractmethod
def reward_fn(
- self, state: StateType, actions: Dict[str, ActType]
- ) -> Dict[str, float]:
+ self, state: StateType, actions: dict[str, ActType]
+ ) -> dict[str, float]:
r"""The reward Function :math:`R(s, a)`.
The reward function :math:`R(s, a) \rightarrow \mathbf{R}^n` where `n` is the
number of agents, defines the reward each agent receives given joint action
`a` was performed in state `s`.
- Arguments
+ Arguments:
---------
state : StateType
the state the environment was in
actions : Dict[str, ActType]
the joint action performed
- Returns
+ Returns:
-------
Dict[str, float]
The reward each agent receives given joint action `a` was performed in
diff --git a/posggym/utils/env_checker.py b/posggym/utils/env_checker.py
index e68b491..fce2351 100644
--- a/posggym/utils/env_checker.py
+++ b/posggym/utils/env_checker.py
@@ -30,17 +30,18 @@
env_reset_passive_checker,
env_step_passive_checker,
)
+from posggym.utils.torch_utils import maybe_expand_dims
def check_reset_seed(env: posggym.Env):
"""Check that the environment can be reset with a seed.
- Arguments
+ Arguments:
---------
env
The environment to check
- Raises
+ Raises:
------
AssertionError
The environment cannot be reset with a random seed, even though `seed` or
@@ -113,7 +114,7 @@ def check_reset_seed(env: posggym.Env):
seed_param = signature.parameters.get("seed")
# Check the default value is None
if seed_param is not None and seed_param.default is not None:
- logger.warn(
+ logger.warning(
"The default seed argument in reset should be `None`, otherwise the "
"environment will by default always be deterministic. "
f"Actual default: {seed_param.default}"
@@ -128,12 +129,12 @@ def check_reset_seed(env: posggym.Env):
def check_reset_options(env: posggym.Env):
"""Check that the environment can be reset with options.
- Arguments
+ Arguments:
---------
env
The environment to check
- Raises
+ Raises:
------
AssertionError
The environment cannot be reset with options, even though `options` or `kwargs`
@@ -163,12 +164,12 @@ def check_reset_options(env: posggym.Env):
def check_reset_return_type(env: posggym.Env):
"""Checks that :meth:`reset` correctly returns a tuple of the form `(obs , info)`.
- Arguments
+ Arguments:
---------
env
The environment to check
- Raises
+ Raises:
------
AssertionError
depending on spec violation
@@ -199,7 +200,7 @@ def check_env(env: posggym.Env, skip_render_check: bool = False):
This is particularly useful when using a custom environment.
- Arguments
+ Arguments:
---------
env
The posggym environment that will be checked
@@ -215,7 +216,7 @@ def check_env(env: posggym.Env, skip_render_check: bool = False):
), f"The environment must inherit from the posggym.Env class. {more_info_msg}"
if env.unwrapped is not env:
- logger.warn(
+ logger.warning(
f"The environment ({env}) is different from the unwrapped version "
f"({env.unwrapped}). This could effect the environment checker as the "
"environment most likely has a wrapper applied to it. We recommend using "
@@ -242,8 +243,10 @@ def check_env(env: posggym.Env, skip_render_check: bool = False):
# ============ Check the returned values ===============
env_reset_passive_checker(env)
+
env_step_passive_checker(
- env, {i: env.action_spaces[i].sample() for i in env.agents}
+ env,
+ {i: maybe_expand_dims(env, env.action_spaces[i].sample()) for i in env.agents},
)
# ==== Check the render method and the declared render modes ====
diff --git a/posggym/utils/history.py b/posggym/utils/history.py
index c7b4397..3e6db8f 100644
--- a/posggym/utils/history.py
+++ b/posggym/utils/history.py
@@ -1,5 +1,5 @@
"""Utilities for storing and managing agent action-observation histories."""
-from typing import Dict, Generic, List, Optional, Tuple
+from typing import Generic
import posggym.model as M
@@ -12,8 +12,8 @@ class AgentHistory(Generic[M.ActType, M.ObsType]):
"""
def __init__(
- self, history: Tuple[Tuple[Optional[M.ActType], Optional[M.ObsType]], ...]
- ):
+ self, history: tuple[tuple[M.ActType | None, M.ObsType | None], ...]
+ ) -> None:
self.history = history
self.t = len(history) - 1
@@ -33,7 +33,7 @@ def get_sub_history(self, horizon: int) -> "AgentHistory":
return self
return AgentHistory(self.history[:horizon])
- def get_last_step(self) -> Tuple[Optional[M.ActType], Optional[M.ObsType]]:
+ def get_last_step(self) -> tuple[M.ActType | None, M.ObsType | None]:
"""Get the last step in the history."""
return self.history[-1]
@@ -46,7 +46,7 @@ def horizon(self) -> int:
return len(self.history)
@classmethod
- def get_init_history(cls, obs: Optional[M.ObsType] = None) -> "AgentHistory":
+ def get_init_history(cls, obs: M.ObsType | None = None) -> "AgentHistory":
"""Get Initial history."""
if obs is None:
return cls(())
@@ -81,7 +81,7 @@ def __iter__(self):
class _AgentHistoryIterator:
- def __init__(self, history: AgentHistory):
+ def __init__(self, history: AgentHistory) -> None:
self.history = history
self._idx = 0
@@ -98,14 +98,14 @@ def __next__(self):
class JointHistory:
"""A joint history for all agents in the environment."""
- def __init__(self, agent_histories: Dict[str, AgentHistory]):
+ def __init__(self, agent_histories: dict[str, AgentHistory]) -> None:
self.agent_histories = agent_histories
self.agent_ids = sorted(agent_histories.keys())
self.num_agents = len(self.agent_histories)
@classmethod
def get_init_history(
- cls, agent_ids: List[str], obs: Optional[Dict[str, M.ObsType]] = None
+ cls, agent_ids: list[str], obs: dict[str, M.ObsType] | None = None
) -> "JointHistory":
"""Get Initial joint history."""
if obs is None:
@@ -117,7 +117,7 @@ def get_agent_history(self, agent_id: str) -> AgentHistory:
return self.agent_histories[agent_id]
def extend(
- self, action: Dict[str, M.ActType], obs: Dict[str, M.ObsType]
+ self, action: dict[str, M.ActType], obs: dict[str, M.ObsType]
) -> "JointHistory":
"""Extend the current history with given action, observation pair."""
new_agent_histories = {
diff --git a/posggym/utils/model_checker.py b/posggym/utils/model_checker.py
index 7dddfc7..13c1e63 100644
--- a/posggym/utils/model_checker.py
+++ b/posggym/utils/model_checker.py
@@ -16,10 +16,9 @@
import inspect
from copy import deepcopy
-from typing import Optional
import posggym.model as M
-from posggym import logger
+from posggym import Env, logger
from posggym.utils.passive_env_checker import (
check_agent_action_spaces,
check_agent_obs,
@@ -31,22 +30,23 @@
data_equivalence,
model_step_passive_checker,
)
+from posggym.utils.torch_utils import maybe_expand_dims
def check_initial_state_type(model: M.POSGModel) -> M.StateType:
"""Checks that :meth:`sample_initial_state` correctly returns a valid state.
- Arguments
+ Arguments:
---------
model
The model to check
- Returns
+ Returns:
-------
state
sampled initial state
- Raises
+ Raises:
------
AssertionError
depending on spec violation
@@ -57,12 +57,12 @@ def check_initial_state_type(model: M.POSGModel) -> M.StateType:
return state
-def check_initial_obs_type(model: M.POSGModel, state: Optional[M.StateType] = None):
+def check_initial_obs_type(model: M.POSGModel, state: M.StateType | None = None):
"""Checks that :meth:`sample_initial_obs` works correctly.
Assumes ``model.sample_initial_state()`` works as expected.
- Arguments
+ Arguments:
---------
model
the model to check
@@ -70,7 +70,7 @@ def check_initial_obs_type(model: M.POSGModel, state: Optional[M.StateType] = No
the state to use for check, default is None in which case a new state is
sampled from the model.
- Raises
+ Raises:
------
AssertionError
depending on spec violation
@@ -82,21 +82,21 @@ def check_initial_obs_type(model: M.POSGModel, state: Optional[M.StateType] = No
try:
obs = model.sample_initial_obs(state)
check_agent_obs(obs, model.observation_spaces, "sample_initial_obs")
- except NotImplementedError:
+ except NotImplementedError as err:
raise AssertionError(
"Model requires the ``sample_initial_obs`` method to be implemented."
- )
+ ) from err
def check_initial_sampling_seed(model: M.POSGModel):
"""Check that model seeding works correctly for initial conditions.
- Arguments
+ Arguments:
---------
model
The environment model to check
- Raises
+ Raises:
------
AssertionError
The model random seeding doesn't work as expected.
@@ -179,21 +179,21 @@ def check_initial_sampling_seed(model: M.POSGModel):
seed_param = signature.parameters.get("seed")
# Check the default value is None
if seed_param is not None and seed_param.default is not None:
- logger.warn(
+ logger.warning(
"The default seed argument in `seed` method should be `None`, otherwise "
"the model will by default always be deterministic. "
f"Actual default: {seed_param.default}"
)
-def check_model(model: M.POSGModel):
+def check_model(env: Env, model: M.POSGModel):
"""Check that an environment model follows posggym API.
This is an invasive function that calls the models step.
This is particularly useful when using a custom environment.
- Arguments
+ Arguments:
---------
model
The posggym environment model that will be checked
@@ -259,5 +259,8 @@ def check_model(model: M.POSGModel):
model_step_passive_checker(
model,
state,
- {i: model.action_spaces[i].sample() for i in model.get_agents(state)},
+ {
+ i: maybe_expand_dims(env, model.action_spaces[i].sample())
+ for i in model.get_agents(state)
+ },
)
diff --git a/posggym/utils/passive_env_checker.py b/posggym/utils/passive_env_checker.py
index e27b410..8a1ad6d 100644
--- a/posggym/utils/passive_env_checker.py
+++ b/posggym/utils/passive_env_checker.py
@@ -8,8 +8,8 @@
import inspect
import random
+from collections.abc import Callable, Sequence
from functools import partial
-from typing import Callable, Dict, Optional, Sequence
import numpy as np
from gymnasium import Space, spaces
@@ -18,19 +18,40 @@
import posggym.model as M
from posggym import error, logger
from posggym.utils import seeding
+from posggym.utils.torch_utils import has_integer_dtype
+
+
+NUM_CHANNELS = 3
+
+try:
+ import torch
+
+ numpy_to_torch_dtype_map = {
+ np.dtype(np.float32): torch.float32,
+ np.dtype(np.float64): torch.float64,
+ np.dtype(np.int32): torch.int32,
+ np.dtype(np.int64): torch.int64,
+ np.dtype(np.uint8): torch.uint8,
+ np.dtype(np.bool_): torch.bool,
+ np.dtype(np.float16): torch.float16,
+ }
+
+except ImportError:
+ torch = None
+ numpy_to_torch_dtype_map = {}
def data_equivalence(data_1, data_2) -> bool:
"""Assert equality between data 1 and 2, i.e observations, actions, info.
- Arguments
+ Arguments:
---------
data_1
data structure 1
data_2
data structure 2
- Returns
+ Returns:
-------
bool
If observation 1 and 2 are equivalent
@@ -41,14 +62,19 @@ def data_equivalence(data_1, data_2) -> bool:
return data_1.keys() == data_2.keys() and all(
data_equivalence(data_1[k], data_2[k]) for k in data_1
)
- elif isinstance(data_1, (tuple, list)):
+ elif isinstance(data_1, tuple | list):
return len(data_1) == len(data_2) and all(
- data_equivalence(o_1, o_2) for o_1, o_2 in zip(data_1, data_2)
+ data_equivalence(o_1, o_2)
+ for o_1, o_2 in zip(data_1, data_2, strict=False)
)
elif isinstance(data_1, np.ndarray):
return data_1.shape == data_2.shape and np.allclose(
data_1, data_2, atol=0.00001
)
+ elif torch is not None and isinstance(data_1, torch.Tensor):
+ return data_1.shape == data_2.shape and torch.allclose(
+ data_1, data_2, atol=0.00001
+ )
else:
return data_1 == data_2
else:
@@ -61,6 +87,7 @@ def check_rng_equality(rng_1: seeding.RNG, rng_2: seeding.RNG, prefix=None):
rng_2
), f"{prefix}Differing RNG types: {rng_1} and {rng_2}"
if isinstance(rng_1, random.Random) and isinstance(rng_2, random.Random):
+ print(rng_1.getstate(), rng_2.getstate())
assert (
rng_1.getstate() == rng_2.getstate()
), f"{prefix}Internal states differ: {rng_1} and {rng_2}"
@@ -70,6 +97,16 @@ def check_rng_equality(rng_1: seeding.RNG, rng_2: seeding.RNG, prefix=None):
assert (
rng_1.bit_generator.state == rng_2.bit_generator.state
), f"{prefix}Internal states differ: {rng_1} and {rng_2}"
+ elif (
+ torch is not None
+ and isinstance(rng_1, torch.Generator)
+ and isinstance(rng_2, torch.Generator)
+ ):
+ print(rng_1.get_state(), rng_2.get_state())
+ assert torch.equal(
+ rng_1.get_state(), rng_2.get_state()
+ ), f"{prefix}Internal states differ: {rng_1} and {rng_2}"
+
else:
raise AssertionError(f"{prefix}Unsupported RNG type: '{type(rng_1)}'.")
@@ -78,12 +115,12 @@ def check_space_limit(space, space_type: str):
"""Check the space limit for only the Box space."""
if isinstance(space, spaces.Box):
if np.any(np.equal(space.low, -np.inf)):
- logger.warn(
+ logger.warning(
f"A Box {space_type} space minimum value is -infinity. This is "
"probably too low."
)
if np.any(np.equal(space.high, np.inf)):
- logger.warn(
+ logger.warning(
f"A Box {space_type} space maximum value is -infinity. This is "
"probably too high."
)
@@ -102,7 +139,7 @@ def check_space_limit(space, space_type: str):
or np.any(space.high > 1)
)
):
- logger.warn(
+ logger.warning(
"For Box action spaces, we recommend using a symmetric and normalized "
"space (range=[-1, 1] or [0, 1]). See "
"https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html "
@@ -116,9 +153,9 @@ def check_space_limit(space, space_type: str):
check_space_limit(subspace, space_type)
-def check_agent_space_limits(agent_spaces: Dict[str, Space], space_type: str):
+def check_agent_space_limits(agent_spaces: dict[str, Space], space_type: str):
"""Check the space limit for any agent Box space."""
- for i, agent_space in agent_spaces.items():
+ for _i, agent_space in agent_spaces.items():
check_space_limit(agent_space, space_type)
@@ -134,46 +171,52 @@ def _check_box_state_space(state_space: spaces.Box):
)
if np.any(state_space.low == state_space.high):
- logger.warn(
+ logger.warning(
"A Box state space maximum and minimum values are equal. "
"Actual equal coordinates: "
- f"{list(zip(*np.where(state_space.low == state_space.high)))}"
+ f"{list(zip(*np.where(state_space.low == state_space.high), strict=False))}"
)
elif np.any(state_space.high < state_space.low):
- logger.warn(
+ logger.warning(
"A Box state space low value is greater than a high value. "
"Actual less than coordinates: "
- f"{list(zip(*np.where(state_space.high < state_space.low)))}"
+ f"{list(zip(*np.where(state_space.high < state_space.low), strict=False))}"
)
def _check_box_observation_space(observation_space: spaces.Box):
"""Checks that a :class:`Box` observation space is defined in a sensible way."""
# Check if the box is an image
+ LOWER_BOUND = 0
+ UPPER_BOUND = np.iinfo(np.uint8).max
+
if (len(observation_space.shape) == 3 and observation_space.shape[0] != 1) or (
len(observation_space.shape) == 4 and observation_space.shape[0] == 1
):
if observation_space.dtype != np.uint8:
- logger.warn(
+ logger.warning(
"It seems a Box observation space is an image but the `dtype` is not "
f"`np.uint8`, actual type: {observation_space.dtype}. "
"If the Box observation space is not an image, we recommend flattening "
"the observation to have only a 1D vector."
)
- if np.any(observation_space.low != 0) or np.any(observation_space.high != 255):
- logger.warn(
+ if np.any(observation_space.low != LOWER_BOUND) or np.any(
+ observation_space.high != UPPER_BOUND
+ ):
+ logger.warning(
"It seems a Box observation space is an image but the lower and upper "
"bounds are not [0, 255]. "
- f"Actual lower bound: {np.min(observation_space.low)}, upper bound: "
- f"{np.max(observation_space.high)}. "
+ "Actual lower bound: %s, upper bound: %s. "
"Generally, CNN policies assume observations are within that range, so "
- "you may encounter an issue if the observation values are not."
+ "you may encounter an issue if the observation values are not.",
+ np.min(observation_space.low),
+ np.max(observation_space.high),
)
if len(observation_space.shape) not in [1, 3] and not (
len(observation_space.shape) == 2 and observation_space.shape[0] == 1
):
- logger.warn(
+ logger.warning(
"A Box observation space has an unconventional shape (neither an image, "
"nor a 1D vector). We recommend flattening the observation to have only a "
"1D vector or use a custom policy to properly process the data. "
@@ -191,16 +234,23 @@ def _check_box_observation_space(observation_space: spaces.Box):
)
if np.any(observation_space.low == observation_space.high):
- logger.warn(
+ equal_coords = list(
+ zip(
+ *np.where(observation_space.low == observation_space.high), strict=False
+ )
+ )
+ logger.warning(
"A Box observation space maximum and minimum values are equal. "
- "Actual equal coordinates: "
- f"{list(zip(*np.where(observation_space.low == observation_space.high)))}"
+ f"Actual equal coordinates: {equal_coords}"
)
+
elif np.any(observation_space.high < observation_space.low):
- logger.warn(
+ less_than_coords = list(
+ zip(*np.where(observation_space.high < observation_space.low), strict=False)
+ )
+ logger.warning(
"A Box observation space low value is greater than a high value. "
- "Actual less than coordinates: "
- f"{list(zip(*np.where(observation_space.high < observation_space.low)))}"
+ f"Actual less than coordinates: {less_than_coords}"
)
@@ -216,16 +266,20 @@ def _check_box_action_space(action_space: spaces.Box):
)
if np.any(action_space.low == action_space.high):
- logger.warn(
+ equal_coords = list(
+ zip(*np.where(action_space.low == action_space.high), strict=False)
+ )
+ logger.warning(
"A Box action space maximum and minimum values are equal. "
- "Actual equal coordinates: "
- f"{list(zip(*np.where(action_space.low == action_space.high)))}"
+ f"Actual equal coordinates: {equal_coords}"
)
elif np.any(action_space.high < action_space.low):
- logger.warn(
+ less_than_coords = list(
+ zip(*np.where(action_space.high < action_space.low), strict=False)
+ )
+ logger.warning(
"A Box action space low value is greater than a high value. "
- "Actual less than coordinates: "
- f"{list(zip(*np.where(action_space.high < action_space.low)))}"
+ f"Actual less than coordinates: {less_than_coords}"
)
@@ -279,7 +333,7 @@ def check_space(
def check_agent_spaces(
- agent_spaces: Dict[str, Space],
+ agent_spaces: dict[str, Space],
space_type: str,
check_box_space_fn: Callable[[spaces.Box], None],
):
@@ -312,7 +366,7 @@ def check_agent_spaces(
def check_state(state: M.StateType, model: M.POSGModel):
"""Check state is valid.
- Arguments
+ Arguments:
---------
state
the state to check
@@ -327,7 +381,7 @@ def check_state(state: M.StateType, model: M.POSGModel):
def check_obs(obs, observation_space: spaces.Space, method_name: str):
"""Check the observation returned by the environment correspond to the declared one.
- Arguments
+ Arguments:
---------
obs
The observation to check
@@ -339,30 +393,63 @@ def check_obs(obs, observation_space: spaces.Space, method_name: str):
"""
pre = f"The obs returned by the `{method_name}()` method"
if isinstance(observation_space, spaces.Discrete):
- if not isinstance(obs, (np.int64, int)):
- logger.warn(f"{pre} should be an int or np.int64, actual type: {type(obs)}")
+ if not isinstance(
+ obs, tuple([np.int64, int, torch.int64] if torch is not None else [])
+ ):
+ logger.warning(
+ f"{pre} should be an int, np.int64 or "
+ f"torch.int64, actual type: {type(obs)}"
+ )
elif isinstance(observation_space, spaces.Box):
if observation_space.shape != ():
- if not isinstance(obs, np.ndarray):
- logger.warn(
- f"{pre} was expecting a numpy array, actual type: {type(obs)}"
+ if not isinstance(obs, np.ndarray) and not (
+ torch is not None and isinstance(obs, torch.Tensor)
+ ):
+ logger.warning(
+ f"{pre} was expecting a numpy array"
+ f"or a torch tensor, actual type: {type(obs)}"
)
- elif obs.dtype != observation_space.dtype:
- logger.warn(
- f"{pre} was expecting numpy array dtype to be "
- f"{observation_space.dtype}, actual type: {obs.dtype}"
+ if isinstance(obs, np.ndarray):
+ # Check for NumPy array dtype
+ if obs.dtype != observation_space.dtype:
+ logger.warning(
+ f"{pre} was expecting numpy array dtype to be "
+ f"{observation_space.dtype}, actual dtype: {obs.dtype}"
+ )
+ elif torch is not None and isinstance(obs, torch.Tensor):
+ # Find the equivalent PyTorch dtype
+ torch_dtype = numpy_to_torch_dtype_map.get(
+ observation_space.dtype # type: ignore
)
- elif isinstance(observation_space, (spaces.MultiBinary, spaces.MultiDiscrete)):
- if not isinstance(obs, np.ndarray):
- logger.warn(f"{pre} was expecting a numpy array, actual type: {type(obs)}")
+
+ if torch_dtype is None:
+ logger.warning(
+ f"{pre} does not have a mapped torch"
+ f"equivalent for numpy dtype: {observation_space.dtype}"
+ )
+ elif obs.dtype != torch_dtype:
+ logger.warning(
+ f"{pre} was expecting torch tensor dtype to be "
+ f"{torch_dtype}, actual dtype: {obs.dtype}"
+ )
+ obs = obs.detach().cpu().numpy().squeeze()
+
+ elif isinstance(observation_space, spaces.MultiBinary | spaces.MultiDiscrete):
+ if not isinstance(obs, np.ndarray) and not (
+ torch is not None and isinstance(obs, torch.Tensor)
+ ):
+ logger.warning(
+ f"{pre} was expecting a numpy array or a"
+ f"torch tensor, actual type: {type(obs)}"
+ )
elif isinstance(observation_space, spaces.Tuple):
if not isinstance(obs, tuple):
- logger.warn(f"{pre} was expecting a tuple, actual type: {type(obs)}")
+ logger.warning(f"{pre} was expecting a tuple, actual type: {type(obs)}")
assert len(obs) == len(observation_space.spaces), (
f"{pre} length is not same as the observation space length, obs length: "
f"{len(obs)}, space length: {len(observation_space.spaces)}"
)
- for sub_obs, sub_space in zip(obs, observation_space.spaces):
+ for sub_obs, sub_space in zip(obs, observation_space.spaces, strict=False):
check_obs(sub_obs, sub_space, method_name)
elif isinstance(observation_space, spaces.Dict):
assert isinstance(obs, dict), f"{pre} must be a dict, actual type: {type(obs)}"
@@ -376,19 +463,19 @@ def check_obs(obs, observation_space: spaces.Space, method_name: str):
try:
if obs not in observation_space:
- logger.warn(f"{pre} is not within the observation space.")
+ logger.warning(f"{pre} is not within the observation space.")
except Exception as e:
- logger.warn(f"{pre} is not within the observation space with exception: {e}")
+ logger.warning(f"{pre} is not within the observation space with exception: {e}")
def check_agent_obs(
- obs: Dict[str, M.ObsType],
- observation_spaces: Dict[str, Space],
+ obs: dict[str, M.ObsType],
+ observation_spaces: dict[str, Space],
method_name: str,
):
"""Check that each agent's observation returned by the environment is valid.
- Arguments
+ Arguments:
---------
obs
The observation for each agent to check
@@ -405,10 +492,10 @@ def check_agent_obs(
raise AssertionError("Invalid observation for agent `{i}`.") from e
-def check_reset_obs(obs: Dict[str, M.ObsType], model: M.POSGModel):
+def check_reset_obs(obs: dict[str, M.ObsType], model: M.POSGModel):
"""Check agent observations returned by the environment `reset()` method are valid.
- Arguments
+ Arguments:
---------
obs
The observation for each agent to check
@@ -420,7 +507,7 @@ def check_reset_obs(obs: Dict[str, M.ObsType], model: M.POSGModel):
"Expected observation from `env.reset()` to be a dictionary, mapping agent IDs "
" to agent obs."
)
- for i, o_i in obs.items():
+ for i, _o_i in obs.items():
assert (
i in model.possible_agents
), f"Invalid agent ID `{i}`. Possible IDs are {model.possible_agents}."
@@ -436,7 +523,7 @@ def env_reset_passive_checker(env, **kwargs):
"""
signature = inspect.signature(env.reset)
if "seed" not in signature.parameters and "kwargs" not in signature.parameters:
- logger.warn(
+ logger.warning(
"posggym requires that `Env.reset` can be passed a `seed` for resetting the"
" environment random number generator."
)
@@ -444,14 +531,14 @@ def env_reset_passive_checker(env, **kwargs):
seed_param = signature.parameters.get("seed")
# Check the default value is None
if seed_param is not None and seed_param.default is not None:
- logger.warn(
+ logger.warning(
"The default seed argument in `Env.reset` should be `None`, otherwise "
"the environment will by default always be deterministic. "
f"Actual default: {seed_param}"
)
if "options" not in signature.parameters and "kwargs" not in signature.parameters:
- logger.warn(
+ logger.warning(
"posggym requires that `Env.reset` can be passed `options` to allow the "
"environment initialisation to be passed additional information."
)
@@ -460,13 +547,13 @@ def env_reset_passive_checker(env, **kwargs):
result = env.reset(**kwargs)
if not isinstance(result, tuple):
- logger.warn(
+ logger.warning(
"The result returned by `env.reset()` was not a tuple of the form "
"`(obs, info)`, where `obs` is a observation and `info` is a dictionary "
f"containing additional information. Actual type: `{type(result)}`"
)
elif len(result) != 2:
- logger.warn(
+ logger.warning(
"The result returned by `env.reset()` should be `(obs, info)` by default, "
"where `obs` is a observation and `info` is a dictionary containing "
"additional information."
@@ -483,9 +570,9 @@ def env_reset_passive_checker(env, **kwargs):
def _check_agent_dict(
agent_dict,
- possible_agents: Optional[Sequence[str]],
+ possible_agents: Sequence[str] | None,
dict_type: str,
- expected_agents: Optional[Sequence[str]] = None,
+ expected_agents: Sequence[str] | None = None,
):
assert isinstance(agent_dict, dict), (
f"Agent {dict_type} dictionary must be a dictionary mapping agentID to values."
@@ -506,7 +593,7 @@ def _check_agent_dict(
def model_step_passive_checker(
- model: M.POSGModel, state: M.StateType, actions: Dict[str, M.ActType]
+ model: M.POSGModel, state: M.StateType, actions: dict[str, M.ActType]
):
"""A passive check for the model step.
@@ -540,46 +627,76 @@ def model_step_passive_checker(
check_state(next_state, model)
check_agent_obs(obs, model.observation_spaces, "step")
- if not all(isinstance(t_i, (bool, np.bool_)) for t_i in terminated.values()):
- logger.warn(
+ if not all(
+ isinstance(t_i, bool | np.bool_)
+ or (
+ torch is not None
+ and isinstance(t_i, torch.Tensor)
+ and t_i.dtype == torch.bool
+ )
+ for t_i in terminated.values()
+ ):
+ logger.warning(
"Expects `terminated` signal to be a boolean for every agent, "
f"actual types: {[type(t_i) for t_i in terminated.values()]}."
)
- if not all(isinstance(t_i, (bool, np.bool_)) for t_i in truncated.values()):
- logger.warn(
+ if not all(
+ isinstance(t_i, bool | np.bool_)
+ or (
+ torch is not None
+ and isinstance(t_i, torch.Tensor)
+ and t_i.dtype == torch.bool
+ )
+ for t_i in truncated.values()
+ ):
+ logger.warning(
"Expects `truncated` signal to be a boolean for every agent, "
f"actual types: {[type(t_i) for t_i in truncated.values()]}."
)
- if not isinstance(done, (bool, np.bool_)):
- logger.warn(
+ if not (
+ isinstance(done, bool | np.bool_)
+ or (
+ torch is not None
+ and isinstance(done, torch.Tensor)
+ and done.dtype == torch.bool
+ )
+ ):
+ logger.warning(
"Expects `done` signal returned by `step()` to be a boolean, "
- f"actual type: {type(info)}"
+ f"actual type: {type(done)}"
)
if not (
all(
np.issubdtype(type(r_i), np.integer)
or np.issubdtype(type(r_i), np.floating)
+ or (torch and (torch.is_floating_point(r_i) or has_integer_dtype(r_i)))
for r_i in reward.values()
)
):
- logger.warn(
+ logger.warning(
"The reward returned for each agent by `step()` must be a float, int, "
- "np.integer or np.floating, "
+ "torch.integer, torch.floating, np.integer or np.floating, "
f"actual types: {[type(r_i) for r_i in reward.values()]}."
)
else:
for i, r_i in reward.items():
- if np.isnan(r_i): # type: ignore
- logger.warn(f"The reward for agent `{i}` is a NaN value.")
- if np.isinf(r_i): # type: ignore
- logger.warn(f"The reward for agent `{i}` is an inf value.")
+ if torch is not None and isinstance(r_i, torch.Tensor):
+ if torch.isnan(r_i): # type: ignore
+ logger.warning(f"The reward for agent `{i}` is a NaN value.")
+ if torch.isinf(r_i): # type: ignore
+ logger.warning(f"The reward for agent `{i}` is an inf value.")
+ else:
+ if np.isnan(r_i): # type: ignore
+ logger.warning(f"The reward for agent `{i}` is a NaN value.")
+ if np.isinf(r_i): # type: ignore
+ logger.warning(f"The reward for agent `{i}` is an inf value.")
return result
-def env_step_passive_checker(env: posggym.Env, actions: Dict[str, M.ActType]):
+def env_step_passive_checker(env: posggym.Env, actions: dict[str, M.ActType]):
"""A passive check for the environment step.
Investigating the returning data then returning the data unchanged.
@@ -609,21 +726,43 @@ def env_step_passive_checker(env: posggym.Env, actions: Dict[str, M.ActType]):
# record keeping
_check_agent_dict(info, env.possible_agents, "info")
- if not all(isinstance(t_i, (bool, np.bool_)) for t_i in terminated.values()):
- logger.warn(
+ if not all(
+ isinstance(t_i, bool | np.bool_)
+ or (
+ torch is not None
+ and isinstance(t_i, torch.Tensor)
+ and t_i.dtype == torch.bool
+ )
+ for t_i in terminated.values()
+ ):
+ logger.warning(
"Expects `terminated` signal to be a boolean for every agent, "
f"actual types: {[type(t_i) for t_i in terminated.values()]}."
)
- if not all(isinstance(t_i, (bool, np.bool_)) for t_i in truncated.values()):
- logger.warn(
+ if not all(
+ isinstance(t_i, bool | np.bool_)
+ or (
+ torch is not None
+ and isinstance(t_i, torch.Tensor)
+ and t_i.dtype == torch.bool
+ )
+ for t_i in truncated.values()
+ ):
+ logger.warning(
"Expects `truncated` signal to be a boolean for every agent, "
f"actual types: {[type(t_i) for t_i in truncated.values()]}."
)
-
- if not isinstance(done, (bool, np.bool_)):
- logger.warn(
+ if not (
+ isinstance(done, bool | np.bool_)
+ or (
+ torch is not None
+ and isinstance(done, torch.Tensor)
+ and done.dtype == torch.bool
+ )
+ ):
+ logger.warning(
"Expects `done` signal returned by `step()` to be a boolean, "
- f"actual type: {type(info)}"
+ f"actual type: {type(done)}"
)
else:
raise error.Error(
@@ -637,73 +776,83 @@ def env_step_passive_checker(env: posggym.Env, actions: Dict[str, M.ActType]):
all(
np.issubdtype(type(r_i), np.integer)
or np.issubdtype(type(r_i), np.floating)
+ or (torch and (torch.is_floating_point(r_i) or has_integer_dtype(r_i)))
for r_i in reward.values()
)
):
- logger.warn(
+ logger.warning(
"The reward returned for each agent by `step()` must be a float, int, "
"np.integer or np.floating, "
f"actual types: {[type(r_i) for r_i in reward.values()]}."
)
else:
for i, r_i in reward.items():
- if np.isnan(r_i): # type: ignore
- logger.warn(f"The reward for agent `{i}` is a NaN value.")
- if np.isinf(r_i): # type: ignore
- logger.warn(f"The reward for agent `{i}` is an inf value.")
+ if torch is not None and isinstance(r_i, torch.Tensor):
+ if torch.isnan(r_i): # type: ignore
+ logger.warning(f"The reward for agent `{i}` is a NaN value.")
+ if torch.isinf(r_i): # type: ignore
+ logger.warning(f"The reward for agent `{i}` is an inf value.")
+ else:
+ if np.isnan(r_i): # type: ignore
+ logger.warning(f"The reward for agent `{i}` is a NaN value.")
+ if np.isinf(r_i): # type: ignore
+ logger.warning(f"The reward for agent `{i}` is an inf value.")
return result
-def _check_render_return(render_mode, render_return):
+def _check_render_return(render_mode, render_return): # noqa: PLR0912
"""Produces warning if `render_return` doesn't match `render_mode`."""
if render_mode == "human":
if render_return is not None:
- logger.warn(
+ logger.warning(
f"Human rendering should return `None`, got {type(render_return)}"
)
elif render_mode == "rgb_array":
if not isinstance(render_return, np.ndarray):
- logger.warn(
+ logger.warning(
"RGB-array rendering should return a numpy array, got "
f"{type(render_return)}"
)
else:
if render_return.dtype != np.uint8:
- logger.warn(
+ logger.warning(
"RGB-array rendering should return a numpy array with dtype "
f"uint8, got {render_return.dtype}"
)
- if render_return.ndim != 3:
- logger.warn(
+ if render_return.ndim != NUM_CHANNELS:
+ logger.warning(
"RGB-array rendering should return a numpy array with three axes, "
f"got {render_return.ndim}"
)
- if render_return.ndim == 3 and render_return.shape[2] != 3:
- logger.warn(
+ if (
+ render_return.ndim == NUM_CHANNELS
+ and render_return.shape[2] != NUM_CHANNELS
+ ):
+ logger.warning(
"RGB-array rendering should return a numpy array in which the "
f"last axis has three dimensions, got {render_return.shape[2]}"
)
elif render_mode == "depth_array":
if not isinstance(render_return, np.ndarray):
- logger.warn(
+ logger.warning(
"Depth-array rendering should return a numpy array, got "
f"{type(render_return)}"
)
elif render_return.ndim != 2:
- logger.warn(
+ logger.warning(
"Depth-array rendering should return a numpy array with two axes, "
f"got {render_return.ndim}"
)
elif render_mode in ["ansi", "ascii"]:
if not isinstance(render_return, str):
- logger.warn(
+ logger.warning(
"ANSI/ASCII rendering should produce a string, got "
f"{type(render_return)}"
)
elif render_mode.endswith("_list"):
if not isinstance(render_return, list):
- logger.warn(
+ logger.warning(
"Render mode `{render_mode}` should produce a list, got "
f"{type(render_return)}"
)
@@ -717,7 +866,7 @@ def _check_render_return(render_mode, render_return):
# check posggym specific render modes, namely dict renders which
# return mapping from agentID to agent specific render.
if not isinstance(render_return, dict):
- logger.warn(
+ logger.warning(
f"Render mode `{render_mode}` should produce a dict, got "
f"{type(render_return)}"
)
@@ -741,19 +890,19 @@ def env_render_passive_checker(env: posggym.Env):
"""
render_modes = env.metadata.get("render_modes")
if render_modes is None:
- logger.warn(
+ logger.warning(
"No render modes were declared in the environment "
"(env.metadata['render_modes'] is None or not defined), you may have "
"trouble when calling `.render()`."
)
else:
- if not isinstance(render_modes, (list, tuple)):
- logger.warn(
+ if not isinstance(render_modes, list | tuple):
+ logger.warning(
"Expects the render_modes to be a sequence (i.e. list, tuple), "
f"actual type: {type(render_modes)}"
)
elif not all(isinstance(mode, str) for mode in render_modes):
- logger.warn(
+ logger.warning(
"Expects all render modes to be strings, "
f"actual types: {[type(mode) for mode in render_modes]}"
)
@@ -762,25 +911,24 @@ def env_render_passive_checker(env: posggym.Env):
# We only require `render_fps` if rendering is actually implemented
if len(render_modes) > 0:
if render_fps is None:
- logger.warn(
+ logger.warning(
"No render fps was declared in the environment "
"(env.metadata['render_fps'] is None or not defined), rendering "
"may occur at inconsistent fps."
)
+ elif not (
+ np.issubdtype(type(render_fps), np.integer)
+ or np.issubdtype(type(render_fps), np.floating)
+ ):
+ logger.warning(
+ "Expects the `env.metadata['render_fps']` to be an integer "
+ f"or a float, actual type: {type(render_fps)}"
+ )
else:
- if not (
- np.issubdtype(type(render_fps), np.integer)
- or np.issubdtype(type(render_fps), np.floating)
- ):
- logger.warn(
- "Expects the `env.metadata['render_fps']` to be an integer "
- f"or a float, actual type: {type(render_fps)}"
- )
- else:
- assert render_fps > 0, (
- "Expects the `env.metadata['render_fps']` to be greater than "
- f"zero, actual value: {render_fps}"
- )
+ assert render_fps > 0, (
+ "Expects the `env.metadata['render_fps']` to be greater than "
+ f"zero, actual value: {render_fps}"
+ )
# env.render is now an attribute with default None
if len(render_modes) == 0:
diff --git a/posggym/utils/run_random_agents.py b/posggym/utils/run_random_agents.py
new file mode 100644
index 0000000..4f7dc9f
--- /dev/null
+++ b/posggym/utils/run_random_agents.py
@@ -0,0 +1,124 @@
+"""Run random agents on an environment.
+
+This script runs an environment using random agents.
+
+The script takes a number of arguments (number of episodes, environment id, render
+mode, etc.). To see all available arguments, run:
+
+ python run_random_agents.py --help
+
+Example, to run 10 episodes of the `Driving-v1` environment with `human` rendering mode,
+
+ python run_random_agents.py \
+ --env_id Driving-v1 \
+ --num_episodes 10 \
+ --render_mode human
+"""
+
+import argparse
+
+import posggym
+
+
+def run_random(
+ env: posggym.Env,
+ num_episodes: int,
+ max_episode_steps: int | None = None,
+ seed: int | None = None,
+):
+ env.reset(seed=seed)
+
+ dones = 0
+ episode_steps = []
+ episode_rewards: dict[str, list[float]] = {i: [] for i in env.possible_agents}
+ for ep_num in range(num_episodes):
+ env.render()
+
+ t = 0
+ done = False
+ rewards = {i: 0.0 for i in env.possible_agents}
+ while not done and (max_episode_steps is None or t < max_episode_steps):
+ a = {i: env.action_spaces[i].sample() for i in env.agents}
+ _, r, _, _, done, _ = env.step(a)
+ t += 1
+
+ env.render()
+
+ for i, r_i in r.items():
+ rewards[i] += r_i
+
+ print(f"End episode {ep_num}")
+ dones += int(done)
+ episode_steps.append(t)
+
+ env.reset()
+
+ for i, r_i in rewards.items():
+ episode_rewards[i].append(r_i)
+
+ if done:
+ print(t, rewards)
+
+ env.close()
+
+ print("All episodes finished")
+ print(
+ f"Completed episodes (i.e. where 'done=True') = {dones} out of {num_episodes}"
+ )
+ mean_steps = sum(episode_steps) / len(episode_steps)
+ print(f"Mean episode steps = {mean_steps:.2f}")
+ mean_returns = {i: sum(r) / len(r) for i, r in episode_rewards.items()}
+ print(f"Mean Episode returns {mean_returns}")
+ return mean_steps, mean_returns
+
+
+def run_random_agent(
+ env_id: str,
+ num_episodes: int,
+ max_episode_steps: int | None = None,
+ seed: int | None = None,
+ render_mode: str | None = None,
+ **kwargs,
+):
+ """Run random agents."""
+ if max_episode_steps is not None:
+ env = posggym.make(
+ env_id,
+ render_mode=render_mode,
+ max_episode_steps=max_episode_steps,
+ **kwargs,
+ )
+ else:
+ env = posggym.make(env_id, render_mode=render_mode, **kwargs)
+
+ return run_random(env, num_episodes, max_episode_steps, seed)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+ parser.add_argument(
+ "--env_id", type=str, required=True, help="ID of environment to run"
+ )
+ parser.add_argument(
+ "--num_episodes",
+ type=int,
+ default=1,
+ help="The number of episodes to run.",
+ )
+ parser.add_argument(
+ "--max_episode_steps",
+ type=int,
+ default=None,
+ help="Max number of steps to run each episode for.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="Random Seed.")
+ parser.add_argument(
+ "--render_mode",
+ type=str,
+ default=None,
+ help="Mode to use for rendering.",
+ )
+ args = parser.parse_args()
+ run_random_agent(**vars(args))
diff --git a/posggym/utils/seeding.py b/posggym/utils/seeding.py
index 0663cd5..9d5e746 100644
--- a/posggym/utils/seeding.py
+++ b/posggym/utils/seeding.py
@@ -6,34 +6,36 @@
"""
import random
-from typing import Optional, Tuple, Union
import numpy as np
+from torch import Generator
from posggym import error
-RNG = Union[random.Random, np.random.Generator]
+RNG = random.Random | np.random.Generator | Generator
-def np_random(seed: Optional[int] = None) -> Tuple[np.random.Generator, int]:
+def np_random(seed: int | None = None) -> tuple[np.random.Generator, int]:
"""Create a numpy random number generator.
- Arguments
+ Arguments:
---------
seed : int, optional
the seed used to create the generator.
- Returns
+ Returns:
-------
rng : np.random.Generator
the random number generator
seed : int
the seed used for the rng (will equal argument seed if one is provided.)
- Raises
+ Raises:
------
- Error
+
+ Error:
+ -----
if seed is not None or a non-negative integer.
"""
@@ -55,24 +57,26 @@ def np_random(seed: Optional[int] = None) -> Tuple[np.random.Generator, int]:
return rng, np_seed
-def std_random(seed: Optional[int] = None) -> Tuple[random.Random, int]:
+def std_random(seed: int | None = None) -> tuple[random.Random, int]:
"""Create random number generator using python built-in `random.Random`.
- Arguments
+ Arguments:
---------
seed : int, optional
the seed used to create the generator.
- Returns
+ Returns:
-------
rng : random.Random
the random number generator
seed : int
the seed used for the rng (will equal argument seed if one is provided.)
- Raises
+ Raises:
------
- Error
+
+ Error:
+ -----
if seed is not None or a non-negative integer.
"""
diff --git a/posggym/utils/torch_utils.py b/posggym/utils/torch_utils.py
new file mode 100644
index 0000000..a45d32f
--- /dev/null
+++ b/posggym/utils/torch_utils.py
@@ -0,0 +1,51 @@
+import numpy as np
+
+
+try:
+ import torch
+except ImportError:
+ torch = None
+
+
+def maybe_expand_dims(env, x):
+ if hasattr(env, "batch_size"):
+ return np.expand_dims(x, axis=0) # Expand dims by 1 if batch_size exists
+ return x
+
+
+# https://github.com/pytorch/pytorch/issues/52161
+def has_integer_dtype(tensor, signed: bool | None = None) -> bool:
+ """Determines if a PyTorch tensor has an integer dtype.
+
+ It also can force `tensor` to be signed or unsigned.
+
+ Parameters
+ ----------
+ tensor
+ The tensor to check.
+ signed
+ Determines which dtypes are allowed for `tensor`:
+
+ - If ``None`` both unsigned and signed integer will be allowed.
+
+ - If ``False`` only unsigned dtypes will be allowed.
+
+ - If ``True`` only signed dtypes will be allowed.
+
+ Returns
+ -------
+ bool
+ ``True`` if the input tensor satisfies the requested condition, ``False``
+ otherwise.
+
+ """
+ assert torch is not None
+
+ uint_types = [torch.uint8]
+ sint_types = [torch.int8, torch.int16, torch.int32, torch.int64]
+ if signed is None:
+ return tensor.dtype in uint_types + sint_types
+ elif signed:
+ return tensor.dtype in sint_types
+ else:
+ return tensor.dtype in uint_types
diff --git a/posggym/vector/__init__.py b/posggym/vector/__init__.py
index 518359a..34628e5 100644
--- a/posggym/vector/__init__.py
+++ b/posggym/vector/__init__.py
@@ -1,4 +1,5 @@
"""Module for posggym vector utils."""
from posggym.vector.sync_vector_env import SyncVectorEnv
+
__all__ = ["SyncVectorEnv"]
diff --git a/posggym/vector/sync_vector_env.py b/posggym/vector/sync_vector_env.py
index 42d8ae2..904629c 100644
--- a/posggym/vector/sync_vector_env.py
+++ b/posggym/vector/sync_vector_env.py
@@ -8,8 +8,13 @@
from __future__ import annotations
+from typing import TYPE_CHECKING, Any
+
+
+if TYPE_CHECKING:
+ from collections.abc import Callable, Iterable
+
from copy import deepcopy
-from typing import Any, Callable, Dict, Iterable, List, Tuple
import numpy as np
from gymnasium.vector.utils import concatenate, create_empty_array
@@ -58,10 +63,10 @@ def __init__(
self,
env_fns: Iterable[Callable[[], posggym.Env]],
copy: bool = True,
- ):
+ ) -> None:
"""Initialize the vectorized environment.
- Arguments
+ Arguments:
---------
env_fns
iterable of callable functions that create the environments.
@@ -118,8 +123,8 @@ def __init__(
def reset(
self,
*,
- seed: int | None | List[int] = None,
- options: Dict[str, Any] | None = None,
+ seed: int | None | list[int] = None,
+ options: dict[str, Any] | None = None,
):
"""Reset all environments and return batch of initial observations and info."""
if seed is None:
@@ -136,7 +141,7 @@ def reset(
observations = {i: [] for i in self.single_observation_spaces}
infos = {i: {} for i in self.single_observation_spaces}
- for env_num, (env, s) in enumerate(zip(self.envs, seed)):
+ for env_num, (env, s) in enumerate(zip(self.envs, seed, strict=False)):
obs, info = env.reset(seed=s, options=options)
for i in self.single_observation_spaces:
observations[i].append(obs[i])
@@ -159,14 +164,14 @@ def step(self, actions):
dictionary under the keys ``final_observation`` and ``final_info``.
- Arguments
+ Arguments:
---------
actions
dict mapping agent ID to batch of actions for that agent, with one action
for each environment. So should be a dict of arrays or lists, with each
array/list having length equal to the number of environments.
- Returns
+ Returns:
-------
observations
dict mapping agent ID to batch of observations for that agent, with one
@@ -244,7 +249,7 @@ def close(self):
for env in self.envs:
env.close()
- def call(self, name: str, *args, **kwargs) -> Tuple:
+ def call(self, name: str, *args, **kwargs) -> tuple:
"""Call a method on all environments and return the results."""
results = []
for env in self.envs:
@@ -256,7 +261,7 @@ def call(self, name: str, *args, **kwargs) -> Tuple:
return tuple(results)
@property
- def possible_agents(self) -> Tuple[str, ...]:
+ def possible_agents(self) -> tuple[str, ...]:
return self.envs[0].possible_agents
@property
@@ -284,7 +289,7 @@ def _add_info(self, infos: dict, info: dict, env_num: int) -> dict:
whether or not the i-indexed environment has this `info`.
Arguments:
- ----------
+ ---------
infos
the infos of the vectorized environment
info
@@ -292,7 +297,7 @@ def _add_info(self, infos: dict, info: dict, env_num: int) -> dict:
env_num
the index of the single environment
- Returns
+ Returns:
-------
infos
the (updated) infos of the vectorized environment
@@ -308,7 +313,7 @@ def _add_info(self, infos: dict, info: dict, env_num: int) -> dict:
infos[k], infos[f"_{k}"] = info_array, array_mask
return infos
- def _init_info_arrays(self, dtype: type) -> Tuple[np.ndarray, np.ndarray]:
+ def _init_info_arrays(self, dtype: type) -> tuple[np.ndarray, np.ndarray]:
"""Initialize the info array.
Initialize the info array. If the dtype is numeric the info array will have the
@@ -316,12 +321,12 @@ def _init_info_arrays(self, dtype: type) -> Tuple[np.ndarray, np.ndarray]:
same length is returned. It will be used for assessing which environment has
info data.
- Arguments
+ Arguments:
---------
dtype
data type of the info coming from the env.
- Returns
+ Returns:
-------
array
the initialized info array.
diff --git a/posggym/wrappers/__init__.py b/posggym/wrappers/__init__.py
index 3847459..7c1d330 100644
--- a/posggym/wrappers/__init__.py
+++ b/posggym/wrappers/__init__.py
@@ -1,5 +1,7 @@
"""Module of wrapper classes."""
+from posggym.wrappers.batched_time_limit import BatchTimeLimit
from posggym.wrappers.discretize_actions import DiscretizeActions
+from posggym.wrappers.discretize_obs import DiscretizeObservations
from posggym.wrappers.env_checker import PassiveEnvChecker
from posggym.wrappers.flatten_observations import FlattenObservations
from posggym.wrappers.order_enforcing import OrderEnforcing
diff --git a/posggym/wrappers/batched_time_limit.py b/posggym/wrappers/batched_time_limit.py
new file mode 100644
index 0000000..7ebc59d
--- /dev/null
+++ b/posggym/wrappers/batched_time_limit.py
@@ -0,0 +1,63 @@
+import numpy as np
+
+import posggym
+
+
+class BatchTimeLimit(posggym.Wrapper):
+ """Wraps environment batch to enforce environment time limit.
+
+ This wrapper will issue a `truncated` signal in the :meth:`step` method for any
+ agents that have not already reached a terminal state by the time a maximum number
+ of timesteps is exceeded. It will also signal that the episode is `done` for all
+ agents in all environments when their respective time limits are exceeded.
+
+ Arguments:
+ ---------
+ env : posggym.Env
+ The environment batch to apply the wrapper
+ max_episode_steps : int, optional
+ The maximum length of episode before it is truncated. If None then will not
+ truncate episodes.
+ """
+
+ def __init__(self, env: posggym.Env, max_episode_steps: int | None = None) -> None:
+ super().__init__(env)
+
+ assert hasattr(self.env, "batch_size")
+ self.num_envs = self.env.batch_size # type: ignore
+
+ if max_episode_steps is None and self.env.spec is not None:
+ assert env.spec is not None
+ max_episode_steps = env.spec.max_episode_steps
+ if self.env.spec is not None:
+ self.env.spec.max_episode_steps = max_episode_steps
+ self._max_episode_steps = max_episode_steps
+ self._elapsed_steps = np.zeros(self.num_envs, dtype=int)
+ self._terminated_agents = [set() for _ in range(self.num_envs)]
+
+ def step(self, actions):
+ """Take a step in all batched environments with time limit enforcement."""
+ obs, rewards, terminated, truncated, done, info = self.env.step(actions)
+
+ for env_idx in range(self.num_envs):
+ self._elapsed_steps[env_idx] += 1
+
+ # Check if max steps are reached for the environment
+ if self._elapsed_steps[env_idx] >= self._max_episode_steps:
+ for agent, agent_truncated in truncated.items():
+ if agent not in self._terminated_agents[env_idx]:
+ agent_truncated[env_idx] = True
+ done[env_idx] = True
+ else:
+ for agent, agent_terminated in terminated.items():
+ # If the agent has terminated in this environment, mark it
+ if agent_terminated[env_idx]:
+ self._terminated_agents[env_idx].add(agent)
+
+ return obs, rewards, terminated, truncated, done, info
+
+ def reset(self, **kwargs):
+ self._elapsed_steps.fill(0)
+ obs, info = self.env.reset(**kwargs)
+ self._terminated_agents = [set() for _ in range(self.num_envs)]
+ return obs, info
diff --git a/posggym/wrappers/discretize_actions.py b/posggym/wrappers/discretize_actions.py
index 968a185..dc1f072 100644
--- a/posggym/wrappers/discretize_actions.py
+++ b/posggym/wrappers/discretize_actions.py
@@ -1,6 +1,7 @@
"""Wrapper to discretize continuous actions."""
-from typing import Dict, Sequence, Union, cast
+from collections.abc import Sequence
+from typing import cast
import numpy as np
from gymnasium import spaces
@@ -16,7 +17,7 @@ class DiscretizeActions(ActionWrapper):
space is multi-dimensional with :attr:`ndim` dimensions, then will create
discretized space with :attr:`num_actions ** ndim` actions.
- Arguments
+ Arguments:
---------
env : posggym.Env
The environment to apply the wrapper
@@ -29,7 +30,7 @@ class DiscretizeActions(ActionWrapper):
"""
- def __init__(self, env: Env, num_actions: int, flatten: bool = False):
+ def __init__(self, env: Env, num_actions: int, flatten: bool = False) -> None:
super().__init__(env)
assert all(
isinstance(act_space, spaces.Box)
@@ -39,8 +40,8 @@ def __init__(self, env: Env, num_actions: int, flatten: bool = False):
self.num_actions = num_actions
self.flatten = flatten
- box_action_spaces = cast(Dict[str, spaces.Box], self.action_spaces)
- self._unflat_space: Dict[str, spaces.MultiDiscrete] = {}
+ box_action_spaces = cast(dict[str, spaces.Box], self.action_spaces)
+ self._unflat_space: dict[str, spaces.MultiDiscrete] = {}
if self.flatten:
self._unflat_space = {
i: self.discretize_action_space( # type: ignore
@@ -55,10 +56,28 @@ def __init__(self, env: Env, num_actions: int, flatten: bool = False):
)
for i, act_space in box_action_spaces.items()
}
+ self.model_action_spaces = self.model.action_spaces
+
+ def _wrap_model(self, model):
+ class DiscretizedModel:
+ def __init__(self, base_model, parent):
+ self._model = base_model
+ self._parent = parent
+ self.action_spaces = parent._action_spaces
+
+ def step(self, state, actions):
+ undiscretized = self._parent.actions(actions)
+ return self._model.step(state, undiscretized)
+
+ def __getattr__(self, name):
+ # Delegate to base model for any undefined attribute
+ return getattr(self._model, name)
+
+ return DiscretizedModel(model, self)
def discretize_action_space(
self, action_space: spaces.Box, num_actions: int, flatten: bool = False
- ) -> Union[spaces.MultiDiscrete, spaces.Discrete]:
+ ) -> spaces.MultiDiscrete | spaces.Discrete:
assert isinstance(action_space, spaces.Box), "Action space must be a Box"
assert (
len(action_space.shape) > 0
@@ -83,13 +102,13 @@ def actions(self, actions):
for i, act_i in actions.items()
}
return {
- i: self.undiscretize_action(act_i, self.model.action_spaces[i])
+ i: self.undiscretize_action(act_i, self.model_action_spaces[i])
for i, act_i in actions.items()
}
def undiscretize_action(
self,
- discrete_action: Union[int, Sequence[int], np.ndarray],
+ discrete_action: int | Sequence[int] | np.ndarray,
action_space: spaces.Box,
) -> np.ndarray:
if isinstance(discrete_action, int):
diff --git a/posggym/wrappers/discretize_obs.py b/posggym/wrappers/discretize_obs.py
new file mode 100644
index 0000000..cb0cd4c
--- /dev/null
+++ b/posggym/wrappers/discretize_obs.py
@@ -0,0 +1,69 @@
+import numpy as np
+from gymnasium import spaces
+
+from posggym import ObservationWrapper
+
+
+class DiscretizeObservations(ObservationWrapper):
+ def __init__(self, env, num_bins=10):
+ super().__init__(env)
+
+ # Ensure all observation spaces are Box
+ assert all(
+ isinstance(space, spaces.Box) for space in env.observation_spaces.values()
+ ), "All agent observation spaces must be gym.spaces.Box"
+
+ self.num_bins = num_bins
+ self._original_obs_spaces = self.observation_spaces
+ self._bin_specs = {}
+
+ # Build new observation spaces
+ self.observation_spaces = {}
+ for agent_id, space in self._original_obs_spaces.items():
+ assert space.dtype in (np.float32, np.float64)
+ assert np.all(np.isfinite(space.low)) and np.all(np.isfinite(space.high))
+ self._bin_specs[agent_id] = {
+ "low": space.low,
+ "high": space.high,
+ "bin_width": (space.high - space.low) / num_bins,
+ }
+ self.observation_spaces[agent_id] = spaces.MultiDiscrete(
+ [num_bins] * space.shape[0] # type: ignore
+ )
+ self.model = self._wrap_model(self.model)
+
+ def observations(self, observation):
+ return {
+ agent_id: self._discretize(obs, self._bin_specs[agent_id])
+ for agent_id, obs in observation.items()
+ }
+
+ def _wrap_model(self, model):
+ class DiscretizedModel:
+ def __init__(self, base_model, parent):
+ self._model = base_model
+ self._parent = parent
+ self.observation_spaces = parent.observation_spaces
+
+ def step(self, state, actions):
+ step_result = self._model.step(state, actions)
+
+ # Discretize observations in-place
+ step_result = step_result._replace(
+ observations=self._parent.observations(step_result.observations)
+ )
+ return step_result
+
+ def sample_initial_obs(self, state):
+ obs = self._model.sample_initial_obs(state)
+ return self._parent.observations(obs)
+
+ def __getattr__(self, name):
+ return getattr(self._model, name)
+
+ return DiscretizedModel(model, self)
+
+ def _discretize(self, obs, spec):
+ obs = np.clip(obs, spec["low"], spec["high"])
+ discrete = ((obs - spec["low"]) / spec["bin_width"]).astype(int)
+ return tuple(np.clip(discrete, 0, self.num_bins - 1))
diff --git a/posggym/wrappers/env_checker.py b/posggym/wrappers/env_checker.py
index 9a6c865..688f464 100644
--- a/posggym/wrappers/env_checker.py
+++ b/posggym/wrappers/env_checker.py
@@ -1,6 +1,5 @@
"""A passive environment checker wrapper for an environment."""
-from typing import Dict
import posggym
import posggym.model as M
@@ -19,19 +18,19 @@ class PassiveEnvChecker(posggym.Wrapper):
Surrounds the step, reset and render functions to check that they follow the
posggym environment and model APIs.
- Arguments
+ Arguments:
---------
env : posggym.Env
The environment to apply the wrapper
- Note
+ Note:
----
This implementation is based on the similar Gymnasium wrapper:
https://github.com/Farama-Foundation/Gymnasium/blob/v0.27.0/gymnasium/wrappers/env_checker.py
"""
- def __init__(self, env: posggym.Env):
+ def __init__(self, env: posggym.Env) -> None:
super().__init__(env)
assert hasattr(env, "model"), "The environment must specify a model."
@@ -49,7 +48,7 @@ def __init__(self, env: posggym.Env):
self.checked_step = False
self.checked_render = False
- def step(self, actions: Dict[str, M.ActType]):
+ def step(self, actions: dict[str, M.ActType]):
"""Steps through the environment.
On the first call will run the `passive_env_step_check`.
diff --git a/posggym/wrappers/flatten_observations.py b/posggym/wrappers/flatten_observations.py
index 557cf9e..b4f05e6 100644
--- a/posggym/wrappers/flatten_observations.py
+++ b/posggym/wrappers/flatten_observations.py
@@ -7,19 +7,19 @@
class FlattenObservations(ObservationWrapper):
"""Observation wrapper that flattens the observation.
- Arguments
+ Arguments:
---------
env : posggym.Env
The environment to apply the wrapper
- Note
+ Note:
----
This implementation is based on the similar Gymnasium wrapper:
https://github.com/Farama-Foundation/Gymnasium/blob/v0.27.0/gymnasium/wrappers/flatten_observation.py
"""
- def __init__(self, env: Env):
+ def __init__(self, env: Env) -> None:
super().__init__(env)
self._observation_spaces = {
i: spaces.flatten_space(obs_space)
diff --git a/posggym/wrappers/monitoring/video_recorder.py b/posggym/wrappers/monitoring/video_recorder.py
index 18f0219..32c2f4d 100644
--- a/posggym/wrappers/monitoring/video_recorder.py
+++ b/posggym/wrappers/monitoring/video_recorder.py
@@ -4,11 +4,10 @@
https://github.com/Farama-Foundation/Gymnasium/blob/v0.27.0/gymnasium/wrappers/monitoring/video_recorder.py
"""
-
import json
import tempfile
from pathlib import Path
-from typing import Dict, List, Optional
+from typing import ClassVar
from posggym import Env, error, logger
@@ -23,27 +22,27 @@ class VideoRecorder:
makes it compatible with posggym.env.render function which returns rgb
arrays for the whole environment as well as (optionally) each agent.
- Note
+ Note:
----
You are responsible for calling `close` on a created VideoRecorder, or else
you may leak an encoder process.
"""
- combatible_render_modes = ["rgb_array", "rgb_array_dict"]
+ combatible_render_modes: ClassVar[list] = ["rgb_array", "rgb_array_dict"]
def __init__(
self,
env: Env,
- path: Optional[Path] = None,
- metadata: Optional[Dict] = None,
+ path: Path | None = None,
+ metadata: dict | None = None,
enabled: bool = True,
- base_path: Optional[Path] = None,
+ base_path: Path | None = None,
disable_logger: bool = False,
- ):
+ ) -> None:
"""Video recorder renders a nice movie of a rollout, frame by frame.
- Arguments
+ Arguments:
---------
env : posggym.Env
Environment to take video of.
@@ -64,7 +63,7 @@ def __init__(
# check that moviepy is now installed
import moviepy # noqa
except ImportError as e:
- raise error.DependencyNotInstalled(
+ raise error.DependencyNotInstalledError(
"MoviePy is not installed, run `pip install moviepy`"
) from e
@@ -73,13 +72,13 @@ def __init__(
self.disable_logger = disable_logger
self._closed = False
- self.render_history: List = []
+ self.render_history: list = []
self.env = env
self.render_mode = env.render_mode
if self.render_mode not in self.combatible_render_modes:
- logger.warn(
+ logger.warning(
f"Disabling video recorder because environment {env} was not "
"initialized with any compatible video modes in "
f"{self.combatible_render_modes}."
@@ -129,7 +128,7 @@ def __init__(
self.write_metadata()
logger.info("Starting new video recorder writing to %s", self.path)
- self.recorded_frames: List = []
+ self.recorded_frames: list = []
@property
def functional(self) -> bool:
@@ -150,7 +149,7 @@ def capture_frame(self):
frame = frame[-1]
elif isinstance(frame, dict):
if "env" not in frame:
- logger.warn(
+ logger.warning(
"The video recorder expects an entry with the key `env` when "
"trying to record an environment that is using the "
"`rgb_array_dict` render mode."
@@ -160,7 +159,7 @@ def capture_frame(self):
frame = frame["env"]
if self._closed:
- logger.warn(
+ logger.warning(
"The video recorder has been closed and no frames will be "
"captured anymore."
)
@@ -173,7 +172,7 @@ def capture_frame(self):
else:
# Indicates a bug in the environment: don't want to raise
# an error here.
- logger.warn(
+ logger.warning(
"Env returned None on `render()`. Disabling further rendering for "
f"video recorder by marking as disabled: path={self.path} "
f"metadata_path={self.metadata_path}"
@@ -193,7 +192,7 @@ def close(self):
try:
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
except ImportError as e:
- raise error.DependencyNotInstalled(
+ raise error.DependencyNotInstalledError(
"MoviePy is not installed, run `pip install moviepy`"
) from e
@@ -220,4 +219,4 @@ def __del__(self):
"""Closes the environment correctly when the recorder is deleted."""
# Make sure we've closed up shop when garbage collecting
if not self._closed:
- logger.warn("Unable to save last video! Did you call close()?")
+ logger.warning("Unable to save last video! Did you call close()?")
diff --git a/posggym/wrappers/order_enforcing.py b/posggym/wrappers/order_enforcing.py
index 40c24e7..62dc333 100644
--- a/posggym/wrappers/order_enforcing.py
+++ b/posggym/wrappers/order_enforcing.py
@@ -1,6 +1,6 @@
"""Wrapper to enforce the proper ordering of environment operations."""
import posggym
-from posggym.error import ResetNeeded
+from posggym.error import ResetNeededError
class OrderEnforcing(posggym.Wrapper):
@@ -8,28 +8,30 @@ class OrderEnforcing(posggym.Wrapper):
Will produce an error if :meth:`step` is called before an initial :meth:`reset`.
- Arguments
+ Arguments:
---------
env : posggym.Env
The environment to apply the wrapper
disable_render_order_enforcing : bool
Whether to disable enforcing of reset before render is called or not.
- Note
+ Note:
----
This implementation is based on the similar Gymnasium wrapper:
https://github.com/Farama-Foundation/Gymnasium/blob/v0.27.0/gymnasium/wrappers/order_enforcing.py
"""
- def __init__(self, env: posggym.Env, disable_render_order_enforcing: bool = False):
+ def __init__(
+ self, env: posggym.Env, disable_render_order_enforcing: bool = False
+ ) -> None:
super().__init__(env)
self._has_reset = False
self._disable_render_order_enforcing = disable_render_order_enforcing
def step(self, actions):
if not self._has_reset:
- raise ResetNeeded("Cannot call env.step() before calling env.reset()")
+ raise ResetNeededError("Cannot call env.step() before calling env.reset()")
return self.env.step(actions)
def reset(self, **kwargs):
@@ -38,7 +40,7 @@ def reset(self, **kwargs):
def render(self):
if not self._disable_render_order_enforcing and not self._has_reset:
- raise ResetNeeded(
+ raise ResetNeededError(
"Cannot call `env.render()` before calling `env.reset()`, if this is an"
" intended action, set `disable_render_order_enforcing=True` on the "
"OrderEnforcer wrapper."
diff --git a/posggym/wrappers/petting_zoo.py b/posggym/wrappers/petting_zoo.py
index 47f5165..8575974 100644
--- a/posggym/wrappers/petting_zoo.py
+++ b/posggym/wrappers/petting_zoo.py
@@ -1,14 +1,15 @@
"""Wrapper for converting a posggym environment into pettingzoo environment."""
from __future__ import annotations
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Any
import posggym
+
try:
from pettingzoo.utils.env import ActionDict, ObsDict, ParallelEnv
except ImportError as e:
- raise posggym.error.DependencyNotInstalled(
+ raise posggym.error.DependencyNotInstalledError(
"pettingzoo is not installed, run `pip install pettingzoo` or visit "
"'https://github.com/Farama-Foundation/PettingZoo#installation' for details on "
"installing pettingzoo."
@@ -33,35 +34,34 @@ class PettingZoo(ParallelEnv):
References
----------
-
- parallel env docs: https://pettingzoo.farama.org/api/parallel/
- parallel env code:
https://github.com/Farama-Foundation/PettingZoo/blob/master/pettingzoo/utils/env.py
"""
- def __init__(self, env: posggym.Env):
+ def __init__(self, env: posggym.Env) -> None:
self.env = env
- self._done_agents: Set[str] = set()
+ self._done_agents: set[str] = set()
@property
- def metadata(self) -> Dict[str, Any]:
+ def metadata(self) -> dict[str, Any]:
return self.env.metadata
@property
- def agents(self) -> List[str]:
+ def agents(self) -> list[str]:
return [i for i in self.env.agents if i not in self._done_agents]
@property
- def possible_agents(self) -> List[str]:
+ def possible_agents(self) -> list[str]:
return list(self.env.possible_agents)
@property
- def action_spaces(self) -> Dict[str, spaces.Space]:
+ def action_spaces(self) -> dict[str, spaces.Space]:
return self.env.action_spaces
@property
- def observation_spaces(self) -> Dict[str, spaces.Space]:
+ def observation_spaces(self) -> dict[str, spaces.Space]:
return self.env.observation_spaces
@property
@@ -70,9 +70,9 @@ def render_mode(self) -> str | None:
def reset(
self,
- seed: Optional[int] = None,
- return_info: bool = False,
- options: Optional[dict] = None,
+ seed: int | None = None,
+ return_info: bool = True,
+ options: dict | None = None,
) -> ObsDict:
obs, info = self.env.reset(seed=seed, options=options)
self._done_agents = set()
@@ -83,8 +83,8 @@ def reset(
def step(
self, actions: ActionDict
- ) -> Tuple[
- ObsDict, Dict[str, float], Dict[str, bool], Dict[str, bool], Dict[str, dict]
+ ) -> tuple[
+ ObsDict, dict[str, float], dict[str, bool], dict[str, bool], dict[str, dict]
]:
obs, rewards, terminated, truncated, all_done, info = self.env.step(actions)
@@ -102,10 +102,9 @@ def step(
for i, done in truncated.items():
if done:
self._done_agents.add(i)
-
return obs, rewards, terminated, truncated, info
- def render(self) -> None | np.ndarray | str | List:
+ def render(self) -> None | np.ndarray | str | list:
output = self.env.render()
if isinstance(output, dict):
return output.get("env", None)
diff --git a/posggym/wrappers/record_episode_statistics.py b/posggym/wrappers/record_episode_statistics.py
index 46656f6..927fa47 100644
--- a/posggym/wrappers/record_episode_statistics.py
+++ b/posggym/wrappers/record_episode_statistics.py
@@ -49,7 +49,7 @@ class RecordEpisodeStatistics(posggym.Wrapper):
can be accessed via :attr:`wrapped_env.return_queue` and
:attr:`wrapped_env.length_queue` respectively.
- Attributes
+ Attributes:
----------
episode_count : int
The number of episodes that have been recorded.
@@ -61,21 +61,21 @@ class RecordEpisodeStatistics(posggym.Wrapper):
A queue of the last ``deque_size`` episode lengths. Each entry is a dictionary
mapping agent ids to the episode length of the respective agent for an episode.
- Arguments
+ Arguments:
---------
env : posggym.Env
The environment to apply the wrapper
deque_size : int
The size of the buffer for storing the previous episode statistics.
- Note
+ Note:
----
This implementation is based on the similar Gymnasium wrapper:
https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/wrappers/record_episode_statistics.py
"""
- def __init__(self, env: posggym.Env, deque_size: int = 100):
+ def __init__(self, env: posggym.Env, deque_size: int = 100) -> None:
super().__init__(env)
self._deque_size = deque_size
diff --git a/posggym/wrappers/record_video.py b/posggym/wrappers/record_video.py
index dcb227f..2a5e580 100644
--- a/posggym/wrappers/record_video.py
+++ b/posggym/wrappers/record_video.py
@@ -5,14 +5,17 @@
"""
+from collections.abc import Callable
from pathlib import Path
-from typing import Callable, Optional, Union
from posggym import logger
from posggym.core import Env, Wrapper
from posggym.wrappers.monitoring.video_recorder import VideoRecorder
+CUBIC_SCHEDULE_THRESHOLD = 1000
+
+
def capped_cubic_video_schedule(episode_id: int) -> bool:
"""Get cubic schedule.
@@ -20,36 +23,36 @@ def capped_cubic_video_schedule(episode_id: int) -> bool:
the 1000th episode, then every 1000 episodes after that:
0, 1, 8, 27, 64, 125, 216, 343, 512, 729, 1000, 2000, 3000, ...
- Arguments
+ Arguments:
---------
episode_id: int
The episode number
- Returns
+ Returns:
-------
bool
Whether to record episode or not.
"""
- if episode_id < 1000:
+ if episode_id < CUBIC_SCHEDULE_THRESHOLD:
return int(round(episode_id ** (1.0 / 3))) ** 3 == episode_id
else:
- return episode_id % 1000 == 0
+ return episode_id % CUBIC_SCHEDULE_THRESHOLD == 0
class RecordVideo(Wrapper):
"""Wrapper for recording videos of rollouts.
- Arguments
+ Arguments:
---------
env: posggym.Env
The environment that will be wrapped
video_folder: str
The folder where the recordings will be stored
- episode_trigger: Optional[Callable[[int], bool]]
+ episode_trigger: Callable[[int], bool] | None
Function that accepts an integer and returns ``True`` iff a recording should be
started at this episode
- step_trigger: Optional[Callable[[int], bool]]
+ step_trigger: Callable[[int], bool] | None
Function that accepts an integer and returns ``True`` iff a recording should be
started at this step
video_length: int
@@ -60,7 +63,7 @@ class RecordVideo(Wrapper):
disable_logger: bool
Whether to disable moviepy logger or not.
- Note
+ Note:
----
This implementation is based on the gymnasium.wrappers.RecordVideo (version
gymnasium 0.27) wrapper, adapted here to work with posggym's multiagent environment:
@@ -71,13 +74,13 @@ class RecordVideo(Wrapper):
def __init__(
self,
env: Env,
- video_folder: Union[Path, str],
- episode_trigger: Optional[Callable[[int], bool]] = None,
- step_trigger: Optional[Callable[[int], bool]] = None,
+ video_folder: Path | str,
+ episode_trigger: Callable[[int], bool] | None = None,
+ step_trigger: Callable[[int], bool] | None = None,
video_length: int = 0,
name_prefix: str = "posggym-video",
disable_logger: bool = False,
- ):
+ ) -> None:
super().__init__(env)
if episode_trigger is None and step_trigger is None:
@@ -88,7 +91,7 @@ def __init__(
self.episode_trigger = episode_trigger
self.step_trigger = step_trigger
- self.video_recorder: Optional[VideoRecorder] = None
+ self.video_recorder: VideoRecorder | None = None
self.disable_logger = disable_logger
if isinstance(video_folder, str):
@@ -97,7 +100,7 @@ def __init__(
self.video_folder = video_folder
# Create output folder if needed
if self.video_folder.is_dir():
- logger.warn(
+ logger.warning(
f"Overwriting existing videos at {self.video_folder} folder (try "
"specifying a different `video_folder` for the `RecordVideo` wrapper "
"if this is not desired)"
diff --git a/posggym/wrappers/rescale_actions.py b/posggym/wrappers/rescale_actions.py
index 12ed670..d899402 100644
--- a/posggym/wrappers/rescale_actions.py
+++ b/posggym/wrappers/rescale_actions.py
@@ -1,5 +1,4 @@
"""Wrapper to rescale continuous actions from [min, max] range."""
-from typing import Dict, Union
import numpy as np
from gymnasium import spaces
@@ -18,16 +17,16 @@ class RescaleActions(posggym.ActionWrapper):
the given agent. If :attr:`min_action` or :attr:`max_action` are dictionaries then
they must have an entry for each possible agent ID in the wrapped environment.
- Arguments
+ Arguments:
---------
env : posggym.Env
The environment to apply the wrapper
- min_action : float, int, np.ndarray, Dict[str, Union[float, int, np.ndarray]]
+ min_action : float, int, np.ndarray, Dict[str, float | int | np.ndarray]
The minimum value for the scaled actions.
- max_action : float, int, np.ndarray, Dict[str, Union[float, int, np.ndarray]]
+ max_action : float, int, np.ndarray, Dict[str, float | int | np.ndarray]
The maximum value for the scaled actions.
- Note
+ Note:
----
Explanation of how to scale number from one interval into new interval:
https://stats.stackexchange.com/questions/281162/scale-a-number-between-a-range
@@ -40,13 +39,9 @@ class RescaleActions(posggym.ActionWrapper):
def __init__(
self,
env: posggym.Env,
- min_action: Union[
- float, int, np.ndarray, Dict[str, Union[float, int, np.ndarray]]
- ],
- max_action: Union[
- float, int, np.ndarray, Dict[str, Union[float, int, np.ndarray]]
- ],
- ):
+ min_action: float | int | np.ndarray | dict[str, float | int | np.ndarray],
+ max_action: float | int | np.ndarray | dict[str, float | int | np.ndarray],
+ ) -> None:
self.min_action = {}
self.max_action = {}
self.rescale_factor = {}
@@ -57,7 +52,7 @@ def __init__(
)
min_action_i = min_action[i] if isinstance(min_action, dict) else min_action
- if isinstance(min_action_i, (float, int)):
+ if isinstance(min_action_i, float | int):
self.min_action[i] = np.full_like(action_space.low, min_action_i)
else:
assert isinstance(min_action_i, np.ndarray), min_action_i
@@ -68,7 +63,7 @@ def __init__(
self.min_action[i] = min_action_i
max_action_i = max_action[i] if isinstance(max_action, dict) else max_action
- if isinstance(max_action_i, (float, int)):
+ if isinstance(max_action_i, float | int):
self.max_action[i] = np.full_like(action_space.high, max_action_i)
else:
assert isinstance(max_action_i, np.ndarray), max_action_i
diff --git a/posggym/wrappers/rescale_observations.py b/posggym/wrappers/rescale_observations.py
index 2a6acf7..23e8f32 100644
--- a/posggym/wrappers/rescale_observations.py
+++ b/posggym/wrappers/rescale_observations.py
@@ -1,5 +1,4 @@
"""Wrapper for rescaling observations to within min and max values."""
-from typing import Dict, Union
import numpy as np
from gymnasium import spaces
@@ -16,17 +15,17 @@ class RescaleObservations(ObservationWrapper):
the given agent. If :attr:`min_obs` or :attr:`max_obs` are dictionaries then they
must have an entry for each possible agent ID in the wrapped environment.
- Arguments
+ Arguments:
---------
env : posggym.Env
The environment to apply the wrapper
- min_obs : float, int, np.ndarray, Dict[str, Union[float, int, np.ndarray]]
+ min_obs : float, int, np.ndarray, Dict[str, float | int | np.ndarray]
The minimum value for the scaled observations.
- max_obs : float, int, np.ndarray, Dict[str, Union[float, int, np.ndarray]]
+ max_obs : float, int, np.ndarray, Dict[str, float | int | np.ndarray]
The maximum value for the scaled observations.
- Note
+ Note:
----
Explanation of how to scale number from one interval into new interval:
https://stats.stackexchange.com/questions/281162/scale-a-number-between-a-range
@@ -36,13 +35,9 @@ class RescaleObservations(ObservationWrapper):
def __init__(
self,
env: Env,
- min_obs: Union[
- float, int, np.ndarray, Dict[str, Union[float, int, np.ndarray]]
- ],
- max_obs: Union[
- float, int, np.ndarray, Dict[str, Union[float, int, np.ndarray]]
- ],
- ):
+ min_obs: float | int | np.ndarray | dict[str, float | int | np.ndarray],
+ max_obs: float | int | np.ndarray | dict[str, float | int | np.ndarray],
+ ) -> None:
self.min_obs = {}
self.max_obs = {}
self.rescale_factor = {}
@@ -53,7 +48,7 @@ def __init__(
)
min_obs_i = min_obs[i] if isinstance(min_obs, dict) else min_obs
- if isinstance(min_obs_i, (float, int)):
+ if isinstance(min_obs_i, float | int):
self.min_obs[i] = np.full_like(obs_space.low, min_obs_i)
else:
assert isinstance(min_obs_i, np.ndarray), min_obs_i
@@ -64,7 +59,7 @@ def __init__(
self.min_obs[i] = min_obs_i
max_obs_i = max_obs[i] if isinstance(max_obs, dict) else max_obs
- if isinstance(max_obs_i, (float, int)):
+ if isinstance(max_obs_i, float | int):
self.max_obs[i] = np.full_like(obs_space.high, max_obs_i)
else:
assert isinstance(max_obs_i, np.ndarray), max_obs_i
diff --git a/posggym/wrappers/rllib_env.py b/posggym/wrappers/rllib_env.py
index 42be879..c3c39ec 100644
--- a/posggym/wrappers/rllib_env.py
+++ b/posggym/wrappers/rllib_env.py
@@ -1,24 +1,24 @@
"""Wrapper for converting a posggym environment into rllib multi-agent environment."""
import warnings
-from typing import Optional, Set, Tuple
from gymnasium import spaces
import posggym
+
try:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.utils.typing import AgentID, MultiAgentDict
-except ImportError:
- raise posggym.error.DependencyNotInstalled(
+except ImportError as err:
+ raise posggym.error.DependencyNotInstalledError(
"The posggym.wrapper.rllib_multi_agent_env wrapper depends on the Ray RLlib "
"library. run `pip install ray[rllib]>=2.3` or visit "
"'https://docs.ray.io/en/latest/ray-overview/installation.html` for more "
"details on installing rllib. "
- )
+ ) from err
class RllibMultiAgentEnv(MultiAgentEnv):
@@ -28,14 +28,13 @@ class RllibMultiAgentEnv(MultiAgentEnv):
References
----------
-
- https://github.com/ray-project/ray/blob/ray-2.3.0/rllib/env/multi_agent_env.py
"""
- def __init__(self, env: posggym.Env):
+ def __init__(self, env: posggym.Env) -> None:
self.env = env
- self._done_agents: Set[AgentID] = set()
+ self._done_agents: set[AgentID] = set()
# must assign this first before calling super().__init__() so that
# property functions are initialized before super().__init__() is
@@ -64,22 +63,22 @@ def action_space(self):
"""
return spaces.Dict(self.env.action_spaces)
- def get_agent_ids(self) -> Set[AgentID]:
+ def get_agent_ids(self) -> set[AgentID]:
"""Return a set of agent ids in the environment."""
return self._agent_ids
def reset( # type: ignore
self,
*,
- seed: Optional[int] = None,
- options: Optional[dict] = None,
- ) -> Tuple[MultiAgentDict, MultiAgentDict]:
+ seed: int | None = None,
+ options: dict | None = None,
+ ) -> tuple[MultiAgentDict, MultiAgentDict]:
self._done_agents = set()
return self.env.reset(seed=seed, options=options)
def step( # type: ignore
self, action_dict: MultiAgentDict
- ) -> Tuple[
+ ) -> tuple[
MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict
]:
"""Returns observations from ready agents.
@@ -87,12 +86,12 @@ def step( # type: ignore
The returns are dicts mapping from agent_id strings to values. The
number of agents in the env can vary over time.
- Arguments
+ Arguments:
---------
action_dict : MultiAgentDict
action for each agent
- Returns
+ Returns:
-------
observations : MultiAgentDict
new observations for each ready agent
diff --git a/posggym/wrappers/stack.py b/posggym/wrappers/stack.py
index 27f1bb5..9b75764 100644
--- a/posggym/wrappers/stack.py
+++ b/posggym/wrappers/stack.py
@@ -1,5 +1,5 @@
"""Environment wrapper class that stacks agent observations into a single array."""
-from typing import Any, Dict
+from typing import Any
import numpy as np
from gymnasium import spaces
@@ -89,7 +89,7 @@ class StackEnv(posggym.Wrapper):
"""
- def __init__(self, env: posggym.Env):
+ def __init__(self, env: posggym.Env) -> None:
super().__init__(env)
self.num_envs = getattr(env, "num_envs", 1)
@@ -123,10 +123,25 @@ def reset(self, **kwargs):
def step(self, actions):
# input shape (num_envs * num_agents, *single_action_space.shape)
# convert to dict of actions, shape (num_envs, *single_action_space.shape)
- action_map = {
- i: actions[idx :: len(self.possible_agents)]
- for idx, i in enumerate(self.possible_agents)
- }
+ try:
+ is_multi_discrete = all(
+ isinstance(
+ self.env.unwrapped.single_action_spaces[key], spaces.MultiDiscrete
+ )
+ for key in self.env.unwrapped.single_action_spaces
+ )
+ except AttributeError:
+ is_multi_discrete = False
+
+ if is_multi_discrete:
+ action_map = {
+ i: actions[:, idx, :] for idx, i in enumerate(self.possible_agents)
+ }
+ else:
+ action_map = {
+ i: actions[idx :: len(self.possible_agents)]
+ for idx, i in enumerate(self.possible_agents)
+ }
obs, rewards, terminated, truncated, dones, infos = self.env.step(action_map)
return (
self._stack_output(obs),
@@ -137,9 +152,9 @@ def step(self, actions):
infos,
)
- def _stack_output(self, output: Dict[str, Any]) -> np.ndarray:
+ def _stack_output(self, output: dict[str, Any]) -> np.ndarray:
"""Stacks the output of the environment into a single array."""
- x0 = list(output.values())[0]
+ x0 = next(iter(output.values()))
x0 = x0 if isinstance(x0, np.ndarray) else np.array([x0])
num_agents = len(self.possible_agents)
if self.is_vector_env:
diff --git a/posggym/wrappers/time_limit.py b/posggym/wrappers/time_limit.py
index 39469f7..ef58c04 100644
--- a/posggym/wrappers/time_limit.py
+++ b/posggym/wrappers/time_limit.py
@@ -1,6 +1,5 @@
"""Wrapper for limiting the time steps of an environment."""
-from typing import Optional, Set
import posggym
@@ -13,7 +12,7 @@ class TimeLimit(posggym.Wrapper):
of timesteps is exceeded. It will also signal that the episode is `done` for all
agents.
- Arguments
+ Arguments:
---------
env : posggym.Env
The environment to apply the wrapper
@@ -21,14 +20,14 @@ class TimeLimit(posggym.Wrapper):
The maximum length of episode before it is truncated. If None then will not
truncate episodes.
- Note
+ Note:
----
This implementation is based on the similar Gymnasium wrapper:
https://github.com/Farama-Foundation/Gymnasium/blob/v0.27.0/gymnasium/wrappers/time_limit.py
"""
- def __init__(self, env: posggym.Env, max_episode_steps: Optional[int] = None):
+ def __init__(self, env: posggym.Env, max_episode_steps: int | None = None) -> None:
super().__init__(env)
if max_episode_steps is None and self.env.spec is not None:
assert env.spec is not None
@@ -37,7 +36,7 @@ def __init__(self, env: posggym.Env, max_episode_steps: Optional[int] = None):
self.env.spec.max_episode_steps = max_episode_steps
self._max_episode_steps = max_episode_steps
self._elapsed_steps = 0
- self._terminated_agents: Set[str] = set()
+ self._terminated_agents: set[str] = set()
def step(self, actions):
obs, rewards, terminated, truncated, done, info = self.env.step(actions)
diff --git a/pyproject.toml b/pyproject.toml
index d6dd325..9a3ecb8 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -8,7 +8,7 @@ build-backend = "setuptools.build_meta"
name = "posggym"
description = "A library for planning and reinforcement learning research in partially observable, multi-agent environments."
readme = "README.md"
-requires-python = ">= 3.8"
+requires-python = ">= 3.10"
authors = [
{ name = "Jonathon Schwartz", email = "jonathon.schwartz@anu.edu.au" },
]
@@ -33,12 +33,13 @@ classifiers = [
'Topic :: Scientific/Engineering :: Artificial Intelligence',
]
dependencies = [
- "gymnasium >=0.26",
+ "gymnasium>=0.26,<=0.29.1",
"numpy >=1.21.0",
"typing-extensions >=4.3.0",
"importlib-metadata >=4.8.0; python_version < '3.10'",
"pygame >=2.0",
- "pymunk >=6.0.0",
+ "pymunk >=6.0.0, <7.0.0",
+ "vmas==1.4.1"
]
dynamic = ["version"]
@@ -89,27 +90,76 @@ posggym = ["envs/grid_world/img/*.png", "py.typed"]
# Linters and Test tools #######################################################
[tool.ruff]
-# https://beta.ruff.rs/docs/settings/
src = ["posggym", "tests", "docs/scripts", "scripts", "examples", "notebooks"]
-extend-select = ["C4", "SIM", "TCH"]
show-fixes = true
# Same as Black.
line-length = 88
-# Assume Python 3.8.
-target-version = "py38"
+# Assume Python 3.10.
+target-version = "py310"
+
+[tool.ruff.lint]
+extend-select = ["C4", "SIM", "TCH"]
+
+ignore = [
+ "A001", "A002",
+ "ARG001", "ARG002",
+ "FBT001", "FBT002", "FBT003",
+ "N802", "N806", "N812",
+ "PGH003", "PGH004",
+ "S101", "S202", "S301", "S310", "S311",
+ "T201", "T203"
+]
+select = [
+ "A",
+ "ARG",
+ "B",
+ "C4",
+ "C90",
+ "E",
+ "ERA",
+ "F",
+ "FBT",
+ "ICN",
+ "I",
+ "ISC",
+ "N",
+ "NPY",
+ "PD",
+ "PGH",
+ "PIE",
+ "PLE",
+ "PLR",
+ "Q",
+ "RUF",
+ "S",
+ "SIM",
+ "T",
+ "UP",
+ "W",
+]
# Ignore `F401` (import violations) in all `__init__.py` files,
-[tool.ruff.per-file-ignores]
+[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["F401", "E402"]
+"tests/**/*.py" = ["PLR2004", "N802", "N801"]
+"docs/conf.py" = ["A001"]
+"docs/**/*.py" = ["PLR2004"]
+"posggym/utils/**.py" = ["PLR2004"]
-[tool.ruff.mccabe]
+[tool.ruff.lint.mccabe]
# Unlike Flake8, default to a complexity level of 10.
-max-complexity = 10
+max-complexity = 30
-[tool.ruff.isort]
+[tool.ruff.lint.isort]
lines-after-imports = 2
extra-standard-library = ["typing_extensions"]
+[tool.ruff.lint.pylint]
+max-statements = 100
+max-branches=20
+max-args=15
+
+
[tool.black]
line-length = 88
@@ -119,7 +169,7 @@ exclude = ["**/node_modules", "**/__pycache__"]
strict = []
typeCheckingMode = "basic"
-pythonVersion = "3.8"
+pythonVersion = "3.10"
pythonPlatform = "All"
enableTypeIgnoreComments = true
diff --git a/scripts/pairwise_agent_comparison.py b/scripts/pairwise_agent_comparison.py
index bda0b45..8e5d18e 100644
--- a/scripts/pairwise_agent_comparison.py
+++ b/scripts/pairwise_agent_comparison.py
@@ -9,6 +9,7 @@
from posggym.agents.evaluation import pairwise
+
if __name__ == "__main__":
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
diff --git a/scripts/time_env.py b/scripts/time_env.py
index c993478..6f32e28 100644
--- a/scripts/time_env.py
+++ b/scripts/time_env.py
@@ -8,24 +8,23 @@
import argparse
import time
-from typing import Optional
import posggym
def time_env_step_rate(
- env_id: str, num_steps: int, seed: Optional[int], render_mode: Optional[str]
+ env_id: str, num_steps: int, seed: int | None, render_mode: str | None
) -> float:
"""Calculate the step rate of environment.
- Arguments
+ Arguments:
---------
env_id: ID of environment to test
num_steps: The number of steps to test for
seed: the random seed to use
render_mode: render mode for environment
- Returns
+ Returns:
-------
step_rate: the average steps per second executed in the environment
@@ -50,10 +49,10 @@ def time_env_step_rate(
def main(
- env_id: Optional[str] = None,
+ env_id: str | None = None,
num_steps: int = 1000,
- seed: Optional[int] = None,
- render_mode: Optional[str] = None,
+ seed: int | None = None,
+ render_mode: str | None = None,
):
env_ids = list(posggym.registry) if env_id is None else [env_id]
diff --git a/setup.py b/setup.py
index 4194960..601f94b 100644
--- a/setup.py
+++ b/setup.py
@@ -8,6 +8,7 @@
from setuptools import setup
from setuptools.command import build_py
+
CWD = Path(__file__).absolute().parent
ASSETS_URL = (
diff --git a/tests/agents/continuous/predator_prey/test_heuristic.py b/tests/agents/continuous/predator_prey/test_heuristic.py
index 9c73bf1..8aad1fe 100644
--- a/tests/agents/continuous/predator_prey/test_heuristic.py
+++ b/tests/agents/continuous/predator_prey/test_heuristic.py
@@ -1,10 +1,9 @@
"""Tests for the heuristic agent in the predator prey continuous environment."""
import numpy as np
-import pytest
-
import posggym
import posggym.agents as pga
+import pytest
RENDER_MODE = None
diff --git a/tests/agents/continuous/pursuit_evasion/test_shortest_path.py b/tests/agents/continuous/pursuit_evasion/test_shortest_path.py
index 1e52d80..6a1dc15 100644
--- a/tests/agents/continuous/pursuit_evasion/test_shortest_path.py
+++ b/tests/agents/continuous/pursuit_evasion/test_shortest_path.py
@@ -1,10 +1,9 @@
"""Tests for the shortest path policy in pursuit-evasion continuous environment."""
import numpy as np
-import pytest
-
import posggym
import posggym.agents as pga
+import pytest
from posggym.agents.utils.action_distributions import DeterministicActionDistribution
@@ -16,7 +15,7 @@
class ConstantPolicy(pga.Policy):
"""A policy that always returns the same action."""
- def __init__(self, model, agent_id, policy_id, action):
+ def __init__(self, model, agent_id, policy_id, action) -> None:
super().__init__(model, agent_id, policy_id)
self.action = action
@@ -90,7 +89,6 @@ def test_shortest_path(agent_id, world):
if __name__ == "__main__":
# For manual debugging
- # RENDER_MODE = "human"
for fn in [
test_shortest_path,
]:
diff --git a/tests/agents/helpers.py b/tests/agents/helpers.py
index bab0819..a6569c7 100644
--- a/tests/agents/helpers.py
+++ b/tests/agents/helpers.py
@@ -4,7 +4,6 @@
https://github.com/Farama-Foundation/Gymnasium/blob/v0.27.0/tests/envs/utils.py
"""
-from typing import List, Optional
import numpy as np
import posggym
@@ -15,7 +14,7 @@
from tests.conftest import env_id_prefix
-def try_make_policy(spec: PolicySpec) -> Optional[pga.Policy]:
+def try_make_policy(spec: PolicySpec) -> pga.Policy | None:
"""Tries to make the policy showing if it is possible."""
try:
if spec.env_id is None:
@@ -32,10 +31,10 @@ def try_make_policy(spec: PolicySpec) -> Optional[pga.Policy]:
return pga.make(spec, env.model, agent_id)
except (
ImportError,
- posggym.error.DependencyNotInstalled,
- posggym.error.MissingArgument,
+ posggym.error.DependencyNotInstalledError,
+ posggym.error.MissingArgumentError,
) as e:
- posggym.logger.warn(
+ posggym.logger.warning(
f"Not testing posggym.agents policy spec `{spec.id}` due to error: {e}"
)
except RuntimeError as e:
@@ -46,28 +45,28 @@ def try_make_policy(spec: PolicySpec) -> Optional[pga.Policy]:
# Tries to make all policies to test with
-_all_testing_initialised_policies: List[Optional[pga.Policy]] = [
+_all_testing_initialised_policies: list[pga.Policy] | None = [
try_make_policy(policy_spec)
for policy_spec in pga.registry.values()
if env_id_prefix is None or policy_spec.id.startswith(env_id_prefix)
]
-all_testing_initialised_policies: List[pga.Policy] = [
+all_testing_initialised_policies: list[pga.Policy] = [
policy for policy in _all_testing_initialised_policies if policy is not None
]
-all_testing_initialised_torch_policies: List[torch_policy.PPOPolicy] = [
+all_testing_initialised_torch_policies: list[torch_policy.PPOPolicy] = [
policy
for policy in all_testing_initialised_policies
if isinstance(policy, torch_policy.PPOPolicy)
]
# All testing posggym-agents policy specs
-all_testing_policy_specs: List[PolicySpec] = [
+all_testing_policy_specs: list[PolicySpec] = [
policy.spec
for policy in all_testing_initialised_policies
if policy.spec is not None
]
# All testing posggym-agents policy specs that use torch
-all_testing_torch_policy_specs: List[PolicySpec] = [
+all_testing_torch_policy_specs: list[PolicySpec] = [
policy.spec
for policy in all_testing_initialised_torch_policies
if policy.spec is not None
@@ -77,7 +76,7 @@ def try_make_policy(spec: PolicySpec) -> Optional[pga.Policy]:
def assert_equals(a, b, prefix=None):
"""Assert equality of data structures `a` and `b`.
- Arguments
+ Arguments:
---------
a: first data structure
b: second data structure
@@ -95,8 +94,8 @@ def assert_equals(a, b, prefix=None):
np.testing.assert_array_equal(a, b)
elif isinstance(a, torch.Tensor):
assert torch.equal(a, b), f"{prefix}Tensors differ: {a} and {b}"
- elif isinstance(a, (tuple, list)):
- for elem_from_a, elem_from_b in zip(a, b):
+ elif isinstance(a, tuple | list):
+ for elem_from_a, elem_from_b in zip(a, b, strict=False):
assert_equals(elem_from_a, elem_from_b, prefix)
else:
assert a == b
diff --git a/tests/agents/test_agents.py b/tests/agents/test_agents.py
index 9e9223c..37ff176 100644
--- a/tests/agents/test_agents.py
+++ b/tests/agents/test_agents.py
@@ -44,7 +44,7 @@ def test_policy(spec: PolicySpec):
obs, _ = env.reset(seed=SEED)
if spec.valid_agent_ids:
- test_agent_id = list(set(env.agents).intersection(spec.valid_agent_ids))[0]
+ test_agent_id = next(iter(set(env.agents).intersection(spec.valid_agent_ids)))
else:
test_agent_id = env.agents[0]
@@ -53,7 +53,7 @@ def test_policy(spec: PolicySpec):
test_policy.reset(seed=SEED + 1)
- for t in range(2):
+ for _t in range(2):
joint_action = {}
for i in env.agents:
if i == test_agent_id and test_policy.observes_state:
@@ -110,7 +110,7 @@ def test_policy_determinism_rollout(spec: PolicySpec):
env_2.reset(seed=SEED)
if spec.valid_agent_ids:
- agent_id = list(set(env_1.agents).intersection(spec.valid_agent_ids))[0]
+ agent_id = next(iter(set(env_1.agents).intersection(spec.valid_agent_ids)))
else:
agent_id = env_1.agents[0]
@@ -122,7 +122,7 @@ def test_policy_determinism_rollout(spec: PolicySpec):
assert_equals(policy_1.get_state(), policy_2.get_state())
- for time_step in range(NUM_STEPS):
+ for _time_step in range(NUM_STEPS):
if policy_1.observes_state:
action_1 = policy_1.step(env_1.state)
action_2 = policy_2.step(env_1.state)
diff --git a/tests/agents/test_make.py b/tests/agents/test_make.py
index c57b3c5..26df168 100644
--- a/tests/agents/test_make.py
+++ b/tests/agents/test_make.py
@@ -7,15 +7,15 @@
import re
import warnings
-import pytest
-from tests.agents.helpers import assert_equals
-
import posggym
import posggym.agents as pga
+import pytest
from posggym import error
from posggym.agents.random_policies import DiscreteFixedDistributionPolicy, RandomPolicy
from posggym.agents.registration import get_env_args_id
from posggym.agents.utils.action_distributions import DiscreteActionDistribution
+from tests.agents.helpers import assert_equals
+
TEST_ENV_ID = "MultiAccessBroadcastChannel-v0"
TEST_ENV_ID_UNV = "MultiAccessBroadcastChannel"
@@ -29,7 +29,7 @@
@pytest.fixture(scope="function")
def register_make_testing_policies():
- """Registers testing policies for `posggym.agents.make`"""
+ """Registers testing policies for `posggym.agents.make`."""
pga.register(policy_name="GenericTestPolicy", entry_point=RandomPolicy, version=0)
pga.register(
policy_name="EnvTestPolicy",
@@ -240,7 +240,7 @@ def test_policy_suggestions(
):
env = posggym.make(TEST_ENV_ID)
with pytest.raises(
- error.UnregisteredPolicy, match=f"Did you mean: `{policy_id_suggested}`?"
+ error.UnregisteredPolicyError, match=f"Did you mean: `{policy_id_suggested}`?"
):
pga.make(policy_id_input, env.model, env.agents[0])
@@ -262,13 +262,13 @@ def test_env_version_suggestions(
env = posggym.make(TEST_ENV_ID)
if default_version:
with pytest.raises(
- error.DeprecatedPolicy,
+ error.DeprecatedPolicyError,
match="It provides the default version",
):
pga.make(policy_id_input, env.model, env.agents[0])
else:
with pytest.raises(
- error.UnregisteredPolicy,
+ error.UnregisteredPolicyError,
match=f"It provides versioned policies: \\[ {suggested_versions} \\]",
):
pga.make(policy_id_input, env.model, env.agents[0])
diff --git a/tests/agents/test_register.py b/tests/agents/test_register.py
index 182c9e6..b929972 100644
--- a/tests/agents/test_register.py
+++ b/tests/agents/test_register.py
@@ -5,7 +5,7 @@
"""
import re
-from typing import Any, Dict, Optional
+from typing import Any
import posggym.agents as pga
import pytest
@@ -32,10 +32,10 @@
],
)
def test_register(
- env_id: Optional[str],
- env_args: Optional[Dict[str, Any]],
+ env_id: str | None,
+ env_args: dict[str, Any] | None,
policy_name: str,
- version: Optional[int],
+ version: int | None,
):
pga.register(
policy_name=policy_name,
@@ -76,10 +76,10 @@ def test_register(
],
)
def test_register_error(
- env_id: Optional[str],
- env_args: Optional[Dict[str, Any]],
+ env_id: str | None,
+ env_args: dict[str, Any] | None,
policy_name: str,
- version: Optional[int],
+ version: int | None,
):
with pytest.raises(error.Error, match="^Malformed policy ID:"):
pga.register(
@@ -98,10 +98,10 @@ def test_register_error(
],
)
def test_register_error2(
- env_id: Optional[str],
- env_args: Optional[Dict[str, Any]],
+ env_id: str | None,
+ env_args: dict[str, Any] | None,
policy_name: str,
- version: Optional[int],
+ version: int | None,
):
with pytest.raises(error.Error, match="^Cannot create policy ID."):
pga.register(
diff --git a/tests/agents/test_spec.py b/tests/agents/test_spec.py
index 4dc8dae..6fd605f 100644
--- a/tests/agents/test_spec.py
+++ b/tests/agents/test_spec.py
@@ -25,7 +25,7 @@
@pytest.fixture(scope="function")
def register_make_testing_policies():
- """Registers testing policies for `posggym_agents.make`"""
+ """Registers testing policies for `posggym_agents.make`."""
pga.register(policy_name="GenericTestPolicy", entry_point=RandomPolicy, version=0)
pga.register(
policy_name="EnvTestPolicy",
@@ -131,7 +131,7 @@ def test_generic_spec_missing_lookup(register_make_testing_policies):
pga.register("Other1", entry_point="no-entry-point", version=100)
with pytest.raises(
- error.DeprecatedPolicy,
+ error.DeprecatedPolicyError,
match=re.escape(
"Policy version v1 for `Test1` is deprecated. Please use `Test1-v15` "
"instead."
@@ -140,7 +140,7 @@ def test_generic_spec_missing_lookup(register_make_testing_policies):
pga.spec("Test1-v1")
with pytest.raises(
- error.UnregisteredPolicy,
+ error.UnregisteredPolicyError,
match=re.escape(
"Policy version `v1000` for policy `Test1` doesn't exist. "
"It provides versioned policies: [ `v0`, `v9`, `v15` ]."
@@ -149,7 +149,7 @@ def test_generic_spec_missing_lookup(register_make_testing_policies):
pga.spec("Test1-v1000")
with pytest.raises(
- error.UnregisteredPolicy,
+ error.UnregisteredPolicyError,
match=re.escape("Policy Unknown1 doesn't exist. "),
):
pga.spec("Unknown1-v1")
@@ -163,7 +163,7 @@ def test_env_spec_missing_lookup():
pga.register("Other1", entry_point="no-entry-point", version=100, env_id=env_id)
with pytest.raises(
- error.DeprecatedPolicy,
+ error.DeprecatedPolicyError,
match=re.escape(
f"Policy version v1 for `{env_id}/Test1` is deprecated. Please use "
f"`{env_id}/Test1-v15` instead."
@@ -172,7 +172,7 @@ def test_env_spec_missing_lookup():
pga.spec(f"{env_id}/Test1-v1")
with pytest.raises(
- error.UnregisteredPolicy,
+ error.UnregisteredPolicyError,
match=re.escape(
f"Policy version `v1000` for policy `{env_id}/Test1` doesn't exist. "
"It provides versioned policies: [ `v0`, `v9`, `v15` ]."
@@ -181,7 +181,7 @@ def test_env_spec_missing_lookup():
pga.spec(f"{env_id}/Test1-v1000")
with pytest.raises(
- error.UnregisteredPolicy,
+ error.UnregisteredPolicyError,
match=re.escape(f"Policy Unknown1 doesn't exist for env ID {env_id}. "),
):
pga.spec(f"{env_id}/Unknown1-v1")
@@ -206,7 +206,7 @@ def test_spec_default_lookups():
pga.register("Test4", entry_point="no-entry-point", version=None, env_id=None)
with pytest.raises(
- error.DeprecatedPolicy,
+ error.DeprecatedPolicyError,
match=re.escape(
f"Policy version `v0` for policy `{env_id}/Test3` doesn't exist. "
f"It provides the default version {env_id}/Test3`."
@@ -217,7 +217,7 @@ def test_spec_default_lookups():
assert pga.spec(f"{env_id}/Test3") is not None
with pytest.raises(
- error.DeprecatedPolicy,
+ error.DeprecatedPolicyError,
match=re.escape(
"Policy version `v0` for policy `Test4` doesn't exist. "
"It provides the default version Test4`."
@@ -228,7 +228,7 @@ def test_spec_default_lookups():
assert pga.spec("Test4") is not None
with pytest.raises(
- error.DeprecatedPolicy,
+ error.DeprecatedPolicyError,
match=re.escape(
"Policy version `v0` for policy `Test4` doesn't exist. "
"It provides the default version Test4`."
diff --git a/tests/agents/test_torch_agents.py b/tests/agents/test_torch_agents.py
index c45b101..0b82d78 100644
--- a/tests/agents/test_torch_agents.py
+++ b/tests/agents/test_torch_agents.py
@@ -33,7 +33,7 @@ def test_policy(spec: PolicySpec):
obs, _ = env.reset(seed=SEED)
if spec.valid_agent_ids:
- test_agent_id = list(set(env.agents).intersection(spec.valid_agent_ids))[0]
+ test_agent_id = next(iter(set(env.agents).intersection(spec.valid_agent_ids)))
else:
test_agent_id = env.agents[0]
@@ -95,7 +95,7 @@ def test_policy_determinism_rollout(spec: PolicySpec):
env_2.reset(seed=SEED)
if spec.valid_agent_ids:
- agent_id = list(set(env_1.agents).intersection(spec.valid_agent_ids))[0]
+ agent_id = next(iter(set(env_1.agents).intersection(spec.valid_agent_ids)))
else:
agent_id = env_1.agents[0]
diff --git a/tests/download_assets.py b/tests/download_assets.py
index a36a80a..04fe6fc 100644
--- a/tests/download_assets.py
+++ b/tests/download_assets.py
@@ -2,6 +2,7 @@
import tempfile
import urllib.request
+
ASSETS_URL = "https://github.com/RDLLab/posggym-agent-models/tarball/refs/tags/v0.4.0"
@@ -17,6 +18,6 @@ def show_progress(block_num, block_size, total_size):
print(f"Downloading assets from {ASSETS_URL}")
-tarfile_path = tempfile.mktemp(suffix=".tar.gz")
+tarfile_path = tempfile.mktemp(suffix=".tar.gz") # noqa: S306
print(f"Downloading assets to {tarfile_path}")
urllib.request.urlretrieve(ASSETS_URL, filename=tarfile_path, reporthook=show_progress)
diff --git a/tests/envs/benchmark.py b/tests/envs/benchmark.py
new file mode 100644
index 0000000..e802bc4
--- /dev/null
+++ b/tests/envs/benchmark.py
@@ -0,0 +1,90 @@
+import timeit
+from functools import partial
+
+import numpy as np
+import torch
+from posggym.envs.continuous.predator_prey_continuous import PredatorPreyContinuousEnv
+from posggym.envs.differentiable.predator_prey_diff import PredatorPreyDiff
+from posggym.vector.sync_vector_env import SyncVectorEnv
+
+
+if __name__ == "__main__":
+ for batch_size in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]:
+
+ def env_fn():
+ return PredatorPreyContinuousEnv(
+ world="20x20Blocks", num_predators=7, num_prey=7
+ )
+
+ envs = SyncVectorEnv([env_fn for _ in range(batch_size)])
+
+ actions = {
+ i: np.stack([act_space.sample() for _ in range(batch_size)])
+ for i, act_space in envs.single_action_spaces.items()
+ }
+
+ def step(envs_, actions_):
+ envs_.step(actions_)
+
+ # Create a partial function with envs and actions as bound arguments
+ step_with_args = partial(step, envs_=envs, actions_=actions)
+
+ # Time the step function over 20 executions
+ execution_time = timeit.timeit(step_with_args, number=20)
+
+ for batch_size in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]:
+ env = PredatorPreyDiff(
+ batch_size=batch_size,
+ world="20x20Blocks",
+ num_predators=7,
+ num_prey=7,
+ )
+
+ def batch_sample_(a_s, batch_size):
+ return np.array([a_s.sample() for _ in range(batch_size)])
+
+ # Take a random action as input to the step function
+ a = {
+ i: torch.Tensor(batch_sample_(env.action_spaces[i], batch_size))
+ for i in env.agents
+ }
+ for action in a.values():
+ action.requires_grad_(True)
+
+ # Function to benchmark the step function
+ def benchmark_step1(env, a):
+ env.step(a)
+
+ benchmark_step_with_args = partial(benchmark_step1, env=env, a=a)
+
+ # Time the step function over 20 executions
+ execution_time = timeit.timeit(benchmark_step_with_args, number=20)
+ print(f"Average time per step: {execution_time / 20} seconds")
+
+ for batch_size in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]:
+ env = PredatorPreyDiff(
+ batch_size=batch_size,
+ world="20x20Blocks",
+ num_predators=7,
+ num_prey=7,
+ )
+
+ def batch_samples(a_s, batch_size):
+ return np.array([a_s.sample() for _ in range(batch_size)])
+
+ # Take a random action as input to the step function
+ a = {
+ i: torch.Tensor(batch_samples(env.action_spaces[i], batch_size))
+ for i in env.agents
+ }
+
+ # Function to benchmark the step function
+ def benchmark_step2(env, a):
+ with torch.no_grad():
+ env.step(a)
+
+ benchmark_step_with_args = partial(benchmark_step2, env=env, a=a)
+
+ # Time the step function over 20 executions
+ execution_time = timeit.timeit(benchmark_step2, number=20)
+ print(f"Average time per step: {execution_time / 20} seconds")
diff --git a/tests/envs/continuous/pursuit_evasion_balancing.py b/tests/envs/continuous/pursuit_evasion_balancing.py
index 569219e..eed6a12 100644
--- a/tests/envs/continuous/pursuit_evasion_balancing.py
+++ b/tests/envs/continuous/pursuit_evasion_balancing.py
@@ -3,13 +3,14 @@
import argparse
import sys
from itertools import product
-from typing import Dict, List, Optional, Tuple, cast
+from typing import cast
import numpy as np
import posggym
import pygame
from posggym.envs.continuous.pursuit_evasion_continuous import PEWorld
+
key_action_map = {
None: 0,
pygame.K_UP: np.array([0.0, 1.0], dtype=np.float32),
@@ -26,8 +27,8 @@
def run_keyboard_agent(
- env: posggym.Env, keyboard_agent_id: List[str]
-) -> Optional[Tuple[Dict[str, float], int]]:
+ env: posggym.Env, keyboard_agent_id: list[str]
+) -> tuple[dict[str, float], int] | None:
"""Run manual keyboard agent in continuous environment.
Assumes environment actions are continuous (i.e. space.Box). So user will be
diff --git a/tests/envs/continuous/test_drone_team_capture.py b/tests/envs/continuous/test_drone_team_capture.py
index 296b95a..b183d4a 100644
--- a/tests/envs/continuous/test_drone_team_capture.py
+++ b/tests/envs/continuous/test_drone_team_capture.py
@@ -1,8 +1,7 @@
"""Specific tests for the DroneTeamCapture-v0 environment."""
-import pytest
-
import posggym
+import pytest
@pytest.mark.parametrize("num_pursuers", [2, 3, 4, 8])
@@ -21,7 +20,7 @@ def test_init_steps(num_pursuers: int):
)
env.reset(seed=35)
- for t in range(100):
+ for _t in range(100):
a = {i: env.action_spaces[i].sample() for i in env.agents}
obs, _, _, _, all_done, _ = env.step(a)
diff --git a/tests/envs/continuous/test_predator_prey.py b/tests/envs/continuous/test_predator_prey.py
index a0e0a04..45f82b5 100644
--- a/tests/envs/continuous/test_predator_prey.py
+++ b/tests/envs/continuous/test_predator_prey.py
@@ -69,7 +69,7 @@ def test_collisions(world):
size = model.world.agent_radius
# pypunk munk allows overlaps, but only typically ~10-15%, so if overlap is as big
# as agent radius then something is wrong
- min_dist = 2 * size * 0.85
+ min_dist = 2 * size * 0.80
for _ in range(100):
state = cast(PPState, env.state)
diff --git a/tests/envs/continuous/test_pursuit_evasion.py b/tests/envs/continuous/test_pursuit_evasion.py
index b2f4748..f90b6c6 100644
--- a/tests/envs/continuous/test_pursuit_evasion.py
+++ b/tests/envs/continuous/test_pursuit_evasion.py
@@ -2,7 +2,6 @@
from typing import cast
import numpy as np
-
import posggym
from posggym.envs.continuous.pursuit_evasion_continuous import (
PEState,
@@ -44,8 +43,8 @@ def test_obs():
state = cast(PEState, env.state)
# Check state is as expected
- assert np.allclose(state.evader_state[:3], evader_start_coord + (0,))
- assert np.allclose(state.pursuer_state[:3], pursuer_start_coord + (0,))
+ assert np.allclose(state.evader_state[:3], (*evader_start_coord, 0))
+ assert np.allclose(state.pursuer_state[:3], (*pursuer_start_coord, 0))
assert np.allclose(state.evader_start_coord, evader_start_coord)
assert np.allclose(pursuer_start_coord, pursuer_start_coord)
assert np.allclose(state.evader_goal_coord, goal_coord)
@@ -154,8 +153,8 @@ def test_shortest_path():
state = cast(PEState, env.state)
# Check state is as expected
- assert np.allclose(state.evader_state[:3], evader_start_coord + (0,))
- assert np.allclose(state.pursuer_state[:3], pursuer_start_coord + (0,))
+ assert np.allclose(state.evader_state[:3], (*evader_start_coord, 0))
+ assert np.allclose(state.pursuer_state[:3], (*pursuer_start_coord, 0))
assert np.allclose(state.evader_start_coord, evader_start_coord)
assert np.allclose(pursuer_start_coord, pursuer_start_coord)
assert np.allclose(state.evader_goal_coord, goal_coord)
@@ -210,8 +209,8 @@ def test_shortest_path_not_double_reward():
state = cast(PEState, env.state)
# Check state is as expected
- assert np.allclose(state.evader_state[:3], evader_start_coord + (0,))
- assert np.allclose(state.pursuer_state[:3], pursuer_start_coord + (0,))
+ assert np.allclose(state.evader_state[:3], (*evader_start_coord, 0))
+ assert np.allclose(state.pursuer_state[:3], (*pursuer_start_coord, 0))
assert np.allclose(state.evader_start_coord, evader_start_coord)
assert np.allclose(pursuer_start_coord, pursuer_start_coord)
assert np.allclose(state.evader_goal_coord, goal_coord)
diff --git a/tests/envs/differentiable/__init__.py b/tests/envs/differentiable/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/envs/differentiable/test_drone_team_capture.py b/tests/envs/differentiable/test_drone_team_capture.py
new file mode 100644
index 0000000..ca4ece3
--- /dev/null
+++ b/tests/envs/differentiable/test_drone_team_capture.py
@@ -0,0 +1,145 @@
+"""Specific tests for the PredatorPreyDifferentiable-v0 environment."""
+
+import numpy as np
+import posggym
+import pytest
+import torch
+
+
+@pytest.mark.parametrize("num_predators", [2, 3, 4, 8])
+@torch.no_grad()
+def test_obs_steps(num_predators: int):
+ """Check observations are as expected after reset."""
+ BATCH_SIZE = 10
+
+ env = posggym.make(
+ "PredatorPreyDifferentiable-v0",
+ max_episode_steps=2,
+ batch_size=BATCH_SIZE,
+ num_predators=num_predators,
+ disable_env_checker=True,
+ )
+ env.reset(seed=35)
+
+ def batch_samples(a_s):
+ return np.array([a_s.sample() for _ in range(BATCH_SIZE)])
+
+ for _ in range(10):
+ a = {i: torch.Tensor(batch_samples(env.action_spaces[i])) for i in env.agents}
+
+ obs, _, _, _, all_done, _ = env.step(a)
+
+ for i, o_i in obs.items():
+ for b in range(BATCH_SIZE):
+ assert env.observation_spaces[i].contains(
+ o_i[b].detach().cpu().numpy().squeeze()
+ ), f"Agent {i} observation {o_i[b]} is not in its observation space."
+
+ env.close()
+
+
+def run_grad_step(num_predators: int):
+ """Check observations are as expected after reset."""
+ BATCH_SIZE = 10
+
+ env = posggym.make(
+ "PredatorPreyDifferentiable-v0",
+ max_episode_steps=40,
+ batch_size=BATCH_SIZE,
+ num_predators=num_predators,
+ disable_env_checker=True,
+ )
+ env.reset(seed=35)
+
+ def batch_samples(a_s):
+ return np.array([a_s.sample() for _ in range(BATCH_SIZE)])
+
+ for t in range(5):
+ a = {i: torch.Tensor(batch_samples(env.action_spaces[i])) for i in env.agents}
+
+ for action in a.values():
+ action.requires_grad_(True)
+
+ if t == 0:
+ first_action = a
+
+ obs, rews, _, _, all_done, _ = env.step(a)
+
+ assert not all_done.all()
+
+ loss = obs["agent_0"].mean() + rews["agent_0"].mean()
+ grad = torch.autograd.grad(loss, first_action["agent_0"], allow_unused=True)
+
+ assert grad is not None
+ assert abs(grad[0]).sum() > 0
+
+ env.close()
+
+ return grad[0]
+
+
+def run_model_step(num_predators: int):
+ """Check observations are as expected after reset."""
+ BATCH_SIZE = 10
+
+ env = posggym.make(
+ "PredatorPreyDifferentiable-v0",
+ max_episode_steps=40,
+ batch_size=BATCH_SIZE,
+ num_predators=num_predators,
+ disable_env_checker=True,
+ )
+ env.reset(seed=35)
+
+ def batch_samples(a_s):
+ return np.array([a_s.sample() for _ in range(BATCH_SIZE)])
+
+ for t in range(5):
+ a = {i: torch.Tensor(batch_samples(env.action_spaces[i])) for i in env.agents}
+
+ for action in a.values():
+ action.requires_grad_(True)
+
+ if t == 0:
+ first_action = a
+
+ state = env.model.sample_initial_state()
+ state = env.model.step(state, a).state
+ state = env.model.step(state, a).state
+ state = env.model.step(state, a).state
+ state = env.model.step(state, a).state
+
+ obs, rews, _, _, all_done, _ = env.step(a)
+
+ assert not all_done.all()
+
+ loss = obs["agent_0"].mean() + rews["agent_0"].mean()
+ grad = torch.autograd.grad(loss, first_action["agent_0"], allow_unused=True)
+
+ assert grad is not None
+ assert abs(grad[0]).sum() > 0
+
+ env.close()
+
+ return grad[0]
+
+
+@pytest.mark.parametrize("num_predators", [3])
+def test_grad(num_predators: int):
+ run_grad_step(num_predators)
+
+
+@pytest.mark.parametrize("num_predators", [3])
+def test_model_step(num_predators: int):
+ run_model_step(num_predators)
+
+
+@pytest.mark.parametrize("num_predators", [3])
+def test_compare_gradients(num_predators: int):
+ """Compare gradients from test_grad and test_model_step."""
+ g1 = run_grad_step(num_predators)
+ g2 = run_model_step(num_predators)
+
+ assert torch.allclose(
+ g1, g2
+ ), "Gradients from test_grad and test_model_step do not match."
diff --git a/tests/envs/grid_world/run_driving_grid_generator.py b/tests/envs/grid_world/run_driving_grid_generator.py
index e45dd7a..ddbdea1 100644
--- a/tests/envs/grid_world/run_driving_grid_generator.py
+++ b/tests/envs/grid_world/run_driving_grid_generator.py
@@ -2,7 +2,6 @@
import argparse
import sys
-from typing import Optional
from posggym.envs.grid_world.driving_gen import DrivingGridGenerator
@@ -10,7 +9,7 @@
def main(
width: int,
height: int,
- max_obstacle_size: Optional[int] = None,
+ max_obstacle_size: int | None = None,
seed: int = 0,
):
"""Run."""
diff --git a/tests/envs/grid_world/run_grid_generator.py b/tests/envs/grid_world/run_grid_generator.py
index d0e00b7..666f866 100644
--- a/tests/envs/grid_world/run_grid_generator.py
+++ b/tests/envs/grid_world/run_grid_generator.py
@@ -3,12 +3,11 @@
import argparse
import random
import sys
-from typing import Optional
from posggym.envs.grid_world.core import GridGenerator
-def _generate_mask(width: int, height: int, seed: Optional[int]):
+def _generate_mask(width: int, height: int, seed: int | None):
rng = random.Random(None) if seed is None else random.Random(seed + 1)
mask = set()
@@ -24,9 +23,9 @@ def main(
width: int,
height: int,
use_random_mask: bool,
- max_obstacle_size: Optional[int] = None,
+ max_obstacle_size: int | None = None,
seed: int = 0,
- check_grid_connectedness: Optional[int] = False,
+ check_grid_connectedness: int | None = False,
):
"""Run."""
mask = _generate_mask(width, height, seed) if use_random_mask else set()
diff --git a/tests/envs/grid_world/test_core.py b/tests/envs/grid_world/test_core.py
index fd36689..64a899d 100644
--- a/tests/envs/grid_world/test_core.py
+++ b/tests/envs/grid_world/test_core.py
@@ -1,5 +1,5 @@
"""Tests for envs.grid_world.core."""
-from posggym.envs.grid_world.core import Grid, Direction
+from posggym.envs.grid_world.core import Direction, Grid
class TestGrid:
diff --git a/tests/envs/test_action_dim_check.py b/tests/envs/test_action_dim_check.py
index efc22ee..b93ad0e 100644
--- a/tests/envs/test_action_dim_check.py
+++ b/tests/envs/test_action_dim_check.py
@@ -3,14 +3,22 @@
Ref:
https://github.com/Farama-Foundation/Gymnasium/blob/v0.27.0/tests/envs/test_action_dim_check.py
"""
-from typing import Dict, Tuple, Union
import numpy as np
import pytest
from gymnasium import spaces
+
from tests.envs.utils import all_testing_initialised_envs
+
+try:
+ import torch
+except ImportError:
+ torch = None
+
import posggym
+from posggym.utils.torch_utils import maybe_expand_dims
+
DISCRETE_ENVS = list(
filter(
@@ -42,9 +50,8 @@ def test_discrete_actions_out_of_bound(env: posggym.Env):
i: env.action_spaces[i].start + env.action_spaces[i].n # type: ignore
for i in env.agents
}
-
env.reset()
- with pytest.raises(Exception):
+ with pytest.raises(AssertionError):
env.step(upper_bounds)
env.close()
@@ -63,7 +70,7 @@ def test_discrete_actions_out_of_bound(env: posggym.Env):
def tuple_equal(
- a: Tuple[Union[int, np.ndarray, float]], b: Tuple[Union[int, np.ndarray, float]]
+ a: tuple[int | np.ndarray | float], b: tuple[int | np.ndarray | float]
) -> bool:
if len(a) != len(b):
return False
@@ -71,6 +78,14 @@ def tuple_equal(
if isinstance(a[i], np.ndarray) and isinstance(b[i], np.ndarray):
if not np.array_equal(a[i], b[i]):
return False
+ elif (
+ torch is not None
+ and isinstance(a[i], torch.Tensor)
+ and isinstance(b[i], torch.Tensor)
+ ):
+ if not torch.equal(a[i], b[i]): # type: ignore
+ return False
+
elif a[i] != b[i]:
return False
return True
@@ -92,19 +107,23 @@ def test_box_actions_out_of_bound(env: posggym.Env):
oob_env = posggym.make(env.spec.id, disable_env_checker=True)
oob_env.reset(seed=42)
- action_spaces: Dict[str, spaces.Box] = env.action_spaces # type: ignore
+ action_spaces: dict[str, spaces.Box] = env.action_spaces # type: ignore
assert all(
isinstance(act_space, spaces.Box) for act_space in action_spaces.values()
)
dtypes = {i: action_spaces[i].dtype for i in env.agents}
- upper_bounds = {i: action_spaces[i].high for i in env.agents}
- lower_bounds = {i: action_spaces[i].low for i in env.agents}
+
+ upper_bounds = {
+ i: maybe_expand_dims(env, action_spaces[i].high) for i in env.agents
+ }
+ lower_bounds = {i: maybe_expand_dims(env, action_spaces[i].low) for i in env.agents}
if all(np.all(action_spaces[i].bounded_above) for i in env.agents):
obs, _, _, _, _, _ = env.step(upper_bounds)
oob_actions = {
- i: np.cast[dtypes[i]](upper_bounds[i] + OOB_VALUE) for i in upper_bounds
+ i: np.asarray(upper_bounds[i] + OOB_VALUE, dtype=dtypes[i])
+ for i in upper_bounds
}
assert all(np.all(oob_actions[i] > upper_bounds[i]) for i in upper_bounds)
@@ -116,7 +135,8 @@ def test_box_actions_out_of_bound(env: posggym.Env):
obs, _, _, _, _, _ = env.step(lower_bounds)
oob_actions = {
- i: np.cast[dtypes[i]](lower_bounds[i] - OOB_VALUE) for i in lower_bounds
+ i: np.asarray(lower_bounds[i] - OOB_VALUE, dtype=dtypes[i])
+ for i in lower_bounds
}
assert all(np.all(oob_actions[i] < lower_bounds[i]) for i in lower_bounds)
oob_obs, _, _, _, _, _ = oob_env.step(oob_actions)
diff --git a/tests/envs/test_envs.py b/tests/envs/test_envs.py
index 183ce52..c51f226 100644
--- a/tests/envs/test_envs.py
+++ b/tests/envs/test_envs.py
@@ -6,17 +6,23 @@
import pickle
import warnings
+import posggym
import pytest
+from posggym.envs.registration import EnvSpec
+from posggym.utils.env_checker import check_env
+from posggym.utils.passive_env_checker import data_equivalence
+from posggym.utils.torch_utils import maybe_expand_dims
+
from tests.envs.utils import (
all_testing_env_specs,
- all_testing_initialised_envs,
assert_equals,
)
-import posggym
-from posggym.envs.registration import EnvSpec
-from posggym.utils.env_checker import check_env
-from posggym.utils.passive_env_checker import data_equivalence
+
+try:
+ import torch
+except ImportError:
+ torch = None
PASSIVE_CHECK_IGNORE_WARNING = [
f"\x1b[33mWARN: {message}"
@@ -27,7 +33,6 @@
CHECK_ENV_IGNORE_WARNINGS = [
- # f"\x1b[33mWARN: {message}\x1b[0m"
f"\x1b[33mWARN: {message}"
for message in [
"A Box observation space minimum value is -infinity. This is probably too low.",
@@ -109,7 +114,10 @@ def test_env_determinism_rollout(env_spec: EnvSpec):
assert_equals(env_1.state, env_2.state, f"[{time_step}][State] ")
# We don't evaluate the determinism of actions
- actions = {i: env_1.action_spaces[i].sample() for i in env_1.agents}
+ actions = {
+ i: maybe_expand_dims(env_1, env_1.action_spaces[i].sample())
+ for i in env_1.agents
+ }
obs_1, rew_1, term_1, trunc_1, done_1, info_1 = env_1.step(actions)
obs_2, rew_2, term_2, trunc_2, done_2, info_2 = env_2.step(actions)
@@ -117,6 +125,9 @@ def test_env_determinism_rollout(env_spec: EnvSpec):
assert_equals(obs_1, obs_2, f"[{time_step}][Observations] ")
# obs_2 verified by previous assertion
for i, o_i in obs_1.items():
+ if torch is not None and isinstance(o_i, torch.Tensor):
+ o_i = o_i.cpu().detach().numpy().squeeze()
+
assert env_1.observation_spaces[i].contains(o_i)
assert_equals(rew_1, rew_2, f"[{time_step}][Rewards] ")
@@ -138,17 +149,20 @@ def test_env_determinism_rollout(env_spec: EnvSpec):
@pytest.mark.parametrize(
- "env",
- all_testing_initialised_envs,
- ids=[env.spec.id for env in all_testing_initialised_envs if env.spec is not None],
+ "env_spec",
+ all_testing_env_specs,
+ ids=[env.id for env in all_testing_env_specs],
)
-def test_pickle_env(env: posggym.Env):
+def test_pickle_env(env_spec: EnvSpec):
"""Test that env can be pickled consistently."""
+ env = env_spec.make(disable_env_checker=True)
pickled_env = pickle.loads(pickle.dumps(env))
data_equivalence(env.reset(), pickled_env.reset())
- actions = {i: env.action_spaces[i].sample() for i in env.agents}
+ actions = {
+ i: maybe_expand_dims(env, env.action_spaces[i].sample()) for i in env.agents
+ }
data_equivalence(env.step(actions), pickled_env.step(actions))
env.close()
pickled_env.close()
diff --git a/tests/envs/test_make.py b/tests/envs/test_make.py
index 497053b..64d9670 100644
--- a/tests/envs/test_make.py
+++ b/tests/envs/test_make.py
@@ -7,11 +7,12 @@
import warnings
from copy import deepcopy
-import pytest
-
import posggym
+import pytest
from posggym.envs.classic import mabc
+from posggym.utils.torch_utils import maybe_expand_dims
from posggym.wrappers import OrderEnforcing, PassiveEnvChecker, TimeLimit
+
from tests.envs.test_envs import PASSIVE_CHECK_IGNORE_WARNING
from tests.envs.utils import all_testing_env_specs
from tests.envs.utils_envs import ArgumentEnv, RegisterDuringMakeEnv
@@ -20,7 +21,7 @@
@pytest.fixture(scope="function")
def register_make_testing_envs():
- """Registers testing envs for `posggym.make`"""
+ """Registers testing envs for `posggym.make`."""
posggym.register(
"DummyEnv-v0",
entry_point="tests.envs.utils_envs:DummyEnv",
@@ -165,7 +166,13 @@ def test_passive_checker_wrapper_warnings(spec):
with warnings.catch_warnings(record=True) as caught_warnings:
env = posggym.make(spec) # disable_env_checker=False
env.reset()
- env.step({i: env.action_spaces[i].sample() for i in env.agents})
+
+ env.step(
+ {
+ i: maybe_expand_dims(env, env.action_spaces[i].sample())
+ for i in env.agents
+ }
+ )
env.close()
for warning in caught_warnings:
@@ -240,26 +247,16 @@ def test_make_render_mode(register_make_testing_envs):
# Add this test when it is
# def test_make_human_rendering(register_make_testing_envs):
# # Make sure that native rendering is used when possible
-# env = posggym.make("MultiAccessBroadcastChannel-v0", render_mode="human")
-# assert not has_wrapper(env, HumanRendering) # Should use native human-rendering
-# assert env.render_mode == "human"
-# env.close()
# with pytest.warns(
# UserWarning,
-# match=re.escape(
# "You are trying to use 'human' rendering for an environment that doesn't "
# "natively support it. The HumanRendering wrapper is being applied to "
# " your environment."
# ),
# ):
# # Make sure that `HumanRendering` is applied here
-# env = posggym.make(
-# "test/NoHuman-v0", render_mode="human"
# ) # This environment doesn't use native rendering
-# assert has_wrapper(env, HumanRendering)
-# assert env.render_mode == "human"
-# env.close()
def test_make_kwargs(register_make_testing_envs):
diff --git a/tests/envs/test_models.py b/tests/envs/test_models.py
index a125c77..e1b60a2 100644
--- a/tests/envs/test_models.py
+++ b/tests/envs/test_models.py
@@ -5,15 +5,22 @@
"""
import warnings
-import pytest
-from tests.envs.test_envs import CHECK_ENV_IGNORE_WARNINGS
-from tests.envs.utils import all_testing_env_specs, assert_equals
-
import posggym
import posggym.model as M
+import pytest
from posggym.envs.registration import EnvSpec
from posggym.utils.model_checker import check_model
+from tests.envs.test_envs import CHECK_ENV_IGNORE_WARNINGS
+from tests.envs.utils import all_testing_env_specs, assert_equals
+
+
+try:
+ import torch
+except ImportError:
+ torch = None
+from posggym.utils.torch_utils import maybe_expand_dims
+
@pytest.mark.parametrize(
"spec",
@@ -24,7 +31,7 @@ def test_models_pass_env_checker(spec):
"""Check that all environment models pass checker with no unexpected warnings."""
with warnings.catch_warnings(record=True) as caught_warnings:
env = spec.make(disable_env_checker=True).unwrapped
- check_model(env.model)
+ check_model(env, env.model)
env.close()
@@ -99,6 +106,11 @@ def test_model_determinism_rollout(env_spec: EnvSpec):
initial_obs_2 = model_2.sample_initial_obs(initial_state_2)
assert_equals(initial_obs_1, initial_obs_2)
# obs_2 verified by previous assertion
+ if torch is not None and isinstance(
+ next(iter(initial_obs_1.values())), torch.Tensor
+ ):
+ initial_obs_1 = {k: v.cpu().numpy().squeeze() for k, v in initial_obs_1.items()}
+
assert all(
model_1.observation_spaces[i].contains(o_i) for i, o_i in initial_obs_1.items()
)
@@ -130,7 +142,8 @@ def test_model_determinism_rollout(env_spec: EnvSpec):
for t in range(num_steps):
# We don't evaluate the determinism of actions
actions = {
- i: model_1.action_spaces[i].sample() for i in model_1.get_agents(state)
+ i: maybe_expand_dims(env_1, model_1.action_spaces[i].sample())
+ for i in model_1.get_agents(state)
}
result_1 = model_1.step(state, actions)
@@ -153,6 +166,9 @@ def test_model_determinism_rollout(env_spec: EnvSpec):
)
# obs_2 verified by previous assertion
for i, o_i in result_1.observations.items():
+ if torch is not None and isinstance(o_i, torch.Tensor):
+ o_i = o_i.cpu().detach().numpy().squeeze()
+
assert model_1.observation_spaces[i].contains(o_i)
assert all(i in result_1.observations for i in model_1.get_agents(state))
diff --git a/tests/envs/test_register.py b/tests/envs/test_register.py
index ae45c87..195f62a 100644
--- a/tests/envs/test_register.py
+++ b/tests/envs/test_register.py
@@ -5,11 +5,9 @@
"""
import re
-from typing import Optional
-
-import pytest
import posggym
+import pytest
@pytest.fixture(scope="function")
@@ -65,9 +63,7 @@ def register_registration_testing_envs():
("MyAwesomeEnv-v", None, "MyAwesomeEnv-v", None),
],
)
-def test_register(
- env_id: str, namespace: Optional[str], name: str, version: Optional[int]
-):
+def test_register(env_id: str, namespace: str | None, name: str, version: int | None):
posggym.register(env_id, "no-entry-point")
assert posggym.spec(env_id).id == env_id
@@ -113,7 +109,7 @@ def test_env_suggestions(
register_registration_testing_envs, env_id_input, env_id_suggested
):
with pytest.raises(
- posggym.error.UnregisteredEnv, match=f"Did you mean: `{env_id_suggested}`?"
+ posggym.error.UnregisteredEnvError, match=f"Did you mean: `{env_id_suggested}`?"
):
posggym.make(env_id_input, disable_env_checker=True)
@@ -134,13 +130,13 @@ def test_env_version_suggestions(
):
if default_version:
with pytest.raises(
- posggym.error.DeprecatedEnv,
+ posggym.error.DeprecatedEnvError,
match="It provides the default version", # env name,
):
posggym.make(env_id_input, disable_env_checker=True)
else:
with pytest.raises(
- posggym.error.UnregisteredEnv,
+ posggym.error.UnregisteredEnvError,
match=f"It provides versioned environments: \\[ {suggested_versions} \\]",
):
posggym.make(env_id_input, disable_env_checker=True)
diff --git a/tests/envs/test_rendering.py b/tests/envs/test_rendering.py
index a0c8392..81e7099 100644
--- a/tests/envs/test_rendering.py
+++ b/tests/envs/test_rendering.py
@@ -6,7 +6,7 @@
import numpy as np
import pytest
from posggym.envs.registration import EnvSpec
-from posggym.logger import warn
+from posggym.logger import warning
from tests.envs.utils import all_testing_env_specs
@@ -32,7 +32,7 @@ def check_rendered(rendered_frame, mode: str):
assert isinstance(rendered_frame, str)
assert len(rendered_frame) > 0
else:
- warn(
+ warning(
f"Unknown render mode: {mode}, cannot check that the rendered data is "
"correct. Add case to `check_rendered`"
)
@@ -44,8 +44,6 @@ def check_rendered(rendered_frame, mode: str):
def test_render_modes(spec: EnvSpec):
env = spec.make(disable_env_checker=True)
- # assert "rgb_array" in env.metadata["render_modes"]
-
for mode in env.metadata["render_modes"]:
if mode != "human":
new_env = spec.make(render_mode=mode, disable_env_checker=True)
diff --git a/tests/envs/test_spec.py b/tests/envs/test_spec.py
index 4f96465..937aace 100644
--- a/tests/envs/test_spec.py
+++ b/tests/envs/test_spec.py
@@ -5,9 +5,8 @@
"""
import re
-import pytest
-
import posggym
+import pytest
def test_spec():
@@ -34,7 +33,7 @@ def test_spec_missing_lookup():
posggym.register(id="Other1-v100", entry_point="no-entry-point")
with pytest.raises(
- posggym.error.DeprecatedEnv,
+ posggym.error.DeprecatedEnvError,
match=re.escape(
"Environment version v1 for `Test1` is deprecated. Please use `Test1-v15` "
"instead."
@@ -43,7 +42,7 @@ def test_spec_missing_lookup():
posggym.spec("Test1-v1")
with pytest.raises(
- posggym.error.UnregisteredEnv,
+ posggym.error.UnregisteredEnvError,
match=re.escape(
"Environment version `v1000` for environment `Test1` doesn't exist. "
"It provides versioned environments: [ `v0`, `v9`, `v15` ]."
@@ -52,7 +51,7 @@ def test_spec_missing_lookup():
posggym.spec("Test1-v1000")
with pytest.raises(
- posggym.error.UnregisteredEnv,
+ posggym.error.UnregisteredEnvError,
match=re.escape("Environment Unknown1 doesn't exist. "),
):
posggym.spec("Unknown1-v1")
@@ -75,7 +74,7 @@ def test_spec_versioned_lookups():
posggym.register("test/Test2-v5", "no-entry-point")
with pytest.raises(
- posggym.error.VersionNotFound,
+ posggym.error.VersionNotFoundError,
match=re.escape(
"Environment version `v9` for environment `test/Test2` doesn't exist. "
"It provides versioned environments: [ `v5` ]."
@@ -84,7 +83,7 @@ def test_spec_versioned_lookups():
posggym.spec("test/Test2-v9")
with pytest.raises(
- posggym.error.DeprecatedEnv,
+ posggym.error.DeprecatedEnvError,
match=re.escape(
"Environment version v4 for `test/Test2` is deprecated. Please use "
"`test/Test2-v5` instead."
@@ -99,7 +98,7 @@ def test_spec_default_lookups():
posggym.register("test/Test3", "no-entry-point")
with pytest.raises(
- posggym.error.DeprecatedEnv,
+ posggym.error.DeprecatedEnvError,
match=re.escape(
"Environment version `v0` for environment `test/Test3` doesn't exist. "
"It provides the default version test/Test3`."
diff --git a/tests/envs/utils.py b/tests/envs/utils.py
index 317fbf0..df84156 100644
--- a/tests/envs/utils.py
+++ b/tests/envs/utils.py
@@ -3,7 +3,6 @@
Reference:
https://github.com/Farama-Foundation/Gymnasium/blob/v0.27.0/tests/envs/utils.py
"""
-from typing import List, Optional
import numpy as np
import posggym
@@ -13,7 +12,13 @@
from tests.conftest import env_id_prefix
-def try_make_env(env_spec: EnvSpec) -> Optional[posggym.Env]:
+try:
+ import torch
+except ImportError:
+ torch = None
+
+
+def try_make_env(env_spec: EnvSpec) -> posggym.Env | None:
"""Tries to make the environment showing if it is possible.
Warning the environments have no wrappers, including time limit and order enforcing.
@@ -28,25 +33,25 @@ def try_make_env(env_spec: EnvSpec) -> Optional[posggym.Env]:
return env_spec.make(disable_env_checker=True).unwrapped
except (
ImportError,
- posggym.error.DependencyNotInstalled,
- posggym.error.MissingArgument,
+ posggym.error.DependencyNotInstalledError,
+ posggym.error.MissingArgumentError,
) as e:
- logger.warn(f"Not testing {env_spec.id} due to error: {e}")
+ logger.warning(f"Not testing {env_spec.id} due to error: {e}")
return None
# Tries to make all environment to test with
-_all_testing_initialised_envs: List[Optional[posggym.Env]] = [
+_all_testing_initialised_envs: list[posggym.Env] | None = [
try_make_env(env_spec)
for env_spec in posggym.envs.registry.values()
if env_id_prefix is None or env_spec.id.startswith(env_id_prefix)
]
-all_testing_initialised_envs: List[posggym.Env] = [
+all_testing_initialised_envs: list[posggym.Env] = [
env for env in _all_testing_initialised_envs if env is not None
]
# All testing posggym environment specs
-all_testing_env_specs: List[EnvSpec] = [
+all_testing_env_specs: list[EnvSpec] = [
env.spec for env in all_testing_initialised_envs if env.spec is not None
]
@@ -54,7 +59,7 @@ def try_make_env(env_spec: EnvSpec) -> Optional[posggym.Env]:
def assert_equals(a, b, prefix=None):
"""Assert equality of data structures `a` and `b`.
- Arguments
+ Arguments:
---------
a:
first data structure
@@ -73,8 +78,11 @@ def assert_equals(a, b, prefix=None):
assert_equals(v_a, v_b, prefix)
elif isinstance(a, np.ndarray):
np.testing.assert_array_equal(a, b)
+ elif torch is not None and isinstance(a, torch.Tensor):
+ torch.testing.assert_close(a, b)
+
elif isinstance(a, tuple):
- for elem_from_a, elem_from_b in zip(a, b):
+ for elem_from_a, elem_from_b in zip(a, b, strict=False):
assert_equals(elem_from_a, elem_from_b, prefix)
else:
assert a == b
diff --git a/tests/envs/utils_envs.py b/tests/envs/utils_envs.py
index 98fd294..3a94e2d 100644
--- a/tests/envs/utils_envs.py
+++ b/tests/envs/utils_envs.py
@@ -4,6 +4,8 @@
https://github.com/Farama-Foundation/Gymnasium/blob/v0.27.0/tests/envs/utils_envs.py
"""
+from typing import ClassVar
+
import posggym
from tests.envs.utils_models import TestModel
@@ -12,21 +14,21 @@
class DummyEnv(posggym.DefaultEnv):
"""Dummy env for use in environment registration and make tests ."""
- def __init__(self):
+ def __init__(self) -> None:
super().__init__(TestModel())
class RegisterDuringMakeEnv(posggym.DefaultEnv):
"""For `test_registration.py` to check `env.make` can import and register env."""
- def __init__(self):
+ def __init__(self) -> None:
super().__init__(TestModel())
class ArgumentEnv(posggym.DefaultEnv):
"""For `test_registration.py` to check `env.make` can import and register env."""
- def __init__(self, arg1, arg2, arg3):
+ def __init__(self, arg1, arg2, arg3) -> None:
super().__init__(TestModel())
self.arg1 = arg1
self.arg2 = arg2
@@ -37,9 +39,9 @@ def __init__(self, arg1, arg2, arg3):
class NoHuman(posggym.DefaultEnv):
"""Environment that does not have human-rendering."""
- metadata = {"render_modes": ["rgb_array_list"], "render_fps": 4}
+ metadata: ClassVar[dict] = {"render_modes": ["rgb_array_list"], "render_fps": 4}
- def __init__(self, render_mode=None):
+ def __init__(self, render_mode=None) -> None:
super().__init__(TestModel())
assert render_mode in self.metadata["render_modes"]
self.render_mode = render_mode
@@ -48,9 +50,9 @@ def __init__(self, render_mode=None):
class NoHumanNoRGB(posggym.DefaultEnv):
"""Environment that has neither human- nor rgb-rendering."""
- metadata = {"render_modes": ["ascii"], "render_fps": 4}
+ metadata: ClassVar[dict] = {"render_modes": ["ascii"], "render_fps": 4}
- def __init__(self, render_mode=None):
+ def __init__(self, render_mode=None) -> None:
super().__init__(TestModel())
assert render_mode in self.metadata["render_modes"]
self.render_mode = render_mode
diff --git a/tests/envs/utils_models.py b/tests/envs/utils_models.py
index c5cfa30..1bc6481 100644
--- a/tests/envs/utils_models.py
+++ b/tests/envs/utils_models.py
@@ -1,5 +1,4 @@
"""Test models for posggym."""
-from typing import Dict, List
import posggym.model as M
from gymnasium import spaces
@@ -9,7 +8,7 @@
class TestModel(M.POSGModel):
"""Basic test model."""
- def __init__(self):
+ def __init__(self) -> None:
self.possible_agents = (0, 1)
self.action_spaces = {i: spaces.Discrete(2) for i in self.possible_agents}
self.observation_spaces = {i: spaces.Discrete(2) for i in self.possible_agents}
@@ -21,16 +20,16 @@ def rng(self) -> seeding.RNG:
self._rng, seed = seeding.std_random()
return self._rng
- def get_agents(self, state: int) -> List[str]:
+ def get_agents(self, state: int) -> list[str]:
return list(self.possible_agents)
def sample_initial_state(self) -> int:
return 0
- def sample_initial_obs(self, state: int) -> Dict[str, int]:
+ def sample_initial_obs(self, state: int) -> dict[str, int]:
return {i: 0 for i in self.possible_agents}
- def step(self, state: int, actions: Dict[str, int]) -> M.JointTimestep[int, int]:
+ def step(self, state: int, actions: dict[str, int]) -> M.JointTimestep[int, int]:
return M.JointTimestep(
0,
{i: 0 for i in self.possible_agents},
diff --git a/tests/test_core.py b/tests/test_core.py
index 1c09271..bd7d162 100644
--- a/tests/test_core.py
+++ b/tests/test_core.py
@@ -4,7 +4,7 @@
https://github.com/Farama-Foundation/Gymnasium/blob/v0.27.0/tests/test_core.py
"""
-from typing import Any, Dict, Optional, Tuple
+from typing import Any
import numpy as np
import posggym.model as M
@@ -25,7 +25,7 @@
class ExampleEnv(DefaultEnv[int, int, int]):
"""Example testing environment."""
- def __init__(self):
+ def __init__(self) -> None:
super().__init__(ExampleModel())
@@ -41,25 +41,25 @@ def test_posggym_env():
class ExampleWrapper(Wrapper):
"""An example testing wrapper."""
- def __init__(self, env: Env[M.StateType, M.ObsType, M.ActType]):
+ def __init__(self, env: Env[M.StateType, M.ObsType, M.ActType]) -> None:
"""Constructor that sets the reward."""
super().__init__(env)
self.new_reward = 3
def reset(
- self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None
- ) -> Tuple[Dict[str, WrapperObsType], Dict[str, Dict]]:
+ self, *, seed: int | None = None, options: dict[str, Any] | None = None
+ ) -> tuple[dict[str, WrapperObsType], dict[str, dict]]:
return super().reset(seed=seed, options=options)
def step(
- self, actions: Dict[str, WrapperActType]
- ) -> Tuple[
- Dict[str, WrapperObsType],
- Dict[str, float],
- Dict[str, bool],
- Dict[str, bool],
+ self, actions: dict[str, WrapperActType]
+ ) -> tuple[
+ dict[str, WrapperObsType],
+ dict[str, float],
+ dict[str, bool],
+ dict[str, bool],
bool,
- Dict[str, Dict],
+ dict[str, dict],
]:
obs, reward, term, trunc, done, info = self.env.step(actions) # type: ignore
reward = {i: self.new_reward for i in reward}
@@ -96,21 +96,21 @@ def test_posggym_wrapper():
class ExampleRewardWrapper(RewardWrapper):
"""Example reward wrapper for testing."""
- def rewards(self, rewards: Dict[str, float]) -> Dict[str, float]:
+ def rewards(self, rewards: dict[str, float]) -> dict[str, float]:
return {i: 1 for i in rewards}
class ExampleObservationWrapper(ObservationWrapper):
"""Example observation wrapper for testing."""
- def observations(self, obs: Dict[str, M.ObsType]) -> Dict[str, WrapperObsType]:
+ def observations(self, obs: dict[str, M.ObsType]) -> dict[str, WrapperObsType]:
return {i: np.array([1]) for i in obs} # type: ignore
class ExampleActionWrapper(ActionWrapper):
"""Example action wrapper for testing."""
- def actions(self, actions: Dict[str, M.ActType]) -> Dict[str, WrapperActType]:
+ def actions(self, actions: dict[str, M.ActType]) -> dict[str, WrapperActType]:
return {i: np.array([1]) for i in actions} # type: ignore
@@ -120,18 +120,18 @@ class ActionWrapperTestEnv(DefaultEnv[int, int, int]):
Step returns the action as an observation.
"""
- def __init__(self):
+ def __init__(self) -> None:
super().__init__(ExampleModel())
def step(
- self, actions: Dict[str, int]
- ) -> Tuple[
- Dict[str, int],
- Dict[str, float],
- Dict[str, bool],
- Dict[str, bool],
+ self, actions: dict[str, int]
+ ) -> tuple[
+ dict[str, int],
+ dict[str, float],
+ dict[str, bool],
+ dict[str, bool],
bool,
- Dict[str, Dict],
+ dict[str, dict],
]:
step = self.model.step(self._state, actions)
self._step_num += 1
diff --git a/tests/test_model.py b/tests/test_model.py
index 4c7c8be..2e61dc4 100644
--- a/tests/test_model.py
+++ b/tests/test_model.py
@@ -1,16 +1,14 @@
"""Checks that the core posggym model API is implemented as expected."""
-from typing import Dict, List
-
-from gymnasium import spaces
import posggym.model as M
+from gymnasium import spaces
from posggym.utils import seeding
class ExampleModel(M.POSGModel[int, int, int]):
"""Example discrete testing model."""
- def __init__(self):
+ def __init__(self) -> None:
self.possible_agents = (0, 1)
self.action_spaces = {i: spaces.Discrete(2) for i in self.possible_agents}
self.observation_spaces = {i: spaces.Discrete(2) for i in self.possible_agents}
@@ -22,16 +20,16 @@ def rng(self) -> seeding.RNG:
self._rng, seed = seeding.std_random()
return self._rng
- def get_agents(self, state: int) -> List[str]:
+ def get_agents(self, state: int) -> list[str]:
return list(self.possible_agents)
def sample_initial_state(self) -> int:
return 0
- def sample_initial_obs(self, state: int) -> Dict[str, int]:
+ def sample_initial_obs(self, state: int) -> dict[str, int]:
return {i: 0 for i in self.possible_agents}
- def step(self, state: int, actions: Dict[str, int]) -> M.JointTimestep[int, int]:
+ def step(self, state: int, actions: dict[str, int]) -> M.JointTimestep[int, int]:
return M.JointTimestep(
0,
{i: 0 for i in self.possible_agents},
diff --git a/tests/utils/test_seeding.py b/tests/utils/test_seeding.py
index 67a643c..5487f1e 100644
--- a/tests/utils/test_seeding.py
+++ b/tests/utils/test_seeding.py
@@ -18,14 +18,18 @@ def test_invalid_seeds():
except error.Error:
pass
else:
- assert False, f"Invalid seed {seed} passed validation for `np_random`"
+ raise AssertionError(
+ f"Invalid seed {seed} passed validation for `np_random`"
+ )
try:
seeding.std_random(seed)
except error.Error:
pass
else:
- assert False, f"Invalid seed {seed} passed validation for `std_random`"
+ raise AssertionError(
+ f"Invalid seed {seed} passed validation for `std_random`"
+ )
def test_valid_seeds():
diff --git a/tests/vector/test_sync_vector_env.py b/tests/vector/test_sync_vector_env.py
index 9b1f6ea..ba1ef96 100644
--- a/tests/vector/test_sync_vector_env.py
+++ b/tests/vector/test_sync_vector_env.py
@@ -62,7 +62,7 @@ def test_reset_sync_vector_env():
assert isinstance(env.observation_spaces[agent_id], spaces.Box)
assert isinstance(obs_i, np.ndarray)
assert obs_i.shape == env.observation_spaces[agent_id].shape
- assert obs_i.shape == (8,) + env.single_observation_spaces[agent_id].shape
+ assert obs_i.shape == (8, *env.single_observation_spaces[agent_id].shape)
assert obs_i.dtype == env.observation_spaces[agent_id].dtype
assert isinstance(info_i, dict)
@@ -109,11 +109,11 @@ def test_step_sync_vector_env(use_single_action_space):
assert isinstance(env.observation_spaces[i], spaces.Box)
assert isinstance(observations[i], np.ndarray)
assert observations[i].shape == env.observation_spaces[i].shape
- assert observations[i].shape == (8,) + env.single_observation_spaces[i].shape
+ assert observations[i].shape == (8, *env.single_observation_spaces[i].shape)
assert observations[i].dtype == env.observation_spaces[i].dtype
assert isinstance(rewards[i], np.ndarray)
- assert isinstance(rewards[i][0], (float, np.floating))
+ assert isinstance(rewards[i][0], float | np.floating)
assert rewards[i].shape == (8,)
assert isinstance(terminations[i], np.ndarray)
diff --git a/tests/wrappers/test_discretize_actions.py b/tests/wrappers/test_discretize_actions.py
index 81c6431..265d00a 100644
--- a/tests/wrappers/test_discretize_actions.py
+++ b/tests/wrappers/test_discretize_actions.py
@@ -1,12 +1,8 @@
"""Test for DiscretizeActions Wrapper."""
-from typing import cast
-
import numpy as np
+import posggym
import pytest
from gymnasium import spaces
-
-import posggym
-from posggym.envs.continuous.driving_continuous import DrivingContinuousModel
from posggym.wrappers import DiscretizeActions
@@ -22,16 +18,13 @@ def test_discretize_actions_flatten(num_actions):
)
wrapped_env = DiscretizeActions(env, num_actions=num_actions, flatten=True)
- model = cast(DrivingContinuousModel, env.model)
-
box_act_dim = 2
base_space = spaces.Box(
- low=np.array([-model.dyaw_limit, -model.dvel_limit], dtype=np.float32),
- high=np.array([model.dyaw_limit, model.dvel_limit], dtype=np.float32),
+ low=np.array([-1, -1], dtype=np.float32),
+ high=np.array([1, 1], dtype=np.float32),
)
n_flat_actions = np.prod([num_actions] * box_act_dim)
- # wrapped_space = spaces.Discrete(n_flat_actions)
assert all(
act_space.n == n_flat_actions
for act_space in wrapped_env.action_spaces.values()
@@ -71,16 +64,12 @@ def test_discretize_actions_multidiscrete(num_actions):
)
wrapped_env = DiscretizeActions(env, num_actions=num_actions, flatten=False)
- model = cast(DrivingContinuousModel, env.model)
-
box_act_dim = 2
base_space = spaces.Box(
- low=np.array([-model.dyaw_limit, -model.dvel_limit], dtype=np.float32),
- high=np.array([model.dyaw_limit, model.dvel_limit], dtype=np.float32),
+ low=np.array([-1, -1], dtype=np.float32),
+ high=np.array([1, 1], dtype=np.float32),
)
- # wrapped_space = spaces.MultiDiscrete([num_actions] * box_act_dim)
-
# perform actions and then check last_action from unwrapped_env
env.reset()
wrapped_env.reset()
diff --git a/tests/wrappers/test_flatten_observations.py b/tests/wrappers/test_flatten_observations.py
index b5f6b5a..7309d29 100644
--- a/tests/wrappers/test_flatten_observations.py
+++ b/tests/wrappers/test_flatten_observations.py
@@ -7,7 +7,7 @@
import numpy as np
import posggym
from gymnasium import spaces
-from posggym.envs.grid_world.driving import CELL_OBS, Speed
+from posggym.envs.grid_world.driving import CELL_OBS
from posggym.wrappers import FlattenObservations
@@ -34,7 +34,7 @@ def test_flatten_observation():
spaces.Discrete(len(CELL_OBS)) for _ in range(obs_depth * obs_width)
)
),
- spaces.Discrete(len(Speed)),
+ spaces.Discrete(4), # reverse, stopped, forward, forward fast,
spaces.Tuple(
(spaces.Discrete(grid_width), spaces.Discrete(grid_height))
), # current coord
@@ -50,7 +50,7 @@ def test_flatten_observation():
1,
[
len(CELL_OBS) * obs_depth * obs_width
- + len(Speed)
+ + 4 # reverse, stopped, forward, forward fast
+ grid_width
+ grid_height
+ grid_width
@@ -62,10 +62,10 @@ def test_flatten_observation():
)
assert all(i in obs for i in env.agents)
- for i, obs_i in obs.items():
+ for _i, obs_i in obs.items():
assert space.contains(obs_i)
assert all(i in wrapped_obs for i in env.agents)
- for i, wrapped_obs_i in wrapped_obs.items():
+ for _i, wrapped_obs_i in wrapped_obs.items():
assert wrapped_space.contains(wrapped_obs_i), wrapped_obs_i.shape
assert isinstance(info, dict)
assert isinstance(wrapped_obs_info, dict)
diff --git a/tests/wrappers/test_order_enforcing.py b/tests/wrappers/test_order_enforcing.py
index 9972abf..4dce39b 100644
--- a/tests/wrappers/test_order_enforcing.py
+++ b/tests/wrappers/test_order_enforcing.py
@@ -4,12 +4,12 @@
https://github.com/Farama-Foundation/Gymnasium/blob/v0.27.0/tests/wrappers/test_order_enforcing.py
"""
-import pytest
-
import posggym
+import pytest
from posggym.envs.classic.mabc import MABCEnv
-from posggym.error import ResetNeeded
+from posggym.error import ResetNeededError
from posggym.wrappers import OrderEnforcing
+
from tests.envs.utils import all_testing_env_specs
from tests.wrappers.utils import has_wrapper
@@ -36,9 +36,9 @@ def test_order_enforcing():
# Assert that the order enforcing works for step and render before reset
order_enforced_env = OrderEnforcing(env)
assert order_enforced_env.has_reset is False
- with pytest.raises(ResetNeeded):
+ with pytest.raises(ResetNeededError):
order_enforced_env.step({i: 0 for i in env.possible_agents})
- with pytest.raises(ResetNeeded):
+ with pytest.raises(ResetNeededError):
order_enforced_env.render()
assert order_enforced_env.has_reset is False
diff --git a/tests/wrappers/test_petting_zoo.py b/tests/wrappers/test_petting_zoo.py
index 8efa191..cf15b7c 100644
--- a/tests/wrappers/test_petting_zoo.py
+++ b/tests/wrappers/test_petting_zoo.py
@@ -17,8 +17,8 @@ def test_make_petting_zoo(spec):
from pettingzoo.utils import agent_selector # type: ignore
from pettingzoo.utils.conversions import parallel_to_aec_wrapper # type: ignore
from posggym.wrappers.petting_zoo import PettingZoo
- except (ImportError, posggym.error.DependencyNotInstalled) as e:
- pytest.skip(f"pettingzoo not installed.: {str(e)}")
+ except (ImportError, posggym.error.DependencyNotInstalledError) as e:
+ pytest.skip(f"pettingzoo not installed.: {e!s}")
class custom_parallel_to_aec_wrapper(parallel_to_aec_wrapper):
"""PettingZoo ParallelEnv to AECEnv converter.
diff --git a/tests/wrappers/test_record_episode_statistics.py b/tests/wrappers/test_record_episode_statistics.py
index 871b9f3..8e6a68e 100644
--- a/tests/wrappers/test_record_episode_statistics.py
+++ b/tests/wrappers/test_record_episode_statistics.py
@@ -42,9 +42,7 @@ def test_record_episode_statistics(env_id, deque_size):
assert len(infos) == len(env.possible_agents)
for i in env.possible_agents:
assert "episode" in infos[i]
- assert all(
- [item in infos[i]["episode"] for item in ["r", "l", "t"]]
- )
+ assert all(item in infos[i]["episode"] for item in ["r", "l", "t"])
assert np.isclose(infos[i]["episode"]["r"], agent_returns[i])
assert infos[i]["episode"]["l"] == t + 1
break
@@ -95,7 +93,7 @@ def _make_env():
assert "episode" in infos[i]
assert "_episode" in infos[i]
assert all(infos[i]["_episode"] == dones)
- assert all([item in infos[i]["episode"] for item in ["r", "l", "t"]])
+ assert all(item in infos[i]["episode"] for item in ["r", "l", "t"])
break
else:
for i in envs.possible_agents:
diff --git a/tests/wrappers/test_record_video.py b/tests/wrappers/test_record_video.py
index 98d757b..6c209ed 100644
--- a/tests/wrappers/test_record_video.py
+++ b/tests/wrappers/test_record_video.py
@@ -6,6 +6,7 @@
"""
import shutil
from pathlib import Path
+
import posggym
from posggym.wrappers.record_video import RecordVideo, capped_cubic_video_schedule
diff --git a/tests/wrappers/test_rescale_actions.py b/tests/wrappers/test_rescale_actions.py
index 0fee28a..3976702 100644
--- a/tests/wrappers/test_rescale_actions.py
+++ b/tests/wrappers/test_rescale_actions.py
@@ -1,12 +1,9 @@
"""Test for RescaleActions Wrapper."""
-from typing import cast
import numpy as np
+import posggym
import pytest
from gymnasium import spaces
-
-import posggym
-from posggym.envs.continuous.driving_continuous import DrivingContinuousModel
from posggym.wrappers import RescaleActions
@@ -34,16 +31,14 @@ def test_rescale_actions(min_val, max_val):
)
wrapped_env = RescaleActions(env, min_action=min_val, max_action=max_val)
- model = cast(DrivingContinuousModel, env.model)
-
box_act_dim = 2
base_space = spaces.Box(
- low=np.array([-model.dyaw_limit, -model.dvel_limit], dtype=np.float32),
- high=np.array([model.dyaw_limit, model.dvel_limit], dtype=np.float32),
+ low=np.array([-1, -1], dtype=np.float32),
+ high=np.array([1, 1], dtype=np.float32),
)
wrapped_spaces = {}
- if isinstance(min_val, (int, float)):
+ if isinstance(min_val, int | float):
# assume max_val also (int, float)
wrapped_spaces = {
i: spaces.Box(
diff --git a/tests/wrappers/test_rescale_observations.py b/tests/wrappers/test_rescale_observations.py
index cf05304..c1eb267 100644
--- a/tests/wrappers/test_rescale_observations.py
+++ b/tests/wrappers/test_rescale_observations.py
@@ -2,10 +2,9 @@
import math
import numpy as np
+import posggym
import pytest
from gymnasium import spaces
-
-import posggym
from posggym.wrappers import RescaleObservations
@@ -41,12 +40,14 @@ def test_rescale_observation(min_val, max_val):
sensors_dim, obs_dim = n_sensors * 2, n_sensors * 2 + 5
sensor_low, sensor_high = [0.0] * sensors_dim, [obs_dist] * sensors_dim
base_space = spaces.Box(
- low=np.array([*sensor_low, -2 * math.pi, -1, -1, 0, 0], dtype=np.float32),
+ low=np.array(
+ [*sensor_low, -2 * math.pi, -1, -1, -size, -size], dtype=np.float32
+ ),
high=np.array([*sensor_high, 2 * math.pi, 1, 1, size, size], dtype=np.float32),
)
wrapped_spaces = {}
- if isinstance(min_val, (int, float)):
+ if isinstance(min_val, int | float):
# assume max_val also (int, float)
wrapped_spaces = {
i: spaces.Box(
@@ -80,7 +81,7 @@ def test_rescale_observation(min_val, max_val):
}
assert all(i in obs for i in env.agents)
- for i, obs_i in obs.items():
+ for _i, obs_i in obs.items():
assert base_space.contains(obs_i), (obs_i, base_space)
assert all(i in wrapped_obs for i in env.agents)
diff --git a/tests/wrappers/test_rllib_multi_agent_env.py b/tests/wrappers/test_rllib_multi_agent_env.py
index 58379fc..49cb0dd 100644
--- a/tests/wrappers/test_rllib_multi_agent_env.py
+++ b/tests/wrappers/test_rllib_multi_agent_env.py
@@ -25,7 +25,7 @@ def test_make_rllib_multi_agent_env(spec):
from posggym.posggym.wrappers.rllib_env import RllibMultiAgentEnv
from ray.rllib.utils.pre_checks.env import check_env
except ImportError as e:
- pytest.skip(f"ray[rllib] not installed.: {str(e)}")
+ pytest.skip(f"ray[rllib] not installed.: {e!s}")
env = posggym.make(spec.id, disable_env_checker=True)
rllib_env = RllibMultiAgentEnv(env)
diff --git a/tests/wrappers/test_time_limit.py b/tests/wrappers/test_time_limit.py
index 0769978..fc39713 100644
--- a/tests/wrappers/test_time_limit.py
+++ b/tests/wrappers/test_time_limit.py
@@ -4,9 +4,8 @@
https://github.com/Farama-Foundation/Gymnasium/blob/v0.27.0/tests/wrappers/test_time_limit.py
"""
-import pytest
-
import posggym
+import pytest
from posggym.envs.classic.mabc import MABCEnv
from posggym.wrappers import TimeLimit
diff --git a/tests/wrappers/test_video_recorder.py b/tests/wrappers/test_video_recorder.py
index 495d2b7..b78a25d 100644
--- a/tests/wrappers/test_video_recorder.py
+++ b/tests/wrappers/test_video_recorder.py
@@ -5,6 +5,7 @@
"""
import re
+from typing import ClassVar
import posggym
import pytest
@@ -12,9 +13,9 @@
class BrokenRecordableEnv(posggym.Env):
- metadata = {"render_modes": ["rgb_array"]}
+ metadata: ClassVar[dict] = {"render_modes": ["rgb_array"]}
- def __init__(self, render_mode="rgb_array"):
+ def __init__(self, render_mode="rgb_array") -> None:
self.render_mode = render_mode
def render(self):
@@ -29,9 +30,9 @@ def state(self):
class UnrecordableEnv(posggym.Env):
- metadata = {"render_modes": [None]}
+ metadata: ClassVar[dict] = {"render_modes": [None]}
- def __init__(self, render_mode=None):
+ def __init__(self, render_mode=None) -> None:
self.render_mode = render_mode
def render(self):
@@ -103,13 +104,3 @@ def test_record_breaking_render_method():
# def test_text_envs():
-# env = posggym.make(
-# "MultiAgentTiger-v0", render_mode="ansi", disable_env_checker=True
-# )
-# video = VideoRecorder(env)
-# try:
-# env.reset()
-# video.capture_frame()
-# video.close()
-# finally:
-# os.remove(video.path)
diff --git a/tests/wrappers/utils.py b/tests/wrappers/utils.py
index 6e0e7c5..1d1048e 100644
--- a/tests/wrappers/utils.py
+++ b/tests/wrappers/utils.py
@@ -15,7 +15,7 @@ def has_wrapper(wrapped_env: posggym.Env, wrapper_type: type) -> bool:
def assert_equals(a, b, prefix=None):
"""Assert equality of data structures `a` and `b`.
- Arguments
+ Arguments:
---------
a:
first data structure
@@ -35,7 +35,7 @@ def assert_equals(a, b, prefix=None):
elif isinstance(a, np.ndarray):
np.testing.assert_array_equal(a, b)
elif isinstance(a, tuple):
- for elem_from_a, elem_from_b in zip(a, b):
+ for elem_from_a, elem_from_b in zip(a, b, strict=False):
assert_equals(elem_from_a, elem_from_b, prefix)
else:
assert a == b