Skip to content

Commit 20fde35

Browse files
committed
wip
1 parent 8fc2bb7 commit 20fde35

File tree

5 files changed

+27
-570
lines changed

5 files changed

+27
-570
lines changed

cookbook/rl/dpo.sh

Lines changed: 0 additions & 84 deletions
This file was deleted.

cookbook/rl/dpo_lora.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -58,20 +58,20 @@
5858
from twinkle.dataset import Dataset, DatasetMeta
5959
from twinkle.loss import DPOLoss
6060
from twinkle.metric import DPOMetric
61-
from twinkle.model import MultiLoraMegatronModel
61+
from twinkle.model import MegatronModel
6262
from twinkle.preprocessor import EmojiDPOProcessor
6363
from twinkle.processor import InputProcessor
6464

6565
logger = get_logger()
6666

6767
# ── Configuration ─────────────────────────────────────────────────────────────
68-
MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen2.5-7B-Instruct')
68+
MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3-4B')
6969
DATASET_ID = os.environ.get('DATASET_ID', 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji')
7070

7171
MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 8))
7272

73-
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 2)) # Number of preference pairs
74-
MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2))
73+
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) # Number of preference pairs
74+
MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 8))
7575
GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 2))
7676
LEARNING_RATE = float(os.environ.get('LR', 1e-4)) # LoRA DPO requires higher LR (1e-4 to 3e-4)
7777
DPO_BETA = float(os.environ.get('DPO_BETA', 0.1))
@@ -85,7 +85,7 @@
8585

8686
def create_dpo_dataset():
8787
"""Create DPO dataset with positive/negative format."""
88-
dataset = Dataset(DatasetMeta(DATASET_ID, data_slice=range(30000)))
88+
dataset = Dataset(DatasetMeta(DATASET_ID, data_slice=range(6000)))
8989
dataset.set_template('Template', model_id=MODEL_ID, max_length=MAX_LENGTH)
9090
dataset.map(
9191
EmojiDPOProcessor,
@@ -137,7 +137,7 @@ def main():
137137
DeviceGroup(name='policy', ranks=list(range(MODEL_GPUS)), device_type='GPU'),
138138
]
139139

140-
policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=1, pp_size=2, cp_size=2, tp_size=2)
140+
policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=4, pp_size=2)
141141
twinkle.initialize(mode='ray', nproc_per_node=8, groups=device_groups)
142142

143143
# ── DataLoader Setup ──────────────────────────────────────────────────────
@@ -152,20 +152,18 @@ def main():
152152
# ── Policy Model Setup with LoRA ──────────────────────────────────────────
153153
lora_config = LoraConfig(
154154
target_modules='all-linear',
155-
r=16,
155+
r=8,
156156
lora_alpha=32,
157157
lora_dropout=0.05,
158158
)
159159

160-
policy_model = MultiLoraMegatronModel(
160+
policy_model = MegatronModel(
161161
model_id=MODEL_ID,
162162
device_mesh=policy_mesh,
163163
remote_group='policy',
164164
)
165165
MAX_STEPS = len(dataloader)
166166
policy_model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS)
167-
# policy_model.set_optimizer('AdamW', lr=LEARNING_RATE, weight_decay=0.01, adapter_name=ADAPTER_NAME)
168-
# policy_model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, adapter_name=ADAPTER_NAME)
169167
policy_model.set_optimizer('default', lr=LEARNING_RATE, weight_decay=0.01, adapter_name=ADAPTER_NAME)
170168
policy_model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS, adapter_name=ADAPTER_NAME)
171169

@@ -205,16 +203,17 @@ def main():
205203

206204
# Gradient clipping and optimizer step
207205
policy_model.clip_grad_and_step(adapter_name=ADAPTER_NAME)
208-
optim_step += 1
209206

210207
# Logging
211-
if optim_step % 1 == 0:
208+
if optim_step % 16 == 0:
212209
metrics = policy_model.calculate_metric(is_training=True, adapter_name=ADAPTER_NAME)
213-
logger.info(f'[Step {optim_step}/{MAX_STEPS}] {metrics}')
210+
logger.info(f'[Step {optim_step // GRADIENT_ACCUMULATION_STEPS}/{MAX_STEPS}] {metrics}')
214211

215212
# Checkpointing
216213
if optim_step % SAVE_STEPS == 0:
217214
policy_model.save(f'dpo-lora-checkpoint-{optim_step}', adapter_name=ADAPTER_NAME)
215+
216+
optim_step += 1
218217

219218
# ── Save Final Checkpoint ─────────────────────────────────────────────────
220219
logger.info(f'Training completed. Total steps: {optim_step}')

cookbook/transformers/fsdp2.py

Lines changed: 14 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,10 @@
77
from twinkle.dataloader import DataLoader
88
from twinkle.dataset import Dataset, DatasetMeta
99
from twinkle.model import TransformersModel
10-
from twinkle.data_format import Message, Trajectory
11-
from twinkle.preprocessor import SelfCognitionProcessor, Preprocessor
10+
from twinkle.preprocessor import SelfCognitionProcessor
1211

1312
# Construct a device_mesh, dp=2
14-
device_mesh = DeviceMesh.from_sizes(dp_size=8)
13+
device_mesh = DeviceMesh.from_sizes(dp_size=2)
1514
# use torchrun mode
1615
twinkle.initialize(mode='local', global_device_mesh=device_mesh)
1716

@@ -21,7 +20,7 @@
2120
def eval(model):
2221
# 100 Samples
2322
dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(100)))
24-
dataset.set_template('Template', model_id='ms://Qwen/Qwen2.5-7B-Instruct')
23+
dataset.set_template('Template', model_id='ms://Qwen/Qwen3.5-4B')
2524
dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
2625
dataset.encode()
2726
dataloader = DataLoader(dataset=dataset, batch_size=8)
@@ -32,55 +31,19 @@ def eval(model):
3231
return metrics
3332

3433

35-
class EmojiDPOProcessor(Preprocessor):
36-
def __init__(
37-
self,
38-
system = 'You are a helpful assistant.',
39-
chosen_key: str = 'answer_zh',
40-
rejected_key: str = 'answer_en',
41-
prompt_key: str = 'prompt',
42-
):
43-
self.system = system
44-
self.chosen_key = chosen_key
45-
self.rejected_key = rejected_key
46-
self.prompt_key = prompt_key
47-
48-
def __call__(self, rows):
49-
rows = self.map_col_to_row(rows)
50-
rows = [self.preprocess(row) for row in rows]
51-
rows = self.map_row_to_col(rows)
52-
return rows
53-
54-
def preprocess(self, row):
55-
"""Process a single row."""
56-
prompt = row.get(self.prompt_key, '')
57-
chosen = row.get(self.chosen_key, '')
58-
rejected = row.get(self.rejected_key, '')
59-
60-
prompt_messages = []
61-
if self.system:
62-
prompt_messages.append(Message(role='system', content=self.system))
63-
prompt_messages.append(Message(role='user', content=prompt))
64-
65-
chosen_messages = prompt_messages + [Message(role='assistant', content=chosen)]
66-
rejected_messages = prompt_messages + [Message(role='assistant', content=rejected)]
67-
68-
return Trajectory(messages=chosen_messages)
69-
70-
7134
def train():
7235
# 1000 samples
73-
dataset = Dataset(dataset_meta=DatasetMeta('ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji'))
36+
dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000)))
7437
# Set template to prepare encoding
75-
dataset.set_template('Template', model_id='ms://Qwen/Qwen2.5-7B-Instruct')
38+
dataset.set_template('Template', model_id='ms://Qwen/Qwen3.5-4B')
7639
# Preprocess the dataset to standard format
77-
dataset.map(EmojiDPOProcessor)
40+
dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
7841
# Encode dataset
7942
dataset.encode()
8043
# Global batch size = 8, for GPUs, so 1 sample per GPU
8144
dataloader = DataLoader(dataset=dataset, batch_size=8)
8245
# Use a TransformersModel
83-
model = TransformersModel(model_id='ms://Qwen/Qwen2.5-7B-Instruct')
46+
model = TransformersModel(model_id='ms://Qwen/Qwen3.5-4B')
8447
model.model._no_split_modules = {'Qwen3_5DecoderLayer'}
8548

8649
lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear')
@@ -109,6 +72,13 @@ def train():
10972
# Print metric
11073
metric = model.calculate_metric(is_training=True)
11174
logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}')
75+
if step > 0 and step % 40 == 0:
76+
metrics = eval(model)
77+
logger.info(f'Eval metric: {metrics}')
78+
metrics['step'] = step
79+
if loss_metric > float(metrics['loss']):
80+
model.save(f'checkpoint-{step}')
81+
loss_metric = float(metrics['loss'])
11282
model.save(f'last-checkpoint')
11383

11484

src/twinkle/metric/dpo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,5 +209,5 @@ def calculate(self):
209209
results['rewards/rejected'] = f'{total_rejected_rewards / total_count:.4f}'
210210
results['rewards/margins'] = f'{total_reward_margin / total_count:.4f}'
211211
results['rewards/accuracies'] = f'{total_reward_correct / total_count * 100:.1f}%'
212-
212+
self.reset()
213213
return results

0 commit comments

Comments
 (0)