Skip to content

idanshen/multi_ref

Repository files navigation

Multi-Reference KL RL(H)F

This repository contains the code used for the experiments in “Theoretical Analysis of KL-regularized RLHF with Multiple Reference Models.” arXiv:2502.01203

It implements exact multi-reference objectives for both Reverse KL (RKL) and Forward KL (FKL) formulations, including training and evaluation pipelines for reproducing our online GRPO and offline DPO/RLHF results.


🧩 Installation

We recommend Python 3.10+ and a CUDA-enabled environment.

# Create environment
conda create -n multiref python=3.10 -y
conda activate multiref

# Core dependencies
pip install torch --index-url https://download.pytorch.org/whl/cu121
pip install "transformers>=4.43" "datasets>=2.19" "accelerate>=0.33" "trl>=0.9"
pip install vllm "einops>=0.7" "peft>=0.12" bitsandbytes
pip install wandb sentencepiece evaluate tqdm

🚀 Quickstart

Below are minimal example commands to reproduce the experiments. Run each script with --help to see all available options.

1️⃣ Online RL (GRPO) on GSM8K

python gsm8k_grpo.py \
  --dataset gsm8k \
  --ref_a Qwen/Qwen2.5-0.5B-Instruct \
  --ref_b Qwen/Qwen2.5-Math-1.5B \
  --alpha 0.5 \            # RKL weight (geometric mean)
  --beta 0.5 \             # FKL weight (arithmetic mean)
  --kl_type rkl \          # rkl | fkl
  --gamma 0.1 \            # KL strength
  --model_out ./runs/gsm8k_rkl_a05 \
  --per_device_train_batch_size 4 \
  --gradient_accumulation_steps 8 \
  --max_steps 2000 \
  --bf16

2️⃣ Offline DPO / RLHF on UltraFeedback

python dpo_training.py \
  --dataset ultrafeedback \
  --ref_a Qwen/Qwen2.5-0.5B-Instruct \
  --ref_b Qwen/Qwen2.5-1.5B \
  --kl_type rkl \               # or fkl
  --alpha 0.5 \                 # (RKL) geometric mean weight
  --beta 0.5 \                  # (FKL) arithmetic mean weight
  --gamma 1.0 \                 # KL coefficient
  --rm_name Skywork/Reward-Llama-3.1-8B-v0.2 \  # reward model
  --output_dir ./runs/dpo_ultra_rkl_a05 \
  --per_device_train_batch_size 2 \
  --gradient_accumulation_steps 16 \
  --num_train_epochs 1 \
  --bf16

🧾 Citation

If you use this code, please cite:

@article{aminian2025theoretical,
  title={Theoretical Analysis of KL-regularized RLHF with Multiple Reference Models},
  author={Aminian, Gholamali and Asadi, Amir R and Shenfeld, Idan and Mroueh, Youssef},
  journal={arXiv preprint arXiv:2502.01203},
  year={2025}
}

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages