Skip to content

KluvaDa/lotus

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

93 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

🪷 Lotus

A high-performance JAX reinforcement learning library

Python License


Lotus is a lightweight reinforcement learning library written in pure JAX (Flax). It supports jit, vmap, and pmap for fast and scalable training on hardware accelerators.

Install

Clone the repository and install dependencies:

git clone https://github.com/auxeno/lotus
pip install -r lotus/requirements.txt

Quick Start

Train multiple PPO agents on 100 different seeds in parallel:

from lotus import PPO

# Create agent and seeds
agent = PPO.create(env='Breakout-MinAtar')
seeds = jnp.arange(100)

# Vectorised training
train_fn = jax.vmap(agent.train, in_axes=(None, 0))
trained_agents = train_fn(agent, seeds)

License

See the Colab notebook for more examples and advanced usage.

Performance

PPO performance comparison for Lotus and CleanRL on the MinAtar Breakout environment. Agents were trained for 500,000 steps on an RTX 4090.

Supported Algorithms

Algorithm Discrete Continuous Paper
DQN Mnih et al. 2013
QR-DQN Dabney et al. 2017
PQN Gallici et al. 2024
DDPG Lillicrap et al. 2015
TD3 Fujimoto et al. 2018
SAC Haarnoja et al. 2018
PPO Schulman et al. 2017
Recurrent PPO -

About

High-performance JAX (Flax) reinforcement learning library

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 95.7%
  • Jupyter Notebook 4.3%