Skip to content

Commit 68c0c1e

Browse files
Add a sample script (#97)
1 parent bbbeddd commit 68c0c1e

File tree

2 files changed

+71
-2
lines changed

2 files changed

+71
-2
lines changed

cookbook/client/tinker/custom_service/sample.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
trajectory = Trajectory(
4040
messages=[
4141
Message(role='system', content='You are a helpful assistant'),
42-
Message(role='user', content='你是谁?'),
42+
Message(role='user', content='Who are you?'),
4343
]
4444
)
4545

@@ -56,7 +56,7 @@
5656
)
5757

5858
# Step 6: Send the sampling request to the server.
59-
# num_samples=8 generates 8 independent completions for the same prompt.
59+
# num_samples=1 generates 1 independent completions for the same prompt.
6060
print('Sampling...')
6161
future = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1)
6262
result = future.result()
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Tinker-Compatible Client - Sampling / Inference Example
2+
#
3+
# This script demonstrates how to use a previously trained LoRA checkpoint
4+
# for text generation (sampling) via the Tinker-compatible client API.
5+
# The server must be running first (see server.py and server_config.yaml).
6+
7+
import os
8+
from tinker import types
9+
10+
from twinkle.data_format import Message, Trajectory
11+
from twinkle.template import Template
12+
from twinkle import init_tinker_client
13+
14+
# Step 1: Initialize Tinker client
15+
init_tinker_client()
16+
17+
from tinker import ServiceClient
18+
19+
base_model = 'Qwen/Qwen3-30B-A3B-Instruct-2507'
20+
base_url = 'http://www.modelscope.cn/twinkle'
21+
22+
# Step 2: Define the base model and connect to the server
23+
service_client = ServiceClient(
24+
base_url=base_url,
25+
api_key=os.environ.get('MODELSCOPE_TOKEN')
26+
)
27+
28+
# Step 3: Create a sampling client by loading weights from a saved checkpoint.
29+
# The model_path is a twinkle:// URI pointing to a previously saved LoRA checkpoint.
30+
# The server will load the base model and apply the LoRA adapter weights.
31+
sampling_client = service_client.create_sampling_client(
32+
model_path='twinkle://xxx-Qwen_Qwen3-30B-A3B-Instruct-2507-xxx/weights/twinkle-lora-1',
33+
base_model=base_model
34+
)
35+
36+
# Step 4: Load the tokenizer locally to encode the prompt and decode the results
37+
print(f'Using model {base_model}')
38+
39+
template = Template(model_id=f'ms://{base_model}')
40+
41+
trajectory = Trajectory(
42+
messages=[
43+
Message(role='system', content='You are a helpful assistant'),
44+
Message(role='user', content='Who are you?'),
45+
]
46+
)
47+
48+
input_feature = template.encode(trajectory, add_generation_prompt=True)
49+
50+
input_ids = input_feature['input_ids'].tolist()
51+
52+
# Step 5: Prepare the prompt and sampling parameters
53+
prompt = types.ModelInput.from_ints(input_ids)
54+
params = types.SamplingParams(
55+
max_tokens=128, # Maximum number of tokens to generate
56+
temperature=0.7,
57+
stop=['\n'] # Stop generation when a newline character is produced
58+
)
59+
60+
# Step 6: Send the sampling request to the server.
61+
# num_samples=1 generates 1 independent completions for the same prompt.
62+
print('Sampling...')
63+
future = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1)
64+
result = future.result()
65+
66+
# Step 7: Decode and print the generated responses
67+
print('Responses:')
68+
for i, seq in enumerate(result.sequences):
69+
print(f'{i}: {repr(template.decode(seq.tokens))}')

0 commit comments

Comments
 (0)