|
1 | | -#%% |
| 1 | +# Tinker-Compatible Client - Megatron LoRA Training & Sampling Example |
| 2 | +# |
| 3 | +# This script demonstrates end-to-end LoRA fine-tuning and inference using the |
| 4 | +# Tinker-compatible client API with a Megatron backend. |
| 5 | +# It covers: connecting to the server, preparing data manually with tokenizers, |
| 6 | +# running a training loop, saving checkpoints, and sampling from the model. |
| 7 | +# The server must be running first (see server.py and server_config.yaml). |
| 8 | + |
2 | 9 | from twinkle_client import init_tinker_compat_client |
| 10 | + |
| 11 | +# Step 1: Initialize the Tinker-compatible client to communicate with the server. |
3 | 12 | service_client = init_tinker_compat_client(base_url='http://localhost:8000') |
4 | 13 |
|
| 14 | +# Step 2: List models available on the server to verify the connection |
5 | 15 | print("Available models:") |
6 | 16 | for item in service_client.get_server_capabilities().supported_models: |
7 | 17 | print("- " + item.model_name) |
8 | 18 |
|
9 | 19 |
|
10 | | -#%% |
| 20 | +# Step 3: Create a REST client for querying training runs and checkpoints. |
| 21 | +# This is useful for inspecting previous training sessions or resuming training. |
11 | 22 | rest_client = service_client.create_rest_client() |
12 | 23 |
|
13 | 24 | future = rest_client.list_training_runs(limit=50) |
14 | 25 | response = future.result() |
| 26 | + |
| 27 | +# You can resume from a twinkle:// path. Example: |
15 | 28 | # resume_path = "twinkle://20260131_170251-Qwen_Qwen2_5-0_5B-Instruct-7275126c/weights/pig-latin-lora-epoch-1" |
16 | 29 | resume_path = "" |
| 30 | + |
17 | 31 | print(f"Found {len(response.training_runs)} training runs") |
18 | 32 | for tr in response.training_runs: |
19 | 33 | print(tr.model_dump_json(indent=2)) |
20 | 34 |
|
21 | 35 | chpts = rest_client.list_checkpoints(tr.training_run_id).result() |
22 | 36 | for chpt in chpts.checkpoints: |
23 | 37 | print(" " + chpt.model_dump_json(indent=2)) |
24 | | - # resume_path = chpt.tinker_path # Just get the last one for demo purposes |
| 38 | + # Uncomment the line below to resume from the last checkpoint: |
| 39 | + # resume_path = chpt.tinker_path |
25 | 40 |
|
26 | | -#%% |
| 41 | +# Step 4: Create or resume a training client. |
| 42 | +# If resume_path is set, it restores both model weights and optimizer state. |
27 | 43 | base_model = "Qwen/Qwen2.5-0.5B-Instruct" |
28 | 44 | if not resume_path: |
29 | 45 | training_client = service_client.create_lora_training_client( |
|
32 | 48 | else: |
33 | 49 | training_client = service_client.create_training_client_from_state_with_optimizer(path=resume_path) |
34 | 50 |
|
35 | | -#%% |
36 | | -# Create some training examples |
| 51 | +# Step 5: Prepare training data manually |
| 52 | +# |
| 53 | +# This example teaches the model to translate English into Pig Latin. |
| 54 | +# Each example has an "input" (English phrase) and "output" (Pig Latin). |
37 | 55 | examples = [ |
38 | 56 | {"input": "banana split", "output": "anana-bay plit-say"}, |
39 | 57 | {"input": "quantum physics", "output": "uantum-qay ysics-phay"}, |
|
44 | 62 | {"input": "coding wizard", "output": "oding-cay izard-way"}, |
45 | 63 | ] |
46 | 64 |
|
47 | | -# Convert examples into the format expected by the training client |
48 | 65 | from tinker import types |
49 | 66 | from modelscope import AutoTokenizer |
50 | | -# Get the tokenizer from the training client |
51 | | -# tokenizer = training_client.get_tokenizer() # NOTE: network call huggingface |
| 67 | + |
| 68 | +# Load the tokenizer locally (avoids a network call to HuggingFace) |
52 | 69 | tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) |
53 | 70 |
|
54 | 71 | def process_example(example: dict, tokenizer) -> types.Datum: |
55 | | - # Format the input with Input/Output template |
56 | | - # For most real use cases, you'll want to use a renderer / chat template, |
57 | | - # (see later docs) but here, we'll keep it simple. |
| 72 | + """Convert a raw example dict into a Datum suitable for the training API. |
| 73 | +
|
| 74 | + The Datum contains: |
| 75 | + - model_input: the token IDs fed into the LLM |
| 76 | + - loss_fn_inputs: target tokens and per-token weights (0 = ignore, 1 = train) |
| 77 | + """ |
| 78 | + # Build a simple prompt template |
58 | 79 | prompt = f"English: {example['input']}\nPig Latin:" |
59 | 80 |
|
| 81 | + # Tokenize the prompt; weights=0 means the loss ignores these tokens |
60 | 82 | prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True) |
61 | 83 | prompt_weights = [0] * len(prompt_tokens) |
62 | | - # Add a space before the output string, and finish with double newline |
| 84 | + |
| 85 | + # Tokenize the completion; weights=1 means the loss is computed on these tokens |
63 | 86 | completion_tokens = tokenizer.encode(f" {example['output']}\n\n", add_special_tokens=False) |
64 | 87 | completion_weights = [1] * len(completion_tokens) |
65 | 88 |
|
| 89 | + # Concatenate prompt + completion |
66 | 90 | tokens = prompt_tokens + completion_tokens |
67 | 91 | weights = prompt_weights + completion_weights |
68 | 92 |
|
| 93 | + # Shift by one: input is tokens[:-1], target is tokens[1:] (next-token prediction) |
69 | 94 | input_tokens = tokens[:-1] |
70 | | - target_tokens = tokens[1:] # We're predicting the next token, so targets need to be shifted. |
| 95 | + target_tokens = tokens[1:] |
71 | 96 | weights = weights[1:] |
72 | 97 |
|
73 | | - # A datum is a single training example for the loss function. |
74 | | - # It has model_input, which is the input sequence that'll be passed into the LLM, |
75 | | - # loss_fn_inputs, which is a dictionary of extra inputs used by the loss function. |
76 | 98 | return types.Datum( |
77 | 99 | model_input=types.ModelInput.from_ints(tokens=input_tokens), |
78 | 100 | loss_fn_inputs=dict(weights=weights, target_tokens=target_tokens) |
79 | 101 | ) |
80 | 102 |
|
| 103 | +# Process all examples into Datum objects |
81 | 104 | processed_examples = [process_example(ex, tokenizer) for ex in examples] |
82 | 105 |
|
83 | | -# Visualize the first example for debugging purposes |
| 106 | +# Visualize the first example to verify tokenization and weight alignment |
84 | 107 | datum0 = processed_examples[0] |
85 | 108 | print(f"{'Input':<20} {'Target':<20} {'Weight':<10}") |
86 | 109 | print("-" * 50) |
87 | 110 | for i, (inp, tgt, wgt) in enumerate(zip(datum0.model_input.to_ints(), datum0.loss_fn_inputs['target_tokens'].tolist(), datum0.loss_fn_inputs['weights'].tolist())): |
88 | 111 | print(f"{repr(tokenizer.decode([inp])):<20} {repr(tokenizer.decode([tgt])):<20} {wgt:<10}") |
89 | 112 |
|
90 | | -#%% |
| 113 | +# Step 6: Run the training loop |
| 114 | +# |
| 115 | +# For each epoch, iterate over multiple batches: |
| 116 | +# - forward_backward: sends data to the server, computes loss & gradients |
| 117 | +# - optim_step: updates model weights using Adam optimizer |
91 | 118 | import numpy as np |
92 | 119 | for epoch in range(2): |
93 | 120 | for batch in range(5): |
94 | | - |
| 121 | + # Send training data and get back logprobs (asynchronous futures) |
95 | 122 | fwdbwd_future = training_client.forward_backward(processed_examples, "cross_entropy") |
96 | 123 | optim_future = training_client.optim_step(types.AdamParams(learning_rate=1e-4)) |
97 | 124 |
|
98 | | - # Wait for the results |
| 125 | + # Wait for results from the server |
99 | 126 | fwdbwd_result = fwdbwd_future.result() |
100 | 127 | optim_result = optim_future.result() |
101 | 128 |
|
102 | | - # fwdbwd_result contains the logprobs of all the tokens we put in. Now we can compute the weighted |
103 | | - # average log loss per token. |
| 129 | + # Compute the weighted average log-loss per token for monitoring |
104 | 130 | print(f"Epoch {epoch}, Batch {batch}: ", end="") |
105 | 131 | logprobs = np.concatenate([output['logprobs'].tolist() for output in fwdbwd_result.loss_fn_outputs]) |
106 | 132 | weights = np.concatenate([example.loss_fn_inputs['weights'].tolist() for example in processed_examples]) |
107 | 133 | print(f"Loss per token: {-np.dot(logprobs, weights) / weights.sum():.4f}") |
108 | 134 |
|
| 135 | + # Save checkpoint (model weights + optimizer state) after each epoch |
109 | 136 | save_future = training_client.save_state(f"pig-latin-lora-epoch-{epoch}") |
110 | 137 | save_result = save_future.result() |
111 | 138 | print(f"Saved checkpoint for epoch {epoch} to {save_result.path}") |
112 | 139 |
|
113 | | -#%% |
114 | | -# First, create a sampling client. We need to transfer weights |
| 140 | +# Step 7: Sample from the trained model |
| 141 | +# |
| 142 | +# Save the current weights and create a sampling client to generate text. |
115 | 143 | sampling_client = training_client.save_weights_and_get_sampling_client(name='pig-latin-model') |
116 | 144 |
|
117 | | -# Now, we can sample from the model. |
| 145 | +# Prepare a prompt and sampling parameters |
118 | 146 | prompt = types.ModelInput.from_ints(tokenizer.encode("English: coffee break\nPig Latin:")) |
119 | | -params = types.SamplingParams(max_tokens=20, temperature=0.0, stop=["\n"]) # Greedy sampling |
| 147 | +params = types.SamplingParams( |
| 148 | + max_tokens=20, # Maximum number of tokens to generate |
| 149 | + temperature=0.0, # Greedy sampling (deterministic) |
| 150 | + stop=["\n"] # Stop at newline |
| 151 | +) |
| 152 | + |
| 153 | +# Generate 8 completions and print the results |
120 | 154 | future = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=8) |
121 | 155 | result = future.result() |
122 | 156 | print("Responses:") |
123 | 157 | for i, seq in enumerate(result.sequences): |
124 | 158 | print(f"{i}: {repr(tokenizer.decode(seq.tokens))}") |
125 | | -# %% |
|
0 commit comments