Skip to content

Commit 9871f3c

Browse files
committed
merge
2 parents dd84454 + 063686c commit 9871f3c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+968
-515
lines changed

README.md

Lines changed: 97 additions & 60 deletions
Large diffs are not rendered by default.

README_ZH.md

Lines changed: 213 additions & 141 deletions
Large diffs are not rendered by default.

assets/framework.jpg

33.9 KB
Loading

assets/multi_lora.png

-110 KB
Loading

cookbook/client/tinker/transformer/grpo.py renamed to cookbook/client/tinker/grpo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
logger = get_logger()
3535

3636
# ========== Configuration ==========
37-
BASE_MODEL = 'Qwen/Qwen2.5-3B-Instruct'
37+
BASE_MODEL = 'Qwen/Qwen2.5-7B-Instruct'
3838
NUM_GENERATIONS = 4
3939
MAX_NEW_TOKENS = 1024
4040
LEARNING_RATE = 1e-5

cookbook/client/tinker/transformer/gsm8k.py renamed to cookbook/client/tinker/gsm8k_grpo.py

Lines changed: 22 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
#
1818
# The server must be running first (see server.py and server_config.yaml).
1919
# Requires both model and sampler services to be configured.
20-
20+
import os
2121
import gc
2222
import re
2323
import numpy as np
@@ -38,12 +38,12 @@
3838
logger = get_logger()
3939

4040
# ========== Configuration ==========
41-
BASE_MODEL = 'Qwen/Qwen2.5-3B-Instruct'
41+
BASE_MODEL = 'Qwen/Qwen2.5-7B-Instruct'
4242
NUM_GENERATIONS = 4
43-
MAX_NEW_TOKENS = 2048
44-
LEARNING_RATE = 1e-5
43+
MAX_NEW_TOKENS = 1024
44+
LEARNING_RATE = 1e-4
4545
MAX_STEPS = 100
46-
BATCH_SIZE = 2
46+
BATCH_SIZE = 4
4747
TEMPERATURE = 1.0
4848
SYNC_INTERVAL = 1 # Save weights for sampler every N steps
4949
LORA_RANK = 8
@@ -56,6 +56,14 @@
5656
"For example:\n<think> ... reasoning ... </think>\n#### 42"
5757
)
5858

59+
# SwanLab experiment tracking
60+
USE_SWANLAB = True
61+
if USE_SWANLAB:
62+
import swanlab
63+
swanlab.login(api_key=os.environ['SWANLAB_API_KEY'])
64+
swanlab.init(project="twinkle-gsm8k", config={
65+
'model_id': BASE_MODEL,
66+
})
5967

6068
class GSM8KProcessor(Preprocessor):
6169
"""Preprocessor for GSM8K dataset.
@@ -354,18 +362,15 @@ def main():
354362
ob_len = len(prompt_ids) - 1
355363
input_tokens = prompt_ids + sampled_tokens[:-1]
356364
target_tokens = [0] * ob_len + sampled_tokens
365+
weights = [0] * ob_len + [1] * len(sampled_tokens)
357366
padded_advantages = [0.0] * ob_len + [advantage] * len(sampled_tokens)
358367
padded_logprobs = [0.0] * ob_len + logprobs
359-
360-
# Verify lengths match
361-
assert len(input_tokens) == len(target_tokens) == len(padded_logprobs) == len(padded_advantages), \
362-
f"Length mismatch: input={len(input_tokens)}, target={len(target_tokens)}, " \
363-
f"logprobs={len(padded_logprobs)}, advantages={len(padded_advantages)}"
364368

365369
datum = types.Datum(
366370
model_input=types.ModelInput.from_ints(input_tokens),
367371
loss_fn_inputs={
368372
'target_tokens': target_tokens,
373+
'weights': weights,
369374
'logprobs': types.TensorData.from_numpy(np.array(padded_logprobs, dtype=np.float32)),
370375
'advantages': types.TensorData.from_numpy(np.array(padded_advantages, dtype=np.float32)),
371376
},
@@ -380,40 +385,22 @@ def main():
380385

381386
# Forward-backward pass with importance_sampling (GRPO) loss
382387
# The training data already contains logprobs and advantages for the GRPO loss
383-
fwdbwd_future = training_client.forward_backward(
384-
training_data, "importance_sampling")
385-
optim_future = training_client.optim_step(
386-
types.AdamParams(learning_rate=LEARNING_RATE))
387-
388-
fwdbwd_result = fwdbwd_future.result()
389-
optim_result = optim_future.result()
390-
391-
# Compute metrics from the forward-backward result
392-
# For importance_sampling, we get logprobs and elementwise_loss
393-
logprobs_list = []
394-
elementwise_losses = []
395-
for output in fwdbwd_result.loss_fn_outputs:
396-
if output.get('logprobs') is not None:
397-
logprobs_list.append(output['logprobs'].to_numpy())
398-
if output.get('elementwise_loss') is not None:
399-
elementwise_losses.append(output['elementwise_loss'].to_numpy())
400-
401-
# Compute average loss per token (weighted by advantages)
402-
if elementwise_losses:
403-
all_losses = np.concatenate(elementwise_losses)
404-
avg_loss = np.mean(all_losses) if len(all_losses) > 0 else 0.0
405-
else:
406-
avg_loss = 0.0
388+
fwdbwd_result = training_client.forward_backward(training_data, "importance_sampling").result()
389+
390+
optim_result = training_client.optim_step(types.AdamParams(learning_rate=LEARNING_RATE)).result()
407391

408392
gc.collect()
409393

410394
# ========== 7. Log ==========
411395
log_dict = metrics.calculate()
412-
log_dict['train/loss_per_token'] = float(avg_loss)
396+
if optim_result.metrics:
397+
log_dict.update(optim_result.metrics)
413398
log_dict['train/frac_reward_zero_std'] = frac_zero_std
414399
log_dict['train/num_training_samples'] = len(training_data)
415400
logger.info(f"Step {step}: {log_dict}")
416401
step += 1
402+
if USE_SWANLAB:
403+
swanlab.log(log_dict)
417404

418405
# Save final checkpoint
419406
save_future = training_client.save_state("gsm8k-grpo-final")

cookbook/client/tinker/transformer/lora.py renamed to cookbook/client/tinker/lora.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
# 2. A model id on hub: "<user>/<model_id>"
3737
# Example:
3838
# resume_path = "twinkle://20260131_170251-Qwen_Qwen2_5-0_5B-Instruct-7275126c/weights/pig-latin-lora-epoch-1"
39-
# resume_path = "AlexEz/20260205_163645-Qwen_Qwen2_5-3B-Instruct-385d5c17_pig-latin-lora-epoch-1"
39+
# resume_path = "AlexEz/20260205_163645-Qwen_Qwen2_5-7B-Instruct-385d5c17_pig-latin-lora-epoch-1"
4040
resume_path = ""
4141

4242
print(f"Found {len(response.training_runs)} training runs")
@@ -51,7 +51,7 @@
5151

5252
# Step 5: Create or resume a training client.
5353
# If resume_path is set, it restores both model weights and optimizer state.
54-
base_model = "Qwen/Qwen2.5-3B-Instruct"
54+
base_model = "Qwen/Qwen2.5-7B-Instruct"
5555
if not resume_path:
5656
training_client = service_client.create_lora_training_client(
5757
base_model=base_model

cookbook/client/tinker/megatron/server_config.yaml

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ proxy_location: EveryNode
1010
# HTTP listener settings
1111
http_options:
1212
host: 0.0.0.0 # Listen on all network interfaces
13-
port: 8000 # Port number for the server
13+
port: 9000 # Port number for the server
1414

1515
# Applications: each entry defines a service component deployed on the server
1616
applications:
@@ -50,12 +50,13 @@ applications:
5050
enable_lora: true # Allow loading LoRA adapters during inference
5151
device_group: # Logical device group for the sampler
5252
name: sampler
53-
gpus_per_worker: 1
54-
ranks: [0,1,2,3] # GPU rank indices to use
53+
gpus_per_worker: 2
54+
ranks: [0,1,2,3,4,5,6,7] # GPU rank indices to use
5555
device_type: cuda
5656
device_mesh:
5757
device_type: cuda
5858
dp_size: 4
59+
tp_size: 2
5960
deployments:
6061
- name: SamplerManagement
6162
autoscaling_config:
@@ -67,7 +68,7 @@ applications:
6768
runtime_env:
6869
env_vars:
6970
TWINKLE_TRUST_REMOTE_CODE: "0"
70-
DEVICE_COUNT_PER_PHYSICAL_NODE: "8"
71+
DEVICE_COUNT_PER_PHYSICAL_NODE: "16"
7172

7273
# 2. Model Service (commented out) - Would host the base model for training.
7374
# Uncomment and configure if you need a training model worker.
@@ -80,12 +81,13 @@ applications:
8081
nproc_per_node: 4 # Number of GPU processes per node
8182
device_group:
8283
name: model
83-
ranks: [4,5,6,7] # GPU rank indices
84+
ranks: [8,9,10,11,12,13,14,15] # GPU rank indices
8485
device_type: cuda
8586
device_mesh:
8687
device_type: cuda
8788
dp_size: 2
8889
tp_size: 2
90+
pp:size: 2
8991
ep_size: 2
9092

9193
queue_config:
@@ -105,4 +107,4 @@ applications:
105107
runtime_env:
106108
env_vars:
107109
TWINKLE_TRUST_REMOTE_CODE: "0"
108-
DEVICE_COUNT_PER_PHYSICAL_NODE: "8"
110+
DEVICE_COUNT_PER_PHYSICAL_NODE: "16"
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Twinkle Server Configuration - Tinker-Compatible Transformers Backend
2+
3+
# Server protocol type: "tinker" enables the Tinker-compatible API
4+
server_type: tinker
5+
6+
# proxy_location: determines where the HTTP proxy runs.
7+
# "EveryNode" means each Ray node runs its own proxy (good for multi-node).
8+
proxy_location: EveryNode
9+
10+
# HTTP listener settings
11+
http_options:
12+
host: 0.0.0.0 # Listen on all network interfaces
13+
port: 8000 # Port number for the server
14+
15+
# Applications: each entry defines a service component deployed on the server
16+
applications:
17+
18+
# 1. TinkerCompatServer - The central API server
19+
# Handles client connections, training run tracking, checkpoint listing.
20+
- name: server
21+
route_prefix: /api/v1 # API endpoint prefix (Tinker-compatible)
22+
import_path: server # Python module to import
23+
args:
24+
25+
deployments:
26+
- name: TinkerCompatServer
27+
autoscaling_config:
28+
min_replicas: 1 # Minimum number of replicas
29+
max_replicas: 1 # Maximum number of replicas
30+
target_ongoing_requests: 128 # Target concurrent requests per replica
31+
ray_actor_options:
32+
num_cpus: 0.1 # CPU resources allocated to this actor
33+
34+
# 2. Model Service (commented out) - Would host the base model for training.
35+
# Uncomment and configure if you need a training model worker.
36+
- name: models-Qwen2.5-7B-Instruct
37+
route_prefix: /api/v1/model/Qwen/Qwen2.5-7B-Instruct
38+
import_path: model
39+
args:
40+
use_megatron: true
41+
model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope model identifier
42+
max_length: 10240
43+
nproc_per_node: 2 # Number of GPU processes per node
44+
device_group:
45+
name: model
46+
ranks: [0,1] # GPU rank indices
47+
device_type: cuda
48+
device_mesh:
49+
device_type: cuda
50+
dp_size: 2
51+
queue_config:
52+
rps_limit: 100 # Max requests per second
53+
tps_limit: 100000 # Max tokens per second
54+
adapter_config:
55+
per_token_adapter_limit: 30 # Max concurrent LoRA adapters
56+
adapter_timeout: 1800 # Seconds before idle adapter unload
57+
deployments:
58+
- name: ModelManagement
59+
autoscaling_config:
60+
min_replicas: 1
61+
max_replicas: 1
62+
target_ongoing_requests: 16
63+
ray_actor_options:
64+
num_cpus: 0.1
65+
runtime_env:
66+
env_vars:
67+
TWINKLE_TRUST_REMOTE_CODE: "0"
68+
DEVICE_COUNT_PER_PHYSICAL_NODE: "8"
69+
70+
# 3. Sampler Service - Runs inference / sampling using vLLM engine
71+
# Used for generating text from the model (e.g., evaluating LoRA results).
72+
- name: sampler-Qwen2.5-7B-Instruct
73+
route_prefix: /api/v1/sampler/Qwen/Qwen2.5-7B-Instruct
74+
import_path: sampler
75+
args:
76+
model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope model identifier
77+
nproc_per_node: 2 # Number of GPU processes per node
78+
sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler)
79+
engine_args: # vLLM engine-specific settings
80+
max_model_len: 4096 # Maximum sequence length the engine supports
81+
gpu_memory_utilization: 0.5 # Fraction of GPU memory to use (0.0-1.0)
82+
enable_lora: true # Allow loading LoRA adapters during inference
83+
logprobs_mode: processed_logprobs # Logprobs mode for sampling results
84+
device_group: # Logical device group for the sampler
85+
name: sampler
86+
ranks: [2] # GPU rank indices to use
87+
device_type: cuda
88+
device_mesh:
89+
device_type: cuda
90+
dp_size: 1
91+
queue_config:
92+
rps_limit: 100 # Max requests per second
93+
tps_limit: 100000 # Max tokens per second
94+
deployments:
95+
- name: SamplerManagement
96+
autoscaling_config:
97+
min_replicas: 1
98+
max_replicas: 1
99+
target_ongoing_requests: 16
100+
ray_actor_options:
101+
num_cpus: 0.1
102+
runtime_env:
103+
env_vars:
104+
TWINKLE_TRUST_REMOTE_CODE: "0"
105+
DEVICE_COUNT_PER_PHYSICAL_NODE: "8"

cookbook/client/tinker/transformer/sample.py renamed to cookbook/client/tinker/sample.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from modelscope import AutoTokenizer
1010

1111
# Step 1: Define the base model and connect to the server
12-
base_model = "Qwen/Qwen2.5-3B-Instruct"
12+
base_model = "Qwen/Qwen2.5-7B-Instruct"
1313
service_client = init_tinker_compat_client(base_url='http://localhost:8000', api_key="tml-EMPTY_TOKEN")
1414

1515
# Step 2: Create a sampling client by loading weights from a saved checkpoint.

0 commit comments

Comments
 (0)