A research framework for learning and evaluating state embeddings in reinforcement learning (RL) environments. This project provides tools for pretraining, training, and evaluating RL agents with learned state representations, including transformer-based embeddings and reconstruction objectives.
- Transformer-based State Embeddings: Learn compact representations of environment states using transformer encoders.
- Reconstruction Objectives: Jointly train RL agents with state reconstruction losses.
- Flexible Environment Wrappers: Context and embedding wrappers for stacking observations and using learned embeddings.
- Pretraining & Evaluation: Pretrain embeddings with various objectives and evaluate with linear probes.
- Support for Multiple Environments: Includes experiments for CartPole, MinAtar Breakout, and MiniGrid Unlock.
RLProject_StateEmbeddings/
│
├── src/
│ ├── main.py # Entry point for running experiments
│ ├── breakout_experiment.py # Breakout (MinAtar) experiment routines
│ ├── cartpole_experiment.py # CartPole experiment routines
│ ├── minigrid_memory_experiment.py # MiniGrid Memory experiment routines
│ ├── minigrid_unlock_experiment.py # MiniGrid Unlock experiment routines
│ ├── seaquest_experiment.py # MinAtar Seaquest experiment routines
│ ├── seaquestmarkov_unlock_experiment.py # MinAtar Seaquest_alt experiment routines
│ ├── state_embedding/
│ ├── env.py # Context and embedding environment wrappers
│ ├── embedding.py # StateEmbedding and StateDecoder modules
│ ├── embedding_eval.py # Linear probe for embedding evaluation
│ ├── train.py # Pretraining routines for embeddings
│ ├── callbacks.py # Custom RL training callbacks
│ └── dqn/
│ ├── dqn.py # DQN with reconstruction loss
│ └── qnetwork.py # Q-network with embedding and reconstruction
│ └── envs/
│ └── seaquest_markov.py # Implementation of an alternative Seaquest env
│
└── README.md # This file
This project uses uv for Python package management and virtual environments.
-
Clone this repo:
git clone https://github.com/thibautklenke/RLProject_StateEmbeddings cd RLProject_StateEmbeddings -
Create and activate a virtual environment:
uv venv source .venv/bin/activate -
Install dependencies:
uv sync
The main entry point is src/main.py, which manages pretraining and training for all supported environments.
To run all experiments (pretraining and training for each environment and seed):
uv run main- By default, this will run pretraining once per environment, then train agents with different seeds.
- You can comment/uncomment lines in
main.pyto select which environments to run.
- Add new environments: Create a new
*_experiment.pyfile following the structure ofcartpole_experiment.py. - Change embedding architecture: Modify
state_embedding/embedding.pyor pass differentembedding_kwargsin experiment files. - Adjust training parameters: Edit variables like
n_pretrain,n_train, ornet_archin the experiment scripts.