-
Notifications
You must be signed in to change notification settings - Fork 491
Add DPO OLMo-core support with MFU improvements #1440
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
- Add olmo_core_callbacks.py with BeakerCallbackV2 and PerfCallback for MFU and tokens_per_second metrics - Add 7b_instruct_dpo_olmocore.sh script for 4-node DPO training with 16k sequence length - Update dpo_utils.py with new TrainingConfig fields: - gradient_checkpointing_mode for activation checkpointing mode - activation_memory_budget for budget-based checkpointing - shard_degree and num_replicas for FSDP configuration - compile_model for torch.compile support - Update dpo.py: - Import PerfCallback and BeakerCallbackV2 from olmo_core_callbacks - Add compile_model support with apply_compile() - Update _apply_parallelism for shard_degree/num_replicas - Add PerfCallback and ProfilerCallback to callbacks - Pass rank_microbatch_size to DPOTrainModule - Update olmo_core_train_modules.py: - Add split_batch_dpo() for microbatch gradient accumulation - Update DPOTrainModule with rank_microbatch_size parameter - Add proper FSDP sync handling in train_batch Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Summary of ChangesHello @finbarrtimbers, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly upgrades the DPO training pipeline by integrating OLMo-core specific features and improving distributed training capabilities. It introduces new callbacks for detailed performance monitoring, expands configuration options for memory optimization and parallelism, and refines data handling and gradient accumulation strategies. These changes aim to enable more efficient and scalable DPO training, particularly for large models and distributed environments. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
BeakerCallbackV2 is now in olmo_core_callbacks.py, so the old beaker_callback.py is no longer needed. Rename test file to match. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces significant enhancements for DPO training with OLMo-core, including support for multi-node training, longer sequences, and improved MFU (Model Flops Utilization) metrics. The changes are extensive, adding configurability for gradient checkpointing and FSDP, introducing torch.compile support, and implementing checkpointing for the reference log-probabilities cache build process. The addition of PerfCallback for MFU monitoring is a great improvement. The code is well-structured, but I've found a critical bug in the training module and a couple of minor issues that should be addressed.
| policy_chosen_logps, policy_rejected_logps, aux_loss = self._forward_fn( | ||
| self.model, | ||
| micro_batch, | ||
| average_log_prob=is_average, | ||
| packing=self.args.packing, | ||
| output_router_logits=self.args.load_balancing_loss, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a potential TypeError here. self._forward_fn is set to dpo_utils.separate_forward_olmo if args.concatenated_forward is False. However, dpo_utils.separate_forward_olmo does not accept the packing keyword argument, which is passed unconditionally on line 166. This will cause the training to crash if concatenated_forward is disabled.
forward_kwargs = {
"average_log_prob": is_average,
"output_router_logits": self.args.load_balancing_loss,
}
if self._forward_fn == dpo_utils.concatenated_forward_olmo:
forward_kwargs["packing"] = self.args.packing
policy_chosen_logps, policy_rejected_logps, aux_loss = self._forward_fn(
self.model,
micro_batch,
**forward_kwargs,
)| --cluster ai2/jupiter \ | ||
| --description "OLMo3-7B DPO with OLMo-core, 4 nodes, 16k seq len" \ | ||
| --workspace ai2/olmo-instruct \ | ||
| --no_auto_dataset_cache \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line starts with a literal tab character (\t). This will likely be interpreted as part of the argument string by the shell, passing "\t--no_auto_dataset_cache" to mason.py, which is probably not intended and could cause the script to fail or behave unexpectedly. The tab should be removed.
| --no_auto_dataset_cache \ | |
| --no_auto_dataset_cache \ |
| for group in self.optim.param_groups: | ||
| new_lr = self.scheduler.set_lr(group, self.trainer) | ||
| self.trainer.record_metric(f"LR (group {group_idx})", new_lr, namespace="optim") | ||
| self.trainer.record_metric("lr", new_lr, namespace="optim") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The learning rate is now logged only for the last parameter group. While it's likely that all parameter groups share the same learning rate in the current setup, the previous implementation was more robust as it logged the LR for each group. This could be important for debugging if different LR schedules are used for different parameter groups in the future. Consider restoring a behavior similar to the previous implementation for more detailed logging.
| for group in self.optim.param_groups: | |
| new_lr = self.scheduler.set_lr(group, self.trainer) | |
| self.trainer.record_metric(f"LR (group {group_idx})", new_lr, namespace="optim") | |
| self.trainer.record_metric("lr", new_lr, namespace="optim") | |
| for i, group in enumerate(self.optim.param_groups): | |
| new_lr = self.scheduler.set_lr(group, self.trainer) | |
| self.trainer.record_metric(f"lr/group_{i}", new_lr, namespace="optim") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 01bfbd6ff6
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| policy_chosen_logps, policy_rejected_logps, aux_loss = self._forward_fn( | ||
| self.model, | ||
| micro_batch, | ||
| average_log_prob=is_average, | ||
| packing=self.args.packing, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid passing packing kwarg to separate forward
When --concatenated_forward=false, _forward_fn is separate_forward_olmo, which does not accept a packing keyword. The new call always passes packing=self.args.packing, so any run with --concatenated_forward=false will raise a TypeError on the first batch and halt training. Gate the packing argument on the concatenated path or update separate_forward_olmo to accept/ignore it.
Useful? React with 👍 / 👎.
| @@ -388,8 +415,6 @@ def main(args: dpo_utils.ExperimentConfig, tc: dataset_transformation.TokenizerC | |||
| ) | |||
|
|
|||
| forward_fn = dpo_utils.concatenated_forward_olmo if args.concatenated_forward else dpo_utils.separate_forward_olmo | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use packing-aware forward for reference cache
With --packing enabled, the collator emits padding-free batches (flattened inputs with cu_seq_lens_*). The reference cache now uses a forward_fn that is always the plain concatenated_forward_olmo/separate_forward_olmo without packing=True, so cached logprobs will be computed with the padded path (or hit shape/key errors) and become invalid. Pass packing=True (or use a packing-aware partial) when building the reference cache for packing runs.
Useful? React with 👍 / 👎.
| if total_aux_loss is not None: | ||
| self.record_metric("train/aux_loss", total_aux_loss, ReduceType.mean) | ||
|
|
||
| self.record_metric("train/token_count", float(batch["token_count"]), ReduceType.sum) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don’t assume token_count exists in packing batches
In train_batch, batch["token_count"] is recorded unconditionally, but the padding-free collator (TensorDataCollatorWithFlatteningDPO) does not populate token_count. Any --packing run will therefore crash with a KeyError on the first step. Compute token_count for packed batches or guard this metric when the key is absent.
Useful? React with 👍 / 👎.
- Narrow build_reference_logprobs_cache to accept HFDataLoader only - Remove supports_checkpointing conditionals (always use HFDataLoader) - Remove manual dataloader.batches_processed increment (auto-incremented) - Remove max_length field from DataCollatorForSeq2SeqDPO Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Fix packing arg only passed to concatenated_forward_olmo (not separate) - Fix missing token_count check for packing batches - Fix packing-aware forward for reference cache with functools.partial - Fix tab character in shell script Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add packing parameter to separate_forward_olmo - Use pf_get_batch_logps when packing=True for both chosen and rejected - Remove concatenated_forward check for packing (both forward fns support it) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Change device_peak_flops_per_second to device_peak_flops to match the current olmo-core API. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Restore dpo_utils.py from main and add packing parameter to separate_forward_olmo for padding-free training support. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Add compile_model, activation_memory_budget, shard_degree, and num_replicas fields to TrainingConfig. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Add gradient_checkpointing_mode, gradient_checkpointing_block_interval, and gradient_checkpointing_modules fields to TrainingConfig. Convert string mode to TransformerActivationCheckpointingMode enum in dpo.py. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Use concatenated_cu_seq_lens_q (not concatenated_cu_seq_lens) for packing-free logps calculation - Add token_count to DataCollatorForSeq2SeqDPO output for accurate MFU metrics Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
When compile_model is enabled, pad all tensors to max_seq_length to prevent torch.compile recompilations due to varying sequence lengths. This improves MFU from ~9% to >40% for 7B DPO training. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This enables torch.compile and the fixed-length padding for improved MFU. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
The combination of gradient_checkpointing and torch.compile causes tensor metadata mismatches during recomputation. Disable checkpointing when using compile_model. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Budget mode activation checkpointing is compatible with torch.compile, unlike full mode which causes tensor metadata mismatches. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Summary
olmo_core_callbacks.pywithBeakerCallbackV2andPerfCallbackfor MFU and tokens_per_second metrics7b_instruct_dpo_olmocore.shscript for 4-node DPO training with 16k sequence lengthdpo_utils.pywith new TrainingConfig fields:gradient_checkpointing_modefor activation checkpointing mode (full, selected_blocks, budget)activation_memory_budgetfor budget-based checkpointingshard_degreeandnum_replicasfor FSDP configurationcompile_modelfor torch.compile supportDataCollatorForSeq2SeqDPOupdated withmax_lengthandtoken_countsupportdpo.py:olmo_core_callbackscompile_modelsupport withapply_compile()_apply_parallelismforshard_degree/num_replicasPerfCallbackandProfilerCallbackto callbacksrank_microbatch_sizetoDPOTrainModuleolmo_core_train_modules.py:split_batch_dpo()for microbatch gradient accumulationDPOTrainModulewithrank_microbatch_sizeparametertrain_batchTest plan
./scripts/train/build_image_and_launch.sh scripts/train/debug/dpo/single_gpu.sh./scripts/train/build_image_and_launch.sh scripts/train/debug/dpo/multi_node.sh🤖 Generated with Claude Code