Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ss2r/algorithms/ppo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def get_train_fn(cfg, checkpoint_path, restore_checkpoint_path):
**agent_cfg,
**training_cfg,
network_factory=network_factory,
checkpoint_logdir=checkpoint_path if cfg.training.store_checkpoint else None,
restore_checkpoint_path=restore_checkpoint_path,
penalizer=penalizer,
penalizer_params=penalizer_params,
Expand Down
2 changes: 0 additions & 2 deletions ss2r/algorithms/ppo/franka_ppo_to_onnx.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from __future__ import annotations

import logging
from collections.abc import Mapping, Sequence
from typing import Any
Expand Down
76 changes: 62 additions & 14 deletions ss2r/algorithms/ppo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@
from brax import envs
from brax.training import pmap, types
from brax.training.acme import running_statistics, specs
from brax.training.agents.ppo import checkpoint
from brax.training.types import Params, PRNGKey
from etils import epath
from orbax import checkpoint as ocp
from ml_collections import config_dict

from ss2r.algorithms.penalizers import Penalizer
from ss2r.algorithms.ppo import _PMAP_AXIS_NAME, Metrics, TrainingState
Expand All @@ -40,6 +40,7 @@
from ss2r.algorithms.ppo import training_step as ppo_training_step
from ss2r.algorithms.ppo.wrappers import TrackOnlineCosts
from ss2r.rl.evaluation import ConstraintsEvaluator
from ss2r.rl.utils import restore_state


def _unpmap(v):
Expand Down Expand Up @@ -92,6 +93,7 @@ def train(
normalize_advantage: bool = True,
eval_env: Optional[envs.Env] = None,
policy_params_fn: Callable[..., None] = lambda *args: None,
checkpoint_logdir: Optional[str] = None,
restore_checkpoint_path: Optional[str] = None,
safety_budget: float = float("inf"),
penalizer: Penalizer | None = None,
Expand Down Expand Up @@ -287,20 +289,50 @@ def training_epoch_with_timing(
penalizer_params=penalizer_params,
) # type: ignore

if (
restore_checkpoint_path is not None
and epath.Path(restore_checkpoint_path).exists()
):
logging.info("restoring from checkpoint %s", restore_checkpoint_path)
orbax_checkpointer = ocp.PyTreeCheckpointer()
target = training_state.normalizer_params, init_params, penalizer_params
(normalizer_params, init_params) = orbax_checkpointer.restore(
restore_checkpoint_path, item=target
if restore_checkpoint_path is not None:
loaded_params = checkpoint.load(restore_checkpoint_path)
restored_normalizer = restore_state(
loaded_params[0], training_state.normalizer_params
)
restored_network_params = training_state.params
restored_penalizer_params = training_state.penalizer_params
restored_optimizer_state = training_state.optimizer_state

if len(loaded_params) >= 2:
try:
restored_network_params = restore_state(
loaded_params[1], training_state.params
)
except Exception:
if len(loaded_params) >= 3:
restored_network_params = training_state.params.replace( # type: ignore
policy=restore_state(
loaded_params[1], training_state.params.policy
),
value=restore_state(
loaded_params[2], training_state.params.value
),
)
if len(loaded_params) >= 3:
try:
restored_penalizer_params = restore_state(
loaded_params[2], training_state.penalizer_params
)
except Exception:
pass
if len(loaded_params) >= 4:
try:
restored_optimizer_state = restore_state(
loaded_params[3], training_state.optimizer_state
)
except Exception:
pass

training_state = training_state.replace( # type: ignore
normalizer_params=normalizer_params,
params=init_params,
penalizer_params=penalizer_params,
normalizer_params=restored_normalizer,
params=restored_network_params,
penalizer_params=restored_penalizer_params,
optimizer_state=restored_optimizer_state,
) # type: ignore

if num_timesteps == 0:
Expand Down Expand Up @@ -383,6 +415,22 @@ def training_epoch_with_timing(
progress_fn(current_step, metrics)
params = _unpmap((training_state.normalizer_params, training_state.params))
policy_params_fn(current_step, make_policy, params)
if checkpoint_logdir:
checkpoint_params = _unpmap(
(
training_state.normalizer_params,
training_state.params,
training_state.penalizer_params,
training_state.optimizer_state,
)
)
dummy_ckpt_config = config_dict.ConfigDict()
checkpoint.save(
checkpoint_logdir,
current_step,
checkpoint_params,
dummy_ckpt_config,
)

total_steps = current_step
assert total_steps >= num_timesteps
Expand Down
38 changes: 38 additions & 0 deletions ss2r/configs/experiment/rccar_sim_to_real_unsafe_ppo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# @package _global_
defaults:
- override /environment: rccar_real
- override /agent: ppo
- override /agent/penalizer: null
- _self_

environment:
action_delay: 1
observation_delay: 0
sliding_window: 5
dt: 0.03333333
sample_init_pose: true
init_pose: [1.42, -1.04, -3.142]

training:
num_envs: 4096
num_timesteps: 250000000
episode_length: 250
train_domain_randomization: true
eval_domain_randomization: true
safe: false

agent:
batch_size: 256
discounting: 0.97
entropy_cost: 0.01
learning_rate: 0.0003
max_grad_norm: 1.0
policy_hidden_layer_sizes: [64, 64]
value_hidden_layer_sizes: [64, 64]
normalize_observations: true
num_minibatches: 32
num_resets_per_eval: 1
num_updates_per_batch: 4
reward_scaling: 1.0
unroll_length: 20
activation: swish