Skip to content

KluvaDa/benchback-rl

Repository files navigation

benchback-rl

Benchmarking Backends for Reinforcement Learning: PyTorch vs JAX (Flax.NNX) vs JAX (Flax.Linen)

Setup

Environment Variables

This project uses two separate .env files:

1. Docker User IDs (setup/docker/.env)

Run the setup script to create setup/docker/.env with your user/group IDs (required for proper file permissions in containers):

./setup/scripts/create-env.sh

This creates:

UID=1000
GID=1000
DOCKER_GID=999

2. WandB Credentials (.env)

Copy the example file and add your WandB credentials:

cp .env.example .env

Then edit .env:

WANDB_API_KEY=your_api_key_here
WANDB_ENTITY=your_username_or_team

Running Benchmarks

python -m benchback_rl.benchmarks.runner

Repository Structure

Packages and Installation

setup/docker/Dockerfile.run with setup/docker/docker-compose.run.yml is used to run benchmarks in a reproducible way using requirements.txt.

setup/docker/Dockerfile.dev with setup/docker/docker-compose.dev.yml and .devcontainer/devcontainer.json is used for development, installing dependencies from pyproject.toml. setup/scripts/export_requirements.sh is used to generate requirements.txt from within the development container.

This repo installs jax from the docker nvcr.io/nvidia/jax:25.10-py3 container for GPU support. It also installs torch with its bundled CUDA dependencies. This way each package is using its own CUDA libraries for best performance and compatibility at the cost of a larger container image.

Implementations

All implementations follow the 13 core implementation details from The 37 Implementation Details of Proximal Policy Optimization.

Design Decisions

Rollout Buffer Storage Layout

The buffer stores transitions with the following semantics:

  • obs[t] — observation fed to the network at step t
  • action[t], log_prob[t], value[t] — network outputs given obs[t]
  • reward[t], done[t]result of taking action[t] in the environment
  • obs[t+1] — next observation (stored at next index)

This means done[t] indicates whether the episode ended after taking action[t], not whether obs[t] is the first observation of a new episode. The buffer stores num_steps + 1 observations (including the final bootstrap observation) but only num_steps of everything else.

Termination vs Truncation

Gymnax environments combine true terminations (agent reached terminal state) and truncations (time limit reached) into a single done flag. We accept this simplification, which introduces a small bias for truncated episodes: when an episode is truncated due to time limit, the bootstrap value should ideally be V(final_obs) rather than 0, since the episode could have continued. However:

  1. For environments with natural termination conditions (CartPole, Atari), true terminations dominate
  2. The bias is typically small for well-tuned time limits
  3. Handling truncation separately would require modifications to gymnax or manual time tracking

Buffer Reset Behavior

The buffer does NOT automatically carry forward the final observation to the next rollout. The caller must explicitly:

  1. Call buffer.reset() to clear the step counter
  2. Call buffer.set_initial_obs(obs) with the appropriate starting observation

This explicit API prevents subtle bugs where stale observations might be used.

PyTorch

Located in src/benchback_rl/rl_torch/, this RL implementation uses PyTorch with an object oriented design. The main training loop is in train.py, while the model definitions are in models.py. It uses environments that are running on the GPU via gymnax using JAX, transferring tensors between PyTorch and JAX using DLPack for efficiency.

JAX (Flax.NNX)

Located in src/benchback_rl/rl_jax_nnx/, this RL implementation uses JAX with the Flax.NNX library. The design is object oriented, similar to the PyTorch implementation, while allowing jittable jax exectution under the hood, as per Flax.NNX's design philosophy. The main training loop is in train.py, while the model definitions are in models.py.

JAX (Flax.Linen)

Located in src/benchback_rl/rl_jax_linen/, this RL implementation uses JAX with the Flax.Linen library. The design is functional, following Flax.Linen's design philosophy. The main training loop is in train.py, while the model definitions are in models.py.

About

Benchmarking Backends for Reinforcement Learning - PyTorch, JAX.NNX, Flax.linen

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published