Skip to content

Commit 3fe6615

Browse files
authored
Merge pull request #43 from modelscope/add_rl_server
update client server doc
2 parents 57af42d + 4c996af commit 3fe6615

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+2856
-555
lines changed

.gitignore

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,6 @@ test.py
88
# C extensions
99
*.so
1010

11-
src/twinkle_client/dataloader
12-
src/twinkle_client/dataset
13-
src/twinkle_client/model
14-
src/twinkle_client/processor
15-
src/twinkle_client/reward
16-
src/twinkle_client/sampler
17-
1811
# Distribution / packaging
1912
.Python
2013
build/

client_tools/client_generator.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,17 @@
33
from pathlib import Path
44
from typing import Dict, List, Tuple, Set
55

6+
AUTO_GEN_WARNING = """# ============================================================================
7+
# WARNING: AUTO-GENERATED FILE - DO NOT MODIFY MANUALLY!
8+
# ============================================================================
9+
# This file is automatically generated by client_tools/client_generator.py
10+
# Any manual changes will be overwritten when the generator runs again.
11+
#
12+
# To update this file:
13+
# 1. Modify the source files in src/twinkle/
14+
# 2. Run: python client_tools/client_generator.py
15+
# ============================================================================
16+
"""
617

718
def generate_processors():
819
"""Generate client wrappers for all classes with @remote_function methods."""
@@ -326,7 +337,8 @@ def __next__(self):
326337
init_params = "self, **kwargs"
327338
kwargs_dict = "kwargs"
328339

329-
class_template = f'''{chr(10).join(import_lines)}
340+
class_template = f'''{AUTO_GEN_WARNING}
341+
{chr(10).join(import_lines)}
330342
class {class_name}({inheritance}):
331343
"""Client wrapper for {class_name} that calls server HTTP endpoints."""
332344
@@ -414,7 +426,8 @@ def write_init_files(module_files: Dict, src_client_path: Path) -> None:
414426
for source_filename, classes in sorted(source_files.items())
415427
for class_name, _, _, _ in classes
416428
]
417-
init_file.write_text('\n'.join(sorted(init_lines)) + '\n', encoding='utf-8')
429+
init_content = AUTO_GEN_WARNING + '\n'.join(sorted(init_lines)) + '\n'
430+
init_file.write_text(init_content, encoding='utf-8')
418431

419432
module_files = scan_modules(src_twinkle_path, module_mapping)
420433
write_client_files(module_files, src_client_path, processor_type_mapping)
@@ -432,7 +445,7 @@ def generate_models():
432445
client_module_path = src_client_path / 'model'
433446
client_module_path.mkdir(parents=True, exist_ok=True)
434447

435-
model_code = '''from typing import Any, Optional, Union, Type, Dict, Literal, List
448+
model_code = AUTO_GEN_WARNING + '''from typing import Any, Optional, Union, Type, Dict, Literal, List
436449
import uuid
437450
from twinkle_client.http import TWINKLE_SERVER_URL
438451
from twinkle_client.http import http_post, heartbeat_manager
@@ -693,7 +706,7 @@ def upload_to_hub(self, checkpoint_dir: str, hub_model_id: str, hub_token: Optio
693706

694707
# Create/overwrite __init__.py
695708
init_file = client_module_path / '__init__.py'
696-
init_content = "from .multi_lora_transformers import MultiLoraTransformersModel\n"
709+
init_content = AUTO_GEN_WARNING + "from .multi_lora_transformers import MultiLoraTransformersModel\n"
697710
print(f"Writing {init_file}...")
698711
with open(init_file, 'w', encoding='utf-8') as f:
699712
f.write(init_content)
@@ -710,7 +723,7 @@ def generate_samplers():
710723
client_module_path = src_client_path / 'sampler'
711724
client_module_path.mkdir(parents=True, exist_ok=True)
712725

713-
sampler_code = '''from typing import Any, Optional, List, Dict, Union
726+
sampler_code = AUTO_GEN_WARNING + '''from typing import Any, Optional, List, Dict, Union
714727
import uuid
715728
from twinkle_client.http import TWINKLE_SERVER_URL
716729
from twinkle_client.http import http_post, heartbeat_manager
@@ -837,7 +850,7 @@ def set_template(self, template_cls: str, adapter_name: str = '', **kwargs):
837850

838851
# Create/overwrite __init__.py
839852
init_file = client_module_path / '__init__.py'
840-
init_content = "from .vllm_sampler import VLLMSampler\n"
853+
init_content = AUTO_GEN_WARNING + "from .vllm_sampler import VLLMSampler\n"
841854
print(f"Writing {init_file}...")
842855
with open(init_file, 'w', encoding='utf-8') as f:
843856
f.write(init_content)

cookbook/legacy/client/tinker/megatron/lora.py renamed to cookbook/client/tinker/megatron/lora.py

Lines changed: 61 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,45 @@
1-
#%%
1+
# Tinker-Compatible Client - Megatron LoRA Training & Sampling Example
2+
#
3+
# This script demonstrates end-to-end LoRA fine-tuning and inference using the
4+
# Tinker-compatible client API with a Megatron backend.
5+
# It covers: connecting to the server, preparing data manually with tokenizers,
6+
# running a training loop, saving checkpoints, and sampling from the model.
7+
# The server must be running first (see server.py and server_config.yaml).
8+
29
from twinkle_client import init_tinker_compat_client
10+
11+
# Step 1: Initialize the Tinker-compatible client to communicate with the server.
312
service_client = init_tinker_compat_client(base_url='http://localhost:8000')
413

14+
# Step 2: List models available on the server to verify the connection
515
print("Available models:")
616
for item in service_client.get_server_capabilities().supported_models:
717
print("- " + item.model_name)
818

919

10-
#%%
20+
# Step 3: Create a REST client for querying training runs and checkpoints.
21+
# This is useful for inspecting previous training sessions or resuming training.
1122
rest_client = service_client.create_rest_client()
1223

1324
future = rest_client.list_training_runs(limit=50)
1425
response = future.result()
26+
27+
# You can resume from a twinkle:// path. Example:
1528
# resume_path = "twinkle://20260131_170251-Qwen_Qwen2_5-0_5B-Instruct-7275126c/weights/pig-latin-lora-epoch-1"
1629
resume_path = ""
30+
1731
print(f"Found {len(response.training_runs)} training runs")
1832
for tr in response.training_runs:
1933
print(tr.model_dump_json(indent=2))
2034

2135
chpts = rest_client.list_checkpoints(tr.training_run_id).result()
2236
for chpt in chpts.checkpoints:
2337
print(" " + chpt.model_dump_json(indent=2))
24-
# resume_path = chpt.tinker_path # Just get the last one for demo purposes
38+
# Uncomment the line below to resume from the last checkpoint:
39+
# resume_path = chpt.tinker_path
2540

26-
#%%
41+
# Step 4: Create or resume a training client.
42+
# If resume_path is set, it restores both model weights and optimizer state.
2743
base_model = "Qwen/Qwen2.5-0.5B-Instruct"
2844
if not resume_path:
2945
training_client = service_client.create_lora_training_client(
@@ -32,8 +48,10 @@
3248
else:
3349
training_client = service_client.create_training_client_from_state_with_optimizer(path=resume_path)
3450

35-
#%%
36-
# Create some training examples
51+
# Step 5: Prepare training data manually
52+
#
53+
# This example teaches the model to translate English into Pig Latin.
54+
# Each example has an "input" (English phrase) and "output" (Pig Latin).
3755
examples = [
3856
{"input": "banana split", "output": "anana-bay plit-say"},
3957
{"input": "quantum physics", "output": "uantum-qay ysics-phay"},
@@ -44,82 +62,97 @@
4462
{"input": "coding wizard", "output": "oding-cay izard-way"},
4563
]
4664

47-
# Convert examples into the format expected by the training client
4865
from tinker import types
4966
from modelscope import AutoTokenizer
50-
# Get the tokenizer from the training client
51-
# tokenizer = training_client.get_tokenizer() # NOTE: network call huggingface
67+
68+
# Load the tokenizer locally (avoids a network call to HuggingFace)
5269
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
5370

5471
def process_example(example: dict, tokenizer) -> types.Datum:
55-
# Format the input with Input/Output template
56-
# For most real use cases, you'll want to use a renderer / chat template,
57-
# (see later docs) but here, we'll keep it simple.
72+
"""Convert a raw example dict into a Datum suitable for the training API.
73+
74+
The Datum contains:
75+
- model_input: the token IDs fed into the LLM
76+
- loss_fn_inputs: target tokens and per-token weights (0 = ignore, 1 = train)
77+
"""
78+
# Build a simple prompt template
5879
prompt = f"English: {example['input']}\nPig Latin:"
5980

81+
# Tokenize the prompt; weights=0 means the loss ignores these tokens
6082
prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True)
6183
prompt_weights = [0] * len(prompt_tokens)
62-
# Add a space before the output string, and finish with double newline
84+
85+
# Tokenize the completion; weights=1 means the loss is computed on these tokens
6386
completion_tokens = tokenizer.encode(f" {example['output']}\n\n", add_special_tokens=False)
6487
completion_weights = [1] * len(completion_tokens)
6588

89+
# Concatenate prompt + completion
6690
tokens = prompt_tokens + completion_tokens
6791
weights = prompt_weights + completion_weights
6892

93+
# Shift by one: input is tokens[:-1], target is tokens[1:] (next-token prediction)
6994
input_tokens = tokens[:-1]
70-
target_tokens = tokens[1:] # We're predicting the next token, so targets need to be shifted.
95+
target_tokens = tokens[1:]
7196
weights = weights[1:]
7297

73-
# A datum is a single training example for the loss function.
74-
# It has model_input, which is the input sequence that'll be passed into the LLM,
75-
# loss_fn_inputs, which is a dictionary of extra inputs used by the loss function.
7698
return types.Datum(
7799
model_input=types.ModelInput.from_ints(tokens=input_tokens),
78100
loss_fn_inputs=dict(weights=weights, target_tokens=target_tokens)
79101
)
80102

103+
# Process all examples into Datum objects
81104
processed_examples = [process_example(ex, tokenizer) for ex in examples]
82105

83-
# Visualize the first example for debugging purposes
106+
# Visualize the first example to verify tokenization and weight alignment
84107
datum0 = processed_examples[0]
85108
print(f"{'Input':<20} {'Target':<20} {'Weight':<10}")
86109
print("-" * 50)
87110
for i, (inp, tgt, wgt) in enumerate(zip(datum0.model_input.to_ints(), datum0.loss_fn_inputs['target_tokens'].tolist(), datum0.loss_fn_inputs['weights'].tolist())):
88111
print(f"{repr(tokenizer.decode([inp])):<20} {repr(tokenizer.decode([tgt])):<20} {wgt:<10}")
89112

90-
#%%
113+
# Step 6: Run the training loop
114+
#
115+
# For each epoch, iterate over multiple batches:
116+
# - forward_backward: sends data to the server, computes loss & gradients
117+
# - optim_step: updates model weights using Adam optimizer
91118
import numpy as np
92119
for epoch in range(2):
93120
for batch in range(5):
94-
121+
# Send training data and get back logprobs (asynchronous futures)
95122
fwdbwd_future = training_client.forward_backward(processed_examples, "cross_entropy")
96123
optim_future = training_client.optim_step(types.AdamParams(learning_rate=1e-4))
97124

98-
# Wait for the results
125+
# Wait for results from the server
99126
fwdbwd_result = fwdbwd_future.result()
100127
optim_result = optim_future.result()
101128

102-
# fwdbwd_result contains the logprobs of all the tokens we put in. Now we can compute the weighted
103-
# average log loss per token.
129+
# Compute the weighted average log-loss per token for monitoring
104130
print(f"Epoch {epoch}, Batch {batch}: ", end="")
105131
logprobs = np.concatenate([output['logprobs'].tolist() for output in fwdbwd_result.loss_fn_outputs])
106132
weights = np.concatenate([example.loss_fn_inputs['weights'].tolist() for example in processed_examples])
107133
print(f"Loss per token: {-np.dot(logprobs, weights) / weights.sum():.4f}")
108134

135+
# Save checkpoint (model weights + optimizer state) after each epoch
109136
save_future = training_client.save_state(f"pig-latin-lora-epoch-{epoch}")
110137
save_result = save_future.result()
111138
print(f"Saved checkpoint for epoch {epoch} to {save_result.path}")
112139

113-
#%%
114-
# First, create a sampling client. We need to transfer weights
140+
# Step 7: Sample from the trained model
141+
#
142+
# Save the current weights and create a sampling client to generate text.
115143
sampling_client = training_client.save_weights_and_get_sampling_client(name='pig-latin-model')
116144

117-
# Now, we can sample from the model.
145+
# Prepare a prompt and sampling parameters
118146
prompt = types.ModelInput.from_ints(tokenizer.encode("English: coffee break\nPig Latin:"))
119-
params = types.SamplingParams(max_tokens=20, temperature=0.0, stop=["\n"]) # Greedy sampling
147+
params = types.SamplingParams(
148+
max_tokens=20, # Maximum number of tokens to generate
149+
temperature=0.0, # Greedy sampling (deterministic)
150+
stop=["\n"] # Stop at newline
151+
)
152+
153+
# Generate 8 completions and print the results
120154
future = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=8)
121155
result = future.result()
122156
print("Responses:")
123157
for i, seq in enumerate(result.sequences):
124158
print(f"{i}: {repr(tokenizer.decode(seq.tokens))}")
125-
# %%
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Twinkle Server Launcher - Tinker-Compatible Megatron Backend
2+
#
3+
# This script starts the Twinkle server with Tinker-compatible API support
4+
# using the Megatron model backend.
5+
# It reads the server_config.yaml in the same directory for all
6+
# configuration (model, deployment settings, etc.).
7+
# Run this script BEFORE running the client training script (lora.py).
8+
9+
import os
10+
11+
# Enable Ray debug mode for verbose logging during development
12+
os.environ['RAY_DEBUG'] = '1'
13+
14+
from twinkle.server import launch_server
15+
16+
# Resolve the path to server_config.yaml relative to this script's location
17+
file_dir = os.path.abspath(os.path.dirname(__file__))
18+
config_path = os.path.join(file_dir, 'server_config.yaml')
19+
20+
# Launch the Twinkle server — this call blocks until the server is shut down
21+
launch_server(config_path=config_path)
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Twinkle Server Configuration - Tinker-Compatible Megatron Backend
2+
3+
# Server protocol type: "tinker" enables the Tinker-compatible API
4+
server_type: tinker
5+
6+
# proxy_location: determines where the HTTP proxy runs.
7+
# "EveryNode" means each Ray node runs its own proxy (good for multi-node).
8+
proxy_location: EveryNode
9+
10+
# HTTP listener settings
11+
http_options:
12+
host: 0.0.0.0 # Listen on all network interfaces
13+
port: 8000 # Port number for the server
14+
15+
# Applications: each entry defines a service component deployed on the server
16+
applications:
17+
18+
# 1. TinkerCompatServer - The central API server
19+
# Handles client connections, training run tracking, checkpoint listing.
20+
- name: server
21+
route_prefix: /api/v1 # API endpoint prefix (Tinker-compatible)
22+
import_path: server # Python module to import
23+
args:
24+
25+
deployments:
26+
- name: TinkerCompatServer
27+
autoscaling_config:
28+
min_replicas: 1 # Minimum number of replicas
29+
max_replicas: 1 # Maximum number of replicas
30+
target_ongoing_requests: 128 # Target concurrent requests per replica
31+
ray_actor_options:
32+
num_cpus: 0.1 # CPU resources allocated to this actor
33+
34+
# 2. Model Service - Hosts the base model for training (Megatron backend)
35+
# This is the actual model worker that performs forward/backward passes.
36+
- name: models-Qwen2.5-0.5B-Instruct
37+
route_prefix: /api/v1/model/Qwen/Qwen2.5-0.5B-Instruct # REST path for this model
38+
import_path: model
39+
args:
40+
use_megatron: true # Use Megatron-LM backend (not HuggingFace)
41+
model_id: "ms://Qwen/Qwen2.5-0.5B-Instruct" # ModelScope model identifier to load
42+
nproc_per_node: 2 # Number of GPU processes per node
43+
device_group: # Logical device group for this model
44+
name: model
45+
ranks: [0, 1] # GPU rank indices to use
46+
device_type: cuda
47+
device_mesh: # Distributed training mesh configuration
48+
device_type: cuda
49+
mesh: [0, 1] # Device indices in the mesh
50+
mesh_dim_names: ['dp'] # Mesh dimension names: 'dp' = data parallel
51+
deployments:
52+
- name: ModelManagement
53+
autoscaling_config:
54+
min_replicas: 1
55+
max_replicas: 1
56+
target_ongoing_requests: 16
57+
ray_actor_options:
58+
num_cpus: 0.1

0 commit comments

Comments
 (0)