-
Notifications
You must be signed in to change notification settings - Fork 491
Add OLMo-core GRPO trainer implementation #1389
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
New files: - open_instruct/grpo.py - Main entry point for OLMo-core GRPO training - open_instruct/grpo_train_module.py - GRPOTrainModule using OLMo-core's TrainModule - open_instruct/grpo_callbacks.py - vLLM weight sync and ref policy update callbacks - open_instruct/grpo_olmo_core_actor.py - Ray actor wrapping OLMo-core training New test scripts: - scripts/train/debug/single_gpu_grpo.sh - Single GPU OLMo-core test - scripts/train/debug/multi_node_grpo.sh - Multi-node OLMo-core test - scripts/train/debug/tool_grpo.sh - Tool use OLMo-core test Key design: - Reuses grpo_fast.make_tokenizer() and setup_datasets() - Uses same Ray actor pattern as grpo_fast.py - Same queue-based architecture with DataPreparationActor - Weight sync callback converts OLMo-core param names to HuggingFace format Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Summary of ChangesHello @finbarrtimbers, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a comprehensive implementation of Group Relative Policy Optimization (GRPO) training, seamlessly integrated with the OLMo-core framework. The core objective is to transition from DeepSpeed to FSDP for distributed training, utilizing Ray actors to manage and scale the training process efficiently across various hardware configurations. This foundational work provides the necessary components for advanced policy optimization within the OLMo-core ecosystem, complete with new testing utilities. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces the OLMo-core GRPO trainer implementation, migrating from DeepSpeed to FSDP and leveraging Ray actors for distributed training. The changes include new modules for GRPO training, callbacks for vLLM weight synchronization and reference policy updates, and Ray actors to wrap OLMo-core's training infrastructure. Several debug scripts are also added to test single-GPU, multi-node, and tool-use scenarios. The overall structure is well-organized, and the use of Ray actors for distributed components is consistent with modern distributed training practices. However, there are a few areas that could be improved for robustness, maintainability, and clarity, particularly regarding hardcoded model-specific logic and environment variable handling within Ray actors.
| --oe_eval_max_length 32768 \ | ||
| --oe_eval_tasks "codex_humanevalplus:0-shot-chat-v1::tulu-thinker,mbppplus:0-shot-chat::tulu-thinker,livecodebench_codegeneration::tulu-thinker" \ | ||
| --checkpoint_state_freq 2 \ | ||
| --checkpoint_state_dir /tmp/checkpoint_test \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using /tmp/checkpoint_test as the checkpoint_state_dir means that checkpoints will be stored in a temporary directory. Data in /tmp is typically not persistent across reboots or job restarts, which can lead to data loss if the job fails or is preempted. For critical checkpoints, a persistent storage location should be used.
| def olmo_core_to_hf_name(name: str) -> str: | ||
| """Convert OLMo-core parameter name to HuggingFace format for Qwen3/LLaMA models.""" | ||
| if name == "embeddings.weight": | ||
| return "model.embed_tokens.weight" | ||
| if name == "lm_head.norm.weight": | ||
| return "model.norm.weight" | ||
| if name == "lm_head.w_out.weight": | ||
| return "lm_head.weight" | ||
|
|
||
| layer_match = re.match(r"blocks\.(\d+)\.(.*)", name) | ||
| if layer_match: | ||
| layer_idx = layer_match.group(1) | ||
| rest = layer_match.group(2) | ||
|
|
||
| mappings = { | ||
| "attention.w_q.weight": "self_attn.q_proj.weight", | ||
| "attention.w_k.weight": "self_attn.k_proj.weight", | ||
| "attention.w_v.weight": "self_attn.v_proj.weight", | ||
| "attention.w_out.weight": "self_attn.o_proj.weight", | ||
| "attention.q_norm.weight": "self_attn.q_norm.weight", | ||
| "attention.k_norm.weight": "self_attn.k_norm.weight", | ||
| "feed_forward.w1.weight": "mlp.gate_proj.weight", | ||
| "feed_forward.w2.weight": "mlp.down_proj.weight", | ||
| "feed_forward.w3.weight": "mlp.up_proj.weight", | ||
| "attention_norm.weight": "input_layernorm.weight", | ||
| "feed_forward_norm.weight": "post_attention_layernorm.weight", | ||
| } | ||
|
|
||
| if rest in mappings: | ||
| return f"model.layers.{layer_idx}.{mappings[rest]}" | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The olmo_core_to_hf_name function contains hardcoded mappings and a specific heuristic for parameter name conversion. This approach is brittle and highly dependent on the internal naming conventions of Qwen3/LLaMA models and OLMo-core. If new models are introduced or existing model architectures change their parameter naming, this function will likely break, requiring manual updates. Consider a more generalized or configurable approach if supporting diverse model architectures is a goal.
| os.environ["NUM_NODES"] = str(self.num_nodes) | ||
| os.environ["LOCAL_WORLD_SIZE"] = str(self.local_world_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Setting environment variables like NUM_NODES and LOCAL_WORLD_SIZE inside a Ray actor's method can lead to unexpected behavior or race conditions, especially if multiple actors are running on the same node or if the environment variables are expected to be static. It's generally safer to pass these values as direct arguments to the actor's __init__ method or ensure they are set in the Ray runtime environment before the actor is launched.
| model_basename = self.model_name_or_path.split("/")[-1] | ||
| config_name = model_basename.replace("-", "_").replace(".", "_") | ||
| config_name = config_name[:-1].lower() + "B" if config_name.endswith("B") else config_name.lower() | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for deriving config_name from model_basename uses a specific heuristic (config_name[:-1].lower() + "B" if config_name.endswith("B") else config_name.lower()). This is brittle and may not generalize well to all model naming conventions, potentially causing issues if model names do not strictly follow this pattern.
| device=device, | ||
| ) | ||
|
|
||
| os.environ["FS_LOCAL_RANK"] = str(self.rank) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the previous comment, setting FS_LOCAL_RANK as an environment variable within the actor's method is not ideal. It's better to pass such configuration directly as arguments or ensure the environment is correctly set up before actor launch to avoid potential issues in a distributed context.
| else: | ||
| logger.error(f"Timeout waiting for GPUs. Only {available_gpus} available, needed {expected_gpus}") | ||
|
|
||
| bundles = [{"GPU": n, "CPU": n} for n in args.num_learners_per_node] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current bundle configuration requests an equal number of CPUs and GPUs per actor ("CPU": n for "GPU": n). While this might be sufficient for some workloads, Ray actors often benefit from more CPU resources than GPU resources, especially for tasks involving data preprocessing or orchestration. This could lead to CPU bottlenecks if the CPU-bound tasks within the actor are not adequately resourced.
| bundles = [{"GPU": n, "CPU": n} for n in args.num_learners_per_node] | |
| bundles = [{"GPU": n, "CPU": n * 2} for n in args.num_learners_per_node] |
| elif is_fsdp: | ||
| for name, param in model.named_parameters(): | ||
| count += 1 | ||
| vllm_name = get_vllm_name(name) | ||
| with FSDP.summon_full_params(model, writeback=False, rank0_only=True): | ||
| if is_rank0: | ||
| refs = [ | ||
| engine.update_weight.remote( | ||
| vllm_name, dtype=str(param.dtype), shape=param.shape, empty_cache=(count == num_params) | ||
| ) | ||
| for engine in self.vllm_engines | ||
| ] | ||
| refss.extend(refs) | ||
|
|
||
| if self.model_update_group is not None: | ||
| dist.broadcast(param.data, 0, group=self.model_update_group) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The elif is_fsdp: block is identical to the if self.gather_whole_model and is_fsdp: block. This indicates redundant code that can be simplified. The logic can be streamlined to avoid repeating the same code path.
if is_fsdp:
with FSDP.summon_full_params(model, writeback=False, rank0_only=True):
for name, param in model.named_parameters():
count += 1
vllm_name = get_vllm_name(name)
if is_rank0:
refs = [
engine.update_weight.remote(
vllm_name, dtype=str(param.dtype), shape=param.shape, empty_cache=(count == num_params)
)
for engine in self.vllm_engines
]
refss.extend(refs)
if self.model_update_group is not None:
dist.broadcast(param.data, 0, group=self.model_update_group)
else:
for name, param in model.named_parameters():
count += 1
vllm_name = get_vllm_name(name)
if is_rank0:
refs = [
engine.update_weight.remote(
vllm_name, dtype=str(param.dtype), shape=param.shape, empty_cache=(count == num_params)
)
for engine in self.vllm_engines
]
refss.extend(refs)
if self.model_update_group is not None:
dist.broadcast(param.data, 0, group=self.model_update_group)| old_logprobs_BT = self.compute_logprobs(self.model, data_BT, use_grad=False) | ||
|
|
||
| num_samples = len(data_BT.query_responses) | ||
| accumulation_steps = num_samples * self.grpo_config.num_epochs * self.grpo_config.num_mini_batches |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable accumulation_steps is calculated as num_samples * self.grpo_config.num_epochs * self.grpo_config.num_mini_batches. While this might represent the total number of inner loop iterations, the name accumulation_steps typically refers to the number of steps over which gradients are accumulated before an optimizer step. If this variable is intended for gradient accumulation, its current calculation might be misleading or incorrect for that purpose. If it's just a loop counter, a more descriptive name would improve clarity.
| --load_ref_policy true \ | ||
| --seed 3 \ | ||
| --local_eval_every 1 \ | ||
| --vllm_sync_backend gloo \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The vllm_sync_backend is set to gloo. For GPU-based distributed training, NCCL is generally preferred over Gloo due to its superior performance characteristics, especially for NVIDIA GPUs. Using Gloo might lead to suboptimal synchronization performance.
| --vllm_sync_backend gloo \ | |
| --vllm_sync_backend nccl \ |
| --tools code search \ | ||
| --search_api_endpoint "http://saturn-cs-aus-248.reviz.ai2.in:47479/search" \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The search_api_endpoint and code_tool_api_endpoint are hardcoded to specific URLs. While this is a debug script, hardcoding external service endpoints makes the script less flexible and harder to manage in different environments (e.g., staging vs. production, or different user setups). Consider making these configurable via environment variables or command-line arguments, even for debug scripts, to promote reusability and easier adaptation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 749cbb926b
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| state_dict = {} | ||
| model = self.train_module.model | ||
| for name, param in model.named_parameters(): | ||
| hf_name = olmo_core_to_hf_name(name) | ||
| state_dict[hf_name] = param.data.cpu() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gather full FSDP params before saving checkpoint
When training is distributed (world_size > 1), parallelize_model(...) wraps the Transformer in FSDP/hsdp, so model.named_parameters() on rank 0 only exposes local shards. The current save path writes those shards directly, which produces an incomplete checkpoint that will not load or will silently miss weights in multi-GPU runs. This only shows up when FSDP is active, but that’s the intended multi-node mode for this entry point. Consider using FSDP full state dict (or FSDP.summon_full_params) to materialize full parameters on rank 0 before writing.
Useful? React with 👍 / 👎.
| num_samples = len(data_BT.query_responses) | ||
| accumulation_steps = num_samples * self.grpo_config.num_epochs * self.grpo_config.num_mini_batches | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove num_mini_batches scaling or implement mini-batches
The loss normalization multiplies accumulation_steps by num_mini_batches, but the training loop only iterates over epochs and samples (there is no mini-batch loop). As a result, when num_mini_batches > 1, the loss is divided by an extra factor and gradients are under-scaled, while the scheduler still assumes extra steps. This only affects runs that set --num_mini_batches > 1; otherwise it’s fine. Either split the batch into mini-batches like grpo_fast or remove num_mini_batches from the accumulation divisor.
Useful? React with 👍 / 👎.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Use init=False so temperature is not exposed as CLI arg (it's already in StreamingDataLoaderConfig and copied at runtime). Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
The field is parsed as string from CLI and then converted in __post_init__. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This field defaults to model_name_or_path if not set, handled by make_tokenizer(). Adding an assert for type narrowing. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Required for running on macOS where vllm is not available. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Use tool names from TOOL_REGISTRY (python, serper_search) - Configure tools via --tool_configs instead of old CLI args Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
SerperSearchTool requires SERPER_API_KEY which is not available. Just test with python code execution tool. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
New files:
New test scripts:
Key design:
Test Results
Ran single GPU GRPO (Beaker), single GPU GRPO with tools (Beaker), and multi-node GRPO (Beaker) scripts.