Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions scripts/reinforcement_learning/rlopt/train.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
# Feiyang Wu (feiyangwu@gatech.edu), based on sb3/trian.py
# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

# Feiyang Wu (feiyangwu@gatech.edu), based on sb3/train.py

"""Script to train RL agent with Stable Baselines3."""

Expand Down Expand Up @@ -35,7 +40,7 @@
dest="algorithm",
type=str.upper,
default="PPO",
choices=["PPO", "SAC", "IPMD"],
choices=["PPO", "SAC", "IPMD", "FASTTD3"],
help="RLOpt algorithm to train (must match the agent config).",
)

Expand Down Expand Up @@ -85,7 +90,7 @@ def cleanup_pbar(*args):

import gymnasium as gym
import torch
from rlopt.agent import IPMD, PPO, SAC
from rlopt.agent import IPMD, PPO, SAC, FastTD3
from torchrl.data import TensorDictReplayBuffer
from torchrl.data.replay_buffers.storages import LazyMemmapStorage
from torchrl.envs import (
Expand Down Expand Up @@ -120,6 +125,7 @@ def cleanup_pbar(*args):
"PPO": PPO,
"SAC": SAC,
"IPMD": IPMD,
"FASTTD3": FastTD3,
}


Expand Down Expand Up @@ -169,6 +175,7 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
args_cli.max_iterations * agent_cfg.collector.total_frames * env_cfg.scene.num_envs
)
agent_cfg.collector.frames_per_batch *= env_cfg.scene.num_envs
agent_cfg.collector.init_random_frames *= env_cfg.scene.num_envs
# set the environment seed
# note: certain randomizations occur in the environment initialization so we set the seed here
env_cfg.seed = agent_cfg.seed
Expand Down Expand Up @@ -236,13 +243,17 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
env = env.set_info_dict_reader(
IsaacLabTerminalObsReader(observation_spec=env.observation_spec, backend="gymnasium") # type: ignore
)
if args_cli.algorithm in ["FASTTD3", "SAC"]:
# off-policy algorithms, should not use normalization in environment wrapper
transform = Compose(
RewardSum(),
StepCounter(1000),
)
else:
transform = Compose(RewardSum(), StepCounter(1000), VecNormV2(in_keys=agent_cfg.policy.input_keys + ["reward"]))
env = TransformedEnv(
env=env,
transform=Compose(
RewardSum(), # type: ignore
StepCounter(1000), # type: ignore
VecNormV2(in_keys=policy_in_keys + ["reward"]),
),
transform=transform,
)

agent_class = ALGORITHM_CLASS_MAP[args_cli.algorithm]
Expand Down
7 changes: 6 additions & 1 deletion source/isaaclab_rl/isaaclab_rl/rlopt.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

from collections import deque
import os

import gymnasium as gym
import torch
from rlopt.agent import IPMDRLOptConfig, PPORLOptConfig, SACRLOptConfig # noqa: F401
from rlopt.agent import IPMDRLOptConfig, PPORLOptConfig, SACRLOptConfig, FastTD3RLOptConfig # noqa: F401
from rlopt.config_base import RLOptConfig
from torchrl.data.tensor_specs import Bounded, Composite, Unbounded
from torchrl.envs.libs.gym import GymWrapper, _gym_to_torchrl_spec_transform, terminal_obs_reader
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
"rlopt_cfg_entry_point": f"{agents.__name__}.rlopt_sac_cfg:G1RLOptSACFlatConfig",
"rlopt_ppo_cfg_entry_point": f"{agents.__name__}.rlopt_ppo_cfg:G1RLOptPPOFlatConfig",
"rlopt_sac_cfg_entry_point": f"{agents.__name__}.rlopt_sac_cfg:G1RLOptSACFlatConfig",
"rlopt_fasttd3_cfg_entry_point": f"{agents.__name__}.rlopt_fasttd3_cfg:G1RLOptFastTD3FlatConfig",
},
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

from isaaclab.utils import configclass

from isaaclab_rl.rlopt import FastTD3RLOptConfig


# Convenience configurations for different scenarios
@configclass
class G1RLOptFastTD3Config(FastTD3RLOptConfig):
"""RLOpt FastTD3 configuration for G1.

Note: input_dim values are left as None for lazy initialization.
The networks will automatically infer dimensions from the environment specs.
"""

def __post_init__(self):
"""Post-initialization setup."""
super().__post_init__()

# Collector settings
self.collector.frames_per_batch = 1 # num_steps_per_env (multiplied by num_envs in train.py)
self.collector.init_random_frames = 10

# FastTD3 settings
self.fasttd3.gamma = 0.99
self.fasttd3.policy_noise = 0.001
self.fasttd3.noise_clip = 0.5
self.fasttd3.use_cdq = True
self.fasttd3.disable_bootstrap = False
self.fasttd3.v_min = -10.0
self.fasttd3.v_max = 10.0
self.fasttd3.batch_size = 8
self.fasttd3.action_bounds = 1.0
self.fasttd3.std_max = 0.4
self.fasttd3.num_atoms = 251
self.fasttd3.tau = 0.1
self.fasttd3.num_updates = 4
self.fasttd3.num_steps = 8

# optimizer
self.optim.optimizer = "adamw"
self.optim.weight_decay = 0.1
self.optim.lr = 3e-4
self.optim.max_grad_norm = None

# buffer
self.replay_buffer.size = 1024 * 10
self.replay_buffer.prb = False


@configclass
class G1RLOptFastTD3FlatConfig(G1RLOptFastTD3Config):
"""RLOpt SAC configuration for G1 on flat terrain."""

def __post_init__(self):
"""Post-initialization setup for flat terrain."""
super().__post_init__()

# assert self.q_function is not None, "Q function configuration must be provided."

# Network architecture for flat terrain
self.fasttd3.num_steps = 8
self.fasttd3.num_updates = 4

# Training duration
self.collector.total_frames = 100_000_000