5858from twinkle .dataset import Dataset , DatasetMeta
5959from twinkle .loss import DPOLoss
6060from twinkle .metric import DPOMetric
61- from twinkle .model import MegatronModel
61+ from twinkle .model import MultiLoraMegatronModel
6262from twinkle .preprocessor import EmojiDPOProcessor
6363from twinkle .processor import InputProcessor
6464
6868MODEL_ID = os .environ .get ('MODEL_ID' , 'ms://Qwen/Qwen2.5-7B-Instruct' )
6969DATASET_ID = os .environ .get ('DATASET_ID' , 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji' )
7070
71- MODEL_GPUS = int (os .environ .get ('MODEL_GPUS' , 2 ))
71+ MODEL_GPUS = int (os .environ .get ('MODEL_GPUS' , 8 ))
7272
7373BATCH_SIZE = int (os .environ .get ('BATCH_SIZE' , 2 )) # Number of preference pairs
7474MICRO_BATCH_SIZE = int (os .environ .get ('MICRO_BATCH_SIZE' , 2 ))
@@ -137,7 +137,7 @@ def main():
137137 DeviceGroup (name = 'policy' , ranks = list (range (MODEL_GPUS )), device_type = 'GPU' ),
138138 ]
139139
140- policy_mesh = DeviceMesh .from_sizes (world_size = MODEL_GPUS , dp_size = MODEL_GPUS )
140+ policy_mesh = DeviceMesh .from_sizes (world_size = MODEL_GPUS , dp_size = 1 , pp_size = 2 , cp_size = 2 , tp_size = 2 )
141141 twinkle .initialize (mode = 'ray' , nproc_per_node = 8 , groups = device_groups )
142142
143143 # ── DataLoader Setup ──────────────────────────────────────────────────────
@@ -157,15 +157,17 @@ def main():
157157 lora_dropout = 0.05 ,
158158 )
159159
160- policy_model = MegatronModel (
160+ policy_model = MultiLoraMegatronModel (
161161 model_id = MODEL_ID ,
162162 device_mesh = policy_mesh ,
163163 remote_group = 'policy' ,
164164 )
165165 MAX_STEPS = len (dataloader )
166166 policy_model .add_adapter_to_model (ADAPTER_NAME , lora_config , gradient_accumulation_steps = GRADIENT_ACCUMULATION_STEPS )
167- policy_model .set_optimizer ('default' , lr = LEARNING_RATE , weight_decay = 0.01 )
168- policy_model .set_lr_scheduler ('default' , lr_decay_steps = MAX_STEPS )
167+ # policy_model.set_optimizer('AdamW', lr=LEARNING_RATE, weight_decay=0.01, adapter_name=ADAPTER_NAME)
168+ # policy_model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, adapter_name=ADAPTER_NAME)
169+ policy_model .set_optimizer ('default' , lr = LEARNING_RATE , weight_decay = 0.01 , adapter_name = ADAPTER_NAME )
170+ policy_model .set_lr_scheduler ('default' , lr_decay_steps = MAX_STEPS , adapter_name = ADAPTER_NAME )
169171
170172 # Set up loss function and metrics
171173 loss_fn = DPOLoss (
@@ -174,10 +176,10 @@ def main():
174176 reference_free = False , # We use base model as reference via disable_lora=True
175177 sft_weight = SFT_WEIGHT ,
176178 )
177- policy_model .set_loss (loss_fn )
178- policy_model .add_metric (DPOMetric , beta = DPO_BETA )
179- policy_model .set_processor (InputProcessor )
180- policy_model .set_template ('Template' , model_id = MODEL_ID )
179+ policy_model .set_loss (loss_fn , adapter_name = ADAPTER_NAME )
180+ policy_model .add_metric (DPOMetric , beta = DPO_BETA , adapter_name = ADAPTER_NAME )
181+ policy_model .set_processor (InputProcessor , adapter_name = ADAPTER_NAME )
182+ policy_model .set_template ('Template' , model_id = MODEL_ID , adapter_name = ADAPTER_NAME )
181183
182184 optim_step = 0
183185 logger .info (get_device_placement ())
@@ -191,32 +193,32 @@ def main():
191193
192194 # Get reference outputs using base model (without LoRA adapter)
193195 # disable_lora=True tells the model to skip LoRA and use base weights
194- ref_outputs = policy_model .forward_only (inputs = dpo_batch , micro_batch_size = 2 , disable_lora = True )
195-
196+ ref_outputs = policy_model .forward_only (inputs = dpo_batch , micro_batch_size = 2 , disable_lora = True , adapter_name = ADAPTER_NAME )
196197 # Forward-backward pass with DPO loss (using LoRA adapter)
197198 # ref_outputs is passed to loss which extracts logps internally
198199 policy_model .forward_backward (
199200 inputs = dpo_batch ,
200201 ref_outputs = ref_outputs ,
201202 micro_batch_size = 2 ,
203+ adapter_name = ADAPTER_NAME
202204 )
203205
204206 # Gradient clipping and optimizer step
205- policy_model .clip_grad_and_step ()
207+ policy_model .clip_grad_and_step (adapter_name = ADAPTER_NAME )
206208 optim_step += 1
207209
208210 # Logging
209211 if optim_step % 1 == 0 :
210- metrics = policy_model .calculate_metric (is_training = True )
212+ metrics = policy_model .calculate_metric (is_training = True , adapter_name = ADAPTER_NAME )
211213 logger .info (f'[Step { optim_step } /{ MAX_STEPS } ] { metrics } ' )
212214
213215 # Checkpointing
214216 if optim_step % SAVE_STEPS == 0 :
215- policy_model .save (f'dpo-lora-checkpoint-{ optim_step } ' )
217+ policy_model .save (f'dpo-lora-checkpoint-{ optim_step } ' , adapter_name = ADAPTER_NAME )
216218
217219 # ── Save Final Checkpoint ─────────────────────────────────────────────────
218220 logger .info (f'Training completed. Total steps: { optim_step } ' )
219- policy_model .save ('dpo-lora-final-checkpoint' )
221+ policy_model .save ('dpo-lora-final-checkpoint' , adapter_name = ADAPTER_NAME )
220222
221223
222224if __name__ == '__main__' :
0 commit comments