From 94dc474fae993817fca6ab5617c4a78e8152cc5d Mon Sep 17 00:00:00 2001 From: PopSoda2002 Date: Mon, 12 Jan 2026 06:55:00 +0000 Subject: [PATCH] Support on policy distillation in fsdp --- miles/backends/fsdp_utils/actor.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/miles/backends/fsdp_utils/actor.py b/miles/backends/fsdp_utils/actor.py index f22a95546..9f78e4cfd 100644 --- a/miles/backends/fsdp_utils/actor.py +++ b/miles/backends/fsdp_utils/actor.py @@ -629,6 +629,20 @@ def _train_step(self, packed_batch, reported_accum, mbs_id, grad_accum): response_lengths = [batch["response_lengths"] for batch in unpacked_batches] advantages = advantages.to(device=log_probs.device) + + if self.args.advantage_estimator == "on_policy_distillation": + student_log_probs = log_probs + teacher_log_probs = rollout_data.get("teacher_log_probs") + response_lengths = rollout_data.get("response_lengths") + device = student_log_probs[0].device + teacher_log_probs = [t_log_prob.to(device=device) for t_log_prob in teacher_log_probs] + teacher_log_probs = [ + t_log_prob[-response_length:] for t_log_prob, response_length in zip(teacher_log_probs, response_lengths) + ] + advantages = [ + teacher_log_prob - student_log_prob + for teacher_log_prob, student_log_prob in zip(teacher_log_probs, student_log_probs) + ] old_log_probs = old_log_probs.to(device=log_probs.device) ppo_kl = old_log_probs - log_probs