A jax/stax implementation of the paper: Paper title [1]
The agent at lpg.agent.py implements the bsuite.baseline.base.Agent interface.
The lpg/environments/*.py interfaces with a dm_env.Environment.
We wrap the gym-atari suite using the bsuite.utils.gym_wrapper.DMEnvFromGym adapter into a dqn.AtariEnv to implement historical observations and actions repeat.
To run the algorithm on a GPU, I suggest to install the gpu version of jax [4]. You can then install this repo using Anaconda python and pip.
conda env create -n template
conda activate template
pip install git+https://github.com/epignatelli/template[1] Paper title.