|
| 1 | +"""DPO (Direct Preference Optimization) Training with MultiLoRA (Megatron Backend). |
| 2 | +
|
| 3 | +MultiLoRA-based DPO training: uses the base model (without LoRA adapter) as reference |
| 4 | +model by calling forward_only with disable_lora=True. This eliminates the need for |
| 5 | +a separate reference model GPU group. |
| 6 | +
|
| 7 | +Uses Megatron backend with MultiLoRAMegatronModel for efficient multi-tenant LoRA training. |
| 8 | +
|
| 9 | +Pipeline: |
| 10 | + 1. Load preference dataset with chosen/rejected pairs. |
| 11 | + 2. Encode positive and negative separately. |
| 12 | + 3. Compute reference model log probabilities using base model (disable_lora=True). |
| 13 | + 4. Train policy model (with LoRA adapter) using DPO loss. |
| 14 | +
|
| 15 | +Architecture (Ray - Single Group): |
| 16 | + ┌─────────────────────────────────────────────────────────────────┐ |
| 17 | + │ Driver (CPU) │ |
| 18 | + │ dataloader ──► batched preference pairs │ |
| 19 | + │ policy_model.forward_only(disable_lora=True) ──► ref logps │ |
| 20 | + │ policy_model.forward_backward() ──► DPO loss + gradient │ |
| 21 | + └─────────────────────────────────────────────────────────────────┘ |
| 22 | + │ |
| 23 | + PolicyModel (with LoRA adapter) |
| 24 | + - forward_only(disable_lora=True) → base model inference (reference) |
| 25 | + - forward_backward() → LoRA adapter training (policy) |
| 26 | +
|
| 27 | +DPO data format (after preprocessing): |
| 28 | + - positive: List[Trajectory] - chosen responses |
| 29 | + - negative: List[Trajectory] - rejected responses |
| 30 | +
|
| 31 | +Environment variables (all optional): |
| 32 | + MODEL_ID – (default: ms://Qwen/Qwen3.5-4B) |
| 33 | + DATASET_ID – (default: ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji) |
| 34 | + MODEL_GPUS – GPUs for policy model (default: 2) |
| 35 | + BATCH_SIZE – global batch size (preference pairs) (default: 8) |
| 36 | + MAX_STEPS – total optimization steps (default: 1000) |
| 37 | + LR – learning rate (default: 1e-4) |
| 38 | + DPO_BETA – DPO temperature parameter (default: 0.1) |
| 39 | + LOSS_TYPE – DPO variant (sigmoid/hinge/ipo) (default: sigmoid) |
| 40 | + SAVE_STEPS – checkpoint save interval (default: 100) |
| 41 | + MAX_LENGTH – max sequence length (default: 2048) |
| 42 | +""" |
| 43 | + |
| 44 | +import os |
| 45 | +from typing import Any, Dict, List, Optional |
| 46 | + |
| 47 | +from peft import LoraConfig |
| 48 | + |
| 49 | +import twinkle |
| 50 | +from twinkle import DeviceGroup, DeviceMesh, get_device_placement, get_logger |
| 51 | +from twinkle.data_format import Trajectory |
| 52 | +from twinkle.dataloader import DataLoader |
| 53 | +from twinkle.dataset import Dataset, DatasetMeta |
| 54 | +from twinkle.loss import DPOLoss |
| 55 | +from twinkle.metric import DPOMetric |
| 56 | +from twinkle.preprocessor import EmojiDPOProcessor |
| 57 | +from twinkle.processor import InputProcessor |
| 58 | + |
| 59 | +logger = get_logger() |
| 60 | + |
| 61 | +# ── Configuration ───────────────────────────────────────────────────────────── |
| 62 | +MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.5-4B') |
| 63 | +DATASET_ID = os.environ.get('DATASET_ID', 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji') |
| 64 | + |
| 65 | +MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 2)) |
| 66 | + |
| 67 | +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) # Number of preference pairs |
| 68 | +GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 2)) |
| 69 | +LEARNING_RATE = float(os.environ.get('LR', 1e-4)) # LoRA DPO requires higher LR (1e-4 to 3e-4) |
| 70 | +DPO_BETA = float(os.environ.get('DPO_BETA', 0.1)) |
| 71 | +SFT_WEIGHT = float(os.environ.get('SFT_WEIGHT', 1.0)) # SFT loss weight for regularization |
| 72 | +LOSS_TYPE = os.environ.get('LOSS_TYPE', 'sigmoid') # sigmoid, hinge, ipo |
| 73 | +SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 100)) |
| 74 | +MAX_LENGTH = int(os.environ.get('MAX_LENGTH', 2048)) |
| 75 | +ADAPTER_NAME = 'default_0' |
| 76 | +SYSTEM_PROMPT = os.environ.get('SYSTEM_PROMPT', 'You are a helpful assistant.') |
| 77 | + |
| 78 | + |
| 79 | +def create_dpo_dataset(): |
| 80 | + """Create DPO dataset with positive/negative format.""" |
| 81 | + dataset = Dataset(DatasetMeta(DATASET_ID, data_slice=range(50))) |
| 82 | + dataset.set_template('Qwen3_5Template', model_id=MODEL_ID, max_length=MAX_LENGTH) |
| 83 | + dataset.map( |
| 84 | + EmojiDPOProcessor, |
| 85 | + init_args={ |
| 86 | + 'system': SYSTEM_PROMPT, |
| 87 | + } |
| 88 | + ) |
| 89 | + # DPO preprocessor returns {'positive': [...], 'negative': [...]} |
| 90 | + # batch_encode handles this format automatically |
| 91 | + dataset.encode(load_from_cache_file=True) |
| 92 | + return dataset |
| 93 | + |
| 94 | + |
| 95 | +def prepare_dpo_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
| 96 | + """Prepare DPO batch: reorganize batch for training with DP-safe interleaving. |
| 97 | +
|
| 98 | + Args: |
| 99 | + batch: List of rows, each with 'positive' and 'negative' InputFeatures |
| 100 | + and other fields (question, etc.) |
| 101 | +
|
| 102 | + Returns: |
| 103 | + List interleaved as [pos_1, neg_1, pos_2, neg_2, ...] to ensure each DP |
| 104 | + worker gets complete positive/negative pairs after slicing. |
| 105 | + Each item contains all original fields plus the InputFeature fields. |
| 106 | + """ |
| 107 | + result = [] |
| 108 | + |
| 109 | + for row in batch: |
| 110 | + # Get base fields (excluding positive/negative) |
| 111 | + base_fields = {k: v for k, v in row.items() if k not in ('positive', 'negative')} |
| 112 | + |
| 113 | + # Positive sample: merge base fields with positive InputFeature |
| 114 | + pos_sample = {**base_fields, **row['positive']} |
| 115 | + # Negative sample: merge base fields with negative InputFeature |
| 116 | + neg_sample = {**base_fields, **row['negative']} |
| 117 | + |
| 118 | + # Interleave: [pos, neg] per pair for DP-safe slicing |
| 119 | + result.append(pos_sample) |
| 120 | + result.append(neg_sample) |
| 121 | + |
| 122 | + return result |
| 123 | + |
| 124 | + |
| 125 | +# ── Main Training Loop ──────────────────────────────────────────────────────── |
| 126 | + |
| 127 | +def main(): |
| 128 | + # Set up device groups - only one group for LoRA training |
| 129 | + device_groups = [ |
| 130 | + DeviceGroup(name='policy', ranks=list(range(MODEL_GPUS)), device_type='GPU'), |
| 131 | + ] |
| 132 | + |
| 133 | + # Configure device mesh for MultiLoRA Megatron: dp=2, pp=1 |
| 134 | + from twinkle.model import MultiLoraMegatronModel |
| 135 | + policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=2, pp_size=1) |
| 136 | + ModelClass = MultiLoraMegatronModel |
| 137 | + |
| 138 | + twinkle.initialize(mode='ray', nproc_per_node=MODEL_GPUS, groups=device_groups) |
| 139 | + |
| 140 | + # ── DataLoader Setup ────────────────────────────────────────────────────── |
| 141 | + dataloader = DataLoader( |
| 142 | + dataset=create_dpo_dataset, |
| 143 | + batch_size=BATCH_SIZE, |
| 144 | + min_batch_size=BATCH_SIZE, |
| 145 | + device_mesh=policy_mesh, |
| 146 | + ) |
| 147 | + |
| 148 | + # ── Policy Model Setup with LoRA ────────────────────────────────────────── |
| 149 | + lora_config = LoraConfig( |
| 150 | + target_modules='all-linear', |
| 151 | + r=8, |
| 152 | + lora_alpha=32, |
| 153 | + lora_dropout=0.05, |
| 154 | + ) |
| 155 | + |
| 156 | + policy_model = ModelClass( |
| 157 | + model_id=MODEL_ID, |
| 158 | + device_mesh=policy_mesh, |
| 159 | + remote_group='policy', |
| 160 | + ) |
| 161 | + MAX_STEPS = len(dataloader) |
| 162 | + policy_model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS) |
| 163 | + |
| 164 | + # Configure optimizer based on backend |
| 165 | + policy_model.set_optimizer('default', lr=LEARNING_RATE, weight_decay=0.01, adapter_name=ADAPTER_NAME) |
| 166 | + policy_model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS, adapter_name=ADAPTER_NAME) |
| 167 | + |
| 168 | + |
| 169 | + # Set up loss function and metrics |
| 170 | + loss_fn = DPOLoss( |
| 171 | + beta=DPO_BETA, |
| 172 | + loss_type=LOSS_TYPE, |
| 173 | + reference_free=False, # We use base model as reference via disable_lora=True |
| 174 | + sft_weight=SFT_WEIGHT, |
| 175 | + ) |
| 176 | + |
| 177 | + policy_model.set_loss(loss_fn) |
| 178 | + policy_model.add_metric(DPOMetric, beta=DPO_BETA) |
| 179 | + policy_model.set_processor(InputProcessor) |
| 180 | + policy_model.set_template('Qwen3_5Template', model_id=MODEL_ID) |
| 181 | + |
| 182 | + optim_step = 0 |
| 183 | + backend_name = 'MultiLoRA Megatron' |
| 184 | + logger.info(get_device_placement()) |
| 185 | + logger.info(f'Starting MultiLoRA DPO training ({backend_name}): loss_type={LOSS_TYPE}, beta={DPO_BETA}, lr={LEARNING_RATE}') |
| 186 | + logger.info(f'Using base model (disable_lora=True) as reference model') |
| 187 | + |
| 188 | + # ── Training Loop ───────────────────────────────────────────────────────── |
| 189 | + for batch in dataloader: |
| 190 | + # batch is List[Dict] with 'positive' and 'negative' keys |
| 191 | + dpo_batch = prepare_dpo_batch(batch) |
| 192 | + |
| 193 | + # Get reference outputs using base model (without LoRA adapter) |
| 194 | + # disable_lora=True tells the model to skip LoRA and use base weights |
| 195 | + ref_outputs = policy_model.forward_only(inputs=dpo_batch, disable_lora=True, adapter_name=ADAPTER_NAME) |
| 196 | + policy_model.forward_backward(inputs=dpo_batch, ref_outputs=ref_outputs, adapter_name=ADAPTER_NAME) |
| 197 | + policy_model.clip_grad_and_step(adapter_name=ADAPTER_NAME) |
| 198 | + |
| 199 | + optim_step += 1 |
| 200 | + |
| 201 | + # Logging |
| 202 | + if optim_step % GRADIENT_ACCUMULATION_STEPS == 0: |
| 203 | + metrics = policy_model.calculate_metric(is_training=True, adapter_name=ADAPTER_NAME) |
| 204 | + logger.info(f'[Step {optim_step // GRADIENT_ACCUMULATION_STEPS}/{MAX_STEPS}] {metrics}') |
| 205 | + |
| 206 | + # Checkpointing |
| 207 | + if optim_step % SAVE_STEPS == 0: |
| 208 | + policy_model.save(f'dpo-lora-checkpoint-{optim_step}', adapter_name=ADAPTER_NAME) |
| 209 | + |
| 210 | + # ── Save Final Checkpoint ───────────────────────────────────────────────── |
| 211 | + logger.info(f'Training completed. Total steps: {optim_step}') |
| 212 | + policy_model.save('dpo-lora-final-checkpoint', adapter_name=ADAPTER_NAME) |
| 213 | + |
| 214 | + |
| 215 | +if __name__ == '__main__': |
| 216 | + main() |
0 commit comments