Lotus is a lightweight reinforcement learning library written in pure JAX (Flax). It supports jit, vmap, and pmap for fast and scalable training on hardware accelerators.
Clone the repository and install dependencies:
git clone https://github.com/auxeno/lotus
pip install -r lotus/requirements.txt
Train multiple PPO agents on 100 different seeds in parallel:
from lotus import PPO
# Create agent and seeds
agent = PPO.create(env='Breakout-MinAtar')
seeds = jnp.arange(100)
# Vectorised training
train_fn = jax.vmap(agent.train, in_axes=(None, 0))
trained_agents = train_fn(agent, seeds)See the Colab notebook for more examples and advanced usage.
PPO performance comparison for Lotus and CleanRL on the MinAtar Breakout environment. Agents were trained for 500,000 steps on an RTX 4090.
| Algorithm | Discrete | Continuous | Paper |
|---|---|---|---|
| DQN | ✔ | Mnih et al. 2013 | |
| QR-DQN | ✔ | Dabney et al. 2017 | |
| PQN | ✔ | Gallici et al. 2024 | |
| DDPG | ✔ | Lillicrap et al. 2015 | |
| TD3 | ✔ | Fujimoto et al. 2018 | |
| SAC | ✔ | Haarnoja et al. 2018 | |
| PPO | ✔ | Schulman et al. 2017 | |
| Recurrent PPO | ✔ | - |

