|
76 | 76 |
|
77 | 77 | BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4)) # Number of preference pairs |
78 | 78 | MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 4)) |
79 | | -GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 8)) |
80 | | -MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000)) |
81 | | -LEARNING_RATE = float(os.environ.get('LR', 5e-5)) |
| 79 | +GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 4)) |
| 80 | +LEARNING_RATE = float(os.environ.get('LR', 5e-6)) # TRL default for DPO is 5e-7 to 5e-6 |
82 | 81 | DPO_BETA = float(os.environ.get('DPO_BETA', 0.1)) |
83 | | -SFT_WEIGHT = float(os.environ.get('SFT_WEIGHT', 0.1)) # SFT loss weight for regularization |
| 82 | +SFT_WEIGHT = float(os.environ.get('SFT_WEIGHT', 1.0)) # SFT loss weight for regularization |
84 | 83 | LOSS_TYPE = os.environ.get('LOSS_TYPE', 'sigmoid') # sigmoid, hinge, ipo, simpo, orpo, cpo |
85 | 84 | SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 200)) |
86 | 85 | MAX_LENGTH = int(os.environ.get('MAX_LENGTH', 2048)) |
|
90 | 89 |
|
91 | 90 | def create_dpo_dataset(): |
92 | 91 | """Create DPO dataset with positive/negative format.""" |
93 | | - dataset = Dataset(DatasetMeta(DATASET_ID, data_slice=range(15000))) |
| 92 | + dataset = Dataset(DatasetMeta(DATASET_ID)) |
94 | 93 | dataset.set_template('Template', model_id=MODEL_ID, max_length=MAX_LENGTH) |
95 | 94 | dataset.map( |
96 | 95 | EmojiDPOProcessor, |
@@ -188,6 +187,7 @@ def main(): |
188 | 187 | device_mesh=policy_mesh, |
189 | 188 | remote_group='policy', |
190 | 189 | ) |
| 190 | + MAX_STEPS = len(dataloader) |
191 | 191 | policy_model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS) |
192 | 192 | policy_model.set_optimizer('AdamW', lr=LEARNING_RATE, weight_decay=0.01) |
193 | 193 | policy_model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=LEARNING_RATE * 0.1) |
|
0 commit comments