Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,6 @@ test.py
# C extensions
*.so

src/twinkle_client/dataloader
src/twinkle_client/dataset
src/twinkle_client/model
src/twinkle_client/processor
src/twinkle_client/reward
src/twinkle_client/sampler

# Distribution / packaging
.Python
build/
Expand Down
25 changes: 19 additions & 6 deletions client_tools/client_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,17 @@
from pathlib import Path
from typing import Dict, List, Tuple, Set

AUTO_GEN_WARNING = """# ============================================================================
# WARNING: AUTO-GENERATED FILE - DO NOT MODIFY MANUALLY!
# ============================================================================
# This file is automatically generated by client_tools/client_generator.py
# Any manual changes will be overwritten when the generator runs again.
#
# To update this file:
# 1. Modify the source files in src/twinkle/
# 2. Run: python client_tools/client_generator.py
# ============================================================================
"""

def generate_processors():
"""Generate client wrappers for all classes with @remote_function methods."""
Expand Down Expand Up @@ -326,7 +337,8 @@ def __next__(self):
init_params = "self, **kwargs"
kwargs_dict = "kwargs"

class_template = f'''{chr(10).join(import_lines)}
class_template = f'''{AUTO_GEN_WARNING}
{chr(10).join(import_lines)}
class {class_name}({inheritance}):
"""Client wrapper for {class_name} that calls server HTTP endpoints."""

Expand Down Expand Up @@ -414,7 +426,8 @@ def write_init_files(module_files: Dict, src_client_path: Path) -> None:
for source_filename, classes in sorted(source_files.items())
for class_name, _, _, _ in classes
]
init_file.write_text('\n'.join(sorted(init_lines)) + '\n', encoding='utf-8')
init_content = AUTO_GEN_WARNING + '\n'.join(sorted(init_lines)) + '\n'
init_file.write_text(init_content, encoding='utf-8')

module_files = scan_modules(src_twinkle_path, module_mapping)
write_client_files(module_files, src_client_path, processor_type_mapping)
Expand All @@ -432,7 +445,7 @@ def generate_models():
client_module_path = src_client_path / 'model'
client_module_path.mkdir(parents=True, exist_ok=True)

model_code = '''from typing import Any, Optional, Union, Type, Dict, Literal, List
model_code = AUTO_GEN_WARNING + '''from typing import Any, Optional, Union, Type, Dict, Literal, List
import uuid
from twinkle_client.http import TWINKLE_SERVER_URL
from twinkle_client.http import http_post, heartbeat_manager
Expand Down Expand Up @@ -693,7 +706,7 @@ def upload_to_hub(self, checkpoint_dir: str, hub_model_id: str, hub_token: Optio

# Create/overwrite __init__.py
init_file = client_module_path / '__init__.py'
init_content = "from .multi_lora_transformers import MultiLoraTransformersModel\n"
init_content = AUTO_GEN_WARNING + "from .multi_lora_transformers import MultiLoraTransformersModel\n"
print(f"Writing {init_file}...")
with open(init_file, 'w', encoding='utf-8') as f:
f.write(init_content)
Expand All @@ -710,7 +723,7 @@ def generate_samplers():
client_module_path = src_client_path / 'sampler'
client_module_path.mkdir(parents=True, exist_ok=True)

sampler_code = '''from typing import Any, Optional, List, Dict, Union
sampler_code = AUTO_GEN_WARNING + '''from typing import Any, Optional, List, Dict, Union
import uuid
from twinkle_client.http import TWINKLE_SERVER_URL
from twinkle_client.http import http_post, heartbeat_manager
Comment on lines +726 to 729
Copy link

Copilot AI Feb 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sampler_code template defined here generates VLLMSampler.__init__ with a return response.json(). That makes the generated client invalid (constructors must return None) and will also fail when /create returns a non-JSON/empty body. Please adjust the template to not return from __init__ and to only parse JSON for endpoints that actually return JSON.

Copilot uses AI. Check for mistakes.
Expand Down Expand Up @@ -837,7 +850,7 @@ def set_template(self, template_cls: str, adapter_name: str = '', **kwargs):

# Create/overwrite __init__.py
init_file = client_module_path / '__init__.py'
init_content = "from .vllm_sampler import VLLMSampler\n"
init_content = AUTO_GEN_WARNING + "from .vllm_sampler import VLLMSampler\n"
print(f"Writing {init_file}...")
with open(init_file, 'w', encoding='utf-8') as f:
f.write(init_content)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,29 +1,45 @@
#%%
# Tinker-Compatible Client - Megatron LoRA Training & Sampling Example
#
# This script demonstrates end-to-end LoRA fine-tuning and inference using the
# Tinker-compatible client API with a Megatron backend.
# It covers: connecting to the server, preparing data manually with tokenizers,
# running a training loop, saving checkpoints, and sampling from the model.
# The server must be running first (see server.py and server_config.yaml).

from twinkle_client import init_tinker_compat_client

# Step 1: Initialize the Tinker-compatible client to communicate with the server.
service_client = init_tinker_compat_client(base_url='http://localhost:8000')

# Step 2: List models available on the server to verify the connection
print("Available models:")
for item in service_client.get_server_capabilities().supported_models:
print("- " + item.model_name)


#%%
# Step 3: Create a REST client for querying training runs and checkpoints.
# This is useful for inspecting previous training sessions or resuming training.
rest_client = service_client.create_rest_client()

future = rest_client.list_training_runs(limit=50)
response = future.result()

# You can resume from a twinkle:// path. Example:
# resume_path = "twinkle://20260131_170251-Qwen_Qwen2_5-0_5B-Instruct-7275126c/weights/pig-latin-lora-epoch-1"
resume_path = ""

print(f"Found {len(response.training_runs)} training runs")
for tr in response.training_runs:
print(tr.model_dump_json(indent=2))

chpts = rest_client.list_checkpoints(tr.training_run_id).result()
for chpt in chpts.checkpoints:
print(" " + chpt.model_dump_json(indent=2))
# resume_path = chpt.tinker_path # Just get the last one for demo purposes
# Uncomment the line below to resume from the last checkpoint:
# resume_path = chpt.tinker_path

#%%
# Step 4: Create or resume a training client.
# If resume_path is set, it restores both model weights and optimizer state.
base_model = "Qwen/Qwen2.5-0.5B-Instruct"
if not resume_path:
training_client = service_client.create_lora_training_client(
Expand All @@ -32,8 +48,10 @@
else:
training_client = service_client.create_training_client_from_state_with_optimizer(path=resume_path)

#%%
# Create some training examples
# Step 5: Prepare training data manually
#
# This example teaches the model to translate English into Pig Latin.
# Each example has an "input" (English phrase) and "output" (Pig Latin).
examples = [
{"input": "banana split", "output": "anana-bay plit-say"},
{"input": "quantum physics", "output": "uantum-qay ysics-phay"},
Expand All @@ -44,82 +62,97 @@
{"input": "coding wizard", "output": "oding-cay izard-way"},
]

# Convert examples into the format expected by the training client
from tinker import types
from modelscope import AutoTokenizer
# Get the tokenizer from the training client
# tokenizer = training_client.get_tokenizer() # NOTE: network call huggingface

# Load the tokenizer locally (avoids a network call to HuggingFace)
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)

def process_example(example: dict, tokenizer) -> types.Datum:
# Format the input with Input/Output template
# For most real use cases, you'll want to use a renderer / chat template,
# (see later docs) but here, we'll keep it simple.
"""Convert a raw example dict into a Datum suitable for the training API.

The Datum contains:
- model_input: the token IDs fed into the LLM
- loss_fn_inputs: target tokens and per-token weights (0 = ignore, 1 = train)
"""
# Build a simple prompt template
prompt = f"English: {example['input']}\nPig Latin:"

# Tokenize the prompt; weights=0 means the loss ignores these tokens
prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True)
prompt_weights = [0] * len(prompt_tokens)
# Add a space before the output string, and finish with double newline

# Tokenize the completion; weights=1 means the loss is computed on these tokens
completion_tokens = tokenizer.encode(f" {example['output']}\n\n", add_special_tokens=False)
completion_weights = [1] * len(completion_tokens)

# Concatenate prompt + completion
tokens = prompt_tokens + completion_tokens
weights = prompt_weights + completion_weights

# Shift by one: input is tokens[:-1], target is tokens[1:] (next-token prediction)
input_tokens = tokens[:-1]
target_tokens = tokens[1:] # We're predicting the next token, so targets need to be shifted.
target_tokens = tokens[1:]
weights = weights[1:]

# A datum is a single training example for the loss function.
# It has model_input, which is the input sequence that'll be passed into the LLM,
# loss_fn_inputs, which is a dictionary of extra inputs used by the loss function.
return types.Datum(
model_input=types.ModelInput.from_ints(tokens=input_tokens),
loss_fn_inputs=dict(weights=weights, target_tokens=target_tokens)
)

# Process all examples into Datum objects
processed_examples = [process_example(ex, tokenizer) for ex in examples]

# Visualize the first example for debugging purposes
# Visualize the first example to verify tokenization and weight alignment
datum0 = processed_examples[0]
print(f"{'Input':<20} {'Target':<20} {'Weight':<10}")
print("-" * 50)
for i, (inp, tgt, wgt) in enumerate(zip(datum0.model_input.to_ints(), datum0.loss_fn_inputs['target_tokens'].tolist(), datum0.loss_fn_inputs['weights'].tolist())):
print(f"{repr(tokenizer.decode([inp])):<20} {repr(tokenizer.decode([tgt])):<20} {wgt:<10}")

#%%
# Step 6: Run the training loop
#
# For each epoch, iterate over multiple batches:
# - forward_backward: sends data to the server, computes loss & gradients
# - optim_step: updates model weights using Adam optimizer
import numpy as np
for epoch in range(2):
for batch in range(5):

# Send training data and get back logprobs (asynchronous futures)
fwdbwd_future = training_client.forward_backward(processed_examples, "cross_entropy")
optim_future = training_client.optim_step(types.AdamParams(learning_rate=1e-4))

# Wait for the results
# Wait for results from the server
fwdbwd_result = fwdbwd_future.result()
optim_result = optim_future.result()

# fwdbwd_result contains the logprobs of all the tokens we put in. Now we can compute the weighted
# average log loss per token.
# Compute the weighted average log-loss per token for monitoring
print(f"Epoch {epoch}, Batch {batch}: ", end="")
logprobs = np.concatenate([output['logprobs'].tolist() for output in fwdbwd_result.loss_fn_outputs])
weights = np.concatenate([example.loss_fn_inputs['weights'].tolist() for example in processed_examples])
print(f"Loss per token: {-np.dot(logprobs, weights) / weights.sum():.4f}")

# Save checkpoint (model weights + optimizer state) after each epoch
save_future = training_client.save_state(f"pig-latin-lora-epoch-{epoch}")
save_result = save_future.result()
print(f"Saved checkpoint for epoch {epoch} to {save_result.path}")

#%%
# First, create a sampling client. We need to transfer weights
# Step 7: Sample from the trained model
#
# Save the current weights and create a sampling client to generate text.
sampling_client = training_client.save_weights_and_get_sampling_client(name='pig-latin-model')

# Now, we can sample from the model.
# Prepare a prompt and sampling parameters
prompt = types.ModelInput.from_ints(tokenizer.encode("English: coffee break\nPig Latin:"))
params = types.SamplingParams(max_tokens=20, temperature=0.0, stop=["\n"]) # Greedy sampling
params = types.SamplingParams(
max_tokens=20, # Maximum number of tokens to generate
temperature=0.0, # Greedy sampling (deterministic)
stop=["\n"] # Stop at newline
)

# Generate 8 completions and print the results
future = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=8)
result = future.result()
print("Responses:")
for i, seq in enumerate(result.sequences):
print(f"{i}: {repr(tokenizer.decode(seq.tokens))}")
# %%
21 changes: 21 additions & 0 deletions cookbook/client/tinker/megatron/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Twinkle Server Launcher - Tinker-Compatible Megatron Backend
#
# This script starts the Twinkle server with Tinker-compatible API support
# using the Megatron model backend.
# It reads the server_config.yaml in the same directory for all
# configuration (model, deployment settings, etc.).
# Run this script BEFORE running the client training script (lora.py).

import os

# Enable Ray debug mode for verbose logging during development
os.environ['RAY_DEBUG'] = '1'

from twinkle.server import launch_server

# Resolve the path to server_config.yaml relative to this script's location
file_dir = os.path.abspath(os.path.dirname(__file__))
config_path = os.path.join(file_dir, 'server_config.yaml')

# Launch the Twinkle server — this call blocks until the server is shut down
launch_server(config_path=config_path)
58 changes: 58 additions & 0 deletions cookbook/client/tinker/megatron/server_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Twinkle Server Configuration - Tinker-Compatible Megatron Backend

# Server protocol type: "tinker" enables the Tinker-compatible API
server_type: tinker

# proxy_location: determines where the HTTP proxy runs.
# "EveryNode" means each Ray node runs its own proxy (good for multi-node).
proxy_location: EveryNode

# HTTP listener settings
http_options:
host: 0.0.0.0 # Listen on all network interfaces
port: 8000 # Port number for the server

# Applications: each entry defines a service component deployed on the server
applications:

# 1. TinkerCompatServer - The central API server
# Handles client connections, training run tracking, checkpoint listing.
- name: server
route_prefix: /api/v1 # API endpoint prefix (Tinker-compatible)
import_path: server # Python module to import
args:

deployments:
- name: TinkerCompatServer
autoscaling_config:
min_replicas: 1 # Minimum number of replicas
max_replicas: 1 # Maximum number of replicas
target_ongoing_requests: 128 # Target concurrent requests per replica
ray_actor_options:
num_cpus: 0.1 # CPU resources allocated to this actor

# 2. Model Service - Hosts the base model for training (Megatron backend)
# This is the actual model worker that performs forward/backward passes.
- name: models-Qwen2.5-0.5B-Instruct
route_prefix: /api/v1/model/Qwen/Qwen2.5-0.5B-Instruct # REST path for this model
import_path: model
args:
use_megatron: true # Use Megatron-LM backend (not HuggingFace)
model_id: "ms://Qwen/Qwen2.5-0.5B-Instruct" # ModelScope model identifier to load
nproc_per_node: 2 # Number of GPU processes per node
device_group: # Logical device group for this model
name: model
ranks: [0, 1] # GPU rank indices to use
device_type: cuda
device_mesh: # Distributed training mesh configuration
device_type: cuda
mesh: [0, 1] # Device indices in the mesh
mesh_dim_names: ['dp'] # Mesh dimension names: 'dp' = data parallel
deployments:
- name: ModelManagement
autoscaling_config:
min_replicas: 1
max_replicas: 1
target_ongoing_requests: 16
ray_actor_options:
num_cpus: 0.1
Loading
Loading