diff --git a/source/extensions/omni.isaac.groundcontrol_assets/data/Props/terrain/racetrack-terrain.usd b/source/extensions/omni.isaac.groundcontrol_assets/data/Props/terrain/racetrack-terrain.usd new file mode 100644 index 0000000..bdda1f1 Binary files /dev/null and b/source/extensions/omni.isaac.groundcontrol_assets/data/Props/terrain/racetrack-terrain.usd differ diff --git a/source/extensions/omni.isaac.groundcontrol_assets/data/Robots/UWRLL/mitcar.usd b/source/extensions/omni.isaac.groundcontrol_assets/data/Robots/UWRLL/mitcar.usd new file mode 100644 index 0000000..4dd33fc Binary files /dev/null and b/source/extensions/omni.isaac.groundcontrol_assets/data/Robots/UWRLL/mitcar.usd differ diff --git a/source/extensions/omni.isaac.groundcontrol_assets/omni/isaac/groundcontrol_assets/__init__.py b/source/extensions/omni.isaac.groundcontrol_assets/omni/isaac/groundcontrol_assets/__init__.py index 58473d1..5f17878 100644 --- a/source/extensions/omni.isaac.groundcontrol_assets/omni/isaac/groundcontrol_assets/__init__.py +++ b/source/extensions/omni.isaac.groundcontrol_assets/omni/isaac/groundcontrol_assets/__init__.py @@ -9,17 +9,17 @@ import toml # Conveniences to other module directories via relative paths -ISAACLAB_ASSETS_EXT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../")) +GROUNDCONTROL_ASSETS_EXT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../")) """Path to the extension source directory.""" -ISAACLAB_ASSETS_DATA_DIR = os.path.join(ISAACLAB_ASSETS_EXT_DIR, "data") +GROUNDCONTROL_ASSETS_DATA_DIR = os.path.join(GROUNDCONTROL_ASSETS_EXT_DIR, "data") """Path to the extension data directory.""" -ISAACLAB_ASSETS_METADATA = toml.load(os.path.join(ISAACLAB_ASSETS_EXT_DIR, "config", "extension.toml")) +GROUNDCONTROL_ASSETS_METADATA = toml.load(os.path.join(GROUNDCONTROL_ASSETS_EXT_DIR, "config", "extension.toml")) """Extension metadata dictionary parsed from the extension.toml file.""" # Configure the module-level variables -__version__ = ISAACLAB_ASSETS_METADATA["package"]["version"] +__version__ = GROUNDCONTROL_ASSETS_METADATA["package"]["version"] ## diff --git a/source/extensions/omni.isaac.groundcontrol_assets/omni/isaac/groundcontrol_assets/mitcar.py b/source/extensions/omni.isaac.groundcontrol_assets/omni/isaac/groundcontrol_assets/mitcar.py new file mode 100644 index 0000000..9390fa6 --- /dev/null +++ b/source/extensions/omni.isaac.groundcontrol_assets/omni/isaac/groundcontrol_assets/mitcar.py @@ -0,0 +1,57 @@ +from omni.isaac.lab.assets import ArticulationCfg +import omni.isaac.lab.sim as sim_utils +from omni.isaac.lab.actuators import ImplicitActuatorCfg + +from omni.isaac.groundcontrol_assets import GROUNDCONTROL_ASSETS_DATA_DIR + +JOINT_NAMES = [ + 'chassis_to_back_left_wheel', + 'chassis_to_back_right_wheel', + 'chassis_to_front_left_hinge', + 'chassis_to_front_right_hinge', + 'front_left_hinge_to_wheel', + 'front_right_hinge_to_wheel', +] + +MITCAR_CFG = ArticulationCfg( + spawn=sim_utils.UsdFileCfg( + usd_path=f"{GROUNDCONTROL_ASSETS_DATA_DIR}/Robots/UWRLL/mitcar.usd", + rigid_props=sim_utils.RigidBodyPropertiesCfg( + rigid_body_enabled=True, + max_linear_velocity=1000.0, + max_angular_velocity=1000.0, + max_depenetration_velocity=100.0, + enable_gyroscopic_forces=True, + ), + articulation_props=sim_utils.ArticulationRootPropertiesCfg( + enabled_self_collisions=False, + solver_position_iteration_count=4, + solver_velocity_iteration_count=0, + sleep_threshold=0.005, + stabilization_threshold=0.001, + ), + collision_props=sim_utils.CollisionPropertiesCfg( + collision_enabled=True + ) + ), + init_state=ArticulationCfg.InitialStateCfg( + pos=(0.0, 0.0, 0.2), + joint_pos={ + 'front_left_hinge_to_wheel': 0.0, + 'front_right_hinge_to_wheel': 0.0, + 'chassis_to_back_left_wheel': 0.0, + 'chassis_to_back_right_wheel': 0.0, + 'chassis_to_front_left_hinge': 0.0, + 'chassis_to_front_right_hinge': 0.0, + }, + ), + actuators={ + f"{k}_actuator": ImplicitActuatorCfg( + joint_names_expr=[k], + effort_limit=400.0, + velocity_limit=100.0, + stiffness=0.0, + damping=10.0, + ) for k in JOINT_NAMES + }, +) \ No newline at end of file diff --git a/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/__init__.py b/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/__init__.py index 5dd3986..204fe64 100644 --- a/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/__init__.py +++ b/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/__init__.py @@ -22,6 +22,69 @@ # Register Gym environments. ## +import gymnasium as gym + +###### WHEELED ###### + +from .manager_based.wheeled.mitcar.mitcar_manager_env_cfg import ( + MITCarRLEnvCfg, MITCarPlayEnvCfg, MITCarIRLEnvCfg +) + +import omni.isaac.groundcontrol_tasks.manager_based.wheeled.mitcar.agents as agents + +gym.register( + id='Isaac-MITCar-v0', + entry_point='omni.isaac.lab.envs:ManagerBasedRLEnv', + kwargs={ + "env_cfg_entry_point":MITCarRLEnvCfg, + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:MITCarPPORunnerCfg", + } +) + +gym.register( + id='Isaac-MITCarPlay-v0', + entry_point='omni.isaac.lab.envs:ManagerBasedRLEnv', + kwargs={ + "env_cfg_entry_point":MITCarPlayEnvCfg, + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:MITCarPPORunnerCfg", + } +) + +gym.register( + id='Isaac-MITCarIRL-v0', + entry_point='omni.isaac.lab.envs:ManagerBasedRLEnv', + kwargs={ + "env_cfg_entry_point":MITCarIRLEnvCfg, + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:MITCarPPORunnerCfg", + } +) + +######################################## +############ RACETRACK ENVS ############ +######################################## + +from .manager_based.wheeled.mitcar.mitcar_manager_racetrack_env_cfg import ( + MITCarRacetrackRLEnvCfg, MITCarRacetrackPlayEnvCfg +) + +gym.register( + id='Isaac-MITCarRacetrack-v0', + entry_point='omni.isaac.lab.envs:ManagerBasedRLEnv', + kwargs={ + "env_cfg_entry_point":MITCarRacetrackRLEnvCfg, + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:MITCarPPORunnerCfg", + } +) + +gym.register( + id='Isaac-MITCarRacetrackPlay-v0', + entry_point='omni.isaac.lab.envs:ManagerBasedRLEnv', + kwargs={ + "env_cfg_entry_point":MITCarRacetrackPlayEnvCfg, + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:MITCarPPORunnerCfg", + } +) + from .utils import import_packages # The blacklist is used to prevent importing configs from sub-packages diff --git a/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/manager_based/wheeled/mitcar/__init__.py b/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/manager_based/wheeled/mitcar/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/manager_based/wheeled/mitcar/agents/rsl_rl_ppo_cfg.py b/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/manager_based/wheeled/mitcar/agents/rsl_rl_ppo_cfg.py new file mode 100644 index 0000000..d92afd9 --- /dev/null +++ b/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/manager_based/wheeled/mitcar/agents/rsl_rl_ppo_cfg.py @@ -0,0 +1,42 @@ +# Copyright (c) 2022-2024, The Isaac Lab Project Developers. +# All rights reserved. + +from omni.isaac.lab.utils import configclass + +from omni.isaac.lab_tasks.utils.wrappers.rsl_rl import ( + RslRlOnPolicyRunnerCfg, + RslRlPpoActorCriticCfg, + RslRlPpoAlgorithmCfg, +) + + +# Adapted from lab_tasks/direct/cartpole/agents/rsl_rl_ppo_cfg.py +# +@configclass +class MITCarPPORunnerCfg(RslRlOnPolicyRunnerCfg): + num_steps_per_env = 128 + max_iterations = 150 + save_interval = 50 + experiment_name = "ppo_mitcar" + + empirical_normalization = False + policy = RslRlPpoActorCriticCfg( + init_noise_std=1.0, + actor_hidden_dims=[128, 128], + critic_hidden_dims=[128, 128], + activation="elu", + ) + algorithm = RslRlPpoAlgorithmCfg( + value_loss_coef=1.0, + use_clipped_value_loss=True, + clip_param=0.2, + entropy_coef=0.005, + num_learning_epochs=5, + num_mini_batches=4, + learning_rate=1.0e-3, + schedule="adaptive", + gamma=0.99, + lam=0.95, + desired_kl=0.01, + max_grad_norm=1.0, + ) diff --git a/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/manager_based/wheeled/mitcar/common/__init__.py b/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/manager_based/wheeled/mitcar/common/__init__.py new file mode 100644 index 0000000..3107d53 --- /dev/null +++ b/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/manager_based/wheeled/mitcar/common/__init__.py @@ -0,0 +1,15 @@ +# This is the observation ordering (DO NOT CHANGE) +# TODO: allow arbitrary re-ordering using this list +JOINT_NAMES = [ + 'chassis_to_back_left_wheel', + 'chassis_to_back_right_wheel', + 'chassis_to_front_left_hinge', + 'chassis_to_front_right_hinge', + 'front_left_hinge_to_wheel', + 'front_right_hinge_to_wheel', +] + +from .actions import ActionsCfg +from .observations import ObservationsCfg +from .events import EventCfg +from .commands import NoCommandsCfg diff --git a/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/manager_based/wheeled/mitcar/common/actions.py b/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/manager_based/wheeled/mitcar/common/actions.py new file mode 100644 index 0000000..5115b17 --- /dev/null +++ b/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/manager_based/wheeled/mitcar/common/actions.py @@ -0,0 +1,14 @@ +import omni.isaac.lab.envs.mdp as mdp +from omni.isaac.lab.utils import configclass + +from . import JOINT_NAMES + +@configclass +class ActionsCfg: + """Action specifications for the environment.""" + + joint_efforts = mdp.JointEffortActionCfg( + asset_name="robot", + joint_names=JOINT_NAMES, + scale=250. + ) \ No newline at end of file diff --git a/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/manager_based/wheeled/mitcar/common/commands.py b/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/manager_based/wheeled/mitcar/common/commands.py new file mode 100644 index 0000000..df27f12 --- /dev/null +++ b/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/manager_based/wheeled/mitcar/common/commands.py @@ -0,0 +1,9 @@ +from omni.isaac.lab.envs import mdp +from omni.isaac.lab.utils import configclass + +@configclass +class NoCommandsCfg: + """Command terms for the MDP.""" + + # no commands for this MDP + null = mdp.NullCommandCfg() diff --git a/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/manager_based/wheeled/mitcar/common/events.py b/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/manager_based/wheeled/mitcar/common/events.py new file mode 100644 index 0000000..ef8fe4d --- /dev/null +++ b/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/manager_based/wheeled/mitcar/common/events.py @@ -0,0 +1,42 @@ +from omni.isaac.lab.envs import mdp +from omni.isaac.lab.managers import EventTermCfg as EventTerm +from omni.isaac.lab.managers import SceneEntityCfg +from omni.isaac.lab.utils import configclass + +from . import JOINT_NAMES + +@configclass +class EventCfg: + """Configuration for events.""" + + # on startup + # add_pole_mass = EventTerm( + # func=mdp.randomize_rigid_body_mass, + # mode="startup", + # params={ + # "asset_cfg": SceneEntityCfg("chassis"), + # "mass_distribution_params": (0.1, 0.5), + # "operation": "add", + # }, + # ) + + # on reset + reset_car_joints = EventTerm( + func=mdp.reset_joints_by_offset, + mode="reset", + params={ + "asset_cfg": SceneEntityCfg("robot", joint_names=JOINT_NAMES), + "position_range": (-0.0, 0.0), + "velocity_range": (-0.0, 0.0), + }, + ) + + reset_car_pos = EventTerm( + func=mdp.reset_root_state_uniform, + mode="reset", + params={ + "asset_cfg": SceneEntityCfg("robot", body_names=['chassis']), + "pose_range": {}, + "velocity_range": {}, + }, + ) diff --git a/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/manager_based/wheeled/mitcar/common/observations.py b/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/manager_based/wheeled/mitcar/common/observations.py new file mode 100644 index 0000000..fdb4779 --- /dev/null +++ b/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/manager_based/wheeled/mitcar/common/observations.py @@ -0,0 +1,30 @@ + +import omni.isaac.lab.envs.mdp as mdp +from omni.isaac.lab.utils import configclass +from omni.isaac.lab.managers import ObservationGroupCfg as ObsGroup +from omni.isaac.lab.managers import ObservationTermCfg as ObsTerm + +from .. import utils + +@configclass +class ObservationsCfg: + """Observation specifications for the environment.""" + + @configclass + class PolicyCfg(ObsGroup): + """Observations for policy group.""" + + # observation terms (order preserved) + joint_vel_rel = ObsTerm(func=mdp.joint_vel_rel) + root_pos_w = ObsTerm(func=utils.root_pos_w) # position in simulation world frame + root_quat_w = ObsTerm(func=mdp.root_quat_w) + base_lin_vel = ObsTerm(func=mdp.base_lin_vel) + base_ang_vel = ObsTerm(func=mdp.base_ang_vel) + + def __post_init__(self) -> None: + self.enable_corruption = False + self.concatenate_terms = True + + # observation groups + policy: PolicyCfg = PolicyCfg() + diff --git a/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/manager_based/wheeled/mitcar/common/rl_env.py b/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/manager_based/wheeled/mitcar/common/rl_env.py new file mode 100644 index 0000000..84dd709 --- /dev/null +++ b/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/manager_based/wheeled/mitcar/common/rl_env.py @@ -0,0 +1,33 @@ +from dataclasses import field + +from omni.isaac.lab.utils import configclass +from omni.isaac.lab.envs import ManagerBasedRLEnvCfg + +from . import ObservationsCfg, ActionsCfg, EventCfg + +@configclass +class MITCarRLCommonCfg(ManagerBasedRLEnvCfg): + """ + Common configuration for the MIT Car environment. + Includes the basic settings: + - Observations + - Actions + - Events + as well as the number of environments and the spacing between them. + + Also sets sim dt and decimation. + """ + + num_envs: int = field(default=5) + env_spacing: float = field(default=0.05) + + # Basic settings + observations = ObservationsCfg() + actions = ActionsCfg() + events = EventCfg() + + def __post_init__(self): + + self.sim.dt = 0.025 # sim step every 25ms = 40Hz + self.decimation = 4 # env step every 4 sim steps: 40Hz / 4 = 10Hz + self.sim.render_interval = self.decimation diff --git a/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/manager_based/wheeled/mitcar/common/scenes.py b/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/manager_based/wheeled/mitcar/common/scenes.py new file mode 100644 index 0000000..7007743 --- /dev/null +++ b/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/manager_based/wheeled/mitcar/common/scenes.py @@ -0,0 +1,64 @@ +import torch + +import omni.isaac.lab.sim as sim_utils +import omni.isaac.lab.utils.math as math_utils + +from omni.isaac.lab.assets import AssetBaseCfg, ArticulationCfg +from omni.isaac.lab.scene import InteractiveSceneCfg +from omni.isaac.lab.terrains import TerrainImporterCfg +from omni.isaac.lab.utils import configclass + +from omni.isaac.groundcontrol_tasks.manager_based.wheeled.terrains import rough, racetrack +from omni.isaac.groundcontrol_assets.mitcar import MITCAR_CFG + +################ +#### SCENES #### +################ + +FLAT_TERRAIN_CFG = TerrainImporterCfg( + prim_path="/World/ground", + terrain_type="plane", + debug_vis=False, + ) + + +@configclass +class MITCarBaseSceneCfg(InteractiveSceneCfg): + + """Configuration for a MIT car Scene""" + + # Distant Light + light = AssetBaseCfg( + prim_path="/World/light", + spawn=sim_utils.DistantLightCfg(color=(0.75, 0.75, 0.75), intensity=3000.0), + ) + + # Mesh + robot: ArticulationCfg = MITCAR_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot") + + +@configclass +class MITCarFlatSceneCfg(MITCarBaseSceneCfg): + """Configuration for a MIT car Scene with flat terrain""" + terrain = FLAT_TERRAIN_CFG + + +@configclass +class MITCarRoughSceneCfg(MITCarBaseSceneCfg): + """Configuration for a MIT car Scene with rough terrain""" + terrain = rough.ROUGH_TERRAIN_CFG + + +@configclass +class MITCarRacetrackSceneCfg(MITCarBaseSceneCfg): + """Configuration for a MIT car Scene with racetrack terrain""" + + terrain = racetrack.RacetrackTerrainImporterCfg() + + def __post_init__(self): + """Post initialization.""" + super().__post_init__() + # Set initial state of the robot + self.robot.init_state = self.robot.init_state.replace( + pos=(0.0, 0.0, 0.5) + ) diff --git a/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/manager_based/wheeled/mitcar/mitcar_manager_env_cfg.py b/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/manager_based/wheeled/mitcar/mitcar_manager_env_cfg.py new file mode 100644 index 0000000..13759bb --- /dev/null +++ b/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/manager_based/wheeled/mitcar/mitcar_manager_env_cfg.py @@ -0,0 +1,206 @@ +import torch +import torch.nn as nn + +import omni.isaac.lab.envs.mdp as mdp +import omni.isaac.lab.utils.math as math_utils + +from omni.isaac.lab.managers import TerminationTermCfg as DoneTerm +from omni.isaac.lab.managers import RewardTermCfg as RewTerm +from omni.isaac.lab.utils import configclass + +################ +# REWARDS +################ + +def dist2goal(env, target): + root_pos = mdp.root_pos_w(env) + # target = env.scene.env_origins + torch.tensor(target, dtype=root_pos.dtype, device=root_pos.device) + target = torch.tensor(target, dtype=root_pos.dtype, device=root_pos.device) + dist = torch.norm(target - root_pos, dim=-1) + return dist + + +def upright_penalty(env, thresh_deg): + rot_mat = math_utils.matrix_from_quat(mdp.root_quat_w(env)) + up_dot = rot_mat[:, 2, 2] + up_dot = torch.rad2deg(torch.arccos(up_dot)) + penalty = torch.where(up_dot > thresh_deg, up_dot - thresh_deg, 0.) + return penalty + + +@configclass +class RewardsCfg: + + """Reward terms for the MDP.""" + # (1) Primary task: Distance to target + dist2goal = RewTerm( + func=dist2goal, + weight=-1.0, + params={"target": [3., 2., 0.]}, + ) + + # (2) Upright + upright = RewTerm( + func=upright_penalty, + weight=-1.0, + params={"thresh_deg": 30.}, + ) + + +@configclass +class TerminationsCfg: + """Termination terms for the MDP.""" + + # (1) Time out + time_out = DoneTerm(func=mdp.time_out, time_out=True) + # (2) Cart out of bounds + # cart_out_of_bounds = DoneTerm( + # func=mdp.joint_pos_out_of_manual_limit, + # params={"asset_cfg": SceneEntityCfg("robot", joint_names=["slider_to_cart"]), "bounds": (-3.0, 3.0)}, + # ) + +@configclass +class CurriculumCfg: + """Configuration for the curriculum.""" + pass + +################################################## +############## ENVIRONMENT CONFIGS ############### +################################################## + +from . import common +from .common.rl_env import MITCarRLCommonCfg +from .common.scenes import MITCarRoughSceneCfg, MITCarFlatSceneCfg, MITCarBaseSceneCfg + +####### RL Environment ####### + +@configclass +class MITCarRLEnvCfg(MITCarRLCommonCfg): + """Configuration for the cartpole environment.""" + + seed: int = 42 + + # # MDP settings + curriculum: CurriculumCfg = CurriculumCfg() + rewards: RewardsCfg = RewardsCfg() + terminations: TerminationsCfg = TerminationsCfg() + # No command generator + commands: common.NoCommandsCfg = common.NoCommandsCfg() + + + def __post_init__(self): + """Post initialization.""" + super().__post_init__() + # viewer settings + self.viewer.eye = [11., 0.0, 14.0] + self.viewer.lookat = [2.0, 0.0, 0.] + + # termination settings + self.episode_length_s = 5 + # self.max_episode_length = 512 + + # Scene settings + self.scene = MITCarRoughSceneCfg( + num_envs=self.num_envs, env_spacing=self.env_spacing, + ) + + # Set seed for terrain generator + self.scene.terrain.terrain_generator = self.scene.terrain.terrain_generator.replace(seed=self.seed) + + # HACK to gain control of ordering of joints through JOINT_NAMES + # observation terms (order preserved) + # self.observations.joint_pos_rel.func = lambda x : self.observations.joint_pos_rel.func(x, joint_order_asset_cfg) + # self.observations.joint_vel_rel.func = lambda x : self.observations.joint_vel_rel.func(x, joint_order_asset_cfg) + # self.observations.pos_w.func = lambda x : self.observations.pos_w.func(x, joint_order_asset_cfg) + # self.observations.base_lin_vel.func = lambda x : self.observations.base_lin_vel.func(x, joint_order_asset_cfg) + # self.observations.lin_vel_w.func = lambda x : self.observations.lin_vel_w.func(x, joint_order_asset_cfg) + # self.observations.base_ang_vel.func = lambda x : self.observations.base_ang_vel.func(x, joint_order_asset_cfg) + + +####### Base Environment ####### + +@configclass +class NoTerminationsCfg: + """No terminations for the MDP.""" + pass + +@configclass +class MITCarPlayEnvCfg(MITCarRLEnvCfg): + """Configuration for the cartpole environment.""" + + terminations = NoTerminationsCfg() + + def __post_init__(self): + """Post initialization.""" + super().__post_init__() + # viewer settings + self.viewer.eye = [11., 0.0, 14.0] + self.viewer.lookat = [2.0, 0.0, 0.] + + # Scene settings + self.scene:MITCarBaseSceneCfg = MITCarFlatSceneCfg( + num_envs=self.num_envs, env_spacing=self.env_spacing + ) + + +####### IRL ####### + +def _compute_learned_reward(env, model): + obs = env.observation_manager.compute()['reward'] + rew = model(obs).squeeze(dim=-1) + return rew + +@configclass +class LearnedRewardsCfg: + """Learned reward terms for the MDP.""" + reward:RewTerm = None + + +def root_2dpos_w(env): + return mdp.root_pos_w(env)[..., :2] + + +@configclass +class MITCarIRLEnvCfg(MITCarRLEnvCfg): + + rew_model: nn.Module = None + + def __post_init__(self): + super().__post_init__() + self.rewards = LearnedRewardsCfg() + rewterm = lambda env: _compute_learned_reward(env, self.rew_model) + self.rewards.reward = RewTerm( + func=rewterm, + weight=1.0, + ) + + +########################################### +############## ENVIRONMENTS ############### +########################################### + +# TODO +# from wheeled_gym.tasks.wrappers.irl_wrapper import IRLWrapper +# class MITCarIRLWrapper(IRLWrapper): +# """MIT Car environment for IRL.""" + +# # def __init__(self, cfg: ManagerBasedRLEnvCfg, render_mode: str | None = None, **kwargs): +# def __init__(self, env): +# super().__init__(env) +# self._episode_length_s_persistent = self.cfg.episode_length_s + +# def irl_mode(self, ep_steps): +# super().irl_mode() +# self.env.cfg.episode_length_s = ep_steps + +# def rl_mode(self): +# super().rl_mode() +# self.env.cfg.episode_length_s = self._episode_length_s_persistent + +# # TODO: Use ClipActionWrapper to clip actions (which uses env.action_space.lower/upper +# # which should be set in the config) +# def step(self, action): +# action = torch.clip(action, -1., 1.) +# obs, rew, term, trunc, info = super().step(action) + +# return obs, rew, term, trunc, info diff --git a/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/manager_based/wheeled/mitcar/mitcar_manager_racetrack_env_cfg.py b/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/manager_based/wheeled/mitcar/mitcar_manager_racetrack_env_cfg.py new file mode 100644 index 0000000..0ef3c49 --- /dev/null +++ b/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/manager_based/wheeled/mitcar/mitcar_manager_racetrack_env_cfg.py @@ -0,0 +1,183 @@ +import torch + +import omni.isaac.lab.envs.mdp as mdp +import omni.isaac.lab.utils.math as math_utils + +from omni.isaac.lab.managers import TerminationTermCfg as DoneTerm +from omni.isaac.lab.managers import RewardTermCfg as RewTerm +from omni.isaac.lab.utils import configclass + +################ +# REWARDS +################ + +def track_progress_rate(env): + '''Estimate track progress by positive z-axis angular velocity around the environment''' + root_ang_vel = mdp.root_ang_vel_w(env) + progress_rate = root_ang_vel[..., 2] + return progress_rate + + +def upright_penalty(env, thresh_deg): + rot_mat = math_utils.matrix_from_quat(mdp.root_quat_w(env)) + up_dot = rot_mat[:, 2, 2] + up_dot = torch.rad2deg(torch.arccos(up_dot)) + penalty = torch.where(up_dot > thresh_deg, up_dot - thresh_deg, 0.) + return penalty + +def forward_vel_rew(env): + lin_vel = mdp.base_lin_vel(env) + return lin_vel[..., 0] + +def falling_penalty(env): + pos = mdp.root_pos_w(env) + return torch.where(pos[..., 2] < 0.1, 100.0, 0.0) + +def forward_wheel_spin(env): + joint_vel = mdp.joint_vel(env) + return torch.sum(joint_vel[..., [0,1,4,5]], dim=-1) + +@configclass +class RacetrackRewardsCfg: + + """Reward terms for the MDP.""" + # (1) Progress around track + # progress = RewTerm( + # func=track_progress_rate, + # weight=1.0, + # ) + + # (2) Upright + upright = RewTerm( + func=upright_penalty, + weight=-1.0, + params={"thresh_deg": 30.}, + ) + + # (3) Forward velocity + forward_vel = RewTerm( + func=forward_vel_rew, + weight=20., + ) + + # (4) Falling penalty + falling_penalty = RewTerm( + func=falling_penalty, + weight=-10.0, + ) + + # (5) Forward wheel spin + forward_wheel_spin = RewTerm( + func=forward_wheel_spin, + weight=1., + ) + + +@configclass +class RacetrackTerminationsCfg: + """Termination terms for the MDP.""" + + # (1) Time out + time_out = DoneTerm(func=mdp.time_out, time_out=True) + # (2) Cart out of bounds + cart_out_of_bounds = DoneTerm( + func=mdp.root_height_below_minimum, + params={"minimum_height": 0.}, + ) + # (3) Stuck TODO + # stuck = DoneTerm( + # func=mdp.root_lin_vel_below_threshold, + # params={"threshold": 0.01}, + # ) + +@configclass +class CurriculumCfg: + """Configuration for the curriculum.""" + pass + +##################################### +############## EVENTS ############### +##################################### + +from omni.isaac.lab.managers import EventTermCfg as EventTerm +from .utils import reset_root_state_from_terrain_points + +@configclass +class RacetrackEventsCfg: + """Configuration for the events.""" + reset_root_state_from_terrain_points = EventTerm( + func=reset_root_state_from_terrain_points, + mode="reset", + ) + + +################################################## +############## ENVIRONMENT CONFIGS ############### +################################################## + +from . import common +from .common.rl_env import MITCarRLCommonCfg +from .common.scenes import MITCarRacetrackSceneCfg + +####### RL Environment ####### + +@configclass +class MITCarRacetrackRLEnvCfg(MITCarRLCommonCfg): + """Configuration for the cartpole environment.""" + + seed: int = 42 + + # Reset config + events: RacetrackEventsCfg = RacetrackEventsCfg() + + # MDP settings + curriculum: CurriculumCfg = CurriculumCfg() + rewards: RacetrackRewardsCfg = RacetrackRewardsCfg() + terminations: RacetrackTerminationsCfg = RacetrackTerminationsCfg() + # No command generator + commands: common.NoCommandsCfg = common.NoCommandsCfg() + + + def __post_init__(self): + """Post initialization.""" + super().__post_init__() + # viewer settings + self.viewer.eye = [11., -20.0, 14.0] + self.viewer.lookat = [2.0, 0.0, 0.] + + # Terminations config + self.episode_length_s = 30 + # self.max_episode_length = 512 + + # Scene settings + self.scene:MITCarRacetrackSceneCfg = MITCarRacetrackSceneCfg( + num_envs=self.num_envs, env_spacing=self.env_spacing, + ) + + # HACK to gain control of ordering of joints through JOINT_NAMES + # observation terms (order preserved) + # self.observations.joint_pos_rel.func = lambda x : self.observations.joint_pos_rel.func(x, joint_order_asset_cfg) + # self.observations.joint_vel_rel.func = lambda x : self.observations.joint_vel_rel.func(x, joint_order_asset_cfg) + # self.observations.pos_w.func = lambda x : self.observations.pos_w.func(x, joint_order_asset_cfg) + # self.observations.base_lin_vel.func = lambda x : self.observations.base_lin_vel.func(x, joint_order_asset_cfg) + # self.observations.lin_vel_w.func = lambda x : self.observations.lin_vel_w.func(x, joint_order_asset_cfg) + # self.observations.base_ang_vel.func = lambda x : self.observations.base_ang_vel.func(x, joint_order_asset_cfg) + + +####### Base Environment ####### + +@configclass +class NoTerminationsCfg: + pass + +@configclass +class MITCarRacetrackPlayEnvCfg(MITCarRacetrackRLEnvCfg): + + # play_duration_s: float = 60. + terminations = NoTerminationsCfg() + + def __post_init__(self): + """Post initialization.""" + super().__post_init__() + # self.episode_length_s = self.play_duration_s + diff --git a/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/manager_based/wheeled/mitcar/utils.py b/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/manager_based/wheeled/mitcar/utils.py new file mode 100644 index 0000000..4d00ec8 --- /dev/null +++ b/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/manager_based/wheeled/mitcar/utils.py @@ -0,0 +1,118 @@ +import torch + +import omni.isaac.lab.envs.mdp as mdp +import omni.isaac.lab.utils.math as math_utils + +from omni.isaac.lab.envs import ManagerBasedEnv +from omni.isaac.lab.managers import SceneEntityCfg +from omni.isaac.lab.assets import Articulation, RigidObject +from omni.isaac.lab.terrains import TerrainImporter + + +def root_pos_w(env: ManagerBasedEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor: + """Root position in the simulation world frame.""" + asset: Articulation = env.scene[asset_cfg.name] + return asset.data.root_pos_w + + +def root_lin_vel_below_threshold(env: ManagerBasedEnv, threshold: float) -> torch.Tensor: + """Check if the root linear velocity is below the threshold. + + Args: + env: The environment. + threshold: The threshold value. + + Returns: + A boolean tensor indicating if the root linear velocity is below the threshold. + """ + return mdp.root_lin_vel_w(env) < threshold + + +@torch.jit.script +def _f(euler_xyz: torch.Tensor) -> torch.Tensor: + return math_utils.quat_from_euler_xyz(euler_xyz[0], euler_xyz[1], euler_xyz[2]) + +__f = torch.vmap(_f) +def quat_from_euler_xyz_vect(euler_xyz: torch.Tensor) -> torch.Tensor: + """Convert Euler XYZ angles to quaternions. + + Args: + euler_xyz: A tensor of Euler XYZ angles in radians with shape ``(N, 3)``. + + Returns: + A tensor of quaternions with shape ``(N, 4)``. + """ + return __f(euler_xyz) + + +def reset_root_state_from_terrain_points( + env: ManagerBasedEnv, + env_ids: torch.Tensor, + # valid_posns_and_rots: dict[str, tuple[float, float]], + asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), +): + """Reset the asset root state by sampling a random valid point from the config. + + This function samples a random valid pose(based on flat patches) from the terrain and sets the root state + of the asset to this position. The function also samples random velocities from the given ranges and sets them + into the physics simulation. + + The function takes a dictionary of position and velocity ranges for each axis and rotation: + + * :attr:`pose_range` - a dictionary of pose ranges for each axis. The keys of the dictionary are ``roll``, + ``pitch``, and ``yaw``. The position is sampled from the flat patches of the terrain. + * :attr:`velocity_range` - a dictionary of velocity ranges for each axis and rotation. The keys of the dictionary + are ``x``, ``y``, ``z``, ``roll``, ``pitch``, and ``yaw``. + + The values are tuples of the form ``(min, max)``. If the dictionary does not contain a particular key, + the position is set to zero for that axis. + + Note: + The function expects the terrain to have valid flat patches under the key "init_pos". The flat patches + are used to sample the random pose for the robot. + + Raises: + ValueError: If the terrain does not have valid flat patches under the key "init_pos". + """ + # access the used quantities (to enable type-hinting) + asset: RigidObject | Articulation = env.scene[asset_cfg.name] + terrain: TerrainImporter = env.scene.terrain + + # obtain all flat patches corresponding to the valid poses + # valid_positions: torch.Tensor = terrain.flat_patches.get("init_pos") + valid_poses = terrain.cfg.valid_init_poses + if valid_poses is None: + raise ValueError( + "The event term 'reset_root_state_from_terrain_points' requires 'valid_init_poses' in the TerrainImporterCfg." + ) + # Tensorizes the valid poses + # TODO move to constructor of terrain importer + posns = torch.stack(list(map(lambda x: torch.tensor(x.pos, device=env.device), valid_poses))) + oris = list(map(lambda x: torch.deg2rad(torch.tensor(x.rot_euler_xyz_deg, device=env.device)), valid_poses)) + oris = torch.stack([math_utils.quat_from_euler_xyz(*ori) for ori in oris]) + + # sample random valid poses + ids = torch.randint(0, len(valid_poses), size=(len(env_ids),), device=env.device) + + positions = posns[ids] + positions += asset.data.default_root_state[env_ids, :3] + orientations = oris[ids] + + # sample random orientations (TODO) + # range_list = [pose_range.get(key, (0.0, 0.0)) for key in ["roll", "pitch", "yaw"]] + # ranges = torch.tensor(range_list, device=asset.device) + # rand_samples = math_utils.sample_uniform(ranges[:, 0], ranges[:, 1], (len(env_ids), 3), device=asset.device) + + # convert to quaternions + # orientations = math_utils.quat_from_euler_xyz(rand_samples[:, 0], rand_samples[:, 1], rand_samples[:, 2]) + + # sample random velocities (TODO) + # range_list = [velocity_range.get(key, (0.0, 0.0)) for key in ["x", "y", "z", "roll", "pitch", "yaw"]] + # ranges = torch.tensor(range_list, device=asset.device) + # rand_samples = math_utils.sample_uniform(ranges[:, 0], ranges[:, 1], (len(env_ids), 6), device=asset.device) + + # velocities = asset.data.default_root_state[:, 7:13] + rand_samples + + # set into the physics simulation + asset.write_root_pose_to_sim(torch.cat([positions, orientations], dim=-1), env_ids=env_ids) + # asset.write_root_velocity_to_sim(velocities, env_ids=env_ids) TODO diff --git a/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/manager_based/wheeled/terrains/racetrack.py b/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/manager_based/wheeled/terrains/racetrack.py new file mode 100644 index 0000000..77050e7 --- /dev/null +++ b/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/manager_based/wheeled/terrains/racetrack.py @@ -0,0 +1,56 @@ +import omni.isaac.lab.sim as sim_utils +from omni.isaac.lab.utils import configclass +from omni.isaac.lab.terrains import TerrainImporterCfg +from omni.isaac.groundcontrol_assets import GROUNDCONTROL_ASSETS_DATA_DIR + +RACETRACK_TERRAIN_CFG = TerrainImporterCfg( + prim_path="/World/ground", + terrain_type="usd", + usd_path=f"{GROUNDCONTROL_ASSETS_DATA_DIR}/Props/terrain/racetrack-terrain.usd", + collision_group=-1, + physics_material=sim_utils.RigidBodyMaterialCfg( + friction_combine_mode="multiply", + restitution_combine_mode="multiply", + static_friction=1.5, + dynamic_friction=1.5, + ), + debug_vis=False, + ) + +@configclass +class RacetrackTerrainImporterCfg(TerrainImporterCfg): + + @configclass + class InitialPoseCfg: + pos: tuple[float, float, float] = (0.0, 0.0, 0.0) + rot_euler_xyz_deg: tuple[float, float, float] = (0.0, 0.0, 0.0) + + height = 0.0 + valid_init_poses = [ + InitialPoseCfg( + pos=(12.0, 1.27, height), + rot_euler_xyz_deg=(0., 0., 135.0) + ), + InitialPoseCfg( + pos=(-5.33, 3.3, height), + rot_euler_xyz_deg=(0., 0., 180.0), + ), + InitialPoseCfg( + pos=(-8.7, -7.27, height), + ), + InitialPoseCfg( + pos=(0., 0., height), + ), + ] + prim_path="/World/ground" + terrain_type="usd" + usd_path=f"{GROUNDCONTROL_ASSETS_DATA_DIR}/Props/terrain/racetrack-terrain.usd", + # usd_path=f"/home/tyler/Research/GroundControl/source/extensions/omni.isaac.groundcontrol_assets/data/Props/terrain/racetrack-terrain.usd", + collision_group=-1 + physics_material=sim_utils.RigidBodyMaterialCfg( + friction_combine_mode="multiply", + restitution_combine_mode="multiply", + static_friction=1.5, + dynamic_friction=1.5, + ) + debug_vis=False \ No newline at end of file diff --git a/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/manager_based/wheeled/terrains/rough.py b/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/manager_based/wheeled/terrains/rough.py new file mode 100644 index 0000000..cccdad3 --- /dev/null +++ b/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/manager_based/wheeled/terrains/rough.py @@ -0,0 +1,45 @@ +import omni.isaac.lab.sim as sim_utils +import omni.isaac.lab.terrains as terrain_gen +from omni.isaac.lab.terrains import TerrainImporterCfg +from omni.isaac.lab.terrains.terrain_generator_cfg import TerrainGeneratorCfg +from omni.isaac.lab_assets import ISAACLAB_NUCLEUS_DIR + +ROUGH_TERRAINS_GEN_CFG = TerrainGeneratorCfg( + seed=42, + size=(8.0, 8.0), + border_width=20.0, + num_rows=2, + num_cols=3, + horizontal_scale=0.1, + vertical_scale=0.005, + slope_threshold=0.75, + use_cache=False, + sub_terrains={ + "random_rough": terrain_gen.HfRandomUniformTerrainCfg( + proportion=0.5, noise_range=(0.02, 0.10), noise_step=0.02, border_width=0.25 + ), + "flat": terrain_gen.MeshPlaneTerrainCfg( + proportion=0.5 + ) + }, +) + +ROUGH_TERRAIN_CFG = TerrainImporterCfg( + prim_path="/World/ground", + terrain_type="generator", + terrain_generator=ROUGH_TERRAINS_GEN_CFG, + max_init_terrain_level=5, + collision_group=-1, + physics_material=sim_utils.RigidBodyMaterialCfg( + friction_combine_mode="multiply", + restitution_combine_mode="multiply", + static_friction=1.0, + dynamic_friction=1.0, + ), + visual_material=sim_utils.MdlFileCfg( + mdl_path=f"{ISAACLAB_NUCLEUS_DIR}/Materials/TilesMarbleSpiderWhiteBrickBondHoned/TilesMarbleSpiderWhiteBrickBondHoned.mdl", + project_uvw=True, + texture_scale=(0.25, 0.25), + ), + debug_vis=False, + ) diff --git a/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/utils/runners/common/rl_runner.py b/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/utils/runners/common/rl_runner.py new file mode 100644 index 0000000..9cfe744 --- /dev/null +++ b/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/utils/runners/common/rl_runner.py @@ -0,0 +1,20 @@ +import torch +from abc import ABC, abstractmethod +from typing import Callable + +class RLRunner(ABC): + + @abstractmethod + def learn(self, num_learning_iterations:int): + '''Train the RL algorithm''' + pass + + @abstractmethod + def reset(self): + '''Reset the RL runner / optimizer''' + pass + + @abstractmethod + def collect_rollouts(self, max_steps:int): + '''Collect rollouts from the environment''' + pass diff --git a/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/utils/runners/rslrl_runner.py b/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/utils/runners/rslrl_runner.py new file mode 100644 index 0000000..a696e1b --- /dev/null +++ b/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/utils/runners/rslrl_runner.py @@ -0,0 +1,200 @@ +import warnings +import rsl_rl +from rsl_rl import runners +import os + +try: + from tqdm import TqdmExperimentalWarning + + # Remove experimental warning + warnings.filterwarnings("ignore", category=TqdmExperimentalWarning) + from tqdm.rich import tqdm +except ImportError: + # Rich not installed, we only throw an error + # if the progress bar is used + tqdm = None + +import time +import torch +from collections import deque +from typing import Dict + +from .common.rl_runner import RLRunner + + +class ExtraRslRlRolloutStorage(rsl_rl.storage.rollout_storage.RolloutStorage): + def __init__(self, extras_shape: Dict[str, torch.Size], + *args, **kwargs): + ''' + example of extras_info: + { + "reward": (torch.Size([64, 2]), lambda infos: infos["observations"]["reward"]), + ... + } + The name is the name of the extras tensor and the size is the size of this tensor. + The callable is the mapping from the infos dict to the tensor. + ''' + super().__init__(*args, **kwargs) + self.extras = {name: torch.empty(self.num_transitions_per_env, self.num_envs, *shape) for name, shape in extras_shape.items()} + + def add_transitions(self, transition, infos): + super().add_transitions(transition) + self.step -= 1 # Undo super step + for name in self.extras.keys(): + self.extras[name][self.step] = infos['observations'][name] + self.step += 1 # Redo super step + + +class OnPolicyRunner(runners.OnPolicyRunner, RLRunner): + ''' Override for logging purposes ''' + + def __init__(self, env, cfg, log_dir=None, device="cpu"): + super().__init__(env, cfg, log_dir, device) + self.no_log = self.cfg.get("rl_no_log", True) + self.no_wandb = self.cfg.get("no_wandb", False) + self.logger_type = None + if not self.no_wandb: + self.logger_type = "wandb" + # self.pbar = tqdm(total=self.cfg.get("rl_max_iterations", 0)) + + def learn(self, num_learning_iterations, init_at_random_ep_len=False): + # initialize writer + if not self.no_log and self.logger_type == "wandb": + from rsl_rl.utils.wandb_utils import WandbSummaryWriter + + self.writer = WandbSummaryWriter(log_dir=self.log_dir, flush_secs=10, cfg=self.cfg) + self.writer.log_config(self.env.cfg, self.cfg, self.alg_cfg, self.policy_cfg) + + if init_at_random_ep_len: + self.env.episode_length_buf = torch.randint_like( + self.env.episode_length_buf, high=int(self.env.max_episode_length) + ) + obs, extras = self.env.get_observations() + critic_obs = extras["observations"].get("critic", obs) + obs, critic_obs = obs.to(self.device), critic_obs.to(self.device) + self.train_mode() # switch to train mode (for dropout for example) + + ep_infos = [] + rewbuffer = deque(maxlen=100) + lenbuffer = deque(maxlen=100) + cur_reward_sum = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device) + cur_episode_length = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device) + + start_iter = self.current_learning_iteration + tot_iter = start_iter + num_learning_iterations + for it in tqdm(range(start_iter, tot_iter)): + start = time.time() + # Rollout + with torch.inference_mode(): + for i in range(self.num_steps_per_env): + actions = self.alg.act(obs, critic_obs) + obs, rewards, dones, infos = self.env.step(actions) + if actions.isnan().any(): + raise ValueError("NaN in actions") + obs = self.obs_normalizer(obs) + if "critic" in infos["observations"]: + critic_obs = self.critic_obs_normalizer(infos["observations"]["critic"]) + else: + critic_obs = obs + obs, critic_obs, rewards, dones = ( + obs.to(self.device), + critic_obs.to(self.device), + rewards.to(self.device), + dones.to(self.device), + ) + self.alg.process_env_step(rewards, dones, infos) + + if not self.no_log: + # Book keeping + # note: we changed logging to use "log" instead of "episode" to avoid confusion with + # different types of logging data (rewards, curriculum, etc.) + if "episode" in infos: + ep_infos.append(infos["episode"]) + elif "log" in infos: + ep_infos.append(infos["log"]) + cur_reward_sum += rewards + cur_episode_length += 1 + new_ids = (dones > 0).nonzero(as_tuple=False) + rewbuffer.extend(cur_reward_sum[new_ids][:, 0].cpu().numpy().tolist()) + lenbuffer.extend(cur_episode_length[new_ids][:, 0].cpu().numpy().tolist()) + cur_reward_sum[new_ids] = 0 + cur_episode_length[new_ids] = 0 + + stop = time.time() + collection_time = stop - start + + # Learning step + start = stop + self.alg.compute_returns(critic_obs) + + mean_value_loss, mean_surrogate_loss = self.alg.update() + stop = time.time() + learn_time = stop - start + self.current_learning_iteration = it + if not self.no_log: + self.log(locals()) + if it % self.save_interval == 0: + self.save(os.path.join(self.log_dir, "models", f"model_{it}.pt")) + ep_infos.clear() + + # self.save(os.path.join(self.log_dir, f"model_{self.current_learning_iteration}.pt")) + + def collect_rollouts(self, max_steps: int) -> ExtraRslRlRolloutStorage: + alg = self.alg + + rollouts_storage = ExtraRslRlRolloutStorage( + extras_shape={ + "reward": torch.Size([self.env.num_envs, 1]), + }, + num_envs=self.storage.num_envs, + num_transitions_per_env=self.storage.num_transitions_per_env, + actor_obs_shape=self.storage.actor_obs_shape, + critic_obs_shape=self.storage.critic_obs_shape, + action_shape=self.storage.action_shape, + device=self.device, + ) + self.storage.clear() + obs, extras = self.env.get_observations() + critic_obs = extras["observations"].get("critic", obs) + obs, critic_obs = obs.to(self.device), critic_obs.to(self.device) + + with torch.inference_mode(): + self.env.reset() + for i in range(max_steps): + actions = alg.act(obs, critic_obs) + obs, rewards, dones, infos = self.env.step(actions) + if actions.isnan().any(): + raise ValueError("NaN in actions") + obs = self.obs_normalizer(obs) + if "critic" in infos["observations"]: + critic_obs = self.critic_obs_normalizer(infos["observations"]["critic"]) + else: + critic_obs = obs + obs, critic_obs, rewards, dones = ( + obs.to(self.device), + critic_obs.to(self.device), + rewards.to(self.device), + dones.to(self.device), + ) + + ## on_policy_runner.OnPolicyRunner.process_env_step: + alg.transition.rewards = rewards.clone() + alg.transition.dones = dones + # Bootstrapping on time outs + if "time_outs" in infos: + alg.transition.rewards += alg.gamma * torch.squeeze( + alg.transition.values * infos["time_outs"].unsqueeze(1).to(alg.device), 1 + ) + + # Record the transitions to both buffers + # alg.storage.add_transitions(alg.transition) + rollouts_storage.add_transitions(alg.transition, infos) + + # Reset + alg.transition.clear() + alg.actor_critic.reset(dones) + + return rollouts_storage + + def reset(self): + self.__init__(self.env, self.cfg, self.log_dir, self.device) diff --git a/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/utils/wrappers/torch_clip_action.py b/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/utils/wrappers/torch_clip_action.py new file mode 100644 index 0000000..3e96cc5 --- /dev/null +++ b/source/extensions/omni.isaac.groundcontrol_tasks/omni/isaac/groundcontrol_tasks/utils/wrappers/torch_clip_action.py @@ -0,0 +1,37 @@ +import torch + +import gymnasium as gym + +class ClipAction(gym.ActionWrapper): + """ Adapted from https://github.com/openai/gym/blob/master/gym/wrappers/clip_action.py + Clip the continuous action within the valid :class:`Box` observation space bound. + + Example: + >>> import gym + >>> env = gym.make('Bipedal-Walker-v3') + >>> env = ClipAction(env) + >>> env.action_space + Box(-1.0, 1.0, (4,), float32) + >>> env.step(np.array([5.0, 2.0, -10.0, 0.0])) + # Executes the action np.array([1.0, 1.0, -1.0, 0]) in the base environment + """ + + def __init__(self, env: gym.Env): + """A wrapper for clipping continuous actions within the valid bound. + + Args: + env: The environment to apply the wrapper + """ + # assert isinstance(env.action_space, Box) + super().__init__(env) + + def action(self, action): + """Clips the action within the valid bounds. + + Args: + action: The action to clip + + Returns: + The clipped action + """ + return torch.clip(action, min=self.action_space.low, max=self.action_space.high) \ No newline at end of file diff --git a/source/standalone/train/train_rsl_rl.py b/source/standalone/train/train_rsl_rl.py new file mode 100644 index 0000000..d7f0544 --- /dev/null +++ b/source/standalone/train/train_rsl_rl.py @@ -0,0 +1,145 @@ +""" +This script tests the OffroadCarEnv in envs. +It also finds the joint_ids order returned by ObsTerm funcs + +example usage: + +python train/train_rsl_rl.py --headless + +loading a previous run: + +python train/train_rsl_rl.py --load-run +""" +################################### +###### BEGIN ISAACLAB SPINUP ###### +################################### + +import argparse +from utils.app_startup import startup, add_all_wheeled_gym_args, add_rsl_rl_args + +parser = argparse.ArgumentParser(description="Random agent for Isaac Lab environments.") + +overrides = { + "rl_max_iterations": 4096, + "env_name": "Isaac-MITCar-v0", + "num_envs": 1024, + "video": True, + "log_every": 5, + "video_interval": 5000, + "video_length": 1000, + # "no_wandb": True, +} +add_all_wheeled_gym_args(parser, overrides) +add_rsl_rl_args(parser) + +def _args_cb(args): + args.save_interval = args.log_every + args.rl_no_log = args.no_log + +simulation_app, args_cli = startup(parser=parser, prelaunch_callback=_args_cb) + +##################### +###### LOGGING ###### +##################### + +import gymnasium as gym +import os +from datetime import datetime +import torch +from omni.isaac.lab.utils.io import dump_pickle, dump_yaml + +if not args_cli.no_wandb: + import wandb + run = wandb.init( + project="IRL", + ) + wandb_name = wandb.run.name + run_name = wandb_name +else: + import random + run_name = f"bfirl-local-{random.randint(0, 1e7)}" + +log_dir = os.path.join(args_cli.log_dir, f'{run_name}_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}') +model_save_path = os.path.join(log_dir, "models") + +if not args_cli.no_log: + paths = [model_save_path] + for path in paths: + if not os.path.exists(path): + os.makedirs(path) + +############################ +#### CREATE ENVIRONMENT #### +############################ + +from omni.isaac.lab.utils.dict import print_dict +from omni.isaac.lab.utils import update_class_from_dict +from omni.isaac.lab_tasks.utils.wrappers.rsl_rl import RslRlVecEnvWrapper +from omni.isaac.lab_tasks.utils import parse_env_cfg +from omni.isaac.lab_tasks.utils import get_checkpoint_path + +import omni.isaac.groundcontrol_tasks +from omni.isaac.groundcontrol_tasks.utils.wrappers.torch_clip_action import ClipAction +from omni.isaac.groundcontrol_tasks.utils.runners.rslrl_runner import OnPolicyRunner + +from utils import WHEELED_LAB_LOGS_DIR +from utils.args import parse_rsl_rl_cfg, default_isaac_cfg + +####### FETCH CONFIGS ####### +isaac_cfg = default_isaac_cfg( + device=args_cli.device, + num_envs=args_cli.num_envs, + use_fabric=not args_cli.disable_fabric +) +env_cfg = parse_env_cfg( + args_cli.env_name, device=args_cli.device, num_envs=args_cli.num_envs, use_fabric=not args_cli.disable_fabric +) +update_class_from_dict(env_cfg, isaac_cfg) + +env = gym.make(args_cli.env_name, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None) + +####### INSTANTIATE ENV ####### +env.action_space.low = -1. +env.action_space.high = 1. +env = ClipAction(env) + +if args_cli.video: + video_kwargs = { + "video_folder": os.path.join(log_dir, "videos"), + "step_trigger": lambda step: step % args_cli.video_interval*args_cli.num_envs == 0, + "video_length": args_cli.video_length, + "disable_logger": True, + } + print("[INFO] Recording videos during training.") + print_dict(video_kwargs, nesting=4) + env = gym.wrappers.RecordVideo(env, **video_kwargs) + +env = RslRlVecEnvWrapper(env) + +#### CREATE AGENT (FACTORY) #### +agent_cfg = parse_rsl_rl_cfg(args_cli.env_name, args_cli) + +# dump the configuration into log-directory +if not args_cli.no_log: + dump_yaml(os.path.join(log_dir, "params", "env.yaml"), env_cfg) + dump_yaml(os.path.join(log_dir, "params", "agent.yaml"), agent_cfg) + dump_pickle(os.path.join(log_dir, "params", "env.pkl"), env_cfg) + dump_pickle(os.path.join(log_dir, "params", "agent.pkl"), agent_cfg) + +runner_log_dir = None if args_cli.test_mode else log_dir +runner = OnPolicyRunner(env, agent_cfg.to_dict(), log_dir=runner_log_dir, device=args_cli.device) + +if args_cli.load_run: + # get path to previous checkpoint + resume_path = get_checkpoint_path(WHEELED_LAB_LOGS_DIR, run_dir=args_cli.load_run, + other_dirs=["models"], checkpoint="model_.*") + print(f"[INFO]: Loading model checkpoint from: {resume_path}") + # load previously trained model + runner.load(resume_path) + +env.seed(agent_cfg.seed) + +runner.learn(num_learning_iterations=args_cli.rl_max_iterations) + +if not args_cli.no_wandb: + run.finish() diff --git a/source/standalone/train/train_sb3_rl.py b/source/standalone/train/train_sb3_rl.py new file mode 100644 index 0000000..f5a85b4 --- /dev/null +++ b/source/standalone/train/train_sb3_rl.py @@ -0,0 +1,148 @@ +""" +This script tests the OffroadCarEnv in envs. +It also finds the joint_ids order returned by ObsTerm funcs + +example usage: + +python test/test_rl.py --headless +""" +################################### +###### BEGIN ISAACLAB SPINUP ###### +################################### + +import argparse +from ..utils.app_startup import startup, add_all_wheeled_gym_args, add_rsl_rl_args + +parser = argparse.ArgumentParser(description="Random agent for Isaac Lab environments.") + +overrides = { + "rl_max_iterations": 1024, + "env_name": "Isaac-MITCarRacetrack-v0", + "num_envs": 1024, + "video": True, + "log_every": 5, + "video_interval": 1000, + "video_length": 1200, + # "no_wandb": True, +} +add_all_wheeled_gym_args(parser, overrides) +add_rsl_rl_args(parser) + +def _args_cb(args): + args.save_interval = args.log_every + args.rl_no_log = args.no_log + +simulation_app, args_cli = startup(parser=parser, prelaunch_callback=_args_cb) + +##################### +###### LOGGING ###### +##################### + +import gymnasium as gym +import os +from datetime import datetime +import torch + +if not args_cli.no_wandb: + import wandb + run = wandb.init( + project="IRL", + ) + wandb_name = wandb.run.name + run_name = wandb_name +else: + import random + run_name = f"bfirl-local-{random.randint(0, 1e7)}" + +log_dir = os.path.join(args_cli.log_dir, f'{run_name}_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}') +# dump the configuration into log-directory +# dump_yaml(os.path.join(log_dir, "params", "env.yaml"), env_cfg) +# dump_pickle(os.path.join(log_dir, "params", "env.pkl"), env_cfg) +model_save_path = os.path.join(log_dir, "models") + +############################ +#### CREATE ENVIRONMENT #### +############################ + +import wheeled_gym.tasks # register envs to gym +from wheeled_gym.train.utils.utils import default_isaac_cfg +from wheeled_gym.utils.data_processing import load_from_sb3 +from wheeled_gym.tasks.wrappers.clip_action import ClipAction +from wheeled_gym import WHEELED_GYM_LOGS_DIR + +from omni.isaac.lab.utils.dict import print_dict +from omni.isaac.lab.utils import update_class_from_dict +from omni.isaac.lab_tasks.utils import parse_env_cfg +from omni.isaac.lab_tasks.utils.wrappers.sb3 import Sb3VecEnvWrapper +from omni.isaac.lab_tasks.utils import get_checkpoint_path + +from wandb.integration.sb3 import WandbCallback +from stable_baselines3 import PPO +from stable_baselines3.common.callbacks import CheckpointCallback, CallbackList + +####### FETCH CONFIGS ####### +isaac_cfg = default_isaac_cfg( + device=args_cli.device, + num_envs=args_cli.num_envs, + use_fabric=not args_cli.disable_fabric +) +env_cfg = parse_env_cfg( + args_cli.env_name, device=args_cli.device, num_envs=args_cli.num_envs, use_fabric=not args_cli.disable_fabric +) +update_class_from_dict(env_cfg, isaac_cfg) + +env = gym.make(args_cli.env_name, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None) +n_timesteps = args_cli.rl_max_iterations * args_cli.agent_n_steps * env.scene.num_envs + +####### INSTANTIATE ENV ####### +# env.action_space.low = torch.tensor(-1., device=args_cli.device) +# env.action_space.high = torch.tensor(1., device=args_cli.device) +env.action_space.low = -1. +env.action_space.high = 1. +env = ClipAction(env) + +if args_cli.video: + video_kwargs = { + "video_folder": os.path.join(log_dir, "videos"), + "step_trigger": lambda step: step % args_cli.video_interval*args_cli.num_envs == 0, + "video_length": args_cli.video_length, + "disable_logger": True, + } + print("[INFO] Recording videos during training.") + print_dict(video_kwargs, nesting=4) + env = gym.wrappers.RecordVideo(env, **video_kwargs) + +env = Sb3VecEnvWrapper(env) + +#### CREATE AGENT (FACTORY) #### +env.seed(args_cli.seed) + +policy_kwargs = dict(net_arch=dict(pi=[32, 32], vf=[32, 32])) +checkpoint_callback = CheckpointCallback(save_freq=args_cli.checkpoint_every, + save_path=os.path.join(model_save_path), + name_prefix="model") +callbacklist = [checkpoint_callback] +if not args_cli.no_wandb: + callbacklist.append(WandbCallback()) +callbacklist = CallbackList(callbacklist) + +if args_cli.load_run: + run_path = os.path.join(WHEELED_GYM_LOGS_DIR, args_cli.load_run) + resume_path = get_checkpoint_path(WHEELED_GYM_LOGS_DIR, run_dir=args_cli.load_run, + other_dirs=["models"], checkpoint="model_.*") + print(f"[INFO]: Loading model checkpoint from: {resume_path}") + runner = load_from_sb3(resume_path, policy_type=args_cli.rl_algo_type) + runner.env = env +else: + runner = PPO("MlpPolicy", env, + device=args_cli.device, + policy_kwargs=policy_kwargs) + +try: + runner.learn(total_timesteps=n_timesteps, callback=callbacklist, progress_bar=True) +except KeyboardInterrupt: + runner.save(os.path.join(model_save_path, f"{run_name}_interrupted")) + +runner.save(os.path.join(model_save_path, f"{run_name}_done")) + +run.finish() diff --git a/source/standalone/train/utils/__init__.py b/source/standalone/train/utils/__init__.py new file mode 100644 index 0000000..74fdc22 --- /dev/null +++ b/source/standalone/train/utils/__init__.py @@ -0,0 +1,10 @@ +import os + +# DEFAULT_RUN_DIRNAME = "graceful-frost-97_2024-08-12_19-29-14" + +WHEELED_LAB_ROOT_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +WHEELED_LAB_RESOURCES_DIR = os.path.join(WHEELED_LAB_ROOT_DIR, 'resources') +WHEELED_LAB_LOGS_DIR = os.path.join(WHEELED_LAB_ROOT_DIR, 'logs') +WHEELED_LAB_CORE_DATA_DIR = os.path.join(WHEELED_LAB_ROOT_DIR, 'core_data') + +# WHEELED_LAB_DEFAULT_RUN_DIR = os.path.join(WHEELED_LAB_CORE_DATA_DIR, DEFAULT_RUN_DIRNAME) \ No newline at end of file diff --git a/source/standalone/train/utils/app_startup.py b/source/standalone/train/utils/app_startup.py new file mode 100644 index 0000000..95f391d --- /dev/null +++ b/source/standalone/train/utils/app_startup.py @@ -0,0 +1,150 @@ +""" +Boilerplate code for starting up IsaacLab backend +""" + +import argparse +from . import WHEELED_LAB_LOGS_DIR + + +defaults = { + "no_log": False, + "log_dir": WHEELED_LAB_LOGS_DIR, + "log_every": 10, + "video": False, + "video_length": 500, + "video_interval": 1e4, + "no_checkpoints": False, + "checkpoint_every": 1e4, + "no_wandb": False, + "rl_max_iterations": 1024, + "irl_max_iterations": 1024, + "agent_n_steps": 256, + "cpu": False, + "device": "cuda:0", + "disable_fabric": False, + "num_envs": 256, + "env_name": "Isaac-MITCar-v0", + "rl_algo_type": "PPO", + "rl_algo_lib": "sb3", + "seed": 42, + "rl_no_log": True, + "logger": "wandb", + "load-run": None, + + "test_mode": False, +} + + +def add_logging_args(parser, default_overrides={}): + defaults.update(default_overrides) + parser.add_argument("--no-log", action="store_true", default=defaults["no_log"] , help="Disable logging") + parser.add_argument("--rl-no-log", action="store_true", default=defaults["rl_no_log"] , help="Disable logging for RL algorithm") + parser.add_argument("--log-dir", type=str, default=defaults["log_dir"] , help="Directory for logging.") + parser.add_argument("--log-every", type=int, default=defaults["log_every"] , help="Log every n updates.") + parser.add_argument("--video", action="store_true", default=defaults["video"], help="Record videos during training.") + parser.add_argument("--video-length", type=int, default=defaults["video_length"], help="Length of the recorded video (in steps).") + parser.add_argument("--video-interval", type=int, default=defaults["video_interval"], help="Interval between video recordings (in steps).") + parser.add_argument("--no-checkpoints", action="store_true", default=defaults["no_checkpoints"], help="Save model checkpoints.") + parser.add_argument("--checkpoint-every", type=int, default=defaults["checkpoint_every"], help="Save model checkpoints every n steps.") + parser.add_argument("--no-wandb", action="store_true", default=defaults["no_wandb"], help="Disable wandb logging.") + parser.add_argument("--test-mode", action="store_true", default=defaults["test_mode"] , help="Disable logging; Disable wandb; Disable video recording; Disable checkpoints.") + + +def add_train_args(parser, default_overrides={}): + defaults.update(default_overrides) + parser.add_argument("--seed", type=int, default=defaults["seed"], help="Seed for training") + parser.add_argument("--rl-max-iterations", type=int, default=defaults["rl_max_iterations"], help="RL rl_algo training iterations.") + parser.add_argument("--irl-max-iterations", type=int, default=defaults["irl_max_iterations"], help="IRL rl_algo training iterations.") + parser.add_argument("--agent-n-steps", type=int, default=defaults["agent_n_steps"], help="Agent max steps") + parser.add_argument("--rl-algo-lib", type=str, default=defaults["rl_algo_lib"], help="library for rl_algo [sb3|rsl]") + parser.add_argument("--rl-algo-type", type=str, default=defaults["rl_algo_type"], help="type of rl_algo [SAC|PPO|manual]") + parser.add_argument("--load-run", type=str, default=defaults['load-run'], help="Name of run to load from logs dir.") + + +def add_env_args(parser, default_overrides={}): + ''' + Add standard environment arguments to the parser. + parser: argparse.ArgumentParser + Argument parser to add arguments to. + ''' + defaults.update(default_overrides) + # parser.add_argument("--cpu", action="store_true", default=defaults["cpu"], help="Use CPU pipeline.") # Deprecated due to support in IsaacLab v1.0.0 + # parser.add_argument("--device", default="cuda:0", help="Device [cpu|cuda:0].") # Deprecated due to support in IsaacLab v1.0.0 + parser.add_argument( + "--disable_fabric", action="store_true", default=defaults["disable_fabric"], help="Disable fabric and use USD I/O operations." + ) + parser.add_argument('-ne', "--num-envs", type=int, default=defaults["num_envs"], help="Number of environments to simulate.") + parser.add_argument('-en', "--env-name", type=str, default=defaults["env_name"], help="Name of the task.") + + +def add_rsl_rl_args(parser: argparse.ArgumentParser, default_overrides={}): + """Add RSL-RL arguments to the parser. + + Args: + parser: The parser to add the arguments to. + """ + defaults.update(default_overrides) + # create a new argument group + arg_group = parser.add_argument_group("rsl_rl", description="Arguments for RSL-RL agent.") + # -- experiment arguments + arg_group.add_argument( + "--experiment_name", type=str, default=None, help="Name of the experiment folder where logs will be stored." + ) + arg_group.add_argument("--run_name", type=str, default=None, help="Run name suffix to the log directory.") + # -- load arguments + arg_group.add_argument("--resume", type=bool, default=None, help="Whether to resume from a checkpoint.") + arg_group.add_argument("--load_run", type=str, default=None, help="Name of the run folder to resume from.") + arg_group.add_argument("--checkpoint", type=str, default=None, help="Checkpoint file to resume from.") + # -- logger arguments + arg_group.add_argument( + "--logger", type=str, default=defaults["logger"], choices={"wandb", "tensorboard", "neptune"}, help="Logger module to use." + ) + arg_group.add_argument( + "--log_project_name", type=str, default=None, help="Name of the logging project when using wandb or neptune." + ) + + +def add_all_wheeled_gym_args(parser, default_overrides={}): + add_logging_args(parser, default_overrides) + add_train_args(parser, default_overrides) + add_env_args(parser, default_overrides) + +def startup(parser=None, prelaunch_callback=None, import_gym_envs=True): + from omni.isaac.lab.app import AppLauncher + ''' + Startup IsaacLab backend. Imports wheeled_gym environments optionally. + Args: + parser: argparse.ArgumentParser, optional, default=None + Argument parser to add arguments to. + prelaunch(args): function to be executed right before launching the app, optional, default=None + Returns: + simulation_app: omni.isaac.dynamic_control.DynamicControl, omni.isaac.dynamic_control._dynamic_control.DynamicControl + Simulation app instance. + args_cli: argparse.Namespace + Parsed command line arguments. + ''' + + if parser is None: + parser = argparse.ArgumentParser(description="Used Boilerplate Starter.") + + AppLauncher.add_app_launcher_args(parser) + args_cli = parser.parse_args() + + if prelaunch_callback is not None: + prelaunch_callback(args_cli) + + if "test_mode" in args_cli and args_cli.test_mode: + args_cli.no_log = True + args_cli.rl_no_log = True + args_cli.no_wandb = True + args_cli.video = False + args_cli.no_checkpoints = True + + if args_cli.video: + args_cli.enable_cameras = True + + # launch omniverse app + app_launcher = AppLauncher(args_cli) + simulation_app = app_launcher.app + + return simulation_app, args_cli diff --git a/source/standalone/train/utils/args.py b/source/standalone/train/utils/args.py new file mode 100644 index 0000000..e242803 --- /dev/null +++ b/source/standalone/train/utils/args.py @@ -0,0 +1,65 @@ +import argparse + +def parse_rsl_rl_cfg(task_name: str, args_cli: argparse.Namespace): + """ + Parse configuration for RSL-RL agent based on inputs. + Needs rsl_rl_cfg_entry_point to load Agent Config + + Args: + task_name: The name of the environment. + args_cli: The command line arguments. + + Returns: + The parsed configuration for RSL-RL agent based on inputs. + """ + from omni.isaac.lab_tasks.utils.parse_cfg import load_cfg_from_registry + + # load the default configuration + rslrl_cfg = load_cfg_from_registry(task_name, "rsl_rl_cfg_entry_point") + + # override the default configuration with CLI arguments + if args_cli.seed is not None: + rslrl_cfg.seed = args_cli.seed + if args_cli.resume is not None: + rslrl_cfg.resume = args_cli.resume + if args_cli.load_run is not None: + rslrl_cfg.load_run = args_cli.load_run + if args_cli.checkpoint is not None: + rslrl_cfg.load_checkpoint = args_cli.checkpoint + if args_cli.run_name is not None: + rslrl_cfg.run_name = args_cli.run_name + if args_cli.logger is not None: + rslrl_cfg.logger = args_cli.logger + # set the project name for wandb and neptune + if rslrl_cfg.logger in {"wandb", "neptune"} and args_cli.log_project_name: + rslrl_cfg.wandb_project = args_cli.log_project_name + rslrl_cfg.neptune_project = args_cli.log_project_name + + # wheeled_gym naming convention overrides + rslrl_cfg.max_iterations = args_cli.rl_max_iterations + rslrl_cfg.num_steps_per_env = args_cli.agent_n_steps + if args_cli.no_log: + rslrl_cfg.log_dir = None + + rslrl_cfg.rl_no_log = args_cli.rl_no_log + + return rslrl_cfg + + +def default_isaac_cfg( + device: str = "cuda:0", num_envs: int | None = None, use_fabric: bool | None = None + ): + default_cfg = {"sim": {"physx": dict()}, "scene": dict()} + + # simulation device + default_cfg["sim"]["device"] = device + + # disable fabric to read/write through USD + if use_fabric is not None: + default_cfg["sim"]["use_fabric"] = use_fabric + + # number of environments + if num_envs is not None: + default_cfg["scene"]["num_envs"] = num_envs + + return default_cfg \ No newline at end of file