Skip to content

Conversation

@natolambert
Copy link
Collaborator

@natolambert natolambert commented Jan 19, 2026

Summary

Adds FP32 LM head projection option for GRPO training to reduce logprob mismatch between the vLLM generator and trainer forward pass.

Background

In GRPO/RL training, the generator (vLLM) and trainer compute logprobs on the same sequences. Due to bf16 precision loss in the LM head projection (hidden @ lm_head.T -> logits -> softmax -> logprobs), these can diverge by 2-24+ nats, which affects training stability (see ScaleRL paper, Figure 3).

The LM Head and Weight Tying

Many models (Qwen, LLaMA, etc.) use weight tying: the input embedding matrix (embed_tokens) and output projection (lm_head) share the same underlying weight tensor to save memory.

  • embed_tokens: Lookup operation - W[token_id] returns that token's embedding vector
  • lm_head: Matrix multiply - hidden @ W.T computes similarity to all token embeddings

Same matrix, different operations. This matters for the fp32 implementation because naively converting lm_head to fp32 would also convert the embedding layer.

Two Modes Implemented

Cache Mode (--fp32_lm_head true):

  • Keeps bf16 weights unchanged
  • Patches forward pass to cast inputs/weights to fp32 on-the-fly
  • vLLM maintains a separate _open_instruct_fp32_weight cache updated after each weight sync
  • Safe with tied weights (no modification to underlying storage)
  • Small compute overhead from casting each forward pass

Permanent Mode (--fp32_lm_head true --fp32_lm_head_permanent true):

  • Unties lm_head from embed_tokens (creates independent copy if needed)
  • Converts lm_head weights to fp32 in-place
  • vLLM receives fp32 weights directly during sync
  • No runtime casting overhead
  • ~15-20% extra memory for the untied fp32 lm_head matrix

Both modes use roughly the same extra memory (~500MB for 0.5B model), but permanent avoids repeated casting.

Changes

Trainer side (model_utils.py):

  • enable_fp32_lm_head(model, permanent=False) - patches or converts lm_head
  • _untie_lm_head_if_needed(model) - safely separates tied weights before fp32 conversion

vLLM side (vllm_utils.py, vllm_utils_workerwrap.py):

  • patch_vllm_for_fp32_logits() - patches LogitsProcessor._get_logits for fp32 computation
  • _sync_fp32_lm_head() - maintains fp32 cache or converts weights after trainer sync
  • Env var OPEN_INSTRUCT_FP32_LM_HEAD controls mode ("1"=cache, "2"=permanent)

GRPO (grpo_fast.py):

  • --fp32_lm_head flag enables fp32 LM head
  • --fp32_lm_head_permanent flag selects permanent mode
  • Applied to policy model, reference policy, and vLLM engines

Test plan

  • make style && make quality
  • uv run pytest open_instruct/test_vllm_utils.py open_instruct/test_rl_utils.py
  • GRPO smoke test: Qwen2.5-0.5B GSM8K with fp32 enabled

Beaker GPU Tests (post-merge with main)

Mode Beaker Job Status
Cache (OPEN_INSTRUCT_FP32_LM_HEAD=1) 01KFHHEEKEGV6GR341XXKJ3RH9 ✅ Passed
Permanent (OPEN_INSTRUCT_FP32_LM_HEAD=2) 01KFHHF9Z2QEVQ3A7H8SG606KC ✅ Passed

DGX Spark Long Runs (312 steps each, ~55 min per run)

Run Config Wandb
Long + NO FP32 baseline https://wandb.ai/ai2-llm/open_instruct_internal/runs/fw73k9s8
Long + FP32 PERMANENT --fp32_lm_head --fp32_lm_head_permanent https://wandb.ai/ai2-llm/open_instruct_internal/runs/gmif0osr

DGX Spark Short Runs (78 steps each)

Run Config Wandb
Short + NO FP32 baseline https://wandb.ai/ai2-llm/open_instruct_internal/runs/ajujoluc
Short + FP32 PERMANENT --fp32_lm_head --fp32_lm_head_permanent https://wandb.ai/ai2-llm/open_instruct_internal/runs/r0i7x38r

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @natolambert, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the GRPO training framework by introducing an FP32 LM head option to improve log-probability consistency between the trainer and generator, and by implementing a validation holdout mechanism for better overfitting detection. It also broadens hardware compatibility and development environment by upgrading to CUDA 13.0.0 and adding robust support for NVIDIA DGX Spark, complete with new documentation and dedicated training scripts.

Highlights

  • FP32 LM Head Option for GRPO: Introduced a new --fp32_lm_head flag to enable FP32 precision for the Language Model (LM) head in both the GRPO trainer and the vLLM generator. This aims to reduce log-probability mismatch between the generator and trainer for improved alignment.
  • Validation Reward Tracking: Added a validation_holdout_ratio parameter to split training data into separate training and validation sets. This allows monitoring performance on held-out data during GRPO training to detect overfitting, with metrics appearing under the eval/ prefix.
  • DGX Spark (Blackwell) Support: Enhanced support for NVIDIA DGX Spark (GB10 Blackwell, CUDA 13, aarch64) by updating the Dockerfile to CUDA 13.0.0, adjusting pyproject.toml for vllm and flash-attn compatibility on aarch64, and adding comprehensive documentation and example training scripts for SFT, DPO, and GRPO on this hardware.
  • Gradient Checkpointing for LoRA: Enabled gradient checkpointing for LoRA (non-QLoRA) training configurations in finetune.py when specified, which can help manage memory usage during fine-tuning.
  • Configurable Reference Logprobs Cache Path: Made the REFERENCE_LOGPROBS_CACHE_PATH configurable via an environment variable, providing more flexibility for cache management.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several valuable enhancements, including an fp32_lm_head option for GRPO to improve logprob alignment, support for DGX Spark (aarch64, CUDA 13), and a new validation_holdout_ratio feature for monitoring overfitting. The changes are well-implemented, with corresponding documentation and example scripts. The code quality is high, and the new features are thoughtfully integrated. I have one suggestion regarding temporary notes in the documentation.

AGENTS.md Outdated
Comment on lines 36 to 51
# TEMP: FP32 LM head rollout notes (ScaleRL alignment)
- Goal: reduce generator↔trainer logprob mismatch; apply fp32 at LM head on both sides.
- References: TRL GRPO `cast_lm_head_to_fp32` (PR #4303, #4446); vLLM patch point is `LogitsProcessor._get_logits`; vLLM plugin system can host the patch; LLaMA-Factory upcast_lmhead_output.
- Implementation in repo:
- Trainer: `enable_fp32_lm_head()` patches output embeddings to run in fp32 without re-allocating weights (tied embeddings preserved).
- Reference policy: same fp32 head toggle to keep KL alignment.
- Generator: vLLM LogitsProcessor `_get_logits` patched to upcast hidden_states (+ bias) before projection.
- Flag: `--fp32_lm_head` (GRPO). Ensure vLLM engines receive the flag too.
- Sanity check (step 0): compare trainer logprobs vs vLLM logprobs on identical sequences; expect tighter deltas with fp32 head.
- OOM safety (DGX Spark): check `free -h`, kill leftover Ray/vLLM if needed, set `PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True,max_split_size_mb:128"`.
- Suggested local debug: `scripts/train/dgx-spark/grpo_qwen_gsm8k.sh` (Qwen2.5-0.5B GSM8K), add `--fp32_lm_head` when testing.
- Long comparison runs: `FP32_LM_HEAD=0 ./scripts/train/dgx-spark/grpo_qwen_gsm8k.sh` then `FP32_LM_HEAD=1 ./scripts/train/dgx-spark/grpo_qwen_gsm8k.sh`; monitor `debug/vllm_vs_local_logprob_diff_*`.
- Plugin vs monkey patch: plugin is just a startup hook that can apply the same patch; in-process patch is used unless a vLLM plugin package is added.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This section with temporary notes about the FP32 LM head feature seems intended for development and should be removed before merging this pull request to keep the documentation clean and focused on user-facing information. If this information is valuable long-term, consider moving it to a design document or a more permanent part of the documentation.

@natolambert
Copy link
Collaborator Author

Note on vLLM fp32 LM head options we discussed:

  • Option A (true fp32 matmul on the fly): patch _get_logits to use F.linear(hidden_states.float(), lm_head.weight.float(), bias.float()) each call. Simple but slower (per-token cast + fp32 GEMM).
  • Option B (cached fp32 head weights): keep an fp32 copy of lm_head.weight updated on weight sync; _get_logits uses cached fp32 weights with fp32 hidden_states. Lower runtime cost; modest extra VRAM (head size). Skips quantized heads.

You chose Option B because KV cache dominates memory in RL, so the head cache is relatively small while preserving fp32 head matmul.

TODO: current debug metric (debug/vllm_vs_local_logprob_diff_*) is per-sample overwrite; we should log batch-averaged stats for a smoother comparison curve.

Nathan Lambert and others added 23 commits January 19, 2026 13:51
Set OPEN_INSTRUCT_FP32_LM_HEAD environment variable at the start of
main() before ray.init(), so it propagates to all worker processes
via Ray's runtime_env.

Previously, the env var was only set in the LLMRayActor process after
Ray was initialized, so vLLM worker subprocesses didn't inherit it.
This caused _maybe_update_fp32_lm_head_cache() to early-return, meaning
the fp32 weight cache was never created and vLLM was still using bf16
weights for the lm_head projection.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Two modes now available via --fp32_lm_head + --fp32_lm_head_permanent:

1. Cache mode (default): Keep bf16 weights, maintain separate fp32 cache
   - Set OPEN_INSTRUCT_FP32_LM_HEAD=1
   - Lower memory overhead but more complex

2. Permanent mode: Convert vLLM lm_head weights to fp32 in-place
   - Set OPEN_INSTRUCT_FP32_LM_HEAD=2 via --fp32_lm_head_permanent
   - Similar to TRL/Flash-RL approach
   - Simpler but uses more memory

Both modes cast hidden_states to fp32 for the matmul. The permanent mode
converts lm_head.weight.data to fp32 after each weight sync, while cache
mode maintains a separate _open_instruct_fp32_weight attribute.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Add --save_logprob_samples flag to save raw vLLM vs trainer logprob
pairs during training. These can be used to create scatter plots
showing the alignment between inference and training logprobs,
similar to Figure 3 in the ScaleRL paper.

New files:
- scripts/analysis/plot_logprob_alignment.py: Creates scatter plots
  with Pearson correlation, supports comparison between runs

Usage:
  # Run training with sample saving
  python open_instruct/grpo_fast.py ... --save_logprob_samples

  # Plot results
  python scripts/analysis/plot_logprob_alignment.py \
    --data_dir /tmp/run/logprob_samples --output plot.png

  # Compare runs (e.g., with/without fp32)
  python scripts/analysis/plot_logprob_alignment.py \
    --data_dirs /tmp/run1/logprob_samples /tmp/run2/logprob_samples \
    --labels "Without FP32" "With FP32" --output comparison.png

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…ilable

The previous code unconditionally cast hidden_states to fp32 at the start
of _get_logits_fp32(), then passed them to the original implementation in
the fallback path. This caused dtype mismatches when the original weights
were still bf16 (before the first weight sync in permanent mode).

Now we only cast hidden_states to fp32 when we actually have fp32 weights
to compute with.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
When tie_word_embeddings=True (common in Qwen, Llama, etc.), the lm_head
weight and input embedding share the same tensor. Converting lm_head to
fp32 in-place would also convert the embedding layer, causing dtype
mismatches in early layers.

This fix checks if weights are tied by comparing data_ptr(), and falls
back to cache mode when weights are tied to avoid the issue.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add --use_probs flag to convert logprobs to probabilities via exp()
- Probabilities plot on 0-1 scale for better visualization
- Supports comparison plots with multiple runs side-by-side

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Numpy doesn't support BFloat16, so convert to float32 before calling
.cpu().numpy(). This fixes the TypeError when saving logprob samples
without fp32_lm_head enabled.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
The previous code saved vLLM vs trainer logprobs on every mini-batch,
but trainer weights get updated after each backward/step. This caused
apparent misalignment because we were comparing:
- vLLM logprobs: computed during generation with weights W
- Trainer logprobs: computed after optimizer.step() with weights W'

Now only save samples on epoch 0, mini-batch 0 of each training step,
when trainer weights still match the vLLM weights from generation.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- maybe_enable_fp32_lm_head() → patch_vllm_for_fp32_logits()
- _maybe_update_fp32_lm_head_cache() → _sync_fp32_lm_head()

The new names better describe what the functions do:
- patch_vllm_for_fp32_logits: patches vLLM LogitsProcessor for fp32 computation
- _sync_fp32_lm_head: syncs fp32 LM head weights after trainer broadcast

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Permanent mode now converts lm_head weights to fp32 once at training start,
matching Flash-RL and TRL implementations:

Trainer side (model_utils.py):
- enable_fp32_lm_head() now accepts `permanent` parameter
- permanent=True converts weights to fp32 in-place
- Handles tied weights by untying lm_head from embed_tokens first

vLLM side (vllm_utils_workerwrap.py):
- update_weight() now accepts fp32 dtype for lm_head in permanent mode
- Weight sync receives fp32 directly, no conversion needed after sync

This is more efficient than the cache approach:
- No bf16→fp32 conversion on each weight sync
- No separate fp32 cache to maintain
- Weights are synced as fp32 directly from trainer

References:
- Flash-RL: https://github.com/yaof20/Flash-RL/blob/main/flash_rl/vllm_patch.py
- TRL: huggingface/trl#4303

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Tests the new permanent fp32 lm_head mode on GSM8K with Qwen2.5-0.5B.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- fp32_comparison_overnight.sh: 4 sequential runs comparing fp32 modes
  1. No FP32 (baseline)
  2. FP32 cache mode
  3. FP32 permanent mode
  4. No FP32 (repeat for variance)

- grpo_qwen_gsm8k_short.sh: Config with shorter response_length=256
  that showed good reward scores (faster training)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- fp32_overnight_4runs.sh: Runs all 4 combinations
  1. grpo_qwen_gsm8k.sh (long) + no fp32
  2. grpo_qwen_gsm8k.sh (long) + fp32 permanent
  3. grpo_qwen_gsm8k_short.sh (short) + no fp32
  4. grpo_qwen_gsm8k_short.sh (short) + fp32 permanent

- Updated grpo_qwen_gsm8k.sh to support FP32_PERMANENT env var
- Added --save_logprob_samples for alignment analysis

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Disable local eval (--local_eval_every 99999) to avoid hanging
- Lower vLLM memory to 0.4 for unified memory systems
- Add fp32_long_2runs.sh and fp32_short_2runs.sh for experiments

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- run_gpu_tests_fp32_cache.sh: Tests cache mode (OPEN_INSTRUCT_FP32_LM_HEAD=1)
- run_gpu_tests_fp32_permanent.sh: Tests permanent mode (OPEN_INSTRUCT_FP32_LM_HEAD=2)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Resolved conflict in grpo_fast.py:
- Main refactored Args to grpo_utils.ExperimentConfig
- Added fp32_lm_head fields to ExperimentConfig
- Re-added fp32 implementation code to grpo_fast.py

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Move test scripts to scripts/test/
- Add launcher scripts (run_gpu_pytest_fp32_*.sh) that use mason.py
- Container scripts (run_gpu_tests_fp32_*.sh) run inside Beaker

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add noqa for late vLLM import (intentional - vLLM may not be installed)
- Add changelog entry for fp32 LM head feature

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- generate_logprobs.py: Generate sequences with vLLM and score with HF
  - Compares bf16 vs fp32 LM head precision
  - Supports Qwen3 sampling params (temperature, top_p, top_k, min_p)
  - Sequential loading allows high GPU memory utilization (0.85)

- plot_logprobs.py: Visualize logprob alignment from saved data
  - Default: probabilities (0-1), use --use-logprobs for raw scale
  - Scatter plots and histograms for bf16 vs fp32 comparison

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Moved to private repo: natolambert/ai2-scripts

Keep only grpo_qwen_gsm8k.sh and grpo_qwen_gsm8k_short.sh for
DGX Spark examples (both have fp32 mode documentation).

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Nathan Lambert and others added 7 commits January 22, 2026 20:08
- Lock histogram x/y axes to same scale for visual comparison
- Remove temporary FP32 notes from AGENTS.md (moved to ~/dev/CLAUDE.md)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Links to related implementations:
- TRL GRPO cast_lm_head_to_fp32 PRs
- Flash-RL vLLM patch

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Document options for fixing the vLLM memory issue:
1. Fix vLLM V1 engine cleanup
2. Split into separate vLLM and HF scripts

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Remove plot_logprob_alignment.py (--save_logprob_samples flag removed)
- Remove debug_vllm_cleanup.py (no longer needed)
- Move fp32_lm_head_test.sh to private scripts repo
- Reset AGENTS.md and run_gpu_tests.sh to main

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@natolambert natolambert marked this pull request as ready for review January 23, 2026 04:24
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 435ffb258c

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +259 to +260
if args.fp32_lm_head:
enable_fp32_lm_head(self.policy, permanent=args.fp32_lm_head_permanent)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Apply fp32 LM head to reference policy too

When --fp32_lm_head is enabled this only patches the policy model, but the reference policy loaded later still runs in bf16 because load_ref_policy(...) is called without the new fp32_lm_head flag (it defaults to False). In runs with --load_ref_policy (default) and a nonzero --beta, the KL term is computed between fp32 policy logprobs and bf16 ref logprobs, which reintroduces the precision mismatch the flag is meant to eliminate. Consider passing fp32_lm_head=args.fp32_lm_head into load_ref_policy so the KL uses matched precision.

Useful? React with 👍 / 👎.

Nathan Lambert and others added 15 commits January 22, 2026 20:36
Split generate_logprobs.py into separate scripts that run in separate
processes, avoiding vLLM V1 memory cleanup issues:

- get_vllm_logprobs.py: Generate sequences with vLLM (bf16 or fp32 mode)
- get_hf_logprobs.py: Score sequences with HuggingFace
- run_logprobs_comparison.sh: Wrapper to run full pipeline

Each vLLM run now uses 0.85 GPU utilization instead of 0.4, since
separate processes get clean GPU memory.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Pass fp32_lm_head flag to load_ref_policy() so ref policy uses
  matched precision when computing KL penalty with --fp32_lm_head
- Remove old generate_logprobs.py (replaced by split scripts)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Adds configurable test script for testing fp32 LM head with larger
models that require tensor parallelism. Supports env var overrides:
- MODEL_NAME: Model to test (default: Qwen/Qwen2.5-7B)
- VLLM_TP: Tensor parallel size per engine (default: 2)
- VLLM_ENGINES: Number of vLLM engines (default: 2)
- NUM_LEARNERS: Training GPUs per node (default: 4)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Reorganizes fp32 analysis scripts into dedicated directory.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add run_logprobs_on_beaker.sh to run the logprobs comparison pipeline
  on Beaker clusters with results saved to /output
- Add tensor_parallel_size support to get_vllm_logprobs.py for larger models
- Remove incorrect training script (was in wrong location)

Usage:
  ./scripts/train/build_image_and_launch.sh scripts/analysis/fp32-lm-head/run_logprobs_on_beaker.sh

  # For larger models:
  MODEL_NAME=Qwen/Qwen2.5-7B NUM_GPUS=2 ./scripts/train/build_image_and_launch.sh     scripts/analysis/fp32-lm-head/run_logprobs_on_beaker.sh

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Use vLLM's apply_model() to execute fp32 cache setup on all worker
processes, making it TP-safe. The old driver_worker path only worked
with single-process execution.

Training code already handles this via WorkerWrap._sync_fp32_lm_head()
which runs inside each worker after weight updates.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Update all internal path references in bash scripts and Python docstrings.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…kers

The previous approach patched LogitsProcessor in the main process, but
with TP>1, vLLM spawns separate worker processes via MultiprocExecutor.
These workers import vllm fresh and get the unpatched LogitsProcessor.

The fix uses apply_model() to run a setup function INSIDE each worker
that does both:
1. Patches LogitsProcessor._get_logits to check for fp32 weights
2. Sets lm_head._open_instruct_fp32_weight cache on the model

This ensures the fp32 code path is actually used when generating.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Instead of duplicating the LogitsProcessor patch implementation,
import and call the existing patch_vllm_for_fp32_logits from
open_instruct.vllm_utils inside each worker process.

The worker function (setup_fp32_in_worker) still needs to be in a
separate module for pickle serialization, but now it delegates the
patching logic to the existing implementation.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@natolambert natolambert force-pushed the fp32-lm-head-grpo branch 2 times, most recently from 0151976 to f647a0c Compare January 28, 2026 00:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants