Skip to content

Conversation

@finbarrtimbers
Copy link
Collaborator

Summary

  • Add olmo_core_callbacks.py with BeakerCallbackV2 and PerfCallback for MFU and tokens_per_second metrics
  • Add 7b_instruct_dpo_olmocore.sh script for 4-node DPO training with 16k sequence length
  • Update dpo_utils.py with new TrainingConfig fields:
    • gradient_checkpointing_mode for activation checkpointing mode (full, selected_blocks, budget)
    • activation_memory_budget for budget-based checkpointing
    • shard_degree and num_replicas for FSDP configuration
    • compile_model for torch.compile support
    • DataCollatorForSeq2SeqDPO updated with max_length and token_count support
  • Update dpo.py:
    • Import callbacks from olmo_core_callbacks
    • Add compile_model support with apply_compile()
    • Update _apply_parallelism for shard_degree/num_replicas
    • Add PerfCallback and ProfilerCallback to callbacks
    • Pass rank_microbatch_size to DPOTrainModule
  • Update olmo_core_train_modules.py:
    • Add split_batch_dpo() for microbatch gradient accumulation
    • Update DPOTrainModule with rank_microbatch_size parameter
    • Add proper FSDP sync handling in train_batch

Test plan

  • Run single-GPU debug script: ./scripts/train/build_image_and_launch.sh scripts/train/debug/dpo/single_gpu.sh
  • Run multi-node debug script: ./scripts/train/build_image_and_launch.sh scripts/train/debug/dpo/multi_node.sh
  • Verify MFU metrics appear in wandb logs

🤖 Generated with Claude Code

- Add olmo_core_callbacks.py with BeakerCallbackV2 and PerfCallback for
  MFU and tokens_per_second metrics
- Add 7b_instruct_dpo_olmocore.sh script for 4-node DPO training with
  16k sequence length
- Update dpo_utils.py with new TrainingConfig fields:
  - gradient_checkpointing_mode for activation checkpointing mode
  - activation_memory_budget for budget-based checkpointing
  - shard_degree and num_replicas for FSDP configuration
  - compile_model for torch.compile support
- Update dpo.py:
  - Import PerfCallback and BeakerCallbackV2 from olmo_core_callbacks
  - Add compile_model support with apply_compile()
  - Update _apply_parallelism for shard_degree/num_replicas
  - Add PerfCallback and ProfilerCallback to callbacks
  - Pass rank_microbatch_size to DPOTrainModule
- Update olmo_core_train_modules.py:
  - Add split_batch_dpo() for microbatch gradient accumulation
  - Update DPOTrainModule with rank_microbatch_size parameter
  - Add proper FSDP sync handling in train_batch

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @finbarrtimbers, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly upgrades the DPO training pipeline by integrating OLMo-core specific features and improving distributed training capabilities. It introduces new callbacks for detailed performance monitoring, expands configuration options for memory optimization and parallelism, and refines data handling and gradient accumulation strategies. These changes aim to enable more efficient and scalable DPO training, particularly for large models and distributed environments.

Highlights

  • Enhanced Performance Metrics: Introduced new OLMo-core callbacks, BeakerCallbackV2 and PerfCallback, for comprehensive metric tracking, including MFU (Model Flops Utilization) and tokens-per-second, providing deeper insights into training efficiency.
  • Advanced Distributed Training Configuration: Extended the TrainingConfig with new fields for flexible gradient checkpointing modes (full, selected_blocks, budget), detailed FSDP configuration (shard_degree, num_replicas), and torch.compile support for potential speedups.
  • Improved Data Handling and Gradient Accumulation: The DataCollatorForSeq2SeqDPO now supports max_length and token_count, and the DPOTrainModule implements microbatch gradient accumulation with proper FSDP/DDP synchronization, leading to more efficient memory usage and stable training.
  • Robust Reference Logprobs Caching: The reference logprobs caching mechanism has been enhanced with incremental checkpointing, allowing for recovery from interruptions and more reliable caching for large datasets.
  • New DPO Training Script: Added a dedicated DPO training script (7b_instruct_dpo_olmocore.sh) demonstrating 4-node, 16k sequence length training with the new OLMo-core features, including budget-based activation checkpointing and model compilation.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

BeakerCallbackV2 is now in olmo_core_callbacks.py, so the old
beaker_callback.py is no longer needed. Rename test file to match.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant enhancements for DPO training with OLMo-core, including support for multi-node training, longer sequences, and improved MFU (Model Flops Utilization) metrics. The changes are extensive, adding configurability for gradient checkpointing and FSDP, introducing torch.compile support, and implementing checkpointing for the reference log-probabilities cache build process. The addition of PerfCallback for MFU monitoring is a great improvement. The code is well-structured, but I've found a critical bug in the training module and a couple of minor issues that should be addressed.

Comment on lines 162 to 168
policy_chosen_logps, policy_rejected_logps, aux_loss = self._forward_fn(
self.model,
micro_batch,
average_log_prob=is_average,
packing=self.args.packing,
output_router_logits=self.args.load_balancing_loss,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's a potential TypeError here. self._forward_fn is set to dpo_utils.separate_forward_olmo if args.concatenated_forward is False. However, dpo_utils.separate_forward_olmo does not accept the packing keyword argument, which is passed unconditionally on line 166. This will cause the training to crash if concatenated_forward is disabled.

                forward_kwargs = {
                    "average_log_prob": is_average,
                    "output_router_logits": self.args.load_balancing_loss,
                }
                if self._forward_fn == dpo_utils.concatenated_forward_olmo:
                    forward_kwargs["packing"] = self.args.packing

                policy_chosen_logps, policy_rejected_logps, aux_loss = self._forward_fn(
                    self.model,
                    micro_batch,
                    **forward_kwargs,
                )

--cluster ai2/jupiter \
--description "OLMo3-7B DPO with OLMo-core, 4 nodes, 16k seq len" \
--workspace ai2/olmo-instruct \
--no_auto_dataset_cache \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This line starts with a literal tab character (\t). This will likely be interpreted as part of the argument string by the shell, passing "\t--no_auto_dataset_cache" to mason.py, which is probably not intended and could cause the script to fail or behave unexpectedly. The tab should be removed.

Suggested change
--no_auto_dataset_cache \
--no_auto_dataset_cache \

Comment on lines +106 to +108
for group in self.optim.param_groups:
new_lr = self.scheduler.set_lr(group, self.trainer)
self.trainer.record_metric(f"LR (group {group_idx})", new_lr, namespace="optim")
self.trainer.record_metric("lr", new_lr, namespace="optim")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The learning rate is now logged only for the last parameter group. While it's likely that all parameter groups share the same learning rate in the current setup, the previous implementation was more robust as it logged the LR for each group. This could be important for debugging if different LR schedules are used for different parameter groups in the future. Consider restoring a behavior similar to the previous implementation for more detailed logging.

Suggested change
for group in self.optim.param_groups:
new_lr = self.scheduler.set_lr(group, self.trainer)
self.trainer.record_metric(f"LR (group {group_idx})", new_lr, namespace="optim")
self.trainer.record_metric("lr", new_lr, namespace="optim")
for i, group in enumerate(self.optim.param_groups):
new_lr = self.scheduler.set_lr(group, self.trainer)
self.trainer.record_metric(f"lr/group_{i}", new_lr, namespace="optim")

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 01bfbd6ff6

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines 162 to 166
policy_chosen_logps, policy_rejected_logps, aux_loss = self._forward_fn(
self.model,
micro_batch,
average_log_prob=is_average,
packing=self.args.packing,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Avoid passing packing kwarg to separate forward

When --concatenated_forward=false, _forward_fn is separate_forward_olmo, which does not accept a packing keyword. The new call always passes packing=self.args.packing, so any run with --concatenated_forward=false will raise a TypeError on the first batch and halt training. Gate the packing argument on the concatenated path or update separate_forward_olmo to accept/ignore it.

Useful? React with 👍 / 👎.

@@ -388,8 +415,6 @@ def main(args: dpo_utils.ExperimentConfig, tc: dataset_transformation.TokenizerC
)

forward_fn = dpo_utils.concatenated_forward_olmo if args.concatenated_forward else dpo_utils.separate_forward_olmo

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Use packing-aware forward for reference cache

With --packing enabled, the collator emits padding-free batches (flattened inputs with cu_seq_lens_*). The reference cache now uses a forward_fn that is always the plain concatenated_forward_olmo/separate_forward_olmo without packing=True, so cached logprobs will be computed with the padded path (or hit shape/key errors) and become invalid. Pass packing=True (or use a packing-aware partial) when building the reference cache for packing runs.

Useful? React with 👍 / 👎.

Comment on lines 213 to 216
if total_aux_loss is not None:
self.record_metric("train/aux_loss", total_aux_loss, ReduceType.mean)

self.record_metric("train/token_count", float(batch["token_count"]), ReduceType.sum)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Don’t assume token_count exists in packing batches

In train_batch, batch["token_count"] is recorded unconditionally, but the padding-free collator (TensorDataCollatorWithFlatteningDPO) does not populate token_count. Any --packing run will therefore crash with a KeyError on the first step. Compute token_count for packed batches or guard this metric when the key is absent.

Useful? React with 👍 / 👎.

finbarrtimbers and others added 17 commits January 30, 2026 10:41
- Narrow build_reference_logprobs_cache to accept HFDataLoader only
- Remove supports_checkpointing conditionals (always use HFDataLoader)
- Remove manual dataloader.batches_processed increment (auto-incremented)
- Remove max_length field from DataCollatorForSeq2SeqDPO

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Fix packing arg only passed to concatenated_forward_olmo (not separate)
- Fix missing token_count check for packing batches
- Fix packing-aware forward for reference cache with functools.partial
- Fix tab character in shell script

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add packing parameter to separate_forward_olmo
- Use pf_get_batch_logps when packing=True for both chosen and rejected
- Remove concatenated_forward check for packing (both forward fns support it)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Change device_peak_flops_per_second to device_peak_flops to match
the current olmo-core API.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Restore dpo_utils.py from main and add packing parameter to
separate_forward_olmo for padding-free training support.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Add compile_model, activation_memory_budget, shard_degree, and
num_replicas fields to TrainingConfig.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Add gradient_checkpointing_mode, gradient_checkpointing_block_interval,
and gradient_checkpointing_modules fields to TrainingConfig. Convert
string mode to TransformerActivationCheckpointingMode enum in dpo.py.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Use concatenated_cu_seq_lens_q (not concatenated_cu_seq_lens) for
  packing-free logps calculation
- Add token_count to DataCollatorForSeq2SeqDPO output for accurate
  MFU metrics

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
When compile_model is enabled, pad all tensors to max_seq_length to prevent
torch.compile recompilations due to varying sequence lengths. This improves
MFU from ~9% to >40% for 7B DPO training.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This enables torch.compile and the fixed-length padding for improved MFU.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
The combination of gradient_checkpointing and torch.compile causes
tensor metadata mismatches during recomputation. Disable checkpointing
when using compile_model.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Budget mode activation checkpointing is compatible with torch.compile,
unlike full mode which causes tensor metadata mismatches.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants