-
Notifications
You must be signed in to change notification settings - Fork 491
Add fp32 LM head option for GRPO #1387
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @natolambert, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the GRPO training framework by introducing an FP32 LM head option to improve log-probability consistency between the trainer and generator, and by implementing a validation holdout mechanism for better overfitting detection. It also broadens hardware compatibility and development environment by upgrading to CUDA 13.0.0 and adding robust support for NVIDIA DGX Spark, complete with new documentation and dedicated training scripts. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces several valuable enhancements, including an fp32_lm_head option for GRPO to improve logprob alignment, support for DGX Spark (aarch64, CUDA 13), and a new validation_holdout_ratio feature for monitoring overfitting. The changes are well-implemented, with corresponding documentation and example scripts. The code quality is high, and the new features are thoughtfully integrated. I have one suggestion regarding temporary notes in the documentation.
AGENTS.md
Outdated
| # TEMP: FP32 LM head rollout notes (ScaleRL alignment) | ||
| - Goal: reduce generator↔trainer logprob mismatch; apply fp32 at LM head on both sides. | ||
| - References: TRL GRPO `cast_lm_head_to_fp32` (PR #4303, #4446); vLLM patch point is `LogitsProcessor._get_logits`; vLLM plugin system can host the patch; LLaMA-Factory upcast_lmhead_output. | ||
| - Implementation in repo: | ||
| - Trainer: `enable_fp32_lm_head()` patches output embeddings to run in fp32 without re-allocating weights (tied embeddings preserved). | ||
| - Reference policy: same fp32 head toggle to keep KL alignment. | ||
| - Generator: vLLM LogitsProcessor `_get_logits` patched to upcast hidden_states (+ bias) before projection. | ||
| - Flag: `--fp32_lm_head` (GRPO). Ensure vLLM engines receive the flag too. | ||
| - Sanity check (step 0): compare trainer logprobs vs vLLM logprobs on identical sequences; expect tighter deltas with fp32 head. | ||
| - OOM safety (DGX Spark): check `free -h`, kill leftover Ray/vLLM if needed, set `PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True,max_split_size_mb:128"`. | ||
| - Suggested local debug: `scripts/train/dgx-spark/grpo_qwen_gsm8k.sh` (Qwen2.5-0.5B GSM8K), add `--fp32_lm_head` when testing. | ||
| - Long comparison runs: `FP32_LM_HEAD=0 ./scripts/train/dgx-spark/grpo_qwen_gsm8k.sh` then `FP32_LM_HEAD=1 ./scripts/train/dgx-spark/grpo_qwen_gsm8k.sh`; monitor `debug/vllm_vs_local_logprob_diff_*`. | ||
| - Plugin vs monkey patch: plugin is just a startup hook that can apply the same patch; in-process patch is used unless a vLLM plugin package is added. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This section with temporary notes about the FP32 LM head feature seems intended for development and should be removed before merging this pull request to keep the documentation clean and focused on user-facing information. If this information is valuable long-term, consider moving it to a design document or a more permanent part of the documentation.
24574d7 to
6cb3bca
Compare
|
Note on vLLM fp32 LM head options we discussed:
You chose Option B because KV cache dominates memory in RL, so the head cache is relatively small while preserving fp32 head matmul. TODO: current debug metric ( |
Set OPEN_INSTRUCT_FP32_LM_HEAD environment variable at the start of main() before ray.init(), so it propagates to all worker processes via Ray's runtime_env. Previously, the env var was only set in the LLMRayActor process after Ray was initialized, so vLLM worker subprocesses didn't inherit it. This caused _maybe_update_fp32_lm_head_cache() to early-return, meaning the fp32 weight cache was never created and vLLM was still using bf16 weights for the lm_head projection. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Two modes now available via --fp32_lm_head + --fp32_lm_head_permanent: 1. Cache mode (default): Keep bf16 weights, maintain separate fp32 cache - Set OPEN_INSTRUCT_FP32_LM_HEAD=1 - Lower memory overhead but more complex 2. Permanent mode: Convert vLLM lm_head weights to fp32 in-place - Set OPEN_INSTRUCT_FP32_LM_HEAD=2 via --fp32_lm_head_permanent - Similar to TRL/Flash-RL approach - Simpler but uses more memory Both modes cast hidden_states to fp32 for the matmul. The permanent mode converts lm_head.weight.data to fp32 after each weight sync, while cache mode maintains a separate _open_instruct_fp32_weight attribute. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Add --save_logprob_samples flag to save raw vLLM vs trainer logprob
pairs during training. These can be used to create scatter plots
showing the alignment between inference and training logprobs,
similar to Figure 3 in the ScaleRL paper.
New files:
- scripts/analysis/plot_logprob_alignment.py: Creates scatter plots
with Pearson correlation, supports comparison between runs
Usage:
# Run training with sample saving
python open_instruct/grpo_fast.py ... --save_logprob_samples
# Plot results
python scripts/analysis/plot_logprob_alignment.py \
--data_dir /tmp/run/logprob_samples --output plot.png
# Compare runs (e.g., with/without fp32)
python scripts/analysis/plot_logprob_alignment.py \
--data_dirs /tmp/run1/logprob_samples /tmp/run2/logprob_samples \
--labels "Without FP32" "With FP32" --output comparison.png
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…ilable The previous code unconditionally cast hidden_states to fp32 at the start of _get_logits_fp32(), then passed them to the original implementation in the fallback path. This caused dtype mismatches when the original weights were still bf16 (before the first weight sync in permanent mode). Now we only cast hidden_states to fp32 when we actually have fp32 weights to compute with. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
When tie_word_embeddings=True (common in Qwen, Llama, etc.), the lm_head weight and input embedding share the same tensor. Converting lm_head to fp32 in-place would also convert the embedding layer, causing dtype mismatches in early layers. This fix checks if weights are tied by comparing data_ptr(), and falls back to cache mode when weights are tied to avoid the issue. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add --use_probs flag to convert logprobs to probabilities via exp() - Probabilities plot on 0-1 scale for better visualization - Supports comparison plots with multiple runs side-by-side Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Numpy doesn't support BFloat16, so convert to float32 before calling .cpu().numpy(). This fixes the TypeError when saving logprob samples without fp32_lm_head enabled. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
The previous code saved vLLM vs trainer logprobs on every mini-batch, but trainer weights get updated after each backward/step. This caused apparent misalignment because we were comparing: - vLLM logprobs: computed during generation with weights W - Trainer logprobs: computed after optimizer.step() with weights W' Now only save samples on epoch 0, mini-batch 0 of each training step, when trainer weights still match the vLLM weights from generation. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- maybe_enable_fp32_lm_head() → patch_vllm_for_fp32_logits() - _maybe_update_fp32_lm_head_cache() → _sync_fp32_lm_head() The new names better describe what the functions do: - patch_vllm_for_fp32_logits: patches vLLM LogitsProcessor for fp32 computation - _sync_fp32_lm_head: syncs fp32 LM head weights after trainer broadcast Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Permanent mode now converts lm_head weights to fp32 once at training start, matching Flash-RL and TRL implementations: Trainer side (model_utils.py): - enable_fp32_lm_head() now accepts `permanent` parameter - permanent=True converts weights to fp32 in-place - Handles tied weights by untying lm_head from embed_tokens first vLLM side (vllm_utils_workerwrap.py): - update_weight() now accepts fp32 dtype for lm_head in permanent mode - Weight sync receives fp32 directly, no conversion needed after sync This is more efficient than the cache approach: - No bf16→fp32 conversion on each weight sync - No separate fp32 cache to maintain - Weights are synced as fp32 directly from trainer References: - Flash-RL: https://github.com/yaof20/Flash-RL/blob/main/flash_rl/vllm_patch.py - TRL: huggingface/trl#4303 Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Tests the new permanent fp32 lm_head mode on GSM8K with Qwen2.5-0.5B. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- fp32_comparison_overnight.sh: 4 sequential runs comparing fp32 modes 1. No FP32 (baseline) 2. FP32 cache mode 3. FP32 permanent mode 4. No FP32 (repeat for variance) - grpo_qwen_gsm8k_short.sh: Config with shorter response_length=256 that showed good reward scores (faster training) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- fp32_overnight_4runs.sh: Runs all 4 combinations 1. grpo_qwen_gsm8k.sh (long) + no fp32 2. grpo_qwen_gsm8k.sh (long) + fp32 permanent 3. grpo_qwen_gsm8k_short.sh (short) + no fp32 4. grpo_qwen_gsm8k_short.sh (short) + fp32 permanent - Updated grpo_qwen_gsm8k.sh to support FP32_PERMANENT env var - Added --save_logprob_samples for alignment analysis Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Disable local eval (--local_eval_every 99999) to avoid hanging - Lower vLLM memory to 0.4 for unified memory systems - Add fp32_long_2runs.sh and fp32_short_2runs.sh for experiments Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- run_gpu_tests_fp32_cache.sh: Tests cache mode (OPEN_INSTRUCT_FP32_LM_HEAD=1) - run_gpu_tests_fp32_permanent.sh: Tests permanent mode (OPEN_INSTRUCT_FP32_LM_HEAD=2) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Resolved conflict in grpo_fast.py: - Main refactored Args to grpo_utils.ExperimentConfig - Added fp32_lm_head fields to ExperimentConfig - Re-added fp32 implementation code to grpo_fast.py Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Move test scripts to scripts/test/ - Add launcher scripts (run_gpu_pytest_fp32_*.sh) that use mason.py - Container scripts (run_gpu_tests_fp32_*.sh) run inside Beaker Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add noqa for late vLLM import (intentional - vLLM may not be installed) - Add changelog entry for fp32 LM head feature Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- generate_logprobs.py: Generate sequences with vLLM and score with HF - Compares bf16 vs fp32 LM head precision - Supports Qwen3 sampling params (temperature, top_p, top_k, min_p) - Sequential loading allows high GPU memory utilization (0.85) - plot_logprobs.py: Visualize logprob alignment from saved data - Default: probabilities (0-1), use --use-logprobs for raw scale - Scatter plots and histograms for bf16 vs fp32 comparison Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Moved to private repo: natolambert/ai2-scripts Keep only grpo_qwen_gsm8k.sh and grpo_qwen_gsm8k_short.sh for DGX Spark examples (both have fp32 mode documentation). Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Lock histogram x/y axes to same scale for visual comparison - Remove temporary FP32 notes from AGENTS.md (moved to ~/dev/CLAUDE.md) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Links to related implementations: - TRL GRPO cast_lm_head_to_fp32 PRs - Flash-RL vLLM patch Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Document options for fixing the vLLM memory issue: 1. Fix vLLM V1 engine cleanup 2. Split into separate vLLM and HF scripts Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Remove plot_logprob_alignment.py (--save_logprob_samples flag removed) - Remove debug_vllm_cleanup.py (no longer needed) - Move fp32_lm_head_test.sh to private scripts repo - Reset AGENTS.md and run_gpu_tests.sh to main Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 435ffb258c
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| if args.fp32_lm_head: | ||
| enable_fp32_lm_head(self.policy, permanent=args.fp32_lm_head_permanent) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apply fp32 LM head to reference policy too
When --fp32_lm_head is enabled this only patches the policy model, but the reference policy loaded later still runs in bf16 because load_ref_policy(...) is called without the new fp32_lm_head flag (it defaults to False). In runs with --load_ref_policy (default) and a nonzero --beta, the KL term is computed between fp32 policy logprobs and bf16 ref logprobs, which reintroduces the precision mismatch the flag is meant to eliminate. Consider passing fp32_lm_head=args.fp32_lm_head into load_ref_policy so the KL uses matched precision.
Useful? React with 👍 / 👎.
Split generate_logprobs.py into separate scripts that run in separate processes, avoiding vLLM V1 memory cleanup issues: - get_vllm_logprobs.py: Generate sequences with vLLM (bf16 or fp32 mode) - get_hf_logprobs.py: Score sequences with HuggingFace - run_logprobs_comparison.sh: Wrapper to run full pipeline Each vLLM run now uses 0.85 GPU utilization instead of 0.4, since separate processes get clean GPU memory. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Pass fp32_lm_head flag to load_ref_policy() so ref policy uses matched precision when computing KL penalty with --fp32_lm_head - Remove old generate_logprobs.py (replaced by split scripts) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Adds configurable test script for testing fp32 LM head with larger models that require tensor parallelism. Supports env var overrides: - MODEL_NAME: Model to test (default: Qwen/Qwen2.5-7B) - VLLM_TP: Tensor parallel size per engine (default: 2) - VLLM_ENGINES: Number of vLLM engines (default: 2) - NUM_LEARNERS: Training GPUs per node (default: 4) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Reorganizes fp32 analysis scripts into dedicated directory. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add run_logprobs_on_beaker.sh to run the logprobs comparison pipeline on Beaker clusters with results saved to /output - Add tensor_parallel_size support to get_vllm_logprobs.py for larger models - Remove incorrect training script (was in wrong location) Usage: ./scripts/train/build_image_and_launch.sh scripts/analysis/fp32-lm-head/run_logprobs_on_beaker.sh # For larger models: MODEL_NAME=Qwen/Qwen2.5-7B NUM_GPUS=2 ./scripts/train/build_image_and_launch.sh scripts/analysis/fp32-lm-head/run_logprobs_on_beaker.sh Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Use vLLM's apply_model() to execute fp32 cache setup on all worker processes, making it TP-safe. The old driver_worker path only worked with single-process execution. Training code already handles this via WorkerWrap._sync_fp32_lm_head() which runs inside each worker after weight updates. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Update all internal path references in bash scripts and Python docstrings. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…kers The previous approach patched LogitsProcessor in the main process, but with TP>1, vLLM spawns separate worker processes via MultiprocExecutor. These workers import vllm fresh and get the unpatched LogitsProcessor. The fix uses apply_model() to run a setup function INSIDE each worker that does both: 1. Patches LogitsProcessor._get_logits to check for fp32 weights 2. Sets lm_head._open_instruct_fp32_weight cache on the model This ensures the fp32 code path is actually used when generating. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Instead of duplicating the LogitsProcessor patch implementation, import and call the existing patch_vllm_for_fp32_logits from open_instruct.vllm_utils inside each worker process. The worker function (setup_fp32_in_worker) still needs to be in a separate module for pickle serialization, but now it delegates the patching logic to the existing implementation. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
0151976 to
f647a0c
Compare
f647a0c to
3842073
Compare
Summary
Adds FP32 LM head projection option for GRPO training to reduce logprob mismatch between the vLLM generator and trainer forward pass.
Background
In GRPO/RL training, the generator (vLLM) and trainer compute logprobs on the same sequences. Due to bf16 precision loss in the LM head projection (
hidden @ lm_head.T -> logits -> softmax -> logprobs), these can diverge by 2-24+ nats, which affects training stability (see ScaleRL paper, Figure 3).The LM Head and Weight Tying
Many models (Qwen, LLaMA, etc.) use weight tying: the input embedding matrix (
embed_tokens) and output projection (lm_head) share the same underlying weight tensor to save memory.W[token_id]returns that token's embedding vectorhidden @ W.Tcomputes similarity to all token embeddingsSame matrix, different operations. This matters for the fp32 implementation because naively converting
lm_headto fp32 would also convert the embedding layer.Two Modes Implemented
Cache Mode (
--fp32_lm_head true):_open_instruct_fp32_weightcache updated after each weight syncPermanent Mode (
--fp32_lm_head true --fp32_lm_head_permanent true):Both modes use roughly the same extra memory (~500MB for 0.5B model), but permanent avoids repeated casting.
Changes
Trainer side (
model_utils.py):enable_fp32_lm_head(model, permanent=False)- patches or converts lm_head_untie_lm_head_if_needed(model)- safely separates tied weights before fp32 conversionvLLM side (
vllm_utils.py,vllm_utils_workerwrap.py):patch_vllm_for_fp32_logits()- patches LogitsProcessor._get_logits for fp32 computation_sync_fp32_lm_head()- maintains fp32 cache or converts weights after trainer syncOPEN_INSTRUCT_FP32_LM_HEADcontrols mode ("1"=cache, "2"=permanent)GRPO (
grpo_fast.py):--fp32_lm_headflag enables fp32 LM head--fp32_lm_head_permanentflag selects permanent modeTest plan
make style && make qualityuv run pytest open_instruct/test_vllm_utils.py open_instruct/test_rl_utils.pyBeaker GPU Tests (post-merge with main)
OPEN_INSTRUCT_FP32_LM_HEAD=1)OPEN_INSTRUCT_FP32_LM_HEAD=2)DGX Spark Long Runs (312 steps each, ~55 min per run)
--fp32_lm_head --fp32_lm_head_permanentDGX Spark Short Runs (78 steps each)
--fp32_lm_head --fp32_lm_head_permanent