Skip to content

Commit c365198

Browse files
committed
Merge remote-tracking branch 'origin/dev' into kernels_unittest_fix_ljl
2 parents 2b35007 + 0e89db9 commit c365198

File tree

10 files changed

+51
-77
lines changed

10 files changed

+51
-77
lines changed

cookbook/client/tinker/short_math_grpo.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,6 @@
5656
'2. Final answer after ####\n\n'
5757
'Example:\n<step>Key step1 -> Ket step 2 -> conclusion</step>\n#### 42')
5858

59-
# SwanLab experiment tracking
60-
USE_SWANLAB = True
61-
if USE_SWANLAB:
62-
import swanlab
63-
swanlab.login(api_key=os.environ['SWANLAB_API_KEY'])
64-
swanlab.init(
65-
project='twinkle-Math', config={
66-
'model_id': BASE_MODEL,
67-
})
6859

6960

7061
class MathPreprocessor(Preprocessor):
@@ -403,8 +394,6 @@ def main():
403394
log_dict['train/num_training_samples'] = len(training_data)
404395
logger.info(f'Step {step}: {log_dict}')
405396
step += 1
406-
if USE_SWANLAB:
407-
swanlab.log(log_dict)
408397

409398
# Save final checkpoint
410399
save_future = training_client.save_state('Math-grpo-final')

cookbook/megatron/tp.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,6 @@
99
from twinkle.model import MegatronModel
1010
from twinkle.preprocessor import SelfCognitionProcessor
1111

12-
if Platform.get_rank() == 0 and os.environ.get('SWANLAB_API_KEY'):
13-
# rank0 recording
14-
import swanlab
15-
swanlab.login(api_key=os.environ['SWANLAB_API_KEY'], save=True)
16-
17-
run = swanlab.init(project='twinkle', )
18-
1912
# Construct a device_mesh, tp=pp=cp=2, dp=1
2013
device_mesh = DeviceMesh.from_sizes(dp_size=1, tp_size=2, pp_size=2, cp_size=2)
2114
# use torchrun mode
@@ -75,8 +68,6 @@ def train():
7568
if step % 5 == 0:
7669
# Print metric
7770
metric = model.calculate_metric(is_training=True)
78-
if Platform.get_rank() == 0 and os.environ.get('SWANLAB_API_KEY'):
79-
swanlab.log(metric)
8071
logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}')
8172
if step > 0 and step % 20 == 0:
8273
metrics = eval(model)

cookbook/megatron/tp_moe.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,6 @@
99
from twinkle.model import MegatronModel
1010
from twinkle.preprocessor import SelfCognitionProcessor
1111

12-
if Platform.get_rank() == 0 and os.environ.get('SWANLAB_API_KEY'):
13-
# rank0 recording
14-
import swanlab
15-
swanlab.login(api_key=os.environ['SWANLAB_API_KEY'], save=True)
16-
17-
run = swanlab.init(project='twinkle', )
18-
1912
# Construct a device_mesh, tp=pp=cp=ep=2, dp=1
2013
device_mesh = DeviceMesh.from_sizes(dp_size=1, tp_size=2, pp_size=2, cp_size=2, ep_size=2)
2114
# use torchrun mode
@@ -74,8 +67,6 @@ def train():
7467
if step % 5 == 0:
7568
# Print metric
7669
metric = model.calculate_metric(is_training=True)
77-
if Platform.get_rank() == 0 and os.environ.get('SWANLAB_API_KEY'):
78-
swanlab.log(metric)
7970
logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}')
8071
if step > 0 and step % 20 == 0:
8172
metrics = eval(model)

cookbook/ray/single_controller.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,6 @@
99
from twinkle.model import TransformersModel
1010
from twinkle.preprocessor import SelfCognitionProcessor
1111

12-
if Platform.get_rank() == 0 and os.environ.get('SWANLAB_API_KEY'):
13-
# rank0 recording
14-
import swanlab
15-
swanlab.login(api_key=os.environ['SWANLAB_API_KEY'], save=True)
16-
17-
run = swanlab.init(project='twinkle', )
18-
1912
device_group = [DeviceGroup(
2013
name='default',
2114
ranks=8,
@@ -83,8 +76,6 @@ def train():
8376
if step % 20 == 0:
8477
# Print metric
8578
metric = model.calculate_metric(is_training=True)
86-
if Platform.get_rank() == 0 and os.environ.get('SWANLAB_API_KEY'):
87-
swanlab.log(metric)
8879
logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}')
8980
if step > 0 and step % 40 == 0:
9081
metrics = eval(model)

cookbook/rl/grpo.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# WIP, not working yet
21
import os
32
from typing import List, Tuple, Dict, Any
43

@@ -32,7 +31,9 @@
3231
MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096))
3332
LEARNING_RATE = float(os.environ.get('LR', 1e-5))
3433
MAX_STEPS = int(os.environ.get('MAX_STEPS', 200))
35-
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4))
34+
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 16)) # global prompt-level, global completion-level batch size = BATCH_SIZE * num_generations * dp_size
35+
MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 16)) # global completion-level mini-batch-size
36+
MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2)) # per-device-micro-batch-size (completion-level), batch_size in forward_backward
3637
GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1))
3738
ADAPTER_NAME = 'default'
3839

@@ -150,19 +151,31 @@ def main():
150151
},
151152
)
152153

153-
advantages = advantage_fn(
154-
total_rewards,
155-
num_generations=NUM_GENERATIONS,
156-
scale='group',
157-
)
158-
advantages = advantages.tolist()
159-
160-
model.forward_backward(inputs=all_input_data, old_logps=all_old_logps, advantages=advantages, micro_batch_size=2)
161-
model.clip_grad_and_step()
162-
optim_step += 1
163-
log_dict = metrics.calculate()
164-
log_dict.update(model.calculate_metric(is_training=True))
165-
logger.info(f'[Step {optim_step}/{MAX_STEPS}] {log_dict}')
154+
advantages = advantage_fn(total_rewards, num_generations=NUM_GENERATIONS, scale='group').tolist()
155+
156+
# Split completions into mini-batches and run one optim step per mini-batch.
157+
total_completions = len(all_input_data)
158+
for mb_start in range(0, total_completions, MINI_BATCH_SIZE):
159+
mb_end = min(mb_start + MINI_BATCH_SIZE, total_completions)
160+
mb_inputs = all_input_data[mb_start:mb_end]
161+
mb_old_logps = all_old_logps[mb_start:mb_end]
162+
mb_advantages = advantages[mb_start:mb_end]
163+
164+
model.forward_backward(
165+
inputs=mb_inputs,
166+
old_logps=mb_old_logps,
167+
advantages=mb_advantages,
168+
micro_batch_size=MICRO_BATCH_SIZE,
169+
)
170+
model.clip_grad_and_step()
171+
optim_step += 1
172+
173+
if optim_step >= MAX_STEPS:
174+
break
175+
log_dict = metrics.calculate()
176+
log_dict.update(model.calculate_metric(is_training=True))
177+
metrics.reset()
178+
logger.info(f'[Step {optim_step}/{MAX_STEPS}] {log_dict}')
166179

167180
logger.info(f'Training completed. optim_steps={optim_step}')
168181
model.save('grpo-gsm8k-checkpoint')

cookbook/transformers/fsdp2.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,6 @@
99
from twinkle.model import TransformersModel
1010
from twinkle.preprocessor import SelfCognitionProcessor
1111

12-
if Platform.get_rank() == 0 and os.environ.get('SWANLAB_API_KEY'):
13-
# rank0 recording
14-
import swanlab
15-
swanlab.login(api_key=os.environ['SWANLAB_API_KEY'], save=True)
16-
17-
run = swanlab.init(project='twinkle', )
18-
1912
# Construct a device_mesh, fsdp=4, dp=2
2013
device_mesh = DeviceMesh.from_sizes(fsdp_size=4, dp_size=2)
2114
# use torchrun mode
@@ -77,8 +70,6 @@ def train():
7770
if step % 20 == 0:
7871
# Print metric
7972
metric = model.calculate_metric(is_training=True)
80-
if Platform.get_rank() == 0 and os.environ.get('SWANLAB_API_KEY'):
81-
swanlab.log(metric)
8273
logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}')
8374
if step > 0 and step % 40 == 0:
8475
metrics = eval(model)

cookbook/transformers/fsdp2_moe.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,6 @@
99
from twinkle.model import TransformersModel
1010
from twinkle.preprocessor import SelfCognitionProcessor
1111

12-
if Platform.get_rank() == 0 and os.environ.get('SWANLAB_API_KEY'):
13-
# rank0 recording
14-
import swanlab
15-
swanlab.login(api_key=os.environ['SWANLAB_API_KEY'], save=True)
16-
17-
run = swanlab.init(
18-
project='twinkle',
19-
)
20-
21-
2212
# Construct a device_mesh, fsdp=4, dp=2
2313
device_mesh = DeviceMesh.from_sizes(fsdp_size=4, dp_size=2)
2414
# use torchrun mode
@@ -83,8 +73,6 @@ def train():
8373
if step % 20 == 0:
8474
# Print metric
8575
metric = model.calculate_metric(is_training=True)
86-
if Platform.get_rank() == 0 and os.environ.get('SWANLAB_API_KEY'):
87-
swanlab.log(metric)
8876
logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}')
8977
if step > 0 and step % 40 == 0:
9078
metrics = eval(model)

src/twinkle/checkpoint_engine/manager.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,25 @@ def decide_backend_engine(platform: Optional[str] = None) -> 'CheckpointEngine':
7777
raise NotImplementedError
7878

7979
def sync_weights(self, merge_and_sync=True):
80+
"""
81+
Synchronize the weights between the model and the sampler.
82+
83+
This method ensures that the sampler's weights are consistent with the model's
84+
current state. It supports two synchronization modes: full merge-and-sync or
85+
separate base-and-LoRA sync.
86+
87+
Args:
88+
merge_and_sync (bool, optional): Whether to merge and sync the weights.
89+
- If True: LoRA weights are merged into the base model, then the
90+
combined weights are synchronized to the sampler on every call.
91+
- If False: On the first call, base model weights are synced to the
92+
sampler. On subsequent calls, only the LoRA adapter weights are
93+
synced incrementally.
94+
Defaults to True.
95+
96+
Returns:
97+
None
98+
"""
8099
start_time = time.time()
81100
model_metadata = self.model.prepare_checkpoint_engine([True]
82101
+ [False] * (self.model.device_mesh.world_size - 1))

src/twinkle/metric/completion_and_reward.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@ def _std(statistic_list: List[float]) -> float:
5757

5858
def calculate(self) -> Dict[str, Any]:
5959
metric_dict = {}
60-
if self.weight_sync_time is not None:
60+
if self.weight_sync_time:
6161
metric_dict['profiling/Time taken: move_model_to_sampler'] = self._mean(self.weight_sync_time)
62-
if self.generate_time is not None:
62+
if self.generate_time:
6363
metric_dict['profiling/Time taken: generate'] = self._mean(self.generate_time)
6464
for key, values in self.rewards.items():
6565
metric_dict[f'train/{key}_reward'] = self._mean(values)

src/twinkle/sampler/vllm_sampler/vllm_engine.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,8 @@ async def _get_or_load_lora(
410410
await self.engine.add_lora(lora_request)
411411
self._lora_request_cache[lora_path] = lora_request
412412
return lora_request
413-
except Exception: # noqa
413+
except Exception as e:
414+
logger.error(f'Failed to load LoRA from {lora_path}: {e}')
414415
return None
415416

416417
async def sleep(self, level: int = 2) -> None:

0 commit comments

Comments
 (0)