diff --git a/cookbook/client/tinker/transformer/grpo.py b/cookbook/client/tinker/grpo.py similarity index 99% rename from cookbook/client/tinker/transformer/grpo.py rename to cookbook/client/tinker/grpo.py index c08c6fe5..94145f2a 100644 --- a/cookbook/client/tinker/transformer/grpo.py +++ b/cookbook/client/tinker/grpo.py @@ -34,7 +34,7 @@ logger = get_logger() # ========== Configuration ========== -BASE_MODEL = 'Qwen/Qwen2.5-3B-Instruct' +BASE_MODEL = 'Qwen/Qwen2.5-7B-Instruct' NUM_GENERATIONS = 4 MAX_NEW_TOKENS = 1024 LEARNING_RATE = 1e-5 diff --git a/cookbook/client/tinker/transformer/gsm8k.py b/cookbook/client/tinker/gsm8k_grpo.py similarity index 90% rename from cookbook/client/tinker/transformer/gsm8k.py rename to cookbook/client/tinker/gsm8k_grpo.py index b7f86c61..139dd40e 100644 --- a/cookbook/client/tinker/transformer/gsm8k.py +++ b/cookbook/client/tinker/gsm8k_grpo.py @@ -17,7 +17,7 @@ # # The server must be running first (see server.py and server_config.yaml). # Requires both model and sampler services to be configured. - +import os import gc import re import numpy as np @@ -38,12 +38,12 @@ logger = get_logger() # ========== Configuration ========== -BASE_MODEL = 'Qwen/Qwen2.5-3B-Instruct' +BASE_MODEL = 'Qwen/Qwen2.5-7B-Instruct' NUM_GENERATIONS = 4 -MAX_NEW_TOKENS = 2048 -LEARNING_RATE = 1e-5 +MAX_NEW_TOKENS = 1024 +LEARNING_RATE = 1e-4 MAX_STEPS = 100 -BATCH_SIZE = 2 +BATCH_SIZE = 4 TEMPERATURE = 1.0 SYNC_INTERVAL = 1 # Save weights for sampler every N steps LORA_RANK = 8 @@ -56,6 +56,14 @@ "For example:\n ... reasoning ... \n#### 42" ) +# SwanLab experiment tracking +USE_SWANLAB = True +if USE_SWANLAB: + import swanlab + swanlab.login(api_key=os.environ['SWANLAB_API_KEY']) + swanlab.init(project="twinkle-gsm8k", config={ + 'model_id': BASE_MODEL, + }) class GSM8KProcessor(Preprocessor): """Preprocessor for GSM8K dataset. @@ -354,18 +362,15 @@ def main(): ob_len = len(prompt_ids) - 1 input_tokens = prompt_ids + sampled_tokens[:-1] target_tokens = [0] * ob_len + sampled_tokens + weights = [0] * ob_len + [1] * len(sampled_tokens) padded_advantages = [0.0] * ob_len + [advantage] * len(sampled_tokens) padded_logprobs = [0.0] * ob_len + logprobs - - # Verify lengths match - assert len(input_tokens) == len(target_tokens) == len(padded_logprobs) == len(padded_advantages), \ - f"Length mismatch: input={len(input_tokens)}, target={len(target_tokens)}, " \ - f"logprobs={len(padded_logprobs)}, advantages={len(padded_advantages)}" datum = types.Datum( model_input=types.ModelInput.from_ints(input_tokens), loss_fn_inputs={ 'target_tokens': target_tokens, + 'weights': weights, 'logprobs': types.TensorData.from_numpy(np.array(padded_logprobs, dtype=np.float32)), 'advantages': types.TensorData.from_numpy(np.array(padded_advantages, dtype=np.float32)), }, @@ -380,40 +385,22 @@ def main(): # Forward-backward pass with importance_sampling (GRPO) loss # The training data already contains logprobs and advantages for the GRPO loss - fwdbwd_future = training_client.forward_backward( - training_data, "importance_sampling") - optim_future = training_client.optim_step( - types.AdamParams(learning_rate=LEARNING_RATE)) - - fwdbwd_result = fwdbwd_future.result() - optim_result = optim_future.result() - - # Compute metrics from the forward-backward result - # For importance_sampling, we get logprobs and elementwise_loss - logprobs_list = [] - elementwise_losses = [] - for output in fwdbwd_result.loss_fn_outputs: - if output.get('logprobs') is not None: - logprobs_list.append(output['logprobs'].to_numpy()) - if output.get('elementwise_loss') is not None: - elementwise_losses.append(output['elementwise_loss'].to_numpy()) - - # Compute average loss per token (weighted by advantages) - if elementwise_losses: - all_losses = np.concatenate(elementwise_losses) - avg_loss = np.mean(all_losses) if len(all_losses) > 0 else 0.0 - else: - avg_loss = 0.0 + fwdbwd_result = training_client.forward_backward(training_data, "importance_sampling").result() + + optim_result = training_client.optim_step(types.AdamParams(learning_rate=LEARNING_RATE)).result() gc.collect() # ========== 7. Log ========== log_dict = metrics.calculate() - log_dict['train/loss_per_token'] = float(avg_loss) + if optim_result.metrics: + log_dict.update(optim_result.metrics) log_dict['train/frac_reward_zero_std'] = frac_zero_std log_dict['train/num_training_samples'] = len(training_data) logger.info(f"Step {step}: {log_dict}") step += 1 + if USE_SWANLAB: + swanlab.log(log_dict) # Save final checkpoint save_future = training_client.save_state("gsm8k-grpo-final") diff --git a/cookbook/client/tinker/transformer/lora.py b/cookbook/client/tinker/lora.py similarity index 98% rename from cookbook/client/tinker/transformer/lora.py rename to cookbook/client/tinker/lora.py index 44fc94c5..24ca2e52 100644 --- a/cookbook/client/tinker/transformer/lora.py +++ b/cookbook/client/tinker/lora.py @@ -36,7 +36,7 @@ # 2. A model id on hub: "/" # Example: # resume_path = "twinkle://20260131_170251-Qwen_Qwen2_5-0_5B-Instruct-7275126c/weights/pig-latin-lora-epoch-1" -# resume_path = "AlexEz/20260205_163645-Qwen_Qwen2_5-3B-Instruct-385d5c17_pig-latin-lora-epoch-1" +# resume_path = "AlexEz/20260205_163645-Qwen_Qwen2_5-7B-Instruct-385d5c17_pig-latin-lora-epoch-1" resume_path = "" print(f"Found {len(response.training_runs)} training runs") @@ -51,7 +51,7 @@ # Step 5: Create or resume a training client. # If resume_path is set, it restores both model weights and optimizer state. -base_model = "Qwen/Qwen2.5-3B-Instruct" +base_model = "Qwen/Qwen2.5-7B-Instruct" if not resume_path: training_client = service_client.create_lora_training_client( base_model=base_model diff --git a/cookbook/client/tinker/megatron/server_config.yaml b/cookbook/client/tinker/megatron/server_config.yaml index 08399d91..375a04c0 100644 --- a/cookbook/client/tinker/megatron/server_config.yaml +++ b/cookbook/client/tinker/megatron/server_config.yaml @@ -10,7 +10,7 @@ proxy_location: EveryNode # HTTP listener settings http_options: host: 0.0.0.0 # Listen on all network interfaces - port: 8000 # Port number for the server + port: 9000 # Port number for the server # Applications: each entry defines a service component deployed on the server applications: @@ -50,12 +50,13 @@ applications: enable_lora: true # Allow loading LoRA adapters during inference device_group: # Logical device group for the sampler name: sampler - gpus_per_worker: 1 - ranks: [0,1,2,3] # GPU rank indices to use + gpus_per_worker: 2 + ranks: [0,1,2,3,4,5,6,7] # GPU rank indices to use device_type: cuda device_mesh: device_type: cuda dp_size: 4 + tp_size: 2 deployments: - name: SamplerManagement autoscaling_config: @@ -67,7 +68,7 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" - DEVICE_COUNT_PER_PHYSICAL_NODE: "8" + DEVICE_COUNT_PER_PHYSICAL_NODE: "16" # 2. Model Service (commented out) - Would host the base model for training. # Uncomment and configure if you need a training model worker. @@ -80,12 +81,13 @@ applications: nproc_per_node: 4 # Number of GPU processes per node device_group: name: model - ranks: [4,5,6,7] # GPU rank indices + ranks: [8,9,10,11,12,13,14,15] # GPU rank indices device_type: cuda device_mesh: device_type: cuda dp_size: 2 tp_size: 2 + pp:size: 2 ep_size: 2 queue_config: @@ -105,4 +107,4 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" - DEVICE_COUNT_PER_PHYSICAL_NODE: "8" + DEVICE_COUNT_PER_PHYSICAL_NODE: "16" diff --git a/cookbook/client/tinker/megatron/server_config_7b.yaml b/cookbook/client/tinker/megatron/server_config_7b.yaml new file mode 100644 index 00000000..1dac6d6c --- /dev/null +++ b/cookbook/client/tinker/megatron/server_config_7b.yaml @@ -0,0 +1,105 @@ +# Twinkle Server Configuration - Tinker-Compatible Transformers Backend + +# Server protocol type: "tinker" enables the Tinker-compatible API +server_type: tinker + +# proxy_location: determines where the HTTP proxy runs. +# "EveryNode" means each Ray node runs its own proxy (good for multi-node). +proxy_location: EveryNode + +# HTTP listener settings +http_options: + host: 0.0.0.0 # Listen on all network interfaces + port: 8000 # Port number for the server + +# Applications: each entry defines a service component deployed on the server +applications: + + # 1. TinkerCompatServer - The central API server + # Handles client connections, training run tracking, checkpoint listing. + - name: server + route_prefix: /api/v1 # API endpoint prefix (Tinker-compatible) + import_path: server # Python module to import + args: + + deployments: + - name: TinkerCompatServer + autoscaling_config: + min_replicas: 1 # Minimum number of replicas + max_replicas: 1 # Maximum number of replicas + target_ongoing_requests: 128 # Target concurrent requests per replica + ray_actor_options: + num_cpus: 0.1 # CPU resources allocated to this actor + + # 2. Model Service (commented out) - Would host the base model for training. + # Uncomment and configure if you need a training model worker. + - name: models-Qwen2.5-7B-Instruct + route_prefix: /api/v1/model/Qwen/Qwen2.5-7B-Instruct + import_path: model + args: + use_megatron: true + model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope model identifier + max_length: 10240 + nproc_per_node: 2 # Number of GPU processes per node + device_group: + name: model + ranks: [0,1] # GPU rank indices + device_type: cuda + device_mesh: + device_type: cuda + dp_size: 2 + queue_config: + rps_limit: 100 # Max requests per second + tps_limit: 100000 # Max tokens per second + adapter_config: + per_token_adapter_limit: 30 # Max concurrent LoRA adapters + adapter_timeout: 1800 # Seconds before idle adapter unload + deployments: + - name: ModelManagement + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 16 + ray_actor_options: + num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_TRUST_REMOTE_CODE: "0" + DEVICE_COUNT_PER_PHYSICAL_NODE: "8" + + # 3. Sampler Service - Runs inference / sampling using vLLM engine + # Used for generating text from the model (e.g., evaluating LoRA results). + - name: sampler-Qwen2.5-7B-Instruct + route_prefix: /api/v1/sampler/Qwen/Qwen2.5-7B-Instruct + import_path: sampler + args: + model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope model identifier + nproc_per_node: 2 # Number of GPU processes per node + sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler) + engine_args: # vLLM engine-specific settings + max_model_len: 4096 # Maximum sequence length the engine supports + gpu_memory_utilization: 0.5 # Fraction of GPU memory to use (0.0-1.0) + enable_lora: true # Allow loading LoRA adapters during inference + logprobs_mode: processed_logprobs # Logprobs mode for sampling results + device_group: # Logical device group for the sampler + name: sampler + ranks: [2] # GPU rank indices to use + device_type: cuda + device_mesh: + device_type: cuda + dp_size: 1 + queue_config: + rps_limit: 100 # Max requests per second + tps_limit: 100000 # Max tokens per second + deployments: + - name: SamplerManagement + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 16 + ray_actor_options: + num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_TRUST_REMOTE_CODE: "0" + DEVICE_COUNT_PER_PHYSICAL_NODE: "8" diff --git a/cookbook/client/tinker/transformer/sample.py b/cookbook/client/tinker/sample.py similarity index 98% rename from cookbook/client/tinker/transformer/sample.py rename to cookbook/client/tinker/sample.py index 84faa689..4752c5b9 100644 --- a/cookbook/client/tinker/transformer/sample.py +++ b/cookbook/client/tinker/sample.py @@ -9,7 +9,7 @@ from modelscope import AutoTokenizer # Step 1: Define the base model and connect to the server -base_model = "Qwen/Qwen2.5-3B-Instruct" +base_model = "Qwen/Qwen2.5-7B-Instruct" service_client = init_tinker_compat_client(base_url='http://localhost:8000', api_key="tml-EMPTY_TOKEN") # Step 2: Create a sampling client by loading weights from a saved checkpoint. diff --git a/cookbook/client/tinker/transformer/self_congnition.py b/cookbook/client/tinker/self_congnition.py similarity index 95% rename from cookbook/client/tinker/transformer/self_congnition.py rename to cookbook/client/tinker/self_congnition.py index e1f5b7a3..d846b5f5 100644 --- a/cookbook/client/tinker/transformer/self_congnition.py +++ b/cookbook/client/tinker/self_congnition.py @@ -18,7 +18,7 @@ from modelscope import AutoTokenizer # The base model to fine-tune / evaluate -base_model = "Qwen/Qwen2.5-3B-Instruct" +base_model = "Qwen/Qwen2.5-7B-Instruct" def train(): @@ -83,7 +83,7 @@ def eval(): # Step 1: Load the trained LoRA checkpoint for inference # Path to a previously saved LoRA checkpoint (twinkle:// URI) - weight_path = "twinkle://20260207_110850-Qwen_Qwen2_5-0_5B-Instruct-ce7e819f/weights/twinkle-lora-2" + weight_path = "twinkle://20260211_112719-Qwen_Qwen2_5-7B-Instruct-a74a4826/weights/twinkle-lora-2" # Connect to the server and create a sampling client with the trained weights service_client = init_tinker_compat_client(base_url='http://localhost:8000') @@ -136,5 +136,5 @@ def eval(): if __name__ == "__main__": - # train() # Uncomment to run training - eval() # Run evaluation / inference + train() # Uncomment to run training + # eval() # Run evaluation / inference diff --git a/cookbook/client/tinker/transformer/server.py b/cookbook/client/tinker/transformer/server.py index f8669622..573aa6f8 100644 --- a/cookbook/client/tinker/transformer/server.py +++ b/cookbook/client/tinker/transformer/server.py @@ -7,8 +7,6 @@ import os -# Enable Ray debug mode for verbose logging during development -# os.environ['RAY_DEBUG'] = '1' os.environ['TWINKLE_TRUST_REMOTE_CODE'] = '0' from twinkle.server import launch_server diff --git a/cookbook/client/tinker/transformer/server_config.yaml b/cookbook/client/tinker/transformer/server_config.yaml index fbdd29c0..00e57387 100644 --- a/cookbook/client/tinker/transformer/server_config.yaml +++ b/cookbook/client/tinker/transformer/server_config.yaml @@ -33,24 +33,24 @@ applications: # 2. Model Service (commented out) - Would host the base model for training. # Uncomment and configure if you need a training model worker. - - name: models-Qwen2.5-3B-Instruct - route_prefix: /api/v1/model/Qwen/Qwen2.5-3B-Instruct + - name: models-Qwen2.5-7B-Instruct + route_prefix: /api/v1/model/Qwen/Qwen2.5-7B-Instruct import_path: model args: use_megatron: false # Use HuggingFace Transformers backend - model_id: "ms://Qwen/Qwen2.5-3B-Instruct" # ModelScope model identifier + model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope model identifier + max_length: 10240 nproc_per_node: 2 # Number of GPU processes per node device_group: name: model - ranks: [0, 1] # GPU rank indices + ranks: [0,1] # GPU rank indices device_type: cuda device_mesh: device_type: cuda - mesh: [0, 1] - mesh_dim_names: ['dp'] # 'dp' = data parallel + dp_size: 2 queue_config: rps_limit: 100 # Max requests per second - tps_limit: 10000 # Max tokens per second + tps_limit: 100000 # Max tokens per second adapter_config: per_token_adapter_limit: 30 # Max concurrent LoRA adapters adapter_timeout: 1800 # Seconds before idle adapter unload @@ -62,28 +62,35 @@ applications: target_ongoing_requests: 16 ray_actor_options: num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_TRUST_REMOTE_CODE: "0" + DEVICE_COUNT_PER_PHYSICAL_NODE: "8" # 3. Sampler Service - Runs inference / sampling using vLLM engine # Used for generating text from the model (e.g., evaluating LoRA results). - - name: sampler-Qwen2.5-3B-Instruct - route_prefix: /api/v1/sampler/Qwen/Qwen2.5-3B-Instruct + - name: sampler-Qwen2.5-7B-Instruct + route_prefix: /api/v1/sampler/Qwen/Qwen2.5-7B-Instruct import_path: sampler args: - model_id: "ms://Qwen/Qwen2.5-3B-Instruct" # ModelScope model identifier - nproc_per_node: 1 # Number of GPU processes per node + model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope model identifier + nproc_per_node: 2 # Number of GPU processes per node sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler) engine_args: # vLLM engine-specific settings max_model_len: 4096 # Maximum sequence length the engine supports gpu_memory_utilization: 0.5 # Fraction of GPU memory to use (0.0-1.0) enable_lora: true # Allow loading LoRA adapters during inference + logprobs_mode: processed_logprobs # Logprobs mode for sampling results device_group: # Logical device group for the sampler name: sampler - ranks: [0] # GPU rank indices to use + ranks: [2] # GPU rank indices to use device_type: cuda device_mesh: device_type: cuda - mesh: [0] - mesh_dim_names: ['dp'] + dp_size: 1 + queue_config: + rps_limit: 100 # Max requests per second + tps_limit: 100000 # Max tokens per second deployments: - name: SamplerManagement autoscaling_config: @@ -92,4 +99,7 @@ applications: target_ongoing_requests: 16 ray_actor_options: num_cpus: 0.1 - num_gpus: 1 # Sampler needs a full GPU for inference + runtime_env: + env_vars: + TWINKLE_TRUST_REMOTE_CODE: "0" + DEVICE_COUNT_PER_PHYSICAL_NODE: "8" diff --git a/cookbook/client/twinkle/transformer/grpo.py b/cookbook/client/twinkle/grpo.py similarity index 100% rename from cookbook/client/twinkle/transformer/grpo.py rename to cookbook/client/twinkle/grpo.py diff --git a/cookbook/client/twinkle/transformer/sampler.py b/cookbook/client/twinkle/sample.py similarity index 100% rename from cookbook/client/twinkle/transformer/sampler.py rename to cookbook/client/twinkle/sample.py diff --git a/cookbook/client/twinkle/transformer/lora.py b/cookbook/client/twinkle/self_congnition.py similarity index 100% rename from cookbook/client/twinkle/transformer/lora.py rename to cookbook/client/twinkle/self_congnition.py diff --git a/src/twinkle/model/transformers/multi_lora_transformers.py b/src/twinkle/model/transformers/multi_lora_transformers.py index af64fdd0..ff2603c9 100644 --- a/src/twinkle/model/transformers/multi_lora_transformers.py +++ b/src/twinkle/model/transformers/multi_lora_transformers.py @@ -222,6 +222,11 @@ def add_metric(self, metric_cls: Union[Metric, str], is_training: Optional[bool] self._check_adapter_valid(kwargs.get("adapter_name")) super().add_metric(metric_cls, is_training, **kwargs) + @remote_function(collect='first', lazy_collect=False) + def calculate_metric(self, is_training, **kwargs): + self._check_adapter_valid(kwargs.get("adapter_name")) + return super().calculate_metric(is_training, **kwargs) + @remote_function() def remove_adapter(self, adapter_name: str): if adapter_name in self.optimizer_group: diff --git a/src/twinkle/server/tinker/common/compat_base.py b/src/twinkle/server/tinker/common/compat_base.py new file mode 100644 index 00000000..60d4f5ce --- /dev/null +++ b/src/twinkle/server/tinker/common/compat_base.py @@ -0,0 +1,92 @@ +import torch +import numpy as np +from typing import List +from tinker import types +from twinkle.template import Template + + +def collect_forward_backward_results(results): + """Custom collect function for forward_backward that handles list [outputs, loss]. + + Args: + results: List of lists from each worker, where each list is [outputs_list, loss_float] + + Returns: + List of [flattened_outputs, averaged_loss] + """ + if not results: + return results + + # results is a list of lists: [[outputs1, loss1], [outputs2, loss2], ...] + # Flatten outputs (first element of each list) + all_outputs = [] + all_losses = [] + for result in results: + outputs, loss = result + all_outputs.extend(outputs) + all_losses.append(loss) + + # Average the losses + avg_loss = float(np.mean(all_losses)) + + return [all_outputs, avg_loss] + + +def clean_metrics(metrics: dict) -> dict: + cleaned = {} + for key, value in metrics.items(): + if isinstance(value, str): + import re + match = re.match(r'^([+-]?\d*\.?\d+)', value.strip()) + if match: + cleaned[key] = float(match.group(1)) + else: + cleaned[key] = value + return cleaned + + +class TwinkleCompatModelBase: + """Base class containing common logic for Twinkle compatibility wrappers.""" + + def get_template(self, adapter_name: str) -> Template: + return self.optimizer_group[adapter_name].template + + @staticmethod + def _get_forward_output(inputs: List[types.Datum], logits: torch.Tensor) -> List[dict]: + """Convert raw logits to the expected output format with logprobs and elementwise_loss.""" + results = [] + for i, feature in enumerate(inputs): + # Ensure 1D shape and correct device to avoid dimension mismatch and device errors + labels = feature.loss_fn_inputs['target_tokens'].to_torch( + ).long().view(-1).to(logits.device) # shape (seq_len,) + weights = feature.loss_fn_inputs['weights'].to_torch( + ).view(-1).to(logits.device) # shape (seq_len,) + + # Slice logits to match the sequence length of labels + # Labels are assumed to be already shifted/aligned with logits + seq_len = labels.numel() + + # Check if index is within logits bounds + if i < logits.shape[0]: + feature_logits = logits[i, :seq_len, :] + + # Calculate log probs for all labels + # Apply log_softmax to convert raw logits to log-probabilities + feature_log_probs = torch.log_softmax(feature_logits, dim=-1) + token_log_probs = feature_log_probs.gather( + dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) + + # elementwise_loss: positive NLL loss (0.0 where masked) + elementwise_loss = -token_log_probs * weights + + results.append({ + 'logprobs': types.TensorData.from_torch(token_log_probs), + 'elementwise_loss': types.TensorData.from_torch(elementwise_loss) + }) + else: + # Handle case where batch index exceeds logits batch size + results.append({ + 'logprobs': None, + 'elementwise_loss': None + }) + return results diff --git a/src/twinkle/server/tinker/common/megatron_model.py b/src/twinkle/server/tinker/common/megatron_model.py index c744800c..e76ce706 100644 --- a/src/twinkle/server/tinker/common/megatron_model.py +++ b/src/twinkle/server/tinker/common/megatron_model.py @@ -1,13 +1,13 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import numpy as np + import torch from typing import List, TYPE_CHECKING, Tuple, Optional, Any from tinker import types -from twinkle.template import Template from twinkle import remote_class, remote_function from twinkle.utils import exists, requires from .datum import datum_to_input_feature, extract_rl_feature from .io_utils import create_checkpoint_manager +from .compat_base import TwinkleCompatModelBase, collect_forward_backward_results, clean_metrics if TYPE_CHECKING: @@ -22,35 +22,8 @@ def __init__(self, *args, **kwargs): requires('megatron_core') -def _collect_forward_backward_results(results): - """Custom collect function for forward_backward that handles list [outputs, loss]. - - Args: - results: List of lists from each worker, where each list is [outputs_list, loss_float] - - Returns: - List of [flattened_outputs, averaged_loss] - """ - if not results: - return results - - # results is a list of lists: [[outputs1, loss1], [outputs2, loss2], ...] - # Flatten outputs (first element of each list) - all_outputs = [] - all_losses = [] - for result in results: - outputs, loss = result - all_outputs.extend(outputs) - all_losses.append(loss) - - # Average the losses - avg_loss = float(np.mean(all_losses)) - - return [all_outputs, avg_loss] - - @remote_class(execute='all') -class TwinkleCompatMegatronModel(_MegatronBase): +class TwinkleCompatMegatronModel(_MegatronBase, TwinkleCompatModelBase): """ Compatibility wrapper around :class:`MultiLoraMegatronModel` for Twinkle/Tinker. @@ -76,7 +49,7 @@ class TwinkleCompatMegatronModel(_MegatronBase): This wrapper provides a direct forward_backward interface. """ - @remote_function(dispatch='slice_dp', collect=_collect_forward_backward_results, sync=True) + @remote_function(dispatch='slice_dp', collect=collect_forward_backward_results, sync=True) def forward_backward(self, *, inputs: List[types.Datum], adapter_name: str, loss_fn: str, **kwargs): """Combined forward and backward pass. @@ -179,6 +152,13 @@ def step(self, *, adam_params: types.AdamParams, **kwargs): # Zero gradients super().zero_grad(**kwargs) + + @remote_function(collect='first', lazy_collect=False) + def calculate_metric(self, is_training, **kwargs): + metric = super().calculate_metric(is_training, **kwargs) + return clean_metrics(metric) + + @remote_function(dispatch='all', sync=True) def load(self, checkpoint_dir: str, **kwargs): """ @@ -206,43 +186,3 @@ def load(self, checkpoint_dir: str, **kwargs): # Load from hub return super().load(name=resolved.checkpoint_name, **kwargs) - def get_template(self, adapter_name: str) -> Template: - return self.optimizer_group[adapter_name].template - - @staticmethod - def _get_forward_output(inputs: List[types.Datum], logits: torch.Tensor) -> List[dict]: - """Convert raw logits to the expected output format with logprobs and elementwise_loss.""" - results = [] - for i, feature in enumerate(inputs): - # Ensure 1D shape and correct device to avoid dimension mismatch and device errors - labels = feature.loss_fn_inputs['target_tokens'].to_torch( - ).long().view(-1).to(logits.device) # shape (seq_len,) - weights = feature.loss_fn_inputs['weights'].to_torch( - ).view(-1).to(logits.device) # shape (seq_len,) - - # Slice logits to match the sequence length of labels - # Labels are assumed to be already shifted/aligned with logits - seq_len = labels.numel() - if i < logits.shape[0]: - feature_logits = logits[i, :seq_len, :] - - # Calculate log probs for all labels - # Apply log_softmax to convert raw logits to log-probabilities - feature_log_probs = torch.log_softmax(feature_logits, dim=-1) - token_log_probs = feature_log_probs.gather( - dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) - - # elementwise_loss: positive NLL loss (0.0 where masked) - elementwise_loss = -token_log_probs * weights - - results.append({ - 'logprobs': types.TensorData.from_torch(token_log_probs), - 'elementwise_loss': types.TensorData.from_torch(elementwise_loss) - }) - else: - # Handle case where batch index exceeds logits batch size - results.append({ - 'logprobs': None, - 'elementwise_loss': None - }) - return results diff --git a/src/twinkle/server/tinker/common/transformers_model.py b/src/twinkle/server/tinker/common/transformers_model.py index 4110dd65..10a08a57 100644 --- a/src/twinkle/server/tinker/common/transformers_model.py +++ b/src/twinkle/server/tinker/common/transformers_model.py @@ -1,43 +1,14 @@ -import torch -import numpy as np from tinker import types from typing import List -from twinkle.template import Template from twinkle.model import MultiLoraTransformersModel from twinkle import remote_class, remote_function from .datum import datum_to_input_feature, extract_rl_feature from .io_utils import create_checkpoint_manager - - -def _collect_forward_backward_results(results): - """Custom collect function for forward_backward that handles list [outputs, loss]. - - Args: - results: List of lists from each worker, where each list is [outputs_list, loss_float] - - Returns: - List of [flattened_outputs, averaged_loss] - """ - if not results: - return results - - # results is a list of lists: [[outputs1, loss1], [outputs2, loss2], ...] - # Flatten outputs (first element of each list) - all_outputs = [] - all_losses = [] - for result in results: - outputs, loss = result - all_outputs.extend(outputs) - all_losses.append(loss) - - # Average the losses - avg_loss = float(np.mean(all_losses)) - - return [all_outputs, avg_loss] +from .compat_base import TwinkleCompatModelBase, collect_forward_backward_results, clean_metrics @remote_class() -class TwinkleCompatTransformersModel(MultiLoraTransformersModel): +class TwinkleCompatTransformersModel(MultiLoraTransformersModel, TwinkleCompatModelBase): """ Compatibility wrapper around :class:`MultiLoraTransformersModel` for Twinkle/Tinker. @@ -83,7 +54,7 @@ def forward_only(self, *, inputs: List[types.Datum], **kwargs): results = self._get_forward_output(inputs, logits) return results - @remote_function(dispatch='slice_dp', collect=_collect_forward_backward_results) + @remote_function(dispatch='slice_dp', collect=collect_forward_backward_results) def forward_backward(self, *, inputs: List[types.Datum], adapter_name: str, loss_fn: str, **kwargs): # Set loss first based on loss_fn if loss_fn == 'cross_entropy': @@ -102,7 +73,7 @@ def forward_backward(self, *, inputs: List[types.Datum], adapter_name: str, loss # Convert Datum to InputFeature input_features = datum_to_input_feature(inputs, template) - + # Forward pass outputs = super().forward(inputs=input_features, adapter_name=adapter_name, **kwargs) @@ -139,6 +110,11 @@ def step(self, *, adam_params: types.AdamParams, **kwargs): # Zero gradients super().zero_grad(**kwargs) + @remote_function(collect='first', lazy_collect=False) + def calculate_metric(self, is_training, **kwargs): + metric = super().calculate_metric(is_training, **kwargs) + return clean_metrics(metric) + @remote_function() def load(self, checkpoint_dir: str, **kwargs): """ @@ -165,36 +141,3 @@ def load(self, checkpoint_dir: str, **kwargs): else: # Load from hub return super().load(name=resolved.checkpoint_name, **kwargs) - - def get_template(self, adapter_name: str) -> Template: - return self.optimizer_group[adapter_name].template - - @staticmethod - def _get_forward_output(inputs: List[types.Datum], logits: torch.Tensor) -> List[dict]: - results = [] - for i, feature in enumerate(inputs): - # Ensure 1D shape and correct device to avoid dimension mismatch and device errors - labels = feature.loss_fn_inputs['target_tokens'].to_torch( - ).long().view(-1).to(logits.device) # shape (seq_len,) - weights = feature.loss_fn_inputs['weights'].to_torch( - ).view(-1).to(logits.device) # shape (seq_len,) - - # Slice logits to match the sequence length of labels - # Labels are assumed to be already shifted/aligned with logits - seq_len = labels.numel() - feature_logits = logits[i, :seq_len, :] - - # Calculate log probs for all labels - # Apply log_softmax to convert raw logits to log-probabilities - feature_log_probs = torch.log_softmax(feature_logits, dim=-1) - token_log_probs = feature_log_probs.gather( - dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) - - # elementwise_loss: positive NLL loss (0.0 where masked) - elementwise_loss = -token_log_probs * weights - - results.append({ - 'logprobs': types.TensorData.from_torch(token_log_probs), - 'elementwise_loss': types.TensorData.from_torch(elementwise_loss) - }) - return results diff --git a/src/twinkle/server/tinker/model.py b/src/twinkle/server/tinker/model.py index 54a9717b..671d5197 100644 --- a/src/twinkle/server/tinker/model.py +++ b/src/twinkle/server/tinker/model.py @@ -421,7 +421,8 @@ async def _do_optim(): self.model.step(adam_params=body.adam_params, adapter_name=adapter_name) - return types.OptimStepResponse(metrics=None) + metrics = self.model.calculate_metric(is_training=True, adapter_name=adapter_name) + return types.OptimStepResponse(metrics=metrics) except Exception: logger.error(traceback.format_exc()) return types.RequestFailedResponse( diff --git a/src/twinkle/server/tinker/server.py b/src/twinkle/server/tinker/server.py index 0b292a72..3e74e63b 100644 --- a/src/twinkle/server/tinker/server.py +++ b/src/twinkle/server/tinker/server.py @@ -95,6 +95,7 @@ def __init__(self, supported_models: Optional[List[types.SupportedModel]] = None types.SupportedModel(model_name="Qwen/Qwen2.5-3B-Instruct"), types.SupportedModel(model_name="Qwen/Qwen2.5-7B-Instruct"), types.SupportedModel(model_name="Qwen/Qwen2.5-72B-Instruct"), + types.SupportedModel(model_name="Qwen/Qwen3-30B-A3B-Instruct-2507"), ] # Lock for ModelScope config file operations (login writes, get_user_info reads) self._modelscope_config_lock = asyncio.Lock()