diff --git a/cookbook/client/tinker/transformer/grpo.py b/cookbook/client/tinker/grpo.py
similarity index 99%
rename from cookbook/client/tinker/transformer/grpo.py
rename to cookbook/client/tinker/grpo.py
index c08c6fe5..94145f2a 100644
--- a/cookbook/client/tinker/transformer/grpo.py
+++ b/cookbook/client/tinker/grpo.py
@@ -34,7 +34,7 @@
logger = get_logger()
# ========== Configuration ==========
-BASE_MODEL = 'Qwen/Qwen2.5-3B-Instruct'
+BASE_MODEL = 'Qwen/Qwen2.5-7B-Instruct'
NUM_GENERATIONS = 4
MAX_NEW_TOKENS = 1024
LEARNING_RATE = 1e-5
diff --git a/cookbook/client/tinker/transformer/gsm8k.py b/cookbook/client/tinker/gsm8k_grpo.py
similarity index 90%
rename from cookbook/client/tinker/transformer/gsm8k.py
rename to cookbook/client/tinker/gsm8k_grpo.py
index b7f86c61..139dd40e 100644
--- a/cookbook/client/tinker/transformer/gsm8k.py
+++ b/cookbook/client/tinker/gsm8k_grpo.py
@@ -17,7 +17,7 @@
#
# The server must be running first (see server.py and server_config.yaml).
# Requires both model and sampler services to be configured.
-
+import os
import gc
import re
import numpy as np
@@ -38,12 +38,12 @@
logger = get_logger()
# ========== Configuration ==========
-BASE_MODEL = 'Qwen/Qwen2.5-3B-Instruct'
+BASE_MODEL = 'Qwen/Qwen2.5-7B-Instruct'
NUM_GENERATIONS = 4
-MAX_NEW_TOKENS = 2048
-LEARNING_RATE = 1e-5
+MAX_NEW_TOKENS = 1024
+LEARNING_RATE = 1e-4
MAX_STEPS = 100
-BATCH_SIZE = 2
+BATCH_SIZE = 4
TEMPERATURE = 1.0
SYNC_INTERVAL = 1 # Save weights for sampler every N steps
LORA_RANK = 8
@@ -56,6 +56,14 @@
"For example:\n ... reasoning ... \n#### 42"
)
+# SwanLab experiment tracking
+USE_SWANLAB = True
+if USE_SWANLAB:
+ import swanlab
+ swanlab.login(api_key=os.environ['SWANLAB_API_KEY'])
+ swanlab.init(project="twinkle-gsm8k", config={
+ 'model_id': BASE_MODEL,
+ })
class GSM8KProcessor(Preprocessor):
"""Preprocessor for GSM8K dataset.
@@ -354,18 +362,15 @@ def main():
ob_len = len(prompt_ids) - 1
input_tokens = prompt_ids + sampled_tokens[:-1]
target_tokens = [0] * ob_len + sampled_tokens
+ weights = [0] * ob_len + [1] * len(sampled_tokens)
padded_advantages = [0.0] * ob_len + [advantage] * len(sampled_tokens)
padded_logprobs = [0.0] * ob_len + logprobs
-
- # Verify lengths match
- assert len(input_tokens) == len(target_tokens) == len(padded_logprobs) == len(padded_advantages), \
- f"Length mismatch: input={len(input_tokens)}, target={len(target_tokens)}, " \
- f"logprobs={len(padded_logprobs)}, advantages={len(padded_advantages)}"
datum = types.Datum(
model_input=types.ModelInput.from_ints(input_tokens),
loss_fn_inputs={
'target_tokens': target_tokens,
+ 'weights': weights,
'logprobs': types.TensorData.from_numpy(np.array(padded_logprobs, dtype=np.float32)),
'advantages': types.TensorData.from_numpy(np.array(padded_advantages, dtype=np.float32)),
},
@@ -380,40 +385,22 @@ def main():
# Forward-backward pass with importance_sampling (GRPO) loss
# The training data already contains logprobs and advantages for the GRPO loss
- fwdbwd_future = training_client.forward_backward(
- training_data, "importance_sampling")
- optim_future = training_client.optim_step(
- types.AdamParams(learning_rate=LEARNING_RATE))
-
- fwdbwd_result = fwdbwd_future.result()
- optim_result = optim_future.result()
-
- # Compute metrics from the forward-backward result
- # For importance_sampling, we get logprobs and elementwise_loss
- logprobs_list = []
- elementwise_losses = []
- for output in fwdbwd_result.loss_fn_outputs:
- if output.get('logprobs') is not None:
- logprobs_list.append(output['logprobs'].to_numpy())
- if output.get('elementwise_loss') is not None:
- elementwise_losses.append(output['elementwise_loss'].to_numpy())
-
- # Compute average loss per token (weighted by advantages)
- if elementwise_losses:
- all_losses = np.concatenate(elementwise_losses)
- avg_loss = np.mean(all_losses) if len(all_losses) > 0 else 0.0
- else:
- avg_loss = 0.0
+ fwdbwd_result = training_client.forward_backward(training_data, "importance_sampling").result()
+
+ optim_result = training_client.optim_step(types.AdamParams(learning_rate=LEARNING_RATE)).result()
gc.collect()
# ========== 7. Log ==========
log_dict = metrics.calculate()
- log_dict['train/loss_per_token'] = float(avg_loss)
+ if optim_result.metrics:
+ log_dict.update(optim_result.metrics)
log_dict['train/frac_reward_zero_std'] = frac_zero_std
log_dict['train/num_training_samples'] = len(training_data)
logger.info(f"Step {step}: {log_dict}")
step += 1
+ if USE_SWANLAB:
+ swanlab.log(log_dict)
# Save final checkpoint
save_future = training_client.save_state("gsm8k-grpo-final")
diff --git a/cookbook/client/tinker/transformer/lora.py b/cookbook/client/tinker/lora.py
similarity index 98%
rename from cookbook/client/tinker/transformer/lora.py
rename to cookbook/client/tinker/lora.py
index 44fc94c5..24ca2e52 100644
--- a/cookbook/client/tinker/transformer/lora.py
+++ b/cookbook/client/tinker/lora.py
@@ -36,7 +36,7 @@
# 2. A model id on hub: "/"
# Example:
# resume_path = "twinkle://20260131_170251-Qwen_Qwen2_5-0_5B-Instruct-7275126c/weights/pig-latin-lora-epoch-1"
-# resume_path = "AlexEz/20260205_163645-Qwen_Qwen2_5-3B-Instruct-385d5c17_pig-latin-lora-epoch-1"
+# resume_path = "AlexEz/20260205_163645-Qwen_Qwen2_5-7B-Instruct-385d5c17_pig-latin-lora-epoch-1"
resume_path = ""
print(f"Found {len(response.training_runs)} training runs")
@@ -51,7 +51,7 @@
# Step 5: Create or resume a training client.
# If resume_path is set, it restores both model weights and optimizer state.
-base_model = "Qwen/Qwen2.5-3B-Instruct"
+base_model = "Qwen/Qwen2.5-7B-Instruct"
if not resume_path:
training_client = service_client.create_lora_training_client(
base_model=base_model
diff --git a/cookbook/client/tinker/megatron/server_config.yaml b/cookbook/client/tinker/megatron/server_config.yaml
index 08399d91..375a04c0 100644
--- a/cookbook/client/tinker/megatron/server_config.yaml
+++ b/cookbook/client/tinker/megatron/server_config.yaml
@@ -10,7 +10,7 @@ proxy_location: EveryNode
# HTTP listener settings
http_options:
host: 0.0.0.0 # Listen on all network interfaces
- port: 8000 # Port number for the server
+ port: 9000 # Port number for the server
# Applications: each entry defines a service component deployed on the server
applications:
@@ -50,12 +50,13 @@ applications:
enable_lora: true # Allow loading LoRA adapters during inference
device_group: # Logical device group for the sampler
name: sampler
- gpus_per_worker: 1
- ranks: [0,1,2,3] # GPU rank indices to use
+ gpus_per_worker: 2
+ ranks: [0,1,2,3,4,5,6,7] # GPU rank indices to use
device_type: cuda
device_mesh:
device_type: cuda
dp_size: 4
+ tp_size: 2
deployments:
- name: SamplerManagement
autoscaling_config:
@@ -67,7 +68,7 @@ applications:
runtime_env:
env_vars:
TWINKLE_TRUST_REMOTE_CODE: "0"
- DEVICE_COUNT_PER_PHYSICAL_NODE: "8"
+ DEVICE_COUNT_PER_PHYSICAL_NODE: "16"
# 2. Model Service (commented out) - Would host the base model for training.
# Uncomment and configure if you need a training model worker.
@@ -80,12 +81,13 @@ applications:
nproc_per_node: 4 # Number of GPU processes per node
device_group:
name: model
- ranks: [4,5,6,7] # GPU rank indices
+ ranks: [8,9,10,11,12,13,14,15] # GPU rank indices
device_type: cuda
device_mesh:
device_type: cuda
dp_size: 2
tp_size: 2
+ pp:size: 2
ep_size: 2
queue_config:
@@ -105,4 +107,4 @@ applications:
runtime_env:
env_vars:
TWINKLE_TRUST_REMOTE_CODE: "0"
- DEVICE_COUNT_PER_PHYSICAL_NODE: "8"
+ DEVICE_COUNT_PER_PHYSICAL_NODE: "16"
diff --git a/cookbook/client/tinker/megatron/server_config_7b.yaml b/cookbook/client/tinker/megatron/server_config_7b.yaml
new file mode 100644
index 00000000..1dac6d6c
--- /dev/null
+++ b/cookbook/client/tinker/megatron/server_config_7b.yaml
@@ -0,0 +1,105 @@
+# Twinkle Server Configuration - Tinker-Compatible Transformers Backend
+
+# Server protocol type: "tinker" enables the Tinker-compatible API
+server_type: tinker
+
+# proxy_location: determines where the HTTP proxy runs.
+# "EveryNode" means each Ray node runs its own proxy (good for multi-node).
+proxy_location: EveryNode
+
+# HTTP listener settings
+http_options:
+ host: 0.0.0.0 # Listen on all network interfaces
+ port: 8000 # Port number for the server
+
+# Applications: each entry defines a service component deployed on the server
+applications:
+
+ # 1. TinkerCompatServer - The central API server
+ # Handles client connections, training run tracking, checkpoint listing.
+ - name: server
+ route_prefix: /api/v1 # API endpoint prefix (Tinker-compatible)
+ import_path: server # Python module to import
+ args:
+
+ deployments:
+ - name: TinkerCompatServer
+ autoscaling_config:
+ min_replicas: 1 # Minimum number of replicas
+ max_replicas: 1 # Maximum number of replicas
+ target_ongoing_requests: 128 # Target concurrent requests per replica
+ ray_actor_options:
+ num_cpus: 0.1 # CPU resources allocated to this actor
+
+ # 2. Model Service (commented out) - Would host the base model for training.
+ # Uncomment and configure if you need a training model worker.
+ - name: models-Qwen2.5-7B-Instruct
+ route_prefix: /api/v1/model/Qwen/Qwen2.5-7B-Instruct
+ import_path: model
+ args:
+ use_megatron: true
+ model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope model identifier
+ max_length: 10240
+ nproc_per_node: 2 # Number of GPU processes per node
+ device_group:
+ name: model
+ ranks: [0,1] # GPU rank indices
+ device_type: cuda
+ device_mesh:
+ device_type: cuda
+ dp_size: 2
+ queue_config:
+ rps_limit: 100 # Max requests per second
+ tps_limit: 100000 # Max tokens per second
+ adapter_config:
+ per_token_adapter_limit: 30 # Max concurrent LoRA adapters
+ adapter_timeout: 1800 # Seconds before idle adapter unload
+ deployments:
+ - name: ModelManagement
+ autoscaling_config:
+ min_replicas: 1
+ max_replicas: 1
+ target_ongoing_requests: 16
+ ray_actor_options:
+ num_cpus: 0.1
+ runtime_env:
+ env_vars:
+ TWINKLE_TRUST_REMOTE_CODE: "0"
+ DEVICE_COUNT_PER_PHYSICAL_NODE: "8"
+
+ # 3. Sampler Service - Runs inference / sampling using vLLM engine
+ # Used for generating text from the model (e.g., evaluating LoRA results).
+ - name: sampler-Qwen2.5-7B-Instruct
+ route_prefix: /api/v1/sampler/Qwen/Qwen2.5-7B-Instruct
+ import_path: sampler
+ args:
+ model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope model identifier
+ nproc_per_node: 2 # Number of GPU processes per node
+ sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler)
+ engine_args: # vLLM engine-specific settings
+ max_model_len: 4096 # Maximum sequence length the engine supports
+ gpu_memory_utilization: 0.5 # Fraction of GPU memory to use (0.0-1.0)
+ enable_lora: true # Allow loading LoRA adapters during inference
+ logprobs_mode: processed_logprobs # Logprobs mode for sampling results
+ device_group: # Logical device group for the sampler
+ name: sampler
+ ranks: [2] # GPU rank indices to use
+ device_type: cuda
+ device_mesh:
+ device_type: cuda
+ dp_size: 1
+ queue_config:
+ rps_limit: 100 # Max requests per second
+ tps_limit: 100000 # Max tokens per second
+ deployments:
+ - name: SamplerManagement
+ autoscaling_config:
+ min_replicas: 1
+ max_replicas: 1
+ target_ongoing_requests: 16
+ ray_actor_options:
+ num_cpus: 0.1
+ runtime_env:
+ env_vars:
+ TWINKLE_TRUST_REMOTE_CODE: "0"
+ DEVICE_COUNT_PER_PHYSICAL_NODE: "8"
diff --git a/cookbook/client/tinker/transformer/sample.py b/cookbook/client/tinker/sample.py
similarity index 98%
rename from cookbook/client/tinker/transformer/sample.py
rename to cookbook/client/tinker/sample.py
index 84faa689..4752c5b9 100644
--- a/cookbook/client/tinker/transformer/sample.py
+++ b/cookbook/client/tinker/sample.py
@@ -9,7 +9,7 @@
from modelscope import AutoTokenizer
# Step 1: Define the base model and connect to the server
-base_model = "Qwen/Qwen2.5-3B-Instruct"
+base_model = "Qwen/Qwen2.5-7B-Instruct"
service_client = init_tinker_compat_client(base_url='http://localhost:8000', api_key="tml-EMPTY_TOKEN")
# Step 2: Create a sampling client by loading weights from a saved checkpoint.
diff --git a/cookbook/client/tinker/transformer/self_congnition.py b/cookbook/client/tinker/self_congnition.py
similarity index 95%
rename from cookbook/client/tinker/transformer/self_congnition.py
rename to cookbook/client/tinker/self_congnition.py
index e1f5b7a3..d846b5f5 100644
--- a/cookbook/client/tinker/transformer/self_congnition.py
+++ b/cookbook/client/tinker/self_congnition.py
@@ -18,7 +18,7 @@
from modelscope import AutoTokenizer
# The base model to fine-tune / evaluate
-base_model = "Qwen/Qwen2.5-3B-Instruct"
+base_model = "Qwen/Qwen2.5-7B-Instruct"
def train():
@@ -83,7 +83,7 @@ def eval():
# Step 1: Load the trained LoRA checkpoint for inference
# Path to a previously saved LoRA checkpoint (twinkle:// URI)
- weight_path = "twinkle://20260207_110850-Qwen_Qwen2_5-0_5B-Instruct-ce7e819f/weights/twinkle-lora-2"
+ weight_path = "twinkle://20260211_112719-Qwen_Qwen2_5-7B-Instruct-a74a4826/weights/twinkle-lora-2"
# Connect to the server and create a sampling client with the trained weights
service_client = init_tinker_compat_client(base_url='http://localhost:8000')
@@ -136,5 +136,5 @@ def eval():
if __name__ == "__main__":
- # train() # Uncomment to run training
- eval() # Run evaluation / inference
+ train() # Uncomment to run training
+ # eval() # Run evaluation / inference
diff --git a/cookbook/client/tinker/transformer/server.py b/cookbook/client/tinker/transformer/server.py
index f8669622..573aa6f8 100644
--- a/cookbook/client/tinker/transformer/server.py
+++ b/cookbook/client/tinker/transformer/server.py
@@ -7,8 +7,6 @@
import os
-# Enable Ray debug mode for verbose logging during development
-# os.environ['RAY_DEBUG'] = '1'
os.environ['TWINKLE_TRUST_REMOTE_CODE'] = '0'
from twinkle.server import launch_server
diff --git a/cookbook/client/tinker/transformer/server_config.yaml b/cookbook/client/tinker/transformer/server_config.yaml
index fbdd29c0..00e57387 100644
--- a/cookbook/client/tinker/transformer/server_config.yaml
+++ b/cookbook/client/tinker/transformer/server_config.yaml
@@ -33,24 +33,24 @@ applications:
# 2. Model Service (commented out) - Would host the base model for training.
# Uncomment and configure if you need a training model worker.
- - name: models-Qwen2.5-3B-Instruct
- route_prefix: /api/v1/model/Qwen/Qwen2.5-3B-Instruct
+ - name: models-Qwen2.5-7B-Instruct
+ route_prefix: /api/v1/model/Qwen/Qwen2.5-7B-Instruct
import_path: model
args:
use_megatron: false # Use HuggingFace Transformers backend
- model_id: "ms://Qwen/Qwen2.5-3B-Instruct" # ModelScope model identifier
+ model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope model identifier
+ max_length: 10240
nproc_per_node: 2 # Number of GPU processes per node
device_group:
name: model
- ranks: [0, 1] # GPU rank indices
+ ranks: [0,1] # GPU rank indices
device_type: cuda
device_mesh:
device_type: cuda
- mesh: [0, 1]
- mesh_dim_names: ['dp'] # 'dp' = data parallel
+ dp_size: 2
queue_config:
rps_limit: 100 # Max requests per second
- tps_limit: 10000 # Max tokens per second
+ tps_limit: 100000 # Max tokens per second
adapter_config:
per_token_adapter_limit: 30 # Max concurrent LoRA adapters
adapter_timeout: 1800 # Seconds before idle adapter unload
@@ -62,28 +62,35 @@ applications:
target_ongoing_requests: 16
ray_actor_options:
num_cpus: 0.1
+ runtime_env:
+ env_vars:
+ TWINKLE_TRUST_REMOTE_CODE: "0"
+ DEVICE_COUNT_PER_PHYSICAL_NODE: "8"
# 3. Sampler Service - Runs inference / sampling using vLLM engine
# Used for generating text from the model (e.g., evaluating LoRA results).
- - name: sampler-Qwen2.5-3B-Instruct
- route_prefix: /api/v1/sampler/Qwen/Qwen2.5-3B-Instruct
+ - name: sampler-Qwen2.5-7B-Instruct
+ route_prefix: /api/v1/sampler/Qwen/Qwen2.5-7B-Instruct
import_path: sampler
args:
- model_id: "ms://Qwen/Qwen2.5-3B-Instruct" # ModelScope model identifier
- nproc_per_node: 1 # Number of GPU processes per node
+ model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope model identifier
+ nproc_per_node: 2 # Number of GPU processes per node
sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler)
engine_args: # vLLM engine-specific settings
max_model_len: 4096 # Maximum sequence length the engine supports
gpu_memory_utilization: 0.5 # Fraction of GPU memory to use (0.0-1.0)
enable_lora: true # Allow loading LoRA adapters during inference
+ logprobs_mode: processed_logprobs # Logprobs mode for sampling results
device_group: # Logical device group for the sampler
name: sampler
- ranks: [0] # GPU rank indices to use
+ ranks: [2] # GPU rank indices to use
device_type: cuda
device_mesh:
device_type: cuda
- mesh: [0]
- mesh_dim_names: ['dp']
+ dp_size: 1
+ queue_config:
+ rps_limit: 100 # Max requests per second
+ tps_limit: 100000 # Max tokens per second
deployments:
- name: SamplerManagement
autoscaling_config:
@@ -92,4 +99,7 @@ applications:
target_ongoing_requests: 16
ray_actor_options:
num_cpus: 0.1
- num_gpus: 1 # Sampler needs a full GPU for inference
+ runtime_env:
+ env_vars:
+ TWINKLE_TRUST_REMOTE_CODE: "0"
+ DEVICE_COUNT_PER_PHYSICAL_NODE: "8"
diff --git a/cookbook/client/twinkle/transformer/grpo.py b/cookbook/client/twinkle/grpo.py
similarity index 100%
rename from cookbook/client/twinkle/transformer/grpo.py
rename to cookbook/client/twinkle/grpo.py
diff --git a/cookbook/client/twinkle/transformer/sampler.py b/cookbook/client/twinkle/sample.py
similarity index 100%
rename from cookbook/client/twinkle/transformer/sampler.py
rename to cookbook/client/twinkle/sample.py
diff --git a/cookbook/client/twinkle/transformer/lora.py b/cookbook/client/twinkle/self_congnition.py
similarity index 100%
rename from cookbook/client/twinkle/transformer/lora.py
rename to cookbook/client/twinkle/self_congnition.py
diff --git a/src/twinkle/model/transformers/multi_lora_transformers.py b/src/twinkle/model/transformers/multi_lora_transformers.py
index af64fdd0..ff2603c9 100644
--- a/src/twinkle/model/transformers/multi_lora_transformers.py
+++ b/src/twinkle/model/transformers/multi_lora_transformers.py
@@ -222,6 +222,11 @@ def add_metric(self, metric_cls: Union[Metric, str], is_training: Optional[bool]
self._check_adapter_valid(kwargs.get("adapter_name"))
super().add_metric(metric_cls, is_training, **kwargs)
+ @remote_function(collect='first', lazy_collect=False)
+ def calculate_metric(self, is_training, **kwargs):
+ self._check_adapter_valid(kwargs.get("adapter_name"))
+ return super().calculate_metric(is_training, **kwargs)
+
@remote_function()
def remove_adapter(self, adapter_name: str):
if adapter_name in self.optimizer_group:
diff --git a/src/twinkle/server/tinker/common/compat_base.py b/src/twinkle/server/tinker/common/compat_base.py
new file mode 100644
index 00000000..60d4f5ce
--- /dev/null
+++ b/src/twinkle/server/tinker/common/compat_base.py
@@ -0,0 +1,92 @@
+import torch
+import numpy as np
+from typing import List
+from tinker import types
+from twinkle.template import Template
+
+
+def collect_forward_backward_results(results):
+ """Custom collect function for forward_backward that handles list [outputs, loss].
+
+ Args:
+ results: List of lists from each worker, where each list is [outputs_list, loss_float]
+
+ Returns:
+ List of [flattened_outputs, averaged_loss]
+ """
+ if not results:
+ return results
+
+ # results is a list of lists: [[outputs1, loss1], [outputs2, loss2], ...]
+ # Flatten outputs (first element of each list)
+ all_outputs = []
+ all_losses = []
+ for result in results:
+ outputs, loss = result
+ all_outputs.extend(outputs)
+ all_losses.append(loss)
+
+ # Average the losses
+ avg_loss = float(np.mean(all_losses))
+
+ return [all_outputs, avg_loss]
+
+
+def clean_metrics(metrics: dict) -> dict:
+ cleaned = {}
+ for key, value in metrics.items():
+ if isinstance(value, str):
+ import re
+ match = re.match(r'^([+-]?\d*\.?\d+)', value.strip())
+ if match:
+ cleaned[key] = float(match.group(1))
+ else:
+ cleaned[key] = value
+ return cleaned
+
+
+class TwinkleCompatModelBase:
+ """Base class containing common logic for Twinkle compatibility wrappers."""
+
+ def get_template(self, adapter_name: str) -> Template:
+ return self.optimizer_group[adapter_name].template
+
+ @staticmethod
+ def _get_forward_output(inputs: List[types.Datum], logits: torch.Tensor) -> List[dict]:
+ """Convert raw logits to the expected output format with logprobs and elementwise_loss."""
+ results = []
+ for i, feature in enumerate(inputs):
+ # Ensure 1D shape and correct device to avoid dimension mismatch and device errors
+ labels = feature.loss_fn_inputs['target_tokens'].to_torch(
+ ).long().view(-1).to(logits.device) # shape (seq_len,)
+ weights = feature.loss_fn_inputs['weights'].to_torch(
+ ).view(-1).to(logits.device) # shape (seq_len,)
+
+ # Slice logits to match the sequence length of labels
+ # Labels are assumed to be already shifted/aligned with logits
+ seq_len = labels.numel()
+
+ # Check if index is within logits bounds
+ if i < logits.shape[0]:
+ feature_logits = logits[i, :seq_len, :]
+
+ # Calculate log probs for all labels
+ # Apply log_softmax to convert raw logits to log-probabilities
+ feature_log_probs = torch.log_softmax(feature_logits, dim=-1)
+ token_log_probs = feature_log_probs.gather(
+ dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
+
+ # elementwise_loss: positive NLL loss (0.0 where masked)
+ elementwise_loss = -token_log_probs * weights
+
+ results.append({
+ 'logprobs': types.TensorData.from_torch(token_log_probs),
+ 'elementwise_loss': types.TensorData.from_torch(elementwise_loss)
+ })
+ else:
+ # Handle case where batch index exceeds logits batch size
+ results.append({
+ 'logprobs': None,
+ 'elementwise_loss': None
+ })
+ return results
diff --git a/src/twinkle/server/tinker/common/megatron_model.py b/src/twinkle/server/tinker/common/megatron_model.py
index c744800c..e76ce706 100644
--- a/src/twinkle/server/tinker/common/megatron_model.py
+++ b/src/twinkle/server/tinker/common/megatron_model.py
@@ -1,13 +1,13 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
-import numpy as np
+
import torch
from typing import List, TYPE_CHECKING, Tuple, Optional, Any
from tinker import types
-from twinkle.template import Template
from twinkle import remote_class, remote_function
from twinkle.utils import exists, requires
from .datum import datum_to_input_feature, extract_rl_feature
from .io_utils import create_checkpoint_manager
+from .compat_base import TwinkleCompatModelBase, collect_forward_backward_results, clean_metrics
if TYPE_CHECKING:
@@ -22,35 +22,8 @@ def __init__(self, *args, **kwargs):
requires('megatron_core')
-def _collect_forward_backward_results(results):
- """Custom collect function for forward_backward that handles list [outputs, loss].
-
- Args:
- results: List of lists from each worker, where each list is [outputs_list, loss_float]
-
- Returns:
- List of [flattened_outputs, averaged_loss]
- """
- if not results:
- return results
-
- # results is a list of lists: [[outputs1, loss1], [outputs2, loss2], ...]
- # Flatten outputs (first element of each list)
- all_outputs = []
- all_losses = []
- for result in results:
- outputs, loss = result
- all_outputs.extend(outputs)
- all_losses.append(loss)
-
- # Average the losses
- avg_loss = float(np.mean(all_losses))
-
- return [all_outputs, avg_loss]
-
-
@remote_class(execute='all')
-class TwinkleCompatMegatronModel(_MegatronBase):
+class TwinkleCompatMegatronModel(_MegatronBase, TwinkleCompatModelBase):
"""
Compatibility wrapper around :class:`MultiLoraMegatronModel` for Twinkle/Tinker.
@@ -76,7 +49,7 @@ class TwinkleCompatMegatronModel(_MegatronBase):
This wrapper provides a direct forward_backward interface.
"""
- @remote_function(dispatch='slice_dp', collect=_collect_forward_backward_results, sync=True)
+ @remote_function(dispatch='slice_dp', collect=collect_forward_backward_results, sync=True)
def forward_backward(self, *, inputs: List[types.Datum], adapter_name: str, loss_fn: str, **kwargs):
"""Combined forward and backward pass.
@@ -179,6 +152,13 @@ def step(self, *, adam_params: types.AdamParams, **kwargs):
# Zero gradients
super().zero_grad(**kwargs)
+
+ @remote_function(collect='first', lazy_collect=False)
+ def calculate_metric(self, is_training, **kwargs):
+ metric = super().calculate_metric(is_training, **kwargs)
+ return clean_metrics(metric)
+
+
@remote_function(dispatch='all', sync=True)
def load(self, checkpoint_dir: str, **kwargs):
"""
@@ -206,43 +186,3 @@ def load(self, checkpoint_dir: str, **kwargs):
# Load from hub
return super().load(name=resolved.checkpoint_name, **kwargs)
- def get_template(self, adapter_name: str) -> Template:
- return self.optimizer_group[adapter_name].template
-
- @staticmethod
- def _get_forward_output(inputs: List[types.Datum], logits: torch.Tensor) -> List[dict]:
- """Convert raw logits to the expected output format with logprobs and elementwise_loss."""
- results = []
- for i, feature in enumerate(inputs):
- # Ensure 1D shape and correct device to avoid dimension mismatch and device errors
- labels = feature.loss_fn_inputs['target_tokens'].to_torch(
- ).long().view(-1).to(logits.device) # shape (seq_len,)
- weights = feature.loss_fn_inputs['weights'].to_torch(
- ).view(-1).to(logits.device) # shape (seq_len,)
-
- # Slice logits to match the sequence length of labels
- # Labels are assumed to be already shifted/aligned with logits
- seq_len = labels.numel()
- if i < logits.shape[0]:
- feature_logits = logits[i, :seq_len, :]
-
- # Calculate log probs for all labels
- # Apply log_softmax to convert raw logits to log-probabilities
- feature_log_probs = torch.log_softmax(feature_logits, dim=-1)
- token_log_probs = feature_log_probs.gather(
- dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
-
- # elementwise_loss: positive NLL loss (0.0 where masked)
- elementwise_loss = -token_log_probs * weights
-
- results.append({
- 'logprobs': types.TensorData.from_torch(token_log_probs),
- 'elementwise_loss': types.TensorData.from_torch(elementwise_loss)
- })
- else:
- # Handle case where batch index exceeds logits batch size
- results.append({
- 'logprobs': None,
- 'elementwise_loss': None
- })
- return results
diff --git a/src/twinkle/server/tinker/common/transformers_model.py b/src/twinkle/server/tinker/common/transformers_model.py
index 4110dd65..10a08a57 100644
--- a/src/twinkle/server/tinker/common/transformers_model.py
+++ b/src/twinkle/server/tinker/common/transformers_model.py
@@ -1,43 +1,14 @@
-import torch
-import numpy as np
from tinker import types
from typing import List
-from twinkle.template import Template
from twinkle.model import MultiLoraTransformersModel
from twinkle import remote_class, remote_function
from .datum import datum_to_input_feature, extract_rl_feature
from .io_utils import create_checkpoint_manager
-
-
-def _collect_forward_backward_results(results):
- """Custom collect function for forward_backward that handles list [outputs, loss].
-
- Args:
- results: List of lists from each worker, where each list is [outputs_list, loss_float]
-
- Returns:
- List of [flattened_outputs, averaged_loss]
- """
- if not results:
- return results
-
- # results is a list of lists: [[outputs1, loss1], [outputs2, loss2], ...]
- # Flatten outputs (first element of each list)
- all_outputs = []
- all_losses = []
- for result in results:
- outputs, loss = result
- all_outputs.extend(outputs)
- all_losses.append(loss)
-
- # Average the losses
- avg_loss = float(np.mean(all_losses))
-
- return [all_outputs, avg_loss]
+from .compat_base import TwinkleCompatModelBase, collect_forward_backward_results, clean_metrics
@remote_class()
-class TwinkleCompatTransformersModel(MultiLoraTransformersModel):
+class TwinkleCompatTransformersModel(MultiLoraTransformersModel, TwinkleCompatModelBase):
"""
Compatibility wrapper around :class:`MultiLoraTransformersModel` for Twinkle/Tinker.
@@ -83,7 +54,7 @@ def forward_only(self, *, inputs: List[types.Datum], **kwargs):
results = self._get_forward_output(inputs, logits)
return results
- @remote_function(dispatch='slice_dp', collect=_collect_forward_backward_results)
+ @remote_function(dispatch='slice_dp', collect=collect_forward_backward_results)
def forward_backward(self, *, inputs: List[types.Datum], adapter_name: str, loss_fn: str, **kwargs):
# Set loss first based on loss_fn
if loss_fn == 'cross_entropy':
@@ -102,7 +73,7 @@ def forward_backward(self, *, inputs: List[types.Datum], adapter_name: str, loss
# Convert Datum to InputFeature
input_features = datum_to_input_feature(inputs, template)
-
+
# Forward pass
outputs = super().forward(inputs=input_features, adapter_name=adapter_name, **kwargs)
@@ -139,6 +110,11 @@ def step(self, *, adam_params: types.AdamParams, **kwargs):
# Zero gradients
super().zero_grad(**kwargs)
+ @remote_function(collect='first', lazy_collect=False)
+ def calculate_metric(self, is_training, **kwargs):
+ metric = super().calculate_metric(is_training, **kwargs)
+ return clean_metrics(metric)
+
@remote_function()
def load(self, checkpoint_dir: str, **kwargs):
"""
@@ -165,36 +141,3 @@ def load(self, checkpoint_dir: str, **kwargs):
else:
# Load from hub
return super().load(name=resolved.checkpoint_name, **kwargs)
-
- def get_template(self, adapter_name: str) -> Template:
- return self.optimizer_group[adapter_name].template
-
- @staticmethod
- def _get_forward_output(inputs: List[types.Datum], logits: torch.Tensor) -> List[dict]:
- results = []
- for i, feature in enumerate(inputs):
- # Ensure 1D shape and correct device to avoid dimension mismatch and device errors
- labels = feature.loss_fn_inputs['target_tokens'].to_torch(
- ).long().view(-1).to(logits.device) # shape (seq_len,)
- weights = feature.loss_fn_inputs['weights'].to_torch(
- ).view(-1).to(logits.device) # shape (seq_len,)
-
- # Slice logits to match the sequence length of labels
- # Labels are assumed to be already shifted/aligned with logits
- seq_len = labels.numel()
- feature_logits = logits[i, :seq_len, :]
-
- # Calculate log probs for all labels
- # Apply log_softmax to convert raw logits to log-probabilities
- feature_log_probs = torch.log_softmax(feature_logits, dim=-1)
- token_log_probs = feature_log_probs.gather(
- dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
-
- # elementwise_loss: positive NLL loss (0.0 where masked)
- elementwise_loss = -token_log_probs * weights
-
- results.append({
- 'logprobs': types.TensorData.from_torch(token_log_probs),
- 'elementwise_loss': types.TensorData.from_torch(elementwise_loss)
- })
- return results
diff --git a/src/twinkle/server/tinker/model.py b/src/twinkle/server/tinker/model.py
index 54a9717b..671d5197 100644
--- a/src/twinkle/server/tinker/model.py
+++ b/src/twinkle/server/tinker/model.py
@@ -421,7 +421,8 @@ async def _do_optim():
self.model.step(adam_params=body.adam_params,
adapter_name=adapter_name)
- return types.OptimStepResponse(metrics=None)
+ metrics = self.model.calculate_metric(is_training=True, adapter_name=adapter_name)
+ return types.OptimStepResponse(metrics=metrics)
except Exception:
logger.error(traceback.format_exc())
return types.RequestFailedResponse(
diff --git a/src/twinkle/server/tinker/server.py b/src/twinkle/server/tinker/server.py
index 0b292a72..3e74e63b 100644
--- a/src/twinkle/server/tinker/server.py
+++ b/src/twinkle/server/tinker/server.py
@@ -95,6 +95,7 @@ def __init__(self, supported_models: Optional[List[types.SupportedModel]] = None
types.SupportedModel(model_name="Qwen/Qwen2.5-3B-Instruct"),
types.SupportedModel(model_name="Qwen/Qwen2.5-7B-Instruct"),
types.SupportedModel(model_name="Qwen/Qwen2.5-72B-Instruct"),
+ types.SupportedModel(model_name="Qwen/Qwen3-30B-A3B-Instruct-2507"),
]
# Lock for ModelScope config file operations (login writes, get_user_info reads)
self._modelscope_config_lock = asyncio.Lock()