Skip to content

Commit 91cad80

Browse files
committed
wip
1 parent 20fde35 commit 91cad80

File tree

3 files changed

+96
-77
lines changed

3 files changed

+96
-77
lines changed
Lines changed: 45 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
1-
"""DPO (Direct Preference Optimization) Training via Ray.
1+
"""DPO (Direct Preference Optimization) Full-Parameter Training via Ray.
22
33
Off-policy preference alignment: trains the model to prefer chosen responses
44
over rejected responses using preference data, without explicit reward modeling.
55
6+
Supports both Transformers (FSDP) and Megatron backends via USE_MEGATRON flag.
7+
68
Pipeline:
79
1. Load preference dataset with chosen/rejected pairs.
810
2. Encode positive and negative separately.
911
3. Compute reference model log probabilities (frozen).
10-
4. Train policy model using DPO loss.
12+
4. Train policy model using DPO loss (full-parameter, no LoRA).
1113
1214
Architecture (Ray):
1315
┌─────────────────────────────────────────────────────────────────┐
@@ -28,14 +30,15 @@
2830
set REF_MODEL_GPUS=0 to skip reference model computation.
2931
3032
Environment variables (all optional):
31-
MODEL_ID – (default: ms://Qwen/Qwen3.5-4B)
33+
USE_MEGATRON – Use Megatron backend (default: 0, use Transformers)
34+
MODEL_ID – (default: ms://Qwen/Qwen3-4B)
3235
DATASET_ID – (default: ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji)
3336
MODEL_GPUS – GPUs for policy model (default: 4)
3437
REF_MODEL_GPUS – GPUs for reference model (default: 4, 0 to disable)
3538
BATCH_SIZE – global batch size (preference pairs) (default: 8)
3639
MICRO_BATCH_SIZE – per-device micro batch size (default: 2)
3740
MAX_STEPS – total optimization steps (default: 1000)
38-
LR – learning rate (default: 5e-6)
41+
LR – learning rate (default: 1e-5)
3942
DPO_BETA – DPO temperature parameter (default: 0.1)
4043
LOSS_TYPE – DPO variant (sigmoid/hinge/ipo/simpo/orpo/cpo) (default: sigmoid)
4144
SAVE_STEPS – checkpoint save interval (default: 100)
@@ -49,9 +52,7 @@
4952
"""
5053

5154
import os
52-
from typing import Any, Dict, List, Optional
53-
54-
from peft import LoraConfig
55+
from typing import Any, Dict, List
5556

5657
import twinkle
5758
from twinkle import DeviceGroup, DeviceMesh, get_device_placement, get_logger
@@ -60,30 +61,29 @@
6061
from twinkle.dataset import Dataset, DatasetMeta
6162
from twinkle.loss import CPOLoss, DPOLoss, ORPOLoss, SimPOLoss
6263
from twinkle.metric import DPOMetric
63-
from twinkle.model import TransformersModel
6464
from twinkle.preprocessor import EmojiDPOProcessor
6565
from twinkle.processor import InputProcessor
6666

6767
logger = get_logger()
6868

6969
# ── Configuration ─────────────────────────────────────────────────────────────
70-
MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen2.5-7B-Instruct')
70+
USE_MEGATRON = int(os.environ.get('USE_MEGATRON', 0))
71+
MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3-4B')
7172
DATASET_ID = os.environ.get('DATASET_ID', 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji')
7273

73-
MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 2))
74-
REF_MODEL_GPUS = int(os.environ.get('REF_MODEL_GPUS', 2))
74+
MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4))
75+
REF_MODEL_GPUS = int(os.environ.get('REF_MODEL_GPUS', 4))
7576
NUM_GPUS = MODEL_GPUS + REF_MODEL_GPUS
7677

77-
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 2)) # Number of preference pairs
78+
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) # Number of preference pairs
7879
MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2))
7980
GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 2))
80-
LEARNING_RATE = float(os.environ.get('LR', 1e-4)) # TRL default for DPO is 5e-7 to 5e-6
81+
LEARNING_RATE = float(os.environ.get('LR', 1e-5))
8182
DPO_BETA = float(os.environ.get('DPO_BETA', 0.1))
8283
SFT_WEIGHT = float(os.environ.get('SFT_WEIGHT', 1.0)) # SFT loss weight for regularization
8384
LOSS_TYPE = os.environ.get('LOSS_TYPE', 'sigmoid') # sigmoid, hinge, ipo, simpo, orpo, cpo
8485
SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 100))
8586
MAX_LENGTH = int(os.environ.get('MAX_LENGTH', 2048))
86-
ADAPTER_NAME = 'default'
8787
SYSTEM_PROMPT = os.environ.get('SYSTEM_PROMPT', 'You are a helpful assistant.')
8888

8989

@@ -162,7 +162,20 @@ def main():
162162
DeviceGroup(name='reference', ranks=list(range(MODEL_GPUS, NUM_GPUS)), device_type='GPU'),
163163
]
164164

165-
policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS)
165+
# Configure device mesh based on backend
166+
if USE_MEGATRON:
167+
# Megatron: dp=2, pp=2 for each model
168+
from twinkle.model import MegatronModel
169+
policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=2, pp_size=2)
170+
ref_mesh = DeviceMesh.from_sizes(world_size=REF_MODEL_GPUS, dp_size=2, pp_size=2)
171+
ModelClass = MegatronModel
172+
else:
173+
# Transformers: fsdp=2, dp=2 for each model
174+
from twinkle.model import TransformersModel
175+
policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, fsdp_size=2, dp_size=2)
176+
ref_mesh = DeviceMesh.from_sizes(world_size=REF_MODEL_GPUS, fsdp_size=2, dp_size=2)
177+
ModelClass = TransformersModel
178+
166179
twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups)
167180

168181
# ── DataLoader Setup ──────────────────────────────────────────────────────
@@ -172,31 +185,29 @@ def main():
172185
min_batch_size=BATCH_SIZE,
173186
device_mesh=policy_mesh,
174187
)
175-
length = len(dataloader)
176188

177189
# ── Policy Model Setup ────────────────────────────────────────────────────
178-
lora_config = LoraConfig(
179-
target_modules='all-linear',
180-
r=16,
181-
lora_alpha=32,
182-
lora_dropout=0.05,
183-
)
184-
185-
policy_model = TransformersModel(
190+
policy_model = ModelClass(
186191
model_id=MODEL_ID,
187192
device_mesh=policy_mesh,
188193
remote_group='policy',
189194
)
190195
MAX_STEPS = len(dataloader)
191-
policy_model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS)
192-
policy_model.set_optimizer('AdamW', lr=LEARNING_RATE, weight_decay=0.01)
193-
policy_model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=LEARNING_RATE * 0.1)
194196

195197
# Determine if we need reference model based on loss type
196198
reference_free = LOSS_TYPE in ['simpo', 'orpo', 'cpo']
197199

198200
# Set up loss function and metrics
199201
loss_fn = create_loss(LOSS_TYPE, DPO_BETA, sft_weight=SFT_WEIGHT, reference_free=False)
202+
203+
# Configure optimizer based on backend (full-parameter training)
204+
if USE_MEGATRON:
205+
policy_model.set_optimizer('default', lr=LEARNING_RATE, weight_decay=0.01)
206+
policy_model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS)
207+
else:
208+
policy_model.set_optimizer('AdamW', lr=LEARNING_RATE, weight_decay=0.01)
209+
policy_model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=LEARNING_RATE * 0.1)
210+
200211
policy_model.set_loss(loss_fn)
201212
policy_model.add_metric(DPOMetric, beta=DPO_BETA)
202213
policy_model.set_processor(InputProcessor)
@@ -205,8 +216,7 @@ def main():
205216
# ── Reference Model Setup ─────────────────────────────────────────────────
206217
ref_model = None
207218
if not reference_free:
208-
ref_mesh = DeviceMesh.from_sizes(world_size=REF_MODEL_GPUS, dp_size=REF_MODEL_GPUS)
209-
ref_model = TransformersModel(
219+
ref_model = ModelClass(
210220
model_id=MODEL_ID,
211221
device_mesh=ref_mesh,
212222
remote_group='reference',
@@ -218,8 +228,9 @@ def main():
218228
logger.info(f'Training without reference model (loss_type={LOSS_TYPE})')
219229

220230
optim_step = 0
231+
backend_name = 'Megatron' if USE_MEGATRON else 'Transformers'
221232
logger.info(get_device_placement())
222-
logger.info(f'Starting DPO training: loss_type={LOSS_TYPE}, beta={DPO_BETA}')
233+
logger.info(f'Starting DPO training ({backend_name}): loss_type={LOSS_TYPE}, beta={DPO_BETA}')
223234

224235
# ── Training Loop ─────────────────────────────────────────────────────────
225236
for batch in dataloader:
@@ -232,20 +243,15 @@ def main():
232243
ref_outputs = ref_model.forward_only(inputs=dpo_batch)
233244

234245
# Forward-backward pass with DPO loss
235-
# ref_outputs is passed to loss which extracts logps internally
236-
policy_model.forward_backward(
237-
inputs=dpo_batch,
238-
ref_outputs=ref_outputs,
239-
)
240-
241-
# Gradient clipping and optimizer step
246+
policy_model.forward_backward(inputs=dpo_batch, ref_outputs=ref_outputs)
242247
policy_model.clip_grad_and_step()
248+
243249
optim_step += 1
244250

245251
# Logging
246-
if optim_step % 1 == 0:
252+
if optim_step % GRADIENT_ACCUMULATION_STEPS == 0:
247253
metrics = policy_model.calculate_metric(is_training=True)
248-
logger.info(f'[Step {optim_step}/{MAX_STEPS}] {metrics}')
254+
logger.info(f'[Step {optim_step // GRADIENT_ACCUMULATION_STEPS}/{MAX_STEPS}] {metrics}')
249255

250256
# Checkpointing
251257
if optim_step % SAVE_STEPS == 0:

cookbook/rl/dpo_lora.py

Lines changed: 49 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,38 @@
11
"""DPO (Direct Preference Optimization) Training with LoRA (Single GPU Group).
22
33
LoRA-based DPO training: uses the base model (without LoRA adapter) as reference
4-
model by calling forward_only with adapter_name=''. This eliminates the need for
4+
model by calling forward_only with disable_lora=True. This eliminates the need for
55
a separate reference model GPU group.
66
7+
Supports both Transformers (FSDP) and Megatron backends via USE_MEGATRON flag.
8+
79
Pipeline:
810
1. Load preference dataset with chosen/rejected pairs.
911
2. Encode positive and negative separately.
10-
3. Compute reference model log probabilities using base model (adapter_name='').
12+
3. Compute reference model log probabilities using base model (disable_lora=True).
1113
4. Train policy model (with LoRA adapter) using DPO loss.
1214
1315
Architecture (Ray - Single Group):
1416
┌─────────────────────────────────────────────────────────────────┐
1517
│ Driver (CPU) │
1618
│ dataloader ──► batched preference pairs │
17-
│ policy_model.forward_only(adapter_name='') ──► reference logps│
19+
│ policy_model.forward_only(disable_lora=True) ──► ref logps
1820
│ policy_model.forward_backward() ──► DPO loss + gradient │
1921
└─────────────────────────────────────────────────────────────────┘
2022
2123
PolicyModel (with LoRA adapter)
22-
- forward_only(adapter_name='') → base model inference (reference)
24+
- forward_only(disable_lora=True) → base model inference (reference)
2325
- forward_backward() → LoRA adapter training (policy)
2426
2527
DPO data format (after preprocessing):
2628
- positive: List[Trajectory] - chosen responses
2729
- negative: List[Trajectory] - rejected responses
2830
2931
Environment variables (all optional):
30-
MODEL_ID – (default: ms://Qwen/Qwen2.5-7B-Instruct)
32+
USE_MEGATRON – Use Megatron backend (default: 0, use Transformers)
33+
MODEL_ID – (default: ms://Qwen/Qwen3-4B)
3134
DATASET_ID – (default: ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji)
32-
MODEL_GPUS – GPUs for policy model (default: 4)
35+
MODEL_GPUS – GPUs for policy model (default: 8)
3336
BATCH_SIZE – global batch size (preference pairs) (default: 8)
3437
MICRO_BATCH_SIZE – per-device micro batch size (default: 2)
3538
MAX_STEPS – total optimization steps (default: 1000)
@@ -58,20 +61,20 @@
5861
from twinkle.dataset import Dataset, DatasetMeta
5962
from twinkle.loss import DPOLoss
6063
from twinkle.metric import DPOMetric
61-
from twinkle.model import MegatronModel
6264
from twinkle.preprocessor import EmojiDPOProcessor
6365
from twinkle.processor import InputProcessor
6466

6567
logger = get_logger()
6668

6769
# ── Configuration ─────────────────────────────────────────────────────────────
70+
USE_MEGATRON = int(os.environ.get('USE_MEGATRON', 0))
6871
MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3-4B')
6972
DATASET_ID = os.environ.get('DATASET_ID', 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji')
7073

7174
MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 8))
7275

7376
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) # Number of preference pairs
74-
MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 8))
77+
MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2))
7578
GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 2))
7679
LEARNING_RATE = float(os.environ.get('LR', 1e-4)) # LoRA DPO requires higher LR (1e-4 to 3e-4)
7780
DPO_BETA = float(os.environ.get('DPO_BETA', 0.1))
@@ -137,8 +140,19 @@ def main():
137140
DeviceGroup(name='policy', ranks=list(range(MODEL_GPUS)), device_type='GPU'),
138141
]
139142

140-
policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=4, pp_size=2)
141-
twinkle.initialize(mode='ray', nproc_per_node=8, groups=device_groups)
143+
# Configure device mesh based on backend
144+
if USE_MEGATRON:
145+
# Megatron: dp=4, pp=2
146+
from twinkle.model import MegatronModel
147+
policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=4, pp_size=2)
148+
ModelClass = MegatronModel
149+
else:
150+
# Transformers: fsdp=4, dp=2
151+
from twinkle.model import TransformersModel
152+
policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, fsdp_size=4, dp_size=2)
153+
ModelClass = TransformersModel
154+
155+
twinkle.initialize(mode='ray', nproc_per_node=MODEL_GPUS, groups=device_groups)
142156

143157
# ── DataLoader Setup ──────────────────────────────────────────────────────
144158
dataloader = DataLoader(
@@ -147,7 +161,6 @@ def main():
147161
min_batch_size=BATCH_SIZE,
148162
device_mesh=policy_mesh,
149163
)
150-
length = len(dataloader)
151164

152165
# ── Policy Model Setup with LoRA ──────────────────────────────────────────
153166
lora_config = LoraConfig(
@@ -157,15 +170,21 @@ def main():
157170
lora_dropout=0.05,
158171
)
159172

160-
policy_model = MegatronModel(
173+
policy_model = ModelClass(
161174
model_id=MODEL_ID,
162175
device_mesh=policy_mesh,
163176
remote_group='policy',
164177
)
165178
MAX_STEPS = len(dataloader)
166179
policy_model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS)
167-
policy_model.set_optimizer('default', lr=LEARNING_RATE, weight_decay=0.01, adapter_name=ADAPTER_NAME)
168-
policy_model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS, adapter_name=ADAPTER_NAME)
180+
181+
# Configure optimizer based on backend
182+
if USE_MEGATRON:
183+
policy_model.set_optimizer('default', lr=LEARNING_RATE, weight_decay=0.01, adapter_name=ADAPTER_NAME)
184+
policy_model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS, adapter_name=ADAPTER_NAME)
185+
else:
186+
policy_model.set_optimizer('AdamW', lr=LEARNING_RATE, weight_decay=0.01)
187+
policy_model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=LEARNING_RATE * 0.1)
169188

170189
# Set up loss function and metrics
171190
loss_fn = DPOLoss(
@@ -174,14 +193,16 @@ def main():
174193
reference_free=False, # We use base model as reference via disable_lora=True
175194
sft_weight=SFT_WEIGHT,
176195
)
177-
policy_model.set_loss(loss_fn, adapter_name=ADAPTER_NAME)
178-
policy_model.add_metric(DPOMetric, beta=DPO_BETA, adapter_name=ADAPTER_NAME)
179-
policy_model.set_processor(InputProcessor, adapter_name=ADAPTER_NAME)
180-
policy_model.set_template('Template', model_id=MODEL_ID, adapter_name=ADAPTER_NAME)
196+
197+
policy_model.set_loss(loss_fn)
198+
policy_model.add_metric(DPOMetric, beta=DPO_BETA)
199+
policy_model.set_processor(InputProcessor)
200+
policy_model.set_template('Template', model_id=MODEL_ID)
181201

182202
optim_step = 0
203+
backend_name = 'Megatron' if USE_MEGATRON else 'Transformers'
183204
logger.info(get_device_placement())
184-
logger.info(f'Starting LoRA DPO training: loss_type={LOSS_TYPE}, beta={DPO_BETA}, lr={LEARNING_RATE}')
205+
logger.info(f'Starting LoRA DPO training ({backend_name}): loss_type={LOSS_TYPE}, beta={DPO_BETA}, lr={LEARNING_RATE}')
185206
logger.info(f'Using base model (disable_lora=True) as reference model')
186207

187208
# ── Training Loop ─────────────────────────────────────────────────────────
@@ -191,33 +212,24 @@ def main():
191212

192213
# Get reference outputs using base model (without LoRA adapter)
193214
# disable_lora=True tells the model to skip LoRA and use base weights
194-
ref_outputs = policy_model.forward_only(inputs=dpo_batch, micro_batch_size=2, disable_lora=True, adapter_name=ADAPTER_NAME)
195-
# Forward-backward pass with DPO loss (using LoRA adapter)
196-
# ref_outputs is passed to loss which extracts logps internally
197-
policy_model.forward_backward(
198-
inputs=dpo_batch,
199-
ref_outputs=ref_outputs,
200-
micro_batch_size=2,
201-
adapter_name=ADAPTER_NAME
202-
)
203-
204-
# Gradient clipping and optimizer step
205-
policy_model.clip_grad_and_step(adapter_name=ADAPTER_NAME)
215+
ref_outputs = policy_model.forward_only(inputs=dpo_batch, disable_lora=True)
216+
policy_model.forward_backward(inputs=dpo_batch, ref_outputs=ref_outputs)
217+
policy_model.clip_grad_and_step()
218+
219+
optim_step += 1
206220

207221
# Logging
208-
if optim_step % 16 == 0:
209-
metrics = policy_model.calculate_metric(is_training=True, adapter_name=ADAPTER_NAME)
222+
if optim_step % GRADIENT_ACCUMULATION_STEPS == 0:
223+
metrics = policy_model.calculate_metric(is_training=True)
210224
logger.info(f'[Step {optim_step // GRADIENT_ACCUMULATION_STEPS}/{MAX_STEPS}] {metrics}')
211225

212226
# Checkpointing
213227
if optim_step % SAVE_STEPS == 0:
214-
policy_model.save(f'dpo-lora-checkpoint-{optim_step}', adapter_name=ADAPTER_NAME)
215-
216-
optim_step += 1
228+
policy_model.save(f'dpo-lora-checkpoint-{optim_step}')
217229

218230
# ── Save Final Checkpoint ─────────────────────────────────────────────────
219231
logger.info(f'Training completed. Total steps: {optim_step}')
220-
policy_model.save('dpo-lora-final-checkpoint', adapter_name=ADAPTER_NAME)
232+
policy_model.save('dpo-lora-final-checkpoint')
221233

222234

223235
if __name__ == '__main__':

src/twinkle/model/megatron/megatron.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,8 @@ def forward_backward(self,
410410
assert isinstance(processor, InputProcessor), 'Set InputProcessor correctly before forwarding'
411411

412412
if micro_batch_size is None:
413-
micro_batch_size = 1
413+
# Compatible with DPO
414+
micro_batch_size = 2
414415
inputs = processor(inputs, micro_batch_size=micro_batch_size, variable_seq_lengths=self.variable_seq_lengths)
415416

416417
# Get parallelism settings for sequence padding and splitting

0 commit comments

Comments
 (0)