This repository contains the code used for the experiments in “Theoretical Analysis of KL-regularized RLHF with Multiple Reference Models.” arXiv:2502.01203
It implements exact multi-reference objectives for both Reverse KL (RKL) and Forward KL (FKL) formulations, including training and evaluation pipelines for reproducing our online GRPO and offline DPO/RLHF results.
We recommend Python 3.10+ and a CUDA-enabled environment.
# Create environment
conda create -n multiref python=3.10 -y
conda activate multiref
# Core dependencies
pip install torch --index-url https://download.pytorch.org/whl/cu121
pip install "transformers>=4.43" "datasets>=2.19" "accelerate>=0.33" "trl>=0.9"
pip install vllm "einops>=0.7" "peft>=0.12" bitsandbytes
pip install wandb sentencepiece evaluate tqdmBelow are minimal example commands to reproduce the experiments.
Run each script with --help to see all available options.
python gsm8k_grpo.py \
--dataset gsm8k \
--ref_a Qwen/Qwen2.5-0.5B-Instruct \
--ref_b Qwen/Qwen2.5-Math-1.5B \
--alpha 0.5 \ # RKL weight (geometric mean)
--beta 0.5 \ # FKL weight (arithmetic mean)
--kl_type rkl \ # rkl | fkl
--gamma 0.1 \ # KL strength
--model_out ./runs/gsm8k_rkl_a05 \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 8 \
--max_steps 2000 \
--bf16python dpo_training.py \
--dataset ultrafeedback \
--ref_a Qwen/Qwen2.5-0.5B-Instruct \
--ref_b Qwen/Qwen2.5-1.5B \
--kl_type rkl \ # or fkl
--alpha 0.5 \ # (RKL) geometric mean weight
--beta 0.5 \ # (FKL) arithmetic mean weight
--gamma 1.0 \ # KL coefficient
--rm_name Skywork/Reward-Llama-3.1-8B-v0.2 \ # reward model
--output_dir ./runs/dpo_ultra_rkl_a05 \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 16 \
--num_train_epochs 1 \
--bf16If you use this code, please cite:
@article{aminian2025theoretical,
title={Theoretical Analysis of KL-regularized RLHF with Multiple Reference Models},
author={Aminian, Gholamali and Asadi, Amir R and Shenfeld, Idan and Mroueh, Youssef},
journal={arXiv preprint arXiv:2502.01203},
year={2025}
}