From 6beeb7591aa7ef016c7337ae3fd8bd47fd6a87f8 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Thu, 26 Feb 2026 14:05:12 +0800 Subject: [PATCH 1/3] update state --- src/twinkle/server/utils/adapter_manager.py | 2 - src/twinkle/server/utils/state.py | 609 ------------------ src/twinkle/server/utils/state/__init__.py | 29 + src/twinkle/server/utils/state/base.py | 68 ++ .../server/utils/state/config_manager.py | 53 ++ .../server/utils/state/future_manager.py | 95 +++ .../server/utils/state/model_manager.py | 47 ++ src/twinkle/server/utils/state/models.py | 56 ++ .../server/utils/state/sampling_manager.py | 47 ++ .../server/utils/state/server_state.py | 451 +++++++++++++ .../server/utils/state/session_manager.py | 77 +++ 11 files changed, 923 insertions(+), 611 deletions(-) delete mode 100644 src/twinkle/server/utils/state.py create mode 100644 src/twinkle/server/utils/state/__init__.py create mode 100644 src/twinkle/server/utils/state/base.py create mode 100644 src/twinkle/server/utils/state/config_manager.py create mode 100644 src/twinkle/server/utils/state/future_manager.py create mode 100644 src/twinkle/server/utils/state/model_manager.py create mode 100644 src/twinkle/server/utils/state/models.py create mode 100644 src/twinkle/server/utils/state/sampling_manager.py create mode 100644 src/twinkle/server/utils/state/server_state.py create mode 100644 src/twinkle/server/utils/state/session_manager.py diff --git a/src/twinkle/server/utils/adapter_manager.py b/src/twinkle/server/utils/adapter_manager.py index c24ce466..06bdbfc3 100644 --- a/src/twinkle/server/utils/adapter_manager.py +++ b/src/twinkle/server/utils/adapter_manager.py @@ -67,8 +67,6 @@ def _init_adapter_manager( # Dict mapping adapter_name -> # {'token': str, 'session_id': str, 'last_activity': float, 'created_at': float, 'inactivity_counter': int} self._adapter_records: dict[str, dict[str, Any]] = {} - # Track adapter count per token - self._adapter_counts: dict[str, int] = {} # Countdown thread self._adapter_countdown_thread: threading.Thread | None = None diff --git a/src/twinkle/server/utils/state.py b/src/twinkle/server/utils/state.py deleted file mode 100644 index e191d80a..00000000 --- a/src/twinkle/server/utils/state.py +++ /dev/null @@ -1,609 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -from __future__ import annotations - -import asyncio -import ray -import re -import time -import uuid -from datetime import datetime -from typing import Any, Dict, Optional - -from twinkle.utils.logger import get_logger - -logger = get_logger() - - -class ServerState: - """ - Unified server state management class. - - This class combines the functionality of: - 1. Session management (create, touch, heartbeat) - 2. Model registration and tracking - 3. Sampling session management - 4. Async future storage and retrieval - 5. Configuration storage - - All methods are designed to be used with Ray actors for distributed state. - """ - - def __init__( - self, - expiration_timeout: float = 86400.0, # 24 hours in seconds - cleanup_interval: float = 3600.0, - **kwargs) -> None: # 1 hour in seconds - # Session tracking - self.sessions: dict[str, dict[str, Any]] = {} - # Model registration - self.models: dict[str, dict[str, Any]] = {} - # Sampling session tracking - self.sampling_sessions: dict[str, dict[str, Any]] = {} - # Async future results - self.futures: dict[str, dict[str, Any]] = {} - # Configuration storage - self.config: dict[str, Any] = {} - - # Cleanup configuration - self.expiration_timeout = expiration_timeout - self.cleanup_interval = cleanup_interval - self._cleanup_task: asyncio.Task | None = None - self._cleanup_running = False - - # ----- Session Management ----- - - def create_session(self, payload: dict[str, Any]) -> str: - """ - Create a new session with the given payload. - - Args: - payload: Session configuration containing optional session_id, tags, etc. - - Returns: - The session_id for the created session - """ - session_id = payload.get('session_id') or f'session_{uuid.uuid4().hex}' - self.sessions[session_id] = { - 'tags': list(payload.get('tags') or []), - 'user_metadata': payload.get('user_metadata') or {}, - 'sdk_version': payload.get('sdk_version'), - 'created_at': datetime.now().isoformat(), - 'last_heartbeat': time.time(), - } - return session_id - - def touch_session(self, session_id: str) -> bool: - """ - Update session heartbeat timestamp. - - Args: - session_id: The session to touch - - Returns: - True if session exists and was touched, False otherwise - """ - if session_id not in self.sessions: - return False - self.sessions[session_id]['last_heartbeat'] = time.time() - return True - - def get_session_last_heartbeat(self, session_id: str) -> float | None: - """ - Get the last heartbeat timestamp for a session. - - Args: - session_id: The session ID to query - - Returns: - Last heartbeat timestamp, or None if session doesn't exist - """ - session_info = self.sessions.get(session_id) - if not session_info: - return None - return session_info.get('last_heartbeat') - - # ----- Model Registration ----- - - def register_model(self, payload: dict[str, Any], model_id: str | None = None, token: str | None = None) -> str: - """ - Register a new model with the server state. - - Args: - payload: Model configuration containing base_model, lora_config, etc. - model_id: Optional explicit model_id, otherwise auto-generated - token: Optional user token for tracking ownership - - Returns: - The model_id for the registered model - """ - _time = datetime.now().strftime('%Y%m%d_%H%M%S') - _model_id: str = model_id or payload.get( - 'model_id') or f"{_time}-{payload.get('base_model', 'model')}-{uuid.uuid4().hex[:8]}" - _model_id = re.sub(r'[^\w\-]', '_', _model_id) - - self.models[_model_id] = { - 'session_id': payload.get('session_id'), - 'model_seq_id': payload.get('model_seq_id'), - 'base_model': payload.get('base_model'), - 'user_metadata': payload.get('user_metadata') or {}, - 'lora_config': payload.get('lora_config'), - 'token': token, # Store token for adapter cleanup integration - 'created_at': datetime.now().isoformat(), - } - return _model_id - - def unload_model(self, model_id: str) -> bool: - """ - Remove a model from the registry. - - Args: - model_id: The model to unload - - Returns: - True if model was found and removed, False otherwise - """ - return self.models.pop(model_id, None) is not None - - def get_model_metadata(self, model_id: str) -> dict[str, Any] | None: - """Get metadata for a registered model.""" - return self.models.get(model_id) - - # ----- Sampling Session Management ----- - - def create_sampling_session(self, payload: dict[str, Any], sampling_session_id: str | None = None) -> str: - """ - Create a new sampling session. - - Args: - payload: Session configuration - sampling_session_id: Optional explicit ID - - Returns: - The sampling_session_id - """ - _sampling_session_id: str = sampling_session_id or payload.get( - 'sampling_session_id') or f'sampling_{uuid.uuid4().hex}' - self.sampling_sessions[_sampling_session_id] = { - 'session_id': payload.get('session_id'), - 'seq_id': payload.get('sampling_session_seq_id'), - 'base_model': payload.get('base_model'), - 'model_path': payload.get('model_path'), - 'created_at': datetime.now().isoformat(), - } - return _sampling_session_id - - def get_sampling_session(self, sampling_session_id: str) -> dict[str, Any] | None: - """Get a sampling session by ID.""" - return self.sampling_sessions.get(sampling_session_id) - - # ----- Future Management ----- - - def get_future(self, request_id: str) -> dict[str, Any] | None: - """Retrieve a stored future result.""" - return self.futures.get(request_id) - - def store_future_status( - self, - request_id: str, - status: str, - model_id: str | None, - reason: str | None = None, - result: Any = None, - queue_state: str | None = None, - queue_state_reason: str | None = None, - ) -> None: - """Store task status with optional result. - - This method supports the full task lifecycle: - - PENDING: Task created, waiting to be processed - - QUEUED: Task in queue waiting for execution - - RUNNING: Task currently executing - - COMPLETED: Task completed successfully (result required) - - FAILED: Task failed with error (result contains error payload) - - RATE_LIMITED: Task rejected due to rate limiting (reason required) - - Args: - request_id: Unique identifier for the request. - status: Task status string (pending/queued/running/completed/failed/rate_limited). - model_id: Optional associated model_id. - reason: Optional reason string (used for rate_limited status). - result: Optional result data (used for completed/failed status). - queue_state: Optional queue state for tinker client (active/paused_rate_limit/paused_capacity). - queue_state_reason: Optional reason for the queue state. - """ - # Serialize result if it has model_dump method - if result is not None and hasattr(result, 'model_dump'): - result = result.model_dump() - - future_data: dict[str, Any] = { - 'status': status, - 'model_id': model_id, - 'updated_at': datetime.now().isoformat(), - } - - # Include reason for rate_limited status - if reason is not None: - future_data['reason'] = reason - - # Include result for completed/failed status - if result is not None: - future_data['result'] = result - - # Include queue_state and queue_state_reason for tinker client compatibility - if queue_state is not None: - future_data['queue_state'] = queue_state - if queue_state_reason is not None: - future_data['queue_state_reason'] = queue_state_reason - - # Update or create the future entry - if request_id in self.futures: - self.futures[request_id].update(future_data) - else: - future_data['created_at'] = datetime.now().isoformat() - self.futures[request_id] = future_data - - # ----- Config Management (from ConfigRegistry) ----- - - def add_config(self, key: str, value: Any): - """ - Add or update a configuration value. - - Args: - key: Configuration key - value: Configuration value - """ - self.config[key] = value - - def add_or_get(self, key: str, value: Any) -> Any: - """ - Add a config if not exists, otherwise return existing value. - - Args: - key: Configuration key - value: Value to add if key doesn't exist - - Returns: - The existing or newly added value - """ - if key in self.config: - return self.config[key] - self.config[key] = value - return value - - def get_config(self, key: str) -> Any | None: - """Get a configuration value by key.""" - return self.config.get(key) - - def pop_config(self, key: str) -> Any | None: - """Remove and return a configuration value.""" - return self.config.pop(key, None) - - def clear_config(self): - """Clear all configuration values.""" - self.config.clear() - - # ----- Resource Cleanup ----- - - def _parse_timestamp(self, timestamp_str: str) -> float: - """Parse ISO format timestamp to unix timestamp. - - Args: - timestamp_str: ISO format timestamp string - - Returns: - Unix timestamp (seconds since epoch) - """ - try: - dt = datetime.fromisoformat(timestamp_str) - return dt.timestamp() - except (ValueError, AttributeError): - # If parsing fails, return current time to avoid keeping invalid entries - return time.time() - - def cleanup_expired_resources(self) -> dict[str, int]: - """Clean up expired sessions, models, sampling_sessions, and futures. - - Resources are considered expired if they haven't been accessed for longer - than the expiration_timeout period. For sessions, we check last_heartbeat - (or created_at if no heartbeat exists). For other resources, we check created_at. - - Returns: - Dict with counts of cleaned up resources by type - """ - current_time = time.time() - cutoff_time = current_time - self.expiration_timeout - - cleanup_stats = { - 'sessions': 0, - 'models': 0, - 'sampling_sessions': 0, - 'futures': 0, - } - - # Clean up expired sessions - expired_session_ids = [] - for session_id, session_data in self.sessions.items(): - # Use last_heartbeat if available, otherwise created_at - last_activity = session_data.get('last_heartbeat') - if last_activity is None: - created_at_str = session_data.get('created_at') - if created_at_str: - last_activity = self._parse_timestamp(created_at_str) - else: - last_activity = 0 - - if last_activity < cutoff_time: - expired_session_ids.append(session_id) - - for session_id in expired_session_ids: - del self.sessions[session_id] - cleanup_stats['sessions'] += 1 - - # Clean up expired models (check by session_id association or created_at) - expired_model_ids = [] - for model_id, model_data in self.models.items(): - # First check if the model's session has been cleaned up - session_id = model_data.get('session_id') - if session_id and session_id in expired_session_ids: - expired_model_ids.append(model_id) - else: - # Check if model itself is expired by created_at - created_at_str = model_data.get('created_at') - if created_at_str: - created_at = self._parse_timestamp(created_at_str) - if created_at < cutoff_time: - expired_model_ids.append(model_id) - - for model_id in expired_model_ids: - del self.models[model_id] - cleanup_stats['models'] += 1 - - # Clean up expired sampling sessions - expired_sampling_ids = [] - for sampling_id, sampling_data in self.sampling_sessions.items(): - # Check by session_id association or created_at - session_id = sampling_data.get('session_id') - if session_id and session_id in expired_session_ids: - expired_sampling_ids.append(sampling_id) - else: - created_at_str = sampling_data.get('created_at') - if created_at_str: - created_at = self._parse_timestamp(created_at_str) - if created_at < cutoff_time: - expired_sampling_ids.append(sampling_id) - - for sampling_id in expired_sampling_ids: - del self.sampling_sessions[sampling_id] - cleanup_stats['sampling_sessions'] += 1 - - # Clean up expired futures (use created_at or updated_at) - expired_future_ids = [] - for request_id, future_data in self.futures.items(): - # Use updated_at if available, otherwise created_at - timestamp_str = future_data.get('updated_at') or future_data.get('created_at') - if timestamp_str: - timestamp = self._parse_timestamp(timestamp_str) - if timestamp < cutoff_time: - expired_future_ids.append(request_id) - - for request_id in expired_future_ids: - del self.futures[request_id] - cleanup_stats['futures'] += 1 - - return cleanup_stats - - async def _cleanup_loop(self) -> None: - """Background task that periodically cleans up expired resources. - - This task runs continuously and triggers cleanup at regular intervals - defined by cleanup_interval. - """ - while self._cleanup_running: - try: - await asyncio.sleep(self.cleanup_interval) - stats = self.cleanup_expired_resources() - # Log cleanup stats (in production, you might want to use proper logging) - if any(stats.values()): - logger.debug(f'[ServerState Cleanup] Removed expired resources: {stats}') - except asyncio.CancelledError: - break - except Exception as e: - # Log but don't crash the cleanup task - logger.warning(f'[ServerState Cleanup] Error during cleanup: {e}') - continue - - def start_cleanup_task(self) -> bool: - """Start the background cleanup task. - - Returns: - True if task was started, False if already running - """ - if self._cleanup_running: - return False - - self._cleanup_running = True - self._cleanup_task = asyncio.create_task(self._cleanup_loop()) - return True - - def stop_cleanup_task(self) -> bool: - """Stop the background cleanup task. - - Returns: - True if task was stopped, False if not running - """ - if not self._cleanup_running: - return False - - self._cleanup_running = False - if self._cleanup_task: - self._cleanup_task.cancel() - self._cleanup_task = None - return True - - def get_cleanup_stats(self) -> dict[str, Any]: - """Get current cleanup configuration and status. - - Returns: - Dict with cleanup configuration and task status - """ - return { - 'expiration_timeout': self.expiration_timeout, - 'cleanup_interval': self.cleanup_interval, - 'cleanup_running': self._cleanup_running, - 'resource_counts': { - 'sessions': len(self.sessions), - 'models': len(self.models), - 'sampling_sessions': len(self.sampling_sessions), - 'futures': len(self.futures), - } - } - - -class ServerStateProxy: - """ - Proxy for interacting with ServerState Ray actor. - - This class wraps Ray remote calls to provide a synchronous-looking API - for interacting with the distributed ServerState actor. - """ - - def __init__(self, actor_handle): - self._actor = actor_handle - - # ----- Session Management ----- - - def create_session(self, payload: dict[str, Any]) -> str: - return ray.get(self._actor.create_session.remote(payload)) - - def touch_session(self, session_id: str) -> bool: - return ray.get(self._actor.touch_session.remote(session_id)) - - def get_session_last_heartbeat(self, session_id: str) -> float | None: - return ray.get(self._actor.get_session_last_heartbeat.remote(session_id)) - - # ----- Model Registration ----- - - def register_model(self, payload: dict[str, Any], model_id: str | None = None, token: str | None = None) -> str: - return ray.get(self._actor.register_model.remote(payload, model_id, token)) - - def unload_model(self, model_id: str) -> bool: - return ray.get(self._actor.unload_model.remote(model_id)) - - def get_model_metadata(self, model_id: str) -> dict[str, Any] | None: - return ray.get(self._actor.get_model_metadata.remote(model_id)) - - # ----- Sampling Session Management ----- - - def create_sampling_session(self, payload: dict[str, Any], sampling_session_id: str | None = None) -> str: - return ray.get(self._actor.create_sampling_session.remote(payload, sampling_session_id)) - - def get_sampling_session(self, sampling_session_id: str) -> dict[str, Any] | None: - """Get a sampling session by ID.""" - return ray.get(self._actor.get_sampling_session.remote(sampling_session_id)) - - # ----- Future Management ----- - - def get_future(self, request_id: str) -> dict[str, Any] | None: - return ray.get(self._actor.get_future.remote(request_id)) - - def store_future_status( - self, - request_id: str, - status: str, - model_id: str | None, - reason: str | None = None, - result: Any = None, - queue_state: str | None = None, - queue_state_reason: str | None = None, - ) -> None: - """Store task status with optional result (synchronous).""" - ray.get( - self._actor.store_future_status.remote(request_id, status, model_id, reason, result, queue_state, - queue_state_reason)) - - # ----- Config Management ----- - - def add_config(self, key: str, value: Any): - return ray.get(self._actor.add_config.remote(key, value)) - - def add_or_get(self, key: str, value: Any) -> Any: - return ray.get(self._actor.add_or_get.remote(key, value)) - - def get_config(self, key: str) -> Any | None: - return ray.get(self._actor.get_config.remote(key)) - - def pop_config(self, key: str) -> Any | None: - return ray.get(self._actor.pop_config.remote(key)) - - def clear_config(self): - return ray.get(self._actor.clear_config.remote()) - - # ----- Resource Cleanup ----- - - def cleanup_expired_resources(self) -> dict[str, int]: - """Manually trigger cleanup of expired resources. - - Returns: - Dict with counts of cleaned up resources by type - """ - return ray.get(self._actor.cleanup_expired_resources.remote()) - - def start_cleanup_task(self) -> bool: - """Start the background cleanup task. - - Returns: - True if task was started, False if already running - """ - return ray.get(self._actor.start_cleanup_task.remote()) - - def stop_cleanup_task(self) -> bool: - """Stop the background cleanup task. - - Returns: - True if task was stopped, False if not running - """ - return ray.get(self._actor.stop_cleanup_task.remote()) - - def get_cleanup_stats(self) -> dict[str, Any]: - """Get current cleanup configuration and status. - - Returns: - Dict with cleanup configuration and task status - """ - return ray.get(self._actor.get_cleanup_stats.remote()) - - -def get_server_state(actor_name: str = 'twinkle_server_state', - auto_start_cleanup: bool = True, - **server_state_kwargs) -> ServerStateProxy: - """ - Get or create the ServerState Ray actor. - - This function ensures only one ServerState actor exists with the given name. - It uses a detached actor so the state persists across driver restarts. - - Args: - actor_name: Name for the Ray actor (default: 'twinkle_server_state') - auto_start_cleanup: Whether to automatically start the cleanup task (default: True) - **server_state_kwargs: Additional keyword arguments passed to ServerState constructor - (e.g., expiration_timeout, cleanup_interval, per_token_adapter_limit) - - Returns: - A ServerStateProxy for interacting with the actor - """ - try: - actor = ray.get_actor(actor_name) - except ValueError: - try: - _ServerState = ray.remote(ServerState) - actor = _ServerState.options(name=actor_name, lifetime='detached').remote(**server_state_kwargs) - # Start cleanup task for newly created actor - if auto_start_cleanup: - try: - ray.get(actor.start_cleanup_task.remote()) - except Exception as e: - logger.debug(f'[ServerState] Warning: Failed to start cleanup task: {e}') - except ValueError: - actor = ray.get_actor(actor_name) - assert actor is not None - return ServerStateProxy(actor) diff --git a/src/twinkle/server/utils/state/__init__.py b/src/twinkle/server/utils/state/__init__.py new file mode 100644 index 00000000..0e34697a --- /dev/null +++ b/src/twinkle/server/utils/state/__init__.py @@ -0,0 +1,29 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from .base import BaseManager +from .config_manager import ConfigManager +from .future_manager import FutureManager +from .model_manager import ModelManager +from .models import FutureRecord, ModelRecord, SamplingSessionRecord, SessionRecord +from .sampling_manager import SamplingSessionManager +from .server_state import ServerState, ServerStateProxy, get_server_state +from .session_manager import SessionManager + +__all__ = [ + # Pydantic record models + 'SessionRecord', + 'ModelRecord', + 'SamplingSessionRecord', + 'FutureRecord', + # Base + 'BaseManager', + # Resource managers + 'SessionManager', + 'ModelManager', + 'SamplingSessionManager', + 'FutureManager', + 'ConfigManager', + # Server state + 'ServerState', + 'ServerStateProxy', + 'get_server_state', +] diff --git a/src/twinkle/server/utils/state/base.py b/src/twinkle/server/utils/state/base.py new file mode 100644 index 00000000..c7480ec7 --- /dev/null +++ b/src/twinkle/server/utils/state/base.py @@ -0,0 +1,68 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from __future__ import annotations + +import time +from abc import ABC, abstractmethod +from datetime import datetime +from pydantic import BaseModel +from typing import Generic, TypeVar + +T = TypeVar('T', bound=BaseModel) + + +class BaseManager(ABC, Generic[T]): + """ + Abstract base class for resource managers. + + Provides common CRUD operations and timestamp parsing. + Subclasses must implement `cleanup_expired`. + """ + + def __init__(self, expiration_timeout: float) -> None: + self._store: dict[str, T] = {} + self.expiration_timeout = expiration_timeout + + # ----- CRUD ----- + + def add(self, resource_id: str, record: T) -> None: + """Store a record under the given ID.""" + self._store[resource_id] = record + + def get(self, resource_id: str) -> T | None: + """Return the record for the given ID, or None.""" + return self._store.get(resource_id) + + def remove(self, resource_id: str) -> bool: + """Remove a record by ID. Returns True if it existed.""" + return self._store.pop(resource_id, None) is not None + + def count(self) -> int: + """Return the number of stored records.""" + return len(self._store) + + # ----- Cleanup ----- + + @abstractmethod + def cleanup_expired(self, cutoff_time: float) -> int: + """ + Remove all records older than cutoff_time. + + Args: + cutoff_time: Unix timestamp; records with activity before this are removed. + + Returns: + Number of records removed. + """ + + # ----- Helpers ----- + + def _parse_timestamp(self, timestamp_str: str) -> float: + """Parse an ISO-format timestamp string to a Unix timestamp. + + Falls back to the current time so that unparseable entries are + never accidentally kept alive forever. + """ + try: + return datetime.fromisoformat(timestamp_str).timestamp() + except (ValueError, AttributeError): + return time.time() diff --git a/src/twinkle/server/utils/state/config_manager.py b/src/twinkle/server/utils/state/config_manager.py new file mode 100644 index 00000000..e1aa3bce --- /dev/null +++ b/src/twinkle/server/utils/state/config_manager.py @@ -0,0 +1,53 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from __future__ import annotations + +from typing import Any + + +class ConfigManager: + """ + Manages key-value configuration entries. + + Configuration entries have no expiry; they persist until explicitly removed + or cleared. This manager does not inherit from BaseManager because config + values are arbitrary Python objects rather than Pydantic models. + """ + + def __init__(self) -> None: + self._store: dict[str, Any] = {} + + # ----- CRUD ----- + + def add(self, key: str, value: Any) -> None: + """Add or overwrite a configuration value.""" + self._store[key] = value + + def add_or_get(self, key: str, value: Any) -> Any: + """Add a value if the key does not exist; otherwise return the existing value. + + Args: + key: Configuration key. + value: Value to store if the key is absent. + + Returns: + The existing or newly stored value. + """ + if key not in self._store: + self._store[key] = value + return self._store[key] + + def get(self, key: str) -> Any | None: + """Return the configuration value for key, or None.""" + return self._store.get(key) + + def pop(self, key: str) -> Any | None: + """Remove and return the configuration value for key, or None.""" + return self._store.pop(key, None) + + def clear(self) -> None: + """Remove all configuration entries.""" + self._store.clear() + + def count(self) -> int: + """Return the number of stored configuration entries.""" + return len(self._store) diff --git a/src/twinkle/server/utils/state/future_manager.py b/src/twinkle/server/utils/state/future_manager.py new file mode 100644 index 00000000..0af069a8 --- /dev/null +++ b/src/twinkle/server/utils/state/future_manager.py @@ -0,0 +1,95 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from __future__ import annotations + +from datetime import datetime +from typing import Any + +from .base import BaseManager +from .models import FutureRecord + + +class FutureManager(BaseManager[FutureRecord]): + """ + Manages async task futures / request statuses. + + Expiry is based on `updated_at` (falls back to `created_at`). + """ + + # ----- Future-specific operations ----- + + def store_status( + self, + request_id: str, + status: str, + model_id: str | None, + reason: str | None = None, + result: Any = None, + queue_state: str | None = None, + queue_state_reason: str | None = None, + ) -> None: + """Create or update a future record with the latest status. + + If the result object has a `model_dump` method (i.e. it is a Pydantic + model) it is serialized to a plain dict before storage. + + Args: + request_id: Unique identifier for the request. + status: Task status string (pending/queued/running/completed/failed/rate_limited). + model_id: Optional associated model_id. + reason: Optional reason string (used for rate_limited status). + result: Optional result data (used for completed/failed status). + queue_state: Optional queue state (active/paused_rate_limit/paused_capacity). + queue_state_reason: Optional reason for the queue state. + """ + if result is not None and hasattr(result, 'model_dump'): + result = result.model_dump() + + now = datetime.now().isoformat() + existing = self._store.get(request_id) + + if existing is not None: + existing.status = status + existing.model_id = model_id + existing.updated_at = now + if reason is not None: + existing.reason = reason + if result is not None: + existing.result = result + if queue_state is not None: + existing.queue_state = queue_state + if queue_state_reason is not None: + existing.queue_state_reason = queue_state_reason + else: + self._store[request_id] = FutureRecord( + status=status, + model_id=model_id, + reason=reason, + result=result, + queue_state=queue_state, + queue_state_reason=queue_state_reason, + created_at=now, + updated_at=now, + ) + + # ----- Cleanup ----- + + def cleanup_expired(self, cutoff_time: float) -> int: + """Remove futures whose last update is older than cutoff_time. + + Args: + cutoff_time: Unix timestamp threshold. + + Returns: + Number of futures removed. + """ + expired_ids = [] + for request_id, record in self._store.items(): + timestamp_str = record.updated_at or record.created_at + timestamp = self._parse_timestamp(timestamp_str) + if timestamp < cutoff_time: + expired_ids.append(request_id) + + for request_id in expired_ids: + del self._store[request_id] + + return len(expired_ids) diff --git a/src/twinkle/server/utils/state/model_manager.py b/src/twinkle/server/utils/state/model_manager.py new file mode 100644 index 00000000..433a607b --- /dev/null +++ b/src/twinkle/server/utils/state/model_manager.py @@ -0,0 +1,47 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from __future__ import annotations + +from .base import BaseManager +from .models import ModelRecord + + +class ModelManager(BaseManager[ModelRecord]): + """ + Manages registered models. + + Expiry is based on `created_at`. A model is also considered expired if + its owning session has already been removed (cascade expiry). + """ + + # ----- Cleanup ----- + + def cleanup_expired(self, cutoff_time: float, expired_session_ids: list[str] | None = None) -> int: + """Remove models that are older than cutoff_time, or whose owning + session has already been expired. + + Args: + cutoff_time: Unix timestamp threshold. + expired_session_ids: Optional list of session IDs that have just + been expired; any model belonging to one of these sessions will + also be removed regardless of its own age. + + Returns: + Number of models removed. + """ + session_set = set(expired_session_ids or []) + expired_ids = [] + + for model_id, record in self._store.items(): + # Cascade: owner session was expired + if record.session_id and record.session_id in session_set: + expired_ids.append(model_id) + continue + # Own age + created_at = self._parse_timestamp(record.created_at) + if created_at < cutoff_time: + expired_ids.append(model_id) + + for model_id in expired_ids: + del self._store[model_id] + + return len(expired_ids) diff --git a/src/twinkle/server/utils/state/models.py b/src/twinkle/server/utils/state/models.py new file mode 100644 index 00000000..343c4998 --- /dev/null +++ b/src/twinkle/server/utils/state/models.py @@ -0,0 +1,56 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from __future__ import annotations + +import time +from datetime import datetime +from pydantic import BaseModel, Field +from typing import Any + + +def _now_iso() -> str: + return datetime.now().isoformat() + + +class SessionRecord(BaseModel): + """Represents a client session.""" + + tags: list[str] = Field(default_factory=list) + user_metadata: dict[str, Any] = Field(default_factory=dict) + sdk_version: str | None = None + created_at: str = Field(default_factory=_now_iso) + last_heartbeat: float = Field(default_factory=time.time) + + +class ModelRecord(BaseModel): + """Represents a registered model.""" + + session_id: str | None = None + model_seq_id: Any = None + base_model: str | None = None + user_metadata: dict[str, Any] = Field(default_factory=dict) + lora_config: Any = None + token: str | None = None + created_at: str = Field(default_factory=_now_iso) + + +class SamplingSessionRecord(BaseModel): + """Represents a sampling session.""" + + session_id: str | None = None + seq_id: Any = None + base_model: str | None = None + model_path: str | None = None + created_at: str = Field(default_factory=_now_iso) + + +class FutureRecord(BaseModel): + """Represents an async task future / request status.""" + + status: str + model_id: str | None = None + reason: str | None = None + result: Any = None + queue_state: str | None = None + queue_state_reason: str | None = None + created_at: str = Field(default_factory=_now_iso) + updated_at: str = Field(default_factory=_now_iso) diff --git a/src/twinkle/server/utils/state/sampling_manager.py b/src/twinkle/server/utils/state/sampling_manager.py new file mode 100644 index 00000000..ff3111a6 --- /dev/null +++ b/src/twinkle/server/utils/state/sampling_manager.py @@ -0,0 +1,47 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from __future__ import annotations + +from .base import BaseManager +from .models import SamplingSessionRecord + + +class SamplingSessionManager(BaseManager[SamplingSessionRecord]): + """ + Manages sampling sessions. + + Expiry is based on `created_at`. A sampling session is also considered + expired if its owning session has already been removed (cascade expiry). + """ + + # ----- Cleanup ----- + + def cleanup_expired(self, cutoff_time: float, expired_session_ids: list[str] | None = None) -> int: + """Remove sampling sessions that are older than cutoff_time, or whose + owning session has already been expired. + + Args: + cutoff_time: Unix timestamp threshold. + expired_session_ids: Optional list of session IDs that have just + been expired; any sampling session belonging to one of these + sessions will also be removed regardless of its own age. + + Returns: + Number of sampling sessions removed. + """ + session_set = set(expired_session_ids or []) + expired_ids = [] + + for sampling_id, record in self._store.items(): + # Cascade: owner session was expired + if record.session_id and record.session_id in session_set: + expired_ids.append(sampling_id) + continue + # Own age + created_at = self._parse_timestamp(record.created_at) + if created_at < cutoff_time: + expired_ids.append(sampling_id) + + for sampling_id in expired_ids: + del self._store[sampling_id] + + return len(expired_ids) diff --git a/src/twinkle/server/utils/state/server_state.py b/src/twinkle/server/utils/state/server_state.py new file mode 100644 index 00000000..3ac52163 --- /dev/null +++ b/src/twinkle/server/utils/state/server_state.py @@ -0,0 +1,451 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from __future__ import annotations + +import asyncio +import ray +import re +import time +import uuid +from datetime import datetime +from typing import Any + +from twinkle.utils.logger import get_logger +from .config_manager import ConfigManager +from .future_manager import FutureManager +from .model_manager import ModelManager +from .models import ModelRecord, SamplingSessionRecord, SessionRecord +from .sampling_manager import SamplingSessionManager +from .session_manager import SessionManager + +logger = get_logger() + + +class ServerState: + """ + Unified server state management class. + + Composes five resource managers: + - SessionManager — client sessions + - ModelManager — registered models + - SamplingSessionManager — sampling sessions + - FutureManager — async task futures + - ConfigManager — key-value configuration + + All methods are designed to be used with Ray actors for distributed state. + """ + + def __init__( + self, + expiration_timeout: float = 86400.0, # 24 hours in seconds + cleanup_interval: float = 3600.0, # 1 hour in seconds + **kwargs) -> None: + self._session_mgr = SessionManager(expiration_timeout) + self._model_mgr = ModelManager(expiration_timeout) + self._sampling_mgr = SamplingSessionManager(expiration_timeout) + self._future_mgr = FutureManager(expiration_timeout) + self._config_mgr = ConfigManager() + + self.expiration_timeout = expiration_timeout + self.cleanup_interval = cleanup_interval + self._cleanup_task: asyncio.Task | None = None + self._cleanup_running = False + + # ----- Session Management ----- + + def create_session(self, payload: dict[str, Any]) -> str: + """Create a new session with the given payload. + + Args: + payload: Session configuration containing optional session_id, tags, etc. + + Returns: + The session_id for the created session. + """ + session_id = payload.get('session_id') or f'session_{uuid.uuid4().hex}' + record = SessionRecord( + tags=list(payload.get('tags') or []), + user_metadata=payload.get('user_metadata') or {}, + sdk_version=payload.get('sdk_version'), + ) + self._session_mgr.add(session_id, record) + return session_id + + def touch_session(self, session_id: str) -> bool: + """Update session heartbeat timestamp. + + Returns: + True if the session exists and was touched, False otherwise. + """ + return self._session_mgr.touch(session_id) + + def get_session_last_heartbeat(self, session_id: str) -> float | None: + """Get the last heartbeat timestamp for a session. + + Returns: + Last heartbeat timestamp, or None if the session does not exist. + """ + return self._session_mgr.get_last_heartbeat(session_id) + + # ----- Model Registration ----- + + def register_model(self, payload: dict[str, Any], model_id: str | None = None, token: str | None = None) -> str: + """Register a new model with the server state. + + Args: + payload: Model configuration containing base_model, lora_config, etc. + model_id: Optional explicit model_id; otherwise auto-generated. + token: Optional user token for tracking ownership. + + Returns: + The model_id for the registered model. + """ + _time = datetime.now().strftime('%Y%m%d_%H%M%S') + _model_id: str = model_id or payload.get( + 'model_id') or f"{_time}-{payload.get('base_model', 'model')}-{uuid.uuid4().hex[:8]}" + _model_id = re.sub(r'[^\w\-]', '_', _model_id) + + record = ModelRecord( + session_id=payload.get('session_id'), + model_seq_id=payload.get('model_seq_id'), + base_model=payload.get('base_model'), + user_metadata=payload.get('user_metadata') or {}, + lora_config=payload.get('lora_config'), + token=token, + ) + self._model_mgr.add(_model_id, record) + return _model_id + + def unload_model(self, model_id: str) -> bool: + """Remove a model from the registry. + + Returns: + True if the model was found and removed, False otherwise. + """ + return self._model_mgr.remove(model_id) + + def get_model_metadata(self, model_id: str) -> dict[str, Any] | None: + """Get metadata for a registered model as a plain dict.""" + record = self._model_mgr.get(model_id) + return record.model_dump() if record is not None else None + + # ----- Sampling Session Management ----- + + def create_sampling_session(self, payload: dict[str, Any], sampling_session_id: str | None = None) -> str: + """Create a new sampling session. + + Args: + payload: Session configuration. + sampling_session_id: Optional explicit ID. + + Returns: + The sampling_session_id. + """ + _sampling_session_id: str = sampling_session_id or payload.get( + 'sampling_session_id') or f'sampling_{uuid.uuid4().hex}' + record = SamplingSessionRecord( + session_id=payload.get('session_id'), + seq_id=payload.get('sampling_session_seq_id'), + base_model=payload.get('base_model'), + model_path=payload.get('model_path'), + ) + self._sampling_mgr.add(_sampling_session_id, record) + return _sampling_session_id + + def get_sampling_session(self, sampling_session_id: str) -> dict[str, Any] | None: + """Get a sampling session by ID as a plain dict.""" + record = self._sampling_mgr.get(sampling_session_id) + return record.model_dump() if record is not None else None + + # ----- Future Management ----- + + def get_future(self, request_id: str) -> dict[str, Any] | None: + """Retrieve a stored future result as a plain dict.""" + record = self._future_mgr.get(request_id) + return record.model_dump() if record is not None else None + + def store_future_status( + self, + request_id: str, + status: str, + model_id: str | None, + reason: str | None = None, + result: Any = None, + queue_state: str | None = None, + queue_state_reason: str | None = None, + ) -> None: + """Store task status with optional result. + + Supports the full task lifecycle: + - PENDING: Task created, waiting to be processed + - QUEUED: Task in queue waiting for execution + - RUNNING: Task currently executing + - COMPLETED: Task completed successfully (result required) + - FAILED: Task failed with error (result contains error payload) + - RATE_LIMITED: Task rejected due to rate limiting (reason required) + + Args: + request_id: Unique identifier for the request. + status: Task status string (pending/queued/running/completed/failed/rate_limited). + model_id: Optional associated model_id. + reason: Optional reason string (used for rate_limited status). + result: Optional result data (used for completed/failed status). + queue_state: Optional queue state for tinker client (active/paused_rate_limit/paused_capacity). + queue_state_reason: Optional reason for the queue state. + """ + self._future_mgr.store_status( + request_id=request_id, + status=status, + model_id=model_id, + reason=reason, + result=result, + queue_state=queue_state, + queue_state_reason=queue_state_reason, + ) + + # ----- Config Management ----- + + def add_config(self, key: str, value: Any) -> None: + """Add or update a configuration value.""" + self._config_mgr.add(key, value) + + def add_or_get(self, key: str, value: Any) -> Any: + """Add a config value if the key does not exist; otherwise return the existing value.""" + return self._config_mgr.add_or_get(key, value) + + def get_config(self, key: str) -> Any | None: + """Get a configuration value by key.""" + return self._config_mgr.get(key) + + def pop_config(self, key: str) -> Any | None: + """Remove and return a configuration value.""" + return self._config_mgr.pop(key) + + def clear_config(self) -> None: + """Clear all configuration values.""" + self._config_mgr.clear() + + # ----- Resource Cleanup ----- + + def cleanup_expired_resources(self) -> dict[str, int]: + """Clean up expired sessions, models, sampling_sessions, and futures. + + Sessions expire based on last_heartbeat (or created_at). Models and + sampling sessions are also cascade-expired when their owning session + expires. Futures expire based on updated_at (or created_at). + + Returns: + Dict with counts of cleaned up resources by type. + """ + current_time = time.time() + cutoff_time = current_time - self.expiration_timeout + + # Collect expired session IDs first for cascade logic + expired_session_ids = self._session_mgr.get_expired_ids(cutoff_time) + + # Perform actual cleanup in dependency order + sessions_removed = self._session_mgr.cleanup_expired(cutoff_time) + models_removed = self._model_mgr.cleanup_expired(cutoff_time, expired_session_ids) + samplings_removed = self._sampling_mgr.cleanup_expired(cutoff_time, expired_session_ids) + futures_removed = self._future_mgr.cleanup_expired(cutoff_time) + + return { + 'sessions': sessions_removed, + 'models': models_removed, + 'sampling_sessions': samplings_removed, + 'futures': futures_removed, + } + + async def _cleanup_loop(self) -> None: + """Background task that periodically cleans up expired resources.""" + while self._cleanup_running: + try: + await asyncio.sleep(self.cleanup_interval) + stats = self.cleanup_expired_resources() + if any(stats.values()): + logger.debug(f'[ServerState Cleanup] Removed expired resources: {stats}') + except asyncio.CancelledError: + break + except Exception as e: + logger.warning(f'[ServerState Cleanup] Error during cleanup: {e}') + continue + + def start_cleanup_task(self) -> bool: + """Start the background cleanup task. + + Returns: + True if task was started, False if already running. + """ + if self._cleanup_running: + return False + self._cleanup_running = True + self._cleanup_task = asyncio.create_task(self._cleanup_loop()) + return True + + def stop_cleanup_task(self) -> bool: + """Stop the background cleanup task. + + Returns: + True if task was stopped, False if not running. + """ + if not self._cleanup_running: + return False + self._cleanup_running = False + if self._cleanup_task: + self._cleanup_task.cancel() + self._cleanup_task = None + return True + + def get_cleanup_stats(self) -> dict[str, Any]: + """Get current cleanup configuration and resource counts. + + Returns: + Dict with cleanup configuration and task status. + """ + return { + 'expiration_timeout': self.expiration_timeout, + 'cleanup_interval': self.cleanup_interval, + 'cleanup_running': self._cleanup_running, + 'resource_counts': { + 'sessions': self._session_mgr.count(), + 'models': self._model_mgr.count(), + 'sampling_sessions': self._sampling_mgr.count(), + 'futures': self._future_mgr.count(), + }, + } + + +# --------------------------------------------------------------------------- +# Ray proxy +# --------------------------------------------------------------------------- + + +class ServerStateProxy: + """ + Proxy for interacting with a ServerState Ray actor. + + Wraps Ray remote calls to provide a synchronous-looking API for + interacting with the distributed ServerState actor. + """ + + def __init__(self, actor_handle) -> None: + self._actor = actor_handle + + # ----- Session Management ----- + + def create_session(self, payload: dict[str, Any]) -> str: + return ray.get(self._actor.create_session.remote(payload)) + + def touch_session(self, session_id: str) -> bool: + return ray.get(self._actor.touch_session.remote(session_id)) + + def get_session_last_heartbeat(self, session_id: str) -> float | None: + return ray.get(self._actor.get_session_last_heartbeat.remote(session_id)) + + # ----- Model Registration ----- + + def register_model(self, payload: dict[str, Any], model_id: str | None = None, token: str | None = None) -> str: + return ray.get(self._actor.register_model.remote(payload, model_id, token)) + + def unload_model(self, model_id: str) -> bool: + return ray.get(self._actor.unload_model.remote(model_id)) + + def get_model_metadata(self, model_id: str) -> dict[str, Any] | None: + return ray.get(self._actor.get_model_metadata.remote(model_id)) + + # ----- Sampling Session Management ----- + + def create_sampling_session(self, payload: dict[str, Any], sampling_session_id: str | None = None) -> str: + return ray.get(self._actor.create_sampling_session.remote(payload, sampling_session_id)) + + def get_sampling_session(self, sampling_session_id: str) -> dict[str, Any] | None: + return ray.get(self._actor.get_sampling_session.remote(sampling_session_id)) + + # ----- Future Management ----- + + def get_future(self, request_id: str) -> dict[str, Any] | None: + return ray.get(self._actor.get_future.remote(request_id)) + + def store_future_status( + self, + request_id: str, + status: str, + model_id: str | None, + reason: str | None = None, + result: Any = None, + queue_state: str | None = None, + queue_state_reason: str | None = None, + ) -> None: + """Store task status with optional result (synchronous).""" + ray.get( + self._actor.store_future_status.remote(request_id, status, model_id, reason, result, queue_state, + queue_state_reason)) + + # ----- Config Management ----- + + def add_config(self, key: str, value: Any): + return ray.get(self._actor.add_config.remote(key, value)) + + def add_or_get(self, key: str, value: Any) -> Any: + return ray.get(self._actor.add_or_get.remote(key, value)) + + def get_config(self, key: str) -> Any | None: + return ray.get(self._actor.get_config.remote(key)) + + def pop_config(self, key: str) -> Any | None: + return ray.get(self._actor.pop_config.remote(key)) + + def clear_config(self): + return ray.get(self._actor.clear_config.remote()) + + # ----- Resource Cleanup ----- + + def cleanup_expired_resources(self) -> dict[str, int]: + return ray.get(self._actor.cleanup_expired_resources.remote()) + + def start_cleanup_task(self) -> bool: + return ray.get(self._actor.start_cleanup_task.remote()) + + def stop_cleanup_task(self) -> bool: + return ray.get(self._actor.stop_cleanup_task.remote()) + + def get_cleanup_stats(self) -> dict[str, Any]: + return ray.get(self._actor.get_cleanup_stats.remote()) + + +# --------------------------------------------------------------------------- +# Factory +# --------------------------------------------------------------------------- + + +def get_server_state(actor_name: str = 'twinkle_server_state', + auto_start_cleanup: bool = True, + **server_state_kwargs) -> ServerStateProxy: + """Get or create the ServerState Ray actor. + + Ensures only one ServerState actor exists with the given name. Uses a + detached actor so the state persists across driver restarts. + + Args: + actor_name: Name for the Ray actor (default: 'twinkle_server_state'). + auto_start_cleanup: Whether to automatically start the cleanup task (default: True). + **server_state_kwargs: Additional keyword arguments passed to ServerState constructor + (e.g., expiration_timeout, cleanup_interval). + + Returns: + A ServerStateProxy for interacting with the actor. + """ + try: + actor = ray.get_actor(actor_name) + except ValueError: + try: + _ServerState = ray.remote(ServerState) + actor = _ServerState.options(name=actor_name, lifetime='detached').remote(**server_state_kwargs) + if auto_start_cleanup: + try: + ray.get(actor.start_cleanup_task.remote()) + except Exception as e: + logger.debug(f'[ServerState] Warning: Failed to start cleanup task: {e}') + except ValueError: + actor = ray.get_actor(actor_name) + assert actor is not None + return ServerStateProxy(actor) diff --git a/src/twinkle/server/utils/state/session_manager.py b/src/twinkle/server/utils/state/session_manager.py new file mode 100644 index 00000000..e7b154cb --- /dev/null +++ b/src/twinkle/server/utils/state/session_manager.py @@ -0,0 +1,77 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from __future__ import annotations + +import time + +from .base import BaseManager +from .models import SessionRecord + + +class SessionManager(BaseManager[SessionRecord]): + """ + Manages client sessions. + + Expiry is based on `last_heartbeat`; falls back to `created_at` if no + heartbeat has been recorded yet. + """ + + # ----- Session-specific operations ----- + + def touch(self, session_id: str) -> bool: + """Update the heartbeat timestamp for a session. + + Returns: + True if the session exists and was updated, False otherwise. + """ + record = self._store.get(session_id) + if record is None: + return False + record.last_heartbeat = time.time() + return True + + def get_last_heartbeat(self, session_id: str) -> float | None: + """Return the last heartbeat timestamp, or None if the session does not exist.""" + record = self._store.get(session_id) + if record is None: + return None + return record.last_heartbeat + + # ----- Cleanup ----- + + def cleanup_expired(self, cutoff_time: float) -> int: + """Remove sessions whose last activity is older than cutoff_time. + + Args: + cutoff_time: Unix timestamp threshold. + + Returns: + Number of sessions removed. + """ + expired_ids = [] + for session_id, record in self._store.items(): + last_activity = record.last_heartbeat + if last_activity == 0.0: + # Fallback: parse created_at + last_activity = self._parse_timestamp(record.created_at) + if last_activity < cutoff_time: + expired_ids.append(session_id) + + for session_id in expired_ids: + del self._store[session_id] + + return len(expired_ids) + + def get_expired_ids(self, cutoff_time: float) -> list[str]: + """Return IDs of sessions that would be removed at the given cutoff. + + Used by ServerState to cascade-expire dependent resources before + actually deleting the sessions. + """ + expired_ids = [] + for session_id, record in self._store.items(): + last_activity = record.last_heartbeat + if last_activity == 0.0: + last_activity = self._parse_timestamp(record.created_at) + if last_activity < cutoff_time: + expired_ids.append(session_id) + return expired_ids From 6cfaac036f645812f2099da7be8655375c04ce7a Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Thu, 26 Feb 2026 15:39:18 +0800 Subject: [PATCH 2/3] update state --- .../tinker/megatron/server_config_7b.yaml | 69 ++++++++++--------- cookbook/client/tinker/self_congnition.py | 1 - src/twinkle/server/tinker/model.py | 15 ++-- src/twinkle/server/tinker/server.py | 1 - src/twinkle/server/twinkle/model.py | 2 +- src/twinkle/server/utils/adapter_manager.py | 34 --------- .../server/utils/state/model_manager.py | 51 +++++++++++++- src/twinkle/server/utils/state/models.py | 2 +- .../server/utils/state/server_state.py | 29 ++++---- 9 files changed, 109 insertions(+), 95 deletions(-) diff --git a/cookbook/client/tinker/megatron/server_config_7b.yaml b/cookbook/client/tinker/megatron/server_config_7b.yaml index 0c8c0550..1727b1b8 100644 --- a/cookbook/client/tinker/megatron/server_config_7b.yaml +++ b/cookbook/client/tinker/megatron/server_config_7b.yaml @@ -21,6 +21,8 @@ applications: route_prefix: /api/v1 # API endpoint prefix (Tinker-compatible) import_path: server # Python module to import args: + server_config: + per_token_adapter_limit: 1 # Maximum number of adapters per token (globally) supported_models: - Qwen/Qwen2.5-7B-Instruct deployments: @@ -56,7 +58,6 @@ applications: adapter_config: adapter_timeout: 30 # Seconds before idle adapter unload adapter_max_lifetime: 36000 # Maximum lifetime of an adapter in seconds (e.g., 10 hours) - per_token_adapter_limit: 30 deployments: - name: ModelManagement autoscaling_config: @@ -71,36 +72,36 @@ applications: # 3. Sampler Service - Runs inference / sampling using vLLM engine # Used for generating text from the model (e.g., evaluating LoRA results). - - name: sampler-Qwen2.5-7B-Instruct - route_prefix: /api/v1/sampler/Qwen/Qwen2.5-7B-Instruct - import_path: sampler - args: - model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope model identifier - nproc_per_node: 2 # Number of GPU processes per node - sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler) - engine_args: # vLLM engine-specific settings - max_model_len: 4096 # Maximum sequence length the engine supports - gpu_memory_utilization: 0.5 # Fraction of GPU memory to use (0.0-1.0) - enable_lora: true # Allow loading LoRA adapters during inference - logprobs_mode: processed_logprobs # Logprobs mode for sampling results - device_group: # Logical device group for the sampler - name: sampler - ranks: [2] # GPU rank indices to use - device_type: cuda - device_mesh: - device_type: cuda - dp_size: 1 - queue_config: - rps_limit: 100 # Max requests per second - tps_limit: 100000 # Max tokens per second - deployments: - - name: SamplerManagement - autoscaling_config: - min_replicas: 1 - max_replicas: 1 - target_ongoing_requests: 16 - ray_actor_options: - num_cpus: 0.1 - runtime_env: - env_vars: - TWINKLE_TRUST_REMOTE_CODE: "0" + # - name: sampler-Qwen2.5-7B-Instruct + # route_prefix: /api/v1/sampler/Qwen/Qwen2.5-7B-Instruct + # import_path: sampler + # args: + # model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope model identifier + # nproc_per_node: 2 # Number of GPU processes per node + # sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler) + # engine_args: # vLLM engine-specific settings + # max_model_len: 4096 # Maximum sequence length the engine supports + # gpu_memory_utilization: 0.5 # Fraction of GPU memory to use (0.0-1.0) + # enable_lora: true # Allow loading LoRA adapters during inference + # logprobs_mode: processed_logprobs # Logprobs mode for sampling results + # device_group: # Logical device group for the sampler + # name: sampler + # ranks: [2] # GPU rank indices to use + # device_type: cuda + # device_mesh: + # device_type: cuda + # dp_size: 1 + # queue_config: + # rps_limit: 100 # Max requests per second + # tps_limit: 100000 # Max tokens per second + # deployments: + # - name: SamplerManagement + # autoscaling_config: + # min_replicas: 1 + # max_replicas: 1 + # target_ongoing_requests: 16 + # ray_actor_options: + # num_cpus: 0.1 + # runtime_env: + # env_vars: + # TWINKLE_TRUST_REMOTE_CODE: "0" diff --git a/cookbook/client/tinker/self_congnition.py b/cookbook/client/tinker/self_congnition.py index 240c25c6..326a6f78 100644 --- a/cookbook/client/tinker/self_congnition.py +++ b/cookbook/client/tinker/self_congnition.py @@ -6,7 +6,6 @@ # 2. eval(): Load a trained checkpoint and sample from it to verify # that the model has learned the custom identity. # The server must be running first (see server.py and server_config.yaml). -import numpy as np import os from tqdm import tqdm from tinker import types diff --git a/src/twinkle/server/tinker/model.py b/src/twinkle/server/tinker/model.py index 64bfacab..55d7e3bd 100644 --- a/src/twinkle/server/tinker/model.py +++ b/src/twinkle/server/tinker/model.py @@ -188,18 +188,21 @@ async def create_model(self, request: Request, body: types.CreateModelRequest) - Returns: UntypedAPIFuture wrapping CreateModelResponse with model_id """ - # Register a new model_id for each create_model call - model_id = self.state.register_model(body.model_dump(), token=request.state.token) async def _create_adapter(): + model_id = None try: + # Register a new model_id for each create_model call + model_id = self.state.register_model(body.model_dump(), token=request.state.token) + + # Create a new LoRA adapter for the model if body.lora_config: # TODO: support more lora config parameters, train_unembed, etc. lora_cfg = LoraConfig(r=body.lora_config.rank, target_modules='all-linear') adapter_name = self.get_adapter_name(adapter_name=model_id) - # Register adapter FIRST (limit check happens inside register_adapter) + # Register adapter FIRST self.register_adapter(adapter_name, request.state.token, session_id=body.session_id) # Create adapter AFTER successful registration @@ -218,8 +221,9 @@ async def _create_adapter(): return types.CreateModelResponse(model_id=model_id) except Exception: # Ensure we don't leave stale grad state. - adapter_name = self.get_adapter_name(adapter_name=model_id) - self._cleanup_adapter(adapter_name) + if model_id: + adapter_name = self.get_adapter_name(adapter_name=model_id) + self._cleanup_adapter(adapter_name) logger.error(traceback.format_exc()) return types.RequestFailedResponse( @@ -229,7 +233,6 @@ async def _create_adapter(): return await self.schedule_task( _create_adapter, - model_id=model_id, token=request.state.token, task_type='create_model', ) diff --git a/src/twinkle/server/tinker/server.py b/src/twinkle/server/tinker/server.py index 3c9f4493..808f23c1 100644 --- a/src/twinkle/server/tinker/server.py +++ b/src/twinkle/server/tinker/server.py @@ -12,7 +12,6 @@ from __future__ import annotations import asyncio -import dataclasses import httpx import logging import os diff --git a/src/twinkle/server/twinkle/model.py b/src/twinkle/server/twinkle/model.py index 1fcf6f8a..858d0716 100644 --- a/src/twinkle/server/twinkle/model.py +++ b/src/twinkle/server/twinkle/model.py @@ -516,7 +516,7 @@ def add_adapter_to_model(self, request: Request, body: AddAdapterRequest): token = request.state.token training_run_manager = create_training_run_manager(token) - # Register adapter FIRST (limit check happens inside register_adapter) + # Register adapter FIRST self.register_adapter(adapter_name, token) # Create adapter AFTER successful registration diff --git a/src/twinkle/server/utils/adapter_manager.py b/src/twinkle/server/utils/adapter_manager.py index 06bdbfc3..8337ed6b 100644 --- a/src/twinkle/server/utils/adapter_manager.py +++ b/src/twinkle/server/utils/adapter_manager.py @@ -43,7 +43,6 @@ class AdapterManagerMixin: def _init_adapter_manager( self, adapter_timeout: float = 1800.0, - per_token_adapter_limit: int = 30, adapter_max_lifetime: float = 12 * 60 * 60, ) -> None: """Initialize the adapter manager. @@ -54,13 +53,10 @@ def _init_adapter_manager( adapter_timeout: Timeout in seconds for inactive adapters and session-based expiration. Default is 1800.0 (30 minutes). Adapters linked to sessions will expire when their session hasn't been touched for this duration. - per_token_adapter_limit: Maximum number of adapters per user token. - Default is 30. adapter_max_lifetime: Maximum lifetime in seconds for an adapter since creation. Default is 43200.0 (12 hours). If <= 0, lifetime enforcement is disabled. """ self._adapter_timeout = adapter_timeout - self._per_token_adapter_limit = per_token_adapter_limit self._adapter_max_lifetime = adapter_max_lifetime # Adapter lifecycle tracking @@ -80,15 +76,7 @@ def register_adapter(self, adapter_name: str, token: str, session_id: str | None token: User token that owns this adapter. session_id: Optional session ID to associate with this adapter. If provided, adapter will expire when the session expires. - - Raises: - RuntimeError: If adapter limit is exceeded for this token. """ - # Check adapter limit BEFORE registering - allowed, reason = self.check_adapter_limit(token) - if not allowed: - raise RuntimeError(reason) - current_time = time.time() self._adapter_records[adapter_name] = { 'token': token, @@ -351,25 +339,3 @@ def stop_adapter_countdown(self) -> None: # Wait for thread to finish (it checks the flag every second) self._adapter_countdown_thread.join(timeout=2.0) logger.debug('[AdapterManager] Countdown thread stopped') - - def check_adapter_limit(self, token: str) -> tuple[bool, str | None]: - """Check adapter count for a user token. - - This method enforces per-user adapter limits to prevent resource exhaustion. - Counts adapters directly from _adapter_records instead of using state storage. - - Args: - token: User token to check. - - Returns: - Tuple of (allowed: bool, reason: Optional[str]). - If allowed is False, reason contains the explanation. - """ - # Count adapters directly from _adapter_records - current_count = sum(1 for record in self._adapter_records.values() - if record.get('token') == token and not record.get('expiring', False)) - - # Check if current count exceeds limit - if current_count >= self._per_token_adapter_limit: - return False, f'Adapter limit exceeded: {current_count}/{self._per_token_adapter_limit} adapters' - return True, None diff --git a/src/twinkle/server/utils/state/model_manager.py b/src/twinkle/server/utils/state/model_manager.py index 433a607b..2eb98b7f 100644 --- a/src/twinkle/server/utils/state/model_manager.py +++ b/src/twinkle/server/utils/state/model_manager.py @@ -11,8 +11,52 @@ class ModelManager(BaseManager[ModelRecord]): Expiry is based on `created_at`. A model is also considered expired if its owning session has already been removed (cascade expiry). + + Enforces a per-token model limit across all model instances (server-global). """ + def __init__(self, expiration_timeout: float, per_token_model_limit: int = 30) -> None: + super().__init__(expiration_timeout) + self._per_token_model_limit = per_token_model_limit + # token -> set of model_ids owned by that token + self._token_models: dict[str, set[str]] = {} + + # ----- CRUD ----- + + def add(self, model_id: str, record: ModelRecord) -> None: + """Store a record under the given ID. + + Args: + model_id: Unique identifier for the model. + record: ModelRecord to store. + + Raises: + RuntimeError: If the token has reached per_token_model_limit. + """ + token = record.token + current_ids = self._token_models.get(token, set()) + if len(current_ids) >= self._per_token_model_limit: + raise RuntimeError(f'Model limit exceeded for token {token[:8]}...: ' + f'{len(current_ids)}/{self._per_token_model_limit} models') + self._token_models.setdefault(token, set()).add(model_id) + self._store[model_id] = record + + def remove(self, model_id: str) -> bool: + """Remove a record by ID and clean up token ownership. + + Returns: + True if the record existed and was removed, False otherwise. + """ + record = self._store.pop(model_id, None) + if record is None: + return False + token = record.token + if token and token in self._token_models: + self._token_models[token].discard(model_id) + if not self._token_models[token]: + del self._token_models[token] + return True + # ----- Cleanup ----- def cleanup_expired(self, cutoff_time: float, expired_session_ids: list[str] | None = None) -> int: @@ -42,6 +86,11 @@ def cleanup_expired(self, cutoff_time: float, expired_session_ids: list[str] | N expired_ids.append(model_id) for model_id in expired_ids: - del self._store[model_id] + record = self._store.pop(model_id) + token = record.token + if token and token in self._token_models: + self._token_models[token].discard(model_id) + if not self._token_models[token]: + del self._token_models[token] return len(expired_ids) diff --git a/src/twinkle/server/utils/state/models.py b/src/twinkle/server/utils/state/models.py index 343c4998..d8499ff8 100644 --- a/src/twinkle/server/utils/state/models.py +++ b/src/twinkle/server/utils/state/models.py @@ -24,12 +24,12 @@ class SessionRecord(BaseModel): class ModelRecord(BaseModel): """Represents a registered model.""" + token: str session_id: str | None = None model_seq_id: Any = None base_model: str | None = None user_metadata: dict[str, Any] = Field(default_factory=dict) lora_config: Any = None - token: str | None = None created_at: str = Field(default_factory=_now_iso) diff --git a/src/twinkle/server/utils/state/server_state.py b/src/twinkle/server/utils/state/server_state.py index 3ac52163..82605410 100644 --- a/src/twinkle/server/utils/state/server_state.py +++ b/src/twinkle/server/utils/state/server_state.py @@ -38,9 +38,10 @@ def __init__( self, expiration_timeout: float = 86400.0, # 24 hours in seconds cleanup_interval: float = 3600.0, # 1 hour in seconds + per_token_model_limit: int = 30, **kwargs) -> None: self._session_mgr = SessionManager(expiration_timeout) - self._model_mgr = ModelManager(expiration_timeout) + self._model_mgr = ModelManager(expiration_timeout, per_token_model_limit) self._sampling_mgr = SamplingSessionManager(expiration_timeout) self._future_mgr = FutureManager(expiration_timeout) self._config_mgr = ConfigManager() @@ -88,13 +89,13 @@ def get_session_last_heartbeat(self, session_id: str) -> float | None: # ----- Model Registration ----- - def register_model(self, payload: dict[str, Any], model_id: str | None = None, token: str | None = None) -> str: + def register_model(self, payload: dict[str, Any], token: str, model_id: str | None = None) -> str: """Register a new model with the server state. Args: payload: Model configuration containing base_model, lora_config, etc. + token: User token that owns this model. Required. model_id: Optional explicit model_id; otherwise auto-generated. - token: Optional user token for tracking ownership. Returns: The model_id for the registered model. @@ -343,8 +344,8 @@ def get_session_last_heartbeat(self, session_id: str) -> float | None: # ----- Model Registration ----- - def register_model(self, payload: dict[str, Any], model_id: str | None = None, token: str | None = None) -> str: - return ray.get(self._actor.register_model.remote(payload, model_id, token)) + def register_model(self, payload: dict[str, Any], token: str, model_id: str | None = None) -> str: + return ray.get(self._actor.register_model.remote(payload, token, model_id)) def unload_model(self, model_id: str) -> bool: return ray.get(self._actor.unload_model.remote(model_id)) @@ -417,9 +418,7 @@ def get_cleanup_stats(self) -> dict[str, Any]: # --------------------------------------------------------------------------- -def get_server_state(actor_name: str = 'twinkle_server_state', - auto_start_cleanup: bool = True, - **server_state_kwargs) -> ServerStateProxy: +def get_server_state(actor_name: str = 'twinkle_server_state', **kwargs) -> ServerStateProxy: """Get or create the ServerState Ray actor. Ensures only one ServerState actor exists with the given name. Uses a @@ -427,8 +426,7 @@ def get_server_state(actor_name: str = 'twinkle_server_state', Args: actor_name: Name for the Ray actor (default: 'twinkle_server_state'). - auto_start_cleanup: Whether to automatically start the cleanup task (default: True). - **server_state_kwargs: Additional keyword arguments passed to ServerState constructor + **kwargs: Additional keyword arguments passed to ServerState constructor (e.g., expiration_timeout, cleanup_interval). Returns: @@ -439,12 +437,11 @@ def get_server_state(actor_name: str = 'twinkle_server_state', except ValueError: try: _ServerState = ray.remote(ServerState) - actor = _ServerState.options(name=actor_name, lifetime='detached').remote(**server_state_kwargs) - if auto_start_cleanup: - try: - ray.get(actor.start_cleanup_task.remote()) - except Exception as e: - logger.debug(f'[ServerState] Warning: Failed to start cleanup task: {e}') + actor = _ServerState.options(name=actor_name, lifetime='detached').remote(**kwargs) + try: + ray.get(actor.start_cleanup_task.remote()) + except Exception as e: + logger.debug(f'[ServerState] Warning: Failed to start cleanup task: {e}') except ValueError: actor = ray.get_actor(actor_name) assert actor is not None From 6e823c9083c1c06e200d85e40f80fcb26c24811f Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Thu, 26 Feb 2026 15:59:45 +0800 Subject: [PATCH 3/3] update doc --- cookbook/client/tinker/megatron/server_config.yaml | 3 ++- cookbook/client/tinker/megatron/server_config_7b.yaml | 2 +- cookbook/client/tinker/transformer/server_config.yaml | 4 ++-- cookbook/client/twinkle/megatron/server_config.yaml | 4 ++-- cookbook/client/twinkle/transformer/server_config.yaml | 5 ++--- docs/source_en/Usage Guide/Server and Client/Server.md | 3 ++- .../\346\234\215\345\212\241\347\253\257.md" | 3 ++- src/twinkle/server/utils/state/model_manager.py | 2 +- 8 files changed, 14 insertions(+), 12 deletions(-) diff --git a/cookbook/client/tinker/megatron/server_config.yaml b/cookbook/client/tinker/megatron/server_config.yaml index 74c0e717..18b0c1d2 100644 --- a/cookbook/client/tinker/megatron/server_config.yaml +++ b/cookbook/client/tinker/megatron/server_config.yaml @@ -21,6 +21,8 @@ applications: route_prefix: /api/v1 # API endpoint prefix (Tinker-compatible) import_path: server # Python module to import args: + server_config: + per_token_model_limit: 3 # Maximum number of models (adapters) per token (server-globally enforced) deployments: - name: TinkerCompatServer @@ -95,7 +97,6 @@ applications: rps_limit: 20 # Max requests per second tps_limit: 16000 # Max tokens per second adapter_config: - per_token_adapter_limit: 3 # Max concurrent LoRA adapters adapter_timeout: 30 # Seconds before idle adapter unload adapter_max_lifetime: 36000 # Maximum lifetime of an adapter in seconds (e.g., 10 hours) deployments: diff --git a/cookbook/client/tinker/megatron/server_config_7b.yaml b/cookbook/client/tinker/megatron/server_config_7b.yaml index 1727b1b8..dc47d796 100644 --- a/cookbook/client/tinker/megatron/server_config_7b.yaml +++ b/cookbook/client/tinker/megatron/server_config_7b.yaml @@ -22,7 +22,7 @@ applications: import_path: server # Python module to import args: server_config: - per_token_adapter_limit: 1 # Maximum number of adapters per token (globally) + per_token_model_limit: 1 # Maximum number of models (adapters) per token (server-globally enforced) supported_models: - Qwen/Qwen2.5-7B-Instruct deployments: diff --git a/cookbook/client/tinker/transformer/server_config.yaml b/cookbook/client/tinker/transformer/server_config.yaml index 20d25f52..f9c7a690 100644 --- a/cookbook/client/tinker/transformer/server_config.yaml +++ b/cookbook/client/tinker/transformer/server_config.yaml @@ -21,7 +21,8 @@ applications: route_prefix: /api/v1 # API endpoint prefix (Tinker-compatible) import_path: server # Python module to import args: - + server_config: + per_token_model_limit: 3 # Maximum number of models (adapters) per token (server-globally enforced) deployments: - name: TinkerCompatServer autoscaling_config: @@ -52,7 +53,6 @@ applications: rps_limit: 100 # Max requests per second tps_limit: 100000 # Max tokens per second adapter_config: - per_token_adapter_limit: 30 # Max concurrent LoRA adapters adapter_timeout: 1800 # Seconds before idle adapter unload deployments: - name: ModelManagement diff --git a/cookbook/client/twinkle/megatron/server_config.yaml b/cookbook/client/twinkle/megatron/server_config.yaml index bb67bcfb..f431bb21 100644 --- a/cookbook/client/twinkle/megatron/server_config.yaml +++ b/cookbook/client/twinkle/megatron/server_config.yaml @@ -21,7 +21,8 @@ applications: route_prefix: /server # API endpoint prefix import_path: server # Python module to import args: - + server_config: + per_token_model_limit: 3 # Maximum number of models (adapters) per token (server-globally enforced) deployments: - name: TwinkleServer autoscaling_config: @@ -50,7 +51,6 @@ applications: mesh: [0,1] # Device indices in the mesh mesh_dim_names: ['dp'] # Mesh dimension names: 'dp' = data parallel adapter_config: - per_token_adapter_limit: 30 # Max concurrent LoRA adapters adapter_timeout: 1800 # Seconds before idle adapter unload deployments: - name: ModelManagement diff --git a/cookbook/client/twinkle/transformer/server_config.yaml b/cookbook/client/twinkle/transformer/server_config.yaml index 787f0a0b..3e9e1472 100644 --- a/cookbook/client/twinkle/transformer/server_config.yaml +++ b/cookbook/client/twinkle/transformer/server_config.yaml @@ -21,7 +21,8 @@ applications: route_prefix: /server # API endpoint prefix import_path: server # Python module to import args: - + server_config: + per_token_model_limit: 3 # Maximum number of models (adapters) per token (server-globally enforced) deployments: - name: TwinkleServer autoscaling_config: @@ -40,7 +41,6 @@ applications: use_megatron: false # Use HuggingFace Transformers (not Megatron) model_id: "ms://Qwen/Qwen2.5-3B-Instruct" # ModelScope model identifier to load adapter_config: - per_token_adapter_limit: 30 # Max LoRA adapters that can be active simultaneously adapter_timeout: 1800 # Seconds before an idle adapter is unloaded nproc_per_node: 2 # Number of GPU processes per node device_group: # Logical device group for this model @@ -103,7 +103,6 @@ applications: gpu_memory_utilization: 0.4 max_model_len: 1024 adapter_config: # Adapter lifecycle management - per_token_adapter_limit: 30 # Max LoRA adapters per user adapter_timeout: 1800 # Seconds before idle adapter is unloaded device_group: name: sampler diff --git a/docs/source_en/Usage Guide/Server and Client/Server.md b/docs/source_en/Usage Guide/Server and Client/Server.md index a82002a9..e5b80023 100644 --- a/docs/source_en/Usage Guide/Server and Client/Server.md +++ b/docs/source_en/Usage Guide/Server and Client/Server.md @@ -259,7 +259,6 @@ applications: use_megatron: false # Use Transformers backend model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope model identifier adapter_config: # LoRA adapter configuration - per_token_adapter_limit: 30 # Maximum number of LoRAs that can be activated simultaneously adapter_timeout: 1800 # Idle adapter timeout unload time (seconds) nproc_per_node: 2 # Number of GPU processes per node device_group: # Logical device group @@ -354,6 +353,8 @@ applications: route_prefix: /api/v1 # Tinker protocol API prefix import_path: server args: + server_config: + per_token_model_limit: 30 # Maximum number of models (adapters) per token (server-global) deployments: - name: TinkerCompatServer autoscaling_config: diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/\346\234\215\345\212\241\347\253\257.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/\346\234\215\345\212\241\347\253\257.md" index fd1ef94e..73915e8c 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/\346\234\215\345\212\241\347\253\257.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/\346\234\215\345\212\241\347\253\257.md" @@ -202,7 +202,6 @@ applications: use_megatron: false # 使用 Transformers 后端 model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope 模型标识 adapter_config: # LoRA 适配器配置 - per_token_adapter_limit: 30 # 同时可激活的最大 LoRA 数量 adapter_timeout: 1800 # 空闲适配器超时卸载时间(秒) nproc_per_node: 2 # 每节点 GPU 进程数 device_group: # 逻辑设备组 @@ -297,6 +296,8 @@ applications: route_prefix: /api/v1 # Tinker 协议 API 前缀 import_path: server args: + server_config: + per_token_model_limit: 30 # 每个 token 最多可创建的模型(适配器)数量(服务器全局生效) deployments: - name: TinkerCompatServer autoscaling_config: diff --git a/src/twinkle/server/utils/state/model_manager.py b/src/twinkle/server/utils/state/model_manager.py index 2eb98b7f..9e0d02b8 100644 --- a/src/twinkle/server/utils/state/model_manager.py +++ b/src/twinkle/server/utils/state/model_manager.py @@ -36,7 +36,7 @@ def add(self, model_id: str, record: ModelRecord) -> None: token = record.token current_ids = self._token_models.get(token, set()) if len(current_ids) >= self._per_token_model_limit: - raise RuntimeError(f'Model limit exceeded for token {token[:8]}...: ' + raise RuntimeError(f'Model limit exceeded: ' f'{len(current_ids)}/{self._per_token_model_limit} models') self._token_models.setdefault(token, set()).add(model_id) self._store[model_id] = record