From d722128cd9c9e0d03e84791fc3fa608e5838607c Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Fri, 10 Apr 2026 16:00:51 +0800 Subject: [PATCH 1/6] update short math grpo --- .../server/megatron/server_config_4b.yaml | 3 +- .../tinker/modelscope/short_math_grpo.py | 195 +++++------------- .../tinker/self_host/short_math_grpo.py | 10 +- .../twinkle/self_host/self_congnition.py | 2 +- .../self_host/{grpo.py => short_math_grpo.py} | 107 ++++++++-- src/twinkle/server/model/tinker_handlers.py | 1 + src/twinkle/template/base.py | 12 +- 7 files changed, 158 insertions(+), 172 deletions(-) rename cookbook/client/twinkle/self_host/{grpo.py => short_math_grpo.py} (68%) diff --git a/cookbook/client/server/megatron/server_config_4b.yaml b/cookbook/client/server/megatron/server_config_4b.yaml index 5dd8a696..36adb332 100644 --- a/cookbook/client/server/megatron/server_config_4b.yaml +++ b/cookbook/client/server/megatron/server_config_4b.yaml @@ -76,13 +76,14 @@ applications: import_path: sampler args: model_id: "ms://Qwen/Qwen3.5-4B" # ModelScope model identifier - nproc_per_node: 2 # Number of GPU processes per node + nproc_per_node: 1 # Number of GPU processes per node sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler) engine_args: # vLLM engine-specific settings max_model_len: 16000 # Maximum sequence length the engine supports gpu_memory_utilization: 0.7 # Fraction of GPU memory to use (0.0-1.0) enable_lora: true # Allow loading LoRA adapters during inference logprobs_mode: processed_logprobs # Logprobs mode for sampling results + enable_tower_connector_lora: true device_group: # Logical device group for the sampler name: sampler ranks: 1 # Number of GPUs to use diff --git a/cookbook/client/tinker/modelscope/short_math_grpo.py b/cookbook/client/tinker/modelscope/short_math_grpo.py index 47a7d24a..c361b934 100644 --- a/cookbook/client/tinker/modelscope/short_math_grpo.py +++ b/cookbook/client/tinker/modelscope/short_math_grpo.py @@ -1,12 +1,12 @@ -# Tinker-Compatible Client - Math GRPO Training Example +# Tinker-Compatible Client - GSM8K GRPO Training Example # -# This script demonstrates Math problem training using the +# This script demonstrates GSM8K math problem training using the # Tinker-compatible client API with save_weights_for_sampler for weight sync. # Instead of calling sync_weights directly, it periodically saves weights and # creates a sampling client for generation. # # Flow: -# 1. Prepare Math dataset (client-side) +# 1. Prepare GSM8K dataset (client-side) # 2. Initialize Tinker-compatible training & sampling clients # 3. Training loop: # a. Every SYNC_INTERVAL steps: save_weights_for_sampler → sampling_client @@ -22,191 +22,102 @@ import os import re from tinker import types -from typing import List, Tuple +from typing import List, Tuple, Dict, Any from twinkle import init_tinker_client from twinkle import get_logger from twinkle.advantage import GRPOAdvantage -from twinkle.data_format import Message, Trajectory from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta -from twinkle.preprocessor import Preprocessor +from twinkle.preprocessor.llm import GSM8KProcessor +from twinkle.reward import GSM8KAccuracyReward from twinkle.reward.base import Reward from twinkle.metric import CompletionRewardMetric -from twinkle.template import Template +from twinkle.template import Qwen3_5Template logger = get_logger() # ========== Configuration ========== BASE_MODEL = 'Qwen/Qwen3.5-27B' -NUM_GENERATIONS = 8 +NUM_GENERATIONS = 4 MAX_NEW_TOKENS = 4096 -LEARNING_RATE = 1e-4 +LEARNING_RATE = 2e-5 MAX_STEPS = 1000 BATCH_SIZE = 2 TEMPERATURE = 1.0 SYNC_INTERVAL = 1 # Save weights for sampler every N steps -LORA_RANK = 8 +LORA_RANK = 16 DATA_NUM = 2000 # Number of Math samples to use -SYSTEM_PROMPT = ('You are a math assistant that values brevity. ' - 'Solve problems with minimal but correct reasoning.\n\n' - 'Rules:\n' - '1. Use tags for reasoning\n' - '2. Final answer after ####\n\n' - 'Example:\nKey step1 -> Ket step 2 -> conclusion\n#### 42') +SYSTEM_PROMPT = ('You are a helpful math assistant. Solve the problem with minimal but correct reasoning ' + 'and put your final answer within \\boxed{}.') +# ========== Reward Functions ========== +class GSM8KBrevityReward(Reward): + """Brevity reward: rewards shorter completions that contain a valid answer. -class MathPreprocessor(Preprocessor): - - def __call__(self, rows): - rows = self.map_col_to_row(rows) - rows = [self.preprocess(row) for row in rows] - rows = self.map_row_to_col(rows) - return rows - - def preprocess(self, sample): - if sample['level'] not in ('Level 4', 'Level 5'): - return Trajectory(messages=[], user_data=[]) - - def get_boxed_answer(text): - match = re.search(r'\\boxed{([^}]*)}', text) - return match.group(1) if match else None - - ground_truth = get_boxed_answer(sample['solution']) - if ground_truth is None: - return Trajectory(messages=[], user_data=[]) - problem = sample['problem'] - return Trajectory( - messages=[ - Message(role='system', content=SYSTEM_PROMPT), - Message(role='user', content=problem), - ], - user_data=[('ground_truth', ground_truth)], - ) - - -# ========== Math Reward Functions ========== -class MathAccuracyReward(Reward): - """Accuracy reward for Math: checks if the model's answer matches ground truth. - - Extracts the last '#### ' from model output and compares with ground truth. - Returns 1.0 for correct, 0.0 for incorrect. + Returns 0.0 if no valid answer format (\\boxed{} or ####). + Otherwise returns higher score for shorter completions (1.0 at <=200 chars). """ - @staticmethod - def extract_answer(completion: str) -> str: - """Extract the last #### answer from model completion.""" - # Only check last 500 chars for efficiency - text = completion[-500:] if len(completion) > 500 else completion - matches = re.findall(r'####\s*([\-\d,\.\s]+)', text) - if matches: - return matches[-1].replace(',', '').replace(' ', '').strip() - return '' - - def __call__(self, trajectories: List[Trajectory], ground_truths: List[Trajectory]) -> List[float]: + def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]: rewards = [] - for trajectory in trajectories: - messages = trajectory.get('messages', []) - # Get model completion (last assistant message) + for traj in trajectories: + messages = traj.get('messages', []) completion = '' for msg in reversed(messages): if msg.get('role') == 'assistant': completion = msg.get('content', '') break - # Get ground truth from user_data - gt = '' - user_data = trajectory.get('user_data', []) - if isinstance(user_data, list): - for item in user_data: - if isinstance(item, (list, tuple)) and len(item) == 2: - if item[0] == 'ground_truth': - gt = str(item[1]) - break - - predicted = self.extract_answer(completion) - - # Numeric comparison - correct = False - if predicted and gt: - try: - correct = abs(float(predicted) - float(gt)) < 1e-5 - except (ValueError, OverflowError): - correct = predicted == gt - - rewards.append(1.0 if correct else 0.0) - return rewards - - -class MathFormatReward(Reward): - """Format reward: checks format and rewards shorter completions. - - Returns higher score for shorter completions (1.0 at length 100 or less). - Returns 0.0 if format is incorrect. - """ - - def __call__(self, trajectories: List[Trajectory], ground_truths: List[Trajectory]) -> List[float]: - rewards = [] - for trajectory in trajectories: - messages = trajectory.get('messages', []) - completion = '' - for msg in reversed(messages): - if msg.get('role') == 'assistant': - completion = msg.get('content', '') - break - - has_think = bool(re.search(r'.*?', completion, re.DOTALL)) - has_answer = bool(re.search(r'####\s*[\-\d,\.]+', completion)) + has_answer = bool( + re.search(r'\\boxed\{[^}]+\}', completion) + or re.search(r'####\s*[\-\d,\.]+', completion) + ) - if not (has_think and has_answer): + if not has_answer: rewards.append(0.0) else: length = len(completion) - if length <= 100: + if length <= 200: rewards.append(1.0) else: - reward = max(0.0, 1.0 - (length - 100) / 2000) - rewards.append(reward) - + rewards.append(max(0.0, 1.0 - (length - 200) / 3000)) return rewards -def create_math_dataset(): - """Create Math dataset.""" - meta = DatasetMeta( - 'ms://modelscope/competition_math', - subset_name='default', - split='train', - data_slice=range(DATA_NUM), - ) - dataset = Dataset(meta) - dataset.set_template('Qwen3_5Template', model_id=BASE_MODEL, max_length=4096, truncation_strategy='delete') - dataset.map(MathPreprocessor()) - dataset.filter(lambda row: bool(row['messages'])) +# ========== Dataset ========== +def create_gsm8k_dataset(): + """Create GSM8K dataset.""" + dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train', data_slice=range(DATA_NUM))) + dataset.set_template('Qwen3_5Template', model_id=f'ms://{BASE_MODEL}', max_length=4096, + truncation_strategy='delete', enable_thinking=True) + dataset.map(GSM8KProcessor(system=SYSTEM_PROMPT)) dataset.encode(add_generation_prompt=True) return dataset -def compute_rewards(trajectories: List[Trajectory], ) -> Tuple[List[float], List[float], List[float]]: - """Compute accuracy and format rewards for Math.""" - accuracy_reward_fn = MathAccuracyReward() - format_reward_fn = MathFormatReward() +def compute_rewards( + trajectories: List[Dict[str, Any]], +) -> Tuple[List[float], List[float], List[float]]: + """Compute accuracy and brevity rewards for GSM8K.""" + accuracy_reward_fn = GSM8KAccuracyReward() + brevity_reward_fn = GSM8KBrevityReward() - accuracy_rewards = accuracy_reward_fn(trajectories, []) - format_rewards = format_reward_fn(trajectories, []) - total_rewards = [a + f for a, f in zip(accuracy_rewards, format_rewards)] - return total_rewards, format_rewards, accuracy_rewards + accuracy_rewards = accuracy_reward_fn(trajectories) + brevity_rewards = brevity_reward_fn(trajectories) + total_rewards = [a + b for a, b in zip(accuracy_rewards, brevity_rewards)] + return total_rewards, brevity_rewards, accuracy_rewards def main(): - logger.info('Starting Math GRPO training...') + logger.info('Starting GSM8K GRPO training...') # Step 1: Prepare dataset and dataloader (client-side) - dataset = create_math_dataset() - dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) - template = Template(model_id=f'ms://{BASE_MODEL}') + dataset = create_gsm8k_dataset() + dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, num_workers=0) + template = Qwen3_5Template(model_id=f'ms://{BASE_MODEL}') logger.info('Dataset and template initialized') @@ -254,7 +165,7 @@ def main(): if step % SYNC_INTERVAL == 0: logger.info(f'Step {step}: Saving weights for sampler...') - sampling_client = (training_client.save_weights_and_get_sampling_client(name=f'Math-step-{step}')) + sampling_client = (training_client.save_weights_and_get_sampling_client(name=f'GSM8K-step-{step}')) logger.info(f'Step {step}: Sampling client ready') if sampling_client is None: @@ -317,14 +228,12 @@ def main(): completion_lengths.append(len(seq.tokens)) # ========== 4. Compute rewards ========== - total_rewards, format_rewards, accuracy_rewards = compute_rewards(trajectories) + total_rewards, brevity_rewards, accuracy_rewards = compute_rewards(trajectories) metrics.accumulate( - None, - None, completion_lengths=completion_lengths, rewards={ 'total': total_rewards, - 'format': format_rewards, + 'brevity': brevity_rewards, 'accuracy': accuracy_rewards, }) @@ -407,7 +316,7 @@ def main(): step += 1 # Save final checkpoint - save_future = training_client.save_state('Math-grpo-final') + save_future = training_client.save_state('gsm8k-grpo-final') save_result = save_future.result() logger.info(f'Saved final checkpoint to {save_result.path}') diff --git a/cookbook/client/tinker/self_host/short_math_grpo.py b/cookbook/client/tinker/self_host/short_math_grpo.py index f077c669..e19955a6 100644 --- a/cookbook/client/tinker/self_host/short_math_grpo.py +++ b/cookbook/client/tinker/self_host/short_math_grpo.py @@ -33,7 +33,7 @@ from twinkle.reward import GSM8KAccuracyReward from twinkle.reward.base import Reward from twinkle.metric import CompletionRewardMetric -from twinkle.template import Template +from twinkle.template import Qwen3_5Template logger = get_logger() @@ -41,9 +41,9 @@ BASE_MODEL = 'Qwen/Qwen3.5-4B' NUM_GENERATIONS = 4 MAX_NEW_TOKENS = 4096 -LEARNING_RATE = 1e-5 +LEARNING_RATE = 2e-5 MAX_STEPS = 1000 -BATCH_SIZE = 4 +BATCH_SIZE = 2 TEMPERATURE = 1.0 SYNC_INTERVAL = 1 # Save weights for sampler every N steps LORA_RANK = 16 @@ -92,7 +92,7 @@ def create_gsm8k_dataset(): """Create GSM8K dataset.""" dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train', data_slice=range(DATA_NUM))) dataset.set_template('Qwen3_5Template', model_id=f'ms://{BASE_MODEL}', max_length=4096, - truncation_strategy='delete', enable_thinking=False) + truncation_strategy='delete', enable_thinking=True) dataset.map(GSM8KProcessor(system=SYSTEM_PROMPT)) dataset.encode(add_generation_prompt=True) return dataset @@ -117,7 +117,7 @@ def main(): # Step 1: Prepare dataset and dataloader (client-side) dataset = create_gsm8k_dataset() dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, num_workers=0) - template = Template(model_id=f'ms://{BASE_MODEL}') + template = Qwen3_5Template(model_id=f'ms://{BASE_MODEL}') logger.info('Dataset and template initialized') diff --git a/cookbook/client/twinkle/self_host/self_congnition.py b/cookbook/client/twinkle/self_host/self_congnition.py index f382956f..c6394f47 100644 --- a/cookbook/client/twinkle/self_host/self_congnition.py +++ b/cookbook/client/twinkle/self_host/self_congnition.py @@ -14,7 +14,7 @@ from twinkle import get_logger from twinkle.dataset import DatasetMeta -from twinkle_client import init_twinkle_client +from twinkle import init_twinkle_client from twinkle_client.dataloader import DataLoader from twinkle_client.dataset import Dataset from twinkle_client.model import MultiLoraTransformersModel diff --git a/cookbook/client/twinkle/self_host/grpo.py b/cookbook/client/twinkle/self_host/short_math_grpo.py similarity index 68% rename from cookbook/client/twinkle/self_host/grpo.py rename to cookbook/client/twinkle/self_host/short_math_grpo.py index d87bfa77..03993c96 100644 --- a/cookbook/client/twinkle/self_host/grpo.py +++ b/cookbook/client/twinkle/self_host/short_math_grpo.py @@ -25,38 +25,84 @@ import gc import os +import re from peft import LoraConfig from typing import List, Tuple, Dict, Any +import swanlab + from twinkle import get_logger -from twinkle.reward import GSM8KAccuracyReward, GSM8KFormatReward +from twinkle.reward import GSM8KAccuracyReward +from twinkle.reward.base import Reward from twinkle.advantage import GRPOAdvantage from twinkle.dataset import DatasetMeta from twinkle.metric import CompletionRewardMetric -from twinkle_client import init_twinkle_client -from twinkle_client.dataloader import DataLoader -from twinkle_client.dataset import Dataset +from twinkle import init_twinkle_client +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset +from twinkle.preprocessor.llm import GSM8KProcessor from twinkle_client.model import MultiLoraTransformersModel from twinkle_client.sampler import vLLMSampler logger = get_logger() + +class GSM8KBrevityReward(Reward): + """Brevity reward: rewards shorter completions that contain a valid answer. + + Returns 0.0 if no valid answer format (\\boxed{} or ####). + Otherwise returns higher score for shorter completions (1.0 at <=200 chars). + """ + + def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]: + rewards = [] + for traj in trajectories: + messages = traj.get('messages', []) + completion = '' + for msg in reversed(messages): + if msg.get('role') == 'assistant': + completion = msg.get('content', '') + break + + has_answer = bool( + re.search(r'\\boxed\{[^}]+\}', completion) + or re.search(r'####\s*[\-\d,\.]+', completion) + ) + + if not has_answer: + rewards.append(0.0) + else: + length = len(completion) + if length <= 200: + rewards.append(1.0) + else: + rewards.append(max(0.0, 1.0 - (length - 200) / 3000)) + return rewards + # ========== Configuration ========== MODEL_ID = 'ms://Qwen/Qwen3.5-4B' NUM_GENERATIONS = 4 MAX_NEW_TOKENS = 1024 -LEARNING_RATE = 1e-5 -MAX_STEPS = 10 +LEARNING_RATE = 2e-5 +MAX_STEPS = 100 BATCH_SIZE = 2 TEMPERATURE = 1.0 SYNC_INTERVAL = 1 # Save weights for sampler every N steps -GRADIENT_ACCUMULATION_STEPS = 4 +GRADIENT_ACCUMULATION_STEPS = 1 +DATA_NUM = 2000 # Number of Math samples to use +USE_SWANLAB = True +SWANLAB_PROJECT = 'twinkle-grpo' +SWANLAB_EXPERIMENT_NAME = 'short-math-grpo' + + +SYSTEM_PROMPT = ('You are a helpful math assistant. Solve the problem with minimal but correct reasoning ' + 'and put your final answer within \\boxed{}.') def create_gsm8k_dataset(): - dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train')) - dataset.set_template('Qwen3_5Template', model_id=MODEL_ID, max_length=2048) - dataset.map('GSM8KProcessor') + dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train', data_slice=range(DATA_NUM))) + dataset.set_template('Qwen3_5Template', model_id=MODEL_ID, max_length=2048, enable_thinking=False) + dataset.map(GSM8KProcessor(system=SYSTEM_PROMPT)) dataset.encode(add_generation_prompt=True) return dataset @@ -64,24 +110,44 @@ def compute_rewards( trajectories: List[Dict[str, Any]], ) -> Tuple[List[float], List[float], List[float]]: accuracy_reward_fn = GSM8KAccuracyReward() - format_reward_fn = GSM8KFormatReward() + brevity_reward_fn = GSM8KBrevityReward() accuracy_rewards = accuracy_reward_fn(trajectories) - format_rewards = format_reward_fn(trajectories) - total_rewards = [a + f for a, f in zip(accuracy_rewards, format_rewards)] - return total_rewards, format_rewards, accuracy_rewards + brevity_rewards = brevity_reward_fn(trajectories) + total_rewards = [a + b for a, b in zip(accuracy_rewards, brevity_rewards)] + return total_rewards, brevity_rewards, accuracy_rewards def train(): + # Step 0: Initialize SwanLab if enabled + if USE_SWANLAB: + swanlab.login(api_key=os.environ.get('SWANLAB_API_KEY', '')) + swanlab.init( + project=SWANLAB_PROJECT, + experiment_name=SWANLAB_EXPERIMENT_NAME, + config={ + 'model_id': MODEL_ID, + 'num_generations': NUM_GENERATIONS, + 'max_new_tokens': MAX_NEW_TOKENS, + 'learning_rate': LEARNING_RATE, + 'max_steps': MAX_STEPS, + 'batch_size': BATCH_SIZE, + 'temperature': TEMPERATURE, + 'sync_interval': SYNC_INTERVAL, + 'gradient_accumulation_steps': GRADIENT_ACCUMULATION_STEPS, + }, + ) + logger.info('SwanLab initialized') + # Step 1: Initialize the Twinkle client client = init_twinkle_client( base_url='http://127.0.0.1:8000', - api_key=os.environ.get('MODELSCOPE_TOKEN'), + api_key='EMPTY_TOKEN', ) # Step 2: Prepare dataset and dataloader dataset = create_gsm8k_dataset() - dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) + dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, num_workers=0) # Step 3: Configure the training model model = MultiLoraTransformersModel(model_id=MODEL_ID) @@ -172,14 +238,14 @@ def train(): # ========== 3. Compute rewards ========== - total_rewards, format_rewards, accuracy_rewards = compute_rewards( + total_rewards, brevity_rewards, accuracy_rewards = compute_rewards( all_input_data ) metrics.accumulate( completion_lengths=all_completion_lengths, rewards={ 'total': total_rewards, - 'format': format_rewards, + 'brevity': brevity_rewards, 'accuracy': accuracy_rewards, }, ) @@ -217,6 +283,11 @@ def train(): log_dict.update(model.calculate_metric(is_training=True).result) log_dict['train/frac_reward_zero_std'] = frac_zero_std logger.info(f'Step {step}: {log_dict}') + + # Log metrics to SwanLab + if USE_SWANLAB and log_dict: + swanlab.log(log_dict, step=step) + step += 1 metrics.reset() diff --git a/src/twinkle/server/model/tinker_handlers.py b/src/twinkle/server/model/tinker_handlers.py index 88a59b2c..37d9df60 100644 --- a/src/twinkle/server/model/tinker_handlers.py +++ b/src/twinkle/server/model/tinker_handlers.py @@ -43,6 +43,7 @@ async def _create_adapter(): try: _model_id = await self.state.register_model(body.model_dump(), token=token, replica_id=self.replica_id) if body.lora_config: + # TODO: Make LoraConfig more flexible lora_cfg = LoraConfig(r=body.lora_config.rank, target_modules='all-linear') adapter_name = self.get_adapter_name(adapter_name=_model_id) self.register_resource(adapter_name, token, session_id=body.session_id) diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py index d3a74c1d..69a1fcd6 100644 --- a/src/twinkle/template/base.py +++ b/src/twinkle/template/base.py @@ -167,10 +167,13 @@ def concat_input_feature(self, prompt_input_feature: InputFeature, new_tokens: L result['input_ids'] = input_ids result['labels'] = labels if 'mm_token_type_ids' in result: - token_ids_shape = result['mm_token_type_ids'].shape - device = result['mm_token_type_ids'].device + mm_token_type_ids = result['mm_token_type_ids'] + if isinstance(mm_token_type_ids, list): + mm_token_type_ids = torch.tensor(mm_token_type_ids) + token_ids_shape = mm_token_type_ids.shape + device = mm_token_type_ids.device padded_tokens = torch.zeros((token_ids_shape[0], len(new_tokens))).to(device) - result['mm_token_type_ids'] = torch.cat((result['mm_token_type_ids'], padded_tokens), dim=1) + result['mm_token_type_ids'] = torch.cat((mm_token_type_ids, padded_tokens), dim=1) new_input_feature = self._invoke_post_pipeline([result])[0] result.update(new_input_feature) messages: List[Message] = result.get('messages') @@ -466,6 +469,8 @@ def _apply_chat_template(self, trajectory: Trajectory, add_generation_prompt: bo # Set default values for processor_kwargs if 'enable_thinking' not in kwargs: processor_kwargs['enable_thinking'] = self.enable_thinking + if 'padding' not in kwargs: + processor_kwargs['padding'] = False # Add remaining kwargs to processor_kwargs processor_kwargs.update(kwargs) @@ -473,7 +478,6 @@ def _apply_chat_template(self, trajectory: Trajectory, add_generation_prompt: bo inputs = self.processor.apply_chat_template( messages, tools=tools, - padding=False, return_dict=True, add_generation_prompt=add_generation_prompt, return_tensors='pt', From fe2858529c0dec5b17cf0e70e0c335b2193aed70 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Fri, 10 Apr 2026 16:56:12 +0800 Subject: [PATCH 2/6] update run.sh --- cookbook/client/server/megatron/run.sh | 31 ++++++++++++++++++- cookbook/client/server/megatron/server.py | 21 ------------- .../client/server/megatron/server_config.yaml | 1 + 3 files changed, 31 insertions(+), 22 deletions(-) delete mode 100644 cookbook/client/server/megatron/server.py diff --git a/cookbook/client/server/megatron/run.sh b/cookbook/client/server/megatron/run.sh index 14966ce9..958e0429 100644 --- a/cookbook/client/server/megatron/run.sh +++ b/cookbook/client/server/megatron/run.sh @@ -12,6 +12,8 @@ # --gpu-workers LIST GPU Worker 列表,分号分隔多个节点 (默认: 4,5,6,7:4) # --cpu-workers N CPU Worker 数量 (默认: 1) # --temp-dir DIR Ray 临时目录 (默认: /dashscope/caches/application/ray_logs) +# --save-dir DIR Twinkle 模型保存目录 (默认: /dashscope/caches/application/save) +# --server-config FILE Twinkle 服务器配置文件路径 (默认: /twinkle/cookbook/client/server/megatron/server_config.yaml) # --help 显示帮助信息 # # 示例: @@ -49,6 +51,8 @@ RAY_ADDRESS="127.0.0.1:$RAY_PORT" # --- 路径配置 --- DEFAULT_TEMP_DIR="/dashscope/caches/application/ray_logs" LOG_FILE="run.log" +DEFAULT_SAVE_DIR="/dashscope/caches/application/save" +DEFAULT_SERVER_CONFIG_FILE="/twinkle/cookbook/client/server/megatron/server_config.yaml" # --- Prometheus 监控配置 --- PROMETHEUS_BIN="/dashscope/caches/application/monitor/prometheus-3.10.0.linux-amd64/prometheus" @@ -67,6 +71,8 @@ HEAD_NODE="0,1,2,3" GPU_WORKERS_INPUT="4,5,6,7" CPU_WORKER_COUNT="1" TEMP_DIR="$DEFAULT_TEMP_DIR" +SAVE_DIR="$DEFAULT_SAVE_DIR" +SERVER_CONFIG_FILE="$DEFAULT_SERVER_CONFIG_FILE" # 解析命名参数 while [[ $# -gt 0 ]]; do @@ -103,6 +109,22 @@ while [[ $# -gt 0 ]]; do TEMP_DIR="${1#*=}" shift ;; + --save-dir) + SAVE_DIR="$2" + shift 2 + ;; + --save-dir=*) + SAVE_DIR="${1#*=}" + shift + ;; + --server-config) + SERVER_CONFIG_FILE="$2" + shift 2 + ;; + --server-config=*) + SERVER_CONFIG_FILE="${1#*=}" + shift + ;; --help|-h) echo "用法: ./run.sh [选项]" echo "" @@ -111,6 +133,8 @@ while [[ $# -gt 0 ]]; do echo " --gpu-workers LIST GPU Worker 列表,分号分隔多个节点 (默认: 4,5,6,7)" echo " --cpu-workers N CPU Worker 数量 (默认: 1)" echo " --temp-dir DIR Ray 临时目录" + echo " --save-dir DIR Twinkle 模型保存目录 (默认: $DEFAULT_SAVE_DIR)" + echo " --server-config FILE Twinkle 服务器配置文件路径 (默认: $DEFAULT_SERVER_CONFIG_FILE)" echo " --help, -h 显示帮助信息" echo "" echo "示例:" @@ -129,6 +153,9 @@ while [[ $# -gt 0 ]]; do esac done +# 将 SAVE_DIR export 给子进程(python server 通过环境变量读取) +export TWINKLE_DEFAULT_SAVE_DIR="$SAVE_DIR" + # 将分号分隔的字符串转为数组 if [ -z "$GPU_WORKERS_INPUT" ]; then GPU_WORKERS=() @@ -222,6 +249,8 @@ echo "" print_info "运行参数:" echo " - Ray 地址: $RAY_ADDRESS" echo " - 临时目录: $TEMP_DIR" +echo " - 保存目录: $TWINKLE_DEFAULT_SAVE_DIR" +echo " - 服务配置: $SERVER_CONFIG_FILE" echo " - 日志文件: $LOG_FILE" echo "" @@ -334,7 +363,7 @@ print_info "日志输出到: $LOG_FILE" echo "" # 启动服务器并实时显示日志 -nohup python server.py > "$LOG_FILE" 2>&1 & +nohup python -m twinkle.server --config "$SERVER_CONFIG_FILE" > "$LOG_FILE" 2>&1 & SERVER_PID=$! # 实时显示日志 diff --git a/cookbook/client/server/megatron/server.py b/cookbook/client/server/megatron/server.py deleted file mode 100644 index d6cb87c5..00000000 --- a/cookbook/client/server/megatron/server.py +++ /dev/null @@ -1,21 +0,0 @@ -# Twinkle Server Launcher - Tinker-Compatible Megatron Backend -# -# This script starts the Twinkle server with Tinker-compatible API support -# using the Megatron model backend. -# It reads the server_config.yaml in the same directory for all -# configuration (model, deployment settings, etc.). -# Run this script BEFORE running the client training script (lora.py). - -import os - -# Enable Ray debug mode for verbose logging during development -os.environ['TWINKLE_TRUST_REMOTE_CODE'] = '0' - -from twinkle.server import launch_server - -# Resolve the path to server_config.yaml relative to this script's location -file_dir = os.path.abspath(os.path.dirname(__file__)) -config_path = os.path.join(file_dir, 'server_config.yaml') - -# Launch the Twinkle server — this call blocks until the server is shut down -launch_server(config_path=config_path) diff --git a/cookbook/client/server/megatron/server_config.yaml b/cookbook/client/server/megatron/server_config.yaml index 6d584455..90c8d9ac 100644 --- a/cookbook/client/server/megatron/server_config.yaml +++ b/cookbook/client/server/megatron/server_config.yaml @@ -50,6 +50,7 @@ applications: enable_lora: true # Allow loading LoRA adapters during inference max_loras: 5 # Max allowed loras working on vLLM at the same time max_lora_rank: 32 # Support up to rank 64 LoRA adapters + enable_tower_connector_lora: true device_group: # Logical device group for the sampler name: sampler gpus_per_worker: 2 From 641c1792cca8a99e77170007debc3a31b9504880 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Fri, 10 Apr 2026 17:02:58 +0800 Subject: [PATCH 3/6] update run.sh --- cookbook/client/server/megatron/run.sh | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/cookbook/client/server/megatron/run.sh b/cookbook/client/server/megatron/run.sh index 958e0429..8c149db2 100644 --- a/cookbook/client/server/megatron/run.sh +++ b/cookbook/client/server/megatron/run.sh @@ -264,6 +264,28 @@ fi # 停止已有 Ray 集群和 Prometheus # ============================================ print_header "清理环境" + +# 停止 Twinkle server.py(twinkle.server 模块) +print_info "停止已有的 Twinkle Server..." +pkill -f "twinkle.server" 2>/dev/null || true + +# 停止 vLLM 进程 +print_info "停止已有的 vLLM 进程..." +pkill -f "vllm" 2>/dev/null || true + +# 等待上述进程退出 +sleep 2 + +# 若仍有残留则强制 SIGKILL +if pgrep -f "twinkle.server" > /dev/null 2>&1; then + print_warning "Twinkle Server 未退出,强制终止..." + pkill -9 -f "twinkle.server" 2>/dev/null || true +fi +if pgrep -f "vllm" > /dev/null 2>&1; then + print_warning "vLLM 进程未退出,强制终止..." + pkill -9 -f "vllm" 2>/dev/null || true +fi + print_info "停止已有的 Ray 集群..." ray stop --force 2>/dev/null || true From 0f4f46f6b10a7a10b9f65566d8281900eef47225 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Fri, 10 Apr 2026 17:14:58 +0800 Subject: [PATCH 4/6] update run.sh --- cookbook/client/server/megatron/run.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cookbook/client/server/megatron/run.sh b/cookbook/client/server/megatron/run.sh index 8c149db2..730a9a35 100644 --- a/cookbook/client/server/megatron/run.sh +++ b/cookbook/client/server/megatron/run.sh @@ -385,8 +385,10 @@ print_info "日志输出到: $LOG_FILE" echo "" # 启动服务器并实时显示日志 +touch "$LOG_FILE" # 预创建文件,避免 tail -f 在文件尚未写入时报错 nohup python -m twinkle.server --config "$SERVER_CONFIG_FILE" > "$LOG_FILE" 2>&1 & SERVER_PID=$! +print_success "Twinkle Server 已启动 (PID: $SERVER_PID)" # 实时显示日志 tail -f "$LOG_FILE" From 66b8b02b35e75528ffd2faf412e10ac8d152448c Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Fri, 10 Apr 2026 17:44:47 +0800 Subject: [PATCH 5/6] update --- cookbook/client/twinkle/modelscope/self_congnition.py | 6 +++--- cookbook/client/twinkle/self_host/self_congnition.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/cookbook/client/twinkle/modelscope/self_congnition.py b/cookbook/client/twinkle/modelscope/self_congnition.py index 81c5ab4d..aeef5606 100644 --- a/cookbook/client/twinkle/modelscope/self_congnition.py +++ b/cookbook/client/twinkle/modelscope/self_congnition.py @@ -14,9 +14,9 @@ from twinkle import get_logger from twinkle.dataset import DatasetMeta -from twinkle_client import init_twinkle_client -from twinkle_client.dataloader import DataLoader -from twinkle_client.dataset import Dataset +from twinkle import init_twinkle_client +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset from twinkle_client.model import MultiLoraTransformersModel logger = get_logger() diff --git a/cookbook/client/twinkle/self_host/self_congnition.py b/cookbook/client/twinkle/self_host/self_congnition.py index c6394f47..c4e14d30 100644 --- a/cookbook/client/twinkle/self_host/self_congnition.py +++ b/cookbook/client/twinkle/self_host/self_congnition.py @@ -15,8 +15,8 @@ from twinkle import get_logger from twinkle.dataset import DatasetMeta from twinkle import init_twinkle_client -from twinkle_client.dataloader import DataLoader -from twinkle_client.dataset import Dataset +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset from twinkle_client.model import MultiLoraTransformersModel logger = get_logger() From 664521159039eee91b41a8c396df836750db744e Mon Sep 17 00:00:00 2001 From: Yunlin Mao Date: Fri, 10 Apr 2026 17:46:44 +0800 Subject: [PATCH 6/6] Update src/twinkle/template/base.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/twinkle/template/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py index 69a1fcd6..ce9abd6a 100644 --- a/src/twinkle/template/base.py +++ b/src/twinkle/template/base.py @@ -168,8 +168,8 @@ def concat_input_feature(self, prompt_input_feature: InputFeature, new_tokens: L result['labels'] = labels if 'mm_token_type_ids' in result: mm_token_type_ids = result['mm_token_type_ids'] - if isinstance(mm_token_type_ids, list): - mm_token_type_ids = torch.tensor(mm_token_type_ids) + if not isinstance(mm_token_type_ids, torch.Tensor): + mm_token_type_ids = torch.as_tensor(mm_token_type_ids) token_ids_shape = mm_token_type_ids.shape device = mm_token_type_ids.device padded_tokens = torch.zeros((token_ids_shape[0], len(new_tokens))).to(device)