|
5 | 5 | # The server must be running first (see server.py and server_config.yaml). |
6 | 6 |
|
7 | 7 | from tinker import types |
| 8 | + |
| 9 | +from twinkle.data_format import Message, Trajectory |
| 10 | +from twinkle.template import Template |
8 | 11 | from twinkle_client import init_tinker_compat_client |
9 | 12 | from modelscope import AutoTokenizer |
10 | 13 |
|
11 | 14 | # Step 1: Define the base model and connect to the server |
12 | | -base_model = "Qwen/Qwen2.5-7B-Instruct" |
13 | | -service_client = init_tinker_compat_client(base_url='http://localhost:8000', api_key="tml-EMPTY_TOKEN") |
| 15 | +base_model = "Qwen/Qwen3-30B-A3B-Instruct-2507" |
| 16 | +service_client = init_tinker_compat_client(base_url='http://www.modelscope.cn/twinkle', api_key=os.environ.get('MODELSCOPE_SDK_TOKEN')) |
14 | 17 |
|
15 | 18 | # Step 2: Create a sampling client by loading weights from a saved checkpoint. |
16 | 19 | # The model_path is a twinkle:// URI pointing to a previously saved LoRA checkpoint. |
17 | 20 | # The server will load the base model and apply the LoRA adapter weights. |
18 | 21 | sampling_client = service_client.create_sampling_client( |
19 | | - model_path="twinkle://20260130_133245-Qwen_Qwen2_5-0_5B-Instruct-ffebd239/weights/pig-latin-lora-epoch-1", |
| 22 | + model_path="twinkle://xxx-Qwen_Qwen3-30B-A3B-Instruct-2507-xxx/weights/twinkle-lora-1", |
20 | 23 | base_model=base_model) |
21 | 24 |
|
22 | 25 | # Step 3: Load the tokenizer locally to encode the prompt and decode the results |
23 | 26 | print(f"Using model {base_model}") |
24 | | -tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) |
| 27 | +template = Template(model_id='ms://Qwen/Qwen3-30B-A3B-Instruct-2507') |
| 28 | + |
| 29 | +trajectory = Trajectory( |
| 30 | + messages=[ |
| 31 | + Message(role='system', content='You are a helpful assistant'), |
| 32 | + Message(role='user', content="你是谁?"), |
| 33 | + ] |
| 34 | +) |
| 35 | + |
| 36 | +input_features = template.batch_encode([trajectory], add_generation_prompt=True) |
| 37 | + |
| 38 | +input_ids = input_features[0]['input_ids'] |
25 | 39 |
|
26 | 40 | # Step 4: Prepare the prompt and sampling parameters |
27 | | -prompt = types.ModelInput.from_ints(tokenizer.encode("English: coffee break\nPig Latin:")) |
| 41 | +prompt = types.ModelInput.from_ints(list(input_ids)) |
28 | 42 | params = types.SamplingParams( |
29 | | - max_tokens=20, # Maximum number of tokens to generate |
| 43 | + max_tokens=128, # Maximum number of tokens to generate |
30 | 44 | temperature=0.0, # Greedy sampling (deterministic, always pick the top token) |
31 | 45 | stop=["\n"] # Stop generation when a newline character is produced |
32 | 46 | ) |
33 | 47 |
|
34 | 48 | # Step 5: Send the sampling request to the server. |
35 | 49 | # num_samples=8 generates 8 independent completions for the same prompt. |
36 | 50 | print("Sampling...") |
37 | | -future = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=8) |
| 51 | +future = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1) |
38 | 52 | result = future.result() |
39 | 53 |
|
40 | 54 | # Step 6: Decode and print the generated responses |
41 | 55 | print("Responses:") |
42 | 56 | for i, seq in enumerate(result.sequences): |
43 | | - print(f"{i}: {repr(tokenizer.decode(seq.tokens))}") |
| 57 | + print(f"{i}: {repr(template.decode(seq.tokens))}") |
0 commit comments