Benchmarking Backends for Reinforcement Learning: PyTorch vs JAX (Flax.NNX) vs JAX (Flax.Linen)
This project uses two separate .env files:
Run the setup script to create setup/docker/.env with your user/group IDs (required for proper file permissions in containers):
./setup/scripts/create-env.shThis creates:
UID=1000
GID=1000
DOCKER_GID=999
Copy the example file and add your WandB credentials:
cp .env.example .envThen edit .env:
WANDB_API_KEY=your_api_key_here
WANDB_ENTITY=your_username_or_team
python -m benchback_rl.benchmarks.runnersetup/docker/Dockerfile.run with setup/docker/docker-compose.run.yml is used to run benchmarks in a reproducible way using requirements.txt.
setup/docker/Dockerfile.dev with setup/docker/docker-compose.dev.yml and .devcontainer/devcontainer.json is used for development, installing dependencies from pyproject.toml. setup/scripts/export_requirements.sh is used to generate requirements.txt from within the development container.
This repo installs jax from the docker nvcr.io/nvidia/jax:25.10-py3 container for GPU support. It also installs torch with its bundled CUDA dependencies. This way each package is using its own CUDA libraries for best performance and compatibility at the cost of a larger container image.
All implementations follow the 13 core implementation details from The 37 Implementation Details of Proximal Policy Optimization.
Rollout Buffer Storage Layout
The buffer stores transitions with the following semantics:
obs[t]— observation fed to the network at step taction[t],log_prob[t],value[t]— network outputs givenobs[t]reward[t],done[t]— result of takingaction[t]in the environmentobs[t+1]— next observation (stored at next index)
This means done[t] indicates whether the episode ended after taking action[t], not whether obs[t] is the first observation of a new episode. The buffer stores num_steps + 1 observations (including the final bootstrap observation) but only num_steps of everything else.
Termination vs Truncation
Gymnax environments combine true terminations (agent reached terminal state) and truncations (time limit reached) into a single done flag. We accept this simplification, which introduces a small bias for truncated episodes: when an episode is truncated due to time limit, the bootstrap value should ideally be V(final_obs) rather than 0, since the episode could have continued. However:
- For environments with natural termination conditions (CartPole, Atari), true terminations dominate
- The bias is typically small for well-tuned time limits
- Handling truncation separately would require modifications to gymnax or manual time tracking
Buffer Reset Behavior
The buffer does NOT automatically carry forward the final observation to the next rollout. The caller must explicitly:
- Call
buffer.reset()to clear the step counter - Call
buffer.set_initial_obs(obs)with the appropriate starting observation
This explicit API prevents subtle bugs where stale observations might be used.
Located in src/benchback_rl/rl_torch/, this RL implementation uses PyTorch with an object oriented design. The main training loop is in train.py, while the model definitions are in models.py. It uses environments that are running on the GPU via gymnax using JAX, transferring tensors between PyTorch and JAX using DLPack for efficiency.
Located in src/benchback_rl/rl_jax_nnx/, this RL implementation uses JAX with the Flax.NNX library. The design is object oriented, similar to the PyTorch implementation, while allowing jittable jax exectution under the hood, as per Flax.NNX's design philosophy. The main training loop is in train.py, while the model definitions are in models.py.
Located in src/benchback_rl/rl_jax_linen/, this RL implementation uses JAX with the Flax.Linen library. The design is functional, following Flax.Linen's design philosophy. The main training loop is in train.py, while the model definitions are in models.py.