Skip to content

Commit a28ea8d

Browse files
committed
update megatron dpo
1 parent ea6df52 commit a28ea8d

File tree

2 files changed

+217
-0
lines changed

2 files changed

+217
-0
lines changed

cookbook/client/tinker/self_host/dpo.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def train():
145145
# C. DPO forward_backward
146146
# Server detects ref_logps → sets DPOLoss + DPOMetric automatically.
147147
# Optional DPO hyper-params can be forwarded via loss_fn_config.
148+
# (e.g. beta, sft_weight, not support dpo_loss_type for tinker)
148149
# -----------------------------------------------------------------
149150
fwdbwd_result = training_client.forward_backward(
150151
input_datums,

cookbook/rl/dpo_multi_lora.py

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
"""DPO (Direct Preference Optimization) Training with MultiLoRA (Megatron Backend).
2+
3+
MultiLoRA-based DPO training: uses the base model (without LoRA adapter) as reference
4+
model by calling forward_only with disable_lora=True. This eliminates the need for
5+
a separate reference model GPU group.
6+
7+
Uses Megatron backend with MultiLoRAMegatronModel for efficient multi-tenant LoRA training.
8+
9+
Pipeline:
10+
1. Load preference dataset with chosen/rejected pairs.
11+
2. Encode positive and negative separately.
12+
3. Compute reference model log probabilities using base model (disable_lora=True).
13+
4. Train policy model (with LoRA adapter) using DPO loss.
14+
15+
Architecture (Ray - Single Group):
16+
┌─────────────────────────────────────────────────────────────────┐
17+
│ Driver (CPU) │
18+
│ dataloader ──► batched preference pairs │
19+
│ policy_model.forward_only(disable_lora=True) ──► ref logps │
20+
│ policy_model.forward_backward() ──► DPO loss + gradient │
21+
└─────────────────────────────────────────────────────────────────┘
22+
23+
PolicyModel (with LoRA adapter)
24+
- forward_only(disable_lora=True) → base model inference (reference)
25+
- forward_backward() → LoRA adapter training (policy)
26+
27+
DPO data format (after preprocessing):
28+
- positive: List[Trajectory] - chosen responses
29+
- negative: List[Trajectory] - rejected responses
30+
31+
Environment variables (all optional):
32+
MODEL_ID – (default: ms://Qwen/Qwen3.5-4B)
33+
DATASET_ID – (default: ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji)
34+
MODEL_GPUS – GPUs for policy model (default: 2)
35+
BATCH_SIZE – global batch size (preference pairs) (default: 8)
36+
MAX_STEPS – total optimization steps (default: 1000)
37+
LR – learning rate (default: 1e-4)
38+
DPO_BETA – DPO temperature parameter (default: 0.1)
39+
LOSS_TYPE – DPO variant (sigmoid/hinge/ipo) (default: sigmoid)
40+
SAVE_STEPS – checkpoint save interval (default: 100)
41+
MAX_LENGTH – max sequence length (default: 2048)
42+
"""
43+
44+
import os
45+
from typing import Any, Dict, List, Optional
46+
47+
from peft import LoraConfig
48+
49+
import twinkle
50+
from twinkle import DeviceGroup, DeviceMesh, get_device_placement, get_logger
51+
from twinkle.data_format import Trajectory
52+
from twinkle.dataloader import DataLoader
53+
from twinkle.dataset import Dataset, DatasetMeta
54+
from twinkle.loss import DPOLoss
55+
from twinkle.metric import DPOMetric
56+
from twinkle.preprocessor import EmojiDPOProcessor
57+
from twinkle.processor import InputProcessor
58+
59+
logger = get_logger()
60+
61+
# ── Configuration ─────────────────────────────────────────────────────────────
62+
MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.5-4B')
63+
DATASET_ID = os.environ.get('DATASET_ID', 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji')
64+
65+
MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 2))
66+
67+
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) # Number of preference pairs
68+
GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 2))
69+
LEARNING_RATE = float(os.environ.get('LR', 1e-4)) # LoRA DPO requires higher LR (1e-4 to 3e-4)
70+
DPO_BETA = float(os.environ.get('DPO_BETA', 0.1))
71+
SFT_WEIGHT = float(os.environ.get('SFT_WEIGHT', 1.0)) # SFT loss weight for regularization
72+
LOSS_TYPE = os.environ.get('LOSS_TYPE', 'sigmoid') # sigmoid, hinge, ipo
73+
SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 100))
74+
MAX_LENGTH = int(os.environ.get('MAX_LENGTH', 2048))
75+
ADAPTER_NAME = 'default_0'
76+
SYSTEM_PROMPT = os.environ.get('SYSTEM_PROMPT', 'You are a helpful assistant.')
77+
78+
79+
def create_dpo_dataset():
80+
"""Create DPO dataset with positive/negative format."""
81+
dataset = Dataset(DatasetMeta(DATASET_ID, data_slice=range(50)))
82+
dataset.set_template('Qwen3_5Template', model_id=MODEL_ID, max_length=MAX_LENGTH)
83+
dataset.map(
84+
EmojiDPOProcessor,
85+
init_args={
86+
'system': SYSTEM_PROMPT,
87+
}
88+
)
89+
# DPO preprocessor returns {'positive': [...], 'negative': [...]}
90+
# batch_encode handles this format automatically
91+
dataset.encode(load_from_cache_file=True)
92+
return dataset
93+
94+
95+
def prepare_dpo_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
96+
"""Prepare DPO batch: reorganize batch for training with DP-safe interleaving.
97+
98+
Args:
99+
batch: List of rows, each with 'positive' and 'negative' InputFeatures
100+
and other fields (question, etc.)
101+
102+
Returns:
103+
List interleaved as [pos_1, neg_1, pos_2, neg_2, ...] to ensure each DP
104+
worker gets complete positive/negative pairs after slicing.
105+
Each item contains all original fields plus the InputFeature fields.
106+
"""
107+
result = []
108+
109+
for row in batch:
110+
# Get base fields (excluding positive/negative)
111+
base_fields = {k: v for k, v in row.items() if k not in ('positive', 'negative')}
112+
113+
# Positive sample: merge base fields with positive InputFeature
114+
pos_sample = {**base_fields, **row['positive']}
115+
# Negative sample: merge base fields with negative InputFeature
116+
neg_sample = {**base_fields, **row['negative']}
117+
118+
# Interleave: [pos, neg] per pair for DP-safe slicing
119+
result.append(pos_sample)
120+
result.append(neg_sample)
121+
122+
return result
123+
124+
125+
# ── Main Training Loop ────────────────────────────────────────────────────────
126+
127+
def main():
128+
# Set up device groups - only one group for LoRA training
129+
device_groups = [
130+
DeviceGroup(name='policy', ranks=list(range(MODEL_GPUS)), device_type='GPU'),
131+
]
132+
133+
# Configure device mesh for MultiLoRA Megatron: dp=2, pp=1
134+
from twinkle.model import MultiLoraMegatronModel
135+
policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=2, pp_size=1)
136+
ModelClass = MultiLoraMegatronModel
137+
138+
twinkle.initialize(mode='ray', nproc_per_node=MODEL_GPUS, groups=device_groups)
139+
140+
# ── DataLoader Setup ──────────────────────────────────────────────────────
141+
dataloader = DataLoader(
142+
dataset=create_dpo_dataset,
143+
batch_size=BATCH_SIZE,
144+
min_batch_size=BATCH_SIZE,
145+
device_mesh=policy_mesh,
146+
)
147+
148+
# ── Policy Model Setup with LoRA ──────────────────────────────────────────
149+
lora_config = LoraConfig(
150+
target_modules='all-linear',
151+
r=8,
152+
lora_alpha=32,
153+
lora_dropout=0.05,
154+
)
155+
156+
policy_model = ModelClass(
157+
model_id=MODEL_ID,
158+
device_mesh=policy_mesh,
159+
remote_group='policy',
160+
)
161+
MAX_STEPS = len(dataloader)
162+
policy_model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS)
163+
164+
# Configure optimizer based on backend
165+
policy_model.set_optimizer('default', lr=LEARNING_RATE, weight_decay=0.01, adapter_name=ADAPTER_NAME)
166+
policy_model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS, adapter_name=ADAPTER_NAME)
167+
168+
169+
# Set up loss function and metrics
170+
loss_fn = DPOLoss(
171+
beta=DPO_BETA,
172+
loss_type=LOSS_TYPE,
173+
reference_free=False, # We use base model as reference via disable_lora=True
174+
sft_weight=SFT_WEIGHT,
175+
)
176+
177+
policy_model.set_loss(loss_fn)
178+
policy_model.add_metric(DPOMetric, beta=DPO_BETA)
179+
policy_model.set_processor(InputProcessor)
180+
policy_model.set_template('Qwen3_5Template', model_id=MODEL_ID)
181+
182+
optim_step = 0
183+
backend_name = 'MultiLoRA Megatron'
184+
logger.info(get_device_placement())
185+
logger.info(f'Starting MultiLoRA DPO training ({backend_name}): loss_type={LOSS_TYPE}, beta={DPO_BETA}, lr={LEARNING_RATE}')
186+
logger.info(f'Using base model (disable_lora=True) as reference model')
187+
188+
# ── Training Loop ─────────────────────────────────────────────────────────
189+
for batch in dataloader:
190+
# batch is List[Dict] with 'positive' and 'negative' keys
191+
dpo_batch = prepare_dpo_batch(batch)
192+
193+
# Get reference outputs using base model (without LoRA adapter)
194+
# disable_lora=True tells the model to skip LoRA and use base weights
195+
ref_outputs = policy_model.forward_only(inputs=dpo_batch, disable_lora=True, adapter_name=ADAPTER_NAME)
196+
policy_model.forward_backward(inputs=dpo_batch, ref_outputs=ref_outputs, adapter_name=ADAPTER_NAME)
197+
policy_model.clip_grad_and_step(adapter_name=ADAPTER_NAME)
198+
199+
optim_step += 1
200+
201+
# Logging
202+
if optim_step % GRADIENT_ACCUMULATION_STEPS == 0:
203+
metrics = policy_model.calculate_metric(is_training=True, adapter_name=ADAPTER_NAME)
204+
logger.info(f'[Step {optim_step // GRADIENT_ACCUMULATION_STEPS}/{MAX_STEPS}] {metrics}')
205+
206+
# Checkpointing
207+
if optim_step % SAVE_STEPS == 0:
208+
policy_model.save(f'dpo-lora-checkpoint-{optim_step}', adapter_name=ADAPTER_NAME)
209+
210+
# ── Save Final Checkpoint ─────────────────────────────────────────────────
211+
logger.info(f'Training completed. Total steps: {optim_step}')
212+
policy_model.save('dpo-lora-final-checkpoint', adapter_name=ADAPTER_NAME)
213+
214+
215+
if __name__ == '__main__':
216+
main()

0 commit comments

Comments
 (0)