|
| 1 | +# Copyright (c) ModelScope Contributors. All rights reserved. |
| 2 | +"""DPO-specific metrics for preference optimization training.""" |
| 3 | +from typing import List, Union |
| 4 | + |
| 5 | +from twinkle.data_format import InputFeature, ModelOutput |
| 6 | +from .base import Metric |
| 7 | + |
| 8 | + |
| 9 | +class DPOMetric(Metric): |
| 10 | + """Metrics for DPO (Direct Preference Optimization) training. |
| 11 | +
|
| 12 | + Computes TRL-style metrics: |
| 13 | + - logps/chosen: Average sequence-level log prob of chosen responses |
| 14 | + - logps/rejected: Average sequence-level log prob of rejected responses |
| 15 | + - rewards/chosen: β * (policy_chosen - ref_chosen) |
| 16 | + - rewards/rejected: β * (policy_rejected - ref_rejected) |
| 17 | + - rewards/margins: chosen_reward - rejected_reward |
| 18 | + - rewards/accuracies: Percentage where chosen_reward > rejected_reward |
| 19 | +
|
| 20 | + Args: |
| 21 | + device_mesh: The device mesh |
| 22 | + process_group: The process group to collect data from |
| 23 | + ignore_index: Label index to ignore (default: -100) |
| 24 | + beta: DPO beta parameter for reward scaling (default: 0.1) |
| 25 | + """ |
| 26 | + |
| 27 | + def __init__(self, device_mesh, process_group, ignore_index: int = -100, beta: float = 0.1, **kwargs): |
| 28 | + super().__init__(device_mesh, process_group, **kwargs) |
| 29 | + self.ignore_index = ignore_index |
| 30 | + self.beta = beta |
| 31 | + self.reset() |
| 32 | + |
| 33 | + def _compute_sequence_logps(self, per_token_logps, labels): |
| 34 | + """Compute sequence-level log probs by summing valid token logps.""" |
| 35 | + import torch |
| 36 | + loss_mask = (labels != self.ignore_index).float() |
| 37 | + return (per_token_logps * loss_mask).sum(dim=-1) |
| 38 | + |
| 39 | + def _split_chosen_rejected(self, tensor): |
| 40 | + """Split interleaved tensor into chosen and rejected. |
| 41 | +
|
| 42 | + Input format: [pos_1, neg_1, pos_2, neg_2, ...] (interleaved for DP-safe slicing) |
| 43 | + Output: (chosen [pos_1, pos_2, ...], rejected [neg_1, neg_2, ...]) |
| 44 | + """ |
| 45 | + return tensor[0::2], tensor[1::2] |
| 46 | + |
| 47 | + def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: ModelOutput, **kwargs): |
| 48 | + """Accumulate DPO metrics from model outputs. |
| 49 | +
|
| 50 | + Expects: |
| 51 | + - outputs['logps']: [batch, seq_len] per-token log probabilities |
| 52 | + - inputs['labels']: [batch, seq_len] labels with ignore_index for non-target tokens |
| 53 | + - kwargs['ref_outputs']: Optional reference model outputs with 'logps' |
| 54 | + """ |
| 55 | + import torch |
| 56 | + |
| 57 | + logps = outputs.get('logps') |
| 58 | + if logps is None: |
| 59 | + return |
| 60 | + |
| 61 | + # Get labels from inputs |
| 62 | + if isinstance(inputs, list): |
| 63 | + # Stack labels from list of inputs |
| 64 | + labels_list = [torch.as_tensor(inp['labels']) for inp in inputs] |
| 65 | + max_len = max(l.shape[0] for l in labels_list) |
| 66 | + padded = [] |
| 67 | + for l in labels_list: |
| 68 | + if l.shape[0] < max_len: |
| 69 | + pad = torch.full((max_len - l.shape[0],), self.ignore_index, dtype=l.dtype) |
| 70 | + l = torch.cat([pad, l]) |
| 71 | + padded.append(l) |
| 72 | + labels = torch.stack(padded) |
| 73 | + else: |
| 74 | + labels = torch.as_tensor(inputs['labels']) |
| 75 | + if labels.dim() == 1: |
| 76 | + labels = labels.unsqueeze(0) |
| 77 | + |
| 78 | + # Ensure logps and labels have same device |
| 79 | + if logps.device != labels.device: |
| 80 | + labels = labels.to(logps.device) |
| 81 | + |
| 82 | + # Align sequence lengths if needed (truncate right) |
| 83 | + if logps.shape[1] != labels.shape[1]: |
| 84 | + min_len = min(logps.shape[1], labels.shape[1]) |
| 85 | + logps = logps[:, :min_len] |
| 86 | + labels = labels[:, :min_len] |
| 87 | + |
| 88 | + # Compute sequence-level logps |
| 89 | + seq_logps = self._compute_sequence_logps(logps, labels) |
| 90 | + |
| 91 | + # Split into chosen and rejected (interleaved format) |
| 92 | + chosen_logps, rejected_logps = self._split_chosen_rejected(seq_logps) |
| 93 | + chosen_labels, rejected_labels = self._split_chosen_rejected(labels) |
| 94 | + |
| 95 | + # Accumulate policy logps |
| 96 | + self.total_chosen_logps += chosen_logps.sum().item() |
| 97 | + self.total_rejected_logps += rejected_logps.sum().item() |
| 98 | + |
| 99 | + # Compute rewards if ref_outputs available |
| 100 | + ref_outputs = kwargs.get('ref_outputs') |
| 101 | + if ref_outputs is not None: |
| 102 | + ref_logps = ref_outputs.get('logps') |
| 103 | + if ref_logps is not None: |
| 104 | + # Align ref_logps |
| 105 | + if ref_logps.device != labels.device: |
| 106 | + ref_logps = ref_logps.to(labels.device) |
| 107 | + if ref_logps.shape[1] != labels.shape[1]: |
| 108 | + min_len = min(ref_logps.shape[1], labels.shape[1]) |
| 109 | + ref_logps = ref_logps[:, :min_len] |
| 110 | + |
| 111 | + ref_seq_logps = self._compute_sequence_logps(ref_logps, labels) |
| 112 | + ref_chosen_logps, ref_rejected_logps = self._split_chosen_rejected(ref_seq_logps) |
| 113 | + |
| 114 | + # Compute rewards: β * (policy - ref) |
| 115 | + chosen_rewards = self.beta * (chosen_logps - ref_chosen_logps) |
| 116 | + rejected_rewards = self.beta * (rejected_logps - ref_rejected_logps) |
| 117 | + |
| 118 | + self.total_chosen_rewards += chosen_rewards.sum().item() |
| 119 | + self.total_rejected_rewards += rejected_rewards.sum().item() |
| 120 | + margins = chosen_rewards - rejected_rewards |
| 121 | + self.total_reward_margin += margins.sum().item() |
| 122 | + self.total_reward_correct += (margins > 0).sum().item() |
| 123 | + self.has_rewards = True |
| 124 | + |
| 125 | + self.total_count += chosen_logps.shape[0] |
| 126 | + |
| 127 | + def reset(self): |
| 128 | + """Reset all accumulated values.""" |
| 129 | + self.total_chosen_logps = 0.0 |
| 130 | + self.total_rejected_logps = 0.0 |
| 131 | + self.total_chosen_rewards = 0.0 |
| 132 | + self.total_rejected_rewards = 0.0 |
| 133 | + self.total_reward_margin = 0.0 |
| 134 | + self.total_reward_correct = 0 |
| 135 | + self.total_count = 0 |
| 136 | + self.has_rewards = False |
| 137 | + |
| 138 | + def calculate(self): |
| 139 | + """Calculate and return aggregated metrics.""" |
| 140 | + local_results = [{ |
| 141 | + 'chosen_logps': self.total_chosen_logps, |
| 142 | + 'rejected_logps': self.total_rejected_logps, |
| 143 | + 'chosen_rewards': self.total_chosen_rewards, |
| 144 | + 'rejected_rewards': self.total_rejected_rewards, |
| 145 | + 'reward_margin': self.total_reward_margin, |
| 146 | + 'reward_correct': self.total_reward_correct, |
| 147 | + 'count': self.total_count, |
| 148 | + 'has_rewards': self.has_rewards, |
| 149 | + }] |
| 150 | + all_results = self.gather_results(local_results) |
| 151 | + |
| 152 | + total_chosen_logps = sum(r['chosen_logps'] for r in all_results) |
| 153 | + total_rejected_logps = sum(r['rejected_logps'] for r in all_results) |
| 154 | + total_chosen_rewards = sum(r['chosen_rewards'] for r in all_results) |
| 155 | + total_rejected_rewards = sum(r['rejected_rewards'] for r in all_results) |
| 156 | + total_reward_margin = sum(r['reward_margin'] for r in all_results) |
| 157 | + total_reward_correct = sum(r['reward_correct'] for r in all_results) |
| 158 | + total_count = sum(r['count'] for r in all_results) |
| 159 | + has_rewards = any(r['has_rewards'] for r in all_results) |
| 160 | + |
| 161 | + self.reset() |
| 162 | + |
| 163 | + if total_count == 0: |
| 164 | + return {} |
| 165 | + |
| 166 | + results = { |
| 167 | + 'logps/chosen': f'{total_chosen_logps / total_count:.2f}', |
| 168 | + 'logps/rejected': f'{total_rejected_logps / total_count:.2f}', |
| 169 | + } |
| 170 | + |
| 171 | + if has_rewards: |
| 172 | + results['rewards/chosen'] = f'{total_chosen_rewards / total_count:.4f}' |
| 173 | + results['rewards/rejected'] = f'{total_rejected_rewards / total_count:.4f}' |
| 174 | + results['rewards/margins'] = f'{total_reward_margin / total_count:.4f}' |
| 175 | + results['rewards/accuracies'] = f'{total_reward_correct / total_count * 100:.1f}%' |
| 176 | + |
| 177 | + return results |
0 commit comments