Skip to content

Conversation

@PopSoda2002
Copy link
Contributor

No description provided.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @PopSoda2002, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the FSDP training backend by integrating a new method for calculating advantages, specifically 'on-policy distillation'. This allows the training process to leverage a teacher model's policy by using the difference in log probabilities as the advantage signal, providing a novel approach to policy optimization that could lead to improved learning efficiency or stability.

Highlights

  • Policy Distillation Integration: Introduced support for 'on_policy_distillation' as an advantage estimator within the _train_step function in miles/backends/fsdp_utils/actor.py.
  • Dynamic Advantage Calculation: Implemented conditional logic to calculate advantages by subtracting student log probabilities from teacher log probabilities when 'on_policy_distillation' is enabled, utilizing teacher_log_probs and response_lengths from rollout_data.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for on-policy distillation in FSDP. The core logic added in _train_step has a few critical issues. Notably, it attempts to access rollout_data which is out of scope, leading to a NameError. There are also incorrect tensor operations when calculating advantages, which would cause runtime errors or incorrect behavior. I've provided specific comments and suggestions to address these problems.

Comment on lines +635 to +636
teacher_log_probs = rollout_data.get("teacher_log_probs")
response_lengths = rollout_data.get("response_lengths")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The variable rollout_data is used here but it's not defined within the _train_step method's scope, which will cause a NameError. Additionally, line 636 incorrectly re-assigns response_lengths from this out-of-scope variable; the correct micro-batch-scoped response_lengths is already available from line 629.

To fix this, you should pack teacher_log_probs into packed_batch (similar to other data) and retrieve it from there. You should also remove the re-assignment of response_lengths.

Comment on lines +642 to +645
advantages = [
teacher_log_prob - student_log_prob
for teacher_log_prob, student_log_prob in zip(teacher_log_probs, student_log_probs)
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

student_log_probs is a single flat tensor containing log probabilities for all sequences in the micro-batch. However, the list comprehension attempts to zip it with teacher_log_probs (a list of tensors), which will incorrectly iterate over the scalar elements of student_log_probs.

You need to split student_log_probs into a list of tensors, one for each sequence, using response_lengths before zipping.

Suggested change
advantages = [
teacher_log_prob - student_log_prob
for teacher_log_prob, student_log_prob in zip(teacher_log_probs, student_log_probs)
]
student_log_probs_list = list(student_log_probs.split(response_lengths, dim=0))
advantages = [
teacher_log_prob - student_log_prob
for teacher_log_prob, student_log_prob in zip(teacher_log_probs, student_log_probs_list)
]

student_log_probs = log_probs
teacher_log_probs = rollout_data.get("teacher_log_probs")
response_lengths = rollout_data.get("response_lengths")
device = student_log_probs[0].device
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Accessing student_log_probs[0] to get the device can raise an IndexError if student_log_probs is an empty tensor. It's safer to get the device directly from the tensor property.

Suggested change
device = student_log_probs[0].device
device = student_log_probs.device

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant