diff --git a/cookbook/client/server/megatron/run.sh b/cookbook/client/server/megatron/run.sh
index 14966ce9..730a9a35 100644
--- a/cookbook/client/server/megatron/run.sh
+++ b/cookbook/client/server/megatron/run.sh
@@ -12,6 +12,8 @@
# --gpu-workers LIST GPU Worker 列表,分号分隔多个节点 (默认: 4,5,6,7:4)
# --cpu-workers N CPU Worker 数量 (默认: 1)
# --temp-dir DIR Ray 临时目录 (默认: /dashscope/caches/application/ray_logs)
+# --save-dir DIR Twinkle 模型保存目录 (默认: /dashscope/caches/application/save)
+# --server-config FILE Twinkle 服务器配置文件路径 (默认: /twinkle/cookbook/client/server/megatron/server_config.yaml)
# --help 显示帮助信息
#
# 示例:
@@ -49,6 +51,8 @@ RAY_ADDRESS="127.0.0.1:$RAY_PORT"
# --- 路径配置 ---
DEFAULT_TEMP_DIR="/dashscope/caches/application/ray_logs"
LOG_FILE="run.log"
+DEFAULT_SAVE_DIR="/dashscope/caches/application/save"
+DEFAULT_SERVER_CONFIG_FILE="/twinkle/cookbook/client/server/megatron/server_config.yaml"
# --- Prometheus 监控配置 ---
PROMETHEUS_BIN="/dashscope/caches/application/monitor/prometheus-3.10.0.linux-amd64/prometheus"
@@ -67,6 +71,8 @@ HEAD_NODE="0,1,2,3"
GPU_WORKERS_INPUT="4,5,6,7"
CPU_WORKER_COUNT="1"
TEMP_DIR="$DEFAULT_TEMP_DIR"
+SAVE_DIR="$DEFAULT_SAVE_DIR"
+SERVER_CONFIG_FILE="$DEFAULT_SERVER_CONFIG_FILE"
# 解析命名参数
while [[ $# -gt 0 ]]; do
@@ -103,6 +109,22 @@ while [[ $# -gt 0 ]]; do
TEMP_DIR="${1#*=}"
shift
;;
+ --save-dir)
+ SAVE_DIR="$2"
+ shift 2
+ ;;
+ --save-dir=*)
+ SAVE_DIR="${1#*=}"
+ shift
+ ;;
+ --server-config)
+ SERVER_CONFIG_FILE="$2"
+ shift 2
+ ;;
+ --server-config=*)
+ SERVER_CONFIG_FILE="${1#*=}"
+ shift
+ ;;
--help|-h)
echo "用法: ./run.sh [选项]"
echo ""
@@ -111,6 +133,8 @@ while [[ $# -gt 0 ]]; do
echo " --gpu-workers LIST GPU Worker 列表,分号分隔多个节点 (默认: 4,5,6,7)"
echo " --cpu-workers N CPU Worker 数量 (默认: 1)"
echo " --temp-dir DIR Ray 临时目录"
+ echo " --save-dir DIR Twinkle 模型保存目录 (默认: $DEFAULT_SAVE_DIR)"
+ echo " --server-config FILE Twinkle 服务器配置文件路径 (默认: $DEFAULT_SERVER_CONFIG_FILE)"
echo " --help, -h 显示帮助信息"
echo ""
echo "示例:"
@@ -129,6 +153,9 @@ while [[ $# -gt 0 ]]; do
esac
done
+# 将 SAVE_DIR export 给子进程(python server 通过环境变量读取)
+export TWINKLE_DEFAULT_SAVE_DIR="$SAVE_DIR"
+
# 将分号分隔的字符串转为数组
if [ -z "$GPU_WORKERS_INPUT" ]; then
GPU_WORKERS=()
@@ -222,6 +249,8 @@ echo ""
print_info "运行参数:"
echo " - Ray 地址: $RAY_ADDRESS"
echo " - 临时目录: $TEMP_DIR"
+echo " - 保存目录: $TWINKLE_DEFAULT_SAVE_DIR"
+echo " - 服务配置: $SERVER_CONFIG_FILE"
echo " - 日志文件: $LOG_FILE"
echo ""
@@ -235,6 +264,28 @@ fi
# 停止已有 Ray 集群和 Prometheus
# ============================================
print_header "清理环境"
+
+# 停止 Twinkle server.py(twinkle.server 模块)
+print_info "停止已有的 Twinkle Server..."
+pkill -f "twinkle.server" 2>/dev/null || true
+
+# 停止 vLLM 进程
+print_info "停止已有的 vLLM 进程..."
+pkill -f "vllm" 2>/dev/null || true
+
+# 等待上述进程退出
+sleep 2
+
+# 若仍有残留则强制 SIGKILL
+if pgrep -f "twinkle.server" > /dev/null 2>&1; then
+ print_warning "Twinkle Server 未退出,强制终止..."
+ pkill -9 -f "twinkle.server" 2>/dev/null || true
+fi
+if pgrep -f "vllm" > /dev/null 2>&1; then
+ print_warning "vLLM 进程未退出,强制终止..."
+ pkill -9 -f "vllm" 2>/dev/null || true
+fi
+
print_info "停止已有的 Ray 集群..."
ray stop --force 2>/dev/null || true
@@ -334,8 +385,10 @@ print_info "日志输出到: $LOG_FILE"
echo ""
# 启动服务器并实时显示日志
-nohup python server.py > "$LOG_FILE" 2>&1 &
+touch "$LOG_FILE" # 预创建文件,避免 tail -f 在文件尚未写入时报错
+nohup python -m twinkle.server --config "$SERVER_CONFIG_FILE" > "$LOG_FILE" 2>&1 &
SERVER_PID=$!
+print_success "Twinkle Server 已启动 (PID: $SERVER_PID)"
# 实时显示日志
tail -f "$LOG_FILE"
diff --git a/cookbook/client/server/megatron/server.py b/cookbook/client/server/megatron/server.py
deleted file mode 100644
index d6cb87c5..00000000
--- a/cookbook/client/server/megatron/server.py
+++ /dev/null
@@ -1,21 +0,0 @@
-# Twinkle Server Launcher - Tinker-Compatible Megatron Backend
-#
-# This script starts the Twinkle server with Tinker-compatible API support
-# using the Megatron model backend.
-# It reads the server_config.yaml in the same directory for all
-# configuration (model, deployment settings, etc.).
-# Run this script BEFORE running the client training script (lora.py).
-
-import os
-
-# Enable Ray debug mode for verbose logging during development
-os.environ['TWINKLE_TRUST_REMOTE_CODE'] = '0'
-
-from twinkle.server import launch_server
-
-# Resolve the path to server_config.yaml relative to this script's location
-file_dir = os.path.abspath(os.path.dirname(__file__))
-config_path = os.path.join(file_dir, 'server_config.yaml')
-
-# Launch the Twinkle server — this call blocks until the server is shut down
-launch_server(config_path=config_path)
diff --git a/cookbook/client/server/megatron/server_config.yaml b/cookbook/client/server/megatron/server_config.yaml
index 6d584455..90c8d9ac 100644
--- a/cookbook/client/server/megatron/server_config.yaml
+++ b/cookbook/client/server/megatron/server_config.yaml
@@ -50,6 +50,7 @@ applications:
enable_lora: true # Allow loading LoRA adapters during inference
max_loras: 5 # Max allowed loras working on vLLM at the same time
max_lora_rank: 32 # Support up to rank 64 LoRA adapters
+ enable_tower_connector_lora: true
device_group: # Logical device group for the sampler
name: sampler
gpus_per_worker: 2
diff --git a/cookbook/client/server/megatron/server_config_4b.yaml b/cookbook/client/server/megatron/server_config_4b.yaml
index 5dd8a696..36adb332 100644
--- a/cookbook/client/server/megatron/server_config_4b.yaml
+++ b/cookbook/client/server/megatron/server_config_4b.yaml
@@ -76,13 +76,14 @@ applications:
import_path: sampler
args:
model_id: "ms://Qwen/Qwen3.5-4B" # ModelScope model identifier
- nproc_per_node: 2 # Number of GPU processes per node
+ nproc_per_node: 1 # Number of GPU processes per node
sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler)
engine_args: # vLLM engine-specific settings
max_model_len: 16000 # Maximum sequence length the engine supports
gpu_memory_utilization: 0.7 # Fraction of GPU memory to use (0.0-1.0)
enable_lora: true # Allow loading LoRA adapters during inference
logprobs_mode: processed_logprobs # Logprobs mode for sampling results
+ enable_tower_connector_lora: true
device_group: # Logical device group for the sampler
name: sampler
ranks: 1 # Number of GPUs to use
diff --git a/cookbook/client/tinker/modelscope/short_math_grpo.py b/cookbook/client/tinker/modelscope/short_math_grpo.py
index 47a7d24a..c361b934 100644
--- a/cookbook/client/tinker/modelscope/short_math_grpo.py
+++ b/cookbook/client/tinker/modelscope/short_math_grpo.py
@@ -1,12 +1,12 @@
-# Tinker-Compatible Client - Math GRPO Training Example
+# Tinker-Compatible Client - GSM8K GRPO Training Example
#
-# This script demonstrates Math problem training using the
+# This script demonstrates GSM8K math problem training using the
# Tinker-compatible client API with save_weights_for_sampler for weight sync.
# Instead of calling sync_weights directly, it periodically saves weights and
# creates a sampling client for generation.
#
# Flow:
-# 1. Prepare Math dataset (client-side)
+# 1. Prepare GSM8K dataset (client-side)
# 2. Initialize Tinker-compatible training & sampling clients
# 3. Training loop:
# a. Every SYNC_INTERVAL steps: save_weights_for_sampler → sampling_client
@@ -22,191 +22,102 @@
import os
import re
from tinker import types
-from typing import List, Tuple
+from typing import List, Tuple, Dict, Any
from twinkle import init_tinker_client
from twinkle import get_logger
from twinkle.advantage import GRPOAdvantage
-from twinkle.data_format import Message, Trajectory
from twinkle.dataloader import DataLoader
from twinkle.dataset import Dataset, DatasetMeta
-from twinkle.preprocessor import Preprocessor
+from twinkle.preprocessor.llm import GSM8KProcessor
+from twinkle.reward import GSM8KAccuracyReward
from twinkle.reward.base import Reward
from twinkle.metric import CompletionRewardMetric
-from twinkle.template import Template
+from twinkle.template import Qwen3_5Template
logger = get_logger()
# ========== Configuration ==========
BASE_MODEL = 'Qwen/Qwen3.5-27B'
-NUM_GENERATIONS = 8
+NUM_GENERATIONS = 4
MAX_NEW_TOKENS = 4096
-LEARNING_RATE = 1e-4
+LEARNING_RATE = 2e-5
MAX_STEPS = 1000
BATCH_SIZE = 2
TEMPERATURE = 1.0
SYNC_INTERVAL = 1 # Save weights for sampler every N steps
-LORA_RANK = 8
+LORA_RANK = 16
DATA_NUM = 2000 # Number of Math samples to use
-SYSTEM_PROMPT = ('You are a math assistant that values brevity. '
- 'Solve problems with minimal but correct reasoning.\n\n'
- 'Rules:\n'
- '1. Use tags for reasoning\n'
- '2. Final answer after ####\n\n'
- 'Example:\nKey step1 -> Ket step 2 -> conclusion\n#### 42')
+SYSTEM_PROMPT = ('You are a helpful math assistant. Solve the problem with minimal but correct reasoning '
+ 'and put your final answer within \\boxed{}.')
+# ========== Reward Functions ==========
+class GSM8KBrevityReward(Reward):
+ """Brevity reward: rewards shorter completions that contain a valid answer.
-class MathPreprocessor(Preprocessor):
-
- def __call__(self, rows):
- rows = self.map_col_to_row(rows)
- rows = [self.preprocess(row) for row in rows]
- rows = self.map_row_to_col(rows)
- return rows
-
- def preprocess(self, sample):
- if sample['level'] not in ('Level 4', 'Level 5'):
- return Trajectory(messages=[], user_data=[])
-
- def get_boxed_answer(text):
- match = re.search(r'\\boxed{([^}]*)}', text)
- return match.group(1) if match else None
-
- ground_truth = get_boxed_answer(sample['solution'])
- if ground_truth is None:
- return Trajectory(messages=[], user_data=[])
- problem = sample['problem']
- return Trajectory(
- messages=[
- Message(role='system', content=SYSTEM_PROMPT),
- Message(role='user', content=problem),
- ],
- user_data=[('ground_truth', ground_truth)],
- )
-
-
-# ========== Math Reward Functions ==========
-class MathAccuracyReward(Reward):
- """Accuracy reward for Math: checks if the model's answer matches ground truth.
-
- Extracts the last '#### ' from model output and compares with ground truth.
- Returns 1.0 for correct, 0.0 for incorrect.
+ Returns 0.0 if no valid answer format (\\boxed{} or ####).
+ Otherwise returns higher score for shorter completions (1.0 at <=200 chars).
"""
- @staticmethod
- def extract_answer(completion: str) -> str:
- """Extract the last #### answer from model completion."""
- # Only check last 500 chars for efficiency
- text = completion[-500:] if len(completion) > 500 else completion
- matches = re.findall(r'####\s*([\-\d,\.\s]+)', text)
- if matches:
- return matches[-1].replace(',', '').replace(' ', '').strip()
- return ''
-
- def __call__(self, trajectories: List[Trajectory], ground_truths: List[Trajectory]) -> List[float]:
+ def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]:
rewards = []
- for trajectory in trajectories:
- messages = trajectory.get('messages', [])
- # Get model completion (last assistant message)
+ for traj in trajectories:
+ messages = traj.get('messages', [])
completion = ''
for msg in reversed(messages):
if msg.get('role') == 'assistant':
completion = msg.get('content', '')
break
- # Get ground truth from user_data
- gt = ''
- user_data = trajectory.get('user_data', [])
- if isinstance(user_data, list):
- for item in user_data:
- if isinstance(item, (list, tuple)) and len(item) == 2:
- if item[0] == 'ground_truth':
- gt = str(item[1])
- break
-
- predicted = self.extract_answer(completion)
-
- # Numeric comparison
- correct = False
- if predicted and gt:
- try:
- correct = abs(float(predicted) - float(gt)) < 1e-5
- except (ValueError, OverflowError):
- correct = predicted == gt
-
- rewards.append(1.0 if correct else 0.0)
- return rewards
-
-
-class MathFormatReward(Reward):
- """Format reward: checks format and rewards shorter completions.
-
- Returns higher score for shorter completions (1.0 at length 100 or less).
- Returns 0.0 if format is incorrect.
- """
-
- def __call__(self, trajectories: List[Trajectory], ground_truths: List[Trajectory]) -> List[float]:
- rewards = []
- for trajectory in trajectories:
- messages = trajectory.get('messages', [])
- completion = ''
- for msg in reversed(messages):
- if msg.get('role') == 'assistant':
- completion = msg.get('content', '')
- break
-
- has_think = bool(re.search(r'.*?', completion, re.DOTALL))
- has_answer = bool(re.search(r'####\s*[\-\d,\.]+', completion))
+ has_answer = bool(
+ re.search(r'\\boxed\{[^}]+\}', completion)
+ or re.search(r'####\s*[\-\d,\.]+', completion)
+ )
- if not (has_think and has_answer):
+ if not has_answer:
rewards.append(0.0)
else:
length = len(completion)
- if length <= 100:
+ if length <= 200:
rewards.append(1.0)
else:
- reward = max(0.0, 1.0 - (length - 100) / 2000)
- rewards.append(reward)
-
+ rewards.append(max(0.0, 1.0 - (length - 200) / 3000))
return rewards
-def create_math_dataset():
- """Create Math dataset."""
- meta = DatasetMeta(
- 'ms://modelscope/competition_math',
- subset_name='default',
- split='train',
- data_slice=range(DATA_NUM),
- )
- dataset = Dataset(meta)
- dataset.set_template('Qwen3_5Template', model_id=BASE_MODEL, max_length=4096, truncation_strategy='delete')
- dataset.map(MathPreprocessor())
- dataset.filter(lambda row: bool(row['messages']))
+# ========== Dataset ==========
+def create_gsm8k_dataset():
+ """Create GSM8K dataset."""
+ dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train', data_slice=range(DATA_NUM)))
+ dataset.set_template('Qwen3_5Template', model_id=f'ms://{BASE_MODEL}', max_length=4096,
+ truncation_strategy='delete', enable_thinking=True)
+ dataset.map(GSM8KProcessor(system=SYSTEM_PROMPT))
dataset.encode(add_generation_prompt=True)
return dataset
-def compute_rewards(trajectories: List[Trajectory], ) -> Tuple[List[float], List[float], List[float]]:
- """Compute accuracy and format rewards for Math."""
- accuracy_reward_fn = MathAccuracyReward()
- format_reward_fn = MathFormatReward()
+def compute_rewards(
+ trajectories: List[Dict[str, Any]],
+) -> Tuple[List[float], List[float], List[float]]:
+ """Compute accuracy and brevity rewards for GSM8K."""
+ accuracy_reward_fn = GSM8KAccuracyReward()
+ brevity_reward_fn = GSM8KBrevityReward()
- accuracy_rewards = accuracy_reward_fn(trajectories, [])
- format_rewards = format_reward_fn(trajectories, [])
- total_rewards = [a + f for a, f in zip(accuracy_rewards, format_rewards)]
- return total_rewards, format_rewards, accuracy_rewards
+ accuracy_rewards = accuracy_reward_fn(trajectories)
+ brevity_rewards = brevity_reward_fn(trajectories)
+ total_rewards = [a + b for a, b in zip(accuracy_rewards, brevity_rewards)]
+ return total_rewards, brevity_rewards, accuracy_rewards
def main():
- logger.info('Starting Math GRPO training...')
+ logger.info('Starting GSM8K GRPO training...')
# Step 1: Prepare dataset and dataloader (client-side)
- dataset = create_math_dataset()
- dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE)
- template = Template(model_id=f'ms://{BASE_MODEL}')
+ dataset = create_gsm8k_dataset()
+ dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, num_workers=0)
+ template = Qwen3_5Template(model_id=f'ms://{BASE_MODEL}')
logger.info('Dataset and template initialized')
@@ -254,7 +165,7 @@ def main():
if step % SYNC_INTERVAL == 0:
logger.info(f'Step {step}: Saving weights for sampler...')
- sampling_client = (training_client.save_weights_and_get_sampling_client(name=f'Math-step-{step}'))
+ sampling_client = (training_client.save_weights_and_get_sampling_client(name=f'GSM8K-step-{step}'))
logger.info(f'Step {step}: Sampling client ready')
if sampling_client is None:
@@ -317,14 +228,12 @@ def main():
completion_lengths.append(len(seq.tokens))
# ========== 4. Compute rewards ==========
- total_rewards, format_rewards, accuracy_rewards = compute_rewards(trajectories)
+ total_rewards, brevity_rewards, accuracy_rewards = compute_rewards(trajectories)
metrics.accumulate(
- None,
- None,
completion_lengths=completion_lengths,
rewards={
'total': total_rewards,
- 'format': format_rewards,
+ 'brevity': brevity_rewards,
'accuracy': accuracy_rewards,
})
@@ -407,7 +316,7 @@ def main():
step += 1
# Save final checkpoint
- save_future = training_client.save_state('Math-grpo-final')
+ save_future = training_client.save_state('gsm8k-grpo-final')
save_result = save_future.result()
logger.info(f'Saved final checkpoint to {save_result.path}')
diff --git a/cookbook/client/tinker/self_host/short_math_grpo.py b/cookbook/client/tinker/self_host/short_math_grpo.py
index f077c669..e19955a6 100644
--- a/cookbook/client/tinker/self_host/short_math_grpo.py
+++ b/cookbook/client/tinker/self_host/short_math_grpo.py
@@ -33,7 +33,7 @@
from twinkle.reward import GSM8KAccuracyReward
from twinkle.reward.base import Reward
from twinkle.metric import CompletionRewardMetric
-from twinkle.template import Template
+from twinkle.template import Qwen3_5Template
logger = get_logger()
@@ -41,9 +41,9 @@
BASE_MODEL = 'Qwen/Qwen3.5-4B'
NUM_GENERATIONS = 4
MAX_NEW_TOKENS = 4096
-LEARNING_RATE = 1e-5
+LEARNING_RATE = 2e-5
MAX_STEPS = 1000
-BATCH_SIZE = 4
+BATCH_SIZE = 2
TEMPERATURE = 1.0
SYNC_INTERVAL = 1 # Save weights for sampler every N steps
LORA_RANK = 16
@@ -92,7 +92,7 @@ def create_gsm8k_dataset():
"""Create GSM8K dataset."""
dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train', data_slice=range(DATA_NUM)))
dataset.set_template('Qwen3_5Template', model_id=f'ms://{BASE_MODEL}', max_length=4096,
- truncation_strategy='delete', enable_thinking=False)
+ truncation_strategy='delete', enable_thinking=True)
dataset.map(GSM8KProcessor(system=SYSTEM_PROMPT))
dataset.encode(add_generation_prompt=True)
return dataset
@@ -117,7 +117,7 @@ def main():
# Step 1: Prepare dataset and dataloader (client-side)
dataset = create_gsm8k_dataset()
dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, num_workers=0)
- template = Template(model_id=f'ms://{BASE_MODEL}')
+ template = Qwen3_5Template(model_id=f'ms://{BASE_MODEL}')
logger.info('Dataset and template initialized')
diff --git a/cookbook/client/twinkle/modelscope/self_congnition.py b/cookbook/client/twinkle/modelscope/self_congnition.py
index 81c5ab4d..aeef5606 100644
--- a/cookbook/client/twinkle/modelscope/self_congnition.py
+++ b/cookbook/client/twinkle/modelscope/self_congnition.py
@@ -14,9 +14,9 @@
from twinkle import get_logger
from twinkle.dataset import DatasetMeta
-from twinkle_client import init_twinkle_client
-from twinkle_client.dataloader import DataLoader
-from twinkle_client.dataset import Dataset
+from twinkle import init_twinkle_client
+from twinkle.dataloader import DataLoader
+from twinkle.dataset import Dataset
from twinkle_client.model import MultiLoraTransformersModel
logger = get_logger()
diff --git a/cookbook/client/twinkle/self_host/self_congnition.py b/cookbook/client/twinkle/self_host/self_congnition.py
index f382956f..c4e14d30 100644
--- a/cookbook/client/twinkle/self_host/self_congnition.py
+++ b/cookbook/client/twinkle/self_host/self_congnition.py
@@ -14,9 +14,9 @@
from twinkle import get_logger
from twinkle.dataset import DatasetMeta
-from twinkle_client import init_twinkle_client
-from twinkle_client.dataloader import DataLoader
-from twinkle_client.dataset import Dataset
+from twinkle import init_twinkle_client
+from twinkle.dataloader import DataLoader
+from twinkle.dataset import Dataset
from twinkle_client.model import MultiLoraTransformersModel
logger = get_logger()
diff --git a/cookbook/client/twinkle/self_host/grpo.py b/cookbook/client/twinkle/self_host/short_math_grpo.py
similarity index 68%
rename from cookbook/client/twinkle/self_host/grpo.py
rename to cookbook/client/twinkle/self_host/short_math_grpo.py
index d87bfa77..03993c96 100644
--- a/cookbook/client/twinkle/self_host/grpo.py
+++ b/cookbook/client/twinkle/self_host/short_math_grpo.py
@@ -25,38 +25,84 @@
import gc
import os
+import re
from peft import LoraConfig
from typing import List, Tuple, Dict, Any
+import swanlab
+
from twinkle import get_logger
-from twinkle.reward import GSM8KAccuracyReward, GSM8KFormatReward
+from twinkle.reward import GSM8KAccuracyReward
+from twinkle.reward.base import Reward
from twinkle.advantage import GRPOAdvantage
from twinkle.dataset import DatasetMeta
from twinkle.metric import CompletionRewardMetric
-from twinkle_client import init_twinkle_client
-from twinkle_client.dataloader import DataLoader
-from twinkle_client.dataset import Dataset
+from twinkle import init_twinkle_client
+from twinkle.dataloader import DataLoader
+from twinkle.dataset import Dataset
+from twinkle.preprocessor.llm import GSM8KProcessor
from twinkle_client.model import MultiLoraTransformersModel
from twinkle_client.sampler import vLLMSampler
logger = get_logger()
+
+class GSM8KBrevityReward(Reward):
+ """Brevity reward: rewards shorter completions that contain a valid answer.
+
+ Returns 0.0 if no valid answer format (\\boxed{} or ####).
+ Otherwise returns higher score for shorter completions (1.0 at <=200 chars).
+ """
+
+ def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]:
+ rewards = []
+ for traj in trajectories:
+ messages = traj.get('messages', [])
+ completion = ''
+ for msg in reversed(messages):
+ if msg.get('role') == 'assistant':
+ completion = msg.get('content', '')
+ break
+
+ has_answer = bool(
+ re.search(r'\\boxed\{[^}]+\}', completion)
+ or re.search(r'####\s*[\-\d,\.]+', completion)
+ )
+
+ if not has_answer:
+ rewards.append(0.0)
+ else:
+ length = len(completion)
+ if length <= 200:
+ rewards.append(1.0)
+ else:
+ rewards.append(max(0.0, 1.0 - (length - 200) / 3000))
+ return rewards
+
# ========== Configuration ==========
MODEL_ID = 'ms://Qwen/Qwen3.5-4B'
NUM_GENERATIONS = 4
MAX_NEW_TOKENS = 1024
-LEARNING_RATE = 1e-5
-MAX_STEPS = 10
+LEARNING_RATE = 2e-5
+MAX_STEPS = 100
BATCH_SIZE = 2
TEMPERATURE = 1.0
SYNC_INTERVAL = 1 # Save weights for sampler every N steps
-GRADIENT_ACCUMULATION_STEPS = 4
+GRADIENT_ACCUMULATION_STEPS = 1
+DATA_NUM = 2000 # Number of Math samples to use
+USE_SWANLAB = True
+SWANLAB_PROJECT = 'twinkle-grpo'
+SWANLAB_EXPERIMENT_NAME = 'short-math-grpo'
+
+
+SYSTEM_PROMPT = ('You are a helpful math assistant. Solve the problem with minimal but correct reasoning '
+ 'and put your final answer within \\boxed{}.')
def create_gsm8k_dataset():
- dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train'))
- dataset.set_template('Qwen3_5Template', model_id=MODEL_ID, max_length=2048)
- dataset.map('GSM8KProcessor')
+ dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train', data_slice=range(DATA_NUM)))
+ dataset.set_template('Qwen3_5Template', model_id=MODEL_ID, max_length=2048, enable_thinking=False)
+ dataset.map(GSM8KProcessor(system=SYSTEM_PROMPT))
dataset.encode(add_generation_prompt=True)
return dataset
@@ -64,24 +110,44 @@ def compute_rewards(
trajectories: List[Dict[str, Any]],
) -> Tuple[List[float], List[float], List[float]]:
accuracy_reward_fn = GSM8KAccuracyReward()
- format_reward_fn = GSM8KFormatReward()
+ brevity_reward_fn = GSM8KBrevityReward()
accuracy_rewards = accuracy_reward_fn(trajectories)
- format_rewards = format_reward_fn(trajectories)
- total_rewards = [a + f for a, f in zip(accuracy_rewards, format_rewards)]
- return total_rewards, format_rewards, accuracy_rewards
+ brevity_rewards = brevity_reward_fn(trajectories)
+ total_rewards = [a + b for a, b in zip(accuracy_rewards, brevity_rewards)]
+ return total_rewards, brevity_rewards, accuracy_rewards
def train():
+ # Step 0: Initialize SwanLab if enabled
+ if USE_SWANLAB:
+ swanlab.login(api_key=os.environ.get('SWANLAB_API_KEY', ''))
+ swanlab.init(
+ project=SWANLAB_PROJECT,
+ experiment_name=SWANLAB_EXPERIMENT_NAME,
+ config={
+ 'model_id': MODEL_ID,
+ 'num_generations': NUM_GENERATIONS,
+ 'max_new_tokens': MAX_NEW_TOKENS,
+ 'learning_rate': LEARNING_RATE,
+ 'max_steps': MAX_STEPS,
+ 'batch_size': BATCH_SIZE,
+ 'temperature': TEMPERATURE,
+ 'sync_interval': SYNC_INTERVAL,
+ 'gradient_accumulation_steps': GRADIENT_ACCUMULATION_STEPS,
+ },
+ )
+ logger.info('SwanLab initialized')
+
# Step 1: Initialize the Twinkle client
client = init_twinkle_client(
base_url='http://127.0.0.1:8000',
- api_key=os.environ.get('MODELSCOPE_TOKEN'),
+ api_key='EMPTY_TOKEN',
)
# Step 2: Prepare dataset and dataloader
dataset = create_gsm8k_dataset()
- dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE)
+ dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, num_workers=0)
# Step 3: Configure the training model
model = MultiLoraTransformersModel(model_id=MODEL_ID)
@@ -172,14 +238,14 @@ def train():
# ========== 3. Compute rewards ==========
- total_rewards, format_rewards, accuracy_rewards = compute_rewards(
+ total_rewards, brevity_rewards, accuracy_rewards = compute_rewards(
all_input_data
)
metrics.accumulate(
completion_lengths=all_completion_lengths,
rewards={
'total': total_rewards,
- 'format': format_rewards,
+ 'brevity': brevity_rewards,
'accuracy': accuracy_rewards,
},
)
@@ -217,6 +283,11 @@ def train():
log_dict.update(model.calculate_metric(is_training=True).result)
log_dict['train/frac_reward_zero_std'] = frac_zero_std
logger.info(f'Step {step}: {log_dict}')
+
+ # Log metrics to SwanLab
+ if USE_SWANLAB and log_dict:
+ swanlab.log(log_dict, step=step)
+
step += 1
metrics.reset()
diff --git a/src/twinkle/server/model/tinker_handlers.py b/src/twinkle/server/model/tinker_handlers.py
index 88a59b2c..37d9df60 100644
--- a/src/twinkle/server/model/tinker_handlers.py
+++ b/src/twinkle/server/model/tinker_handlers.py
@@ -43,6 +43,7 @@ async def _create_adapter():
try:
_model_id = await self.state.register_model(body.model_dump(), token=token, replica_id=self.replica_id)
if body.lora_config:
+ # TODO: Make LoraConfig more flexible
lora_cfg = LoraConfig(r=body.lora_config.rank, target_modules='all-linear')
adapter_name = self.get_adapter_name(adapter_name=_model_id)
self.register_resource(adapter_name, token, session_id=body.session_id)
diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py
index d3a74c1d..ce9abd6a 100644
--- a/src/twinkle/template/base.py
+++ b/src/twinkle/template/base.py
@@ -167,10 +167,13 @@ def concat_input_feature(self, prompt_input_feature: InputFeature, new_tokens: L
result['input_ids'] = input_ids
result['labels'] = labels
if 'mm_token_type_ids' in result:
- token_ids_shape = result['mm_token_type_ids'].shape
- device = result['mm_token_type_ids'].device
+ mm_token_type_ids = result['mm_token_type_ids']
+ if not isinstance(mm_token_type_ids, torch.Tensor):
+ mm_token_type_ids = torch.as_tensor(mm_token_type_ids)
+ token_ids_shape = mm_token_type_ids.shape
+ device = mm_token_type_ids.device
padded_tokens = torch.zeros((token_ids_shape[0], len(new_tokens))).to(device)
- result['mm_token_type_ids'] = torch.cat((result['mm_token_type_ids'], padded_tokens), dim=1)
+ result['mm_token_type_ids'] = torch.cat((mm_token_type_ids, padded_tokens), dim=1)
new_input_feature = self._invoke_post_pipeline([result])[0]
result.update(new_input_feature)
messages: List[Message] = result.get('messages')
@@ -466,6 +469,8 @@ def _apply_chat_template(self, trajectory: Trajectory, add_generation_prompt: bo
# Set default values for processor_kwargs
if 'enable_thinking' not in kwargs:
processor_kwargs['enable_thinking'] = self.enable_thinking
+ if 'padding' not in kwargs:
+ processor_kwargs['padding'] = False
# Add remaining kwargs to processor_kwargs
processor_kwargs.update(kwargs)
@@ -473,7 +478,6 @@ def _apply_chat_template(self, trajectory: Trajectory, add_generation_prompt: bo
inputs = self.processor.apply_chat_template(
messages,
tools=tools,
- padding=False,
return_dict=True,
add_generation_prompt=add_generation_prompt,
return_tensors='pt',