From b1652bc91a0b153c3c5c7ec6d4b77f3d1c8915a6 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Sat, 7 Feb 2026 18:33:19 +0800 Subject: [PATCH 1/9] update sampler --- client_tools/client_generator.py | 29 +- .../client/twinkle/transformer/sampler.py | 89 ++++ .../twinkle/transformer/server_config.yaml | 32 ++ src/twinkle/server/tinker/sampler.py | 34 +- src/twinkle/server/twinkle/sampler.py | 411 +++++++++++++----- src/twinkle/server/utils/io_utils.py | 27 ++ src/twinkle_client/sampler/vllm_sampler.py | 29 +- 7 files changed, 485 insertions(+), 166 deletions(-) create mode 100644 cookbook/client/twinkle/transformer/sampler.py diff --git a/client_tools/client_generator.py b/client_tools/client_generator.py index 41f0419e..9e6b1bd0 100644 --- a/client_tools/client_generator.py +++ b/client_tools/client_generator.py @@ -750,7 +750,7 @@ def __init__(self, model_id: str, **kwargs): self.adapter_name = None if '://' in model_id: model_id = model_id.split('://')[1] - self.server_url = f'{self.server_url}/models/{model_id}' + self.server_url = f'{self.server_url}/samplers/{model_id}' response = http_post( url=f'{self.server_url}/create', json_data=kwargs @@ -799,25 +799,34 @@ def sample( self, inputs: Union[List[Trajectory], List[InputFeature]], sampling_params: Optional[Dict[str, Any]] = None, - adapter_name: str = '' - ) -> SampleResponse: + adapter_name: str = '', + adapter_uri: Optional[str] = None, + num_samples: int = 1, + ) -> Dict[str, Any]: """Sample from the model. Args: inputs: List of Trajectory or InputFeature to sample from. sampling_params: Sampling parameters dict. - adapter_name: Adapter name. + adapter_name: Adapter name for LoRA inference. + adapter_uri: Adapter URI (twinkle:// path or local path) for LoRA inference. + num_samples: Number of completions to generate per prompt. Returns: - SampleResponse with sampled sequences. + Dict with 'sequences' list, each containing tokens, logprobs, stop_reason. """ + json_data = { + 'inputs': inputs, + 'sampling_params': sampling_params, + 'adapter_name': adapter_name, + 'num_samples': num_samples, + } + if adapter_uri is not None: + json_data['adapter_uri'] = adapter_uri + response = http_post( url=f'{self.server_url}/sample', - json_data={ - 'inputs': inputs, - 'sampling_params': sampling_params, - 'adapter_name': adapter_name - } + json_data=json_data ) response.raise_for_status() return response.json() diff --git a/cookbook/client/twinkle/transformer/sampler.py b/cookbook/client/twinkle/transformer/sampler.py new file mode 100644 index 00000000..138d1529 --- /dev/null +++ b/cookbook/client/twinkle/transformer/sampler.py @@ -0,0 +1,89 @@ +# Twinkle Client - Sampler (Inference) Example +# +# This script demonstrates how to run text generation inference +# through the Twinkle client-server architecture. +# The server must be running first (see server.py and server_config.yaml). +# +# This is the client/server equivalent of cookbook/legacy/sampler/sampler_demo.py. +# Instead of running everything locally, the sampler runs on the server side +# while the client sends requests over HTTP. + +# Step 1: Load environment variables from a .env file (e.g., API tokens) +import dotenv +dotenv.load_dotenv('.env') + +import os +from transformers import AutoTokenizer + +from twinkle import get_logger +from twinkle_client import init_twinkle_client +from twinkle_client.sampler import VLLMSampler + +logger = get_logger() + +MODEL_ID = 'Qwen/Qwen2.5-7B-Instruct' + +# Optional: adapter URI for LoRA inference +# This can be a twinkle:// path from a training run checkpoint +# or None to use the base model +ADAPTER_URI = None +# Example: +# ADAPTER_URI = "twinkle://tml-EMPTY_TOKEN/20260203_211942-Qwen_Qwen2_5-7B-Instruct-11cdabc7/weights/twinkle-lora-2" + + +def sample(): + # Step 2: Initialize the Twinkle client to communicate with the remote server. + client = init_twinkle_client( + base_url='http://127.0.0.1:8000', + api_key=os.environ.get('MODELSCOPE_SDK_TOKEN'), + ) + + # Step 3: Create the sampler client pointing to the model on the server + sampler = VLLMSampler(model_id=MODEL_ID) + + # Step 4: Set the chat template so the sampler can encode Trajectory inputs + sampler.set_template('Template', model_id=MODEL_ID) + + # Step 5: Prepare inputs as Trajectory dicts (messages format) + # Each trajectory is a conversation with system and user messages + trajectory = { + 'messages': [ + {'role': 'system', 'content': 'You are a helpful assistant.'}, + {'role': 'user', 'content': 'Who are you?'}, + ] + } + + num_prompts = 4 + num_samples = 2 # Generate 2 completions per prompt + + # Step 6: Configure sampling parameters + sampling_params = { + 'max_tokens': 128, + 'temperature': 1.0, + } + + # Step 7: Call the sampler + # - inputs: list of Trajectory dicts (will be encoded server-side using the template) + # - sampling_params: controls generation behavior + # - adapter_uri: optional LoRA adapter path for fine-tuned inference + # - num_samples: number of completions per prompt + response = sampler.sample( + inputs=[trajectory] * num_prompts, + sampling_params=sampling_params, + adapter_uri=ADAPTER_URI, + num_samples=num_samples, + ) + + # Step 8: Decode and print the results + tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) + + logger.info(f"Generated {len(response['sequences'])} sequences " + f"({num_prompts} prompts x {num_samples} samples)") + + for i, seq in enumerate(response['sequences']): + text = tokenizer.decode(seq['tokens'], skip_special_tokens=True) + logger.info(f"Sequence {i}:\n {text}\n") + + +if __name__ == '__main__': + sample() diff --git a/cookbook/client/twinkle/transformer/server_config.yaml b/cookbook/client/twinkle/transformer/server_config.yaml index e65c2e7c..bc742282 100644 --- a/cookbook/client/twinkle/transformer/server_config.yaml +++ b/cookbook/client/twinkle/transformer/server_config.yaml @@ -82,5 +82,37 @@ applications: min_replicas: 1 max_replicas: 1 target_ongoing_requests: 128 + ray_actor_options: + num_cpus: 0.1 + + # 4. Sampler Service - Handles text generation inference + # Uses vLLM for efficient batched generation with optional LoRA adapters. + - name: sampler-Qwen2.5-7B-Instruct + route_prefix: /samplers/Qwen/Qwen2.5-7B-Instruct # REST path for this sampler + import_path: sampler + args: + model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope model identifier to load + sampler_type: vllm # Sampler backend (vllm or torch) + nproc_per_node: 2 # Number of GPU processes per node + engine_args: # vLLM engine configuration + gpu_memory_utilization: 0.4 + max_model_len: 1024 + adapter_config: # Adapter lifecycle management + per_token_adapter_limit: 30 # Max LoRA adapters per user + adapter_timeout: 1800 # Seconds before idle adapter is unloaded + device_group: + name: sampler + ranks: [0,1] # GPU rank indices to use + device_type: cuda + device_mesh: + device_type: cuda + mesh: [0,1] + mesh_dim_names: ['dp'] + deployments: + - name: SamplerManagement + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 16 ray_actor_options: num_cpus: 0.1 \ No newline at end of file diff --git a/src/twinkle/server/tinker/sampler.py b/src/twinkle/server/tinker/sampler.py index 674e86c7..747ffb3c 100644 --- a/src/twinkle/server/tinker/sampler.py +++ b/src/twinkle/server/tinker/sampler.py @@ -30,37 +30,6 @@ logger = get_logger() -def parse_adapter_uri(adapter_uri: str, token: str) -> tuple: - """Parse adapter URI to extract user_id and resolved lora_path. - - Args: - adapter_uri: The adapter URI, supports formats: - - twinkle://{training_run_id}/weights/{checkpoint_name} or sampler_weights/{name} - - Local filesystem path - token: User token for resolving the base output directory - - Returns: - Tuple of (user_id, lora_path) where lora_path is the resolved filesystem path - """ - if adapter_uri.startswith('twinkle://'): - # Use CheckpointManager to parse and resolve the path - checkpoint_manager = create_checkpoint_manager(token) - parsed = checkpoint_manager.parse_path(adapter_uri) - if parsed: - # Get the filesystem path using get_ckpt_dir - lora_path = str(checkpoint_manager.get_ckpt_dir( - parsed.training_run_id, parsed.checkpoint_id - )) - return parsed.training_run_id, lora_path - else: - # Fallback: parse manually for non-standard formats - suffix = adapter_uri[len('twinkle://'):] - return 'default', suffix - else: - # Local path - return 'default', adapter_uri - - def build_sampler_app(model_id: str, nproc_per_node: int, device_group: Dict[str, Any], @@ -189,7 +158,8 @@ async def _do_sample(): user_id = None if model_path: token = request.state.token - user_id, adapter_uri = parse_adapter_uri(model_path, token) + checkpoint_manager = create_checkpoint_manager(token) + user_id, adapter_uri = checkpoint_manager.parse_adapter_uri(model_path) # Convert tinker SamplingParams to twinkle SamplingParams if needed sampling_params = None diff --git a/src/twinkle/server/twinkle/sampler.py b/src/twinkle/server/twinkle/sampler.py index 3693ecab..793a66dc 100644 --- a/src/twinkle/server/twinkle/sampler.py +++ b/src/twinkle/server/twinkle/sampler.py @@ -1,35 +1,138 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import os -import threading -import time +""" +Twinkle sampler (inference) server. + +This module provides a Ray Serve deployment for distributed text generation/inference. +It supports: +1. VLLM and Torch sampler backends +2. LoRA adapter loading via adapter URIs (twinkle:// paths or local paths) +3. Multi-user inference with adapter lifecycle management +4. Flexible sampling parameters +""" +import traceback from typing import Dict, Any, List, Optional, Union -from fastapi import FastAPI -from fastapi import Request -from peft import LoraConfig +from fastapi import FastAPI, Request +from pydantic import BaseModel, Field from ray import serve import twinkle from twinkle import DeviceGroup, DeviceMesh from twinkle.data_format import Trajectory, InputFeature -from twinkle.sampler import VLLMSampler, Sampler -from twinkle.server.utils.validation import verify_request_token +from twinkle.sampler.types import SamplingParams, SampleResponse, SampledSequence +from twinkle.server.utils.adapter_manager import AdapterManagerMixin +from twinkle.server.utils.validation import verify_request_token, get_token_from_request from twinkle.server.utils.state import get_server_state, ServerStateProxy -from twinkle.sampler.types import SamplingParams, SampleResponse +from twinkle.utils.logger import get_logger + +logger = get_logger() + + +# ----- Request/Response Models ----- + +class SampleRequest(BaseModel): + """Request body for the /sample endpoint.""" + inputs: Any = Field(..., description="List of Trajectory or InputFeature dicts") + sampling_params: Optional[Dict[str, Any]] = Field( + None, description="Sampling parameters (max_tokens, temperature, etc.)") + adapter_name: str = Field('', description="Adapter name for LoRA inference") + adapter_uri: Optional[str] = Field( + None, description="Adapter URI (twinkle:// path or local path) for LoRA inference") + num_samples: int = Field(1, description="Number of completions to generate per prompt") + + +class SampleResponseModel(BaseModel): + """Response body for the /sample endpoint.""" + sequences: List[Dict[str, Any]] = Field( + ..., description="List of sampled sequences, each with tokens, logprobs, stop_reason") + prompt_logprobs: Optional[List[Optional[float]]] = None + topk_prompt_logprobs: Optional[List[Optional[List]]] = None + + +class SetTemplateRequest(BaseModel): + """Request body for the /set_template endpoint.""" + template_cls: str = Field(..., description="Template class name (e.g. 'Template')") + adapter_name: str = Field('', description="Adapter name to associate the template with") + + class Config: + extra = "allow" + + +class SetTemplateResponse(BaseModel): + """Response body for the /set_template endpoint.""" + status: str = "ok" + + +class AddAdapterRequest(BaseModel): + """Request body for the /add_adapter_to_sampler endpoint.""" + adapter_name: str = Field(..., description="Name of the adapter to add") + config: Any = Field(..., description="LoRA configuration dict") + + +class AddAdapterResponse(BaseModel): + """Response body for the /add_adapter_to_sampler endpoint.""" + status: str = "ok" + adapter_name: str + + +class SyncWeightsRequest(BaseModel): + """Request body for the /sync_weights endpoint.""" + state_dict: Dict[str, Any] = Field(..., description="Model state dict to sync") + adapter_name: str = Field('', description="Adapter name for LoRA weight sync") + + +class SyncWeightsResponse(BaseModel): + """Response body for the /sync_weights endpoint.""" + status: str = "ok" + + +class HeartbeatRequest(BaseModel): + """Request body for the /heartbeat endpoint.""" + adapter_name: str = Field(..., description="Adapter name to keep alive") + + +class HeartbeatResponse(BaseModel): + """Response body for the /heartbeat endpoint.""" + status: str = "ok" + +class CreateResponse(BaseModel): + """Response body for the /create endpoint.""" + status: str = "ok" + + +# ----- Application Builder ----- def build_sampler_app(model_id: str, - device_group: Dict[str, Any], - device_mesh: Dict[str, Any], - deploy_options: Dict[str, Any], + nproc_per_node: int = 1, + device_group: Dict[str, Any] = None, + device_mesh: Dict[str, Any] = None, + deploy_options: Dict[str, Any] = None, + sampler_type: str = 'vllm', + engine_args: Optional[Dict[str, Any]] = None, + adapter_config: Optional[Dict[str, Any]] = None, **kwargs): - app = FastAPI() - device_group = DeviceGroup(**device_group) - twinkle.initialize(mode='ray', - groups=[device_group], - lazy_collect=False) + """Build a sampler application for text generation inference. - device_mesh = DeviceMesh(**device_mesh) + Args: + model_id: Model identifier (e.g., "Qwen/Qwen2.5-7B-Instruct") + nproc_per_node: Number of GPU processes per node + device_group: Device group configuration dict + device_mesh: Device mesh configuration dict for parallelism + deploy_options: Ray Serve deployment options + sampler_type: Type of sampler to use ('vllm' or 'torch') + engine_args: Additional engine arguments for the sampler + adapter_config: Adapter lifecycle config (adapter_timeout, per_token_adapter_limit) + **kwargs: Additional arguments passed to the sampler + + Returns: + Ray Serve deployment bound with configuration + """ + app = FastAPI( + title="Twinkle Sampler", + description="REST API for distributed text generation inference", + version="1.0.0" + ) @app.middleware("http") async def verify_token(request: Request, call_next): @@ -37,106 +140,186 @@ async def verify_token(request: Request, call_next): @serve.deployment(name="SamplerManagement") @serve.ingress(app) - class SamplerManagement: - - COUNT_DOWN = 60 * 30 - - def __init__(self): - self.sampler = VLLMSampler(model_id=model_id, - device_mesh=device_mesh, - remote_group=device_group.name, - **kwargs) - self.adapter_records: Dict[str, int] = {} - self.hb_thread = threading.Thread(target=self.countdown, daemon=True) - self.hb_thread.start() - self.state: ServerStateProxy = get_server_state() - self.per_token_sampler_limit = int(os.environ.get("TWINKLE_PER_USER_SAMPLER_LIMIT", "3")) - self.key_token_dict = {} - - def countdown(self): - while True: - time.sleep(1) - for key in list(self.adapter_records.keys()): - self.adapter_records[key] += 1 - if self.adapter_records[key] > self.COUNT_DOWN: - self.sampler.remove_adapter(key) - self.adapter_records.pop(key, None) - if key in self.key_token_dict: - self.handle_adapter_count(self.key_token_dict[key], False) - - def handle_adapter_count(self, token, add: bool): - user_key = token + '_' + 'sampler_adapter' - cur_count = self.state.get_config(user_key) or 0 - if add: - if cur_count < self.per_token_sampler_limit: - self.state.add_config(user_key, cur_count + 1) - else: - raise RuntimeError(f'Model adapter count limitation reached: {self.per_token_sampler_limit}') + class SamplerManagement(AdapterManagerMixin): + """Sampler management service for text generation inference. + + Manages: + - VLLM or Torch sampler initialization and lifecycle + - Adapter lifecycle via AdapterManagerMixin + - Inference requests with LoRA adapter support + - Template configuration for trajectory encoding + """ + + def __init__(self, nproc_per_node: int, device_group: Dict[str, Any], + device_mesh: Dict[str, Any], sampler_type: str = 'vllm', + engine_args: Optional[Dict[str, Any]] = None, + adapter_config: Optional[Dict[str, Any]] = None, **kwargs): + self.device_group = DeviceGroup(**device_group) + twinkle.initialize(mode='ray', + nproc_per_node=nproc_per_node, + groups=[self.device_group], + lazy_collect=False) + self.device_mesh = DeviceMesh(**device_mesh) + self.sampler_type = sampler_type + + # Initialize sampler based on type + if sampler_type == 'vllm': + from twinkle.sampler import VLLMSampler + sampler_kwargs = engine_args or {} + self.sampler = VLLMSampler( + model_id=model_id, + engine_args=sampler_kwargs, + device_mesh=self.device_mesh, + remote_group=self.device_group.name, + **{k: v for k, v in kwargs.items() if k not in ['engine_args']} + ) else: - if cur_count > 0: - cur_count -= 1 - self.state.add_config(user_key, cur_count) - if cur_count <= 0: - self.state.pop_config(user_key) + from twinkle.sampler import TorchSampler + self.sampler = TorchSampler( + model_id=model_id, + device_mesh=self.device_mesh, + remote_group=self.device_group.name, + **kwargs + ) - def assert_adapter_exists(self, adapter_name): - assert adapter_name and adapter_name in self.adapter_records + # Initialize state and adapter manager + self.state: ServerStateProxy = get_server_state() + _adapter_config = adapter_config or {} + self._init_adapter_manager(**_adapter_config) + self.start_adapter_countdown() - def assert_adapter_valid(self, adapter_name): - assert adapter_name == '' or adapter_name in self.adapter_records + def _on_adapter_expired(self, adapter_name: str, token: str) -> None: + """Handle expired adapters by removing them from the sampler.""" + try: + self.sampler.remove_adapter(adapter_name) + logger.info(f"Removed expired adapter {adapter_name}") + self.check_adapter_limit(token, False) + except Exception as e: + logger.warning(f"Failed to remove expired adapter {adapter_name}: {e}") @staticmethod - def get_adapter_name(request, adapter_name): + def _get_adapter_name(request: Request, adapter_name: Optional[str]) -> Optional[str]: + if adapter_name is None or adapter_name == '': + return None return request.state.request_id + '-' + adapter_name - @app.post("/create") - def create(self, *args, **kwargs): - return '' - - @app.post("/sample") - def sample( - self, - request, - *, - inputs: Union[List[Trajectory], List[InputFeature]], - sampling_params: Optional[Dict[str, Any]] = None, - adapter_name: str = '' - ) -> SampleResponse: - self.assert_adapter_valid(adapter_name) - full_adapter_name = self.get_adapter_name(request, adapter_name=adapter_name) - - params = None - if sampling_params: - params = SamplingParams.from_dict(sampling_params) - - return self.sampler.sample(inputs, params, full_adapter_name) - - @app.post("/add_adapter_to_sampler") - def add_adapter_to_sampler(self, request, *, adapter_name: str, config): - assert adapter_name, 'You need to specify a valid `adapter_name`' - self.handle_adapter_count(request.state.token, True) - full_adapter_name = self.get_adapter_name(request, adapter_name=adapter_name) - config = LoraConfig(**config) - self.sampler.add_adapter_to_sampler(full_adapter_name, config) - self.adapter_records[full_adapter_name] = 0 - self.key_token_dict[full_adapter_name] = request.state.token - - # TODO: check if this is needed - @app.post("/sync_weights") - def sync_weights(self, request, *, state_dict: Dict[str, Any], adapter_name: str = ''): - self.assert_adapter_valid(adapter_name) - full_adapter_name = self.get_adapter_name(request, adapter_name=adapter_name) - return self.sampler.sync_weights(state_dict, full_adapter_name) - - @app.post("/heartbeat") - def heartbeat(self, request, *, adapter_name: str): - self.assert_adapter_exists(adapter_name=adapter_name) - full_adapter_name = self.get_adapter_name(request, adapter_name=adapter_name) - self.adapter_records[full_adapter_name] = 0 - - @app.post("/set_template") - def set_template(self, request, *, template_cls: str, adapter_name: str = '', **kwargs): - full_adapter_name = self.get_adapter_name(request, adapter_name=adapter_name) - return self.sampler.set_template(template_cls, adapter_name=full_adapter_name, **kwargs) - - return SamplerManagement.options(**deploy_options).bind() + @app.post("/create", response_model=CreateResponse) + def create(self, request: Request) -> CreateResponse: + """Health check / session creation endpoint.""" + return CreateResponse() + + @app.post("/sample", response_model=SampleResponseModel) + def sample(self, request: Request, body: SampleRequest) -> SampleResponseModel: + """Sample completions from the model. + + Supports: + - Trajectory inputs (messages-based, requires template to be set) + - InputFeature inputs (pre-tokenized input_ids) + - LoRA adapter via adapter_name or adapter_uri (twinkle:// path) + - Multiple completions per prompt via num_samples + """ + try: + # Resolve adapter + adapter_path = None + adapter_name = body.adapter_name or '' + full_adapter_name = self._get_adapter_name(request, adapter_name) or '' + + if body.adapter_uri: + from .common.io_utils import create_checkpoint_manager + token = get_token_from_request(request) + checkpoint_manager = create_checkpoint_manager(token) + _, adapter_path = checkpoint_manager.parse_adapter_uri(body.adapter_uri) + + # Parse inputs + inputs = body.inputs + if isinstance(inputs, list) and inputs: + first = inputs[0] + if isinstance(first, dict) and 'input_ids' in first: + inputs = [InputFeature(**item) for item in inputs] + else: + inputs = [Trajectory(**item) for item in inputs] + elif isinstance(inputs, dict): + if 'input_ids' in inputs: + inputs = [InputFeature(**inputs)] + else: + inputs = [Trajectory(**inputs)] + + # Build sampling params + params = None + if body.sampling_params: + params = SamplingParams.from_dict(body.sampling_params) + + # Call sampler + response = self.sampler.sample( + inputs, + params, + adapter_name=full_adapter_name, + adapter_path=adapter_path, + num_samples=body.num_samples, + ) + if callable(response): + response = response() + + # Convert to response model + sequences = [] + for seq in response.sequences: + sequences.append({ + 'stop_reason': seq.stop_reason, + 'tokens': list(seq.tokens), + 'logprobs': list(seq.logprobs) if seq.logprobs is not None else None, + }) + + return SampleResponseModel( + sequences=sequences, + prompt_logprobs=response.prompt_logprobs, + topk_prompt_logprobs=response.topk_prompt_logprobs, + ) + except Exception: + logger.error(traceback.format_exc()) + raise + + @app.post("/set_template", response_model=SetTemplateResponse) + def set_template(self, request: Request, body: SetTemplateRequest) -> SetTemplateResponse: + """Set the chat template for encoding Trajectory inputs.""" + full_adapter_name = self._get_adapter_name(request, body.adapter_name) or '' + extra_kwargs = body.model_extra or {} + self.sampler.set_template(body.template_cls, adapter_name=full_adapter_name, **extra_kwargs) + return SetTemplateResponse() + + @app.post("/add_adapter_to_sampler", response_model=AddAdapterResponse) + def add_adapter_to_sampler(self, request: Request, body: AddAdapterRequest) -> AddAdapterResponse: + """Add a LoRA adapter to the sampler.""" + assert body.adapter_name, 'You need to specify a valid `adapter_name`' + full_adapter_name = self._get_adapter_name(request, body.adapter_name) + token = get_token_from_request(request) + + from peft import LoraConfig + config = LoraConfig(**body.config) if isinstance(body.config, dict) else body.config + + with self._adapter_lock: + self.sampler.add_adapter_to_sampler(full_adapter_name, config) + + self.register_adapter(full_adapter_name, token) + allowed, reason = self.check_adapter_limit(token, True) + if not allowed: + raise RuntimeError(reason) + + return AddAdapterResponse(adapter_name=full_adapter_name) + + @app.post("/sync_weights", response_model=SyncWeightsResponse) + def sync_weights(self, request: Request, body: SyncWeightsRequest) -> SyncWeightsResponse: + """Synchronize model weights to the sampler.""" + full_adapter_name = self._get_adapter_name(request, body.adapter_name) or '' + self.sampler.sync_weights(body.state_dict, full_adapter_name) + return SyncWeightsResponse() + + @app.post("/heartbeat", response_model=HeartbeatResponse) + def heartbeat(self, request: Request, body: HeartbeatRequest) -> HeartbeatResponse: + """Keep an adapter alive by resetting its inactivity timer.""" + full_adapter_name = self._get_adapter_name(request, body.adapter_name) + self.assert_adapter_exists(adapter_name=full_adapter_name) + self.touch_adapter(full_adapter_name) + return HeartbeatResponse() + + return SamplerManagement.options(**deploy_options).bind( + nproc_per_node, device_group, device_mesh, sampler_type, engine_args, adapter_config, **kwargs) diff --git a/src/twinkle/server/utils/io_utils.py b/src/twinkle/server/utils/io_utils.py index 7a471288..45f5e1b7 100644 --- a/src/twinkle/server/utils/io_utils.py +++ b/src/twinkle/server/utils/io_utils.py @@ -818,6 +818,33 @@ def _get_weights_info_from_hub(self, hub_model_id: str) -> Optional[Any]: except Exception: return None + def parse_adapter_uri(self, adapter_uri: str) -> tuple: + """Parse adapter URI to extract user_id and resolved lora_path. + + Args: + adapter_uri: The adapter URI, supports formats: + - twinkle://{training_run_id}/weights/{checkpoint_name} or sampler_weights/{name} + - Local filesystem path + + Returns: + Tuple of (user_id, lora_path) where lora_path is the resolved filesystem path + """ + if adapter_uri.startswith(self.path_prefix): + parsed = self.parse_path(adapter_uri) + if parsed: + # Get the filesystem path using get_ckpt_dir + lora_path = str(self.get_ckpt_dir( + parsed.training_run_id, parsed.checkpoint_id + )) + return parsed.training_run_id, lora_path + else: + # Fallback: parse manually for non-standard formats + suffix = adapter_uri[len(self.path_prefix):] + return 'default', suffix + else: + # Local path + return 'default', adapter_uri + def resolve_load_path(self, path: str, validate_exists: bool = True) -> ResolvedLoadPath: """ Resolve a checkpoint load path. diff --git a/src/twinkle_client/sampler/vllm_sampler.py b/src/twinkle_client/sampler/vllm_sampler.py index e8c6302e..51146eb6 100644 --- a/src/twinkle_client/sampler/vllm_sampler.py +++ b/src/twinkle_client/sampler/vllm_sampler.py @@ -35,7 +35,7 @@ def __init__(self, model_id: str, **kwargs): self.adapter_name = None if '://' in model_id: model_id = model_id.split('://')[1] - self.server_url = f'{self.server_url}/models/{model_id}' + self.server_url = f'{self.server_url}/samplers/{model_id}' response = http_post( url=f'{self.server_url}/create', json_data=kwargs @@ -84,25 +84,34 @@ def sample( self, inputs: Union[List[Trajectory], List[InputFeature]], sampling_params: Optional[Dict[str, Any]] = None, - adapter_name: str = '' - ) -> SampleResponse: + adapter_name: str = '', + adapter_uri: Optional[str] = None, + num_samples: int = 1, + ) -> Dict[str, Any]: """Sample from the model. Args: inputs: List of Trajectory or InputFeature to sample from. sampling_params: Sampling parameters dict. - adapter_name: Adapter name. + adapter_name: Adapter name for LoRA inference. + adapter_uri: Adapter URI (twinkle:// path or local path) for LoRA inference. + num_samples: Number of completions to generate per prompt. Returns: - SampleResponse with sampled sequences. + Dict with 'sequences' list, each containing tokens, logprobs, stop_reason. """ + json_data = { + 'inputs': inputs, + 'sampling_params': sampling_params, + 'adapter_name': adapter_name, + 'num_samples': num_samples, + } + if adapter_uri is not None: + json_data['adapter_uri'] = adapter_uri + response = http_post( url=f'{self.server_url}/sample', - json_data={ - 'inputs': inputs, - 'sampling_params': sampling_params, - 'adapter_name': adapter_name - } + json_data=json_data ) response.raise_for_status() return response.json() From 7ce9d395b47936f3d1d11e0b1c56e327664be9c5 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Sat, 7 Feb 2026 18:49:36 +0800 Subject: [PATCH 2/9] update sampler --- .../\346\234\215\345\212\241\347\253\257.md" | 109 ++++++++++++++++++ 1 file changed, 109 insertions(+) diff --git "a/docs/source/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/\346\234\215\345\212\241\347\253\257.md" "b/docs/source/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/\346\234\215\345\212\241\347\253\257.md" index 4d56229a..d194159d 100644 --- "a/docs/source/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/\346\234\215\345\212\241\347\253\257.md" +++ "b/docs/source/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/\346\234\215\345\212\241\347\253\257.md" @@ -1,5 +1,114 @@ # 服务端(Server) +## Ray 集群配置 + +在启动 Server 之前,**必须先启动并配置 Ray 节点**。只有正确配置了 Ray 节点后,Server 才能正确分配和占用资源(GPU、CPU 等)。 + +### 启动 Ray 节点 + +Ray 集群由多个节点(Node)组成,每个节点可以配置不同的资源。启动步骤如下: + +#### 1. 启动 Head 节点(第一个 GPU 节点) + +```bash +# 停止已有的 Ray 集群(如果有) +ray stop + +# 启动 Head 节点,使用 GPU 0-3,共 4 个 GPU +CUDA_VISIBLE_DEVICES=0,1,2,3 ray start --head --num-gpus=4 --port=6379 +``` + +#### 2. 启动 Worker 节点 + +```bash +# 第二个 GPU 节点,使用 GPU 4-7,共 4 个 GPU +CUDA_VISIBLE_DEVICES=4,5,6,7 ray start --address=10.28.252.9:6379 --num-gpus=4 + +# CPU 节点(用于运行 Processor 等 CPU 任务) +ray start --address=10.28.252.9:6379 --num-gpus=0 +``` + +**说明:** +- `--head`:标记此节点为 Head 节点(集群的主节点) +- `--port=6379`:Head 节点监听端口 +- `--address=:`:Worker 节点连接到 Head 节点的地址 +- `--num-gpus=N`:该节点可用的 GPU 数量 +- `CUDA_VISIBLE_DEVICES`:限制该节点可见的 GPU 设备 + +#### 3. 完整示例:3 节点集群 + +```bash +# 停止旧集群并启动新集群 +ray stop && \ +CUDA_VISIBLE_DEVICES=0,1,2,3 ray start --head --num-gpus=4 --port=6379 && \ +CUDA_VISIBLE_DEVICES=4,5,6,7 ray start --address=10.28.252.9:6379 --num-gpus=4 && \ +ray start --address=10.28.252.9:6379 --num-gpus=0 +``` + +此配置启动了 3 个节点: +- **Node 0**(Head):4 个 GPU(卡 0-3) +- **Node 1**(Worker):4 个 GPU(卡 4-7) +- **Node 2**(Worker):纯 CPU 节点 + +### YAML 配置中的 Node Rank + +在 YAML 配置文件中,**每个组件需要占用一个独立的 Node**,`ranks` 配置在各自的 Node 内都是从 0 开始编号的。 + +**示例配置:** + +```yaml +applications: + # 模型服务占用 Node 0(Head 节点,GPU 0-3) + - name: models-Qwen2.5-7B-Instruct + route_prefix: /models/Qwen/Qwen2.5-7B-Instruct + import_path: model + args: + nproc_per_node: 4 + device_group: + name: model + ranks: [0, 1, 2, 3] # Node 0 内的 GPU 编号 + device_type: cuda + device_mesh: + device_type: cuda + mesh: [0, 1, 2, 3] + mesh_dim_names: ['dp'] + + # Sampler 服务占用 Node 1(Worker 节点,GPU 4-7) + - name: sampler-Qwen2.5-7B-Instruct + route_prefix: /sampler/Qwen/Qwen2.5-7B-Instruct + import_path: sampler + args: + nproc_per_node: 2 + device_group: + name: sampler + ranks: [0, 1] # Node 1 内的 GPU 编号(对应物理 GPU 4-5) + device_type: cuda + device_mesh: + device_type: cuda + mesh: [0, 1] + mesh_dim_names: ['dp'] + + # Processor 服务占用 Node 2(CPU 节点) + - name: processor + route_prefix: /processors + import_path: processor + args: + ncpu_proc_per_node: 4 + device_group: + name: processor + ranks: 0 # Node 2 内的 CPU 编号 + device_type: CPU + device_mesh: + device_type: CPU + mesh: [0, 1, 2, 3] + mesh_dim_names: ['dp'] +``` + +**重要提示:** +- 每个组件的 `ranks` 配置都是相对于其所占用的 Ray Node 而言 +- 不同组件会自动分配到不同的 Node 上 +- Ray 会根据资源需求(`ray_actor_options` 中的 `num_gpus`、`num_cpus`)自动调度到合适的 Node + ## 启动方式 Server 统一通过 `launch_server` 函数或 CLI 命令启动,配合 YAML 配置文件。 From bd2e607f03e8dae27fd63206c69f7ef5e6fe6e58 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Sun, 8 Feb 2026 17:38:47 +0800 Subject: [PATCH 3/9] updat --- .../twinkle/transformer/server_config.yaml | 16 +++---- src/twinkle_client/processor/grpo.py | 45 +++++++++++++++++++ 2 files changed, 53 insertions(+), 8 deletions(-) create mode 100644 src/twinkle_client/processor/grpo.py diff --git a/cookbook/client/twinkle/transformer/server_config.yaml b/cookbook/client/twinkle/transformer/server_config.yaml index bc742282..fbb980c9 100644 --- a/cookbook/client/twinkle/transformer/server_config.yaml +++ b/cookbook/client/twinkle/transformer/server_config.yaml @@ -33,8 +33,8 @@ applications: # 2. Model Service - Hosts the base model for training # This is the actual model worker that performs forward/backward passes. - - name: models-Qwen2.5-7B-Instruct - route_prefix: /models/Qwen/Qwen2.5-7B-Instruct # REST path for this model + - name: models-Qwen2.5-0.5B-Instruct + route_prefix: /models/Qwen/Qwen2.5-0.5B-Instruct # REST path for this model import_path: model args: use_megatron: false # Use HuggingFace Transformers (not Megatron) @@ -87,13 +87,13 @@ applications: # 4. Sampler Service - Handles text generation inference # Uses vLLM for efficient batched generation with optional LoRA adapters. - - name: sampler-Qwen2.5-7B-Instruct - route_prefix: /samplers/Qwen/Qwen2.5-7B-Instruct # REST path for this sampler + - name: sampler-Qwen2.5-0.5B-Instruct + route_prefix: /samplers/Qwen/Qwen2.5-0.5B-Instruct # REST path for this sampler import_path: sampler args: - model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope model identifier to load + model_id: "ms://Qwen/Qwen2.5-0.5B-Instruct" # ModelScope model identifier to load sampler_type: vllm # Sampler backend (vllm or torch) - nproc_per_node: 2 # Number of GPU processes per node + nproc_per_node: 1 # Number of GPU processes per node engine_args: # vLLM engine configuration gpu_memory_utilization: 0.4 max_model_len: 1024 @@ -102,11 +102,11 @@ applications: adapter_timeout: 1800 # Seconds before idle adapter is unloaded device_group: name: sampler - ranks: [0,1] # GPU rank indices to use + ranks: [0] # GPU rank indices to use device_type: cuda device_mesh: device_type: cuda - mesh: [0,1] + mesh: [0] mesh_dim_names: ['dp'] deployments: - name: SamplerManagement diff --git a/src/twinkle_client/processor/grpo.py b/src/twinkle_client/processor/grpo.py new file mode 100644 index 00000000..d1dfe292 --- /dev/null +++ b/src/twinkle_client/processor/grpo.py @@ -0,0 +1,45 @@ +from typing import Optional +from twinkle_client.http import TWINKLE_SERVER_URL +from twinkle_client.http import http_post, heartbeat_manager +from twinkle import DeviceMesh +from twinkle.data_format import InputFeature +from .base import InputProcessor + +class GRPOLossProcessor(InputProcessor): + """Client wrapper for GRPOLossProcessor that calls server HTTP endpoints.""" + + def __init__(self, device_mesh: Optional[DeviceMesh] = None, ignore_index: int = -100, **kwargs): + from twinkle_client.http import get_base_url + self.server_url = get_base_url() + + response = http_post( + url=f'{self.server_url}/processors/create', + json_data={ + 'processor_type': 'processor', + 'class_type': 'GRPOLossProcessor', + **{'device_mesh': device_mesh, 'ignore_index': ignore_index}, **kwargs + } + ) + response.raise_for_status() + self.processor_id = response.json()['processor_id'] + heartbeat_manager.register_processor(self.processor_id) + + def __del__(self): + try: + heartbeat_manager.unregister_processor(self.processor_id) + except: + pass + + + def prepare_inputs(self, inputs: InputFeature): + response = http_post( + url=f'{self.server_url}/processors/call', + json_data={ + 'processor_id': self.processor_id, + 'function': 'prepare_inputs', + **{'inputs': inputs}, + } + ) + response.raise_for_status() + return response.json()["result"] + \ No newline at end of file From de16401a9f7801645ec950b5356f7587e818bb71 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Sun, 8 Feb 2026 21:18:50 +0800 Subject: [PATCH 4/9] update sampler --- cookbook/client/tinker/transformer/grpo.py | 259 ++++++++++++++++++++ cookbook/client/twinkle/transformer/grpo.py | 238 ++++++++++++++++++ 2 files changed, 497 insertions(+) create mode 100644 cookbook/client/tinker/transformer/grpo.py create mode 100644 cookbook/client/twinkle/transformer/grpo.py diff --git a/cookbook/client/tinker/transformer/grpo.py b/cookbook/client/tinker/transformer/grpo.py new file mode 100644 index 00000000..d9a6bb89 --- /dev/null +++ b/cookbook/client/tinker/transformer/grpo.py @@ -0,0 +1,259 @@ +# Tinker-Compatible Client - GRPO (Group Relative Policy Optimization) Training Example +# +# This script demonstrates GRPO reinforcement learning training using the +# Tinker-compatible client API with save_weights_for_sampler for weight sync. +# Instead of calling sync_weights directly, it periodically saves weights and +# creates a sampling client for generation. +# +# Flow: +# 1. Prepare Countdown dataset (client-side) +# 2. Initialize Tinker-compatible training & sampling clients +# 3. Training loop: +# a. Every SYNC_INTERVAL steps: save_weights_for_sampler → sampling_client +# b. Sample completions from the sampling client +# c. Compute rewards and advantages (client-side) +# d. Train on sampled data weighted by advantages +# e. Optimizer step +# +# The server must be running first (see server.py and server_config.yaml). +# Requires both model and sampler services to be configured. + +import gc +import numpy as np +from typing import List, Tuple + +from tinker import types +from twinkle_client import init_tinker_compat_client +from twinkle import get_logger +from twinkle.advantage import GRPOAdvantage +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.metric import CompletionRewardMetric +from twinkle.server.tinker.common import input_feature_to_datum +from modelscope import AutoTokenizer + +logger = get_logger() + +# ========== Configuration ========== +BASE_MODEL = "Qwen/Qwen2.5-3B-Instruct" +NUM_GENERATIONS = 8 +MAX_NEW_TOKENS = 1024 +LEARNING_RATE = 1e-5 +MAX_STEPS = 2000 +BATCH_SIZE = 4 +TEMPERATURE = 1.0 +SYNC_INTERVAL = 1 # Save weights for sampler every N steps +LORA_RANK = 8 + + +def create_countdown_dataset(): + """Create Countdown Game dataset for GRPO training.""" + from twinkle.preprocessor import CountdownProcessor + dataset = Dataset(DatasetMeta( + "ms://zouxuhong/Countdown-Tasks-3to4", data_slice=range(50000))) + dataset.set_template( + "Template", model_id=f'ms://{BASE_MODEL}', max_length=8192) + dataset.map(CountdownProcessor()) + dataset.encode(add_generation_prompt=True) + return dataset + + +def compute_rewards( + trajectories: List[dict], +) -> Tuple[List[float], List[float], List[float]]: + """Compute format and accuracy rewards for Countdown game.""" + from twinkle.reward import CountDownAccuracy, FormatReward + format_rewards = FormatReward()(trajectories, []) + accuracy_rewards = CountDownAccuracy()(trajectories, []) + total_rewards = [a + b for a, b in zip(accuracy_rewards, format_rewards)] + return total_rewards, format_rewards, accuracy_rewards + + +def main(): + # Step 1: Prepare dataset and dataloader (client-side) + dataset = create_countdown_dataset() + dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) + tokenizer = AutoTokenizer.from_pretrained( + BASE_MODEL, trust_remote_code=True) + + # Step 2: Initialize the Tinker-compatible client + service_client = init_tinker_compat_client( + base_url='http://localhost:8000') + + # Create a LoRA training client for GRPO + training_client = service_client.create_lora_training_client( + base_model=BASE_MODEL, + rank=LORA_RANK, + ) + + # Step 3: Setup metrics and advantage function + advantage_fn = GRPOAdvantage() + metrics = CompletionRewardMetric() + + sampling_params = types.SamplingParams( + max_tokens=MAX_NEW_TOKENS, + temperature=TEMPERATURE, + top_p=0.95, + ) + + # The sampling client is created on-demand via save_weights_for_sampler + sampling_client = None + + step = 0 + for batch in dataloader: + if step >= MAX_STEPS: + break + + metrics.reset() + prompts = batch if isinstance(batch, list) else [batch] + + # ========== 1. Save weights for sampler (instead of sync_weights) ========== + if step % SYNC_INTERVAL == 0: + logger.info(f"Step {step}: Saving weights for sampler...") + sampling_client = ( + training_client.save_weights_and_get_sampling_client( + name=f'grpo-step-{step}')) + logger.info(f"Step {step}: Sampling client ready") + + if sampling_client is None: + logger.warning("No sampling client available, skipping step") + step += 1 + continue + + # ========== 2. Sample completions ========== + # Convert input features to token prompts for the sampling client + all_sequences = [] + for prompt_feature in prompts: + input_ids = prompt_feature['input_ids'] + if hasattr(input_ids, 'tolist'): + input_ids = input_ids.tolist() + prompt = types.ModelInput.from_ints(input_ids) + future = sampling_client.sample( + prompt=prompt, + sampling_params=sampling_params, + num_samples=NUM_GENERATIONS, + ) + result = future.result() + all_sequences.extend(result.sequences) + + if not all_sequences: + logger.warning(f"Step {step}: No valid samples, skipping") + step += 1 + continue + + # ========== 3. Build trajectories and collect logprobs ========== + trajectories = [] + old_logps_list = [] + completion_lengths = [] + + for seq in all_sequences: + decoded_text = tokenizer.decode(seq.tokens, skip_special_tokens=True) + trajectories.append({ + 'messages': [{'role': 'assistant', 'content': decoded_text}] + }) + old_logps_list.append( + [lp for lp in seq.logprobs] if seq.logprobs else []) + completion_lengths.append(len(seq.tokens)) + + # ========== 4. Compute rewards ========== + total_rewards, format_rewards, accuracy_rewards = compute_rewards( + trajectories) + metrics.accumulate( + None, None, + completion_lengths=completion_lengths, + rewards={ + 'total': total_rewards, + 'format': format_rewards, + 'accuracy': accuracy_rewards, + }) + + # ========== 5. Compute advantages ========== + advantages = advantage_fn( + total_rewards, + num_generations=NUM_GENERATIONS, + scale='group', + ).tolist() + + frac_zero_std = ( + 1.0 if all(abs(a) < 1e-8 for a in advantages) else 0.0) + if frac_zero_std == 1.0: + logger.info( + f"Step {step}: All advantages are zero, skipping training") + step += 1 + continue + + # ========== 6. Training step ========== + # Select samples with positive advantages for training + # Weight them by their advantage value for GRPO-style optimization + training_data = [] + for i, seq in enumerate(all_sequences): + if advantages[i] <= 0: + continue + # Build a Datum from the completion tokens + # Prompt tokens: weight=0 (don't compute loss on prompt) + # Completion tokens: weight=advantage (advantage-weighted SFT) + prompt_feature = prompts[i // NUM_GENERATIONS] + prompt_ids = prompt_feature['input_ids'] + if hasattr(prompt_ids, 'tolist'): + prompt_ids = prompt_ids.tolist() + + full_tokens = prompt_ids + list(seq.tokens) + prompt_weights = [0.0] * len(prompt_ids) + # Scale completion weights by normalized advantage + completion_weights = [float(advantages[i])] * len(seq.tokens) + + # Shift by one for next-token prediction + input_tokens = full_tokens[:-1] + target_tokens = full_tokens[1:] + weights = (prompt_weights + completion_weights)[1:] + + datum = types.Datum( + model_input=types.ModelInput.from_ints(input_tokens), + loss_fn_inputs={ + 'target_tokens': target_tokens, + 'weights': weights, + }, + ) + training_data.append(datum) + + if not training_data: + logger.info( + f"Step {step}: No positive-advantage samples, skipping") + step += 1 + continue + + # Forward-backward pass with cross-entropy on advantage-weighted data + fwdbwd_future = training_client.forward_backward( + training_data, "cross_entropy") + optim_future = training_client.optim_step( + types.AdamParams(learning_rate=LEARNING_RATE)) + + fwdbwd_result = fwdbwd_future.result() + optim_result = optim_future.result() + + # Compute weighted average loss for monitoring + logprobs = np.concatenate( + [output['logprobs'].tolist() + for output in fwdbwd_result.loss_fn_outputs]) + weights = np.concatenate( + [d.loss_fn_inputs['weights'].tolist() for d in training_data]) + loss_per_token = -np.dot(logprobs, weights) / max(weights.sum(), 1e-8) + + gc.collect() + + # ========== 7. Log ========== + log_dict = metrics.calculate() + log_dict['train/loss_per_token'] = loss_per_token + log_dict['train/frac_reward_zero_std'] = frac_zero_std + log_dict['train/num_training_samples'] = len(training_data) + logger.info(f"Step {step}: {log_dict}") + step += 1 + + # Save final checkpoint + save_future = training_client.save_state("grpo-countdown-final") + save_result = save_future.result() + logger.info(f"Saved final checkpoint to {save_result.path}") + + +if __name__ == '__main__': + main() diff --git a/cookbook/client/twinkle/transformer/grpo.py b/cookbook/client/twinkle/transformer/grpo.py new file mode 100644 index 00000000..71b7887f --- /dev/null +++ b/cookbook/client/twinkle/transformer/grpo.py @@ -0,0 +1,238 @@ +# Twinkle Client - GRPO (Group Relative Policy Optimization) Training Example +# +# This script demonstrates GRPO reinforcement learning training using the +# Twinkle client API with model.save() + adapter_uri for weight sync. +# Instead of calling sync_weights directly, it periodically saves model weights +# and passes the checkpoint path to the sampler as adapter_uri. +# +# Flow: +# 1. Prepare Countdown dataset (client-side) +# 2. Initialize Twinkle client, model, and sampler +# 3. Configure model with GRPOLoss, optimizer, LR scheduler +# 4. Training loop: +# a. Every SYNC_INTERVAL steps: model.save() → get twinkle_path +# b. sampler.sample(inputs, adapter_uri=twinkle_path, num_samples=N) +# c. Compute rewards and advantages (client-side) +# d. model.forward_backward(inputs, advantages, old_logps) +# e. Optimizer step +# +# The server must be running first (see server.py and server_config.yaml). +# Requires both model and sampler services to be configured. + +import dotenv +dotenv.load_dotenv('.env') + +import gc +import os +from typing import List, Tuple + +from peft import LoraConfig + +from twinkle import get_logger +from twinkle.advantage import GRPOAdvantage +from twinkle.dataset import DatasetMeta +from twinkle.metric import CompletionRewardMetric +from twinkle_client import init_twinkle_client +from twinkle_client.dataloader import DataLoader +from twinkle_client.dataset import Dataset +from twinkle_client.model import MultiLoraTransformersModel +from twinkle_client.sampler import VLLMSampler + +logger = get_logger() + +# ========== Configuration ========== +MODEL_ID = 'ms://Qwen/Qwen2.5-3B-Instruct' +NUM_GENERATIONS = 8 +MAX_NEW_TOKENS = 1024 +LEARNING_RATE = 1e-5 +MAX_STEPS = 2000 +BATCH_SIZE = 4 +TEMPERATURE = 1.0 +SYNC_INTERVAL = 1 # Save weights for sampler every N steps +GRADIENT_ACCUMULATION_STEPS = 4 + + +def create_countdown_dataset(): + """Create Countdown Game dataset for GRPO training.""" + from twinkle.preprocessor import CountdownProcessor + + dataset = Dataset(dataset_meta=DatasetMeta( + "ms://zouxuhong/Countdown-Tasks-3to4", data_slice=range(50000))) + dataset.set_template( + 'Template', model_id=MODEL_ID, max_length=8192) + dataset.map(CountdownProcessor()) + dataset.encode(add_generation_prompt=True, batched=True) + return dataset + + +def compute_rewards( + trajectories: List[dict], +) -> Tuple[List[float], List[float], List[float]]: + """Compute format and accuracy rewards for Countdown game.""" + from twinkle.reward import CountDownAccuracy, FormatReward + format_rewards = FormatReward()(trajectories, []) + accuracy_rewards = CountDownAccuracy()(trajectories, []) + total_rewards = [a + b for a, b in zip(accuracy_rewards, format_rewards)] + return total_rewards, format_rewards, accuracy_rewards + + +def train(): + # Step 1: Initialize the Twinkle client + client = init_twinkle_client( + base_url='http://127.0.0.1:8000', + api_key=os.environ.get('MODELSCOPE_SDK_TOKEN'), + ) + + # Step 2: Prepare dataset and dataloader + dataset = create_countdown_dataset() + dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) + + # Step 3: Configure the training model + model = MultiLoraTransformersModel(model_id=MODEL_ID) + + lora_config = LoraConfig( + target_modules='all-linear', + r=8, + lora_alpha=32, + lora_dropout=0.05, + ) + model.add_adapter_to_model( + 'default', lora_config, + gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS, + ) + + # Set GRPO loss (the key difference from SFT training) + model.set_loss('GRPOLoss', epsilon=0.2, beta=0.0) + + # Set optimizer and LR scheduler + model.set_optimizer('AdamW', lr=LEARNING_RATE) + model.set_lr_scheduler( + 'CosineWarmupScheduler', + num_warmup_steps=500, + num_training_steps=MAX_STEPS, + ) + + # Set processor and template for encoding inputs + model.set_processor('InputProcessor') + model.set_template('Template', model_id=MODEL_ID) + + # Step 4: Configure the sampler + sampler = VLLMSampler(model_id=MODEL_ID) + sampler.set_template('Template', model_id=MODEL_ID) + + # Step 5: Setup metrics and advantage function + advantage_fn = GRPOAdvantage() + metrics = CompletionRewardMetric() + + sampling_params = { + 'max_tokens': MAX_NEW_TOKENS, + 'temperature': TEMPERATURE, + 'top_p': 0.95, + } + + # Track the current adapter path for sampling + current_adapter_uri = None + + step = 0 + for batch in dataloader: + if step >= MAX_STEPS: + break + + metrics.reset() + prompts = batch if isinstance(batch, list) else [batch] + + # ========== 1. Save weights and update adapter_uri ========== + # Instead of sync_weights, save the model checkpoint and pass + # the resulting path to the sampler as adapter_uri + if step % SYNC_INTERVAL == 0: + logger.info(f"Step {step}: Saving weights for sampler...") + twinkle_path = model.save( + name=f'grpo-sampler-step-{step}', + save_optimizer=False, + ) + current_adapter_uri = twinkle_path + logger.info( + f"Step {step}: Saved weights to {current_adapter_uri}") + + # ========== 2. Sample completions ========== + sample_response = sampler.sample( + inputs=prompts, + sampling_params=sampling_params, + adapter_uri=current_adapter_uri, + num_samples=NUM_GENERATIONS, + ) + + input_features = [] + old_logps_list = [] + completion_lengths = [] + + sequences = sample_response.get('sequences', []) + for seq in sequences: + input_features.append(seq.get('new_input_feature', seq)) + old_logps_list.append(seq.get('logprobs', [])) + completion_lengths.append(len(seq.get('tokens', []))) + + if not input_features: + logger.warning(f"Step {step}: No valid samples, skipping") + step += 1 + continue + + # ========== 3. Compute rewards ========== + total_rewards, format_rewards, accuracy_rewards = compute_rewards( + input_features) + metrics.accumulate( + None, None, + completion_lengths=completion_lengths, + rewards={ + 'total': total_rewards, + 'format': format_rewards, + 'accuracy': accuracy_rewards, + }) + + # ========== 4. Compute advantages ========== + advantages = advantage_fn( + total_rewards, + num_generations=NUM_GENERATIONS, + scale='group', + ).tolist() + + frac_zero_std = ( + 1.0 if all(abs(a) < 1e-8 for a in advantages) else 0.0) + if frac_zero_std == 1.0: + logger.info( + f"Step {step}: All advantages are zero, skipping training") + step += 1 + continue + + # ========== 5. Training step (GRPO) ========== + # forward_backward with GRPO loss: passes advantages and old_logps + # to the server-side GRPOLoss for proper policy optimization + model.forward_backward( + inputs=input_features, + advantages=advantages, + old_logps=old_logps_list, + ) + + # Gradient clipping and optimizer step + model.clip_grad_norm(1.0) + model.step() + model.zero_grad() + model.lr_step() + + gc.collect() + + # ========== 6. Log ========== + log_dict = metrics.calculate() + log_dict.update(model.calculate_metric()) + log_dict['train/frac_reward_zero_std'] = frac_zero_std + logger.info(f"Step {step}: {log_dict}") + step += 1 + + # Save final checkpoint + twinkle_path = model.save( + name='grpo-countdown-final', save_optimizer=True) + logger.info(f"Saved final checkpoint: {twinkle_path}") + + +if __name__ == '__main__': + train() From c0014339bea5972536cae8739faec437547a606a Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Sun, 8 Feb 2026 21:33:04 +0800 Subject: [PATCH 5/9] update cpu env --- src/twinkle/infra/_ray/ray_helper.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/twinkle/infra/_ray/ray_helper.py b/src/twinkle/infra/_ray/ray_helper.py index bb3326d3..2e68baeb 100644 --- a/src/twinkle/infra/_ray/ray_helper.py +++ b/src/twinkle/infra/_ray/ray_helper.py @@ -336,7 +336,8 @@ def create_workers(worker_cls: Type[T], group: str, execute: Literal['all', 'pee else: world_size = len(ranks) workers = [] - _visible_device_env = {Platform.get_platform(device_type_upper).visible_device_env(): ''} + # For CPU case, don't set visible device environment variables + _visible_device_env = {} for rank, (deploy_pg, index) in enumerate(zip(placement_groups, list(range(world_size)))): deploy_pg: Dict cluster_name = group From c3f36e18d212f6a2115343e010f702102da938ae Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Sun, 8 Feb 2026 22:18:53 +0800 Subject: [PATCH 6/9] update compat --- src/twinkle/model/megatron/multi_lora_megatron.py | 2 ++ src/twinkle/model/transformers/multi_lora_transformers.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/twinkle/model/megatron/multi_lora_megatron.py b/src/twinkle/model/megatron/multi_lora_megatron.py index f5b7cdc4..607fefb4 100644 --- a/src/twinkle/model/megatron/multi_lora_megatron.py +++ b/src/twinkle/model/megatron/multi_lora_megatron.py @@ -97,6 +97,8 @@ def __init__(self, self.model = self.strategy.wrap_model(self.model) self._model_wrapped = True self.multi_adapter.save_initial_weights() + # Active group for compatibility with single adapter + self.active_group = None def _check_adapter_valid(self, adapter_name: str): assert adapter_name and adapter_name in self.optimizer_group, f'Use a valid adapter_name first, current is: {adapter_name}' diff --git a/src/twinkle/model/transformers/multi_lora_transformers.py b/src/twinkle/model/transformers/multi_lora_transformers.py index 21a853d1..a9aef3a0 100644 --- a/src/twinkle/model/transformers/multi_lora_transformers.py +++ b/src/twinkle/model/transformers/multi_lora_transformers.py @@ -55,6 +55,8 @@ def __init__(self, # noqa self.strategy = AccelerateStrategy(mixed_precision=mixed_precision, device_mesh=None) self.model = self.strategy.wrap_model(self.model) self.multi_adapter.save_initial_weights() + # Active group for compatibility with single adapter + self.active_group = None def _check_adapter_valid(self, adapter_name: str): assert adapter_name and adapter_name in self.optimizer_group, f'Use a valid adapter_name first, current is: {adapter_name}' From b69b88861078a025052a8f18f83e00eb582174fc Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Sun, 8 Feb 2026 22:41:42 +0800 Subject: [PATCH 7/9] update --- .../server/twinkle/common/serialize.py | 33 +++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/src/twinkle/server/twinkle/common/serialize.py b/src/twinkle/server/twinkle/common/serialize.py index 2f53bf86..ae38d2a3 100644 --- a/src/twinkle/server/twinkle/common/serialize.py +++ b/src/twinkle/server/twinkle/common/serialize.py @@ -23,10 +23,38 @@ basic_types = (*primitive_types, *container_types) +def _serialize_data_slice(data_slice): + """Serialize data_slice (Iterable) into a JSON-compatible dict.""" + if data_slice is None: + return None + if isinstance(data_slice, range): + return {'_slice_type_': 'range', 'start': data_slice.start, 'stop': data_slice.stop, 'step': data_slice.step} + if isinstance(data_slice, (list, tuple)): + return {'_slice_type_': 'list', 'values': list(data_slice)} + raise ValueError( + f'Http mode does not support data_slice of type {type(data_slice).__name__}. ' + 'Supported types: range, list, tuple.' + ) + + +def _deserialize_data_slice(data_slice): + """Deserialize a dict back into the original data_slice object.""" + if data_slice is None: + return None + if not isinstance(data_slice, dict) or '_slice_type_' not in data_slice: + return data_slice + slice_type = data_slice['_slice_type_'] + if slice_type == 'range': + return range(data_slice['start'], data_slice['stop'], data_slice['step']) + if slice_type == 'list': + return data_slice['values'] + raise ValueError(f'Unsupported data_slice type: {slice_type}') + + def serialize_object(obj) -> str: if isinstance(obj, DatasetMeta): - assert obj.data_slice is None, 'Http mode does not support data_slice' - data = obj.__dict__ + data = obj.__dict__.copy() + data['data_slice'] = _serialize_data_slice(data.get('data_slice')) data['_TWINKLE_TYPE_'] = 'DatasetMeta' return json.dumps(data, ensure_ascii=False) elif isinstance(obj, LoraConfig): @@ -54,6 +82,7 @@ def deserialize_object(data: str) -> Any: if '_TWINKLE_TYPE_' in data: _type = data.pop('_TWINKLE_TYPE_') if _type == 'DatasetMeta': + data['data_slice'] = _deserialize_data_slice(data.get('data_slice')) return DatasetMeta(**data) elif _type == 'LoraConfig': return LoraConfig(**data) From 546ae098d82c16bd1fd7ed10bb570e479b3b756e Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Sun, 8 Feb 2026 22:57:57 +0800 Subject: [PATCH 8/9] update --- src/twinkle/server/twinkle/common/serialize.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/twinkle/server/twinkle/common/serialize.py b/src/twinkle/server/twinkle/common/serialize.py index ae38d2a3..839a3fbd 100644 --- a/src/twinkle/server/twinkle/common/serialize.py +++ b/src/twinkle/server/twinkle/common/serialize.py @@ -1,14 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import json -import sys from numbers import Number from typing import Mapping, Any -if sys.version_info[:2] <= (3, 11): - from typing_extensions import TypedDict -else: - from typing import TypedDict - from peft import LoraConfig from twinkle.dataset import DatasetMeta @@ -65,7 +59,7 @@ def serialize_object(obj) -> str: } filtered_dict['_TWINKLE_TYPE_'] = 'LoraConfig' return json.dumps(filtered_dict, ensure_ascii=False) - elif isinstance(obj, (Mapping, TypedDict)): + elif isinstance(obj, Mapping): return json.dumps(obj, ensure_ascii=False) elif isinstance(obj, basic_types): return obj From fe82fda7f41758f48283fd8f391572f7155a460f Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Sun, 8 Feb 2026 23:14:55 +0800 Subject: [PATCH 9/9] update --- client_tools/client_generator.py | 8 - cookbook/client/tinker/megatron/lora.py | 158 ------------------ cookbook/client/tinker/transformer/grpo.py | 14 +- cookbook/client/twinkle/transformer/grpo.py | 11 +- cookbook/client/twinkle/transformer/lora.py | 63 +++---- .../client/twinkle/transformer/sampler.py | 6 +- .../sampler/vllm_sampler/vllm_engine.py | 2 +- src/twinkle/server/__init__.py | 12 +- .../server/twinkle/common/serialize.py | 8 +- src/twinkle/server/twinkle/sampler.py | 3 +- src/twinkle_client/dataloader/dataloader.py | 8 +- src/twinkle_client/dataset/base.py | 5 +- .../dataset/iterable_dataset.py | 1 - .../dataset/iterable_packing_dataset.py | 1 - src/twinkle_client/dataset/lazy_dataset.py | 1 - src/twinkle_client/dataset/packing_dataset.py | 1 - .../model/multi_lora_transformers.py | 1 - src/twinkle_client/processor/base.py | 1 - src/twinkle_client/sampler/vllm_sampler.py | 6 - 19 files changed, 59 insertions(+), 251 deletions(-) delete mode 100644 cookbook/client/tinker/megatron/lora.py diff --git a/client_tools/client_generator.py b/client_tools/client_generator.py index 9e6b1bd0..e892ce2c 100644 --- a/client_tools/client_generator.py +++ b/client_tools/client_generator.py @@ -240,7 +240,6 @@ def build_imports() -> Tuple[List[str], str]: if typing_imports: lines.append(f"from typing import {', '.join(sorted(typing_imports))}") lines.extend([ - "from twinkle_client.http import TWINKLE_SERVER_URL", "from twinkle_client.http import http_post, heartbeat_manager", ]) lines.extend(sorted(twinkle_imports)) @@ -447,7 +446,6 @@ def generate_models(): model_code = AUTO_GEN_WARNING + '''from typing import Any, Optional, Union, Type, Dict, Literal, List import uuid -from twinkle_client.http import TWINKLE_SERVER_URL from twinkle_client.http import http_post, heartbeat_manager from twinkle import DeviceMesh from twinkle.data_format import InputFeature, Trajectory @@ -724,15 +722,10 @@ def generate_samplers(): client_module_path.mkdir(parents=True, exist_ok=True) sampler_code = AUTO_GEN_WARNING + '''from typing import Any, Optional, List, Dict, Union -import uuid -from twinkle_client.http import TWINKLE_SERVER_URL from twinkle_client.http import http_post, heartbeat_manager from twinkle.sampler.base import Sampler -from twinkle.sampler.types import SamplingParams, SampleResponse -from twinkle import DeviceMesh from peft import PeftConfig from twinkle.data_format import Trajectory, InputFeature -import json class VLLMSampler(Sampler): @@ -756,7 +749,6 @@ def __init__(self, model_id: str, **kwargs): json_data=kwargs ) response.raise_for_status() - return response.json() def _send_adapter_heartbeat(self): """Internal method to send adapter heartbeat.""" diff --git a/cookbook/client/tinker/megatron/lora.py b/cookbook/client/tinker/megatron/lora.py deleted file mode 100644 index 91815bb7..00000000 --- a/cookbook/client/tinker/megatron/lora.py +++ /dev/null @@ -1,158 +0,0 @@ -# Tinker-Compatible Client - Megatron LoRA Training & Sampling Example -# -# This script demonstrates end-to-end LoRA fine-tuning and inference using the -# Tinker-compatible client API with a Megatron backend. -# It covers: connecting to the server, preparing data manually with tokenizers, -# running a training loop, saving checkpoints, and sampling from the model. -# The server must be running first (see server.py and server_config.yaml). - -from twinkle_client import init_tinker_compat_client - -# Step 1: Initialize the Tinker-compatible client to communicate with the server. -service_client = init_tinker_compat_client(base_url='http://localhost:8000') - -# Step 2: List models available on the server to verify the connection -print("Available models:") -for item in service_client.get_server_capabilities().supported_models: - print("- " + item.model_name) - - -# Step 3: Create a REST client for querying training runs and checkpoints. -# This is useful for inspecting previous training sessions or resuming training. -rest_client = service_client.create_rest_client() - -future = rest_client.list_training_runs(limit=50) -response = future.result() - -# You can resume from a twinkle:// path. Example: -# resume_path = "twinkle://20260131_170251-Qwen_Qwen2_5-0_5B-Instruct-7275126c/weights/pig-latin-lora-epoch-1" -resume_path = "" - -print(f"Found {len(response.training_runs)} training runs") -for tr in response.training_runs: - print(tr.model_dump_json(indent=2)) - - chpts = rest_client.list_checkpoints(tr.training_run_id).result() - for chpt in chpts.checkpoints: - print(" " + chpt.model_dump_json(indent=2)) - # Uncomment the line below to resume from the last checkpoint: - # resume_path = chpt.tinker_path - -# Step 4: Create or resume a training client. -# If resume_path is set, it restores both model weights and optimizer state. -base_model = "Qwen/Qwen2.5-0.5B-Instruct" -if not resume_path: - training_client = service_client.create_lora_training_client( - base_model=base_model - ) -else: - training_client = service_client.create_training_client_from_state_with_optimizer(path=resume_path) - -# Step 5: Prepare training data manually -# -# This example teaches the model to translate English into Pig Latin. -# Each example has an "input" (English phrase) and "output" (Pig Latin). -examples = [ - {"input": "banana split", "output": "anana-bay plit-say"}, - {"input": "quantum physics", "output": "uantum-qay ysics-phay"}, - {"input": "donut shop", "output": "onut-day op-shay"}, - {"input": "pickle jar", "output": "ickle-pay ar-jay"}, - {"input": "space exploration", "output": "ace-spay exploration-way"}, - {"input": "rubber duck", "output": "ubber-ray uck-day"}, - {"input": "coding wizard", "output": "oding-cay izard-way"}, -] - -from tinker import types -from modelscope import AutoTokenizer - -# Load the tokenizer locally (avoids a network call to HuggingFace) -tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) - -def process_example(example: dict, tokenizer) -> types.Datum: - """Convert a raw example dict into a Datum suitable for the training API. - - The Datum contains: - - model_input: the token IDs fed into the LLM - - loss_fn_inputs: target tokens and per-token weights (0 = ignore, 1 = train) - """ - # Build a simple prompt template - prompt = f"English: {example['input']}\nPig Latin:" - - # Tokenize the prompt; weights=0 means the loss ignores these tokens - prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True) - prompt_weights = [0] * len(prompt_tokens) - - # Tokenize the completion; weights=1 means the loss is computed on these tokens - completion_tokens = tokenizer.encode(f" {example['output']}\n\n", add_special_tokens=False) - completion_weights = [1] * len(completion_tokens) - - # Concatenate prompt + completion - tokens = prompt_tokens + completion_tokens - weights = prompt_weights + completion_weights - - # Shift by one: input is tokens[:-1], target is tokens[1:] (next-token prediction) - input_tokens = tokens[:-1] - target_tokens = tokens[1:] - weights = weights[1:] - - return types.Datum( - model_input=types.ModelInput.from_ints(tokens=input_tokens), - loss_fn_inputs=dict(weights=weights, target_tokens=target_tokens) - ) - -# Process all examples into Datum objects -processed_examples = [process_example(ex, tokenizer) for ex in examples] - -# Visualize the first example to verify tokenization and weight alignment -datum0 = processed_examples[0] -print(f"{'Input':<20} {'Target':<20} {'Weight':<10}") -print("-" * 50) -for i, (inp, tgt, wgt) in enumerate(zip(datum0.model_input.to_ints(), datum0.loss_fn_inputs['target_tokens'].tolist(), datum0.loss_fn_inputs['weights'].tolist())): - print(f"{repr(tokenizer.decode([inp])):<20} {repr(tokenizer.decode([tgt])):<20} {wgt:<10}") - -# Step 6: Run the training loop -# -# For each epoch, iterate over multiple batches: -# - forward_backward: sends data to the server, computes loss & gradients -# - optim_step: updates model weights using Adam optimizer -import numpy as np -for epoch in range(2): - for batch in range(5): - # Send training data and get back logprobs (asynchronous futures) - fwdbwd_future = training_client.forward_backward(processed_examples, "cross_entropy") - optim_future = training_client.optim_step(types.AdamParams(learning_rate=1e-4)) - - # Wait for results from the server - fwdbwd_result = fwdbwd_future.result() - optim_result = optim_future.result() - - # Compute the weighted average log-loss per token for monitoring - print(f"Epoch {epoch}, Batch {batch}: ", end="") - logprobs = np.concatenate([output['logprobs'].tolist() for output in fwdbwd_result.loss_fn_outputs]) - weights = np.concatenate([example.loss_fn_inputs['weights'].tolist() for example in processed_examples]) - print(f"Loss per token: {-np.dot(logprobs, weights) / weights.sum():.4f}") - - # Save checkpoint (model weights + optimizer state) after each epoch - save_future = training_client.save_state(f"pig-latin-lora-epoch-{epoch}") - save_result = save_future.result() - print(f"Saved checkpoint for epoch {epoch} to {save_result.path}") - -# Step 7: Sample from the trained model -# -# Save the current weights and create a sampling client to generate text. -sampling_client = training_client.save_weights_and_get_sampling_client(name='pig-latin-model') - -# Prepare a prompt and sampling parameters -prompt = types.ModelInput.from_ints(tokenizer.encode("English: coffee break\nPig Latin:")) -params = types.SamplingParams( - max_tokens=20, # Maximum number of tokens to generate - temperature=0.0, # Greedy sampling (deterministic) - stop=["\n"] # Stop at newline -) - -# Generate 8 completions and print the results -future = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=8) -result = future.result() -print("Responses:") -for i, seq in enumerate(result.sequences): - print(f"{i}: {repr(tokenizer.decode(seq.tokens))}") diff --git a/cookbook/client/tinker/transformer/grpo.py b/cookbook/client/tinker/transformer/grpo.py index d9a6bb89..67b498b6 100644 --- a/cookbook/client/tinker/transformer/grpo.py +++ b/cookbook/client/tinker/transformer/grpo.py @@ -35,25 +35,25 @@ logger = get_logger() # ========== Configuration ========== -BASE_MODEL = "Qwen/Qwen2.5-3B-Instruct" +MODEL_ID = 'ms://Qwen/Qwen2.5-0.5B-Instruct' NUM_GENERATIONS = 8 MAX_NEW_TOKENS = 1024 LEARNING_RATE = 1e-5 -MAX_STEPS = 2000 +MAX_STEPS = 10 BATCH_SIZE = 4 TEMPERATURE = 1.0 -SYNC_INTERVAL = 1 # Save weights for sampler every N steps -LORA_RANK = 8 +SYNC_INTERVAL = 5 # Save weights for sampler every N steps +GRADIENT_ACCUMULATION_STEPS = 4 def create_countdown_dataset(): """Create Countdown Game dataset for GRPO training.""" - from twinkle.preprocessor import CountdownProcessor + dataset = Dataset(DatasetMeta( - "ms://zouxuhong/Countdown-Tasks-3to4", data_slice=range(50000))) + "ms://zouxuhong/Countdown-Tasks-3to4", data_slice=range(500))) dataset.set_template( "Template", model_id=f'ms://{BASE_MODEL}', max_length=8192) - dataset.map(CountdownProcessor()) + dataset.map('CountdownProcessor') dataset.encode(add_generation_prompt=True) return dataset diff --git a/cookbook/client/twinkle/transformer/grpo.py b/cookbook/client/twinkle/transformer/grpo.py index 71b7887f..63f79437 100644 --- a/cookbook/client/twinkle/transformer/grpo.py +++ b/cookbook/client/twinkle/transformer/grpo.py @@ -41,26 +41,25 @@ logger = get_logger() # ========== Configuration ========== -MODEL_ID = 'ms://Qwen/Qwen2.5-3B-Instruct' +MODEL_ID = 'ms://Qwen/Qwen2.5-0.5B-Instruct' NUM_GENERATIONS = 8 MAX_NEW_TOKENS = 1024 LEARNING_RATE = 1e-5 -MAX_STEPS = 2000 +MAX_STEPS = 10 BATCH_SIZE = 4 TEMPERATURE = 1.0 -SYNC_INTERVAL = 1 # Save weights for sampler every N steps +SYNC_INTERVAL = 5 # Save weights for sampler every N steps GRADIENT_ACCUMULATION_STEPS = 4 def create_countdown_dataset(): """Create Countdown Game dataset for GRPO training.""" - from twinkle.preprocessor import CountdownProcessor dataset = Dataset(dataset_meta=DatasetMeta( - "ms://zouxuhong/Countdown-Tasks-3to4", data_slice=range(50000))) + "ms://zouxuhong/Countdown-Tasks-3to4", data_slice=range(500))) dataset.set_template( 'Template', model_id=MODEL_ID, max_length=8192) - dataset.map(CountdownProcessor()) + dataset.map('CountdownProcessor') dataset.encode(add_generation_prompt=True, batched=True) return dataset diff --git a/cookbook/client/twinkle/transformer/lora.py b/cookbook/client/twinkle/transformer/lora.py index 5e2d292a..3d13aee3 100644 --- a/cookbook/client/twinkle/transformer/lora.py +++ b/cookbook/client/twinkle/transformer/lora.py @@ -46,27 +46,27 @@ def train(): # Step 4: Prepare the dataset # Load the self-cognition dataset from ModelScope - dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition')) + dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(500))) # Apply a chat template so the data matches the model's expected input format dataset.set_template( - 'Template', model_id='ms://Qwen/Qwen2.5-7B-Instruct', max_length=512) + 'Template', model_id='ms://Qwen/Qwen2.5-0.5B-Instruct', max_length=512) # Replace placeholder names in the dataset with custom model/author names dataset.map('SelfCognitionProcessor', init_args={ - 'model_name': 'twinkle模型', 'model_author': 'twinkle团队'}) + 'model_name': 'twinkle模型', 'model_author': 'ModelScope社区'}) # Tokenize and encode the dataset into model-ready input features dataset.encode(batched=True) - # Wrap the dataset into a DataLoader that yields batches of size 8 - dataloader = DataLoader(dataset=dataset, batch_size=8) + # Wrap the dataset into a DataLoader that yields batches of size 4 + dataloader = DataLoader(dataset=dataset, batch_size=4) # Step 5: Configure the model # Create a multi-LoRA Transformers model pointing to the base model on ModelScope model = MultiLoraTransformersModel( - model_id='ms://Qwen/Qwen2.5-7B-Instruct') + model_id='ms://Qwen/Qwen2.5-0.5B-Instruct') # Define LoRA configuration: apply low-rank adapters to all linear layers lora_config = LoraConfig( @@ -102,38 +102,41 @@ def train(): # Step 7: Run the training loop logger.info(model.get_train_configs()) - for step, batch in enumerate(dataloader): - # Forward pass + backward pass (computes gradients) - output = model.forward_backward(inputs=batch) + for epoch in range(3): + logger.info(f'Starting epoch {epoch}') + for step, batch in enumerate(dataloader): + # Forward pass + backward pass (computes gradients) + output = model.forward_backward(inputs=batch) - # Log the loss every 2 steps (aligned with gradient accumulation) - if step % 2 == 0: - logger.info(f'Current is step {step // 2}, loss: {output}') + # Log the loss every 2 steps (aligned with gradient accumulation) + if step % 2 == 0: + logger.info(f'Current is step {step // 2}, loss: {output}') - # Clip gradients to prevent exploding gradients (max norm = 1.0) - model.clip_grad_norm(1.0) + # Clip gradients to prevent exploding gradients (max norm = 1.0) + model.clip_grad_norm(1.0) - # Perform one optimizer step (update model weights) - model.step() + # Perform one optimizer step (update model weights) + model.step() - # Reset gradients to zero for the next iteration - model.zero_grad() + # Reset gradients to zero for the next iteration + model.zero_grad() - # Advance the learning rate scheduler by one step - model.lr_step() + # Advance the learning rate scheduler by one step + model.lr_step() - # Step 8: Save the trained checkpoint - twinkle_path = model.save(name=f'step-{step}', save_optimizer=True) - logger.info(f"Saved checkpoint: {twinkle_path}") + # Step 8: Save the trained checkpoint + twinkle_path = model.save(name=f'twinkle-epoch-{epoch}', save_optimizer=True) + logger.info(f"Saved checkpoint: {twinkle_path}") # Step 9: Upload the checkpoint to ModelScope Hub - hub_model_id = 'AlexEz/twinkle-self-cognition' - model.upload_to_hub( - checkpoint_dir=twinkle_path, - hub_model_id=hub_model_id, - async_upload=False - ) - logger.info(f"Uploaded checkpoint to hub: {hub_model_id}") + # YOUR_USER_NAME = "your_username" + # hub_model_id = f'{YOUR_USER_NAME}/twinkle-self-cognition' + # model.upload_to_hub( + # checkpoint_dir=twinkle_path, + # hub_model_id=hub_model_id, + # async_upload=False + # ) + # logger.info(f"Uploaded checkpoint to hub: {hub_model_id}") if __name__ == '__main__': diff --git a/cookbook/client/twinkle/transformer/sampler.py b/cookbook/client/twinkle/transformer/sampler.py index 138d1529..7419f167 100644 --- a/cookbook/client/twinkle/transformer/sampler.py +++ b/cookbook/client/twinkle/transformer/sampler.py @@ -21,14 +21,14 @@ logger = get_logger() -MODEL_ID = 'Qwen/Qwen2.5-7B-Instruct' +MODEL_ID = 'Qwen/Qwen2.5-0.5B-Instruct' # Optional: adapter URI for LoRA inference # This can be a twinkle:// path from a training run checkpoint # or None to use the base model -ADAPTER_URI = None +# ADAPTER_URI = None # Example: -# ADAPTER_URI = "twinkle://tml-EMPTY_TOKEN/20260203_211942-Qwen_Qwen2_5-7B-Instruct-11cdabc7/weights/twinkle-lora-2" +ADAPTER_URI = "twinkle://20260208_224851-fa3cdd11-default/weights/twinkle-epoch-2" def sample(): diff --git a/src/twinkle/sampler/vllm_sampler/vllm_engine.py b/src/twinkle/sampler/vllm_sampler/vllm_engine.py index a4dd648e..975456a8 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_engine.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_engine.py @@ -374,7 +374,7 @@ async def _get_or_load_lora(self, lora_path: str, user_id: Optional[str] = None) # Verify it's still loaded in engine loaded_loras = await self.engine.list_loras() if lora_int_id in loaded_loras: - if lora_path != self._user_lora_paths[user_id]: + if self._user_lora_paths.get(user_id) != lora_path: # reload the lora await self.remove_adapter(user_id) lora_request = await self._get_or_load_lora(lora_path, user_id) diff --git a/src/twinkle/server/__init__.py b/src/twinkle/server/__init__.py index d57f97f7..f522ff19 100644 --- a/src/twinkle/server/__init__.py +++ b/src/twinkle/server/__init__.py @@ -1,16 +1,10 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -from .twinkle.sampler import build_sampler_app as _build_sampler_app -from .twinkle.model import build_model_app as _build_model_app -from .twinkle.processor import build_processor_app as _build_processor_app +from .twinkle.sampler import build_sampler_app +from .twinkle.model import build_model_app +from .twinkle.processor import build_processor_app from .twinkle.server import build_server_app -from .utils import wrap_builder_with_device_group_env from .launcher import ServerLauncher, launch_server - -build_model_app = wrap_builder_with_device_group_env(_build_model_app) -build_processor_app = wrap_builder_with_device_group_env(_build_processor_app) -build_sampler_app = wrap_builder_with_device_group_env(_build_sampler_app) - __all__ = [ 'build_model_app', 'build_processor_app', diff --git a/src/twinkle/server/twinkle/common/serialize.py b/src/twinkle/server/twinkle/common/serialize.py index ae38d2a3..839a3fbd 100644 --- a/src/twinkle/server/twinkle/common/serialize.py +++ b/src/twinkle/server/twinkle/common/serialize.py @@ -1,14 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import json -import sys from numbers import Number from typing import Mapping, Any -if sys.version_info[:2] <= (3, 11): - from typing_extensions import TypedDict -else: - from typing import TypedDict - from peft import LoraConfig from twinkle.dataset import DatasetMeta @@ -65,7 +59,7 @@ def serialize_object(obj) -> str: } filtered_dict['_TWINKLE_TYPE_'] = 'LoraConfig' return json.dumps(filtered_dict, ensure_ascii=False) - elif isinstance(obj, (Mapping, TypedDict)): + elif isinstance(obj, Mapping): return json.dumps(obj, ensure_ascii=False) elif isinstance(obj, basic_types): return obj diff --git a/src/twinkle/server/twinkle/sampler.py b/src/twinkle/server/twinkle/sampler.py index 793a66dc..0c0bc1fb 100644 --- a/src/twinkle/server/twinkle/sampler.py +++ b/src/twinkle/server/twinkle/sampler.py @@ -18,8 +18,7 @@ import twinkle from twinkle import DeviceGroup, DeviceMesh -from twinkle.data_format import Trajectory, InputFeature -from twinkle.sampler.types import SamplingParams, SampleResponse, SampledSequence +from twinkle.data_format import Trajectory, InputFeature, SamplingParams from twinkle.server.utils.adapter_manager import AdapterManagerMixin from twinkle.server.utils.validation import verify_request_token, get_token_from_request from twinkle.server.utils.state import get_server_state, ServerStateProxy diff --git a/src/twinkle_client/dataloader/dataloader.py b/src/twinkle_client/dataloader/dataloader.py index f43b87fd..02cc205c 100644 --- a/src/twinkle_client/dataloader/dataloader.py +++ b/src/twinkle_client/dataloader/dataloader.py @@ -9,17 +9,15 @@ # 2. Run: python client_tools/client_generator.py # ============================================================================ -from typing import Callable, Optional, Type, Union -from twinkle_client.http import TWINKLE_SERVER_URL +from typing import Callable, Type, Union from twinkle_client.http import http_post, heartbeat_manager -from twinkle import DeviceMesh from twinkle.dataset import Dataset from twinkle.processor import InputProcessor class DataLoader(object): """Client wrapper for DataLoader that calls server HTTP endpoints.""" - def __init__(self, dataset: Union[Dataset, Callable], device_mesh: Optional[DeviceMesh] = None, **kwargs): + def __init__(self, dataset: Union[Dataset, Callable], **kwargs): from twinkle_client.http import get_base_url self.server_url = get_base_url() @@ -28,7 +26,7 @@ def __init__(self, dataset: Union[Dataset, Callable], device_mesh: Optional[Devi json_data={ 'processor_type': 'dataloader', 'class_type': 'DataLoader', - **{'dataset': dataset, 'device_mesh': device_mesh}, **kwargs + **{'dataset': dataset}, **kwargs } ) response.raise_for_status() diff --git a/src/twinkle_client/dataset/base.py b/src/twinkle_client/dataset/base.py index 8417af98..63097582 100644 --- a/src/twinkle_client/dataset/base.py +++ b/src/twinkle_client/dataset/base.py @@ -10,7 +10,6 @@ # ============================================================================ from typing import Any, Callable, Dict, Type, Union -from twinkle_client.http import TWINKLE_SERVER_URL from twinkle_client.http import http_post, heartbeat_manager from twinkle.dataset import Dataset from twinkle.dataset import DatasetMeta @@ -58,13 +57,13 @@ def set_template(self, template_func: Union[Template, Type[Template], str], **kw return response.json()["result"] - def encode(self, **kwargs): + def encode(self, add_generation_prompt: bool = False, **kwargs): response = http_post( url=f'{self.server_url}/processors/call', json_data={ 'processor_id': self.processor_id, 'function': 'encode', - **{}, + **{'add_generation_prompt': add_generation_prompt}, **kwargs } ) diff --git a/src/twinkle_client/dataset/iterable_dataset.py b/src/twinkle_client/dataset/iterable_dataset.py index 2edad39a..fb7658f5 100644 --- a/src/twinkle_client/dataset/iterable_dataset.py +++ b/src/twinkle_client/dataset/iterable_dataset.py @@ -9,7 +9,6 @@ # 2. Run: python client_tools/client_generator.py # ============================================================================ -from twinkle_client.http import TWINKLE_SERVER_URL from twinkle_client.http import http_post, heartbeat_manager from twinkle.dataset import Dataset from twinkle.dataset import DatasetMeta diff --git a/src/twinkle_client/dataset/iterable_packing_dataset.py b/src/twinkle_client/dataset/iterable_packing_dataset.py index de2d7509..0917c87b 100644 --- a/src/twinkle_client/dataset/iterable_packing_dataset.py +++ b/src/twinkle_client/dataset/iterable_packing_dataset.py @@ -10,7 +10,6 @@ # ============================================================================ from typing import Type, Union -from twinkle_client.http import TWINKLE_SERVER_URL from twinkle_client.http import http_post, heartbeat_manager from twinkle.dataset import Dataset from twinkle.dataset import DatasetMeta diff --git a/src/twinkle_client/dataset/lazy_dataset.py b/src/twinkle_client/dataset/lazy_dataset.py index b6238942..419a1a8f 100644 --- a/src/twinkle_client/dataset/lazy_dataset.py +++ b/src/twinkle_client/dataset/lazy_dataset.py @@ -9,7 +9,6 @@ # 2. Run: python client_tools/client_generator.py # ============================================================================ -from twinkle_client.http import TWINKLE_SERVER_URL from twinkle_client.http import http_post, heartbeat_manager from twinkle.dataset import Dataset from twinkle.dataset import DatasetMeta diff --git a/src/twinkle_client/dataset/packing_dataset.py b/src/twinkle_client/dataset/packing_dataset.py index 23102dd5..1187e817 100644 --- a/src/twinkle_client/dataset/packing_dataset.py +++ b/src/twinkle_client/dataset/packing_dataset.py @@ -9,7 +9,6 @@ # 2. Run: python client_tools/client_generator.py # ============================================================================ -from twinkle_client.http import TWINKLE_SERVER_URL from twinkle_client.http import http_post, heartbeat_manager from twinkle.dataset import Dataset from twinkle.dataset import DatasetMeta diff --git a/src/twinkle_client/model/multi_lora_transformers.py b/src/twinkle_client/model/multi_lora_transformers.py index c0762ded..bf792dca 100644 --- a/src/twinkle_client/model/multi_lora_transformers.py +++ b/src/twinkle_client/model/multi_lora_transformers.py @@ -10,7 +10,6 @@ # ============================================================================ from typing import Any, Optional, Union, Type, Dict, Literal, List import uuid -from twinkle_client.http import TWINKLE_SERVER_URL from twinkle_client.http import http_post, heartbeat_manager from twinkle import DeviceMesh from twinkle.data_format import InputFeature, Trajectory diff --git a/src/twinkle_client/processor/base.py b/src/twinkle_client/processor/base.py index 6de80530..79856944 100644 --- a/src/twinkle_client/processor/base.py +++ b/src/twinkle_client/processor/base.py @@ -10,7 +10,6 @@ # ============================================================================ from typing import List, Literal, Optional, Union -from twinkle_client.http import TWINKLE_SERVER_URL from twinkle_client.http import http_post, heartbeat_manager from twinkle import DeviceMesh from twinkle.data_format import InputFeature diff --git a/src/twinkle_client/sampler/vllm_sampler.py b/src/twinkle_client/sampler/vllm_sampler.py index 51146eb6..591ec8fc 100644 --- a/src/twinkle_client/sampler/vllm_sampler.py +++ b/src/twinkle_client/sampler/vllm_sampler.py @@ -9,15 +9,10 @@ # 2. Run: python client_tools/client_generator.py # ============================================================================ from typing import Any, Optional, List, Dict, Union -import uuid -from twinkle_client.http import TWINKLE_SERVER_URL from twinkle_client.http import http_post, heartbeat_manager from twinkle.sampler.base import Sampler -from twinkle.sampler.types import SamplingParams, SampleResponse -from twinkle import DeviceMesh from peft import PeftConfig from twinkle.data_format import Trajectory, InputFeature -import json class VLLMSampler(Sampler): @@ -41,7 +36,6 @@ def __init__(self, model_id: str, **kwargs): json_data=kwargs ) response.raise_for_status() - return response.json() def _send_adapter_heartbeat(self): """Internal method to send adapter heartbeat."""