From 0c99fd34596a0b9215cd431a3968e03cb9f51bd6 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Thu, 5 Feb 2026 15:27:25 +0800 Subject: [PATCH 1/9] add push lock --- cookbook/client/tinker/transformer/lora.py | 12 +++++-- src/twinkle/server/tinker/common/io_utils.py | 34 ++++++++++++++++++++ src/twinkle/server/tinker/model.py | 2 +- src/twinkle/server/tinker/server.py | 24 ++++++++------ src/twinkle_client/utils/patch_tinker.py | 30 +++++++++++++++++ 5 files changed, 89 insertions(+), 13 deletions(-) diff --git a/cookbook/client/tinker/transformer/lora.py b/cookbook/client/tinker/transformer/lora.py index 42ca9f5f..5d29c250 100644 --- a/cookbook/client/tinker/transformer/lora.py +++ b/cookbook/client/tinker/transformer/lora.py @@ -1,6 +1,10 @@ #%% +import dotenv +dotenv.load_dotenv('.env') + +import os from twinkle_client import init_tinker_compat_client -service_client = init_tinker_compat_client(base_url='http://localhost:8000') +service_client = init_tinker_compat_client(base_url='http://localhost:8000', api_key=os.environ.get('MODELSCOPE_SDK_TOKEN')) print("Available models:") for item in service_client.get_server_capabilities().supported_models: @@ -110,5 +114,7 @@ def process_example(example: dict, tokenizer) -> types.Datum: save_result = save_future.result() print(f"Saved checkpoint for epoch {epoch} to {save_result.path}") -# sampling_client = training_client.save_weights_and_get_sampling_client(name='pig-latin-model') - \ No newline at end of file +# NOTE: Need to set your modelscope token as api_key when initializing the service client +# model name is {run_id}_{checkpoint_name} +rest_client.publish_checkpoint_from_tinker_path(save_result.path).result() +print("Published checkpoint") diff --git a/src/twinkle/server/tinker/common/io_utils.py b/src/twinkle/server/tinker/common/io_utils.py index b3d55df1..94ff6835 100644 --- a/src/twinkle/server/tinker/common/io_utils.py +++ b/src/twinkle/server/tinker/common/io_utils.py @@ -57,8 +57,36 @@ def _create_training_run(self, model_id: str, run_config: types.CreateModelReque def _parse_training_run(self, data: Dict[str, Any]) -> types.TrainingRun: """Parse training run data into TrainingRun model.""" + # Transform checkpoint data to ensure tinker_path field exists + data = self._transform_checkpoint_fields(data) return types.TrainingRun(**data) + def _transform_checkpoint_fields(self, data: Dict[str, Any]) -> Dict[str, Any]: + """Transform checkpoint data to ensure compatibility with tinker types. + + Handles cases where: + - last_checkpoint/last_sampler_checkpoint might have twinkle_path instead of tinker_path + - Missing path field that needs to be constructed from other data + """ + data = data.copy() + for field in ['last_checkpoint', 'last_sampler_checkpoint']: + if field in data and data[field] is not None: + ckpt = data[field].copy() + # If twinkle_path exists but tinker_path doesn't, use twinkle_path + if 'twinkle_path' in ckpt and 'tinker_path' not in ckpt: + ckpt['tinker_path'] = ckpt.pop('twinkle_path') + # If neither exists, try to construct from checkpoint_id + elif 'tinker_path' not in ckpt: + # Try to get path from any available path field + path = ckpt.get('path') or ckpt.get('twinkle_path') + if path: + ckpt['tinker_path'] = path + elif 'checkpoint_id' in ckpt and 'training_run_id' in data: + # Construct path from components + ckpt['tinker_path'] = f"twinkle://{data['training_run_id']}/{ckpt['checkpoint_id']}" + data[field] = ckpt + return data + def _create_training_runs_response( self, runs: List[types.TrainingRun], limit: int, offset: int, total: int ) -> types.TrainingRunsResponse: @@ -99,6 +127,12 @@ def _create_checkpoint( def _parse_checkpoint(self, data: Dict[str, Any]) -> types.Checkpoint: """Parse checkpoint data into Checkpoint model.""" + data = data.copy() + # Transform twinkle_path to tinker_path if needed + if 'twinkle_path' in data and 'tinker_path' not in data: + data['tinker_path'] = data.pop('twinkle_path') + elif 'tinker_path' not in data and 'path' in data: + data['tinker_path'] = data.pop('path') return types.Checkpoint(**data) def _create_checkpoints_response(self, checkpoints: List[types.Checkpoint]) -> types.CheckpointsListResponse: diff --git a/src/twinkle/server/tinker/model.py b/src/twinkle/server/tinker/model.py index 8a288ffa..5d84ec97 100644 --- a/src/twinkle/server/tinker/model.py +++ b/src/twinkle/server/tinker/model.py @@ -24,7 +24,6 @@ from twinkle.server.utils.state import get_server_state, ServerStateProxy from twinkle.utils.logger import get_logger -from .common import TwinkleCompatTransformersModel from .common.task_queue import TaskQueueMixin, TaskQueueConfig from .common.adapter_manager import AdapterManagerMixin from .common.io_utils import create_training_run_manager, create_checkpoint_manager @@ -113,6 +112,7 @@ def __init__(self, nproc_per_node: int, device_group: Dict[str, Any], **kwargs ) else: + from .common import TwinkleCompatTransformersModel self.model = TwinkleCompatTransformersModel( model_id=model_id, device_mesh=self.device_mesh, diff --git a/src/twinkle/server/tinker/server.py b/src/twinkle/server/tinker/server.py index 749856d9..09c54c56 100644 --- a/src/twinkle/server/tinker/server.py +++ b/src/twinkle/server/tinker/server.py @@ -95,6 +95,8 @@ def __init__(self, supported_models: Optional[List[types.SupportedModel]] = None types.SupportedModel(model_name="Qwen/Qwen2.5-7B-Instruct"), types.SupportedModel(model_name="Qwen/Qwen2.5-72B-Instruct"), ] + # Lock for ModelScope config file operations (login writes, get_user_info reads) + self._modelscope_config_lock = asyncio.Lock() def _validate_base_model(self, base_model: str) -> None: """Validate that base_model is in supported_models list. @@ -544,15 +546,19 @@ async def publish_checkpoint( # Generate hub_model_id from checkpoint content and user token # Format: {username}/{run_id}_{checkpoint_name} - try: - from modelscope.hub.api import HubApi, ModelScopeConfig - hub_api = HubApi(token=token) - hub_api.login() # Save user info to local - username = ModelScopeConfig.get_user_info()[0] - except Exception: - # Fallback to using sanitized token as username - import re - username = re.sub(r'[^\w\-]', '_', token)[:20] + # Use lock to prevent race conditions when multiple requests access ModelScope config file + async with self._modelscope_config_lock: + try: + from modelscope.hub.api import HubApi, ModelScopeConfig + hub_api = HubApi(token=token) + hub_api.login() # Save user info to local + username = ModelScopeConfig.get_user_info()[0] + except Exception as e: + logger.error(f"Failed to get username from ModelScope: {e}") + raise HTTPException( + status_code=401, + detail="Failed to get username from ModelScope. Please ensure your token is valid." + ) # Extract checkpoint name from checkpoint_id (e.g., "weights/step-8" -> "step-8") checkpoint_name = checkpoint_id.split('/')[-1] diff --git a/src/twinkle_client/utils/patch_tinker.py b/src/twinkle_client/utils/patch_tinker.py index aca3862a..fcfcdb80 100644 --- a/src/twinkle_client/utils/patch_tinker.py +++ b/src/twinkle_client/utils/patch_tinker.py @@ -89,6 +89,31 @@ def _patched_async_tinker_init( self._idempotency_header = "X-Idempotency-Key" +def _patched_from_tinker_path(cls, tinker_path: str) -> Any: + """Patched version that supports both 'tinker://' and 'twinkle://' prefixes.""" + prefix = None + if tinker_path.startswith("tinker://"): + prefix = "tinker://" + elif tinker_path.startswith("twinkle://"): + prefix = "twinkle://" + + if prefix is None: + raise ValueError(f"Invalid tinker path: {tinker_path}") + + parts = tinker_path[len(prefix):].split("/") + if len(parts) != 3: + raise ValueError(f"Invalid tinker path: {tinker_path}") + if parts[1] not in ["weights", "sampler_weights"]: + raise ValueError(f"Invalid tinker path: {tinker_path}") + checkpoint_type = "training" if parts[1] == "weights" else "sampler" + return cls( + tinker_path=tinker_path, + training_run_id=parts[0], + checkpoint_type=checkpoint_type, + checkpoint_id="/".join(parts[1:]), + ) + + def patch_tinker(): """ Apply patches to tinker library. @@ -96,6 +121,7 @@ def patch_tinker(): This function patches: 1. InternalClientHolder._create_sampling_session to bypass 'tinker://' prefix validation 2. AsyncTinker.__init__ to bypass 'tml-' prefix validation for api_key + 3. ParsedCheckpointTinkerPath.from_tinker_path to support both 'tinker://' and 'twinkle://' prefixes This patch is idempotent - calling it multiple times has no additional effect. """ @@ -112,6 +138,10 @@ def patch_tinker(): from tinker._client import AsyncTinker AsyncTinker.__init__ = _patched_async_tinker_init + # Patch 3: support both tinker:// and twinkle:// prefixes for checkpoint paths + from tinker.types.checkpoint import ParsedCheckpointTinkerPath + ParsedCheckpointTinkerPath.from_tinker_path = classmethod(_patched_from_tinker_path) + _patched = True except ImportError: # tinker not installed, skip patching From b1f6746f3f60623ab1036026a191642758dc763c Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Thu, 5 Feb 2026 16:20:46 +0800 Subject: [PATCH 2/9] add load remote --- cookbook/client/tinker/transformer/lora.py | 9 +- src/twinkle/server/tinker/common/__init__.py | 8 -- src/twinkle/server/tinker/common/io_utils.py | 16 ---- .../server/tinker/common/megatron_model.py | 26 +++--- .../tinker/common/transformers_model.py | 26 ++---- src/twinkle/server/tinker/model.py | 2 +- src/twinkle/server/twinkle/common/__init__.py | 20 ----- src/twinkle/server/twinkle/common/io_utils.py | 21 ----- src/twinkle/server/twinkle/model.py | 43 +++------- src/twinkle/server/utils/io_utils.py | 82 +++++++++++++++++++ 10 files changed, 117 insertions(+), 136 deletions(-) diff --git a/cookbook/client/tinker/transformer/lora.py b/cookbook/client/tinker/transformer/lora.py index 5d29c250..1289ad15 100644 --- a/cookbook/client/tinker/transformer/lora.py +++ b/cookbook/client/tinker/transformer/lora.py @@ -17,7 +17,8 @@ future = rest_client.list_training_runs(limit=50) response = future.result() # resume_path = "twinkle://20260131_170251-Qwen_Qwen2_5-0_5B-Instruct-7275126c/weights/pig-latin-lora-epoch-1" -resume_path = "" +resume_path = "AlexEz/20260205_152451-Qwen_Qwen2_5-7B-Instruct-104b022e_pig-latin-lora-epoch-1" +# resume_path = "" print(f"Found {len(response.training_runs)} training runs") for tr in response.training_runs: print(tr.model_dump_json(indent=2)) @@ -34,6 +35,7 @@ base_model=base_model ) else: + print("Resuming from " + resume_path) training_client = service_client.create_training_client_from_state_with_optimizer(path=resume_path) #%% @@ -110,11 +112,12 @@ def process_example(example: dict, tokenizer) -> types.Datum: weights = np.concatenate([example.loss_fn_inputs['weights'].tolist() for example in processed_examples]) print(f"Loss per token: {-np.dot(logprobs, weights) / weights.sum():.4f}") + # Save the model and optimizer state save_future = training_client.save_state(f"pig-latin-lora-epoch-{epoch}") save_result = save_future.result() print(f"Saved checkpoint for epoch {epoch} to {save_result.path}") # NOTE: Need to set your modelscope token as api_key when initializing the service client # model name is {run_id}_{checkpoint_name} -rest_client.publish_checkpoint_from_tinker_path(save_result.path).result() -print("Published checkpoint") +# rest_client.publish_checkpoint_from_tinker_path(save_result.path).result() +# print("Published checkpoint") diff --git a/src/twinkle/server/tinker/common/__init__.py b/src/twinkle/server/tinker/common/__init__.py index b0cf29af..f4e96a81 100644 --- a/src/twinkle/server/tinker/common/__init__.py +++ b/src/twinkle/server/tinker/common/__init__.py @@ -1,14 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from .datum import datum_to_input_feature, input_feature_to_datum -from .transformers_model import TwinkleCompatTransformersModel from twinkle.utils import exists, requires - -if exists('megatron_core'): - from .megatron_model import TwinkleCompatMegatronModel -else: - class TwinkleCompatMegatronModel: # pragma: no cover - only used when megatron_core is missing - def __init__(self, *args, **kwargs): - requires('megatron_core') from .rate_limiter import RateLimiter from .task_queue import ( TaskStatus, diff --git a/src/twinkle/server/tinker/common/io_utils.py b/src/twinkle/server/tinker/common/io_utils.py index 94ff6835..047a9e81 100644 --- a/src/twinkle/server/tinker/common/io_utils.py +++ b/src/twinkle/server/tinker/common/io_utils.py @@ -11,13 +11,10 @@ from tinker import types from twinkle.server.utils.io_utils import ( - TWINKLE_DEFAULT_SAVE_DIR, CHECKPOINT_INFO_FILENAME, TRAIN_RUN_INFO_FILENAME, BaseTrainingRunManager, BaseCheckpointManager, - validate_user_path, - validate_ownership, ) @@ -172,16 +169,3 @@ def create_checkpoint_manager(token: str) -> CheckpointManager: training_run_manager = TrainingRunManager(token) return CheckpointManager(token, training_run_manager) - -# Re-export for backward compatibility -__all__ = [ - 'TWINKLE_DEFAULT_SAVE_DIR', - 'TRAIN_RUN_INFO_FILENAME', - 'CHECKPOINT_INFO_FILENAME', - 'TrainingRunManager', - 'CheckpointManager', - 'validate_user_path', - 'validate_ownership', - 'create_training_run_manager', - 'create_checkpoint_manager', -] diff --git a/src/twinkle/server/tinker/common/megatron_model.py b/src/twinkle/server/tinker/common/megatron_model.py index 3296ab59..cac49c05 100644 --- a/src/twinkle/server/tinker/common/megatron_model.py +++ b/src/twinkle/server/tinker/common/megatron_model.py @@ -170,7 +170,7 @@ def load(self, checkpoint_dir: str, **kwargs): Load checkpoint with token-based isolation support. Args: - checkpoint_dir: The twinkle:// path to the checkpoint + checkpoint_dir: The twinkle:// path to the checkpoint or hub model ID **kwargs: Additional keyword arguments including optional 'token' """ # Extract token from kwargs if provided (for user isolation) @@ -181,21 +181,15 @@ def load(self, checkpoint_dir: str, **kwargs): # Create checkpoint manager with the token checkpoint_manager = create_checkpoint_manager(token) - # handle twinkle checkpoint format - tinker_path = checkpoint_manager.parse_tinker_path(checkpoint_dir) - if not tinker_path: - raise ValueError(f"Invalid twinkle checkpoint path: {checkpoint_dir}") - - # check adapter files with token-based path - weight_path = checkpoint_manager.get_ckpt_dir( - tinker_path.training_run_id, - tinker_path.checkpoint_id - ) - if not weight_path or not weight_path.exists(): - raise ValueError(f"Checkpoint not found at {weight_path}") - - # Load using parent class method - return super().load(name=weight_path.name, output_dir=str(weight_path.parent), **kwargs) + # Use resolve_load_path to handle path resolution + resolved = checkpoint_manager.resolve_load_path(checkpoint_dir) + + if resolved.is_twinkle_path: + # Load from twinkle checkpoint + return super().load(name=resolved.checkpoint_name, output_dir=str(resolved.checkpoint_dir), **kwargs) + else: + # Load from hub + return super().load(name=resolved.checkpoint_name, **kwargs) @staticmethod def _get_forward_output(inputs: List[types.Datum], logits: torch.Tensor) -> List[dict]: diff --git a/src/twinkle/server/tinker/common/transformers_model.py b/src/twinkle/server/tinker/common/transformers_model.py index 095bec76..280e4a3c 100644 --- a/src/twinkle/server/tinker/common/transformers_model.py +++ b/src/twinkle/server/tinker/common/transformers_model.py @@ -89,7 +89,7 @@ def load(self, checkpoint_dir: str, **kwargs): Load checkpoint with token-based isolation support. Args: - checkpoint_dir: The twinkle:// path to the checkpoint + checkpoint_dir: The twinkle:// path to the checkpoint or hub model ID **kwargs: Additional keyword arguments including optional 'token' """ # Extract token from kwargs if provided (for user isolation) @@ -100,25 +100,15 @@ def load(self, checkpoint_dir: str, **kwargs): # Create checkpoint manager with the token checkpoint_manager = create_checkpoint_manager(token) - # handle twinkle checkpoint format - tinker_path = checkpoint_manager.parse_tinker_path(checkpoint_dir) - if not tinker_path: - raise ValueError(f"Invalid twinkle checkpoint path: {checkpoint_dir}") + # Use resolve_load_path to handle path resolution + resolved = checkpoint_manager.resolve_load_path(checkpoint_dir) - # check adapter files with token-based path - weight_path = checkpoint_manager.get_ckpt_dir( - tinker_path.training_run_id, - tinker_path.checkpoint_id - ) - if not weight_path or not weight_path.exists(): - raise ValueError(f"Checkpoint not found at {weight_path}") - - if (weight_path / 'adapter_config.json').exists(): - return super().load(name=weight_path.name, output_dir=weight_path.parent, **kwargs) - elif (weight_path / tinker_path.training_run_id / 'adapter_config.json').exists(): - return super().load(name=weight_path.name, output_dir=weight_path.parent, **kwargs) + if resolved.is_twinkle_path: + # Load from twinkle checkpoint + return super().load(name=resolved.checkpoint_name, output_dir=str(resolved.checkpoint_dir), **kwargs) else: - raise ValueError(f"Adapter files not found in {weight_path}") + # Load from hub + return super().load(name=resolved.checkpoint_name, **kwargs) @staticmethod def _get_forward_output(inputs: List[types.Datum], logits: torch.Tensor) -> List[dict]: diff --git a/src/twinkle/server/tinker/model.py b/src/twinkle/server/tinker/model.py index 5d84ec97..231f1330 100644 --- a/src/twinkle/server/tinker/model.py +++ b/src/twinkle/server/tinker/model.py @@ -112,7 +112,7 @@ def __init__(self, nproc_per_node: int, device_group: Dict[str, Any], **kwargs ) else: - from .common import TwinkleCompatTransformersModel + from .common.transformers_model import TwinkleCompatTransformersModel self.model = TwinkleCompatTransformersModel( model_id=model_id, device_mesh=self.device_mesh, diff --git a/src/twinkle/server/twinkle/common/__init__.py b/src/twinkle/server/twinkle/common/__init__.py index ae8c5069..85b3e739 100644 --- a/src/twinkle/server/twinkle/common/__init__.py +++ b/src/twinkle/server/twinkle/common/__init__.py @@ -1,21 +1 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -from .io_utils import ( - TrainingRunManager, - CheckpointManager, - TrainingRun, - TrainingRunsResponse, - Checkpoint, - CheckpointsListResponse, - Cursor, - WeightsInfoResponse, - LoraConfig, - CreateModelRequest, - ParsedCheckpointTwinklePath, - validate_user_path, - validate_ownership, - TWINKLE_DEFAULT_SAVE_DIR, - TRAIN_RUN_INFO_FILENAME, - create_training_run_manager, - create_checkpoint_manager, -) -from twinkle.server.utils.io_utils import BaseFileManager as FileManager diff --git a/src/twinkle/server/twinkle/common/io_utils.py b/src/twinkle/server/twinkle/common/io_utils.py index e5d03608..677f18fe 100644 --- a/src/twinkle/server/twinkle/common/io_utils.py +++ b/src/twinkle/server/twinkle/common/io_utils.py @@ -208,24 +208,3 @@ def create_checkpoint_manager(token: str) -> CheckpointManager: return CheckpointManager(token, training_run_manager) -# Re-export for backward compatibility -__all__ = [ - 'TWINKLE_DEFAULT_SAVE_DIR', - 'TRAIN_RUN_INFO_FILENAME', - 'CHECKPOINT_INFO_FILENAME', - 'Cursor', - 'Checkpoint', - 'TrainingRun', - 'TrainingRunsResponse', - 'CheckpointsListResponse', - 'ParsedCheckpointTwinklePath', - 'WeightsInfoResponse', - 'LoraConfig', - 'CreateModelRequest', - 'TrainingRunManager', - 'CheckpointManager', - 'validate_user_path', - 'validate_ownership', - 'create_training_run_manager', - 'create_checkpoint_manager', -] diff --git a/src/twinkle/server/twinkle/model.py b/src/twinkle/server/twinkle/model.py index 4661be5c..0acba943 100644 --- a/src/twinkle/server/twinkle/model.py +++ b/src/twinkle/server/twinkle/model.py @@ -428,47 +428,24 @@ def load(self, request: Request, body: LoadRequest): # Extract token for directory isolation token = request.state.token + checkpoint_manager = create_checkpoint_manager(token) - # Check if body.name is a twinkle:// path or a simple checkpoint name - if body.name.startswith("twinkle://"): - # Parse twinkle:// path - checkpoint_manager = create_checkpoint_manager(token) - parsed = checkpoint_manager.parse_twinkle_path(body.name) - if not parsed: - raise ValueError(f"Invalid twinkle path format: {body.name}") - - # Extract the checkpoint name from the parsed path - # parsed.checkpoint_id is like "weights/step-8" - checkpoint_id = parsed.checkpoint_id - checkpoint_name = parsed.checkpoint_id.split('/')[-1] # Extract "step-8" - - # Use the training_run_id from the path as the model_id - model_id_to_load = parsed.training_run_id - - # Verify checkpoint exists and user has access - checkpoint = checkpoint_manager.get(model_id_to_load, checkpoint_id) - if not checkpoint: - raise ValueError( - f"Checkpoint not found or access denied: {body.name}" - ) - - # Get the actual directory path - output_dir = checkpoint_manager.get_save_dir( - model_id=model_id_to_load, - is_sampler=False - ) - + # Use resolve_load_path to handle path resolution + resolved = checkpoint_manager.resolve_load_path(body.name) + + if resolved.is_twinkle_path: + # Load from twinkle checkpoint directory ret = self.model.load( - name=checkpoint_name, - output_dir=output_dir, + name=resolved.checkpoint_name, + output_dir=resolved.checkpoint_dir, adapter_name=adapter_name, load_optimizer=body.load_optimizer, **extra_kwargs ) else: - # No twinkle checkpoint name provided - load from modelscope + # Load from hub (checkpoint_dir is None) ret = self.model.load( - name=body.name, + name=resolved.checkpoint_name, adapter_name=adapter_name, load_optimizer=body.load_optimizer, token=token, diff --git a/src/twinkle/server/utils/io_utils.py b/src/twinkle/server/utils/io_utils.py index 5eb4c44a..0211e8c2 100644 --- a/src/twinkle/server/utils/io_utils.py +++ b/src/twinkle/server/utils/io_utils.py @@ -76,6 +76,23 @@ class BaseParsedCheckpointPath(BaseModel): checkpoint_id: str +class ResolvedLoadPath(BaseModel): + """Result of resolving a load path. + + Attributes: + checkpoint_name: The name of the checkpoint (e.g., 'step-8' or hub model id) + checkpoint_dir: The directory containing the checkpoint, or None if loading from hub + is_twinkle_path: Whether the path was a twinkle:// path + training_run_id: The training run ID (only set for twinkle:// paths) + checkpoint_id: The checkpoint ID (only set for twinkle:// paths) + """ + checkpoint_name: str + checkpoint_dir: Optional[str] = None + is_twinkle_path: bool = False + training_run_id: Optional[str] = None + checkpoint_id: Optional[str] = None + + class BaseWeightsInfoResponse(BaseModel): """Base model for weights info response.""" training_run_id: str @@ -725,3 +742,68 @@ def get_weights_info(self, checkpoint_path: str) -> Optional[Any]: return None return self._create_weights_info(run_info) + + def resolve_load_path(self, path: str, validate_exists: bool = True) -> ResolvedLoadPath: + """ + Resolve a checkpoint load path. + + This method handles two types of paths: + 1. twinkle:// paths: Parse, validate permissions, return checkpoint_name and checkpoint_dir + 2. Hub model IDs: Return the path as checkpoint_name with checkpoint_dir=None + + Args: + path: The path to resolve (either twinkle:// format or hub model ID) + validate_exists: Whether to validate that the checkpoint exists (default: True) + + Returns: + ResolvedLoadPath with checkpoint_name and checkpoint_dir + + Raises: + ValueError: If the path format is invalid or checkpoint not found + """ + # Check if path starts with twinkle:// prefix + if path.startswith(self.path_prefix): + # Parse the twinkle:// path + parsed = self.parse_path(path) + if not parsed: + raise ValueError(f"Invalid {self.path_prefix} path format: {path}") + + # Extract components + training_run_id = parsed.training_run_id + checkpoint_id = parsed.checkpoint_id + checkpoint_name = checkpoint_id.split('/')[-1] # Extract name from "weights/step-8" + + if validate_exists: + # Verify checkpoint exists and user has access + checkpoint = self.get(training_run_id, checkpoint_id) + if not checkpoint: + raise ValueError( + f"Checkpoint not found or access denied: {path}" + ) + + # Get the checkpoint directory + checkpoint_dir = str(self.get_ckpt_dir(training_run_id, checkpoint_id)) + + if validate_exists: + # Verify the directory exists + from pathlib import Path as PathLib + if not PathLib(checkpoint_dir).exists(): + raise ValueError(f"Checkpoint directory not found: {checkpoint_dir}") + + return ResolvedLoadPath( + checkpoint_name=checkpoint_name, + checkpoint_dir=checkpoint_dir, + is_twinkle_path=True, + training_run_id=training_run_id, + checkpoint_id=checkpoint_id + ) + else: + # Not a twinkle:// path - treat as hub model ID + # Return the path as checkpoint_name with no checkpoint_dir + return ResolvedLoadPath( + checkpoint_name=path, + checkpoint_dir=None, + is_twinkle_path=False, + training_run_id=None, + checkpoint_id=None + ) From 3f40ddd0b3d61908976082c1504141a29379715b Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Thu, 5 Feb 2026 17:00:20 +0800 Subject: [PATCH 3/9] add tinker resume from remote --- cookbook/client/tinker/transformer/lora.py | 3 +- src/twinkle/hub/hub.py | 178 +++++++++++++----- src/twinkle/server/tinker/common/io_utils.py | 27 ++- src/twinkle/server/twinkle/common/io_utils.py | 19 +- src/twinkle/server/utils/io_utils.py | 109 +++++++++-- 5 files changed, 272 insertions(+), 64 deletions(-) diff --git a/cookbook/client/tinker/transformer/lora.py b/cookbook/client/tinker/transformer/lora.py index 1289ad15..00181ff9 100644 --- a/cookbook/client/tinker/transformer/lora.py +++ b/cookbook/client/tinker/transformer/lora.py @@ -16,8 +16,9 @@ future = rest_client.list_training_runs(limit=50) response = future.result() +# Support resume from twinkle path or model id # resume_path = "twinkle://20260131_170251-Qwen_Qwen2_5-0_5B-Instruct-7275126c/weights/pig-latin-lora-epoch-1" -resume_path = "AlexEz/20260205_152451-Qwen_Qwen2_5-7B-Instruct-104b022e_pig-latin-lora-epoch-1" +resume_path = "AlexEz/20260205_163645-Qwen_Qwen2_5-7B-Instruct-385d5c17_pig-latin-lora-epoch-1" # resume_path = "" print(f"Found {len(response.training_runs)} training runs") for tr in response.training_runs: diff --git a/src/twinkle/hub/hub.py b/src/twinkle/hub/hub.py index 26a046f0..62887363 100644 --- a/src/twinkle/hub/hub.py +++ b/src/twinkle/hub/hub.py @@ -59,6 +59,24 @@ def remove_source_type(resource_name: str): else: return parts[-1] + @classmethod + def _get_hub_class(cls, resource_name: str) -> type: + """Get the appropriate Hub class based on resource name prefix. + + Args: + resource_name: The resource name with optional prefix (hf:// or ms://) + + Returns: + The Hub class (HFHub or MSHub) + """ + source = cls.source_type(resource_name) + if source == 'hf': + return HFHub + elif source == 'ms': + return MSHub + else: + raise NotImplementedError(f'Unknown source type: {source}') + @classmethod def try_login(cls, token: Optional[str] = None) -> bool: """Try to log in to the hub @@ -69,12 +87,8 @@ def try_login(cls, token: Optional[str] = None) -> bool: Returns: bool: Whether login is successful """ - if cls.source_type(token) == 'hf': - return HFHub.try_login(cls.remove_source_type(token)) - elif cls.source_type(token) == 'ms': - return MSHub.try_login(cls.remove_source_type(token)) - else: - raise NotImplementedError + hub = cls._get_hub_class(token) + return hub.try_login(cls.remove_source_type(token)) @classmethod def create_model_repo(cls, repo_id: str, token: Optional[str] = None, private: bool = False): @@ -85,12 +99,8 @@ def create_model_repo(cls, repo_id: str, token: Optional[str] = None, private: b token: The hub token to use private: If is a private repo """ - if cls.source_type(repo_id) == 'hf': - return HFHub.create_model_repo(cls.remove_source_type(repo_id), token, private) - elif cls.source_type(repo_id) == 'ms': - return MSHub.create_model_repo(cls.remove_source_type(repo_id), token, private) - else: - raise NotImplementedError + hub = cls._get_hub_class(repo_id) + return hub.create_model_repo(cls.remove_source_type(repo_id), token, private) @classmethod def push_to_hub(cls, @@ -117,14 +127,10 @@ def push_to_hub(cls, revision: The revision to push to ignore_patterns: The ignore file patterns """ - if cls.source_type(repo_id) == 'hf': - return HFHub.push_to_hub(cls.remove_source_type(repo_id), folder_path, path_in_repo, commit_message, - commit_description, token, private, revision, ignore_patterns, **kwargs) - elif cls.source_type(repo_id) == 'ms': - return MSHub.push_to_hub(cls.remove_source_type(repo_id), folder_path, path_in_repo, commit_message, - commit_description, token, private, revision, ignore_patterns, **kwargs) - else: - raise NotImplementedError + hub = cls._get_hub_class(repo_id) + return hub.push_to_hub( + cls.remove_source_type(repo_id), folder_path, path_in_repo, commit_message, + commit_description, token, private, revision, ignore_patterns, **kwargs) @classmethod def async_push_to_hub(cls, @@ -174,12 +180,8 @@ def load_dataset(cls, Returns: The Dataset instance """ - if cls.source_type(dataset_id) == 'hf': - return HFHub.load_dataset(cls.remove_source_type(dataset_id), subset_name, split, streaming, revision) - elif cls.source_type(dataset_id) == 'ms': - return MSHub.load_dataset(cls.remove_source_type(dataset_id), subset_name, split, streaming, revision) - else: - raise NotImplementedError() + hub = cls._get_hub_class(dataset_id) + return hub.load_dataset(cls.remove_source_type(dataset_id), subset_name, split, streaming, revision) @classmethod def download_model(cls, @@ -208,22 +210,40 @@ def download_model(cls, ) | set(large_file_pattern) if os.path.exists(model_id_or_path): return model_id_or_path - if cls.source_type(model_id_or_path) == 'hf': - return HFHub.download_model( - model_id_or_path=cls.remove_source_type(model_id_or_path), - revision=revision, - ignore_patterns=ignore_patterns, - token=token, - **kwargs) - elif cls.source_type(model_id_or_path) == 'ms': - return MSHub.download_model( - model_id_or_path=cls.remove_source_type(model_id_or_path), - revision=revision, - ignore_patterns=ignore_patterns, - token=token, - **kwargs) - else: - raise NotImplementedError + hub = cls._get_hub_class(model_id_or_path) + return hub.download_model( + model_id_or_path=cls.remove_source_type(model_id_or_path), + revision=revision, + ignore_patterns=ignore_patterns, + token=token, + **kwargs) + + @classmethod + def download_file(cls, + repo_id: str, + repo_type: str = 'model', + allow_patterns: Optional[Union[List[str], str]] = None, + token: Optional[str] = None, + **kwargs) -> str: + """Download specific files from the hub + + Args: + repo_id: The repository id + repo_type: The type of repository, default is 'model' + allow_patterns: Patterns to filter which files to download + token: The hub token + **kwargs: Additional arguments passed to the download function + + Returns: + The local directory path containing downloaded files + """ + hub = cls._get_hub_class(repo_id) + return hub.download_file( + repo_id=cls.remove_source_type(repo_id), + repo_type=repo_type, + allow_patterns=allow_patterns, + token=token, + **kwargs) class MSHub(HubOperation): @@ -432,6 +452,51 @@ def download_model(cls, return snapshot_download(**download_kwargs) + @classmethod + def download_file(cls, + repo_id: str, + repo_type: str = 'model', + allow_patterns: Optional[Union[List[str], str]] = None, + token: Optional[str] = None, + **kwargs) -> str: + """Download specific files from ModelScope hub + + Args: + repo_id: The repository id + repo_type: The type of repository, default is 'model' + allow_patterns: Patterns to filter which files to download + token: The hub token + **kwargs: Additional arguments passed to _snapshot_download + + Returns: + The local directory path containing downloaded files + """ + requires('modelscope') + cls.try_login(token) + from modelscope.hub.snapshot_download import _snapshot_download + import inspect + + # Build download arguments + download_kwargs = { + 'repo_id': repo_id, + 'repo_type': repo_type, + 'allow_patterns': allow_patterns, + **kwargs + } + + # Add token parameter only if supported by the function signature + if token is not None: + sig = inspect.signature(_snapshot_download) + if 'token' in sig.parameters: + download_kwargs['token'] = token + else: + print( + 'Token parameter is not supported by current modelscope version. ' + 'Please upgrade to modelscope >= 1.34.0 for token-based authentication.' + ) + + return _snapshot_download(**download_kwargs) + @staticmethod def add_patterns_to_file(repo, file_name: str, @@ -574,3 +639,32 @@ def download_model(cls, token=token, **kwargs ) + + @classmethod + def download_file(cls, + repo_id: str, + repo_type: str = 'model', + allow_patterns: Optional[Union[List[str], str]] = None, + token: Optional[str] = None, + **kwargs) -> str: + """Download specific files from HuggingFace hub + + Args: + repo_id: The repository id + repo_type: The type of repository, default is 'model' + allow_patterns: Patterns to filter which files to download + token: The hub token + **kwargs: Additional arguments passed to snapshot_download + + Returns: + The local directory path containing downloaded files + """ + requires('huggingface_hub') + from huggingface_hub import snapshot_download + return snapshot_download( + repo_id=repo_id, + repo_type=repo_type, + allow_patterns=allow_patterns, + token=token, + **kwargs + ) diff --git a/src/twinkle/server/tinker/common/io_utils.py b/src/twinkle/server/tinker/common/io_utils.py index 047a9e81..f6f6a05d 100644 --- a/src/twinkle/server/tinker/common/io_utils.py +++ b/src/twinkle/server/tinker/common/io_utils.py @@ -11,10 +11,14 @@ from tinker import types from twinkle.server.utils.io_utils import ( + TWINKLE_DEFAULT_SAVE_DIR, CHECKPOINT_INFO_FILENAME, TRAIN_RUN_INFO_FILENAME, BaseTrainingRunManager, BaseCheckpointManager, + ResolvedLoadPath, + validate_user_path, + validate_ownership, ) @@ -109,9 +113,17 @@ def path_field_name(self) -> str: def _create_checkpoint( self, checkpoint_id: str, checkpoint_type: str, - path: str, size_bytes: int, public: bool + path: str, size_bytes: int, public: bool, + base_model: Optional[str] = None, + is_lora: bool = False, + lora_rank: Optional[int] = None, + train_unembed: Optional[bool] = None, + train_mlp: Optional[bool] = None, + train_attn: Optional[bool] = None, + user_metadata: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: """Create checkpoint data.""" + # Create base checkpoint using tinker types checkpoint = types.Checkpoint( checkpoint_id=checkpoint_id, checkpoint_type=checkpoint_type, @@ -120,7 +132,18 @@ def _create_checkpoint( size_bytes=size_bytes, public=public ) - return checkpoint.model_dump(mode='json') + result = checkpoint.model_dump(mode='json') + + # Add training run info fields (may not be supported by external types.Checkpoint) + result['base_model'] = base_model + result['is_lora'] = is_lora + result['lora_rank'] = lora_rank + result['train_unembed'] = train_unembed + result['train_mlp'] = train_mlp + result['train_attn'] = train_attn + result['user_metadata'] = user_metadata + + return result def _parse_checkpoint(self, data: Dict[str, Any]) -> types.Checkpoint: """Parse checkpoint data into Checkpoint model.""" diff --git a/src/twinkle/server/twinkle/common/io_utils.py b/src/twinkle/server/twinkle/common/io_utils.py index 677f18fe..87349553 100644 --- a/src/twinkle/server/twinkle/common/io_utils.py +++ b/src/twinkle/server/twinkle/common/io_utils.py @@ -22,6 +22,7 @@ BaseWeightsInfoResponse, BaseTrainingRunManager, BaseCheckpointManager, + ResolvedLoadPath, validate_user_path, validate_ownership, ) @@ -146,7 +147,14 @@ def path_field_name(self) -> str: def _create_checkpoint( self, checkpoint_id: str, checkpoint_type: str, - path: str, size_bytes: int, public: bool + path: str, size_bytes: int, public: bool, + base_model: Optional[str] = None, + is_lora: bool = False, + lora_rank: Optional[int] = None, + train_unembed: Optional[bool] = None, + train_mlp: Optional[bool] = None, + train_attn: Optional[bool] = None, + user_metadata: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: """Create checkpoint data.""" checkpoint = Checkpoint( @@ -155,7 +163,14 @@ def _create_checkpoint( time=datetime.now(), twinkle_path=path, size_bytes=size_bytes, - public=public + public=public, + base_model=base_model, + is_lora=is_lora, + lora_rank=lora_rank, + train_unembed=train_unembed, + train_mlp=train_mlp, + train_attn=train_attn, + user_metadata=user_metadata ) return checkpoint.model_dump(mode='json') diff --git a/src/twinkle/server/utils/io_utils.py b/src/twinkle/server/utils/io_utils.py index 0211e8c2..7a471288 100644 --- a/src/twinkle/server/utils/io_utils.py +++ b/src/twinkle/server/utils/io_utils.py @@ -17,6 +17,8 @@ from pydantic import BaseModel +from twinkle.hub import HubOperation + TWINKLE_DEFAULT_SAVE_DIR = os.environ.get('TWINKLE_DEFAULT_SAVE_DIR', './outputs') CHECKPOINT_INFO_FILENAME = 'checkpoint_metadata.json' @@ -37,6 +39,14 @@ class BaseCheckpoint(BaseModel): time: datetime size_bytes: int public: bool = False + # Training run info (stored for hub downloads) + base_model: Optional[str] = None + is_lora: bool = False + lora_rank: Optional[int] = None + train_unembed: Optional[bool] = None + train_mlp: Optional[bool] = None + train_attn: Optional[bool] = None + user_metadata: Optional[Dict[str, Any]] = None class BaseTrainingRun(BaseModel): @@ -435,11 +445,32 @@ def path_field_name(self) -> str: @abstractmethod def _create_checkpoint( self, checkpoint_id: str, checkpoint_type: str, - path: str, size_bytes: int, public: bool + path: str, size_bytes: int, public: bool, + base_model: Optional[str] = None, + is_lora: bool = False, + lora_rank: Optional[int] = None, + train_unembed: Optional[bool] = None, + train_mlp: Optional[bool] = None, + train_attn: Optional[bool] = None, + user_metadata: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: """ Create checkpoint data. + Args: + checkpoint_id: The checkpoint identifier + checkpoint_type: Type of checkpoint ('training' or 'sampler') + path: The twinkle:// path to the checkpoint + size_bytes: Size of the checkpoint in bytes + public: Whether the checkpoint is public + base_model: The base model name/path + is_lora: Whether this is a LoRA checkpoint + lora_rank: The LoRA rank if applicable + train_unembed: Whether unembed layers are trained + train_mlp: Whether MLP layers are trained + train_attn: Whether attention layers are trained + user_metadata: User-provided metadata + Returns: Dictionary with checkpoint data """ @@ -594,12 +625,22 @@ def save(self, model_id: str, name: str, path = f"{self.path_prefix}{model_id}/{checkpoint_id}" checkpoint_path = self.get_ckpt_dir(model_id, checkpoint_id) + # Read training run info to include in checkpoint metadata + run_info = self.training_run_manager._read_info(model_id) + ckpt_data = self._create_checkpoint( checkpoint_id=checkpoint_id, checkpoint_type=checkpoint_type, path=path, size_bytes=self.get_dir_size(checkpoint_path), - public=public + public=public, + base_model=run_info.get('base_model'), + is_lora=run_info.get('is_lora', False), + lora_rank=run_info.get('lora_rank'), + train_unembed=run_info.get('train_unembed'), + train_mlp=run_info.get('train_mlp'), + train_attn=run_info.get('train_attn'), + user_metadata=run_info.get('user_metadata') ) self._write_ckpt_info(model_id, checkpoint_id, ckpt_data) @@ -718,30 +759,64 @@ def get_weights_info(self, checkpoint_path: str) -> Optional[Any]: """ Get weights info. + Supports both twinkle:// paths (local checkpoints) and hub model IDs. + For hub model IDs, downloads checkpoint_metadata.json from ModelScope. + Args: - checkpoint_path: The path + checkpoint_path: The twinkle:// path or hub model ID Returns: WeightsInfoResponse or None if not found """ - parsed_path = self.parse_path(checkpoint_path) - if not parsed_path: - return None - - ckpt_info = self.get(parsed_path.training_run_id, parsed_path.checkpoint_id) - if not ckpt_info: + # Use resolve_load_path to determine if this is a twinkle path or hub path + try: + resolved = self.resolve_load_path(checkpoint_path, validate_exists=False) + except ValueError: return None - # Weight info is stored in the training run info - run_info = self.training_run_manager._read_info(parsed_path.training_run_id) - if not run_info: - return None + if resolved.is_twinkle_path: + # Local twinkle:// path - read from local checkpoint metadata + ckpt_data = self._read_ckpt_info(resolved.training_run_id, resolved.checkpoint_id) + if not ckpt_data or not ckpt_data.get('base_model'): + return None + return self._create_weights_info(ckpt_data) + else: + # Hub model ID - download checkpoint_metadata.json from ModelScope + return self._get_weights_info_from_hub(checkpoint_path) + + def _get_weights_info_from_hub(self, hub_model_id: str) -> Optional[Any]: + """ + Download and parse checkpoint_metadata.json from hub. - # Validate ownership - if not validate_ownership(self.token, run_info.get('model_owner', '')): + Args: + hub_model_id: The hub model ID (e.g., 'user/model-name') + + Returns: + WeightsInfoResponse or None if not found or failed to download + """ + try: + # Download only the checkpoint_metadata.json file from hub + local_dir = HubOperation.download_file( + repo_id=hub_model_id, + allow_patterns=[CHECKPOINT_INFO_FILENAME], + token=self.token + ) + + # Read and parse the metadata + metadata_path = os.path.join(local_dir, CHECKPOINT_INFO_FILENAME) + if not os.path.exists(metadata_path): + return None + + with open(metadata_path, 'r') as f: + ckpt_data = json.load(f) + + if not ckpt_data.get('base_model'): + return None + + return self._create_weights_info(ckpt_data) + + except Exception: return None - - return self._create_weights_info(run_info) def resolve_load_path(self, path: str, validate_exists: bool = True) -> ResolvedLoadPath: """ From 42e84d845b9339dedd6463c20915b4111eedc3ff Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Thu, 5 Feb 2026 17:55:16 +0800 Subject: [PATCH 4/9] add manager --- cookbook/client/twinkle/transformer/lora.py | 5 +- .../twinkle/transformer/server_config.yaml | 3 + src/twinkle/server/tinker/common/__init__.py | 7 -- src/twinkle/server/tinker/model.py | 4 +- src/twinkle/server/twinkle/common/io_utils.py | 27 +++++ src/twinkle/server/twinkle/model.py | 101 ++++++------------ src/twinkle/server/utils/__init__.py | 8 ++ .../common => utils}/adapter_manager.py | 0 .../{tinker/common => utils}/rate_limiter.py | 0 .../{tinker/common => utils}/task_queue.py | 0 10 files changed, 75 insertions(+), 80 deletions(-) rename src/twinkle/server/{tinker/common => utils}/adapter_manager.py (100%) rename src/twinkle/server/{tinker/common => utils}/rate_limiter.py (100%) rename src/twinkle/server/{tinker/common => utils}/task_queue.py (100%) diff --git a/cookbook/client/twinkle/transformer/lora.py b/cookbook/client/twinkle/transformer/lora.py index 994a76a1..f83a18c7 100644 --- a/cookbook/client/twinkle/transformer/lora.py +++ b/cookbook/client/twinkle/transformer/lora.py @@ -56,6 +56,7 @@ def train(): if resume_path: logger.info(f'Resuming training from {resume_path}') model.load(resume_path, load_optimizer=True) + # Start training logger.info(model.get_train_configs()) for step, batch in enumerate(dataloader): output = model.forward_backward(inputs=batch) @@ -66,10 +67,12 @@ def train(): model.zero_grad() model.lr_step() + # Save the model twinkle_path = model.save(name=f'step-{step}', save_optimizer=True) logger.info(f"Saved checkpoint: {twinkle_path}") - hub_model_id = 'AlexEz/twinkle-self-cognition-2' + # Upload the model to ModelScope + hub_model_id = 'AlexEz/twinkle-self-cognition' model.upload_to_hub( checkpoint_dir=twinkle_path, hub_model_id=hub_model_id, diff --git a/cookbook/client/twinkle/transformer/server_config.yaml b/cookbook/client/twinkle/transformer/server_config.yaml index 52d4a567..1c987498 100644 --- a/cookbook/client/twinkle/transformer/server_config.yaml +++ b/cookbook/client/twinkle/transformer/server_config.yaml @@ -26,6 +26,9 @@ applications: args: use_megatron: false model_id: "ms://Qwen/Qwen2.5-0.5B-Instruct" + adapter_config: + per_token_adapter_limit: 30 + adapter_timeout: 1800 nproc_per_node: 2 device_group: name: model diff --git a/src/twinkle/server/tinker/common/__init__.py b/src/twinkle/server/tinker/common/__init__.py index f4e96a81..80c2ca8a 100644 --- a/src/twinkle/server/tinker/common/__init__.py +++ b/src/twinkle/server/tinker/common/__init__.py @@ -1,10 +1,3 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from .datum import datum_to_input_feature, input_feature_to_datum from twinkle.utils import exists, requires -from .rate_limiter import RateLimiter -from .task_queue import ( - TaskStatus, - QueueState, - TaskQueueConfig, - TaskQueueMixin, -) diff --git a/src/twinkle/server/tinker/model.py b/src/twinkle/server/tinker/model.py index 231f1330..59d12f1a 100644 --- a/src/twinkle/server/tinker/model.py +++ b/src/twinkle/server/tinker/model.py @@ -24,8 +24,8 @@ from twinkle.server.utils.state import get_server_state, ServerStateProxy from twinkle.utils.logger import get_logger -from .common.task_queue import TaskQueueMixin, TaskQueueConfig -from .common.adapter_manager import AdapterManagerMixin +from twinkle.server.utils.task_queue import TaskQueueMixin, TaskQueueConfig +from twinkle.server.utils.adapter_manager import AdapterManagerMixin from .common.io_utils import create_training_run_manager, create_checkpoint_manager logger = get_logger() diff --git a/src/twinkle/server/twinkle/common/io_utils.py b/src/twinkle/server/twinkle/common/io_utils.py index 87349553..d3614b44 100644 --- a/src/twinkle/server/twinkle/common/io_utils.py +++ b/src/twinkle/server/twinkle/common/io_utils.py @@ -176,8 +176,35 @@ def _create_checkpoint( def _parse_checkpoint(self, data: Dict[str, Any]) -> Checkpoint: """Parse checkpoint data into Checkpoint model.""" + data = data.copy() + # Transform tinker_path to twinkle_path if needed + if 'tinker_path' in data and 'twinkle_path' not in data: + data['twinkle_path'] = data.pop('tinker_path') + elif 'twinkle_path' not in data and 'path' in data: + data['twinkle_path'] = data.pop('path') return Checkpoint(**data) + def get(self, model_id: str, checkpoint_id: str) -> Optional[Checkpoint]: + """ + Get checkpoint metadata with backwards compatibility. + + Args: + model_id: The model identifier + checkpoint_id: The checkpoint identifier + + Returns: + Checkpoint object or None if not found + """ + data = self._read_ckpt_info(model_id, checkpoint_id) + if not data: + return None + # Handle backwards compatibility: construct twinkle_path if missing + if 'twinkle_path' not in data and 'tinker_path' not in data and 'path' not in data: + if 'checkpoint_id' in data: + data = data.copy() + data['twinkle_path'] = f"{self.path_prefix}{model_id}/{data['checkpoint_id']}" + return self._parse_checkpoint(data) + def _create_checkpoints_response(self, checkpoints: List[Checkpoint]) -> CheckpointsListResponse: """Create a checkpoints list response.""" return CheckpointsListResponse(checkpoints=checkpoints, cursor=None) diff --git a/src/twinkle/server/twinkle/model.py b/src/twinkle/server/twinkle/model.py index 0acba943..949f335b 100644 --- a/src/twinkle/server/twinkle/model.py +++ b/src/twinkle/server/twinkle/model.py @@ -1,7 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import os -import threading -import time from typing import Dict, Any, Optional from fastapi import FastAPI, Request @@ -11,10 +9,11 @@ import twinkle from twinkle import DeviceGroup, DeviceMesh -from twinkle.model.base import TwinkleModel from twinkle.data_format import InputFeature, Trajectory +from twinkle.server.utils.adapter_manager import AdapterManagerMixin from twinkle.server.utils.validation import verify_request_token from twinkle.server.utils.state import get_server_state, ServerStateProxy +from twinkle.utils.logger import get_logger from .common.serialize import deserialize_object from .common.io_utils import ( CreateModelRequest, @@ -23,6 +22,8 @@ create_checkpoint_manager, ) +logger = get_logger() + class CreateRequest(BaseModel): class Config: @@ -138,6 +139,7 @@ def build_model_app(model_id: str, device_mesh: Dict[str, Any], deploy_options: Dict[str, Any], use_megatron: bool = False, + adapter_config: Dict[str, Any] = None, **kwargs): app = FastAPI() @@ -147,15 +149,14 @@ async def verify_token(request: Request, call_next): @serve.deployment(name="ModelManagement") @serve.ingress(app) - class ModelManagement: - - COUNT_DOWN = 60 * 30 + class ModelManagement(AdapterManagerMixin): def __init__(self, nproc_per_node: int, device_group: Dict[str, Any], device_mesh: Dict[str, Any]): self.device_group = DeviceGroup(**device_group) twinkle.initialize(mode='ray', nproc_per_node=nproc_per_node, groups=[self.device_group], lazy_collect=False) self.device_mesh = DeviceMesh(**device_mesh) if use_megatron: + from twinkle.model import MultiLoraMegatronModel self.model = MultiLoraMegatronModel( model_id=model_id, device_mesh=self.device_mesh, @@ -170,53 +171,18 @@ def __init__(self, nproc_per_node: int, device_group: Dict[str, Any], device_mes remote_group=self.device_group.name, **kwargs ) - self.adapter_records: Dict[str, int] = {} - self.hb_thread = threading.Thread(target=self.countdown, daemon=True) - self.hb_thread.start() - self.adapter_lock = threading.Lock() + + # Initialize state before adapter manager (mixin needs self.state) self.state: ServerStateProxy = get_server_state() - self.per_token_model_limit = int(os.environ.get("TWINKLE_PER_USER_MODEL_LIMIT", 3)) - self.key_token_dict = {} - - def countdown(self): - while True: - time.sleep(1) - for key in list(self.adapter_records.keys()): - self.adapter_records[key] += 1 - if self.adapter_records[key] > self.COUNT_DOWN: - with self.adapter_lock: - self.model.remove_adapter(key) - self.adapter_records.pop(key, None) - token = self.key_token_dict.pop(key, None) - if token: - self.handle_adapter_count(token, False) - - def handle_adapter_count(self, token: str, add: bool): - user_key = token + '_' + 'model_adapter' - cur_count = self.state.get_config(user_key) or 0 - if add: - if cur_count < self.per_token_model_limit: - self.state.add_config(user_key, cur_count + 1) - else: - raise RuntimeError(f'Model adapter count limitation reached: {self.per_token_model_limit}') - else: - if cur_count > 0: - cur_count -= 1 - self.state.add_config(user_key, cur_count) - if cur_count <= 0: - self.state.pop_config(user_key) + + # Initialize adapter manager from mixin + self._init_adapter_manager(**adapter_config) + self.start_adapter_countdown() @app.post("/create") def create(self, request: Request, body: CreateRequest): return {'status': 'ok'} - def assert_adapter_exists(self, adapter_name: str): - assert adapter_name and adapter_name in self.adapter_records, f"Adapter {adapter_name} not found" - - def assert_adapter_valid(self, adapter_name: Optional[str]): - assert adapter_name is None or adapter_name == '' or adapter_name in self.adapter_records, \ - f"Adapter {adapter_name} is invalid" - @staticmethod def get_adapter_name(request: Request, adapter_name: Optional[str]) -> Optional[str]: if adapter_name is None or adapter_name == '': @@ -433,24 +399,15 @@ def load(self, request: Request, body: LoadRequest): # Use resolve_load_path to handle path resolution resolved = checkpoint_manager.resolve_load_path(body.name) - if resolved.is_twinkle_path: - # Load from twinkle checkpoint directory - ret = self.model.load( - name=resolved.checkpoint_name, - output_dir=resolved.checkpoint_dir, - adapter_name=adapter_name, - load_optimizer=body.load_optimizer, - **extra_kwargs - ) - else: - # Load from hub (checkpoint_dir is None) - ret = self.model.load( - name=resolved.checkpoint_name, - adapter_name=adapter_name, - load_optimizer=body.load_optimizer, - token=token, - **extra_kwargs - ) + # Load from twinkle checkpoint directory + ret = self.model.load( + name=resolved.checkpoint_name, + output_dir=resolved.checkpoint_dir, + adapter_name=adapter_name, + load_optimizer=body.load_optimizer, + token=token, + **extra_kwargs + ) return {'result': ret} @@ -536,12 +493,16 @@ def add_adapter_to_model(self, request: Request, body: AddAdapterRequest): token = request.state.token training_run_manager = create_training_run_manager(token) - with self.adapter_lock: + with self._adapter_lock: self.model.add_adapter_to_model(adapter_name, config, **extra_kwargs) - self.adapter_records[adapter_name] = 0 - self.key_token_dict[adapter_name] = token - self.handle_adapter_count(token, True) + # Register adapter for lifecycle tracking + self.register_adapter(adapter_name, token) + + # Check adapter limit (raises if exceeded) + allowed, reason = self.check_adapter_limit(token, True) + if not allowed: + raise RuntimeError(reason) # Save training run metadata (similar to tinker's create_model) # Create a training run config from the adapter configuration @@ -585,7 +546,7 @@ def set_processor(self, request: Request, body: SetProcessorRequest): def heartbeat(self, request: Request, body: HeartbeatRequest): adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name) self.assert_adapter_exists(adapter_name=adapter_name) - self.adapter_records[adapter_name] = 0 + self.touch_adapter(adapter_name) return {'status': 'ok'} @app.post("/calculate_metric") diff --git a/src/twinkle/server/utils/__init__.py b/src/twinkle/server/utils/__init__.py index 98abbd67..5f61d75d 100644 --- a/src/twinkle/server/utils/__init__.py +++ b/src/twinkle/server/utils/__init__.py @@ -7,3 +7,11 @@ TRAIN_RUN_INFO_FILENAME, ) from .device_utils import auto_fill_device_group_visible_devices, wrap_builder_with_device_group_env +from .rate_limiter import RateLimiter +from .task_queue import ( + TaskStatus, + QueueState, + TaskQueueConfig, + TaskQueueMixin, +) +from .adapter_manager import AdapterManagerMixin diff --git a/src/twinkle/server/tinker/common/adapter_manager.py b/src/twinkle/server/utils/adapter_manager.py similarity index 100% rename from src/twinkle/server/tinker/common/adapter_manager.py rename to src/twinkle/server/utils/adapter_manager.py diff --git a/src/twinkle/server/tinker/common/rate_limiter.py b/src/twinkle/server/utils/rate_limiter.py similarity index 100% rename from src/twinkle/server/tinker/common/rate_limiter.py rename to src/twinkle/server/utils/rate_limiter.py diff --git a/src/twinkle/server/tinker/common/task_queue.py b/src/twinkle/server/utils/task_queue.py similarity index 100% rename from src/twinkle/server/tinker/common/task_queue.py rename to src/twinkle/server/utils/task_queue.py From fe54791948aa21fbcf965c8bafb0cccec63608bb Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Thu, 5 Feb 2026 18:00:18 +0800 Subject: [PATCH 5/9] update import --- src/twinkle/server/tinker/__init__.py | 6 +++--- src/twinkle/server/tinker/sampler.py | 2 +- src/twinkle/server/tinker/server.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/twinkle/server/tinker/__init__.py b/src/twinkle/server/tinker/__init__.py index b864eac4..4174d255 100644 --- a/src/twinkle/server/tinker/__init__.py +++ b/src/twinkle/server/tinker/__init__.py @@ -1,8 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -from .model import build_model_app as _build_model_app -from .sampler import build_sampler_app as _build_sampler_app -from .server import build_server_app +from .model import build_model_app as tinker_build_model_app +from .sampler import build_sampler_app as tinker_build_sampler_app +from .server import build_server_app as tinker_build_server_app from ..utils import wrap_builder_with_device_group_env diff --git a/src/twinkle/server/tinker/sampler.py b/src/twinkle/server/tinker/sampler.py index 99d97eed..cc99a9cd 100644 --- a/src/twinkle/server/tinker/sampler.py +++ b/src/twinkle/server/tinker/sampler.py @@ -21,10 +21,10 @@ from twinkle import DeviceGroup, DeviceMesh from twinkle.server.utils.validation import verify_request_token from twinkle.server.utils.state import get_server_state, ServerStateProxy +from twinkle.server.utils.task_queue import TaskQueueMixin, TaskQueueConfig from twinkle.sampler.types import SamplingParams as TwinkleSamplingParams from twinkle.utils.logger import get_logger -from .common.task_queue import TaskQueueMixin, TaskQueueConfig logger = get_logger() diff --git a/src/twinkle/server/tinker/server.py b/src/twinkle/server/tinker/server.py index 09c54c56..e4da1c6e 100644 --- a/src/twinkle/server/tinker/server.py +++ b/src/twinkle/server/tinker/server.py @@ -24,9 +24,9 @@ from twinkle.server.utils.validation import verify_request_token, get_token_from_request from twinkle.server.utils.state import get_server_state +from twinkle.server.utils.task_queue import QueueState from twinkle.hub import HubOperation from .common.io_utils import create_training_run_manager, create_checkpoint_manager -from .common.task_queue import QueueState logger = logging.getLogger(__name__) From 4dcc4525f5d4ff43614f155a3a36a2b14f4e5f76 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Fri, 6 Feb 2026 13:14:06 +0800 Subject: [PATCH 6/9] update server --- cookbook/client/tinker/megatron/server.py | 48 +-- .../client/tinker/megatron/server_config.yaml | 31 +- cookbook/client/tinker/transformer/lora.py | 6 +- .../tinker/transformer/self_congnition.py | 2 +- cookbook/client/tinker/transformer/server.py | 51 +-- .../tinker/transformer/server_config.yaml | 66 +-- cookbook/client/twinkle/megatron/server.py | 49 +-- .../twinkle/megatron/server_config.yaml | 7 +- cookbook/client/twinkle/transformer/server.py | 49 +-- .../twinkle/transformer/server_config.yaml | 7 +- cookbook/remote/tinker/ascend/server.py | 61 +-- .../remote/tinker/ascend/server_config.yaml | 7 +- cookbook/remote/tinker/server.py | 48 +-- cookbook/remote/tinker/server_config.yaml | 5 +- cookbook/remote/twinkle/server.py | 50 +-- cookbook/remote/twinkle/server_config.yaml | 7 +- src/twinkle/server/__init__.py | 10 + src/twinkle/server/__main__.py | 143 +++++++ src/twinkle/server/launcher.py | 384 ++++++++++++++++++ src/twinkle/server/tinker/__init__.py | 14 +- 20 files changed, 615 insertions(+), 430 deletions(-) create mode 100644 src/twinkle/server/__main__.py create mode 100644 src/twinkle/server/launcher.py diff --git a/cookbook/client/tinker/megatron/server.py b/cookbook/client/tinker/megatron/server.py index 2fe976e0..1fd179f1 100644 --- a/cookbook/client/tinker/megatron/server.py +++ b/cookbook/client/tinker/megatron/server.py @@ -1,51 +1,9 @@ import os os.environ['RAY_DEBUG'] = '1' -import ray -from omegaconf import OmegaConf -from ray import serve -from twinkle.server.tinker import build_model_app, build_server_app -ray.init(namespace="twinkle_cluster") -serve.shutdown() -import time -time.sleep(5) +from twinkle.server import launch_server file_dir = os.path.abspath(os.path.dirname(__file__)) -config = OmegaConf.load(os.path.join(file_dir, 'server_config.yaml')) +config_path = os.path.join(file_dir, 'server_config.yaml') -# Start Ray Serve with http_options from config -http_options = OmegaConf.to_container(config.http_options, resolve=True) -serve.start(http_options=http_options) - -APP_BUILDERS = { - 'main:build_server_app': build_server_app, - 'main:build_model_app': build_model_app, - # 'main:build_sampler_app': build_sampler_app, -} - -for app_config in config.applications: - print(f"Starting {app_config.name} at {app_config.route_prefix}...") - - builder = APP_BUILDERS[app_config.import_path] - args = OmegaConf.to_container(app_config.args, resolve=True) if app_config.args else {} - - deploy_options = {} - deploy_config = app_config.deployments[0] - if 'autoscaling_config' in deploy_config: - deploy_options['autoscaling_config'] = OmegaConf.to_container(deploy_config.autoscaling_config) - if 'ray_actor_options' in deploy_config: - deploy_options['ray_actor_options'] = OmegaConf.to_container(deploy_config.ray_actor_options) - - app = builder( - deploy_options=deploy_options, - **{k: v for k, v in args.items()} - ) - - serve.run(app, name=app_config.name, route_prefix=app_config.route_prefix) - -print("\nAll applications started!") -print("Endpoints:") -for app_config in config.applications: - print(f" - http://localhost:8000{app_config.route_prefix}") - -input("\nPress Enter to stop the server...") \ No newline at end of file +launch_server(config_path=config_path) \ No newline at end of file diff --git a/cookbook/client/tinker/megatron/server_config.yaml b/cookbook/client/tinker/megatron/server_config.yaml index c5d48885..3c2a9265 100644 --- a/cookbook/client/tinker/megatron/server_config.yaml +++ b/cookbook/client/tinker/megatron/server_config.yaml @@ -1,3 +1,4 @@ +server_type: tinker proxy_location: EveryNode http_options: host: 0.0.0.0 @@ -6,7 +7,7 @@ http_options: applications: - name: server route_prefix: /api/v1 - import_path: main:build_server_app + import_path: server args: deployments: @@ -22,7 +23,7 @@ applications: - name: models-Qwen2.5-0.5B-Instruct route_prefix: /api/v1/model/Qwen/Qwen2.5-0.5B-Instruct - import_path: main:build_model_app + import_path: model args: use_megatron: true model_id: "ms://Qwen/Qwen2.5-0.5B-Instruct" @@ -46,29 +47,3 @@ applications: logging_config: log_level: DEBUG - # Example: Add more models as needed - # - name: models-Qwen2.5-7B-Instruct - # route_prefix: /api/v1/model/Qwen/Qwen2.5-7B-Instruct - # import_path: main:build_model_app - # args: - # model_id: "ms://Qwen/Qwen2.5-7B-Instruct" - # nproc_per_node: 4 - # device_group: - # name: model7b - # ranks: [2, 3, 4, 5] - # device_type: cuda - # device_mesh: - # device_type: cuda - # mesh: [2, 3, 4, 5] - # mesh_dim_names: ['dp'] - # deployments: - # - name: ModelManagement - # autoscaling_config: - # min_replicas: 1 - # max_replicas: 1 - # target_ongoing_requests: 16 - # ray_actor_options: - # num_cpus: 0.1 - # logging_config: - # log_level: DEBUG - diff --git a/cookbook/client/tinker/transformer/lora.py b/cookbook/client/tinker/transformer/lora.py index 00181ff9..0f1a096b 100644 --- a/cookbook/client/tinker/transformer/lora.py +++ b/cookbook/client/tinker/transformer/lora.py @@ -18,8 +18,8 @@ response = future.result() # Support resume from twinkle path or model id # resume_path = "twinkle://20260131_170251-Qwen_Qwen2_5-0_5B-Instruct-7275126c/weights/pig-latin-lora-epoch-1" -resume_path = "AlexEz/20260205_163645-Qwen_Qwen2_5-7B-Instruct-385d5c17_pig-latin-lora-epoch-1" -# resume_path = "" +# resume_path = "AlexEz/20260205_163645-Qwen_Qwen2_5-7B-Instruct-385d5c17_pig-latin-lora-epoch-1" +resume_path = "" print(f"Found {len(response.training_runs)} training runs") for tr in response.training_runs: print(tr.model_dump_json(indent=2)) @@ -30,7 +30,7 @@ # resume_path = chpt.tinker_path # Just get the last one for demo purposes #%% -base_model = "Qwen/Qwen2.5-7B-Instruct" +base_model = "Qwen/Qwen2.5-0.5B-Instruct" if not resume_path: training_client = service_client.create_lora_training_client( base_model=base_model diff --git a/cookbook/client/tinker/transformer/self_congnition.py b/cookbook/client/tinker/transformer/self_congnition.py index 3e644d7d..14a3963b 100644 --- a/cookbook/client/tinker/transformer/self_congnition.py +++ b/cookbook/client/tinker/transformer/self_congnition.py @@ -8,7 +8,7 @@ from twinkle.server.tinker.common import input_feature_to_datum from modelscope import AutoTokenizer -base_model = "Qwen/Qwen2.5-7B-Instruct" +base_model = "Qwen/Qwen2.5-0.5B-Instruct" def train(): # process data diff --git a/cookbook/client/tinker/transformer/server.py b/cookbook/client/tinker/transformer/server.py index be9123e1..1fd179f1 100644 --- a/cookbook/client/tinker/transformer/server.py +++ b/cookbook/client/tinker/transformer/server.py @@ -1,54 +1,9 @@ import os os.environ['RAY_DEBUG'] = '1' -import ray -from omegaconf import OmegaConf -from ray import serve -from twinkle.server.tinker import build_model_app, build_sampler_app, build_server_app -ray.init(namespace="twinkle_cluster") -serve.shutdown() -import time -time.sleep(5) +from twinkle.server import launch_server file_dir = os.path.abspath(os.path.dirname(__file__)) -config = OmegaConf.load(os.path.join(file_dir, 'server_config.yaml')) +config_path = os.path.join(file_dir, 'server_config.yaml') -# Start Ray Serve with http_options from config -http_options = OmegaConf.to_container(config.http_options, resolve=True) -serve.start(http_options=http_options) - -APP_BUILDERS = { - 'main:build_server_app': build_server_app, - 'main:build_model_app': build_model_app, - # 'main:build_sampler_app': build_sampler_app, -} - -for app_config in config.applications: - print(f"Starting {app_config.name} at {app_config.route_prefix}...") - - if app_config.import_path not in APP_BUILDERS: - continue - - builder = APP_BUILDERS[app_config.import_path] - args = OmegaConf.to_container(app_config.args, resolve=True) if app_config.args else {} - - deploy_options = {} - deploy_config = app_config.deployments[0] - if 'autoscaling_config' in deploy_config: - deploy_options['autoscaling_config'] = OmegaConf.to_container(deploy_config.autoscaling_config) - if 'ray_actor_options' in deploy_config: - deploy_options['ray_actor_options'] = OmegaConf.to_container(deploy_config.ray_actor_options) - - app = builder( - deploy_options=deploy_options, - **{k: v for k, v in args.items()} - ) - - serve.run(app, name=app_config.name, route_prefix=app_config.route_prefix) - -print("\nAll applications started!") -print("Endpoints:") -for app_config in config.applications: - print(f" - http://localhost:8000{app_config.route_prefix}") - -input("\nPress Enter to stop the server...") \ No newline at end of file +launch_server(config_path=config_path) \ No newline at end of file diff --git a/cookbook/client/tinker/transformer/server_config.yaml b/cookbook/client/tinker/transformer/server_config.yaml index 9f5f86bc..94943356 100644 --- a/cookbook/client/tinker/transformer/server_config.yaml +++ b/cookbook/client/tinker/transformer/server_config.yaml @@ -1,3 +1,4 @@ +server_type: tinker proxy_location: EveryNode http_options: host: 0.0.0.0 @@ -6,7 +7,7 @@ http_options: applications: - name: server route_prefix: /api/v1 - import_path: main:build_server_app + import_path: server args: deployments: @@ -21,11 +22,11 @@ applications: log_level: DEBUG - name: models-Qwen2.5-0.5B-Instruct - route_prefix: /api/v1/model/Qwen/Qwen2.5-7B-Instruct - import_path: main:build_model_app + route_prefix: /api/v1/model/Qwen/Qwen2.5-0.5B-Instruct + import_path: model args: use_megatron: false - model_id: "ms://Qwen/Qwen2.5-7B-Instruct" + model_id: "ms://Qwen/Qwen2.5-0.5B-Instruct" nproc_per_node: 2 device_group: name: model @@ -52,59 +53,34 @@ applications: logging_config: log_level: DEBUG - - name: sampler-Qwen2.5-0.5B-Instruct - route_prefix: /api/v1/sampler/Qwen/Qwen2.5-0.5B-Instruct - import_path: main:build_sampler_app - args: - model_id: "ms://Qwen/Qwen2.5-0.5B-Instruct" - nproc_per_node: 1 - sampler_type: vllm # or 'torch' for TorchSampler - engine_args: - max_model_len: 4096 - gpu_memory_utilization: 0.5 - enable_lora: false - device_group: - name: sampler - ranks: [0] - device_type: cuda - device_mesh: - device_type: cuda - mesh: [0] - mesh_dim_names: ['dp'] - deployments: - - name: SamplerManagement - autoscaling_config: - min_replicas: 1 - max_replicas: 1 - target_ongoing_requests: 16 - ray_actor_options: - num_cpus: 0.1 - num_gpus: 1 - logging_config: - log_level: DEBUG - - # Example: Add more models as needed - # - name: models-Qwen2.5-7B-Instruct - # route_prefix: /api/v1/model/Qwen/Qwen2.5-7B-Instruct - # import_path: main:build_model_app + # - name: sampler-Qwen2.5-0.5B-Instruct + # route_prefix: /api/v1/sampler/Qwen/Qwen2.5-0.5B-Instruct + # import_path: sampler # args: - # model_id: "ms://Qwen/Qwen2.5-7B-Instruct" - # nproc_per_node: 4 + # model_id: "ms://Qwen/Qwen2.5-0.5B-Instruct" + # nproc_per_node: 1 + # sampler_type: vllm # or 'torch' for TorchSampler + # engine_args: + # max_model_len: 4096 + # gpu_memory_utilization: 0.5 + # enable_lora: false # device_group: - # name: model7b - # ranks: [2, 3, 4, 5] + # name: sampler + # ranks: [0] # device_type: cuda # device_mesh: # device_type: cuda - # mesh: [2, 3, 4, 5] + # mesh: [0] # mesh_dim_names: ['dp'] # deployments: - # - name: ModelManagement + # - name: SamplerManagement # autoscaling_config: # min_replicas: 1 # max_replicas: 1 # target_ongoing_requests: 16 # ray_actor_options: # num_cpus: 0.1 + # num_gpus: 1 # logging_config: # log_level: DEBUG + diff --git a/cookbook/client/twinkle/megatron/server.py b/cookbook/client/twinkle/megatron/server.py index 3ae52629..1fd179f1 100644 --- a/cookbook/client/twinkle/megatron/server.py +++ b/cookbook/client/twinkle/megatron/server.py @@ -1,52 +1,9 @@ import os os.environ['RAY_DEBUG'] = '1' -import ray -from omegaconf import OmegaConf -from ray import serve -from twinkle.server import build_processor_app, build_sampler_app, build_model_app, build_server_app -ray.init() -serve.shutdown() -import time -time.sleep(5) +from twinkle.server import launch_server file_dir = os.path.abspath(os.path.dirname(__file__)) -config = OmegaConf.load(os.path.join(file_dir, 'server_config.yaml')) +config_path = os.path.join(file_dir, 'server_config.yaml') -# Start Ray Serve with http_options from config -http_options = OmegaConf.to_container(config.http_options, resolve=True) -serve.start(http_options=http_options) - -APP_BUILDERS = { - 'main:model_qwen25_7B': build_model_app, - # 'main:build_sampler_app': build_sampler_app, - 'main:processor_app': build_processor_app, - 'main:build_server_app': build_server_app, -} - -for app_config in config.applications: - print(f"Starting {app_config.name} at {app_config.route_prefix}...") - - builder = APP_BUILDERS[app_config.import_path] - args = OmegaConf.to_container(app_config.args, resolve=True) if app_config.args else {} - - deploy_options = {} - deploy_config = app_config.deployments[0] - if 'autoscaling_config' in deploy_config: - deploy_options['autoscaling_config'] = OmegaConf.to_container(deploy_config.autoscaling_config) - if 'ray_actor_options' in deploy_config: - deploy_options['ray_actor_options'] = OmegaConf.to_container(deploy_config.ray_actor_options) - - app = builder( - deploy_options=deploy_options, - **{k: v for k, v in args.items()} - ) - - serve.run(app, name=app_config.name, route_prefix=app_config.route_prefix) - -print("\nAll applications started!") -print("Endpoints:") -for app_config in config.applications: - print(f" - http://localhost:8000{app_config.route_prefix}") - -input("\nPress Enter to stop the server...") \ No newline at end of file +launch_server(config_path=config_path) \ No newline at end of file diff --git a/cookbook/client/twinkle/megatron/server_config.yaml b/cookbook/client/twinkle/megatron/server_config.yaml index 98c02a4d..004c1e86 100644 --- a/cookbook/client/twinkle/megatron/server_config.yaml +++ b/cookbook/client/twinkle/megatron/server_config.yaml @@ -1,3 +1,4 @@ +server_type: twinkle proxy_location: EveryNode http_options: host: 0.0.0.0 @@ -6,7 +7,7 @@ http_options: applications: - name: server route_prefix: /server - import_path: main:build_server_app + import_path: server args: deployments: @@ -22,7 +23,7 @@ applications: - name: models-Qwen2.5-7B-Instruct route_prefix: /models/Qwen/Qwen2.5-7B-Instruct - import_path: main:model_qwen25_7B + import_path: model args: use_megatron: true model_id: "ms://Qwen/Qwen2.5-0.5B-Instruct" @@ -46,7 +47,7 @@ applications: - name: processor route_prefix: /processors - import_path: main:processor_app + import_path: processor args: nproc_per_node: 2 ncpu_proc_per_node: 2 diff --git a/cookbook/client/twinkle/transformer/server.py b/cookbook/client/twinkle/transformer/server.py index 3ae52629..1fd179f1 100644 --- a/cookbook/client/twinkle/transformer/server.py +++ b/cookbook/client/twinkle/transformer/server.py @@ -1,52 +1,9 @@ import os os.environ['RAY_DEBUG'] = '1' -import ray -from omegaconf import OmegaConf -from ray import serve -from twinkle.server import build_processor_app, build_sampler_app, build_model_app, build_server_app -ray.init() -serve.shutdown() -import time -time.sleep(5) +from twinkle.server import launch_server file_dir = os.path.abspath(os.path.dirname(__file__)) -config = OmegaConf.load(os.path.join(file_dir, 'server_config.yaml')) +config_path = os.path.join(file_dir, 'server_config.yaml') -# Start Ray Serve with http_options from config -http_options = OmegaConf.to_container(config.http_options, resolve=True) -serve.start(http_options=http_options) - -APP_BUILDERS = { - 'main:model_qwen25_7B': build_model_app, - # 'main:build_sampler_app': build_sampler_app, - 'main:processor_app': build_processor_app, - 'main:build_server_app': build_server_app, -} - -for app_config in config.applications: - print(f"Starting {app_config.name} at {app_config.route_prefix}...") - - builder = APP_BUILDERS[app_config.import_path] - args = OmegaConf.to_container(app_config.args, resolve=True) if app_config.args else {} - - deploy_options = {} - deploy_config = app_config.deployments[0] - if 'autoscaling_config' in deploy_config: - deploy_options['autoscaling_config'] = OmegaConf.to_container(deploy_config.autoscaling_config) - if 'ray_actor_options' in deploy_config: - deploy_options['ray_actor_options'] = OmegaConf.to_container(deploy_config.ray_actor_options) - - app = builder( - deploy_options=deploy_options, - **{k: v for k, v in args.items()} - ) - - serve.run(app, name=app_config.name, route_prefix=app_config.route_prefix) - -print("\nAll applications started!") -print("Endpoints:") -for app_config in config.applications: - print(f" - http://localhost:8000{app_config.route_prefix}") - -input("\nPress Enter to stop the server...") \ No newline at end of file +launch_server(config_path=config_path) \ No newline at end of file diff --git a/cookbook/client/twinkle/transformer/server_config.yaml b/cookbook/client/twinkle/transformer/server_config.yaml index 1c987498..7a82bfe2 100644 --- a/cookbook/client/twinkle/transformer/server_config.yaml +++ b/cookbook/client/twinkle/transformer/server_config.yaml @@ -1,3 +1,4 @@ +server_type: twinkle proxy_location: EveryNode http_options: host: 0.0.0.0 @@ -6,7 +7,7 @@ http_options: applications: - name: server route_prefix: /server - import_path: main:build_server_app + import_path: server args: deployments: @@ -22,7 +23,7 @@ applications: - name: models-Qwen2.5-7B-Instruct route_prefix: /models/Qwen/Qwen2.5-7B-Instruct - import_path: main:model_qwen25_7B + import_path: model args: use_megatron: false model_id: "ms://Qwen/Qwen2.5-0.5B-Instruct" @@ -49,7 +50,7 @@ applications: - name: processor route_prefix: /processors - import_path: main:processor_app + import_path: processor args: nproc_per_node: 2 ncpu_proc_per_node: 2 diff --git a/cookbook/remote/tinker/ascend/server.py b/cookbook/remote/tinker/ascend/server.py index eb437480..b7515dbf 100644 --- a/cookbook/remote/tinker/ascend/server.py +++ b/cookbook/remote/tinker/ascend/server.py @@ -1,64 +1,9 @@ import os os.environ['RAY_DEBUG'] = '1' -import ray -from omegaconf import OmegaConf -from ray import serve -from twinkle.server.tinker import build_model_app, build_server_app, build_sampler_app -ray.init(namespace="twinkle_cluster") -serve.shutdown() -import time -import sys -time.sleep(5) +from twinkle.server import launch_server file_dir = os.path.abspath(os.path.dirname(__file__)) -config = OmegaConf.load(os.path.join(file_dir, 'server_config.yaml')) +config_path = os.path.join(file_dir, 'server_config.yaml') -# Start Ray Serve with http_options from config -http_options = OmegaConf.to_container(config.http_options, resolve=True) -serve.start(http_options=http_options) - -APP_BUILDERS = { - 'main:build_server_app': build_server_app, - 'main:model_qwen25_7B': build_model_app, - 'main:build_sampler_app': build_sampler_app, -} - -for app_config in config.applications: - print(f"Starting {app_config.name} at {app_config.route_prefix}...") - - builder = APP_BUILDERS[app_config.import_path] - args = OmegaConf.to_container(app_config.args, resolve=True) if app_config.args else {} - deploy_options = {} - deploy_config = app_config.deployments[0] - if 'autoscaling_config' in deploy_config: - deploy_options['autoscaling_config'] = OmegaConf.to_container(deploy_config.autoscaling_config) - if 'ray_actor_options' in deploy_config: - deploy_options['ray_actor_options'] = OmegaConf.to_container(deploy_config.ray_actor_options) - # Pass through any additional deployment options (e.g., request_timeout_s). - for key in deploy_config: - if key in {"name", "autoscaling_config", "ray_actor_options"}: - continue - value = getattr(deploy_config, key) - if OmegaConf.is_config(value): - value = OmegaConf.to_container(value, resolve=True) - deploy_options[key] = value - - app = builder( - deploy_options=deploy_options, - **{k: v for k, v in args.items()} - ) - - serve.run(app, name=app_config.name, route_prefix=app_config.route_prefix) - -print("\nAll applications started!") -print("Endpoints:") -for app_config in config.applications: - print(f" - http://localhost:8000{app_config.route_prefix}") - -# In non-interactive runs (e.g., run_server.sh redirect), avoid blocking on stdin. -if sys.stdin.isatty(): - input("\nPress Enter to stop the server...") -else: - while True: - time.sleep(60) +launch_server(config_path=config_path) diff --git a/cookbook/remote/tinker/ascend/server_config.yaml b/cookbook/remote/tinker/ascend/server_config.yaml index 3959c49a..0ced7d62 100644 --- a/cookbook/remote/tinker/ascend/server_config.yaml +++ b/cookbook/remote/tinker/ascend/server_config.yaml @@ -1,3 +1,4 @@ +server_type: tinker proxy_location: EveryNode http_options: host: 0.0.0.0 @@ -6,7 +7,7 @@ http_options: applications: - name: server route_prefix: /api/v1 - import_path: main:build_server_app + import_path: server args: supported_models: - model_name: "Qwen/Qwen3-0.6B" @@ -25,7 +26,7 @@ applications: - name: models-Qwen2.5-7B-Instruct # route_prefix 需要是 HTTP 路径前缀;不要用本地文件路径 route_prefix: /api/v1/model/Qwen/Qwen3-0.6B - import_path: main:model_qwen25_7B + import_path: model args: model_id: "/home/zyh/model/Qwen3-0.6B" nproc_per_node: 2 @@ -51,7 +52,7 @@ applications: - name: sampler-Qwen3-0.6B route_prefix: /api/v1/sampler/Qwen/Qwen3-0.6B - import_path: main:build_sampler_app + import_path: sampler args: model_id: "/home/zyh/model/Qwen3-0.6B" nproc_per_node: 1 diff --git a/cookbook/remote/tinker/server.py b/cookbook/remote/tinker/server.py index 2fe976e0..1fd179f1 100644 --- a/cookbook/remote/tinker/server.py +++ b/cookbook/remote/tinker/server.py @@ -1,51 +1,9 @@ import os os.environ['RAY_DEBUG'] = '1' -import ray -from omegaconf import OmegaConf -from ray import serve -from twinkle.server.tinker import build_model_app, build_server_app -ray.init(namespace="twinkle_cluster") -serve.shutdown() -import time -time.sleep(5) +from twinkle.server import launch_server file_dir = os.path.abspath(os.path.dirname(__file__)) -config = OmegaConf.load(os.path.join(file_dir, 'server_config.yaml')) +config_path = os.path.join(file_dir, 'server_config.yaml') -# Start Ray Serve with http_options from config -http_options = OmegaConf.to_container(config.http_options, resolve=True) -serve.start(http_options=http_options) - -APP_BUILDERS = { - 'main:build_server_app': build_server_app, - 'main:build_model_app': build_model_app, - # 'main:build_sampler_app': build_sampler_app, -} - -for app_config in config.applications: - print(f"Starting {app_config.name} at {app_config.route_prefix}...") - - builder = APP_BUILDERS[app_config.import_path] - args = OmegaConf.to_container(app_config.args, resolve=True) if app_config.args else {} - - deploy_options = {} - deploy_config = app_config.deployments[0] - if 'autoscaling_config' in deploy_config: - deploy_options['autoscaling_config'] = OmegaConf.to_container(deploy_config.autoscaling_config) - if 'ray_actor_options' in deploy_config: - deploy_options['ray_actor_options'] = OmegaConf.to_container(deploy_config.ray_actor_options) - - app = builder( - deploy_options=deploy_options, - **{k: v for k, v in args.items()} - ) - - serve.run(app, name=app_config.name, route_prefix=app_config.route_prefix) - -print("\nAll applications started!") -print("Endpoints:") -for app_config in config.applications: - print(f" - http://localhost:8000{app_config.route_prefix}") - -input("\nPress Enter to stop the server...") \ No newline at end of file +launch_server(config_path=config_path) \ No newline at end of file diff --git a/cookbook/remote/tinker/server_config.yaml b/cookbook/remote/tinker/server_config.yaml index 2375271d..e8f4ec77 100644 --- a/cookbook/remote/tinker/server_config.yaml +++ b/cookbook/remote/tinker/server_config.yaml @@ -1,3 +1,4 @@ +server_type: tinker proxy_location: EveryNode http_options: host: 0.0.0.0 @@ -6,7 +7,7 @@ http_options: applications: - name: server route_prefix: /api/v1 - import_path: main:build_server_app + import_path: server args: deployments: @@ -22,7 +23,7 @@ applications: - name: models-Qwen2.5-0.5B-Instruct route_prefix: /api/v1/model/Qwen/Qwen2.5-0.5B-Instruct - import_path: main:build_model_app + import_path: model args: model_id: "ms://Qwen/Qwen2.5-0.5B-Instruct" nproc_per_node: 2 diff --git a/cookbook/remote/twinkle/server.py b/cookbook/remote/twinkle/server.py index 0f6f761a..b7515dbf 100644 --- a/cookbook/remote/twinkle/server.py +++ b/cookbook/remote/twinkle/server.py @@ -1,53 +1,9 @@ import os os.environ['RAY_DEBUG'] = '1' -import ray -from omegaconf import OmegaConf -from ray import serve -from twinkle.server import build_processor_app, build_sampler_app, build_model_app, build_server_app - -ray.init() -serve.shutdown() -import time -time.sleep(5) +from twinkle.server import launch_server file_dir = os.path.abspath(os.path.dirname(__file__)) -config = OmegaConf.load(os.path.join(file_dir, 'server_config.yaml')) - -# Start Ray Serve with http_options from config -http_options = OmegaConf.to_container(config.http_options, resolve=True) -serve.start(http_options=http_options) - -APP_BUILDERS = { - 'main:model_qwen25_7B': build_model_app, - # 'main:build_sampler_app': build_sampler_app, - 'main:processor_app': build_processor_app, - 'main:build_server_app': build_server_app, -} - -for app_config in config.applications: - print(f"Starting {app_config.name} at {app_config.route_prefix}...") - - builder = APP_BUILDERS[app_config.import_path] - args = OmegaConf.to_container(app_config.args, resolve=True) if app_config.args else {} - - deploy_options = {} - deploy_config = app_config.deployments[0] - if 'autoscaling_config' in deploy_config: - deploy_options['autoscaling_config'] = OmegaConf.to_container(deploy_config.autoscaling_config) - if 'ray_actor_options' in deploy_config: - deploy_options['ray_actor_options'] = OmegaConf.to_container(deploy_config.ray_actor_options) - - app = builder( - deploy_options=deploy_options, - **{k: v for k, v in args.items()} - ) - - serve.run(app, name=app_config.name, route_prefix=app_config.route_prefix) - -print("\nAll applications started!") -print("Endpoints:") -for app_config in config.applications: - print(f" - http://{config.http_options.host}:{config.http_options.port}{app_config.route_prefix}") +config_path = os.path.join(file_dir, 'server_config.yaml') -input("\nPress Enter to stop the server...") +launch_server(config_path=config_path) diff --git a/cookbook/remote/twinkle/server_config.yaml b/cookbook/remote/twinkle/server_config.yaml index d07a96d6..4adcb9a5 100644 --- a/cookbook/remote/twinkle/server_config.yaml +++ b/cookbook/remote/twinkle/server_config.yaml @@ -1,3 +1,4 @@ +server_type: twinkle proxy_location: EveryNode http_options: host: 0.0.0.0 @@ -6,7 +7,7 @@ http_options: applications: - name: server route_prefix: /server - import_path: main:build_server_app + import_path: server args: deployments: @@ -22,7 +23,7 @@ applications: - name: models-Qwen2.5-7B-Instruct route_prefix: /models/Qwen/Qwen2.5-7B-Instruct - import_path: main:model_qwen25_7B + import_path: model args: model_id: "ms://Qwen/Qwen2.5-0.5B-Instruct" nproc_per_node: 2 @@ -45,7 +46,7 @@ applications: - name: processor route_prefix: /processors - import_path: main:processor_app + import_path: processor args: nproc_per_node: 2 ncpu_proc_per_node: 2 diff --git a/src/twinkle/server/__init__.py b/src/twinkle/server/__init__.py index a5c24817..d57f97f7 100644 --- a/src/twinkle/server/__init__.py +++ b/src/twinkle/server/__init__.py @@ -4,8 +4,18 @@ from .twinkle.processor import build_processor_app as _build_processor_app from .twinkle.server import build_server_app from .utils import wrap_builder_with_device_group_env +from .launcher import ServerLauncher, launch_server build_model_app = wrap_builder_with_device_group_env(_build_model_app) build_processor_app = wrap_builder_with_device_group_env(_build_processor_app) build_sampler_app = wrap_builder_with_device_group_env(_build_sampler_app) + +__all__ = [ + 'build_model_app', + 'build_processor_app', + 'build_sampler_app', + 'build_server_app', + 'ServerLauncher', + 'launch_server', +] diff --git a/src/twinkle/server/__main__.py b/src/twinkle/server/__main__.py new file mode 100644 index 00000000..636fc676 --- /dev/null +++ b/src/twinkle/server/__main__.py @@ -0,0 +1,143 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +CLI entry point for Twinkle Server. + +Usage: + # From config file + python -m twinkle.server --config server_config.yaml + + # With server type override + python -m twinkle.server --config server_config.yaml --server-type tinker + + # Quick start with minimal args + python -m twinkle.server --server-type tinker --port 8000 --model-id "Qwen/Qwen2.5-7B-Instruct" +""" +from __future__ import annotations + +import argparse +import sys +from pathlib import Path + +from twinkle import get_logger + +logger = get_logger() + +def create_parser() -> argparse.ArgumentParser: + """Create the argument parser.""" + parser = argparse.ArgumentParser( + prog="python -m twinkle.server", + description="Twinkle Server Launcher - Unified launcher for tinker and twinkle servers", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Start server from YAML config file + python -m twinkle.server --config server_config.yaml + + # Start tinker server with specific config + python -m twinkle.server -c config.yaml -t tinker + + # Run in background (daemon mode) + python -m twinkle.server -c config.yaml --no-wait + """, + ) + + # Config file option + parser.add_argument( + "-c", "--config", + type=str, + required=True, + metavar="PATH", + help="Path to YAML configuration file (required)", + ) + + # Server type + parser.add_argument( + "-t", "--server-type", + type=str, + default="twinkle", + choices=["tinker", "twinkle"], + metavar="TYPE", + help="Server type: 'tinker' or 'twinkle' (default: twinkle)", + ) + + # Ray options + parser.add_argument( + "--namespace", + type=str, + metavar="NS", + help="Ray namespace (default: 'twinkle_cluster' for tinker, None for twinkle)", + ) + + # Runtime options + parser.add_argument( + "--no-wait", + action="store_true", + help="Don't block waiting for Enter (daemon mode)", + ) + parser.add_argument( + "--log-level", + type=str, + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR"], + metavar="LEVEL", + help="Logging level (default: INFO)", + ) + + return parser + + +def main(args: list[str] | None = None) -> int: + """ + Main entry point for the CLI. + + Args: + args: Command line arguments (uses sys.argv if None) + + Returns: + Exit code (0 for success, non-zero for error) + """ + parser = create_parser() + parsed_args = parser.parse_args(args) + + # Setup logging + setup_logging(parsed_args.log_level) + + + try: + from twinkle.server.launcher import launch_server + + # Config file mode + config_path = Path(parsed_args.config) + if not config_path.exists(): + logger.error(f"Config file not found: {config_path}") + return 1 + + launch_server( + config_path=config_path, + server_type=parsed_args.server_type, + ray_namespace=parsed_args.namespace, + wait=not parsed_args.no_wait, + ) + + return 0 + + except KeyboardInterrupt: + logger.info("Server stopped by user") + return 0 + except FileNotFoundError as e: + logger.error(f"File not found: {e}") + return 1 + except ValueError as e: + logger.error(f"Configuration error: {e}") + return 1 + except ImportError as e: + logger.error(f"Import error: {e}") + logger.error("Make sure all required dependencies are installed") + return 1 + except Exception as e: + logger.exception(f"Unexpected error: {e}") + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/twinkle/server/launcher.py b/src/twinkle/server/launcher.py new file mode 100644 index 00000000..189efa90 --- /dev/null +++ b/src/twinkle/server/launcher.py @@ -0,0 +1,384 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Unified Server Launcher for Twinkle. + +This module provides a unified way to launch both tinker and twinkle servers +with support for YAML config files, Python dict config, and CLI. + +Usage: + # From YAML config + from twinkle.server import launch_server + launch_server(config_path="server_config.yaml") + + # From Python dict + launch_server(config={ + "server_type": "tinker", + "http_options": {"host": "0.0.0.0", "port": 8000}, + "applications": [...] + }) + + # CLI + python -m twinkle.server --config server_config.yaml +""" +from __future__ import annotations + +import os +import time +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Union + +from twinkle import get_logger + +logger = get_logger() + + +class ServerLauncher: + """ + Unified server launcher for tinker and twinkle servers. + + This class handles Ray/Serve initialization and application deployment + for both tinker and twinkle server types. + + Attributes: + server_type: The type of server ('tinker' or 'twinkle') + config: The server configuration dictionary + ray_namespace: The Ray namespace for the cluster + """ + + # Mapping of simplified import_path names to actual builder functions + # These will be populated lazily to avoid circular imports + _TINKER_BUILDERS: Dict[str, str] = { + 'server': 'build_server_app', + 'model': 'build_model_app', + 'sampler': 'build_sampler_app', + } + + _TWINKLE_BUILDERS: Dict[str, str] = { + 'server': 'build_server_app', + 'model': 'build_model_app', + 'sampler': 'build_sampler_app', + 'processor': 'build_processor_app', + } + + def __init__( + self, + server_type: str = "twinkle", + config: Optional[Dict[str, Any]] = None, + ray_namespace: Optional[str] = None, + ): + """ + Initialize the server launcher. + + Args: + server_type: Server type ('tinker' or 'twinkle') + config: Configuration dictionary + ray_namespace: Ray namespace (default: 'twinkle_cluster' for tinker, None for twinkle) + """ + if server_type not in ("tinker", "twinkle"): + raise ValueError(f"server_type must be 'tinker' or 'twinkle', got '{server_type}'") + + self.server_type = server_type + self.config = config or {} + self.ray_namespace = ray_namespace + self._builders: Dict[str, Callable] = {} + self._ray_initialized = False + self._serve_started = False + + def _get_builders(self) -> Dict[str, Callable]: + """ + Get the appropriate builder functions for the server type. + + Returns: + Dictionary mapping import_path names to builder functions + """ + if self._builders: + return self._builders + + if self.server_type == "tinker": + from twinkle.server.tinker import ( + build_model_app, + build_sampler_app, + build_server_app, + ) + self._builders = { + 'build_server_app': build_server_app, + 'build_model_app': build_model_app, + 'build_sampler_app': build_sampler_app, + } + else: # twinkle + from twinkle.server import ( + build_model_app, + build_processor_app, + build_sampler_app, + build_server_app, + ) + self._builders = { + 'build_server_app': build_server_app, + 'build_model_app': build_model_app, + 'build_sampler_app': build_sampler_app, + 'build_processor_app': build_processor_app, + } + + return self._builders + + def _resolve_builder(self, import_path: str) -> Callable: + """ + Resolve an import_path to a builder function. + + Args: + import_path: The import path from config (e.g., 'server', 'main:build_server_app') + + Returns: + The builder function + + Raises: + ValueError: If the import_path cannot be resolved + """ + builders = self._get_builders() + builder_map = self._TINKER_BUILDERS if self.server_type == "tinker" else self._TWINKLE_BUILDERS + + # Try to resolve through the mapping + if import_path in builder_map: + builder_name = builder_map[import_path] + if builder_name in builders: + return builders[builder_name] + + # Direct builder name + if import_path in builders: + return builders[import_path] + + raise ValueError( + f"Unknown import_path '{import_path}' for server_type '{self.server_type}'. " + f"Available: {list(builder_map.keys())}" + ) + + def _init_ray(self) -> None: + """Initialize Ray if not already initialized.""" + if self._ray_initialized: + return + + import ray + + # Determine namespace + namespace = self.ray_namespace + if namespace is None: + namespace = self.config.get('ray_namespace') + if namespace is None and self.server_type == "tinker": + namespace = "twinkle_cluster" + + init_kwargs = {} + if namespace: + init_kwargs['namespace'] = namespace + + if not ray.is_initialized(): + ray.init(**init_kwargs) + logger.info(f"Ray initialized with namespace={namespace}") + + self._ray_initialized = True + + def _start_serve(self) -> None: + """Start Ray Serve with http_options from config.""" + if self._serve_started: + return + + from ray import serve + + # Shutdown any existing serve instance + try: + serve.shutdown() + time.sleep(2) # Wait for cleanup + except Exception: + pass + + # Get http_options from config + http_options = self.config.get('http_options', {}) + if isinstance(http_options, dict): + http_options = dict(http_options) + else: + # Handle OmegaConf or other config objects + http_options = dict(http_options) if http_options else {} + + serve.start(http_options=http_options) + logger.info(f"Ray Serve started with http_options={http_options}") + + self._serve_started = True + + def _deploy_application(self, app_config: Dict[str, Any]) -> None: + """ + Deploy a single application. + + Args: + app_config: Application configuration dictionary + """ + from ray import serve + + name = app_config.get('name', 'app') + route_prefix = app_config.get('route_prefix', '/') + import_path = app_config.get('import_path', 'server') + args = app_config.get('args', {}) or {} + deployments = app_config.get('deployments', []) + + logger.info(f"Starting {name} at {route_prefix}...") + + # Resolve builder function + builder = self._resolve_builder(import_path) + + # Build deploy_options from deployments config + deploy_options = {} + if deployments: + deploy_config = deployments[0] + if isinstance(deploy_config, dict): + if 'autoscaling_config' in deploy_config: + deploy_options['autoscaling_config'] = dict(deploy_config['autoscaling_config']) + if 'ray_actor_options' in deploy_config: + deploy_options['ray_actor_options'] = dict(deploy_config['ray_actor_options']) + + # Build and deploy the application + app = builder( + deploy_options=deploy_options, + **{k: v for k, v in args.items()} + ) + + serve.run(app, name=name, route_prefix=route_prefix) + logger.info(f"Deployed {name} at {route_prefix}") + + def launch(self, wait: bool = True) -> None: + """ + Launch the server with all configured applications. + + Args: + wait: If True, block and wait for Enter to stop the server + """ + self._init_ray() + self._start_serve() + + applications = self.config.get('applications', []) + if not applications: + logger.warning("No applications configured") + return + + # Deploy each application + for app_config in applications: + if isinstance(app_config, dict): + self._deploy_application(app_config) + else: + # Handle OmegaConf or other config objects + self._deploy_application(dict(app_config)) + + # Print endpoints + http_options = self.config.get('http_options', {}) + host = http_options.get('host', 'localhost') + port = http_options.get('port', 8000) + + print("\nAll applications started!") + print("Endpoints:") + for app_config in applications: + route_prefix = app_config.get('route_prefix', '/') if isinstance(app_config, dict) else app_config.route_prefix + print(f" - http://{host}:{port}{route_prefix}") + + if wait: + input("\nPress Enter to stop the server...") + + @classmethod + def from_yaml( + cls, + config_path: Union[str, Path], + server_type: str = "twinkle", + ray_namespace: Optional[str] = None, + ) -> "ServerLauncher": + """ + Create a ServerLauncher from a YAML config file. + + Args: + config_path: Path to the YAML config file + server_type: Server type ('tinker' or 'twinkle'), default is 'twinkle' + ray_namespace: Override Ray namespace from config + + Returns: + Configured ServerLauncher instance + """ + from omegaconf import OmegaConf + + config_path = Path(config_path) + if not config_path.exists(): + raise FileNotFoundError(f"Config file not found: {config_path}") + + config = OmegaConf.load(config_path) + config_dict = OmegaConf.to_container(config, resolve=True) + + # Override server_type from config if specified + if 'server_type' in config_dict: + server_type = config_dict['server_type'] + + return cls( + server_type=server_type, + config=config_dict, + ray_namespace=ray_namespace or config_dict.get('ray_namespace'), + ) + + + + + +def launch_server( + config: Optional[Dict[str, Any]] = None, + config_path: Optional[Union[str, Path]] = None, + server_type: str = "twinkle", + ray_namespace: Optional[str] = None, + wait: bool = True, +) -> ServerLauncher: + """ + Launch a twinkle server with flexible configuration options. + + This is the main entry point for launching servers programmatically. + + Args: + config: Configuration dictionary (takes precedence over config_path) + config_path: Path to YAML config file + server_type: Server type ('tinker' or 'twinkle'), default is 'twinkle' + ray_namespace: Ray namespace + wait: If True, block and wait for Enter to stop the server + + Returns: + The ServerLauncher instance + + Raises: + ValueError: If neither config nor config_path is provided + + Examples: + # From YAML config (twinkle mode) + launch_server(config_path="server_config.yaml") + + # From YAML config (tinker mode) + launch_server(config_path="server_config.yaml", server_type="tinker") + + # From Python dict + launch_server(config={ + "server_type": "tinker", + "http_options": {"host": "0.0.0.0", "port": 8000}, + "applications": [...] + }) + """ + if config is None and config_path is None: + raise ValueError("Either 'config' or 'config_path' must be provided") + + launcher: ServerLauncher + + if config is not None: + # From Python dict config - override with config's server_type if specified + final_server_type = config.get('server_type', server_type) + launcher = ServerLauncher( + server_type=final_server_type, + config=config, + ray_namespace=ray_namespace or config.get('ray_namespace'), + ) + else: + # From YAML config file + launcher = ServerLauncher.from_yaml( + config_path=config_path, + server_type=server_type, + ray_namespace=ray_namespace, + ) + + launcher.launch(wait=wait) + return launcher diff --git a/src/twinkle/server/tinker/__init__.py b/src/twinkle/server/tinker/__init__.py index 4174d255..54dc848d 100644 --- a/src/twinkle/server/tinker/__init__.py +++ b/src/twinkle/server/tinker/__init__.py @@ -1,10 +1,16 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -from .model import build_model_app as tinker_build_model_app -from .sampler import build_sampler_app as tinker_build_sampler_app -from .server import build_server_app as tinker_build_server_app +from .model import build_model_app as _build_model_app +from .sampler import build_sampler_app as _build_sampler_app +from .server import build_server_app from ..utils import wrap_builder_with_device_group_env build_model_app = wrap_builder_with_device_group_env(_build_model_app) -build_sampler_app = wrap_builder_with_device_group_env(_build_sampler_app) \ No newline at end of file +build_sampler_app = wrap_builder_with_device_group_env(_build_sampler_app) + +__all__ = [ + 'build_model_app', + 'build_sampler_app', + 'build_server_app', +] \ No newline at end of file From 35f349d802e43e6657719dfd0e0c2f63ac1b5741 Mon Sep 17 00:00:00 2001 From: Yunlin Mao Date: Fri, 6 Feb 2026 14:05:50 +0800 Subject: [PATCH 7/9] Update src/twinkle/server/launcher.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/twinkle/server/launcher.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/twinkle/server/launcher.py b/src/twinkle/server/launcher.py index 189efa90..a1f89b24 100644 --- a/src/twinkle/server/launcher.py +++ b/src/twinkle/server/launcher.py @@ -223,15 +223,15 @@ def _deploy_application(self, app_config: Dict[str, Any]) -> None: # Resolve builder function builder = self._resolve_builder(import_path) + # Build deploy_options from deployments config # Build deploy_options from deployments config deploy_options = {} if deployments: deploy_config = deployments[0] if isinstance(deploy_config, dict): - if 'autoscaling_config' in deploy_config: - deploy_options['autoscaling_config'] = dict(deploy_config['autoscaling_config']) - if 'ray_actor_options' in deploy_config: - deploy_options['ray_actor_options'] = dict(deploy_config['ray_actor_options']) + # Copy all deployment options from the config, except 'name'. + deploy_options = {k: v for k, v in deploy_config.items() if k != 'name'} + # Build and deploy the application app = builder( From 00e8819e4acedfb1425c7e712a11939fc8b31ea2 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Fri, 6 Feb 2026 14:06:15 +0800 Subject: [PATCH 8/9] update server --- src/twinkle/server/__main__.py | 29 +++---- src/twinkle/server/launcher.py | 139 +++++++++++++++++---------------- 2 files changed, 83 insertions(+), 85 deletions(-) diff --git a/src/twinkle/server/__main__.py b/src/twinkle/server/__main__.py index 636fc676..2e0e91c5 100644 --- a/src/twinkle/server/__main__.py +++ b/src/twinkle/server/__main__.py @@ -22,6 +22,7 @@ logger = get_logger() + def create_parser() -> argparse.ArgumentParser: """Create the argument parser.""" parser = argparse.ArgumentParser( @@ -40,7 +41,7 @@ def create_parser() -> argparse.ArgumentParser: python -m twinkle.server -c config.yaml --no-wait """, ) - + # Config file option parser.add_argument( "-c", "--config", @@ -49,7 +50,7 @@ def create_parser() -> argparse.ArgumentParser: metavar="PATH", help="Path to YAML configuration file (required)", ) - + # Server type parser.add_argument( "-t", "--server-type", @@ -59,7 +60,7 @@ def create_parser() -> argparse.ArgumentParser: metavar="TYPE", help="Server type: 'tinker' or 'twinkle' (default: twinkle)", ) - + # Ray options parser.add_argument( "--namespace", @@ -67,7 +68,7 @@ def create_parser() -> argparse.ArgumentParser: metavar="NS", help="Ray namespace (default: 'twinkle_cluster' for tinker, None for twinkle)", ) - + # Runtime options parser.add_argument( "--no-wait", @@ -82,45 +83,41 @@ def create_parser() -> argparse.ArgumentParser: metavar="LEVEL", help="Logging level (default: INFO)", ) - + return parser def main(args: list[str] | None = None) -> int: """ Main entry point for the CLI. - + Args: args: Command line arguments (uses sys.argv if None) - + Returns: Exit code (0 for success, non-zero for error) """ parser = create_parser() parsed_args = parser.parse_args(args) - - # Setup logging - setup_logging(parsed_args.log_level) - - + try: from twinkle.server.launcher import launch_server - + # Config file mode config_path = Path(parsed_args.config) if not config_path.exists(): logger.error(f"Config file not found: {config_path}") return 1 - + launch_server( config_path=config_path, server_type=parsed_args.server_type, ray_namespace=parsed_args.namespace, wait=not parsed_args.no_wait, ) - + return 0 - + except KeyboardInterrupt: logger.info("Server stopped by user") return 0 diff --git a/src/twinkle/server/launcher.py b/src/twinkle/server/launcher.py index 189efa90..f3e0f2a9 100644 --- a/src/twinkle/server/launcher.py +++ b/src/twinkle/server/launcher.py @@ -35,16 +35,16 @@ class ServerLauncher: """ Unified server launcher for tinker and twinkle servers. - + This class handles Ray/Serve initialization and application deployment for both tinker and twinkle server types. - + Attributes: server_type: The type of server ('tinker' or 'twinkle') config: The server configuration dictionary ray_namespace: The Ray namespace for the cluster """ - + # Mapping of simplified import_path names to actual builder functions # These will be populated lazily to avoid circular imports _TINKER_BUILDERS: Dict[str, str] = { @@ -52,14 +52,14 @@ class ServerLauncher: 'model': 'build_model_app', 'sampler': 'build_sampler_app', } - + _TWINKLE_BUILDERS: Dict[str, str] = { 'server': 'build_server_app', 'model': 'build_model_app', 'sampler': 'build_sampler_app', 'processor': 'build_processor_app', } - + def __init__( self, server_type: str = "twinkle", @@ -68,32 +68,33 @@ def __init__( ): """ Initialize the server launcher. - + Args: server_type: Server type ('tinker' or 'twinkle') config: Configuration dictionary ray_namespace: Ray namespace (default: 'twinkle_cluster' for tinker, None for twinkle) """ if server_type not in ("tinker", "twinkle"): - raise ValueError(f"server_type must be 'tinker' or 'twinkle', got '{server_type}'") - + raise ValueError( + f"server_type must be 'tinker' or 'twinkle', got '{server_type}'") + self.server_type = server_type self.config = config or {} self.ray_namespace = ray_namespace self._builders: Dict[str, Callable] = {} self._ray_initialized = False self._serve_started = False - + def _get_builders(self) -> Dict[str, Callable]: """ Get the appropriate builder functions for the server type. - + Returns: Dictionary mapping import_path names to builder functions """ if self._builders: return self._builders - + if self.server_type == "tinker": from twinkle.server.tinker import ( build_model_app, @@ -118,78 +119,78 @@ def _get_builders(self) -> Dict[str, Callable]: 'build_sampler_app': build_sampler_app, 'build_processor_app': build_processor_app, } - + return self._builders - + def _resolve_builder(self, import_path: str) -> Callable: """ Resolve an import_path to a builder function. - + Args: import_path: The import path from config (e.g., 'server', 'main:build_server_app') - + Returns: The builder function - + Raises: ValueError: If the import_path cannot be resolved """ builders = self._get_builders() builder_map = self._TINKER_BUILDERS if self.server_type == "tinker" else self._TWINKLE_BUILDERS - + # Try to resolve through the mapping if import_path in builder_map: builder_name = builder_map[import_path] if builder_name in builders: return builders[builder_name] - + # Direct builder name if import_path in builders: return builders[import_path] - + raise ValueError( f"Unknown import_path '{import_path}' for server_type '{self.server_type}'. " f"Available: {list(builder_map.keys())}" ) - + def _init_ray(self) -> None: """Initialize Ray if not already initialized.""" if self._ray_initialized: return - + import ray - + # Determine namespace namespace = self.ray_namespace if namespace is None: namespace = self.config.get('ray_namespace') if namespace is None and self.server_type == "tinker": namespace = "twinkle_cluster" - + init_kwargs = {} if namespace: init_kwargs['namespace'] = namespace - + if not ray.is_initialized(): ray.init(**init_kwargs) logger.info(f"Ray initialized with namespace={namespace}") - + self._ray_initialized = True - + def _start_serve(self) -> None: """Start Ray Serve with http_options from config.""" if self._serve_started: return - + from ray import serve - + # Shutdown any existing serve instance try: serve.shutdown() time.sleep(2) # Wait for cleanup except Exception: pass - + # Get http_options from config http_options = self.config.get('http_options', {}) if isinstance(http_options, dict): @@ -197,66 +198,68 @@ def _start_serve(self) -> None: else: # Handle OmegaConf or other config objects http_options = dict(http_options) if http_options else {} - + serve.start(http_options=http_options) logger.info(f"Ray Serve started with http_options={http_options}") - + self._serve_started = True - + def _deploy_application(self, app_config: Dict[str, Any]) -> None: """ Deploy a single application. - + Args: app_config: Application configuration dictionary """ from ray import serve - + name = app_config.get('name', 'app') route_prefix = app_config.get('route_prefix', '/') import_path = app_config.get('import_path', 'server') args = app_config.get('args', {}) or {} deployments = app_config.get('deployments', []) - + logger.info(f"Starting {name} at {route_prefix}...") - + # Resolve builder function builder = self._resolve_builder(import_path) - + # Build deploy_options from deployments config deploy_options = {} if deployments: deploy_config = deployments[0] if isinstance(deploy_config, dict): if 'autoscaling_config' in deploy_config: - deploy_options['autoscaling_config'] = dict(deploy_config['autoscaling_config']) + deploy_options['autoscaling_config'] = dict( + deploy_config['autoscaling_config']) if 'ray_actor_options' in deploy_config: - deploy_options['ray_actor_options'] = dict(deploy_config['ray_actor_options']) - + deploy_options['ray_actor_options'] = dict( + deploy_config['ray_actor_options']) + # Build and deploy the application app = builder( deploy_options=deploy_options, **{k: v for k, v in args.items()} ) - + serve.run(app, name=name, route_prefix=route_prefix) logger.info(f"Deployed {name} at {route_prefix}") - + def launch(self, wait: bool = True) -> None: """ Launch the server with all configured applications. - + Args: wait: If True, block and wait for Enter to stop the server """ self._init_ray() self._start_serve() - + applications = self.config.get('applications', []) if not applications: logger.warning("No applications configured") return - + # Deploy each application for app_config in applications: if isinstance(app_config, dict): @@ -264,21 +267,22 @@ def launch(self, wait: bool = True) -> None: else: # Handle OmegaConf or other config objects self._deploy_application(dict(app_config)) - + # Print endpoints http_options = self.config.get('http_options', {}) host = http_options.get('host', 'localhost') port = http_options.get('port', 8000) - + print("\nAll applications started!") print("Endpoints:") for app_config in applications: - route_prefix = app_config.get('route_prefix', '/') if isinstance(app_config, dict) else app_config.route_prefix + route_prefix = app_config.get( + 'route_prefix', '/') if isinstance(app_config, dict) else app_config.route_prefix print(f" - http://{host}:{port}{route_prefix}") - + if wait: input("\nPress Enter to stop the server...") - + @classmethod def from_yaml( cls, @@ -288,36 +292,33 @@ def from_yaml( ) -> "ServerLauncher": """ Create a ServerLauncher from a YAML config file. - + Args: config_path: Path to the YAML config file server_type: Server type ('tinker' or 'twinkle'), default is 'twinkle' ray_namespace: Override Ray namespace from config - + Returns: Configured ServerLauncher instance """ from omegaconf import OmegaConf - + config_path = Path(config_path) if not config_path.exists(): raise FileNotFoundError(f"Config file not found: {config_path}") - + config = OmegaConf.load(config_path) config_dict = OmegaConf.to_container(config, resolve=True) - + # Override server_type from config if specified if 'server_type' in config_dict: server_type = config_dict['server_type'] - + return cls( server_type=server_type, config=config_dict, ray_namespace=ray_namespace or config_dict.get('ray_namespace'), ) - - - def launch_server( @@ -329,29 +330,29 @@ def launch_server( ) -> ServerLauncher: """ Launch a twinkle server with flexible configuration options. - + This is the main entry point for launching servers programmatically. - + Args: config: Configuration dictionary (takes precedence over config_path) config_path: Path to YAML config file server_type: Server type ('tinker' or 'twinkle'), default is 'twinkle' ray_namespace: Ray namespace wait: If True, block and wait for Enter to stop the server - + Returns: The ServerLauncher instance - + Raises: ValueError: If neither config nor config_path is provided - + Examples: # From YAML config (twinkle mode) launch_server(config_path="server_config.yaml") - + # From YAML config (tinker mode) launch_server(config_path="server_config.yaml", server_type="tinker") - + # From Python dict launch_server(config={ "server_type": "tinker", @@ -361,9 +362,9 @@ def launch_server( """ if config is None and config_path is None: raise ValueError("Either 'config' or 'config_path' must be provided") - + launcher: ServerLauncher - + if config is not None: # From Python dict config - override with config's server_type if specified final_server_type = config.get('server_type', server_type) @@ -379,6 +380,6 @@ def launch_server( server_type=server_type, ray_namespace=ray_namespace, ) - + launcher.launch(wait=wait) return launcher From 5054ef6e00a396fd51dabab050533d8c74dfa7cc Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Sat, 7 Feb 2026 11:24:53 +0800 Subject: [PATCH 9/9] update --- .../tinker/transformer/self_congnition.py | 10 +++- .../tinker/transformer/server_config.yaml | 60 +++++++++---------- src/twinkle/sampler/vllm_engine.py | 44 +++++--------- src/twinkle/server/tinker/common/io_utils.py | 3 +- src/twinkle/server/tinker/sampler.py | 52 ++++++++++++++-- src/twinkle/server/tinker/server.py | 29 ++------- 6 files changed, 105 insertions(+), 93 deletions(-) diff --git a/cookbook/client/tinker/transformer/self_congnition.py b/cookbook/client/tinker/transformer/self_congnition.py index 14a3963b..e4112628 100644 --- a/cookbook/client/tinker/transformer/self_congnition.py +++ b/cookbook/client/tinker/transformer/self_congnition.py @@ -46,7 +46,7 @@ def train(): print(f"Saved checkpoint to {save_result.path}") def eval(): - weight_path = "twinkle://20260203_194633-Qwen_Qwen2_5-0_5B-Instruct-03aa3f06/weights/twinkle-lora" + weight_path = "twinkle://20260207_110850-Qwen_Qwen2_5-0_5B-Instruct-ce7e819f/weights/twinkle-lora-2" service_client = init_tinker_compat_client(base_url='http://localhost:8000') sampling_client = service_client.create_sampling_client( @@ -56,6 +56,10 @@ def eval(): tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) inputs = [ + { + 'role': 'system', + 'content': 'You are a helpful assistant.' + }, { 'role': 'user', 'content': 'what is your name?' @@ -78,5 +82,5 @@ def eval(): print(f"{i}: {repr(tokenizer.decode(seq.tokens))}") if __name__ == "__main__": - train() - # eval() + # train() + eval() diff --git a/cookbook/client/tinker/transformer/server_config.yaml b/cookbook/client/tinker/transformer/server_config.yaml index 94943356..df56f60a 100644 --- a/cookbook/client/tinker/transformer/server_config.yaml +++ b/cookbook/client/tinker/transformer/server_config.yaml @@ -53,34 +53,34 @@ applications: logging_config: log_level: DEBUG - # - name: sampler-Qwen2.5-0.5B-Instruct - # route_prefix: /api/v1/sampler/Qwen/Qwen2.5-0.5B-Instruct - # import_path: sampler - # args: - # model_id: "ms://Qwen/Qwen2.5-0.5B-Instruct" - # nproc_per_node: 1 - # sampler_type: vllm # or 'torch' for TorchSampler - # engine_args: - # max_model_len: 4096 - # gpu_memory_utilization: 0.5 - # enable_lora: false - # device_group: - # name: sampler - # ranks: [0] - # device_type: cuda - # device_mesh: - # device_type: cuda - # mesh: [0] - # mesh_dim_names: ['dp'] - # deployments: - # - name: SamplerManagement - # autoscaling_config: - # min_replicas: 1 - # max_replicas: 1 - # target_ongoing_requests: 16 - # ray_actor_options: - # num_cpus: 0.1 - # num_gpus: 1 - # logging_config: - # log_level: DEBUG + - name: sampler-Qwen2.5-0.5B-Instruct + route_prefix: /api/v1/sampler/Qwen/Qwen2.5-0.5B-Instruct + import_path: sampler + args: + model_id: "ms://Qwen/Qwen2.5-0.5B-Instruct" + nproc_per_node: 1 + sampler_type: vllm # or 'torch' for TorchSampler + engine_args: + max_model_len: 4096 + gpu_memory_utilization: 0.5 + enable_lora: false + device_group: + name: sampler + ranks: [0] + device_type: cuda + device_mesh: + device_type: cuda + mesh: [0] + mesh_dim_names: ['dp'] + deployments: + - name: SamplerManagement + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 16 + ray_actor_options: + num_cpus: 0.1 + num_gpus: 1 + logging_config: + log_level: DEBUG diff --git a/src/twinkle/sampler/vllm_engine.py b/src/twinkle/sampler/vllm_engine.py index 84e56adc..2bf035b6 100644 --- a/src/twinkle/sampler/vllm_engine.py +++ b/src/twinkle/sampler/vllm_engine.py @@ -180,7 +180,8 @@ async def sample( logprobs: bool = True, include_prompt_logprobs: bool = False, topk_prompt_logprobs: int = 0, - adapter_uri: Optional[str] = None, + adapter_path: Optional[str] = None, + adapter_user_id: Optional[str] = None, request_id: Optional[str] = None, priority: int = 0, *, @@ -199,8 +200,8 @@ async def sample( logprobs: Whether to return log probabilities for generated tokens. include_prompt_logprobs: Whether to compute logprobs on prompt tokens. topk_prompt_logprobs: If > 0, returns top-k logprobs for each prompt token. - adapter_uri: URI of LoRA adapter to use (for multi-tenant mode). - Format: twinkle://{model_id}/lora/{user_id} + adapter_path: Resolved filesystem path to LoRA adapter directory. + adapter_user_id: User identifier for the adapter (for tracking loaded adapters). request_id: Optional request ID for tracking. priority: Request priority (higher = more urgent). images: Optional list of images for multimodal models. @@ -243,10 +244,10 @@ async def sample( else: prompt = TokensPrompt(prompt_token_ids=prompt_token_ids) - # Build LoRA request if adapter_uri provided + # Build LoRA request if adapter_path provided lora_request = None - if adapter_uri and self.enable_lora: - lora_request = await self._get_or_load_lora(adapter_uri) + if adapter_path and self.enable_lora: + lora_request = await self._get_or_load_lora(adapter_path, adapter_user_id) # Generate generator = self.engine.generate( @@ -337,42 +338,27 @@ def _generate_lora_id(self) -> int: self._next_lora_id += 1 return lora_id - def _parse_adapter_uri(self, adapter_uri: str) -> tuple: - if adapter_uri.startswith('twinkle://'): - # Format: twinkle://{user_id}/{relative_path} - suffix = adapter_uri[len('twinkle://'):] - parts = suffix.split('/', 1) - if len(parts) == 2: - user_id, relative_path = parts - # Hardcoded base path for debug - will be replaced with actual storage logic - base_path = "/mnt/nas2/yunlin.myl/twinkle/outputs" - lora_path = os.path.join(base_path, user_id, relative_path) - return user_id, lora_path - else: - return 'default', suffix - else: - # Local path - return 'default', adapter_uri - - async def _get_or_load_lora(self, adapter_uri: str): + async def _get_or_load_lora(self, lora_path: str, user_id: Optional[str] = None): """ - Get or load a LoRA adapter from URI, return LoRARequest for sampling. + Get or load a LoRA adapter from path, return LoRARequest for sampling. This method: - 1. Parses the URI to get user_id and path + 1. Uses the provided user_id for tracking (or 'default' if not provided) 2. Checks if already loaded for this user 3. Loads if needed 4. Returns the LoRARequest for vLLM Args: - adapter_uri: The adapter URI (twinkle://... or local path) + lora_path: Resolved filesystem path to the LoRA adapter directory + user_id: User identifier for tracking loaded adapters Returns: LoRARequest or None if loading fails """ from vllm.lora.request import LoRARequest - user_id, lora_path = self._parse_adapter_uri(adapter_uri) + if user_id is None: + user_id = 'default' # Check if already loaded for this user if user_id in self._user_lora_ids: @@ -383,7 +369,7 @@ async def _get_or_load_lora(self, adapter_uri: str): if lora_path != self._user_lora_paths[user_id]: # reload the lora await self.remove_adapter(user_id) - lora_request = await self._get_or_load_lora(adapter_uri) + lora_request = await self._get_or_load_lora(lora_path, user_id) return lora_request return LoRARequest( lora_name=f"lora_{user_id}", diff --git a/src/twinkle/server/tinker/common/io_utils.py b/src/twinkle/server/tinker/common/io_utils.py index f6f6a05d..cbf4de0e 100644 --- a/src/twinkle/server/tinker/common/io_utils.py +++ b/src/twinkle/server/tinker/common/io_utils.py @@ -1,6 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -""" -Tinker-specific IO utilities for managing training runs and checkpoints. +"""Tinker-specific IO utilities for managing training runs and checkpoints. This module extends the base IO utilities with Tinker-specific implementations. It uses types from the tinker package for compatibility with the Tinker API. diff --git a/src/twinkle/server/tinker/sampler.py b/src/twinkle/server/tinker/sampler.py index cc99a9cd..674e86c7 100644 --- a/src/twinkle/server/tinker/sampler.py +++ b/src/twinkle/server/tinker/sampler.py @@ -19,16 +19,48 @@ import twinkle from twinkle import DeviceGroup, DeviceMesh -from twinkle.server.utils.validation import verify_request_token +from twinkle.server.utils.validation import verify_request_token, get_token_from_request from twinkle.server.utils.state import get_server_state, ServerStateProxy from twinkle.server.utils.task_queue import TaskQueueMixin, TaskQueueConfig from twinkle.sampler.types import SamplingParams as TwinkleSamplingParams from twinkle.utils.logger import get_logger +from .common.io_utils import create_checkpoint_manager logger = get_logger() +def parse_adapter_uri(adapter_uri: str, token: str) -> tuple: + """Parse adapter URI to extract user_id and resolved lora_path. + + Args: + adapter_uri: The adapter URI, supports formats: + - twinkle://{training_run_id}/weights/{checkpoint_name} or sampler_weights/{name} + - Local filesystem path + token: User token for resolving the base output directory + + Returns: + Tuple of (user_id, lora_path) where lora_path is the resolved filesystem path + """ + if adapter_uri.startswith('twinkle://'): + # Use CheckpointManager to parse and resolve the path + checkpoint_manager = create_checkpoint_manager(token) + parsed = checkpoint_manager.parse_path(adapter_uri) + if parsed: + # Get the filesystem path using get_ckpt_dir + lora_path = str(checkpoint_manager.get_ckpt_dir( + parsed.training_run_id, parsed.checkpoint_id + )) + return parsed.training_run_id, lora_path + else: + # Fallback: parse manually for non-standard formats + suffix = adapter_uri[len('twinkle://'):] + return 'default', suffix + else: + # Local path + return 'default', adapter_uri + + def build_sampler_app(model_id: str, nproc_per_node: int, device_group: Dict[str, Any], @@ -145,8 +177,19 @@ async def _do_sample(): # Extract prompt token IDs from ModelInput prompt_token_ids = body.prompt.to_ints() - # Determine adapter URI from model_path - adapter_uri = body.model_path if body.model_path else None + # Get model_path: use body.model_path or look up from sampling session + model_path = body.model_path + if not model_path and body.sampling_session_id: + session = self.state.get_sampling_session(body.sampling_session_id) + if session: + model_path = session.get('model_path') + + # Parse and resolve adapter URI from model_path + adapter_uri = None + user_id = None + if model_path: + token = request.state.token + user_id, adapter_uri = parse_adapter_uri(model_path, token) # Convert tinker SamplingParams to twinkle SamplingParams if needed sampling_params = None @@ -169,7 +212,8 @@ async def _do_sample(): logprobs=want_logprobs, include_prompt_logprobs=body.prompt_logprobs or False, topk_prompt_logprobs=body.topk_prompt_logprobs or 0, - adapter_uri=adapter_uri, + adapter_path=adapter_uri, + adapter_user_id=user_id, ) # Convert twinkle SampleResponse to tinker types.SampleResponse diff --git a/src/twinkle/server/tinker/server.py b/src/twinkle/server/tinker/server.py index e4da1c6e..2b5b4b9d 100644 --- a/src/twinkle/server/tinker/server.py +++ b/src/twinkle/server/tinker/server.py @@ -216,17 +216,6 @@ async def _proxy_to_sampler(self, request: Request, endpoint: str, base_model: s """ return await self._proxy_request(request, endpoint, base_model, 'sampler') - @staticmethod - def _sample_output() -> types.SampleResponse: - """Generate a sample output for testing purposes. - - Returns: - A mock SampleResponse with dummy data - """ - sequence = types.SampledSequence(stop_reason="stop", tokens=[ - 1, 2, 3], logprobs=[-0.1, -0.2, -0.3]) - return types.SampleResponse(sequences=[sequence]) - # --- Endpoints --------------------------------------------------------- @app.get("/healthz") @@ -682,8 +671,8 @@ async def load_weights(self, request: Request, body: types.LoadWeightsRequest) - async def asample(self, request: Request, body: types.SampleRequest) -> Any: """Execute text generation (inference). - This endpoint first tries to use a local sampler if available. - Otherwise, it proxies the request to the sampler service. + Proxies the request to the sampler service based on base_model. + The sampler handles model_path resolution from sampling session. Args: body: Sample request with prompt and sampling parameters @@ -691,24 +680,14 @@ async def asample(self, request: Request, body: types.SampleRequest) -> Any: Returns: Proxied response from sampler service """ - model_path = body.model_path base_model = body.base_model - # If both are None, look up from sampling session - if not model_path and not base_model and body.sampling_session_id: + # If base_model not provided, look up from sampling session + if not base_model and body.sampling_session_id: session = self.state.get_sampling_session(body.sampling_session_id) if session: - model_path = session.get('model_path') base_model = session.get('base_model') - # Extract base_model from model_path if needed - if model_path and not base_model: - # Format: twinkle://Qwen/Qwen2.5-0.5B-Instruct/lora/xxx -> Qwen/Qwen2.5-0.5B-Instruct - path = model_path.replace("twinkle://", "").replace("tinker://", "") - parts = path.split("/") - if len(parts) >= 2: - base_model = f"{parts[0]}/{parts[1]}" - return await self._proxy_to_sampler(request, "asample", base_model) @app.post("/save_weights_for_sampler")