Skip to content

Commit 5906f0b

Browse files
committed
update
1 parent 07e0d52 commit 5906f0b

10 files changed

Lines changed: 81 additions & 40 deletions

File tree

cookbook/client/tinker/gsm8k_grpo.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
#
1818
# The server must be running first (see server.py and server_config.yaml).
1919
# Requires both model and sampler services to be configured.
20-
20+
import os
2121
import gc
2222
import re
2323
import numpy as np
@@ -38,12 +38,12 @@
3838
logger = get_logger()
3939

4040
# ========== Configuration ==========
41-
BASE_MODEL = 'Qwen/Qwen2.5-3B-Instruct'
41+
BASE_MODEL = 'Qwen/Qwen2.5-7B-Instruct'
4242
NUM_GENERATIONS = 4
43-
MAX_NEW_TOKENS = 2048
44-
LEARNING_RATE = 1e-5
43+
MAX_NEW_TOKENS = 1024
44+
LEARNING_RATE = 1e-4
4545
MAX_STEPS = 100
46-
BATCH_SIZE = 2
46+
BATCH_SIZE = 4
4747
TEMPERATURE = 1.0
4848
SYNC_INTERVAL = 1 # Save weights for sampler every N steps
4949
LORA_RANK = 8
@@ -60,7 +60,7 @@
6060
USE_SWANLAB = True
6161
if USE_SWANLAB:
6262
import swanlab
63-
swanlab.login(api_key=os.environ['SWANLAB_API_KEY'], save=True)
63+
swanlab.login(api_key=os.environ['SWANLAB_API_KEY'])
6464
swanlab.init(project="twinkle-gsm8k", config={
6565
'model_id': BASE_MODEL,
6666
})
@@ -363,8 +363,8 @@ def main():
363363
input_tokens = prompt_ids + sampled_tokens[:-1]
364364
target_tokens = [0] * ob_len + sampled_tokens
365365
weights = [0] * ob_len + [1] * len(sampled_tokens)
366-
padded_advantages = [advantage] * len(sampled_tokens)
367-
padded_logprobs = logprobs
366+
padded_advantages = [0.0] * ob_len + [advantage] * len(sampled_tokens)
367+
padded_logprobs = [0.0] * ob_len + logprobs
368368

369369
datum = types.Datum(
370370
model_input=types.ModelInput.from_ints(input_tokens),
@@ -393,7 +393,8 @@ def main():
393393

394394
# ========== 7. Log ==========
395395
log_dict = metrics.calculate()
396-
log_dict['train/loss_per_token'] = float(avg_loss)
396+
if optim_result.metrics:
397+
log_dict.update(optim_result.metrics)
397398
log_dict['train/frac_reward_zero_std'] = frac_zero_std
398399
log_dict['train/num_training_samples'] = len(training_data)
399400
logger.info(f"Step {step}: {log_dict}")

cookbook/client/tinker/megatron/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
# Resolve the path to server_config.yaml relative to this script's location
1717
file_dir = os.path.abspath(os.path.dirname(__file__))
18-
config_path = os.path.join(file_dir, 'server_config_3b.yaml')
18+
config_path = os.path.join(file_dir, 'server_config.yaml')
1919

2020
# Launch the Twinkle server — this call blocks until the server is shut down
2121
launch_server(config_path=config_path)

cookbook/client/tinker/megatron/server_config_3b.yaml renamed to cookbook/client/tinker/megatron/server_config_7b.yaml

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,23 +33,24 @@ applications:
3333

3434
# 2. Model Service (commented out) - Would host the base model for training.
3535
# Uncomment and configure if you need a training model worker.
36-
- name: models-Qwen2.5-3B-Instruct
37-
route_prefix: /api/v1/model/Qwen/Qwen2.5-3B-Instruct
36+
- name: models-Qwen2.5-7B-Instruct
37+
route_prefix: /api/v1/model/Qwen/Qwen2.5-7B-Instruct
3838
import_path: model
3939
args:
40-
use_megatron: true # Use HuggingFace Transformers backend
41-
model_id: "ms://Qwen/Qwen2.5-3B-Instruct" # ModelScope model identifier
40+
use_megatron: true
41+
model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope model identifier
42+
max_length: 10240
4243
nproc_per_node: 2 # Number of GPU processes per node
4344
device_group:
4445
name: model
45-
ranks: [0, 1] # GPU rank indices
46+
ranks: [0,1] # GPU rank indices
4647
device_type: cuda
4748
device_mesh:
4849
device_type: cuda
49-
50+
dp_size: 2
5051
queue_config:
5152
rps_limit: 100 # Max requests per second
52-
tps_limit: 10000 # Max tokens per second
53+
tps_limit: 100000 # Max tokens per second
5354
adapter_config:
5455
per_token_adapter_limit: 30 # Max concurrent LoRA adapters
5556
adapter_timeout: 1800 # Seconds before idle adapter unload
@@ -68,24 +69,28 @@ applications:
6869

6970
# 3. Sampler Service - Runs inference / sampling using vLLM engine
7071
# Used for generating text from the model (e.g., evaluating LoRA results).
71-
- name: sampler-Qwen2.5-3B-Instruct
72-
route_prefix: /api/v1/sampler/Qwen/Qwen2.5-3B-Instruct
72+
- name: sampler-Qwen2.5-7B-Instruct
73+
route_prefix: /api/v1/sampler/Qwen/Qwen2.5-7B-Instruct
7374
import_path: sampler
7475
args:
75-
model_id: "ms://Qwen/Qwen2.5-3B-Instruct" # ModelScope model identifier
76-
nproc_per_node: 1 # Number of GPU processes per node
76+
model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope model identifier
77+
nproc_per_node: 2 # Number of GPU processes per node
7778
sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler)
7879
engine_args: # vLLM engine-specific settings
7980
max_model_len: 4096 # Maximum sequence length the engine supports
80-
gpu_memory_utilization: 0.7 # Fraction of GPU memory to use (0.0-1.0)
81+
gpu_memory_utilization: 0.5 # Fraction of GPU memory to use (0.0-1.0)
8182
enable_lora: true # Allow loading LoRA adapters during inference
83+
logprobs_mode: processed_logprobs # Logprobs mode for sampling results
8284
device_group: # Logical device group for the sampler
8385
name: sampler
84-
gpus_per_worker: 1
85-
ranks: [0] # GPU rank indices to use
86+
ranks: [2] # GPU rank indices to use
8687
device_type: cuda
8788
device_mesh:
8889
device_type: cuda
90+
dp_size: 1
91+
queue_config:
92+
rps_limit: 100 # Max requests per second
93+
tps_limit: 100000 # Max tokens per second
8994
deployments:
9095
- name: SamplerManagement
9196
autoscaling_config:

cookbook/client/tinker/transformer/server.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77

88
import os
99

10-
# Enable Ray debug mode for verbose logging during development
11-
# os.environ['RAY_DEBUG'] = '1'
1210
os.environ['TWINKLE_TRUST_REMOTE_CODE'] = '0'
1311

1412
from twinkle.server import launch_server

cookbook/client/tinker/transformer/server_config.yaml

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,17 @@ applications:
3333

3434
# 2. Model Service (commented out) - Would host the base model for training.
3535
# Uncomment and configure if you need a training model worker.
36-
- name: models-Qwen2.5-3B-Instruct
37-
route_prefix: /api/v1/model/Qwen/Qwen2.5-3B-Instruct
36+
- name: models-Qwen2.5-7B-Instruct
37+
route_prefix: /api/v1/model/Qwen/Qwen2.5-7B-Instruct
3838
import_path: model
3939
args:
4040
use_megatron: false # Use HuggingFace Transformers backend
41-
model_id: "ms://Qwen/Qwen2.5-3B-Instruct" # ModelScope model identifier
41+
model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope model identifier
42+
max_length: 10240
4243
nproc_per_node: 2 # Number of GPU processes per node
4344
device_group:
4445
name: model
45-
ranks: [1,2] # GPU rank indices
46+
ranks: [0,1] # GPU rank indices
4647
device_type: cuda
4748
device_mesh:
4849
device_type: cuda
@@ -68,12 +69,12 @@ applications:
6869

6970
# 3. Sampler Service - Runs inference / sampling using vLLM engine
7071
# Used for generating text from the model (e.g., evaluating LoRA results).
71-
- name: sampler-Qwen2.5-3B-Instruct
72-
route_prefix: /api/v1/sampler/Qwen/Qwen2.5-3B-Instruct
72+
- name: sampler-Qwen2.5-7B-Instruct
73+
route_prefix: /api/v1/sampler/Qwen/Qwen2.5-7B-Instruct
7374
import_path: sampler
7475
args:
75-
model_id: "ms://Qwen/Qwen2.5-3B-Instruct" # ModelScope model identifier
76-
nproc_per_node: 1 # Number of GPU processes per node
76+
model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope model identifier
77+
nproc_per_node: 2 # Number of GPU processes per node
7778
sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler)
7879
engine_args: # vLLM engine-specific settings
7980
max_model_len: 4096 # Maximum sequence length the engine supports
@@ -82,7 +83,7 @@ applications:
8283
logprobs_mode: processed_logprobs # Logprobs mode for sampling results
8384
device_group: # Logical device group for the sampler
8485
name: sampler
85-
ranks: [3] # GPU rank indices to use
86+
ranks: [2] # GPU rank indices to use
8687
device_type: cuda
8788
device_mesh:
8889
device_type: cuda

src/twinkle/infra/_ray/resource_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def __init__(self,
7171
self.min_node_idx = 0
7272
self.nnodes = math.ceil(cpu_proc_count / ncpu_proc_per_node)
7373

74-
breakpoint()
74+
# breakpoint()
7575
self.nodes = []
7676
for node in ray.nodes():
7777
# get available nodes

src/twinkle/loss/grpo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def __call__(
305305
Returns:
306306
loss: Scalar loss value
307307
"""
308-
breakpoint()
308+
# breakpoint()
309309
import torch
310310
labels = inputs.get('labels')
311311
assert labels is not None, "inputs must contain 'labels'"

src/twinkle/server/tinker/common/megatron_model.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,18 @@ def _collect_forward_backward_results(results):
4848

4949
return [all_outputs, avg_loss]
5050

51+
def _clean_metrics(metrics: dict) -> dict:
52+
cleaned = {}
53+
for key, value in metrics.items():
54+
if isinstance(value, str):
55+
import re
56+
match = re.match(r'^([+-]?\d*\.?\d+)', value.strip())
57+
if match:
58+
cleaned[key] = float(match.group(1))
59+
else:
60+
cleaned[key] = value
61+
return cleaned
62+
5163

5264
@remote_class(execute='all')
5365
class TwinkleCompatMegatronModel(_MegatronBase):
@@ -179,6 +191,13 @@ def step(self, *, adam_params: types.AdamParams, **kwargs):
179191
# Zero gradients
180192
super().zero_grad(**kwargs)
181193

194+
195+
@remote_function(collect='first', lazy_collect=False)
196+
def calculate_metric(self, is_training, **kwargs):
197+
metric = super().calculate_metric(is_training, **kwargs)
198+
return _clean_metrics(metric)
199+
200+
182201
@remote_function(dispatch='all', sync=True)
183202
def load(self, checkpoint_dir: str, **kwargs):
184203
"""

src/twinkle/server/tinker/common/transformers_model.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,18 @@ def _collect_forward_backward_results(results):
3535

3636
return [all_outputs, avg_loss]
3737

38+
def _clean_metrics(metrics: dict) -> dict:
39+
cleaned = {}
40+
for key, value in metrics.items():
41+
if isinstance(value, str):
42+
import re
43+
match = re.match(r'^([+-]?\d*\.?\d+)', value.strip())
44+
if match:
45+
cleaned[key] = float(match.group(1))
46+
else:
47+
cleaned[key] = value
48+
return cleaned
49+
3850

3951
@remote_class()
4052
class TwinkleCompatTransformersModel(MultiLoraTransformersModel):
@@ -102,7 +114,7 @@ def forward_backward(self, *, inputs: List[types.Datum], adapter_name: str, loss
102114

103115
# Convert Datum to InputFeature
104116
input_features = datum_to_input_feature(inputs, template)
105-
117+
# breakpoint()
106118
# Forward pass
107119
outputs = super().forward(inputs=input_features, adapter_name=adapter_name, **kwargs)
108120

@@ -139,7 +151,11 @@ def step(self, *, adam_params: types.AdamParams, **kwargs):
139151
# Zero gradients
140152
super().zero_grad(**kwargs)
141153

142-
return super().calculate_metric(is_training=True, **kwargs)
154+
@remote_function(collect='first', lazy_collect=False)
155+
def calculate_metric(self, is_training, **kwargs):
156+
metric = super().calculate_metric(is_training, **kwargs)
157+
return _clean_metrics(metric)
158+
143159
@remote_function()
144160
def load(self, checkpoint_dir: str, **kwargs):
145161
"""

src/twinkle/server/tinker/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -419,8 +419,9 @@ async def _do_optim():
419419
# Touch adapter to reset inactivity counter
420420
self.touch_adapter(adapter_name)
421421

422-
metrics = self.model.step(adam_params=body.adam_params,
422+
self.model.step(adam_params=body.adam_params,
423423
adapter_name=adapter_name)
424+
metrics = self.model.calculate_metric(is_training=True, adapter_name=adapter_name)
424425
return types.OptimStepResponse(metrics=metrics)
425426
except Exception:
426427
logger.error(traceback.format_exc())

0 commit comments

Comments
 (0)