From 6035b6e44a64ed3c742f2b63cf474a9ad3a0c713 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 22 Nov 2024 20:28:57 +0100 Subject: [PATCH 1/8] Add pyproject and Makefile --- Makefile | 28 ++++++++++++++++++++++++++++ README.md | 41 ++++++++++++++++++++--------------------- pyproject.toml | 29 +++++++++++++++++++++++++++++ setup.py | 12 +++++++----- train_sbx.py | 4 +++- 5 files changed, 87 insertions(+), 27 deletions(-) create mode 100644 Makefile create mode 100644 pyproject.toml diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..9dfa5d2 --- /dev/null +++ b/Makefile @@ -0,0 +1,28 @@ +SHELL=/bin/bash +LINT_PATHS=frasa_env/ setup.py + +mypy: + mypy ${LINT_PATHS} + +lint: + # stop the build if there are Python syntax errors or undefined names + # see https://www.flake8rules.com/ + ruff check ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full + # exit-zero treats all errors as warnings. + ruff check ${LINT_PATHS} --exit-zero --output-format=concise + +format: + # Sort imports + ruff check --select I ${LINT_PATHS} --fix + # Reformat using black + black ${LINT_PATHS} + +check-codestyle: + # Sort imports + ruff check --select I ${LINT_PATHS} + # Reformat using black + black --check ${LINT_PATHS} + +commit-checks: format mypy lint + +.PHONY: lint format check-codestyle commit-checks diff --git a/README.md b/README.md index 3915868..baf9a27 100644 --- a/README.md +++ b/README.md @@ -30,13 +30,13 @@ Please note that the RL Baselines3 Zoo create a conflict on the gymnasium module ### Issue with MuJoCo using Wayland -If you are using Wayland (instead of X11), you may encounter issues with the MuJoCo viewer, +If you are using Wayland (instead of X11), you may encounter issues with the MuJoCo viewer, such as frozen windows or buttons not working. To solve this, you can build GLFW from source with shared libraries and set the `LD_LIBRARY_PATH` and `PYGLFW_LIBRARY` environment variables to point to the built libraries. 1. Download source package here and unzip it: https://www.glfw.org/download.html - + 2. Install dependancies to build GLFW for Wayland and X11 ``` @@ -58,7 +58,7 @@ make ``` 5. Change LD_LIBRARY_PATH and PYGLFW_LIBRARY to match GLFW version you built and add it to your bashrc - + ``` export LD_LIBRARY_PATH="LD_LIBRARY_PATH:path/to/glfw/build/src/" export PYGLFW_LIBRARY="path/to/glfw/build/src/libglfw.so" @@ -68,18 +68,18 @@ export PYGLFW_LIBRARY="path/to/glfw/build/src/libglfw.so" ### Generating initial positions -Pre-generating initial positions for the standup environment is recommended, -as it can be time-consuming to generate them during training. To do so, you can +Pre-generating initial positions for the standup environment is recommended, +as it can be time-consuming to generate them during training. To do so, you can use the standup_generate_initial.py script: ```bash python standup_generate_initial.py ``` -It will generate initial positions by letting the robot fall from random positions +It will generate initial positions by letting the robot fall from random positions and store them in `frasa_env/env/standup_initial_configurations.pkl`. -Let the script run until you have collected enough initial positions -(typically a few thousand). You can stop the script at any time using Ctrl+C; +Let the script run until you have collected enough initial positions +(typically a few thousand). You can stop the script at any time using Ctrl+C; the generated positions will be saved automatically. @@ -104,7 +104,6 @@ You can train an agent using: python train_sbx.py \ --algo crossq \ --env frasa-standup-v0 \ - --gym-packages frasa_env \ --conf hyperparams/crossq.yml ``` @@ -150,14 +149,14 @@ Where the arguments are:
-The Sigmaban robot is a small humanoid developed by the Rhoban team to -compete in the RoboCup KidSize league. The robot is 70 cm tall and weighs -7.5 kg. It has 20 degrees of freedom, a camera, and is equipped with pressure -sensors and an IMU. +The Sigmaban robot is a small humanoid developed by the Rhoban team to +compete in the RoboCup KidSize league. The robot is 70 cm tall and weighs +7.5 kg. It has 20 degrees of freedom, a camera, and is equipped with pressure +sensors and an IMU. -The MuJoCo XML model of this robot is located in the `frasa_env/mujoco_simulator/model` folder. +The MuJoCo XML model of this robot is located in the `frasa_env/mujoco_simulator/model` folder. -For a detailed description of the process to convert URDF to MuJoCo XML, +For a detailed description of the process to convert URDF to MuJoCo XML, see the [README](frasa_env/mujoco_simulator/model/README.md) in the model directory.
@@ -172,7 +171,7 @@ The stand up environment can be found in `frasa_env/env/standup_env.py`. The environment simplifies the learning process by controlling only the 5 DoFs presented on the right (elbow, shoulder_pitch, hip_pitch, knee, ankle_pitch). These are the primary joints involved in recovery and standing movements in the sagittal plane $(x, z)$. The actions are symmetrically applied to both sides of the robot. -The robot's state is characterized by the **pose vector** $\psi$, which includes the joint angles $q$ and the trunk pitch $\theta$. The target pose for recovery is defined as: +The robot's state is characterized by the **pose vector** $\psi$, which includes the joint angles $q$ and the trunk pitch $\theta$. The target pose for recovery is defined as: $$ \psi_{\text{target}} = [q_{\text{target}}, \theta_{\text{target}}] $$ @@ -180,7 +179,7 @@ This target pose represents the upright position the robot should achieve after At the start of each episode, the robot’s joint angles $q$ and trunk pitch $\theta$ are set to random values within their physical limits. The robot is then released above the ground, simulating various fallen states. The physical simulation is -then stepped using current joints configuration as targets, until stability is reached, at which point +then stepped using current joints configuration as targets, until stability is reached, at which point the learning process begins. An episode termination occurs when the robot reaches an unsafe or irreversible state: @@ -206,7 +205,7 @@ The observation space for the **StandupEnv** captures key elements of the robot' ### Action Space -The action space specifies control commands for the robot's 5 joints. Its structure depends on the control mode (`position`, `velocity`, or `error`) specified in `options["control"]`. +The action space specifies control commands for the robot's 5 joints. Its structure depends on the control mode (`position`, `velocity`, or `error`) specified in `options["control"]`. | **Mode** | **Description** | **Range** | **Notes** | |---|---|---|---| @@ -226,7 +225,7 @@ $$ R = R_{state} + R_{variation} + R_{collision} $$ The state proximity reward $R_{state}$ represents the proximity to the desired state $\psi_{target}$: -$$ R_{state} = \text{exp}\left(-20 \cdot \left| \psi_{current} - \psi_\{target} \right|^2 \right) $$ +$$ R_{state} = \text{exp}\left(-20 \cdot \left| \psi_{current} - \psi_\{target} \right|^2 \right) $$ 2. Variation Penalty Reward @@ -294,12 +293,12 @@ To cite this repository in publications: ```bibtex @article{frasa2024, - title={FRASA: An End-to-End Reinforcement Learning Agent for Fall Recovery and Stand Up of Humanoid Robots}, + title={FRASA: An End-to-End Reinforcement Learning Agent for Fall Recovery and Stand Up of Humanoid Robots}, author={Clément Gaspard and Marc Duclusaud and Grégoire Passault and Mélodie Daniel and Olivier Ly}, year={2024}, eprint={2410.08655}, archivePrefix={arXiv}, primaryClass={cs.RO}, - url={https://arxiv.org/abs/2410.08655}, + url={https://arxiv.org/abs/2410.08655}, } ``` diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..81c9cab --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,29 @@ +[tool.ruff] +# Same as Black. +line-length = 127 +# Assume Python 3.9 +target-version = "py39" + +[tool.ruff.lint] +# See https://beta.ruff.rs/docs/rules/ +select = ["E", "F", "B", "UP", "C90", "RUF"] +# B028: Ignore explicit stacklevel` +# RUF013: Too many false positives (implicit optional) +ignore = ["B028", "RUF013"] + +[tool.ruff.lint.per-file-ignores] + + +[tool.ruff.lint.mccabe] +# Unlike Flake8, default to a complexity level of 10. +max-complexity = 15 + +[tool.black] +line-length = 127 + +[tool.mypy] +ignore_missing_imports = true +follow_imports = "silent" +show_error_codes = true +# exclude = """(?x)( +# )""" diff --git a/setup.py b/setup.py index 555000a..cba0da0 100644 --- a/setup.py +++ b/setup.py @@ -1,15 +1,17 @@ from setuptools import setup +from setuptools import find_packages -setup(name="frasa_env", - version="1.0", +setup(name="frasa_env", + packages=[package for package in find_packages() if package.startswith("frasa_env")], + version="1.0", description='FRASA RL Environment', install_requires=[ - "gymnasium==0.29.1", + "gymnasium>=0.29.1,<1.1.0", "numpy>=1.20.0", "stable_baselines3>=2.1.0", - "sb3-contrib>=2.1.0", + "sb3-contrib>=2.1.0", "mujoco>=3.1.5", "meshcat>=0.3.2", "sbx-rl>=0.17.0", - ], + ], ) diff --git a/train_sbx.py b/train_sbx.py index 32495df..466e2f8 100644 --- a/train_sbx.py +++ b/train_sbx.py @@ -3,6 +3,8 @@ from rl_zoo3.train import train from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ +import frasa_env # noqa: F404 + rl_zoo3.ALGOS["ddpg"] = DDPG rl_zoo3.ALGOS["dqn"] = DQN # See SBX readme to use DroQ configuration @@ -17,4 +19,4 @@ if __name__ == "__main__": - train() \ No newline at end of file + train() From c65954f28401cac99528e343837f0dc1c9162090 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 22 Nov 2024 20:33:25 +0100 Subject: [PATCH 2/8] Reformat --- enjoy_sbx.py | 4 +++- frasa_env/__init__.py | 2 +- frasa_env/env/__init__.py | 7 +++++- frasa_env/env/standup.py | 6 +++-- frasa_env/env/standup_env.py | 9 ++++---- frasa_env/mujoco_simulator/simulator.py | 16 +++++-------- pyproject.toml | 3 ++- setup.py | 30 ++++++++++++------------- train_sbx.py | 2 +- 9 files changed, 43 insertions(+), 36 deletions(-) diff --git a/enjoy_sbx.py b/enjoy_sbx.py index 1c1ce04..b649be0 100644 --- a/enjoy_sbx.py +++ b/enjoy_sbx.py @@ -3,6 +3,8 @@ from rl_zoo3.enjoy import enjoy from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ +import frasa_env # noqa: F401 + rl_zoo3.ALGOS["ddpg"] = DDPG rl_zoo3.ALGOS["dqn"] = DQN # See SBX readme to use DroQ configuration @@ -17,4 +19,4 @@ if __name__ == "__main__": - enjoy() \ No newline at end of file + enjoy() diff --git a/frasa_env/__init__.py b/frasa_env/__init__.py index 26463dc..83d9bbc 100644 --- a/frasa_env/__init__.py +++ b/frasa_env/__init__.py @@ -2,5 +2,5 @@ register( id="frasa-standup-v0", - entry_point="frasa_env.env:FRASAEnv" + entry_point="frasa_env.env:FRASAEnv", ) diff --git a/frasa_env/env/__init__.py b/frasa_env/env/__init__.py index e30deab..2e86685 100644 --- a/frasa_env/env/__init__.py +++ b/frasa_env/env/__init__.py @@ -1,3 +1,8 @@ +from frasa_env.env.standup import FRASAEnv from frasa_env.env.standup_env import StandupEnv -from frasa_env.env.standup import FRASAEnv \ No newline at end of file + +__all__ = [ + "FRASAEnv", + "StandupEnv" +] diff --git a/frasa_env/env/standup.py b/frasa_env/env/standup.py index 4eedf04..18cbc72 100644 --- a/frasa_env/env/standup.py +++ b/frasa_env/env/standup.py @@ -1,6 +1,8 @@ -from .standup_env import StandupEnv import numpy as np +from .standup_env import StandupEnv + + class FRASAEnv(StandupEnv): def __init__(self, evaluation="none", render_mode="none", options={}): options["stabilization_time"] = 2.0 @@ -9,4 +11,4 @@ def __init__(self, evaluation="none", render_mode="none", options={}): options["vmax"] = 2 * np.pi options["reset_final_p"] = 0.1 - super().__init__(evaluation=evaluation, render_mode=render_mode, options=options) \ No newline at end of file + super().__init__(evaluation=evaluation, render_mode=render_mode, options=options) diff --git a/frasa_env/env/standup_env.py b/frasa_env/env/standup_env.py index ac39899..f4c170f 100644 --- a/frasa_env/env/standup_env.py +++ b/frasa_env/env/standup_env.py @@ -1,10 +1,11 @@ -import numpy as np +import os import pickle import random -import os -import time + import gymnasium +import numpy as np from gymnasium import spaces + from frasa_env.mujoco_simulator.simulator import Simulator, tf @@ -285,7 +286,7 @@ def step(self, action): state_current = [*self.q_history[-1], self.tilt_history[-1]] - reward = np.exp(-20*(np.linalg.norm(np.array(state_current) - np.array(self.options["desired_state"]))**2)) + reward = np.exp(-20 * (np.linalg.norm(np.array(state_current) - np.array(self.options["desired_state"])) ** 2)) action_variation = np.abs(action - self.previous_actions[-1]) self.previous_actions.append(action) diff --git a/frasa_env/mujoco_simulator/simulator.py b/frasa_env/mujoco_simulator/simulator.py index b7eac7f..a5b13ee 100644 --- a/frasa_env/mujoco_simulator/simulator.py +++ b/frasa_env/mujoco_simulator/simulator.py @@ -1,9 +1,10 @@ import os -import numpy as np import time + +import meshcat.transformations as tf import mujoco import mujoco.viewer -import meshcat.transformations as tf +import numpy as np class Simulator: @@ -14,9 +15,7 @@ def __init__(self, model_dir: str | None = None): self.model_dir = model_dir # Load the model and data - self.model: mujoco.MjModel = mujoco.MjModel.from_xml_path( - f"{model_dir}/scene.xml" - ) + self.model: mujoco.MjModel = mujoco.MjModel.from_xml_path(f"{model_dir}/scene.xml") self.data: mujoco.MjData = mujoco.MjData(self.model) # Retrieve the degrees of freedom id/name pairs @@ -47,12 +46,9 @@ def self_collisions(self) -> float: def centroidal_force(self) -> float: return np.linalg.norm(self.data.qfrc_constraint[3:]) - + def dof_names(self) -> list: - return [ - mujoco.mj_id2name(self.model, mujoco.mjtObj.mjOBJ_ACTUATOR, i) - for i in range(self.model.nu) - ] + return [mujoco.mj_id2name(self.model, mujoco.mjtObj.mjOBJ_ACTUATOR, i) for i in range(self.model.nu)] def reset(self) -> None: mujoco.mj_resetData(self.model, self.data) diff --git a/pyproject.toml b/pyproject.toml index 81c9cab..ee17a68 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,8 @@ target-version = "py39" select = ["E", "F", "B", "UP", "C90", "RUF"] # B028: Ignore explicit stacklevel` # RUF013: Too many false positives (implicit optional) -ignore = ["B028", "RUF013"] +# RUF012: ClassVar +ignore = ["B028", "RUF013", "RUF012"] [tool.ruff.lint.per-file-ignores] diff --git a/setup.py b/setup.py index cba0da0..fcccf91 100644 --- a/setup.py +++ b/setup.py @@ -1,17 +1,17 @@ -from setuptools import setup -from setuptools import find_packages +from setuptools import find_packages, setup -setup(name="frasa_env", +setup( + name="frasa_env", packages=[package for package in find_packages() if package.startswith("frasa_env")], - version="1.0", - description='FRASA RL Environment', - install_requires=[ - "gymnasium>=0.29.1,<1.1.0", - "numpy>=1.20.0", - "stable_baselines3>=2.1.0", - "sb3-contrib>=2.1.0", - "mujoco>=3.1.5", - "meshcat>=0.3.2", - "sbx-rl>=0.17.0", - ], - ) + version="1.0", + description="FRASA RL Environment", + install_requires=[ + "gymnasium>=0.29.1,<1.1.0", + "numpy>=1.20.0", + "stable_baselines3>=2.1.0", + "sb3-contrib>=2.1.0", + "mujoco>=3.1.5", + "meshcat>=0.3.2", + "sbx-rl>=0.17.0", + ], +) diff --git a/train_sbx.py b/train_sbx.py index 466e2f8..91ec284 100644 --- a/train_sbx.py +++ b/train_sbx.py @@ -3,7 +3,7 @@ from rl_zoo3.train import train from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ -import frasa_env # noqa: F404 +import frasa_env # noqa: F401 rl_zoo3.ALGOS["ddpg"] = DDPG rl_zoo3.ALGOS["dqn"] = DQN From 25019fb5d2635cf98fb4e1ca428a0c61bccd8a8d Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 22 Nov 2024 20:38:03 +0100 Subject: [PATCH 3/8] Fix warnings --- frasa_env/env/__init__.py | 6 +----- frasa_env/env/standup.py | 3 ++- frasa_env/env/standup_env.py | 11 ++++++----- 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/frasa_env/env/__init__.py b/frasa_env/env/__init__.py index 2e86685..1bf2727 100644 --- a/frasa_env/env/__init__.py +++ b/frasa_env/env/__init__.py @@ -1,8 +1,4 @@ from frasa_env.env.standup import FRASAEnv from frasa_env.env.standup_env import StandupEnv - -__all__ = [ - "FRASAEnv", - "StandupEnv" -] +__all__ = ["FRASAEnv", "StandupEnv"] diff --git a/frasa_env/env/standup.py b/frasa_env/env/standup.py index 18cbc72..3f77975 100644 --- a/frasa_env/env/standup.py +++ b/frasa_env/env/standup.py @@ -4,7 +4,8 @@ class FRASAEnv(StandupEnv): - def __init__(self, evaluation="none", render_mode="none", options={}): + def __init__(self, evaluation=False, render_mode="none", options=None): + options = options or {} options["stabilization_time"] = 2.0 options["truncate_duration"] = 5.0 options["dt"] = 0.05 diff --git a/frasa_env/env/standup_env.py b/frasa_env/env/standup_env.py index f4c170f..864773d 100644 --- a/frasa_env/env/standup_env.py +++ b/frasa_env/env/standup_env.py @@ -1,6 +1,7 @@ import os import pickle import random +from typing import Optional import gymnasium import numpy as np @@ -12,7 +13,7 @@ class StandupEnv(gymnasium.Env): metadata = {"render_modes": ["human", "none"]} - def __init__(self, render_mode="none", options: dict = {}, evaluation: bool = False): + def __init__(self, render_mode="none", options: Optional[dict] = None, evaluation: bool = False): self.options = { # Duration of the stabilization pre-simulation (waiting for the gravity to stabilize the robot) [s] "stabilization_time": 2.0, @@ -53,7 +54,7 @@ def __init__(self, render_mode="none", options: dict = {}, evaluation: bool = Fa # Previous actions "previous_actions": 1, } - self.options.update(options) + self.options.update(options or {}) self.render_mode = render_mode self.sim = Simulator() @@ -369,7 +370,7 @@ def apply_randomization(self): def randomize_fall(self, target: bool = False): # Decide if we will use the target my_target = np.copy(self.options["desired_state"]) - if target == False: + if target is False: target = self.np_random.random() < self.options["reset_final_p"] # Selecting a random configuration @@ -397,7 +398,7 @@ def randomize_fall(self, target: bool = False): self.sim.set_T_world_site("trunk", T_world_trunk) # Wait for the robot to stabilize - for k in range(round(self.options["stabilization_time"] / self.sim.dt)): + for _ in range(round(self.options["stabilization_time"] / self.sim.dt)): self.sim.step() def reset( @@ -405,7 +406,7 @@ def reset( seed: int = None, target: bool = False, use_cache: bool = True, - options: dict = {}, + options: Optional[dict] = None, ): super().reset(seed=seed) self.sim.reset() From 214817092f7b21418a0d31e289ee9035d408d9f0 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 22 Nov 2024 20:41:48 +0100 Subject: [PATCH 4/8] Reformat files and add py39 compat --- Makefile | 2 +- enjoy_sbx.py | 5 +++- frasa_env/mujoco_simulator/simulator.py | 3 ++- standup_generate_initial.py | 8 ++++-- test_env.py | 34 ++++++++++--------------- train_sbx.py | 5 +++- 6 files changed, 30 insertions(+), 27 deletions(-) diff --git a/Makefile b/Makefile index 9dfa5d2..c549021 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ SHELL=/bin/bash -LINT_PATHS=frasa_env/ setup.py +LINT_PATHS=frasa_env/ *.py mypy: mypy ${LINT_PATHS} diff --git a/enjoy_sbx.py b/enjoy_sbx.py index b649be0..de5c39c 100644 --- a/enjoy_sbx.py +++ b/enjoy_sbx.py @@ -1,9 +1,12 @@ +import gymnasium as gym import rl_zoo3 import rl_zoo3.enjoy from rl_zoo3.enjoy import enjoy from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ -import frasa_env # noqa: F401 +import frasa_env + +gym.register_envs(frasa_env) rl_zoo3.ALGOS["ddpg"] = DDPG rl_zoo3.ALGOS["dqn"] = DQN diff --git a/frasa_env/mujoco_simulator/simulator.py b/frasa_env/mujoco_simulator/simulator.py index a5b13ee..8770c95 100644 --- a/frasa_env/mujoco_simulator/simulator.py +++ b/frasa_env/mujoco_simulator/simulator.py @@ -1,5 +1,6 @@ import os import time +from typing import Optional import meshcat.transformations as tf import mujoco @@ -8,7 +9,7 @@ class Simulator: - def __init__(self, model_dir: str | None = None): + def __init__(self, model_dir: Optional[str] = None): # If model_dir is not provided, use the current directory if model_dir is None: model_dir = os.path.join(os.path.dirname(__file__) + "/model/") diff --git a/standup_generate_initial.py b/standup_generate_initial.py index 9665649..e0d3e46 100644 --- a/standup_generate_initial.py +++ b/standup_generate_initial.py @@ -1,8 +1,12 @@ -import frasa_env.env -import pickle import os +import pickle + import gymnasium as gym +import frasa_env + +gym.register_envs(frasa_env) + env = gym.make("frasa-standup-v0") configs: list = [] filename: str = env.unwrapped.get_initial_config_filename() diff --git a/test_env.py b/test_env.py index 4024b68..63506d9 100644 --- a/test_env.py +++ b/test_env.py @@ -1,31 +1,23 @@ -import frasa_env -import numpy as np +import argparse + import gymnasium as gym +import numpy as np from stable_baselines3.common.noise import ( NormalActionNoise, OrnsteinUhlenbeckActionNoise, ) -import argparse -argparser = argparse.ArgumentParser( - description="Test the sigmaban-standup-v0 environment" -) -argparser.add_argument( - "--env", type=str, default="frasa-standup-v0", help="Environment to test" -) -argparser.add_argument( - "--random", action="store_true", help="Use random actions instead of zeros" -) +import frasa_env + +gym.register_envs(frasa_env) + +argparser = argparse.ArgumentParser(description="Test the sigmaban-standup-v0 environment") +argparser.add_argument("--env", type=str, default="frasa-standup-v0", help="Environment to test") +argparser.add_argument("--random", action="store_true", help="Use random actions instead of zeros") argparser.add_argument("--normal", action="store_true", help="Use normal action noise") -argparser.add_argument( - "--orn", action="store_true", help="Use Ornstein-Uhlenbeck action noise" -) -argparser.add_argument( - "--std", type=float, default=0.1, help="Standard deviation for the action noise" -) -argparser.add_argument( - "--theta", type=float, default=0.15, help="Theta for the Ornstein-Uhlenbeck noise" -) +argparser.add_argument("--orn", action="store_true", help="Use Ornstein-Uhlenbeck action noise") +argparser.add_argument("--std", type=float, default=0.1, help="Standard deviation for the action noise") +argparser.add_argument("--theta", type=float, default=0.15, help="Theta for the Ornstein-Uhlenbeck noise") args = argparser.parse_args() env = gym.make(args.env) diff --git a/train_sbx.py b/train_sbx.py index 91ec284..e92bab5 100644 --- a/train_sbx.py +++ b/train_sbx.py @@ -1,9 +1,12 @@ +import gymnasium as gym import rl_zoo3 import rl_zoo3.train from rl_zoo3.train import train from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ -import frasa_env # noqa: F401 +import frasa_env + +gym.register_envs(frasa_env) rl_zoo3.ALGOS["ddpg"] = DDPG rl_zoo3.ALGOS["dqn"] = DQN From 5fc2ba32d31b5d4acf8b371f01f8a586869a942d Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 22 Nov 2024 20:49:18 +0100 Subject: [PATCH 5/8] Fix for gym reset --- frasa_env/env/standup_env.py | 5 +++-- standup_generate_initial.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/frasa_env/env/standup_env.py b/frasa_env/env/standup_env.py index 864773d..1361a12 100644 --- a/frasa_env/env/standup_env.py +++ b/frasa_env/env/standup_env.py @@ -404,12 +404,13 @@ def randomize_fall(self, target: bool = False): def reset( self, seed: int = None, - target: bool = False, - use_cache: bool = True, options: Optional[dict] = None, ): super().reset(seed=seed) self.sim.reset() + options = options or {} + target = options.get("target", False) + use_cache = options.get("use_cache", True) # Initial robot configuration if use_cache and self.initial_config is not None: diff --git a/standup_generate_initial.py b/standup_generate_initial.py index e0d3e46..bc0f60d 100644 --- a/standup_generate_initial.py +++ b/standup_generate_initial.py @@ -16,7 +16,7 @@ try: while True: - env.reset(use_cache=False) + env.reset(options={"use_cache": False}) configs.append([env.unwrapped.sim.data.qpos.copy(), env.unwrapped.sim.data.ctrl.copy()]) if len(configs) % 100 == 0: From 706b230de717b1813149f722b03fc7bbf5eb7d0b Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 22 Nov 2024 21:10:41 +0100 Subject: [PATCH 6/8] Remove warnings --- frasa_env/env/standup_env.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/frasa_env/env/standup_env.py b/frasa_env/env/standup_env.py index 1361a12..5c488a4 100644 --- a/frasa_env/env/standup_env.py +++ b/frasa_env/env/standup_env.py @@ -11,7 +11,7 @@ class StandupEnv(gymnasium.Env): - metadata = {"render_modes": ["human", "none"]} + metadata = {"render_modes": ["human", "none"], "render_fps": 30} def __init__(self, render_mode="none", options: Optional[dict] = None, evaluation: bool = False): self.options = { @@ -82,14 +82,14 @@ def __init__(self, render_mode="none", options: Optional[dict] = None, evaluatio ) elif self.options["control"] == "velocity": self.action_space = spaces.Box( - np.array([-self.options["vmax"]] * len(self.dofs)), - np.array([self.options["vmax"]] * len(self.dofs)), + np.array([-self.options["vmax"]] * len(self.dofs), dtype=np.float32), + np.array([self.options["vmax"]] * len(self.dofs), dtype=np.float32), dtype=np.float32, ) elif self.options["control"] == "error": self.action_space = spaces.Box( - np.array([-np.pi / 4] * len(self.dofs)), - np.array([np.pi / 4] * len(self.dofs)), + np.array([-np.pi / 4] * len(self.dofs), dtype=np.float32), + np.array([np.pi / 4] * len(self.dofs), dtype=np.float32), dtype=np.float32, ) else: @@ -111,7 +111,8 @@ def __init__(self, render_mode="none", options: Optional[dict] = None, evaluatio -10, # Previous action *(list(self.action_space.low) * self.options["previous_actions"]), - ] + ], + dtype=np.float32, ), np.array( [ @@ -127,7 +128,8 @@ def __init__(self, render_mode="none", options: Optional[dict] = None, evaluatio 10, # Previous action *(list(self.action_space.high) * self.options["previous_actions"]), - ] + ], + dtype=np.float32, ), dtype=np.float32, ) From f637ed1722dacbe54f06ed9dadfab10ac0aa9438 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 22 Nov 2024 21:16:02 +0100 Subject: [PATCH 7/8] Warn when cache not found --- frasa_env/env/standup_env.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/frasa_env/env/standup_env.py b/frasa_env/env/standup_env.py index 5c488a4..63b17ad 100644 --- a/frasa_env/env/standup_env.py +++ b/frasa_env/env/standup_env.py @@ -1,6 +1,7 @@ import os import pickle import random +import warnings from typing import Optional import gymnasium @@ -160,7 +161,7 @@ def __init__(self, render_mode="none", options: Optional[dict] = None, evaluatio initial_config_path = self.get_initial_config_filename() self.initial_config = None if os.path.exists(initial_config_path): - print(f"Loading initial configurations from {initial_config_path}") + # print(f"Loading initial configurations from {initial_config_path}") with open(initial_config_path, "rb") as f: self.initial_config = pickle.load(f) @@ -414,6 +415,11 @@ def reset( target = options.get("target", False) use_cache = options.get("use_cache", True) + if use_cache and self.initial_config is not None: + warnings.warn( + "use_cache=True but no initial config file could be loaded." "Did you run standup_generate_initial.py?" + ) + # Initial robot configuration if use_cache and self.initial_config is not None: qpos, ctrl = random.choice(self.initial_config) From e3e6fdc44e11937a84a01ee2a922161c4f40f82e Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 22 Nov 2024 21:20:37 +0100 Subject: [PATCH 8/8] Fix condition --- frasa_env/env/standup_env.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/frasa_env/env/standup_env.py b/frasa_env/env/standup_env.py index 63b17ad..308fa51 100644 --- a/frasa_env/env/standup_env.py +++ b/frasa_env/env/standup_env.py @@ -415,9 +415,9 @@ def reset( target = options.get("target", False) use_cache = options.get("use_cache", True) - if use_cache and self.initial_config is not None: + if use_cache and self.initial_config is None: warnings.warn( - "use_cache=True but no initial config file could be loaded." "Did you run standup_generate_initial.py?" + "use_cache=True but no initial config file could be loaded. Did you run standup_generate_initial.py?" ) # Initial robot configuration