This is a working version of Kurtland Chua's (guy who gave PETS!!! 🙇♂️) repo. The repo has been modified to work with newer versions jax and gym, since JAX has officially stopped supporting CUDA 11. This version works on CUDA 12.
Currently only implements PETS, MBPO and a Model-based Policy Agent.
There aren't many usable implementations of these MBRL algorithms in JAX, which makes this a valuable resource for runnning MBRL experiments using JAX.
Warning: This is a work-in-progress, and has not been evaluated on harder environments!
A Dockerfile with all required dependencies is provided in the /docker/ folder. This is different from the original Dockerfile provided by Kurt, and uses Python 3.10.12 along with jax-v0.4.16, gym-v0.26.2 and mujoco_py-v2.1.2.14. There are some Cython compilation issues with mujoco_py in singularity environments running on a SLURM system: #687, #644. The Dockerfile in this repo uses a modified version of mujoco_py which is compatible with Singularity, run the docker_build.sh with appropriate tags. Alternatively use a prebuilt container: docker pull avirupdas55/jax:kchua.
A starter script for running an example experiment on cartpole is provided in model_based_experiment.py.
This script can be run via
python3 model_based_experiment.py
--logdir DIR (optional) Directory for saving checkpoints and
rollout recordings.
--save-every FREQ (optional) Saving frequency. Defaults to 1 (i.e.
save after every iteration)
--keep-all-checkpoints (optional) Flag which enables saving of all
checkpoints (instead of only the most
recent one).
--iters ITERS (optional) Number of training iterations to run.
Defaults to 100.
-s SEED (optional) Experiment random seed. If not
provided, uniformly chosen in
[0, 10000).
env ENV (required) Experiment environment. Currently
supports [`MujocoCartpole-v0`,
`HalfCheetah-v3`]
agent_type AGENT (required) Agent type. Choices: [PETS, Policy].
For example, to run PETS and save recordings of rollouts to /external/:
python3 model_based_experiment.py --logdir /external/ MujocoCartpole-v0 PETS