Skip to content

ponseko/jaxnasium

Repository files navigation

Jaxnasium: A Lightweight Utility Library for JAX-based RL Projects

PyPI version Python 3.11+ Ruff Documentation Status

Jaxnasium lets you

  1. 🕹️ Import your favourite environments from various libraries with a single API and automatically wrap them to a common standard.
  2. 🚀 Bootstrap new JAX RL projects with a single CLI command and get started instantly with a complete codebase.
  3. 🤖 Jaxnasium comes equiped with standard general RL implementations based on a near-single-file philosophy. You can either import these as off-the-shelf algorithms or copy over the code and tweak them for your problem. These algorithms follow the ideas of PureJaxRL for extremely fast end-to-end RL training in JAX.

For more details, see the 📖 Documentation.

🚀 Getting started

Jaxnasium lets you bootstrap your new reinforcement learning projects directly from the command line. As such, for new projects, the easiest way to get started is via uv:

uvx jaxnasium <projectname>
uv run example_train.py

# ... or via pipx
pipx run jaxnasium <projectname>
# activate a virtual environment in your preferred way, e.g. conda
python example_train.py

This will set up a Python project folder structure with (optionally) an environment template and (optionally) algorithm code for you to tailor to your problem.

For existing projects, you can simply install Jaxnasium via pip and import the required functionality.

pip install jaxnasium
import jax
import jaxnasium as jym
from jaxnasium.algorithms import PPO

env = jym.make("CartPole-v1")
env = jaxnasium.LogWrapper(env)
rng = jax.random.PRNGKey(0)
agent = PPO(total_timesteps=5e5, learning_rate=2.5e-3)
agent = agent.train(rng, env)

🏠 Environments

Jaxnasium is not aimed at delivering a full environment suite. However, it does come equipped with a jym.make(...) command to import environments from existing suites (provided that these are installed) and wrap them appropriately to the Jaxnasium API standard. For example, using environments from Gymnax:

import jaxnasium as jym
from jaxnasium.algorithms import PPO
import jax

env = jym.make("Breakout-MinAtar")
env = jym.FlattenObservationWrapper(env)
env = jym.LogWrapper(env)

agent = PPO(**some_good_hyperparameters)
agent = agent.train(jax.random.PRNGKey(0), env)

# > Using an environment from Gymnax via gymnax.make(Breakout-MinAtar).
# > Wrapping Gymnax environment with GymnaxWrapper
# >  Disable this behavior by passing wrapper=False
# > Wrapping environment in VecEnvWrapper
# > ... training results

!!!info For convenience, Jaxnasium does include the 5 classic-control environments.

See the Environments page for a complete list of available environments.

Environment API

The Jaxnasium API stays close to the somewhat established Gymnax API for the reset() and step() functions, but allows for truncated episodes in a manner closer to Gymnasium.

env = jym.make(...)

obs, env_state = env.reset(key) # <-- Mirroring Gymnax

# env.step(): Gymnasium Timestep tuple with state information
(obs, reward, terminated, truncated, info), env_state = env.step(key, state, action)

🤖 Algorithms

Algorithms in jaxnasium.algorithms are built following a near-single-file implementation philosophy in mind. In contrast to implementations in CleanRL or PureJaxRL, Jaxnasium algorithms are built in Equinox and follow a class-based design with a familiar Stable-Baselines API.

from jaxnasium.algorithms import PPO
import jax

env = ...
agent = PPO(**some_good_hyperparameters)
agent = agent.train(jax.random.PRNGKey(0), env)

See the Algorithms for more details on the included algorithms..

Available Algorithms

Algorithm Multi-Agent1 Observation Spaces Action Spaces Composite (nested) Spaces2
PPO Box, Discrete, MultiDiscrete Box, Discrete, MultiDiscrete
DQN Box, Discrete, MultiDiscrete Discrete, MultiDiscrete3
PQN Box, Discrete, MultiDiscrete Discrete, MultiDiscrete3
SAC Box, Discrete, MultiDiscrete Box, Discrete, MultiDiscrete

1 All algorithms support automatic multi-agent transformation through the auto_upgrade_multi_agent parameter. See Multi-Agent documentation for more information.

2 Algorithms support composite (nested) spaces. See Spaces documentation for more information.

3 MultiDiscrete action spaces in PQN and DQN are only supported when flattening to a Discrete action space. E.g. via the FlattenActionSpaceWrapper.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Packages

No packages published

Contributors 2

  •  
  •  

Languages