-
Notifications
You must be signed in to change notification settings - Fork 81
[Feat.] Support on policy distillation in fsdp #420
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @PopSoda2002, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the FSDP training backend by integrating a new method for calculating advantages, specifically 'on-policy distillation'. This allows the training process to leverage a teacher model's policy by using the difference in log probabilities as the advantage signal, providing a novel approach to policy optimization that could lead to improved learning efficiency or stability. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for on-policy distillation in FSDP. The core logic added in _train_step has a few critical issues. Notably, it attempts to access rollout_data which is out of scope, leading to a NameError. There are also incorrect tensor operations when calculating advantages, which would cause runtime errors or incorrect behavior. I've provided specific comments and suggestions to address these problems.
| teacher_log_probs = rollout_data.get("teacher_log_probs") | ||
| response_lengths = rollout_data.get("response_lengths") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable rollout_data is used here but it's not defined within the _train_step method's scope, which will cause a NameError. Additionally, line 636 incorrectly re-assigns response_lengths from this out-of-scope variable; the correct micro-batch-scoped response_lengths is already available from line 629.
To fix this, you should pack teacher_log_probs into packed_batch (similar to other data) and retrieve it from there. You should also remove the re-assignment of response_lengths.
| advantages = [ | ||
| teacher_log_prob - student_log_prob | ||
| for teacher_log_prob, student_log_prob in zip(teacher_log_probs, student_log_probs) | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
student_log_probs is a single flat tensor containing log probabilities for all sequences in the micro-batch. However, the list comprehension attempts to zip it with teacher_log_probs (a list of tensors), which will incorrectly iterate over the scalar elements of student_log_probs.
You need to split student_log_probs into a list of tensors, one for each sequence, using response_lengths before zipping.
| advantages = [ | |
| teacher_log_prob - student_log_prob | |
| for teacher_log_prob, student_log_prob in zip(teacher_log_probs, student_log_probs) | |
| ] | |
| student_log_probs_list = list(student_log_probs.split(response_lengths, dim=0)) | |
| advantages = [ | |
| teacher_log_prob - student_log_prob | |
| for teacher_log_prob, student_log_prob in zip(teacher_log_probs, student_log_probs_list) | |
| ] |
| student_log_probs = log_probs | ||
| teacher_log_probs = rollout_data.get("teacher_log_probs") | ||
| response_lengths = rollout_data.get("response_lengths") | ||
| device = student_log_probs[0].device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No description provided.