Skip to content

janhuenermann/minrl

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

minrl

Reproduction of GRPO, intended to be educational and efficient:

  • On an RTX 3090, we start with a base model (Qwen3-1.7B-Base) and train on the countdown task, reaching 60% accuracy in under 3 hours.
  • We incorporate improvements from Dr. GRPO and DAPO, namely averaging over tokens rather than over batches and dropping reward scaling.
  • We use LoRA to fit model training into 24GB of GPU memory.

Setup

In order to set up a training environment, you need to install the dependencies, including flash-attn.

# Clone the repository
git clone git@github.com:janhuenermann/minrl.git && cd minrl

# It is recommended to create a new virtual environment
python -m venv .env
source .env/bin/activate

# Install package and core dependencies
pip install -e .

# Install flash-attn with ninja for faster build times (this may take a while)
pip install ninja
MAX_JOBS=4 pip install -v --no-build-isolation flash-attn

Training

To start training, you can simply run the scripts/train_grpo.py script, which will automatically download the required artifacts (dataset and base model) and start training.

python scripts/train_grpo.py

About

Educational implementation of RL for LLMs in Pytorch

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages