diff --git a/ss2r/algorithms/ppo/__init__.py b/ss2r/algorithms/ppo/__init__.py index 85ed3df25..485da0585 100644 --- a/ss2r/algorithms/ppo/__init__.py +++ b/ss2r/algorithms/ppo/__init__.py @@ -94,6 +94,7 @@ def get_train_fn(cfg, checkpoint_path, restore_checkpoint_path): **agent_cfg, **training_cfg, network_factory=network_factory, + checkpoint_logdir=checkpoint_path if cfg.training.store_checkpoint else None, restore_checkpoint_path=restore_checkpoint_path, penalizer=penalizer, penalizer_params=penalizer_params, diff --git a/ss2r/algorithms/ppo/franka_ppo_to_onnx.py b/ss2r/algorithms/ppo/franka_ppo_to_onnx.py index 3182d9486..3ca0a4c38 100644 --- a/ss2r/algorithms/ppo/franka_ppo_to_onnx.py +++ b/ss2r/algorithms/ppo/franka_ppo_to_onnx.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import logging from collections.abc import Mapping, Sequence from typing import Any diff --git a/ss2r/algorithms/ppo/train.py b/ss2r/algorithms/ppo/train.py index 8d1e62fa0..5e29e0e47 100644 --- a/ss2r/algorithms/ppo/train.py +++ b/ss2r/algorithms/ppo/train.py @@ -29,9 +29,9 @@ from brax import envs from brax.training import pmap, types from brax.training.acme import running_statistics, specs +from brax.training.agents.ppo import checkpoint from brax.training.types import Params, PRNGKey -from etils import epath -from orbax import checkpoint as ocp +from ml_collections import config_dict from ss2r.algorithms.penalizers import Penalizer from ss2r.algorithms.ppo import _PMAP_AXIS_NAME, Metrics, TrainingState @@ -40,6 +40,7 @@ from ss2r.algorithms.ppo import training_step as ppo_training_step from ss2r.algorithms.ppo.wrappers import TrackOnlineCosts from ss2r.rl.evaluation import ConstraintsEvaluator +from ss2r.rl.utils import restore_state def _unpmap(v): @@ -92,6 +93,7 @@ def train( normalize_advantage: bool = True, eval_env: Optional[envs.Env] = None, policy_params_fn: Callable[..., None] = lambda *args: None, + checkpoint_logdir: Optional[str] = None, restore_checkpoint_path: Optional[str] = None, safety_budget: float = float("inf"), penalizer: Penalizer | None = None, @@ -287,20 +289,50 @@ def training_epoch_with_timing( penalizer_params=penalizer_params, ) # type: ignore - if ( - restore_checkpoint_path is not None - and epath.Path(restore_checkpoint_path).exists() - ): - logging.info("restoring from checkpoint %s", restore_checkpoint_path) - orbax_checkpointer = ocp.PyTreeCheckpointer() - target = training_state.normalizer_params, init_params, penalizer_params - (normalizer_params, init_params) = orbax_checkpointer.restore( - restore_checkpoint_path, item=target + if restore_checkpoint_path is not None: + loaded_params = checkpoint.load(restore_checkpoint_path) + restored_normalizer = restore_state( + loaded_params[0], training_state.normalizer_params ) + restored_network_params = training_state.params + restored_penalizer_params = training_state.penalizer_params + restored_optimizer_state = training_state.optimizer_state + + if len(loaded_params) >= 2: + try: + restored_network_params = restore_state( + loaded_params[1], training_state.params + ) + except Exception: + if len(loaded_params) >= 3: + restored_network_params = training_state.params.replace( # type: ignore + policy=restore_state( + loaded_params[1], training_state.params.policy + ), + value=restore_state( + loaded_params[2], training_state.params.value + ), + ) + if len(loaded_params) >= 3: + try: + restored_penalizer_params = restore_state( + loaded_params[2], training_state.penalizer_params + ) + except Exception: + pass + if len(loaded_params) >= 4: + try: + restored_optimizer_state = restore_state( + loaded_params[3], training_state.optimizer_state + ) + except Exception: + pass + training_state = training_state.replace( # type: ignore - normalizer_params=normalizer_params, - params=init_params, - penalizer_params=penalizer_params, + normalizer_params=restored_normalizer, + params=restored_network_params, + penalizer_params=restored_penalizer_params, + optimizer_state=restored_optimizer_state, ) # type: ignore if num_timesteps == 0: @@ -383,6 +415,22 @@ def training_epoch_with_timing( progress_fn(current_step, metrics) params = _unpmap((training_state.normalizer_params, training_state.params)) policy_params_fn(current_step, make_policy, params) + if checkpoint_logdir: + checkpoint_params = _unpmap( + ( + training_state.normalizer_params, + training_state.params, + training_state.penalizer_params, + training_state.optimizer_state, + ) + ) + dummy_ckpt_config = config_dict.ConfigDict() + checkpoint.save( + checkpoint_logdir, + current_step, + checkpoint_params, + dummy_ckpt_config, + ) total_steps = current_step assert total_steps >= num_timesteps diff --git a/ss2r/configs/experiment/rccar_sim_to_real_unsafe_ppo.yaml b/ss2r/configs/experiment/rccar_sim_to_real_unsafe_ppo.yaml new file mode 100644 index 000000000..eae63906b --- /dev/null +++ b/ss2r/configs/experiment/rccar_sim_to_real_unsafe_ppo.yaml @@ -0,0 +1,38 @@ +# @package _global_ +defaults: + - override /environment: rccar_real + - override /agent: ppo + - override /agent/penalizer: null + - _self_ + +environment: + action_delay: 1 + observation_delay: 0 + sliding_window: 5 + dt: 0.03333333 + sample_init_pose: true + init_pose: [1.42, -1.04, -3.142] + +training: + num_envs: 4096 + num_timesteps: 250000000 + episode_length: 250 + train_domain_randomization: true + eval_domain_randomization: true + safe: false + +agent: + batch_size: 256 + discounting: 0.97 + entropy_cost: 0.01 + learning_rate: 0.0003 + max_grad_norm: 1.0 + policy_hidden_layer_sizes: [64, 64] + value_hidden_layer_sizes: [64, 64] + normalize_observations: true + num_minibatches: 32 + num_resets_per_eval: 1 + num_updates_per_batch: 4 + reward_scaling: 1.0 + unroll_length: 20 + activation: swish