1- """DPO (Direct Preference Optimization) Training via Ray.
1+ """DPO (Direct Preference Optimization) Full-Parameter Training via Ray.
22
33Off-policy preference alignment: trains the model to prefer chosen responses
44over rejected responses using preference data, without explicit reward modeling.
55
6+ Supports both Transformers (FSDP) and Megatron backends via USE_MEGATRON flag.
7+
68Pipeline:
79 1. Load preference dataset with chosen/rejected pairs.
810 2. Encode positive and negative separately.
911 3. Compute reference model log probabilities (frozen).
10- 4. Train policy model using DPO loss.
12+ 4. Train policy model using DPO loss (full-parameter, no LoRA) .
1113
1214Architecture (Ray):
1315 ┌─────────────────────────────────────────────────────────────────┐
2830set REF_MODEL_GPUS=0 to skip reference model computation.
2931
3032Environment variables (all optional):
31- MODEL_ID – (default: ms://Qwen/Qwen3.5-4B)
33+ USE_MEGATRON – Use Megatron backend (default: 0, use Transformers)
34+ MODEL_ID – (default: ms://Qwen/Qwen3-4B)
3235 DATASET_ID – (default: ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji)
3336 MODEL_GPUS – GPUs for policy model (default: 4)
3437 REF_MODEL_GPUS – GPUs for reference model (default: 4, 0 to disable)
3538 BATCH_SIZE – global batch size (preference pairs) (default: 8)
3639 MICRO_BATCH_SIZE – per-device micro batch size (default: 2)
3740 MAX_STEPS – total optimization steps (default: 1000)
38- LR – learning rate (default: 5e-6 )
41+ LR – learning rate (default: 1e-5 )
3942 DPO_BETA – DPO temperature parameter (default: 0.1)
4043 LOSS_TYPE – DPO variant (sigmoid/hinge/ipo/simpo/orpo/cpo) (default: sigmoid)
4144 SAVE_STEPS – checkpoint save interval (default: 100)
4952"""
5053
5154import os
52- from typing import Any , Dict , List , Optional
53-
54- from peft import LoraConfig
55+ from typing import Any , Dict , List
5556
5657import twinkle
5758from twinkle import DeviceGroup , DeviceMesh , get_device_placement , get_logger
6061from twinkle .dataset import Dataset , DatasetMeta
6162from twinkle .loss import CPOLoss , DPOLoss , ORPOLoss , SimPOLoss
6263from twinkle .metric import DPOMetric
63- from twinkle .model import TransformersModel
6464from twinkle .preprocessor import EmojiDPOProcessor
6565from twinkle .processor import InputProcessor
6666
6767logger = get_logger ()
6868
6969# ── Configuration ─────────────────────────────────────────────────────────────
70- MODEL_ID = os .environ .get ('MODEL_ID' , 'ms://Qwen/Qwen2.5-7B-Instruct' )
70+ USE_MEGATRON = int (os .environ .get ('USE_MEGATRON' , 0 ))
71+ MODEL_ID = os .environ .get ('MODEL_ID' , 'ms://Qwen/Qwen3-4B' )
7172DATASET_ID = os .environ .get ('DATASET_ID' , 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji' )
7273
73- MODEL_GPUS = int (os .environ .get ('MODEL_GPUS' , 2 ))
74- REF_MODEL_GPUS = int (os .environ .get ('REF_MODEL_GPUS' , 2 ))
74+ MODEL_GPUS = int (os .environ .get ('MODEL_GPUS' , 4 ))
75+ REF_MODEL_GPUS = int (os .environ .get ('REF_MODEL_GPUS' , 4 ))
7576NUM_GPUS = MODEL_GPUS + REF_MODEL_GPUS
7677
77- BATCH_SIZE = int (os .environ .get ('BATCH_SIZE' , 2 )) # Number of preference pairs
78+ BATCH_SIZE = int (os .environ .get ('BATCH_SIZE' , 8 )) # Number of preference pairs
7879MICRO_BATCH_SIZE = int (os .environ .get ('MICRO_BATCH_SIZE' , 2 ))
7980GRADIENT_ACCUMULATION_STEPS = int (os .environ .get ('GRADIENT_ACCUMULATION_STEPS' , 2 ))
80- LEARNING_RATE = float (os .environ .get ('LR' , 1e-4 )) # TRL default for DPO is 5e-7 to 5e-6
81+ LEARNING_RATE = float (os .environ .get ('LR' , 1e-5 ))
8182DPO_BETA = float (os .environ .get ('DPO_BETA' , 0.1 ))
8283SFT_WEIGHT = float (os .environ .get ('SFT_WEIGHT' , 1.0 )) # SFT loss weight for regularization
8384LOSS_TYPE = os .environ .get ('LOSS_TYPE' , 'sigmoid' ) # sigmoid, hinge, ipo, simpo, orpo, cpo
8485SAVE_STEPS = int (os .environ .get ('SAVE_STEPS' , 100 ))
8586MAX_LENGTH = int (os .environ .get ('MAX_LENGTH' , 2048 ))
86- ADAPTER_NAME = 'default'
8787SYSTEM_PROMPT = os .environ .get ('SYSTEM_PROMPT' , 'You are a helpful assistant.' )
8888
8989
@@ -162,7 +162,20 @@ def main():
162162 DeviceGroup (name = 'reference' , ranks = list (range (MODEL_GPUS , NUM_GPUS )), device_type = 'GPU' ),
163163 ]
164164
165- policy_mesh = DeviceMesh .from_sizes (world_size = MODEL_GPUS , dp_size = MODEL_GPUS )
165+ # Configure device mesh based on backend
166+ if USE_MEGATRON :
167+ # Megatron: dp=2, pp=2 for each model
168+ from twinkle .model import MegatronModel
169+ policy_mesh = DeviceMesh .from_sizes (world_size = MODEL_GPUS , dp_size = 2 , pp_size = 2 )
170+ ref_mesh = DeviceMesh .from_sizes (world_size = REF_MODEL_GPUS , dp_size = 2 , pp_size = 2 )
171+ ModelClass = MegatronModel
172+ else :
173+ # Transformers: fsdp=2, dp=2 for each model
174+ from twinkle .model import TransformersModel
175+ policy_mesh = DeviceMesh .from_sizes (world_size = MODEL_GPUS , fsdp_size = 2 , dp_size = 2 )
176+ ref_mesh = DeviceMesh .from_sizes (world_size = REF_MODEL_GPUS , fsdp_size = 2 , dp_size = 2 )
177+ ModelClass = TransformersModel
178+
166179 twinkle .initialize (mode = 'ray' , nproc_per_node = NUM_GPUS , groups = device_groups )
167180
168181 # ── DataLoader Setup ──────────────────────────────────────────────────────
@@ -172,31 +185,29 @@ def main():
172185 min_batch_size = BATCH_SIZE ,
173186 device_mesh = policy_mesh ,
174187 )
175- length = len (dataloader )
176188
177189 # ── Policy Model Setup ────────────────────────────────────────────────────
178- lora_config = LoraConfig (
179- target_modules = 'all-linear' ,
180- r = 16 ,
181- lora_alpha = 32 ,
182- lora_dropout = 0.05 ,
183- )
184-
185- policy_model = TransformersModel (
190+ policy_model = ModelClass (
186191 model_id = MODEL_ID ,
187192 device_mesh = policy_mesh ,
188193 remote_group = 'policy' ,
189194 )
190195 MAX_STEPS = len (dataloader )
191- policy_model .add_adapter_to_model (ADAPTER_NAME , lora_config , gradient_accumulation_steps = GRADIENT_ACCUMULATION_STEPS )
192- policy_model .set_optimizer ('AdamW' , lr = LEARNING_RATE , weight_decay = 0.01 )
193- policy_model .set_lr_scheduler ('CosineAnnealingLR' , T_max = MAX_STEPS , eta_min = LEARNING_RATE * 0.1 )
194196
195197 # Determine if we need reference model based on loss type
196198 reference_free = LOSS_TYPE in ['simpo' , 'orpo' , 'cpo' ]
197199
198200 # Set up loss function and metrics
199201 loss_fn = create_loss (LOSS_TYPE , DPO_BETA , sft_weight = SFT_WEIGHT , reference_free = False )
202+
203+ # Configure optimizer based on backend (full-parameter training)
204+ if USE_MEGATRON :
205+ policy_model .set_optimizer ('default' , lr = LEARNING_RATE , weight_decay = 0.01 )
206+ policy_model .set_lr_scheduler ('default' , lr_decay_steps = MAX_STEPS )
207+ else :
208+ policy_model .set_optimizer ('AdamW' , lr = LEARNING_RATE , weight_decay = 0.01 )
209+ policy_model .set_lr_scheduler ('CosineAnnealingLR' , T_max = MAX_STEPS , eta_min = LEARNING_RATE * 0.1 )
210+
200211 policy_model .set_loss (loss_fn )
201212 policy_model .add_metric (DPOMetric , beta = DPO_BETA )
202213 policy_model .set_processor (InputProcessor )
@@ -205,8 +216,7 @@ def main():
205216 # ── Reference Model Setup ─────────────────────────────────────────────────
206217 ref_model = None
207218 if not reference_free :
208- ref_mesh = DeviceMesh .from_sizes (world_size = REF_MODEL_GPUS , dp_size = REF_MODEL_GPUS )
209- ref_model = TransformersModel (
219+ ref_model = ModelClass (
210220 model_id = MODEL_ID ,
211221 device_mesh = ref_mesh ,
212222 remote_group = 'reference' ,
@@ -218,8 +228,9 @@ def main():
218228 logger .info (f'Training without reference model (loss_type={ LOSS_TYPE } )' )
219229
220230 optim_step = 0
231+ backend_name = 'Megatron' if USE_MEGATRON else 'Transformers'
221232 logger .info (get_device_placement ())
222- logger .info (f'Starting DPO training: loss_type={ LOSS_TYPE } , beta={ DPO_BETA } ' )
233+ logger .info (f'Starting DPO training ( { backend_name } ) : loss_type={ LOSS_TYPE } , beta={ DPO_BETA } ' )
223234
224235 # ── Training Loop ─────────────────────────────────────────────────────────
225236 for batch in dataloader :
@@ -232,20 +243,15 @@ def main():
232243 ref_outputs = ref_model .forward_only (inputs = dpo_batch )
233244
234245 # Forward-backward pass with DPO loss
235- # ref_outputs is passed to loss which extracts logps internally
236- policy_model .forward_backward (
237- inputs = dpo_batch ,
238- ref_outputs = ref_outputs ,
239- )
240-
241- # Gradient clipping and optimizer step
246+ policy_model .forward_backward (inputs = dpo_batch , ref_outputs = ref_outputs )
242247 policy_model .clip_grad_and_step ()
248+
243249 optim_step += 1
244250
245251 # Logging
246- if optim_step % 1 == 0 :
252+ if optim_step % GRADIENT_ACCUMULATION_STEPS == 0 :
247253 metrics = policy_model .calculate_metric (is_training = True )
248- logger .info (f'[Step { optim_step } /{ MAX_STEPS } ] { metrics } ' )
254+ logger .info (f'[Step { optim_step // GRADIENT_ACCUMULATION_STEPS } /{ MAX_STEPS } ] { metrics } ' )
249255
250256 # Checkpointing
251257 if optim_step % SAVE_STEPS == 0 :
0 commit comments