Skip to content

Commit 8c662f0

Browse files
committed
wip
1 parent 0cf1ac3 commit 8c662f0

File tree

2 files changed

+9
-11
lines changed

2 files changed

+9
-11
lines changed

cookbook/rl/dpo.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,10 @@
7676

7777
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4)) # Number of preference pairs
7878
MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 4))
79-
GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 8))
80-
MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000))
81-
LEARNING_RATE = float(os.environ.get('LR', 5e-5))
79+
GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 4))
80+
LEARNING_RATE = float(os.environ.get('LR', 5e-6)) # TRL default for DPO is 5e-7 to 5e-6
8281
DPO_BETA = float(os.environ.get('DPO_BETA', 0.1))
83-
SFT_WEIGHT = float(os.environ.get('SFT_WEIGHT', 0.1)) # SFT loss weight for regularization
82+
SFT_WEIGHT = float(os.environ.get('SFT_WEIGHT', 1.0)) # SFT loss weight for regularization
8483
LOSS_TYPE = os.environ.get('LOSS_TYPE', 'sigmoid') # sigmoid, hinge, ipo, simpo, orpo, cpo
8584
SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 200))
8685
MAX_LENGTH = int(os.environ.get('MAX_LENGTH', 2048))
@@ -90,7 +89,7 @@
9089

9190
def create_dpo_dataset():
9291
"""Create DPO dataset with positive/negative format."""
93-
dataset = Dataset(DatasetMeta(DATASET_ID, data_slice=range(15000)))
92+
dataset = Dataset(DatasetMeta(DATASET_ID))
9493
dataset.set_template('Template', model_id=MODEL_ID, max_length=MAX_LENGTH)
9594
dataset.map(
9695
EmojiDPOProcessor,
@@ -188,6 +187,7 @@ def main():
188187
device_mesh=policy_mesh,
189188
remote_group='policy',
190189
)
190+
MAX_STEPS = len(dataloader)
191191
policy_model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS)
192192
policy_model.set_optimizer('AdamW', lr=LEARNING_RATE, weight_decay=0.01)
193193
policy_model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=LEARNING_RATE * 0.1)

src/twinkle/loss/dpo.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -346,12 +346,10 @@ def __call__(
346346
else:
347347
loss = dpo_loss
348348

349-
# Return sample count for gradient normalization (not token count)
350-
# DPO loss is already per-sample mean, so we just count samples for accumulation
351-
import torch
352-
num_samples = torch.tensor(chosen_labels.shape[0], device=loss.device)
353-
354-
return LossOutput(loss=loss, num_tokens=num_samples)
349+
# Return 0 to skip gradient normalization by num_tokens
350+
# DPO loss is already per-sample mean, unlike SFT which sums per-token loss
351+
# When num_tokens=0, normalize_and_clip_grad_norm defaults to 1 (no division)
352+
return LossOutput(loss=loss, num_tokens=0)
355353

356354

357355
class SimPOLoss(PreferenceLossBase):

0 commit comments

Comments
 (0)