|
1 | | -# WIP, not working yet |
2 | 1 | import os |
3 | 2 | from typing import List, Tuple, Dict, Any |
4 | 3 |
|
|
32 | 31 | MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096)) |
33 | 32 | LEARNING_RATE = float(os.environ.get('LR', 1e-5)) |
34 | 33 | MAX_STEPS = int(os.environ.get('MAX_STEPS', 200)) |
35 | | -BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4)) |
| 34 | +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 16)) # global prompt-level, global completion-level batch size = BATCH_SIZE * num_generations * dp_size |
| 35 | +MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 16)) # global completion-level mini-batch-size |
| 36 | +MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2)) # per-device-micro-batch-size (completion-level), batch_size in forward_backward |
36 | 37 | GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1)) |
37 | 38 | ADAPTER_NAME = 'default' |
38 | 39 |
|
@@ -150,19 +151,31 @@ def main(): |
150 | 151 | }, |
151 | 152 | ) |
152 | 153 |
|
153 | | - advantages = advantage_fn( |
154 | | - total_rewards, |
155 | | - num_generations=NUM_GENERATIONS, |
156 | | - scale='group', |
157 | | - ) |
158 | | - advantages = advantages.tolist() |
159 | | - |
160 | | - model.forward_backward(inputs=all_input_data, old_logps=all_old_logps, advantages=advantages, micro_batch_size=2) |
161 | | - model.clip_grad_and_step() |
162 | | - optim_step += 1 |
163 | | - log_dict = metrics.calculate() |
164 | | - log_dict.update(model.calculate_metric(is_training=True)) |
165 | | - logger.info(f'[Step {optim_step}/{MAX_STEPS}] {log_dict}') |
| 154 | + advantages = advantage_fn(total_rewards, num_generations=NUM_GENERATIONS, scale='group').tolist() |
| 155 | + |
| 156 | + # Split completions into mini-batches and run one optim step per mini-batch. |
| 157 | + total_completions = len(all_input_data) |
| 158 | + for mb_start in range(0, total_completions, MINI_BATCH_SIZE): |
| 159 | + mb_end = min(mb_start + MINI_BATCH_SIZE, total_completions) |
| 160 | + mb_inputs = all_input_data[mb_start:mb_end] |
| 161 | + mb_old_logps = all_old_logps[mb_start:mb_end] |
| 162 | + mb_advantages = advantages[mb_start:mb_end] |
| 163 | + |
| 164 | + model.forward_backward( |
| 165 | + inputs=mb_inputs, |
| 166 | + old_logps=mb_old_logps, |
| 167 | + advantages=mb_advantages, |
| 168 | + micro_batch_size=MICRO_BATCH_SIZE, |
| 169 | + ) |
| 170 | + model.clip_grad_and_step() |
| 171 | + optim_step += 1 |
| 172 | + |
| 173 | + if optim_step >= MAX_STEPS: |
| 174 | + break |
| 175 | + log_dict = metrics.calculate() |
| 176 | + log_dict.update(model.calculate_metric(is_training=True)) |
| 177 | + metrics.reset() |
| 178 | + logger.info(f'[Step {optim_step}/{MAX_STEPS}] {log_dict}') |
166 | 179 |
|
167 | 180 | logger.info(f'Training completed. optim_steps={optim_step}') |
168 | 181 | model.save('grpo-gsm8k-checkpoint') |
|
0 commit comments