Jaxnasium lets you
- 🕹️ Import your favourite environments from various libraries with a single API and automatically wrap them to a common standard.
- 🚀 Bootstrap new JAX RL projects with a single CLI command and get started instantly with a complete codebase.
- 🤖 Jaxnasium comes equiped with standard general RL implementations based on a near-single-file philosophy. You can either import these as off-the-shelf algorithms or copy over the code and tweak them for your problem. These algorithms follow the ideas of PureJaxRL for extremely fast end-to-end RL training in JAX.
For more details, see the 📖 Documentation.
Jaxnasium lets you bootstrap your new reinforcement learning projects directly from the command line. As such, for new projects, the easiest way to get started is via uv:
uvx jaxnasium <projectname> uv run example_train.py # ... or via pipx pipx run jaxnasium <projectname> # activate a virtual environment in your preferred way, e.g. conda python example_train.py
This will set up a Python project folder structure with (optionally) an environment template and (optionally) algorithm code for you to tailor to your problem.
For existing projects, you can simply install Jaxnasium via pip and import the required functionality.
pip install jaxnasium
import jax import jaxnasium as jym from jaxnasium.algorithms import PPO env = jym.make("CartPole-v1") env = jaxnasium.LogWrapper(env) rng = jax.random.PRNGKey(0) agent = PPO(total_timesteps=5e5, learning_rate=2.5e-3) agent = agent.train(rng, env)
Jaxnasium is not aimed at delivering a full environment suite. However, it does come equipped with a jym.make(...) command to import environments from existing suites (provided that these are installed) and wrap them appropriately to the Jaxnasium API standard. For example, using environments from Gymnax:
import jaxnasium as jym
from jaxnasium.algorithms import PPO
import jax
env = jym.make("Breakout-MinAtar")
env = jym.FlattenObservationWrapper(env)
env = jym.LogWrapper(env)
agent = PPO(**some_good_hyperparameters)
agent = agent.train(jax.random.PRNGKey(0), env)
# > Using an environment from Gymnax via gymnax.make(Breakout-MinAtar).
# > Wrapping Gymnax environment with GymnaxWrapper
# > Disable this behavior by passing wrapper=False
# > Wrapping environment in VecEnvWrapper
# > ... training results!!!info For convenience, Jaxnasium does include the 5 classic-control environments.
See the Environments page for a complete list of available environments.
The Jaxnasium API stays close to the somewhat established Gymnax API for the reset() and step() functions, but allows for truncated episodes in a manner closer to Gymnasium.
env = jym.make(...)
obs, env_state = env.reset(key) # <-- Mirroring Gymnax
# env.step(): Gymnasium Timestep tuple with state information
(obs, reward, terminated, truncated, info), env_state = env.step(key, state, action)Algorithms in jaxnasium.algorithms are built following a near-single-file implementation philosophy in mind. In contrast to implementations in CleanRL or PureJaxRL, Jaxnasium algorithms are built in Equinox and follow a class-based design with a familiar Stable-Baselines API.
from jaxnasium.algorithms import PPO
import jax
env = ...
agent = PPO(**some_good_hyperparameters)
agent = agent.train(jax.random.PRNGKey(0), env)See the Algorithms for more details on the included algorithms..
| Algorithm | Multi-Agent1 | Observation Spaces | Action Spaces | Composite (nested) Spaces2 |
|---|---|---|---|---|
| PPO | ✅ | Box, Discrete, MultiDiscrete |
Box, Discrete, MultiDiscrete |
✅ |
| DQN | ✅ | Box, Discrete, MultiDiscrete |
Discrete, MultiDiscrete3 |
✅ |
| PQN | ✅ | Box, Discrete, MultiDiscrete |
Discrete, MultiDiscrete3 |
✅ |
| SAC | ✅ | Box, Discrete, MultiDiscrete |
Box, Discrete, MultiDiscrete |
✅ |
1 All algorithms support automatic multi-agent transformation through the auto_upgrade_multi_agent parameter. See Multi-Agent documentation for more information.
2 Algorithms support composite (nested) spaces. See Spaces documentation for more information.
3 MultiDiscrete action spaces in PQN and DQN are only supported when flattening to a Discrete action space. E.g. via the FlattenActionSpaceWrapper.