Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
a32d1ce
changes to driving
rhys-newbury Feb 12, 2024
5b11f4a
use options
rhys-newbury Feb 12, 2024
d557ca1
fix bugs
rhys-newbury Feb 12, 2024
2a9efd2
randomize on reset
rhys-newbury Feb 13, 2024
3471af0
randomize dynamics
rhys-newbury Feb 13, 2024
19780ca
randomize dynamics
rhys-newbury Feb 13, 2024
99c2bc3
randomize more envs + fix type errors
rhys-newbury Feb 15, 2024
4489e95
add flag
rhys-newbury Feb 15, 2024
7f01f39
fix bug
rhys-newbury Feb 15, 2024
4b4f57e
update cont + random
rhys-newbury Feb 19, 2024
078faa5
improve indexing
rhys-newbury Feb 19, 2024
8c40e39
adopting to use api which should generalize, plz
rhys-newbury Feb 19, 2024
8d0faa1
add empty env
rhys-newbury Feb 20, 2024
d6e42b1
make sim more stable
rhys-newbury Feb 20, 2024
88b7876
rm try
rhys-newbury Feb 20, 2024
827a606
minimum mass
rhys-newbury Feb 20, 2024
e3e1948
update empty world
rhys-newbury Feb 26, 2024
62c78fb
rm print
rhys-newbury Feb 26, 2024
d1b40e0
check all action spaces
rhys-newbury Feb 26, 2024
136969d
some changes
rhys-newbury Feb 27, 2024
6776e24
removed abs
rhys-newbury Feb 28, 2024
f049449
undo change
rhys-newbury Feb 28, 2024
d9a2ba5
up supported num agents
rhys-newbury Feb 29, 2024
8ffa3dd
hacky changes :(
rhys-newbury Feb 29, 2024
9eeffff
merge
rhys-newbury May 15, 2024
f815e1c
Update Rules + Add Differentiable POSG
rhys-newbury May 26, 2025
5273606
improve discretize
rhys-newbury Jun 2, 2025
9b16e54
docker build fix
rhys-newbury Jun 2, 2025
f7d3a22
docker build fix 2
rhys-newbury Jun 2, 2025
0624dbd
lower pymunk
rhys-newbury Jun 2, 2025
ada9c6b
do not wrap model
rhys-newbury Jun 2, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions .github/workflows/docker.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
name: Build Docker Image

on: [push]

jobs:
docker-build:
runs-on: ubuntu-latest
env:
IMAGE_NAME: ghcr.io/${{ github.repository_owner }}/posggym:latest


permissions:
contents: read
packages: write

steps:
- name: Checkout Repository
uses: actions/checkout@v3
with:
submodules: recursive

- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

- name: Log in to GitHub Container Registry
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}

- name: Normalize tag name
run: echo "IMAGE_TAG=$(echo $IMAGE_NAME | tr '[:upper:]' '[:lower:]')" >> $GITHUB_ENV

- name: Build and Push Docker Image
uses: docker/build-push-action@v5
with:
context: .
file: Dockerfile
push: true
tags: ${{ env.IMAGE_TAG }}
cache-from: type=registry,ref=${{ env.IMAGE_TAG }}
cache-to: type=registry,ref=${{ env.IMAGE_TAG }},mode=max
16 changes: 13 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,24 @@ on: [push]
jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12"]

steps:
- name: Checkout code
uses: actions/checkout@v2

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

- name: pip install
run: pip install --upgrade pip && pip install --user -e .[all] && pip install --user -e .[testing]
- name: Install dependencies
run: |
pip install --upgrade pip
pip install --user -e .[all]
pip install --user -e .[testing]

- name: Run tests
run : pytest
run: pytest
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,4 @@ dmypy.json

# Ruff linter
.ruff_cache/
*.pickle
5 changes: 4 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,12 @@ repos:
- id: black
- repo: https://github.com/charliermarsh/ruff-pre-commit
# Ruff version.
rev: 'v0.0.254'
rev: 'v0.4.10'
hooks:
- id: ruff
args:
- --fix
- --unsafe-fixes
- repo: local
hooks:
- id: pyright
Expand Down
12 changes: 12 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
FROM pytorch/pytorch:2.2.2-cuda11.8-cudnn8-runtime

WORKDIR /app

COPY pyproject.toml ./
COPY setup.py ./
COPY ./posggym/__init__.py /app/posggym/__init__.py

RUN pip install -e .[all]


COPY . .
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
import posggym


project = "POSGGym"
copyright = "2023, Jonathon Schwartz"
author = "Jonathon Schwartz"
Expand Down
10 changes: 4 additions & 6 deletions docs/scripts/gen_agent_gifs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import re
from pathlib import Path
from pprint import pprint
from typing import Any, Dict, List
from typing import Any

import posggym
import posggym.agents as pga
Expand All @@ -27,7 +27,7 @@

def gen_agent_gif(
env_id: str,
policy_ids: List[str],
policy_ids: list[str],
ignore_existing: bool = False,
length: int = 300,
custom_env: bool = False,
Expand All @@ -43,7 +43,7 @@ def gen_agent_gif(
for policy_id in policy_ids:
try:
pi_spec = pga.spec(policy_id)
except posggym.error.NameNotFound as e:
except posggym.error.NameNotFoundError as e:
if "/" not in policy_id:
# try prepending env id
policy_id = f"{env_id}/{policy_id}"
Expand All @@ -65,8 +65,6 @@ def gen_agent_gif(
env = posggym.make(
env_id, disable_env_checker=True, render_mode="rgb_array", **env_args
)
# env = posggym.wrappers.RescaleObservations(env, min_obs=-1.0, max_obs=1.0)
# env = posggym.wrappers.RescaleActions(env, min_action=-1.0, max_action=1.0)

policies = {}
for idx, spec in enumerate(policy_specs):
Expand Down Expand Up @@ -107,7 +105,7 @@ def gen_agent_gif(
for _ in range(repeat):
frames.append(Image.fromarray(frame))

actions: Dict[str, Any] = {}
actions: dict[str, Any] = {}
for i in env.agents:
if policies[i].observes_state:
actions[i] = policies[i].step(env.state)
Expand Down
7 changes: 3 additions & 4 deletions docs/scripts/gen_agent_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

"""
import re
from typing import Dict, List
from pathlib import Path

import posggym
Expand All @@ -22,7 +21,7 @@

all_agents = list(pga.registry.values())
# env_type -> env_name -> [PolicySpec]
filtered_agents_by_env_type: Dict[str, Dict[str, List[PolicySpec]]] = {}
filtered_agents_by_env_type: dict[str, dict[str, list[PolicySpec]]] = {}

# Obtain filtered list
for pi_spec in tqdm(all_agents):
Expand Down Expand Up @@ -100,7 +99,7 @@
else:
info = (
"These policies are for the "
+ f"<a href='../../../environments/{env_type}/{snake_env_name}'>"
f"<a href='../../../environments/{env_type}/{snake_env_name}'>"
f"{title_env_name} environment</a>. Read environment page for detailed "
"information about the environment."
)
Expand All @@ -124,7 +123,7 @@
env_args_ids.sort()

if None in filtered_agents_by_env_args_id:
env_args_ids = [None] + env_args_ids
env_args_ids = [None, *env_args_ids]

for env_args_id in env_args_ids:
policy_specs = filtered_agents_by_env_args_id[env_args_id]
Expand Down
36 changes: 11 additions & 25 deletions docs/scripts/gen_env_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,21 @@

import re
from functools import reduce
from typing import Dict, List
from pathlib import Path
from tqdm import tqdm
from utils import kill_strs, trim

import posggym
from posggym.envs.registration import EnvSpec
from tqdm import tqdm

from utils import kill_strs, trim


pattern = re.compile(r"(?<!^)(?=[A-Z])")

posggym.logger.set_level(posggym.logger.DISABLED)

all_envs = list(posggym.envs.registry.values())
filtered_envs_by_type: Dict[str, Dict[str, EnvSpec]] = {}
filtered_envs_by_type: dict[str, dict[str, EnvSpec]] = {}

# Obtain filtered list
for env_spec in tqdm(all_envs):
Expand All @@ -47,7 +48,7 @@
split = str(type(env.unwrapped)).split(".")
env_name = split[3]

if env_type not in filtered_envs_by_type.keys():
if env_type not in filtered_envs_by_type:
filtered_envs_by_type[env_type] = {}
# only store new entries and higher versions
if env_name not in filtered_envs_by_type[env_type] or (
Expand All @@ -60,7 +61,7 @@
print(e)

# Sort
filtered_envs: List = list(
filtered_envs: list = list(
reduce(
lambda s, x: s + x, # type: ignore
(
Expand Down Expand Up @@ -120,15 +121,15 @@
if "rgb_array" in env.metadata["render_modes"]:
gif = (
"```{figure}"
+ f" ../../_static/videos/{env_type}/{snake_env_name}.gif"
+ f"\n:width: 200px\n:name: {snake_env_name}\n```"
f" ../../_static/videos/{env_type}/{snake_env_name}.gif"
f"\n:width: 200px\n:name: {snake_env_name}\n```"
)
else:
gif = ""
info = (
"This environment is part of the "
+ f"<a href='..'>{env_type_title} environments</a>. "
+ "Please read that page first for general information."
f"<a href='..'>{env_type_title} environments</a>. "
"Please read that page first for general information."
)

act_spaces_str = str(env.action_spaces)
Expand All @@ -144,32 +145,17 @@
env_table += f"| Symmetric | {env.is_symmetric} |\n"

# if env.observation_space.shape:
# env_table += f"| Observation Shape | {env.observation_space.shape} |\n"

# if hasattr(env.observation_space, "high"):
# high = env.observation_space.high

# if hasattr(high, "shape"):
# if len(high.shape) == 3:
# high = high[0][0][0]
# if env_type == "mujoco":
# high = high[0]
# high = np.round(high, 2)
# high = str(high).replace("\n", " ")
# env_table += f"| Observation High | {high} |\n"

# if hasattr(env.observation_space, "low"):
# low = env.observation_space.low
# if hasattr(low, "shape"):
# if len(low.shape) == 3:
# low = low[0][0][0]
# if env_type == "mujoco":
# low = low[0]
# low = np.round(low, 2)
# low = str(low).replace("\n", " ")
# env_table += f"| Observation Low | {low} |\n"
# else:
# env_table += f"| Observation Space | {env.observation_space} |\n"

env_table += f'| Import | `posggym.make("{env_spec.id}")` |\n'

Expand Down
5 changes: 3 additions & 2 deletions docs/scripts/gen_envs_display.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import sys
from pathlib import Path


DOCS_DIR = Path(__file__).resolve().parent.parent

all_envs = [
Expand Down Expand Up @@ -100,7 +101,7 @@ def generate_page(env, limit=-1, base_path=""):
type_arg = sys.argv[1]

for env in all_envs:
if type_arg == env["id"] or type_arg == "":
if type_arg in {env["id"], ""}:
type_dict_arr.append(env)

for type_dict in type_dict_arr:
Expand All @@ -127,7 +128,7 @@ def generate_page(env, limit=-1, base_path=""):
env_name = " ".join(type_id.split("_")).title()
fp.write(
f"# Complete List - {env_name}\n\n"
+ "```{raw} html\n:file: complete_list.html\n```"
"```{raw} html\n:file: complete_list.html\n```"
)
else:
page = generate_page(type_dict)
Expand Down
2 changes: 1 addition & 1 deletion docs/scripts/gen_gifs.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def gen_gif(
repeat = (
int(60 / env.metadata["render_fps"]) if env_type == "classic" else 1
)
for i in range(repeat):
for _i in range(repeat):
frames.append(Image.fromarray(frame))
action = {i: env.action_spaces[i].sample() for i in env.agents}
_, _, _, _, done, _ = env.step(action)
Expand Down
Loading