Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
logger = get_logger()

# ========== Configuration ==========
BASE_MODEL = 'Qwen/Qwen2.5-3B-Instruct'
BASE_MODEL = 'Qwen/Qwen2.5-7B-Instruct'
NUM_GENERATIONS = 4
MAX_NEW_TOKENS = 1024
LEARNING_RATE = 1e-5
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#
# The server must be running first (see server.py and server_config.yaml).
# Requires both model and sampler services to be configured.

import os
import gc
import re
import numpy as np
Expand All @@ -38,12 +38,12 @@
logger = get_logger()

# ========== Configuration ==========
BASE_MODEL = 'Qwen/Qwen2.5-3B-Instruct'
BASE_MODEL = 'Qwen/Qwen2.5-7B-Instruct'
NUM_GENERATIONS = 4
MAX_NEW_TOKENS = 2048
LEARNING_RATE = 1e-5
MAX_NEW_TOKENS = 1024
LEARNING_RATE = 1e-4
MAX_STEPS = 100
BATCH_SIZE = 2
BATCH_SIZE = 4
TEMPERATURE = 1.0
SYNC_INTERVAL = 1 # Save weights for sampler every N steps
LORA_RANK = 8
Expand All @@ -56,6 +56,14 @@
"For example:\n<think> ... reasoning ... </think>\n#### 42"
)

# SwanLab experiment tracking
USE_SWANLAB = True
if USE_SWANLAB:
import swanlab
swanlab.login(api_key=os.environ['SWANLAB_API_KEY'])
swanlab.init(project="twinkle-gsm8k", config={
'model_id': BASE_MODEL,
})
Comment on lines +63 to +66
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Accessing os.environ['SWANLAB_API_KEY'] directly will raise a KeyError if the environment variable is not set, causing the script to crash. It's safer to use os.environ.get() and handle the case where the key is missing by raising a more informative error.

    api_key = os.environ.get('SWANLAB_API_KEY')
    if not api_key:
        raise ValueError("SWANLAB_API_KEY environment variable not set, but USE_SWANLAB is True.")
    swanlab.login(api_key=api_key)
    swanlab.init(project="twinkle-gsm8k", config={
        'model_id': BASE_MODEL,
    })


class GSM8KProcessor(Preprocessor):
"""Preprocessor for GSM8K dataset.
Expand Down Expand Up @@ -354,18 +362,15 @@ def main():
ob_len = len(prompt_ids) - 1
input_tokens = prompt_ids + sampled_tokens[:-1]
target_tokens = [0] * ob_len + sampled_tokens
weights = [0] * ob_len + [1] * len(sampled_tokens)
padded_advantages = [0.0] * ob_len + [advantage] * len(sampled_tokens)
padded_logprobs = [0.0] * ob_len + logprobs

# Verify lengths match
assert len(input_tokens) == len(target_tokens) == len(padded_logprobs) == len(padded_advantages), \
f"Length mismatch: input={len(input_tokens)}, target={len(target_tokens)}, " \
f"logprobs={len(padded_logprobs)}, advantages={len(padded_advantages)}"

datum = types.Datum(
model_input=types.ModelInput.from_ints(input_tokens),
loss_fn_inputs={
'target_tokens': target_tokens,
'weights': weights,
'logprobs': types.TensorData.from_numpy(np.array(padded_logprobs, dtype=np.float32)),
'advantages': types.TensorData.from_numpy(np.array(padded_advantages, dtype=np.float32)),
},
Expand All @@ -380,40 +385,22 @@ def main():

# Forward-backward pass with importance_sampling (GRPO) loss
# The training data already contains logprobs and advantages for the GRPO loss
fwdbwd_future = training_client.forward_backward(
training_data, "importance_sampling")
optim_future = training_client.optim_step(
types.AdamParams(learning_rate=LEARNING_RATE))

fwdbwd_result = fwdbwd_future.result()
optim_result = optim_future.result()

# Compute metrics from the forward-backward result
# For importance_sampling, we get logprobs and elementwise_loss
logprobs_list = []
elementwise_losses = []
for output in fwdbwd_result.loss_fn_outputs:
if output.get('logprobs') is not None:
logprobs_list.append(output['logprobs'].to_numpy())
if output.get('elementwise_loss') is not None:
elementwise_losses.append(output['elementwise_loss'].to_numpy())

# Compute average loss per token (weighted by advantages)
if elementwise_losses:
all_losses = np.concatenate(elementwise_losses)
avg_loss = np.mean(all_losses) if len(all_losses) > 0 else 0.0
else:
avg_loss = 0.0
fwdbwd_result = training_client.forward_backward(training_data, "importance_sampling").result()

optim_result = training_client.optim_step(types.AdamParams(learning_rate=LEARNING_RATE)).result()

gc.collect()

# ========== 7. Log ==========
log_dict = metrics.calculate()
log_dict['train/loss_per_token'] = float(avg_loss)
if optim_result.metrics:
log_dict.update(optim_result.metrics)
log_dict['train/frac_reward_zero_std'] = frac_zero_std
log_dict['train/num_training_samples'] = len(training_data)
logger.info(f"Step {step}: {log_dict}")
step += 1
if USE_SWANLAB:
swanlab.log(log_dict)

# Save final checkpoint
save_future = training_client.save_state("gsm8k-grpo-final")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
# 2. A model id on hub: "<user>/<model_id>"
# Example:
# resume_path = "twinkle://20260131_170251-Qwen_Qwen2_5-0_5B-Instruct-7275126c/weights/pig-latin-lora-epoch-1"
# resume_path = "AlexEz/20260205_163645-Qwen_Qwen2_5-3B-Instruct-385d5c17_pig-latin-lora-epoch-1"
# resume_path = "AlexEz/20260205_163645-Qwen_Qwen2_5-7B-Instruct-385d5c17_pig-latin-lora-epoch-1"
resume_path = ""

print(f"Found {len(response.training_runs)} training runs")
Expand All @@ -51,7 +51,7 @@

# Step 5: Create or resume a training client.
# If resume_path is set, it restores both model weights and optimizer state.
base_model = "Qwen/Qwen2.5-3B-Instruct"
base_model = "Qwen/Qwen2.5-7B-Instruct"
if not resume_path:
training_client = service_client.create_lora_training_client(
base_model=base_model
Expand Down
14 changes: 8 additions & 6 deletions cookbook/client/tinker/megatron/server_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ proxy_location: EveryNode
# HTTP listener settings
http_options:
host: 0.0.0.0 # Listen on all network interfaces
port: 8000 # Port number for the server
port: 9000 # Port number for the server

# Applications: each entry defines a service component deployed on the server
applications:
Expand Down Expand Up @@ -50,12 +50,13 @@ applications:
enable_lora: true # Allow loading LoRA adapters during inference
device_group: # Logical device group for the sampler
name: sampler
gpus_per_worker: 1
ranks: [0,1,2,3] # GPU rank indices to use
gpus_per_worker: 2
ranks: [0,1,2,3,4,5,6,7] # GPU rank indices to use
device_type: cuda
device_mesh:
device_type: cuda
dp_size: 4
tp_size: 2
deployments:
- name: SamplerManagement
autoscaling_config:
Expand All @@ -67,7 +68,7 @@ applications:
runtime_env:
env_vars:
TWINKLE_TRUST_REMOTE_CODE: "0"
DEVICE_COUNT_PER_PHYSICAL_NODE: "8"
DEVICE_COUNT_PER_PHYSICAL_NODE: "16"

# 2. Model Service (commented out) - Would host the base model for training.
# Uncomment and configure if you need a training model worker.
Expand All @@ -80,12 +81,13 @@ applications:
nproc_per_node: 4 # Number of GPU processes per node
device_group:
name: model
ranks: [4,5,6,7] # GPU rank indices
ranks: [8,9,10,11,12,13,14,15] # GPU rank indices
device_type: cuda
device_mesh:
device_type: cuda
dp_size: 2
tp_size: 2
pp:size: 2
ep_size: 2

queue_config:
Expand All @@ -105,4 +107,4 @@ applications:
runtime_env:
env_vars:
TWINKLE_TRUST_REMOTE_CODE: "0"
DEVICE_COUNT_PER_PHYSICAL_NODE: "8"
DEVICE_COUNT_PER_PHYSICAL_NODE: "16"
105 changes: 105 additions & 0 deletions cookbook/client/tinker/megatron/server_config_7b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Twinkle Server Configuration - Tinker-Compatible Transformers Backend

# Server protocol type: "tinker" enables the Tinker-compatible API
server_type: tinker

# proxy_location: determines where the HTTP proxy runs.
# "EveryNode" means each Ray node runs its own proxy (good for multi-node).
proxy_location: EveryNode

# HTTP listener settings
http_options:
host: 0.0.0.0 # Listen on all network interfaces
port: 8000 # Port number for the server

# Applications: each entry defines a service component deployed on the server
applications:

# 1. TinkerCompatServer - The central API server
# Handles client connections, training run tracking, checkpoint listing.
- name: server
route_prefix: /api/v1 # API endpoint prefix (Tinker-compatible)
import_path: server # Python module to import
args:

deployments:
- name: TinkerCompatServer
autoscaling_config:
min_replicas: 1 # Minimum number of replicas
max_replicas: 1 # Maximum number of replicas
target_ongoing_requests: 128 # Target concurrent requests per replica
ray_actor_options:
num_cpus: 0.1 # CPU resources allocated to this actor

# 2. Model Service (commented out) - Would host the base model for training.
# Uncomment and configure if you need a training model worker.
- name: models-Qwen2.5-7B-Instruct
route_prefix: /api/v1/model/Qwen/Qwen2.5-7B-Instruct
import_path: model
args:
use_megatron: true
model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope model identifier
max_length: 10240
nproc_per_node: 2 # Number of GPU processes per node
device_group:
name: model
ranks: [0,1] # GPU rank indices
device_type: cuda
device_mesh:
device_type: cuda
dp_size: 2
queue_config:
rps_limit: 100 # Max requests per second
tps_limit: 100000 # Max tokens per second
adapter_config:
per_token_adapter_limit: 30 # Max concurrent LoRA adapters
adapter_timeout: 1800 # Seconds before idle adapter unload
deployments:
- name: ModelManagement
autoscaling_config:
min_replicas: 1
max_replicas: 1
target_ongoing_requests: 16
ray_actor_options:
num_cpus: 0.1
runtime_env:
env_vars:
TWINKLE_TRUST_REMOTE_CODE: "0"
DEVICE_COUNT_PER_PHYSICAL_NODE: "8"

# 3. Sampler Service - Runs inference / sampling using vLLM engine
# Used for generating text from the model (e.g., evaluating LoRA results).
- name: sampler-Qwen2.5-7B-Instruct
route_prefix: /api/v1/sampler/Qwen/Qwen2.5-7B-Instruct
import_path: sampler
args:
model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope model identifier
nproc_per_node: 2 # Number of GPU processes per node
sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler)
engine_args: # vLLM engine-specific settings
max_model_len: 4096 # Maximum sequence length the engine supports
gpu_memory_utilization: 0.5 # Fraction of GPU memory to use (0.0-1.0)
enable_lora: true # Allow loading LoRA adapters during inference
logprobs_mode: processed_logprobs # Logprobs mode for sampling results
device_group: # Logical device group for the sampler
name: sampler
ranks: [2] # GPU rank indices to use
device_type: cuda
device_mesh:
device_type: cuda
dp_size: 1
queue_config:
rps_limit: 100 # Max requests per second
tps_limit: 100000 # Max tokens per second
deployments:
- name: SamplerManagement
autoscaling_config:
min_replicas: 1
max_replicas: 1
target_ongoing_requests: 16
ray_actor_options:
num_cpus: 0.1
runtime_env:
env_vars:
TWINKLE_TRUST_REMOTE_CODE: "0"
DEVICE_COUNT_PER_PHYSICAL_NODE: "8"
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from modelscope import AutoTokenizer

# Step 1: Define the base model and connect to the server
base_model = "Qwen/Qwen2.5-3B-Instruct"
base_model = "Qwen/Qwen2.5-7B-Instruct"
service_client = init_tinker_compat_client(base_url='http://localhost:8000', api_key="tml-EMPTY_TOKEN")

# Step 2: Create a sampling client by loading weights from a saved checkpoint.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from modelscope import AutoTokenizer

# The base model to fine-tune / evaluate
base_model = "Qwen/Qwen2.5-3B-Instruct"
base_model = "Qwen/Qwen2.5-7B-Instruct"


def train():
Expand Down Expand Up @@ -83,7 +83,7 @@ def eval():
# Step 1: Load the trained LoRA checkpoint for inference

# Path to a previously saved LoRA checkpoint (twinkle:// URI)
weight_path = "twinkle://20260207_110850-Qwen_Qwen2_5-0_5B-Instruct-ce7e819f/weights/twinkle-lora-2"
weight_path = "twinkle://20260211_112719-Qwen_Qwen2_5-7B-Instruct-a74a4826/weights/twinkle-lora-2"

# Connect to the server and create a sampling client with the trained weights
service_client = init_tinker_compat_client(base_url='http://localhost:8000')
Expand Down Expand Up @@ -136,5 +136,5 @@ def eval():


if __name__ == "__main__":
# train() # Uncomment to run training
eval() # Run evaluation / inference
train() # Uncomment to run training
# eval() # Run evaluation / inference
2 changes: 0 additions & 2 deletions cookbook/client/tinker/transformer/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@

import os

# Enable Ray debug mode for verbose logging during development
# os.environ['RAY_DEBUG'] = '1'
os.environ['TWINKLE_TRUST_REMOTE_CODE'] = '0'

from twinkle.server import launch_server
Expand Down
Loading
Loading