From d6fbd89bc1317ca11beebd819ce90081797ff2b1 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Wed, 11 Feb 2026 20:31:11 +0800 Subject: [PATCH 01/22] update load --- src/twinkle/server/tinker/common/megatron_model.py | 4 ++-- src/twinkle/server/tinker/common/transformers_model.py | 2 +- src/twinkle/server/utils/io_utils.py | 10 ++++------ src/twinkle/server/utils/validation.py | 2 +- src/twinkle_client/__init__.py | 1 + src/twinkle_client/http/http_utils.py | 1 + 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/twinkle/server/tinker/common/megatron_model.py b/src/twinkle/server/tinker/common/megatron_model.py index 1ac9f428..7a436403 100644 --- a/src/twinkle/server/tinker/common/megatron_model.py +++ b/src/twinkle/server/tinker/common/megatron_model.py @@ -175,13 +175,13 @@ def load(self, checkpoint_dir: str, **kwargs): # Create checkpoint manager with the token checkpoint_manager = create_checkpoint_manager(token) - + # Use resolve_load_path to handle path resolution resolved = checkpoint_manager.resolve_load_path(checkpoint_dir) if resolved.is_twinkle_path: # Load from twinkle checkpoint - return super().load(name=resolved.checkpoint_name, output_dir=str(resolved.checkpoint_dir), **kwargs) + return super().load(name=resolved.checkpoint_name, output_dir=resolved.checkpoint_dir, **kwargs) else: # Load from hub return super().load(name=resolved.checkpoint_name, **kwargs) diff --git a/src/twinkle/server/tinker/common/transformers_model.py b/src/twinkle/server/tinker/common/transformers_model.py index eecc2a7f..feff9036 100644 --- a/src/twinkle/server/tinker/common/transformers_model.py +++ b/src/twinkle/server/tinker/common/transformers_model.py @@ -137,7 +137,7 @@ def load(self, checkpoint_dir: str, **kwargs): if resolved.is_twinkle_path: # Load from twinkle checkpoint - return super().load(name=resolved.checkpoint_name, output_dir=str(resolved.checkpoint_dir), **kwargs) + return super().load(name=resolved.checkpoint_name, output_dir=resolved.checkpoint_dir, **kwargs) else: # Load from hub return super().load(name=resolved.checkpoint_name, **kwargs) diff --git a/src/twinkle/server/utils/io_utils.py b/src/twinkle/server/utils/io_utils.py index 8a4e9e7b..3926bd5a 100644 --- a/src/twinkle/server/utils/io_utils.py +++ b/src/twinkle/server/utils/io_utils.py @@ -912,18 +912,16 @@ def resolve_load_path(self, path: str, validate_exists: bool = True) -> Resolved f"Checkpoint not found or access denied: {path}" ) - # Get the checkpoint directory - checkpoint_dir = str(self.get_ckpt_dir(training_run_id, checkpoint_id)) + # Get the checkpoint directory parent path (no checkpoint name in the path) + checkpoint_dir = self.get_ckpt_dir(training_run_id, checkpoint_id).parent if validate_exists: - # Verify the directory exists - from pathlib import Path as PathLib - if not PathLib(checkpoint_dir).exists(): + if not checkpoint_dir.exists(): raise ValueError(f"Checkpoint directory not found: {checkpoint_dir}") return ResolvedLoadPath( checkpoint_name=checkpoint_name, - checkpoint_dir=checkpoint_dir, + checkpoint_dir=checkpoint_dir.as_posix(), is_twinkle_path=True, training_run_id=training_run_id, checkpoint_id=checkpoint_id diff --git a/src/twinkle/server/utils/validation.py b/src/twinkle/server/utils/validation.py index 1f553da5..1f63b44c 100644 --- a/src/twinkle/server/utils/validation.py +++ b/src/twinkle/server/utils/validation.py @@ -22,7 +22,7 @@ async def verify_request_token(request: Request, call_next): Returns: JSONResponse with error if validation fails, otherwise the response from call_next """ - authorization = request.headers.get("Authorization") + authorization = request.headers.get("Twinkle-Authorization") token = authorization[7:] if authorization and authorization.startswith("Bearer ") else authorization if not is_token_valid(token): return JSONResponse(status_code=403, content={"detail": "Invalid token"}) diff --git a/src/twinkle_client/__init__.py b/src/twinkle_client/__init__.py index 782956c1..1ad6812d 100644 --- a/src/twinkle_client/__init__.py +++ b/src/twinkle_client/__init__.py @@ -29,6 +29,7 @@ def init_tinker_compat_client(base_url: Optional[str] = None, api_key: Optional[ default_headers = { "X-Ray-Serve-Request-Id": get_request_id(), "Authorization": 'Bearer ' + api_key, + "Twinkle-Authorization": 'Bearer ' + api_key, # For server compatibility } | kwargs.pop("default_headers", {}) service_client = ServiceClient(base_url=base_url, api_key=api_key, default_headers=default_headers, **kwargs) diff --git a/src/twinkle_client/http/http_utils.py b/src/twinkle_client/http/http_utils.py index 1e927c2a..0743ca2c 100644 --- a/src/twinkle_client/http/http_utils.py +++ b/src/twinkle_client/http/http_utils.py @@ -17,6 +17,7 @@ def _build_headers(additional_headers: Optional[Dict[str, str]] = None) -> Dict[ headers = { "X-Ray-Serve-Request-Id": get_request_id(), "Authorization": 'Bearer ' + get_api_key(), + "Twinkle-Authorization": 'Bearer ' + get_api_key(), # For server compatibility } if additional_headers: From 67111e5f4b75eb3fa1553ad797c2a5b12b541821 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Thu, 12 Feb 2026 12:24:55 +0800 Subject: [PATCH 02/22] update --- src/twinkle/server/tinker/model.py | 61 ++++++++-- src/twinkle/server/tinker/sampler.py | 3 +- src/twinkle/server/utils/task_queue.py | 161 ++++++++++++++----------- 3 files changed, 145 insertions(+), 80 deletions(-) diff --git a/src/twinkle/server/tinker/model.py b/src/twinkle/server/tinker/model.py index 671d5197..545e445f 100644 --- a/src/twinkle/server/tinker/model.py +++ b/src/twinkle/server/tinker/model.py @@ -125,12 +125,21 @@ def __init__(self, nproc_per_node: int, device_group: Dict[str, Any], self.base_model = model_id self.state: ServerStateProxy = get_server_state() + # Training state: require at least one forward_backward before optim_step. + # Keyed by adapter_name. + self._grad_ready: Dict[str, bool] = {} + # Initialize task queue self._init_task_queue(TaskQueueConfig.from_dict(queue_config)) self._init_adapter_manager(**adapter_config) self.start_adapter_countdown() + def _on_adapter_expired(self, adapter_name: str, token: str) -> None: + # Called from AdapterManagerMixin's countdown thread. + self._grad_ready.pop(adapter_name, None) + super()._on_adapter_expired(adapter_name, token) + @app.post('/create_model') async def create_model( self, request: Request, @@ -182,12 +191,18 @@ async def _create_adapter(): self.model.set_optimizer('Adam', adapter_name=adapter_name) + # Fresh adapter has no accumulated gradients. + self._grad_ready[adapter_name] = False + training_run_manager = create_training_run_manager( request.state.token) training_run_manager.save(model_id, body) return types.CreateModelResponse(model_id=model_id) except Exception: + # Ensure we don't leave stale grad state. + adapter_name = self.get_adapter_name(adapter_name=model_id) + self._grad_ready.pop(adapter_name, None) # If adapter creation fails, decrement the count self.check_adapter_limit(request.state.token, False) @@ -198,9 +213,11 @@ async def _create_adapter(): ) return await self.schedule_task( - _create_adapter(), + _create_adapter, model_id=model_id, token=request.state.token, + adapter_name=model_id, + task_type='create_model', ) @@ -257,6 +274,7 @@ async def _do_unload(): # Only remove adapter, not the base model adapter_name = self.get_adapter_name( adapter_name=body.model_id) + self._grad_ready.pop(adapter_name, None) if self.get_adapter_info(adapter_name): self.model.remove_adapter(adapter_name) # Unregister adapter from rate limiter @@ -267,9 +285,11 @@ async def _do_unload(): return types.UnloadModelResponse(model_id=body.model_id) return await self.schedule_task( - _do_unload(), + _do_unload, model_id=body.model_id, token=request.state.token, + adapter_name=body.model_id, + task_type='unload_model', ) @app.post('/forward') @@ -324,10 +344,12 @@ async def _do_forward(): len(d.model_input.to_ints()) for d in body.forward_input.data ) return await self.schedule_task( - _do_forward(), + _do_forward, model_id=body.model_id, token=request.state.token, input_tokens=input_tokens, + adapter_name=body.model_id, + task_type='forward', ) @app.post('/forward_backward') @@ -371,6 +393,8 @@ async def _do_forward_backward(): loss_fn=loss_fn, **loss_fn_config) output_type = 'ImportanceSamplingLossReturn' if loss_fn == 'importance_sampling' else 'CrossEntropyLossReturn' + # Mark gradients as ready after a successful forward_backward. + self._grad_ready[adapter_name] = True return types.ForwardBackwardOutput( loss_fn_output_type=output_type, loss_fn_outputs=output, @@ -388,10 +412,12 @@ async def _do_forward_backward(): len(d.model_input.to_ints()) for d in body.forward_backward_input.data ) return await self.schedule_task( - _do_forward_backward(), + _do_forward_backward, model_id=body.model_id, token=request.state.token, input_tokens=input_tokens, + adapter_name=body.model_id, + task_type='forward_backward', ) @app.post('/optim_step') @@ -416,11 +442,19 @@ async def _do_optim(): adapter_name=body.model_id) self.assert_adapter_exists(adapter_name=adapter_name) + # Disallow empty step (must have at least one forward_backward since last step) + if not self._grad_ready.get(adapter_name, False): + raise RuntimeError( + f"No accumulated gradients for adapter={adapter_name}; call forward_backward before optim_step" + ) + # Touch adapter to reset inactivity counter self.touch_adapter(adapter_name) self.model.step(adam_params=body.adam_params, adapter_name=adapter_name) + # Clear grad-ready after a successful step. + self._grad_ready[adapter_name] = False metrics = self.model.calculate_metric(is_training=True, adapter_name=adapter_name) return types.OptimStepResponse(metrics=metrics) except Exception: @@ -431,9 +465,11 @@ async def _do_optim(): ) return await self.schedule_task( - _do_optim(), + _do_optim, model_id=body.model_id, token=request.state.token, + adapter_name=body.model_id, + task_type='optim_step', ) @app.post('/save_weights') @@ -492,9 +528,11 @@ async def _do_save(): ) return await self.schedule_task( - _do_save(), + _do_save, model_id=body.model_id, token=request.state.token, + adapter_name=body.model_id, + task_type='save_weights', ) @app.post('/save_weights_for_sampler') @@ -566,9 +604,11 @@ async def _do_save_for_sampler(): ) return await self.schedule_task( - _do_save_for_sampler(), + _do_save_for_sampler, model_id=body.model_id, token=request.state.token, + adapter_name=body.model_id, + task_type='save_weights_for_sampler', ) @app.post('/load_weights') @@ -609,6 +649,9 @@ async def _do_load(): load_optimizer=load_optimizer, adapter_name=adapter_name, token=token) + + # Loading a checkpoint should reset step readiness. + self._grad_ready[adapter_name] = False return types.LoadWeightsResponse(path=body.path, type='load_weights') except Exception: @@ -619,9 +662,11 @@ async def _do_load(): ) return await self.schedule_task( - _do_load(), + _do_load, model_id=body.model_id, token=request.state.token, + adapter_name=body.model_id, + task_type='load_weights', ) return ModelManagement.options(**deploy_options).bind( diff --git a/src/twinkle/server/tinker/sampler.py b/src/twinkle/server/tinker/sampler.py index 6ac3a6c3..a17eb777 100644 --- a/src/twinkle/server/tinker/sampler.py +++ b/src/twinkle/server/tinker/sampler.py @@ -220,9 +220,10 @@ async def _do_sample(): # Calculate input tokens for rate limiting input_tokens = len(body.prompt.to_ints()) return await self.schedule_task( - _do_sample(), + _do_sample, token=request.state.token, input_tokens=input_tokens, + task_type='sample', ) return SamplerManagement.options(**deploy_options).bind( diff --git a/src/twinkle/server/utils/task_queue.py b/src/twinkle/server/utils/task_queue.py index eca31052..55bc4537 100644 --- a/src/twinkle/server/utils/task_queue.py +++ b/src/twinkle/server/utils/task_queue.py @@ -12,9 +12,9 @@ import asyncio import traceback import uuid -from dataclasses import dataclass, field +from dataclasses import dataclass from enum import Enum -from typing import TYPE_CHECKING, Any, Callable, Coroutine, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Callable, Coroutine, Dict, Optional from twinkle.utils.logger import get_logger from .rate_limiter import RateLimiter @@ -132,7 +132,7 @@ async def my_endpoint(self, request, body): async def _do_work(): return await some_operation() return await self.schedule_task( - _do_work(), + _do_work, model_id=body.model_id, token=request.state.token, input_tokens=len(body.tokens) @@ -149,7 +149,9 @@ def _init_task_queue(self, config: Optional[TaskQueueConfig] = None) -> None: config: Optional TaskQueueConfig. If None, uses default config. """ self._task_queue_config = config or TaskQueueConfig() - self._task_queue: asyncio.Queue = asyncio.Queue() + # Per-key queues to avoid cross-adapter head-of-line blocking. + # Key is usually adapter_name; falls back to token/default. + self._task_queues: Dict[str, asyncio.Queue] = {} # Initialize rate limiter for RPS/TPS control self._rate_limiter = RateLimiter( @@ -162,66 +164,64 @@ def _init_task_queue(self, config: Optional[TaskQueueConfig] = None) -> None: # Start the rate limiter cleanup task self._rate_limiter.start_cleanup_task() - self._worker_task: Optional[asyncio.Task] = None - self._worker_started = False - self._worker_start_lock = asyncio.Lock() + # Per-key workers + self._queue_workers: Dict[str, asyncio.Task] = {} + self._queue_worker_start_lock = asyncio.Lock() + + @staticmethod + def _queue_key( + adapter_name: Optional[str], + token: Optional[str], + ) -> str: + if adapter_name: + return f"adapter:{adapter_name}" + if token: + return f"token:{token}" + return "default" + + async def _ensure_worker_started(self, queue_key: str) -> None: + """Ensure a background worker is running for a queue key.""" + if queue_key in self._queue_workers and not self._queue_workers[queue_key].done(): + return - async def _ensure_worker_started(self) -> None: - """Ensure the background worker is running. + async with self._queue_worker_start_lock: + worker = self._queue_workers.get(queue_key) + if worker is not None and not worker.done(): + return - Thread-safe: Uses asyncio.Lock to prevent race conditions when - multiple concurrent requests try to start the worker simultaneously. - """ - # Fast path: avoid lock if already started - if self._worker_started: - return + if queue_key not in self._task_queues: + self._task_queues[queue_key] = asyncio.Queue() - # Slow path: acquire lock to safely check and start - async with self._worker_start_lock: - # Double-check after acquiring lock (another coroutine might have started it) - if not self._worker_started: - logger.debug(f"[TaskQueue] Starting background worker...") - self._worker_task = asyncio.create_task(self._queue_worker()) - self._worker_started = True - logger.debug( - f"[TaskQueue] Background worker started: {self._worker_task}") - - async def _queue_worker(self) -> None: - """Background worker that processes tasks from the queue serially. - - This worker runs indefinitely, pulling tasks from the queue and - executing them one at a time. This ensures thread-safe execution - of model operations that cannot be parallelized. - """ - logger.debug(f"[TaskQueue] Worker started") + self._queue_workers[queue_key] = asyncio.create_task( + self._queue_worker(queue_key) + ) + + async def _queue_worker(self, queue_key: str) -> None: + """Background worker that processes tasks from a specific queue serially.""" + print(f"[TaskQueue] Worker started for {queue_key}") + q = self._task_queues[queue_key] while True: try: - # Wait for a task from the queue - logger.debug( - f"[TaskQueue] Waiting for task... (queue size: {self._task_queue.qsize()})") - request_id, coro, model_id = await self._task_queue.get() + print( + f"[TaskQueue] Waiting for task... key={queue_key} (queue size: {q.qsize()})") + request_id, coro_factory, model_id = await q.get() - logger.debug(f"[TaskQueue] Processing task {request_id}") + print(f"[TaskQueue] Processing task {request_id} key={queue_key}") try: - # Update status to RUNNING self.state.store_future_status( request_id, TaskStatus.RUNNING.value, model_id, queue_state=QueueState.ACTIVE.value ) - # Execute the task + coro = coro_factory() result = await coro - logger.debug( - f"[TaskQueue] Task {request_id} completed successfully") - # Store completed result + print(f"[TaskQueue] Task {request_id} completed successfully") self.state.store_future_status( request_id, TaskStatus.COMPLETED.value, model_id, result=result ) except Exception: - # Store error result - logger.debug( - f"[TaskQueue] Task {request_id} failed with error") + print(f"[TaskQueue] Task {request_id} failed with error") error_payload = { 'error': traceback.format_exc(), 'category': 'Server' @@ -230,22 +230,23 @@ async def _queue_worker(self) -> None: request_id, TaskStatus.FAILED.value, model_id, result=error_payload ) finally: - self._task_queue.task_done() + q.task_done() except asyncio.CancelledError: - logger.warning(f"[TaskQueue] Worker cancelled") + logger.warning(f"[TaskQueue] Worker cancelled key={queue_key}") break except Exception: - # Log but don't crash the worker - logger.warning("Error in task queue worker") + logger.warning(f"Error in task queue worker key={queue_key}") continue async def schedule_task( self, - coro: Coroutine, + coro_factory: Callable[[], Coroutine], model_id: Optional[str] = None, token: Optional[str] = None, input_tokens: int = 0, + adapter_name: Optional[str] = None, + task_type: Optional[str] = None, ) -> Dict[str, Any]: """Schedule an async task with rate limiting and status tracking. @@ -259,17 +260,21 @@ async def schedule_task( 3. Execute tasks serially through a queue Args: - coro: The coroutine to execute. + coro_factory: Factory that creates the coroutine to execute. The coroutine + will be created only after passing rate limiting and when it's time + to execute the queued task. model_id: Optional model_id to associate with the result. token: Optional user token for rate limiting. input_tokens: Number of input tokens for tps rate limiting. + adapter_name: Optional adapter name used for per-adapter queueing. + task_type: Optional task type for logging/observability. Returns: Dict containing request_id and model_id for future retrieval. """ request_id = f"req_{uuid.uuid4().hex}" - logger.debug( + print( f"[TaskQueue] Scheduling task {request_id}, rps_limit={self._task_queue_config.rps_limit}, enabled={self._task_queue_config.enabled}") # 1. Register PENDING status FIRST (fixes race condition) @@ -280,11 +285,11 @@ async def schedule_task( # 2. Check rate limiting if enabled and token provided if self._task_queue_config.enabled and token: - logger.debug( + print( f"[TaskQueue] Checking rate limit for token={token[:8]}... input_tokens={input_tokens}") allowed, reason = await self._rate_limiter.check_and_record(token, input_tokens) if not allowed: - logger.debug(f"[TaskQueue] Rate limited: {reason}") + print(f"[TaskQueue] Rate limited: {reason}") self.state.store_future_status( request_id, TaskStatus.RATE_LIMITED.value, model_id, reason=reason, @@ -292,21 +297,25 @@ async def schedule_task( queue_state_reason=reason ) return {'request_id': request_id, 'model_id': model_id} - logger.debug(f"[TaskQueue] Rate limit check passed") + print(f"[TaskQueue] Rate limit check passed") - # 3. Ensure worker is started - await self._ensure_worker_started() + # 3. Route to per-adapter/per-token queue + queue_key = self._queue_key(adapter_name=adapter_name, token=token) - # 4. Put task in queue and update status - logger.debug( - f"[TaskQueue] Adding task {request_id} to queue (current size: {self._task_queue.qsize()})") - await self._task_queue.put((request_id, coro, model_id)) + # 4. Ensure worker is started for this queue + await self._ensure_worker_started(queue_key) + + # 5. Put task in queue and update status + q = self._task_queues[queue_key] + print( + f"[TaskQueue] Adding task {request_id} to queue key={queue_key} (current size: {q.qsize()}) type={task_type}") + await q.put((request_id, coro_factory, model_id)) self.state.store_future_status( request_id, TaskStatus.QUEUED.value, model_id, queue_state=QueueState.ACTIVE.value ) - logger.debug( - f"[TaskQueue] Task {request_id} queued, new queue size: {self._task_queue.qsize()}") + print( + f"[TaskQueue] Task {request_id} queued, new queue size: {q.qsize()} key={queue_key}") return {'request_id': request_id, 'model_id': model_id} @@ -317,8 +326,10 @@ def get_queue_stats(self) -> Dict[str, Any]: Dict with queue size and worker status. """ return { - 'queue_size': self._task_queue.qsize(), - 'worker_running': self._worker_started and self._worker_task is not None, + 'queue_size': sum(q.qsize() for q in self._task_queues.values()), + 'queue_count': len(self._task_queues), + 'worker_running': any((t is not None and not t.done()) for t in self._queue_workers.values()), + 'worker_count': len(self._queue_workers), 'rate_limit_config': { 'rps_limit': self._task_queue_config.rps_limit, 'tps_limit': self._task_queue_config.tps_limit, @@ -354,12 +365,20 @@ async def shutdown_task_queue(self) -> None: # Stop the rate limiter cleanup task await self._rate_limiter.stop_cleanup_task() - # Cancel the worker task if running - if self._worker_task and not self._worker_task.done(): - self._worker_task.cancel() + # Cancel all queue workers if running + for task in list(self._queue_workers.values()): + if task and not task.done(): + task.cancel() + + for task in list(self._queue_workers.values()): + if not task: + continue try: - await self._worker_task + await task except asyncio.CancelledError: pass - logger.debug("[TaskQueue] Task queue shutdown complete") + self._queue_workers.clear() + self._task_queues.clear() + + print("[TaskQueue] Task queue shutdown complete") From 947a3e48ea776e5743147e8ab88f5fbaf2aa8a67 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Thu, 12 Feb 2026 13:43:38 +0800 Subject: [PATCH 03/22] update --- src/twinkle/server/tinker/model.py | 35 +- src/twinkle/server/utils/adapter_manager.py | 47 +++ src/twinkle/server/utils/task_queue.py | 351 +++++++++++++++----- 3 files changed, 320 insertions(+), 113 deletions(-) diff --git a/src/twinkle/server/tinker/model.py b/src/twinkle/server/tinker/model.py index 545e445f..2bbab8a3 100644 --- a/src/twinkle/server/tinker/model.py +++ b/src/twinkle/server/tinker/model.py @@ -58,11 +58,6 @@ def build_model_app(model_id: str, Returns: Configured Ray Serve deployment bound with parameters """ - # adapter_config can be None; expanding with ** would raise TypeError and break Serve init. - # Normalize to {} so AdapterManagerMixin uses its default timeout/limits. - if adapter_config is None: - adapter_config = {} - app = FastAPI() @app.middleware('http') @@ -125,10 +120,6 @@ def __init__(self, nproc_per_node: int, device_group: Dict[str, Any], self.base_model = model_id self.state: ServerStateProxy = get_server_state() - # Training state: require at least one forward_backward before optim_step. - # Keyed by adapter_name. - self._grad_ready: Dict[str, bool] = {} - # Initialize task queue self._init_task_queue(TaskQueueConfig.from_dict(queue_config)) @@ -137,7 +128,9 @@ def __init__(self, nproc_per_node: int, device_group: Dict[str, Any], def _on_adapter_expired(self, adapter_name: str, token: str) -> None: # Called from AdapterManagerMixin's countdown thread. - self._grad_ready.pop(adapter_name, None) + self.clear_adapter_state(adapter_name) + # Fail any pending tasks for this adapter/model. + self.fail_pending_tasks_for_model(adapter_name, reason='Adapter expired') super()._on_adapter_expired(adapter_name, token) @app.post('/create_model') @@ -192,7 +185,7 @@ async def _create_adapter(): adapter_name=adapter_name) # Fresh adapter has no accumulated gradients. - self._grad_ready[adapter_name] = False + self.set_adapter_state(adapter_name, 'grad_ready', False) training_run_manager = create_training_run_manager( request.state.token) @@ -202,7 +195,7 @@ async def _create_adapter(): except Exception: # Ensure we don't leave stale grad state. adapter_name = self.get_adapter_name(adapter_name=model_id) - self._grad_ready.pop(adapter_name, None) + self.clear_adapter_state(adapter_name) # If adapter creation fails, decrement the count self.check_adapter_limit(request.state.token, False) @@ -216,7 +209,6 @@ async def _create_adapter(): _create_adapter, model_id=model_id, token=request.state.token, - adapter_name=model_id, task_type='create_model', ) @@ -274,7 +266,7 @@ async def _do_unload(): # Only remove adapter, not the base model adapter_name = self.get_adapter_name( adapter_name=body.model_id) - self._grad_ready.pop(adapter_name, None) + self.clear_adapter_state(adapter_name) if self.get_adapter_info(adapter_name): self.model.remove_adapter(adapter_name) # Unregister adapter from rate limiter @@ -288,7 +280,6 @@ async def _do_unload(): _do_unload, model_id=body.model_id, token=request.state.token, - adapter_name=body.model_id, task_type='unload_model', ) @@ -348,7 +339,6 @@ async def _do_forward(): model_id=body.model_id, token=request.state.token, input_tokens=input_tokens, - adapter_name=body.model_id, task_type='forward', ) @@ -394,7 +384,7 @@ async def _do_forward_backward(): **loss_fn_config) output_type = 'ImportanceSamplingLossReturn' if loss_fn == 'importance_sampling' else 'CrossEntropyLossReturn' # Mark gradients as ready after a successful forward_backward. - self._grad_ready[adapter_name] = True + self.set_adapter_state(adapter_name, 'grad_ready', True) return types.ForwardBackwardOutput( loss_fn_output_type=output_type, loss_fn_outputs=output, @@ -416,7 +406,6 @@ async def _do_forward_backward(): model_id=body.model_id, token=request.state.token, input_tokens=input_tokens, - adapter_name=body.model_id, task_type='forward_backward', ) @@ -443,7 +432,7 @@ async def _do_optim(): self.assert_adapter_exists(adapter_name=adapter_name) # Disallow empty step (must have at least one forward_backward since last step) - if not self._grad_ready.get(adapter_name, False): + if not self.get_adapter_state(adapter_name, 'grad_ready', False): raise RuntimeError( f"No accumulated gradients for adapter={adapter_name}; call forward_backward before optim_step" ) @@ -454,7 +443,7 @@ async def _do_optim(): self.model.step(adam_params=body.adam_params, adapter_name=adapter_name) # Clear grad-ready after a successful step. - self._grad_ready[adapter_name] = False + self.set_adapter_state(adapter_name, 'grad_ready', False) metrics = self.model.calculate_metric(is_training=True, adapter_name=adapter_name) return types.OptimStepResponse(metrics=metrics) except Exception: @@ -468,7 +457,6 @@ async def _do_optim(): _do_optim, model_id=body.model_id, token=request.state.token, - adapter_name=body.model_id, task_type='optim_step', ) @@ -531,7 +519,6 @@ async def _do_save(): _do_save, model_id=body.model_id, token=request.state.token, - adapter_name=body.model_id, task_type='save_weights', ) @@ -607,7 +594,6 @@ async def _do_save_for_sampler(): _do_save_for_sampler, model_id=body.model_id, token=request.state.token, - adapter_name=body.model_id, task_type='save_weights_for_sampler', ) @@ -651,7 +637,7 @@ async def _do_load(): token=token) # Loading a checkpoint should reset step readiness. - self._grad_ready[adapter_name] = False + self.set_adapter_state(adapter_name, 'grad_ready', False) return types.LoadWeightsResponse(path=body.path, type='load_weights') except Exception: @@ -665,7 +651,6 @@ async def _do_load(): _do_load, model_id=body.model_id, token=request.state.token, - adapter_name=body.model_id, task_type='load_weights', ) diff --git a/src/twinkle/server/utils/adapter_manager.py b/src/twinkle/server/utils/adapter_manager.py index 02056d1e..846ca4c5 100644 --- a/src/twinkle/server/utils/adapter_manager.py +++ b/src/twinkle/server/utils/adapter_manager.py @@ -82,6 +82,7 @@ def register_adapter(self, adapter_name: str, token: str) -> None: 'last_activity': current_time, 'created_at': current_time, 'inactivity_counter': 0, + 'state': {}, } logger.debug( f"[AdapterManager] Registered adapter {adapter_name} for token {token[:8]}...") @@ -104,6 +105,48 @@ def unregister_adapter(self, adapter_name: str) -> bool: return True return False + def set_adapter_state(self, adapter_name: str, key: str, value: Any) -> None: + """Set a per-adapter state value. + + This is intentionally generic so higher-level services can store + adapter-scoped state (e.g., training readiness) without maintaining + separate side maps. + """ + with self._adapter_lock: + info = self._adapter_records.get(adapter_name) + if info is None: + return + state = info.setdefault('state', {}) + state[key] = value + + def get_adapter_state(self, adapter_name: str, key: str, default: Any = None) -> Any: + """Get a per-adapter state value.""" + with self._adapter_lock: + info = self._adapter_records.get(adapter_name) + if info is None: + return default + state = info.get('state') or {} + return state.get(key, default) + + def pop_adapter_state(self, adapter_name: str, key: str, default: Any = None) -> Any: + """Pop a per-adapter state value.""" + with self._adapter_lock: + info = self._adapter_records.get(adapter_name) + if info is None: + return default + state = info.get('state') + if not isinstance(state, dict): + return default + return state.pop(key, default) + + def clear_adapter_state(self, adapter_name: str) -> None: + """Clear all per-adapter state values.""" + with self._adapter_lock: + info = self._adapter_records.get(adapter_name) + if info is None: + return + info['state'] = {} + def touch_adapter(self, adapter_name: str) -> bool: """Update adapter activity timestamp to prevent timeout. @@ -161,6 +204,10 @@ def _on_adapter_expired(self, adapter_name: str, token: str) -> None: token: User token that owns this adapter. """ try: + # Best-effort cleanup of adapter state + with self._adapter_lock: + if adapter_name in self._adapter_records: + self._adapter_records[adapter_name]['state'] = {} # Remove adapter from model self.model.remove_adapter(adapter_name) logger.info( diff --git a/src/twinkle/server/utils/task_queue.py b/src/twinkle/server/utils/task_queue.py index 55bc4537..48158368 100644 --- a/src/twinkle/server/utils/task_queue.py +++ b/src/twinkle/server/utils/task_queue.py @@ -10,11 +10,13 @@ from __future__ import annotations import asyncio +import time import traceback import uuid +from collections import deque from dataclasses import dataclass from enum import Enum -from typing import TYPE_CHECKING, Any, Callable, Coroutine, Dict, Optional +from typing import TYPE_CHECKING, Any, Callable, Coroutine, Deque, Dict, Optional from twinkle.utils.logger import get_logger from .rate_limiter import RateLimiter @@ -108,6 +110,57 @@ def from_dict(cls, config_dict: Optional[Dict[str, Any]] = None) -> 'TaskQueueCo config_dict['token_cleanup_interval']) return config +@dataclass +class _QueuedTask: + request_id: str + coro_factory: Callable[[], Coroutine] + model_id: Optional[str] + token: Optional[str] + input_tokens: int + task_type: Optional[str] + created_at: float + first_rate_limited_at: Optional[float] = None + + +class _DequeTaskQueue: + """Unbounded async queue backed by deque, with put_left() support. + + Only implements the subset of asyncio.Queue APIs used in this module. + """ + def __init__(self) -> None: + self._q: Deque[Any] = deque() + self._unfinished_tasks: int = 0 + self._finished: asyncio.Event = asyncio.Event() + self._finished.set() + + def qsize(self) -> int: + return len(self._q) + + async def put(self, item: Any) -> None: + self._q.append(item) + self._unfinished_tasks += 1 + self._finished.clear() + + async def put_left(self, item: Any) -> None: + self._q.appendleft(item) + self._unfinished_tasks += 1 + self._finished.clear() + + def get_nowait(self) -> Any: + if not self._q: + raise asyncio.QueueEmpty + return self._q.popleft() + + def task_done(self) -> None: + if self._unfinished_tasks <= 0: + raise ValueError("task_done() called too many times") + self._unfinished_tasks -= 1 + if self._unfinished_tasks == 0: + self._finished.set() + + async def join(self) -> None: + await self._finished.wait() + class TaskQueueMixin: """Mixin providing task queue management, rate limiting, and status tracking. @@ -149,9 +202,10 @@ def _init_task_queue(self, config: Optional[TaskQueueConfig] = None) -> None: config: Optional TaskQueueConfig. If None, uses default config. """ self._task_queue_config = config or TaskQueueConfig() - # Per-key queues to avoid cross-adapter head-of-line blocking. - # Key is usually adapter_name; falls back to token/default. - self._task_queues: Dict[str, asyncio.Queue] = {} + # Per-key queues, but executed by a single global worker. + self._task_queues: Dict[str, _DequeTaskQueue] = {} + self._queue_order: Deque[str] = deque() + self._new_task_event: asyncio.Event = asyncio.Event() # Initialize rate limiter for RPS/TPS control self._rate_limiter = RateLimiter( @@ -164,88 +218,213 @@ def _init_task_queue(self, config: Optional[TaskQueueConfig] = None) -> None: # Start the rate limiter cleanup task self._rate_limiter.start_cleanup_task() - # Per-key workers - self._queue_workers: Dict[str, asyncio.Task] = {} - self._queue_worker_start_lock = asyncio.Lock() + # Single worker to ensure model operations remain serial. + self._worker_task: Optional[asyncio.Task] = None + self._worker_started = False + self._worker_start_lock = asyncio.Lock() + + # Event loop reference for thread-safe callbacks (e.g., adapter expiration thread) + self._event_loop: Optional[asyncio.AbstractEventLoop] = None @staticmethod def _queue_key( - adapter_name: Optional[str], + model_id: Optional[str], token: Optional[str], ) -> str: - if adapter_name: - return f"adapter:{adapter_name}" + if model_id: + return f"model:{model_id}" if token: return f"token:{token}" return "default" - async def _ensure_worker_started(self, queue_key: str) -> None: - """Ensure a background worker is running for a queue key.""" - if queue_key in self._queue_workers and not self._queue_workers[queue_key].done(): + async def _ensure_worker_started(self) -> None: + """Ensure the single background worker is running.""" + if self._worker_started and self._worker_task is not None and not self._worker_task.done(): return - async with self._queue_worker_start_lock: - worker = self._queue_workers.get(queue_key) - if worker is not None and not worker.done(): + async with self._worker_start_lock: + if self._worker_started and self._worker_task is not None and not self._worker_task.done(): return + self._worker_task = asyncio.create_task(self._queue_worker()) + self._worker_started = True - if queue_key not in self._task_queues: - self._task_queues[queue_key] = asyncio.Queue() + def _ensure_queue_registered(self, queue_key: str) -> None: + if queue_key not in self._task_queues: + self._task_queues[queue_key] = _DequeTaskQueue() + if queue_key not in self._queue_order: + self._queue_order.append(queue_key) - self._queue_workers[queue_key] = asyncio.create_task( - self._queue_worker(queue_key) - ) + async def _queue_worker(self) -> None: + """Single background worker that processes tasks serially across all queues. - async def _queue_worker(self, queue_key: str) -> None: - """Background worker that processes tasks from a specific queue serially.""" - print(f"[TaskQueue] Worker started for {queue_key}") - q = self._task_queues[queue_key] + Selection policy: round-robin across queue keys. If a task is rate-limited + at execution time, it is requeued and the worker tries other queues. + """ + print("[TaskQueue] Worker started") while True: try: - print( - f"[TaskQueue] Waiting for task... key={queue_key} (queue size: {q.qsize()})") - request_id, coro_factory, model_id = await q.get() - - print(f"[TaskQueue] Processing task {request_id} key={queue_key}") - try: + # Wait until there is at least one queue with a task + while True: + if any(q.qsize() > 0 for q in self._task_queues.values()): + break + self._new_task_event.clear() + await self._new_task_event.wait() + + executed_any = False + # Try each queue at most once per loop for fairness + for _ in range(len(self._queue_order)): + queue_key = self._queue_order[0] + self._queue_order.rotate(-1) + + q = self._task_queues.get(queue_key) + if q is None: + continue + + try: + task: _QueuedTask = q.get_nowait() + except asyncio.QueueEmpty: + continue + + now = time.monotonic() + + # Global queue timeout + if (now - task.created_at) > self._task_queue_config.queue_timeout: + error_payload = { + 'error': f"Queue timeout exceeded: waited {now - task.created_at:.2f}s", + 'category': 'Server' + } + self.state.store_future_status( + task.request_id, TaskStatus.FAILED.value, task.model_id, result=error_payload, + queue_state=QueueState.PAUSED_CAPACITY.value, + queue_state_reason=error_payload['error'], + ) + q.task_done() + continue + + # Rate limiting at execution time (requeue on limit) + if self._task_queue_config.enabled and task.token: + allowed, reason = await self._rate_limiter.check_and_record( + task.token, task.input_tokens + ) + if not allowed: + if task.first_rate_limited_at is None: + task.first_rate_limited_at = now + # If a task cannot get a slot within a window, fail it. + if (now - task.first_rate_limited_at) > self._task_queue_config.window_seconds: + error_payload = { + 'error': f"Rate limit wait exceeded window: {reason}", + 'category': 'Server' + } + self.state.store_future_status( + task.request_id, TaskStatus.FAILED.value, task.model_id, result=error_payload, + queue_state=QueueState.PAUSED_RATE_LIMIT.value, + queue_state_reason=reason, + ) + q.task_done() + continue + + # Put back to FRONT to preserve order, then try other queues + self.state.store_future_status( + task.request_id, TaskStatus.QUEUED.value, task.model_id, + queue_state=QueueState.PAUSED_RATE_LIMIT.value, + queue_state_reason=reason, + ) + await q.put_left(task) + q.task_done() + continue + + # Execute + executed_any = True self.state.store_future_status( - request_id, TaskStatus.RUNNING.value, model_id, + task.request_id, TaskStatus.RUNNING.value, task.model_id, queue_state=QueueState.ACTIVE.value ) - coro = coro_factory() - result = await coro - - print(f"[TaskQueue] Task {request_id} completed successfully") - self.state.store_future_status( - request_id, TaskStatus.COMPLETED.value, model_id, result=result - ) - except Exception: - print(f"[TaskQueue] Task {request_id} failed with error") - error_payload = { - 'error': traceback.format_exc(), - 'category': 'Server' - } - self.state.store_future_status( - request_id, TaskStatus.FAILED.value, model_id, result=error_payload - ) - finally: - q.task_done() + try: + coro = task.coro_factory() + result = await coro + self.state.store_future_status( + task.request_id, TaskStatus.COMPLETED.value, task.model_id, result=result, + queue_state=QueueState.ACTIVE.value + ) + except Exception: + error_payload = { + 'error': traceback.format_exc(), + 'category': 'Server' + } + self.state.store_future_status( + task.request_id, TaskStatus.FAILED.value, task.model_id, result=error_payload, + queue_state=QueueState.ACTIVE.value + ) + finally: + q.task_done() + + # Keep serial semantics: execute at most one runnable task per loop + break + + if not executed_any: + # All available tasks were rate-limited; avoid busy looping. + await asyncio.sleep(min(self._task_queue_config.window_seconds, 0.1)) except asyncio.CancelledError: - logger.warning(f"[TaskQueue] Worker cancelled key={queue_key}") + logger.warning("[TaskQueue] Worker cancelled") break except Exception: - logger.warning(f"Error in task queue worker key={queue_key}") + logger.warning("Error in task queue worker") continue + async def _fail_queue_tasks_async(self, queue_key: str, reason: str) -> None: + q = self._task_queues.get(queue_key) + if q is None: + return + + drained: list[_QueuedTask] = [] + while True: + try: + drained.append(q.get_nowait()) + except asyncio.QueueEmpty: + break + + for task in drained: + error_payload = { + 'error': reason, + 'category': 'Server' + } + self.state.store_future_status( + task.request_id, TaskStatus.FAILED.value, task.model_id, result=error_payload, + queue_state=QueueState.UNKNOWN.value, + queue_state_reason=reason, + ) + q.task_done() + + # Remove queue structures + self._task_queues.pop(queue_key, None) + try: + while queue_key in self._queue_order: + self._queue_order.remove(queue_key) + except ValueError: + pass + + def fail_pending_tasks_for_model(self, model_id: str, reason: str) -> None: + """Fail and drop queued tasks for a model. Safe to call from non-async threads.""" + queue_key = self._queue_key(model_id=model_id, token=None) + if self._event_loop is None: + # Best-effort: nothing we can do safely without a loop. + logger.warning( + f"[TaskQueue] fail_pending_tasks_for_model called without event loop: {queue_key}") + return + + def _schedule() -> None: + asyncio.create_task(self._fail_queue_tasks_async(queue_key, reason)) + + self._event_loop.call_soon_threadsafe(_schedule) + async def schedule_task( self, coro_factory: Callable[[], Coroutine], model_id: Optional[str] = None, token: Optional[str] = None, input_tokens: int = 0, - adapter_name: Optional[str] = None, task_type: Optional[str] = None, ) -> Dict[str, Any]: """Schedule an async task with rate limiting and status tracking. @@ -266,7 +445,6 @@ async def schedule_task( model_id: Optional model_id to associate with the result. token: Optional user token for rate limiting. input_tokens: Number of input tokens for tps rate limiting. - adapter_name: Optional adapter name used for per-adapter queueing. task_type: Optional task type for logging/observability. Returns: @@ -274,6 +452,9 @@ async def schedule_task( """ request_id = f"req_{uuid.uuid4().hex}" + if self._event_loop is None: + self._event_loop = asyncio.get_running_loop() + print( f"[TaskQueue] Scheduling task {request_id}, rps_limit={self._task_queue_config.rps_limit}, enabled={self._task_queue_config.enabled}") @@ -283,33 +464,28 @@ async def schedule_task( queue_state=QueueState.ACTIVE.value ) - # 2. Check rate limiting if enabled and token provided - if self._task_queue_config.enabled and token: - print( - f"[TaskQueue] Checking rate limit for token={token[:8]}... input_tokens={input_tokens}") - allowed, reason = await self._rate_limiter.check_and_record(token, input_tokens) - if not allowed: - print(f"[TaskQueue] Rate limited: {reason}") - self.state.store_future_status( - request_id, TaskStatus.RATE_LIMITED.value, model_id, - reason=reason, - queue_state=QueueState.PAUSED_RATE_LIMIT.value, - queue_state_reason=reason - ) - return {'request_id': request_id, 'model_id': model_id} - print(f"[TaskQueue] Rate limit check passed") + # 2. Route to per-model/per-token queue + queue_key = self._queue_key(model_id=model_id, token=token) + self._ensure_queue_registered(queue_key) - # 3. Route to per-adapter/per-token queue - queue_key = self._queue_key(adapter_name=adapter_name, token=token) + # 3. Ensure worker is started + await self._ensure_worker_started() - # 4. Ensure worker is started for this queue - await self._ensure_worker_started(queue_key) - - # 5. Put task in queue and update status + # 4. Put task in queue and update status q = self._task_queues[queue_key] print( f"[TaskQueue] Adding task {request_id} to queue key={queue_key} (current size: {q.qsize()}) type={task_type}") - await q.put((request_id, coro_factory, model_id)) + await q.put( + _QueuedTask( + request_id=request_id, + coro_factory=coro_factory, + model_id=model_id, + token=token, + input_tokens=input_tokens, + task_type=task_type, + created_at=time.monotonic(), + ) + ) self.state.store_future_status( request_id, TaskStatus.QUEUED.value, model_id, queue_state=QueueState.ACTIVE.value @@ -317,6 +493,8 @@ async def schedule_task( print( f"[TaskQueue] Task {request_id} queued, new queue size: {q.qsize()} key={queue_key}") + self._new_task_event.set() + return {'request_id': request_id, 'model_id': model_id} def get_queue_stats(self) -> Dict[str, Any]: @@ -328,8 +506,7 @@ def get_queue_stats(self) -> Dict[str, Any]: return { 'queue_size': sum(q.qsize() for q in self._task_queues.values()), 'queue_count': len(self._task_queues), - 'worker_running': any((t is not None and not t.done()) for t in self._queue_workers.values()), - 'worker_count': len(self._queue_workers), + 'worker_running': self._worker_task is not None and not self._worker_task.done(), 'rate_limit_config': { 'rps_limit': self._task_queue_config.rps_limit, 'tps_limit': self._task_queue_config.tps_limit, @@ -365,20 +542,18 @@ async def shutdown_task_queue(self) -> None: # Stop the rate limiter cleanup task await self._rate_limiter.stop_cleanup_task() - # Cancel all queue workers if running - for task in list(self._queue_workers.values()): - if task and not task.done(): - task.cancel() - - for task in list(self._queue_workers.values()): - if not task: - continue + # Cancel the worker task if running + if self._worker_task and not self._worker_task.done(): + self._worker_task.cancel() try: - await task + await self._worker_task except asyncio.CancelledError: pass - self._queue_workers.clear() + self._worker_task = None + self._worker_started = False + self._task_queues.clear() + self._queue_order.clear() print("[TaskQueue] Task queue shutdown complete") From 4e7f30d6b4e433d7a4fdf1153c5eb49393c08261 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Thu, 12 Feb 2026 15:39:17 +0800 Subject: [PATCH 04/22] update --- src/twinkle/server/tinker/model.py | 4 +- src/twinkle/server/tinker/sampler.py | 8 ++ src/twinkle/server/utils/adapter_manager.py | 121 +++++++++++--------- 3 files changed, 80 insertions(+), 53 deletions(-) diff --git a/src/twinkle/server/tinker/model.py b/src/twinkle/server/tinker/model.py index 2bbab8a3..51e783d2 100644 --- a/src/twinkle/server/tinker/model.py +++ b/src/twinkle/server/tinker/model.py @@ -126,12 +126,12 @@ def __init__(self, nproc_per_node: int, device_group: Dict[str, Any], self._init_adapter_manager(**adapter_config) self.start_adapter_countdown() - def _on_adapter_expired(self, adapter_name: str, token: str) -> None: + def _on_adapter_expired(self, adapter_name: str) -> None: # Called from AdapterManagerMixin's countdown thread. self.clear_adapter_state(adapter_name) # Fail any pending tasks for this adapter/model. self.fail_pending_tasks_for_model(adapter_name, reason='Adapter expired') - super()._on_adapter_expired(adapter_name, token) + super()._on_adapter_expired(adapter_name) @app.post('/create_model') async def create_model( diff --git a/src/twinkle/server/tinker/sampler.py b/src/twinkle/server/tinker/sampler.py index a17eb777..7a6b5fa1 100644 --- a/src/twinkle/server/tinker/sampler.py +++ b/src/twinkle/server/tinker/sampler.py @@ -9,6 +9,7 @@ 3. Multi-user inference with rate limiting 4. Flexible sampling parameters """ +import os import traceback from typing import Any, Dict, Optional @@ -166,6 +167,13 @@ async def _do_sample(): checkpoint_manager = create_checkpoint_manager(token) adapter_name, adapter_uri = checkpoint_manager.parse_adapter_uri(model_path) + # Validate adapter URI existence if provided + if not adapter_uri or not os.path.exists(adapter_uri): + return types.RequestFailedResponse( + error=f"Adapter URI {model_path} does not exist. Please check the model_path.", + category=types.RequestErrorCategory.User, + ) + # Convert tinker SamplingParams to twinkle SamplingParams if needed sampling_params = None if body.sampling_params: diff --git a/src/twinkle/server/utils/adapter_manager.py b/src/twinkle/server/utils/adapter_manager.py index 846ca4c5..3c0e2f2b 100644 --- a/src/twinkle/server/utils/adapter_manager.py +++ b/src/twinkle/server/utils/adapter_manager.py @@ -43,7 +43,12 @@ class AdapterManagerMixin: state: 'ServerStateProxy' model: 'TwinkleModel' - def _init_adapter_manager(self, adapter_timeout: float = 1800.0, per_token_adapter_limit: int = 30) -> None: + def _init_adapter_manager( + self, + adapter_timeout: float = 1800.0, + per_token_adapter_limit: int = 30, + adapter_max_lifetime: float = 12 * 60 * 60, + ) -> None: """Initialize the adapter manager. This should be called in the __init__ of the inheriting class. @@ -53,9 +58,12 @@ def _init_adapter_manager(self, adapter_timeout: float = 1800.0, per_token_adapt Default is 1800.0 (30 minutes). per_token_adapter_limit: Maximum number of adapters per user token. Default is 30. + adapter_max_lifetime: Maximum lifetime in seconds for an adapter since creation. + Default is 43200.0 (12 hours). If <= 0, lifetime enforcement is disabled. """ self._adapter_timeout = adapter_timeout self._per_token_adapter_limit = per_token_adapter_limit + self._adapter_max_lifetime = adapter_max_lifetime # Adapter lifecycle tracking # Dict mapping adapter_name -> {'token': str, 'last_activity': float, 'created_at': float, 'inactivity_counter': int} @@ -83,6 +91,7 @@ def register_adapter(self, adapter_name: str, token: str) -> None: 'created_at': current_time, 'inactivity_counter': 0, 'state': {}, + 'expiring': False, } logger.debug( f"[AdapterManager] Registered adapter {adapter_name} for token {token[:8]}...") @@ -157,12 +166,14 @@ def touch_adapter(self, adapter_name: str) -> bool: True if adapter was found and touched, False otherwise. """ with self._adapter_lock: - if adapter_name in self._adapter_records: - self._adapter_records[adapter_name]['last_activity'] = time.time( - ) - self._adapter_records[adapter_name]['inactivity_counter'] = 0 - return True - return False + info = self._adapter_records.get(adapter_name) + if not info: + return False + if info.get('expiring'): + return False + info['last_activity'] = time.time() + info['inactivity_counter'] = 0 + return True def get_adapter_info(self, adapter_name: str) -> Optional[Dict[str, Any]]: """Get information about a registered adapter. @@ -193,31 +204,25 @@ def list_adapters(self, token: Optional[str] = None) -> List[str]: if info.get('token') == token ] - def _on_adapter_expired(self, adapter_name: str, token: str) -> None: + def _on_adapter_expired(self, adapter_name: str) -> None: """Hook method called when an adapter expires. - Default implementation removes the adapter from the model and updates adapter count. + Default implementation removes the adapter from the model. This is called from the countdown thread, so be careful with blocking operations. Args: adapter_name: Name of the expired adapter. - token: User token that owns this adapter. """ try: # Best-effort cleanup of adapter state with self._adapter_lock: - if adapter_name in self._adapter_records: - self._adapter_records[adapter_name]['state'] = {} - # Remove adapter from model + info = self._adapter_records.get(adapter_name) + if info is not None: + info['state'] = {} self.model.remove_adapter(adapter_name) - logger.info( - f"[AdapterManager] Removed expired adapter {adapter_name} for token {token[:8]}...") - - # Decrement adapter count - self.check_adapter_limit(token, False) + logger.info(f"[AdapterManager] Removed expired adapter {adapter_name}") except Exception as e: - logger.warning( - f"[AdapterManager] Failed to remove expired adapter {adapter_name}: {e}") + logger.warning(f"[AdapterManager] Failed to remove expired adapter {adapter_name}: {e}") @staticmethod def get_adapter_name(adapter_name: str) -> str: @@ -234,16 +239,11 @@ def get_adapter_name(adapter_name: str) -> str: return adapter_name def assert_adapter_exists(self, adapter_name: str) -> None: - """Validate that an adapter exists. - - Args: - adapter_name: The adapter name to check - - Raises: - AssertionError: If adapter doesn't exist - """ - assert adapter_name and self.get_adapter_info(adapter_name) is not None, \ - f"Adapter {adapter_name} not found" + """Validate that an adapter exists and is not expiring.""" + with self._adapter_lock: + info = self._adapter_records.get(adapter_name) + assert adapter_name and info is not None and not info.get('expiring'), \ + f"Adapter {adapter_name} not found" def assert_adapter_valid(self, adapter_name: Optional[str]) -> None: """Validate that an adapter name is valid. @@ -271,38 +271,56 @@ def _adapter_countdown_loop(self) -> None: while self._adapter_countdown_running: try: time.sleep(1) + now = time.time() - # Find and process expired adapters - expired_adapters = [] + expired_adapters: List[Tuple[str, Optional[str]]] = [] with self._adapter_lock: for adapter_name, info in list(self._adapter_records.items()): - # Increment inactivity counter - info['inactivity_counter'] = info.get( - 'inactivity_counter', 0) + 1 + if info.get('expiring'): + continue + + created_at = info.get("created_at") or now + exceeded_ttl = ( + self._adapter_max_lifetime + and self._adapter_max_lifetime > 0 + and (now - created_at) > self._adapter_max_lifetime + ) + + info["inactivity_counter"] = info.get("inactivity_counter", 0) + 1 + exceeded_inactivity = info["inactivity_counter"] > self._adapter_timeout - # Check if adapter has timed out - if info['inactivity_counter'] > self._adapter_timeout: + if exceeded_ttl or exceeded_inactivity: + info['expiring'] = True + info['state'] = {} # best-effort clear token = info.get('token') expired_adapters.append((adapter_name, token)) - self._adapter_records.pop(adapter_name, None) logger.debug( - f"[AdapterManager] Adapter {adapter_name} timed out after " - f"{info['inactivity_counter']}s of inactivity" + f"[AdapterManager] Adapter {adapter_name} expired " + f"(ttl={exceeded_ttl}, inactivity={exceeded_inactivity})" ) - # Call hook method outside the lock for adapter_name, token in expired_adapters: + success = False try: - self._on_adapter_expired(adapter_name, token) + self._on_adapter_expired(adapter_name) + if token: + self.check_adapter_limit(token, False) + success = True except Exception as e: logger.warning( - f"[AdapterManager] Error in _on_adapter_expired() " - f"for {adapter_name}: {e}" + f"[AdapterManager] Error while expiring adapter {adapter_name}: {e}" ) + finally: + with self._adapter_lock: + if success: + self._adapter_records.pop(adapter_name, None) + else: + info = self._adapter_records.get(adapter_name) + if info is not None: + info['expiring'] = False except Exception as e: - logger.warning( - f"[AdapterManager] Error in countdown loop: {e}") + logger.warning(f"[AdapterManager] Error in countdown loop: {e}") continue logger.debug("[AdapterManager] Countdown thread stopped") @@ -342,11 +360,12 @@ def get_adapter_stats(self) -> Dict[str, Any]: """ with self._adapter_lock: return { - 'registered_adapters': len(self._adapter_records), - 'tracked_adapter_counts': len(self._adapter_counts), - 'countdown_running': self._adapter_countdown_running, - 'adapter_timeout_seconds': self._adapter_timeout, - 'per_token_adapter_limit': self._per_token_adapter_limit, + "registered_adapters": len(self._adapter_records), + "tracked_adapter_counts": len(self._adapter_counts), + "countdown_running": self._adapter_countdown_running, + "adapter_timeout_seconds": self._adapter_timeout, + "adapter_max_lifetime_seconds": self._adapter_max_lifetime, + "per_token_adapter_limit": self._per_token_adapter_limit, } def check_adapter_limit(self, token: str, add: bool) -> Tuple[bool, Optional[str]]: From 51b89e0bba3a3c8f2e9d614c016327184804b90a Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Thu, 12 Feb 2026 16:08:00 +0800 Subject: [PATCH 05/22] update --- .../server/tinker/common/compat_base.py | 40 ++++++++++++++++--- 1 file changed, 34 insertions(+), 6 deletions(-) diff --git a/src/twinkle/server/tinker/common/compat_base.py b/src/twinkle/server/tinker/common/compat_base.py index 9b573ff8..256141d8 100644 --- a/src/twinkle/server/tinker/common/compat_base.py +++ b/src/twinkle/server/tinker/common/compat_base.py @@ -67,15 +67,43 @@ def collect_forward_backward_results(results, device_mesh: DeviceMesh): def clean_metrics(metrics: dict) -> dict: + import re + from numbers import Number + + def _to_float(v): + # python numeric / numpy scalar + if isinstance(v, (float, int, Number, np.generic, str)): + try: + return float(v) + except Exception: + return None + # 0-d torch tensor + if isinstance(v, torch.Tensor) and v.numel() == 1: + try: + return float(v.item()) + except Exception: + return None + return None + cleaned = {} for key, value in metrics.items(): + fv = _to_float(value) + if fv is not None: + cleaned[key] = fv + continue + + # handle common metric strings: "123 seconds", "1.23 iters/s" if isinstance(value, str): - import re - match = re.match(r'^([+-]?\d*\.?\d+)', value.strip()) - if match: - cleaned[key] = float(match.group(1)) - else: - cleaned[key] = value + s = value.strip() + if s: + try: + head, unit = s.split() # ignore unit/tail + cleaned[f'{key}/{unit}'] = float(head) + except Exception: + m = re.match(r"^([+-]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?)", s) + if m: + cleaned[key] = float(m.group(1)) + return cleaned From 077e206f60a8007ee15a77e7f952807c93c8098f Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Thu, 12 Feb 2026 17:19:46 +0800 Subject: [PATCH 06/22] update --- src/twinkle/server/tinker/model.py | 22 +-- src/twinkle/server/utils/task_queue.py | 178 ++++++++++++++----------- 2 files changed, 110 insertions(+), 90 deletions(-) diff --git a/src/twinkle/server/tinker/model.py b/src/twinkle/server/tinker/model.py index 51e783d2..5835f8f3 100644 --- a/src/twinkle/server/tinker/model.py +++ b/src/twinkle/server/tinker/model.py @@ -309,9 +309,6 @@ async def _do_forward(): self.touch_adapter(adapter_name) datum_list = body.forward_input.data - assert len( - datum_list) >= self.device_mesh.data_world_size, f"Batch size {len(datum_list)} must be greater than data world size {self.device_mesh.data_world_size}" - loss_fn_config = body.forward_input.loss_fn_config or {} output = self.model.forward_only(inputs=datum_list, @@ -330,15 +327,19 @@ async def _do_forward(): category=types.RequestErrorCategory.Server, ) - # Calculate input tokens for rate limiting + # Calculate input tokens and batch size for validation + datum_list = body.forward_input.data input_tokens = sum( - len(d.model_input.to_ints()) for d in body.forward_input.data + len(d.model_input.to_ints()) for d in datum_list ) + batch_size = len(datum_list) return await self.schedule_task( _do_forward, model_id=body.model_id, token=request.state.token, input_tokens=input_tokens, + batch_size=batch_size, + data_world_size=self.device_mesh.data_world_size, task_type='forward', ) @@ -371,9 +372,6 @@ async def _do_forward_backward(): self.touch_adapter(adapter_name) datum_list = body.forward_backward_input.data - assert len( - datum_list) >= self.device_mesh.data_world_size, f"Batch size {len(datum_list)} must be greater than data world size {self.device_mesh.data_world_size}" - loss_fn = body.forward_backward_input.loss_fn loss_fn_config = body.forward_backward_input.loss_fn_config or {} @@ -397,15 +395,19 @@ async def _do_forward_backward(): category=types.RequestErrorCategory.Server, ) - # Calculate input tokens for rate limiting + # Calculate input tokens and batch size for validation + datum_list = body.forward_backward_input.data input_tokens = sum( - len(d.model_input.to_ints()) for d in body.forward_backward_input.data + len(d.model_input.to_ints()) for d in datum_list ) + batch_size = len(datum_list) return await self.schedule_task( _do_forward_backward, model_id=body.model_id, token=request.state.token, input_tokens=input_tokens, + batch_size=batch_size, + data_world_size=self.device_mesh.data_world_size, task_type='forward_backward', ) diff --git a/src/twinkle/server/utils/task_queue.py b/src/twinkle/server/utils/task_queue.py index 48158368..a1e8666a 100644 --- a/src/twinkle/server/utils/task_queue.py +++ b/src/twinkle/server/utils/task_queue.py @@ -61,8 +61,7 @@ class TaskQueueConfig: enabled: Whether rate limiting is enabled. token_cleanup_multiplier: Multiplier for token cleanup threshold. token_cleanup_interval: How often to run cleanup task (seconds). - per_token_adapter_limit: Maximum number of adapters per user token. - adapter_timeout: Timeout in seconds for inactive adapters (default 30 minutes). + max_input_tokens: Maximum allowed input tokens per request (default 10000). """ rps_limit: float = 100.0 # 10 requests per second tps_limit: float = 10000.0 # 10000 input tokens per second @@ -72,6 +71,7 @@ class TaskQueueConfig: # Remove tokens after 10x window inactivity token_cleanup_multiplier: float = 10.0 token_cleanup_interval: float = 60.0 # Run cleanup every 60 seconds + max_input_tokens: int = 10000 # Maximum input tokens per request @classmethod def from_dict(cls, config_dict: Optional[Dict[str, Any]] = None) -> 'TaskQueueConfig': @@ -86,6 +86,7 @@ def from_dict(cls, config_dict: Optional[Dict[str, Any]] = None) -> 'TaskQueueCo - enabled: whether rate limiting is enabled - token_cleanup_multiplier: multiplier for token cleanup threshold - token_cleanup_interval: cleanup task interval in seconds + - max_input_tokens: maximum input tokens per request Returns: TaskQueueConfig instance with values from dict merged with defaults. @@ -108,6 +109,8 @@ def from_dict(cls, config_dict: Optional[Dict[str, Any]] = None) -> 'TaskQueueCo if 'token_cleanup_interval' in config_dict: config.token_cleanup_interval = float( config_dict['token_cleanup_interval']) + if 'max_input_tokens' in config_dict: + config.max_input_tokens = int(config_dict['max_input_tokens']) return config @dataclass @@ -122,46 +125,6 @@ class _QueuedTask: first_rate_limited_at: Optional[float] = None -class _DequeTaskQueue: - """Unbounded async queue backed by deque, with put_left() support. - - Only implements the subset of asyncio.Queue APIs used in this module. - """ - def __init__(self) -> None: - self._q: Deque[Any] = deque() - self._unfinished_tasks: int = 0 - self._finished: asyncio.Event = asyncio.Event() - self._finished.set() - - def qsize(self) -> int: - return len(self._q) - - async def put(self, item: Any) -> None: - self._q.append(item) - self._unfinished_tasks += 1 - self._finished.clear() - - async def put_left(self, item: Any) -> None: - self._q.appendleft(item) - self._unfinished_tasks += 1 - self._finished.clear() - - def get_nowait(self) -> Any: - if not self._q: - raise asyncio.QueueEmpty - return self._q.popleft() - - def task_done(self) -> None: - if self._unfinished_tasks <= 0: - raise ValueError("task_done() called too many times") - self._unfinished_tasks -= 1 - if self._unfinished_tasks == 0: - self._finished.set() - - async def join(self) -> None: - await self._finished.wait() - - class TaskQueueMixin: """Mixin providing task queue management, rate limiting, and status tracking. @@ -203,7 +166,7 @@ def _init_task_queue(self, config: Optional[TaskQueueConfig] = None) -> None: """ self._task_queue_config = config or TaskQueueConfig() # Per-key queues, but executed by a single global worker. - self._task_queues: Dict[str, _DequeTaskQueue] = {} + self._task_queues: Dict[str, asyncio.Queue] = {} self._queue_order: Deque[str] = deque() self._new_task_event: asyncio.Event = asyncio.Event() @@ -250,7 +213,7 @@ async def _ensure_worker_started(self) -> None: def _ensure_queue_registered(self, queue_key: str) -> None: if queue_key not in self._task_queues: - self._task_queues[queue_key] = _DequeTaskQueue() + self._task_queues[queue_key] = asyncio.Queue() if queue_key not in self._queue_order: self._queue_order.append(queue_key) @@ -301,37 +264,7 @@ async def _queue_worker(self) -> None: q.task_done() continue - # Rate limiting at execution time (requeue on limit) - if self._task_queue_config.enabled and task.token: - allowed, reason = await self._rate_limiter.check_and_record( - task.token, task.input_tokens - ) - if not allowed: - if task.first_rate_limited_at is None: - task.first_rate_limited_at = now - # If a task cannot get a slot within a window, fail it. - if (now - task.first_rate_limited_at) > self._task_queue_config.window_seconds: - error_payload = { - 'error': f"Rate limit wait exceeded window: {reason}", - 'category': 'Server' - } - self.state.store_future_status( - task.request_id, TaskStatus.FAILED.value, task.model_id, result=error_payload, - queue_state=QueueState.PAUSED_RATE_LIMIT.value, - queue_state_reason=reason, - ) - q.task_done() - continue - - # Put back to FRONT to preserve order, then try other queues - self.state.store_future_status( - task.request_id, TaskStatus.QUEUED.value, task.model_id, - queue_state=QueueState.PAUSED_RATE_LIMIT.value, - queue_state_reason=reason, - ) - await q.put_left(task) - q.task_done() - continue + # Rate limiting check has been moved to schedule_task(), so tasks here should pass rate limits # Execute executed_any = True @@ -384,7 +317,7 @@ async def _fail_queue_tasks_async(self, queue_key: str, reason: str) -> None: drained.append(q.get_nowait()) except asyncio.QueueEmpty: break - + for task in drained: error_payload = { 'error': reason, @@ -405,6 +338,7 @@ async def _fail_queue_tasks_async(self, queue_key: str, reason: str) -> None: except ValueError: pass + def fail_pending_tasks_for_model(self, model_id: str, reason: str) -> None: """Fail and drop queued tasks for a model. Safe to call from non-async threads.""" queue_key = self._queue_key(model_id=model_id, token=None) @@ -419,12 +353,85 @@ def _schedule() -> None: self._event_loop.call_soon_threadsafe(_schedule) + async def _perform_preflight_checks( + self, + request_id: str, + model_id: Optional[str], + token: Optional[str], + input_tokens: int, + batch_size: Optional[int] = None, + data_world_size: Optional[int] = None, + ) -> Optional[Dict[str, Any]]: + """Perform pre-flight checks including rate limiting and token validation. + + Args: + request_id: The request ID for status tracking. + model_id: Optional model_id for error reporting. + token: Optional user token for rate limiting. + input_tokens: Number of input tokens for validation. + batch_size: Optional batch size for validation. + data_world_size: Optional data world size for batch size validation. + + Returns: + None if checks pass, or error response dict if checks fail. + """ + if not token or not self._task_queue_config.enabled: + return None + + # Check max input tokens + if input_tokens > self._task_queue_config.max_input_tokens: + error_msg = f"Input tokens ({input_tokens}) exceed maximum allowed ({self._task_queue_config.max_input_tokens})" + error_payload = { + 'error': error_msg, + 'category': 'User' + } + self.state.store_future_status( + request_id, TaskStatus.FAILED.value, model_id, result=error_payload, + queue_state=QueueState.UNKNOWN.value, + queue_state_reason=error_msg, + ) + return {'request_id': request_id, 'model_id': model_id} + + # Check batch size if provided + if batch_size is not None and data_world_size is not None: + if batch_size < data_world_size: + error_msg = f"Batch size {batch_size} must be greater than or equal to data world size {data_world_size}" + error_payload = { + 'error': error_msg, + 'category': 'User' + } + self.state.store_future_status( + request_id, TaskStatus.FAILED.value, model_id, result=error_payload, + queue_state=QueueState.UNKNOWN.value, + queue_state_reason=error_msg, + ) + return {'request_id': request_id, 'model_id': model_id} + + # Check rate limits + allowed, reason = await self._rate_limiter.check_and_record(token, input_tokens) + if not allowed: + error_msg = f"Rate limit exceeded: {reason}" + error_payload = { + 'error': error_msg, + 'category': 'User' + } + self.state.store_future_status( + request_id, TaskStatus.FAILED.value, model_id, result=error_payload, + queue_state=QueueState.PAUSED_RATE_LIMIT.value, + queue_state_reason=error_msg, + ) + return {'request_id': request_id, 'model_id': model_id} + + return None + async def schedule_task( self, coro_factory: Callable[[], Coroutine], model_id: Optional[str] = None, token: Optional[str] = None, input_tokens: int = 0, + batch_size: Optional[int] = None, + data_world_size: Optional[int] = None, task_type: Optional[str] = None, ) -> Dict[str, Any]: """Schedule an async task with rate limiting and status tracking. @@ -445,12 +452,22 @@ async def schedule_task( model_id: Optional model_id to associate with the result. token: Optional user token for rate limiting. input_tokens: Number of input tokens for tps rate limiting. + batch_size: Optional batch size for validation. + data_world_size: Optional data world size for batch size validation. task_type: Optional task type for logging/observability. Returns: Dict containing request_id and model_id for future retrieval. """ + # Generate request_id first so it can be included in error responses request_id = f"req_{uuid.uuid4().hex}" + + # 1. Pre-flight checks: rate limiting, max token validation, and batch size validation + preflight_result = await self._perform_preflight_checks( + request_id, model_id, token, input_tokens, batch_size, data_world_size + ) + if preflight_result is not None: + return preflight_result if self._event_loop is None: self._event_loop = asyncio.get_running_loop() @@ -458,20 +475,21 @@ async def schedule_task( print( f"[TaskQueue] Scheduling task {request_id}, rps_limit={self._task_queue_config.rps_limit}, enabled={self._task_queue_config.enabled}") - # 1. Register PENDING status FIRST (fixes race condition) + + # 2. Register PENDING status FIRST self.state.store_future_status( request_id, TaskStatus.PENDING.value, model_id, queue_state=QueueState.ACTIVE.value ) - # 2. Route to per-model/per-token queue + # 3. Route to per-model/per-token queue queue_key = self._queue_key(model_id=model_id, token=token) self._ensure_queue_registered(queue_key) - # 3. Ensure worker is started + # 4. Ensure worker is started await self._ensure_worker_started() - # 4. Put task in queue and update status + # 5. Put task in queue and update status q = self._task_queues[queue_key] print( f"[TaskQueue] Adding task {request_id} to queue key={queue_key} (current size: {q.qsize()}) type={task_type}") From 87ce96db8f25a686905a85d956d6ba3ed4ba51aa Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Thu, 12 Feb 2026 17:59:39 +0800 Subject: [PATCH 07/22] update --- cookbook/client/tinker/grpo.py | 292 ------------------ cookbook/client/tinker/gsm8k_grpo.py | 9 +- .../tinker/megatron/server_config_7b.yaml | 8 +- cookbook/client/tinker/sample.py | 29 +- cookbook/client/tinker/self_congnition.py | 41 +-- 5 files changed, 47 insertions(+), 332 deletions(-) delete mode 100644 cookbook/client/tinker/grpo.py diff --git a/cookbook/client/tinker/grpo.py b/cookbook/client/tinker/grpo.py deleted file mode 100644 index 94145f2a..00000000 --- a/cookbook/client/tinker/grpo.py +++ /dev/null @@ -1,292 +0,0 @@ -# Tinker-Compatible Client - GRPO (Group Relative Policy Optimization) Training Example -# -# This script demonstrates GRPO reinforcement learning training using the -# Tinker-compatible client API with save_weights_for_sampler for weight sync. -# Instead of calling sync_weights directly, it periodically saves weights and -# creates a sampling client for generation. -# -# Flow: -# 1. Prepare Countdown dataset (client-side) -# 2. Initialize Tinker-compatible training & sampling clients -# 3. Training loop: -# a. Every SYNC_INTERVAL steps: save_weights_for_sampler → sampling_client -# b. Sample completions from the sampling client -# c. Compute rewards and advantages (client-side) -# d. Train on sampled data weighted by advantages -# e. Optimizer step -# -# The server must be running first (see server.py and server_config.yaml). -# Requires both model and sampler services to be configured. - -import gc -import numpy as np -from typing import List, Tuple - -from tinker import types -from twinkle_client import init_tinker_compat_client -from twinkle import get_logger -from twinkle.advantage import GRPOAdvantage -from twinkle.dataloader import DataLoader -from twinkle.dataset import Dataset, DatasetMeta -from twinkle.metric import CompletionRewardMetric -from modelscope import AutoTokenizer - -logger = get_logger() - -# ========== Configuration ========== -BASE_MODEL = 'Qwen/Qwen2.5-7B-Instruct' -NUM_GENERATIONS = 4 -MAX_NEW_TOKENS = 1024 -LEARNING_RATE = 1e-5 -MAX_STEPS = 100 -BATCH_SIZE = 2 -TEMPERATURE = 1.0 -SYNC_INTERVAL = 2 # Save weights for sampler every N steps -LORA_RANK = 8 - - -def create_countdown_dataset(): - """Create Countdown Game dataset for GRPO training.""" - logger.info("Loading Countdown dataset...") - - dataset = Dataset(DatasetMeta( - "ms://zouxuhong/Countdown-Tasks-3to4", data_slice=range(500))) - dataset.set_template( - "Template", model_id=f'ms://{BASE_MODEL}', max_length=8192) - dataset.map('CountdownProcessor') - dataset.encode(add_generation_prompt=True) - - logger.info(f"Dataset loaded with {len(dataset)} samples") - return dataset - - -def compute_rewards( - trajectories: List[dict], -) -> Tuple[List[float], List[float], List[float]]: - """Compute format and accuracy rewards for Countdown game.""" - from twinkle.reward import CountDownAccuracy, FormatReward - format_rewards = FormatReward()(trajectories, []) - accuracy_rewards = CountDownAccuracy()(trajectories, []) - total_rewards = [a + b for a, b in zip(accuracy_rewards, format_rewards)] - return total_rewards, format_rewards, accuracy_rewards - - -def main(): - logger.info("Starting GRPO training...") - - # Step 1: Prepare dataset and dataloader (client-side) - dataset = create_countdown_dataset() - dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) - tokenizer = AutoTokenizer.from_pretrained( - BASE_MODEL, trust_remote_code=True) - - logger.info("Dataset and tokenizer initialized") - - # Step 2: Initialize the Tinker-compatible client - logger.info("Connecting to Tinker server...") - service_client = init_tinker_compat_client( - base_url='http://localhost:8000') - - logger.info("Creating LoRA training client...") - # Create a LoRA training client for GRPO - training_client = service_client.create_lora_training_client( - base_model=BASE_MODEL, - rank=LORA_RANK, - ) - - logger.info("Training client created successfully") - - # Step 3: Setup metrics and advantage function - advantage_fn = GRPOAdvantage() - metrics = CompletionRewardMetric() - - sampling_params = types.SamplingParams( - max_tokens=MAX_NEW_TOKENS, - temperature=TEMPERATURE, - top_p=0.95, - ) - - # The sampling client is created on-demand via save_weights_for_sampler - sampling_client = None - - step = 0 - for batch in dataloader: - if step >= MAX_STEPS: - break - - metrics.reset() - prompts = batch if isinstance(batch, list) else [batch] - - # ========== 1. Save weights for sampler (instead of sync_weights) ========== - if step % SYNC_INTERVAL == 0: - logger.info(f"Step {step}: Saving weights for sampler...") - - sampling_client = ( - training_client.save_weights_and_get_sampling_client( - name=f'grpo-step-{step}')) - logger.info(f"Step {step}: Sampling client ready") - - if sampling_client is None: - logger.warning("No sampling client available, skipping step") - step += 1 - continue - - # ========== 2. Sample completions ========== - # Convert input features to token prompts for the sampling client - all_sequences = [] - for prompt_feature in prompts: - input_ids = prompt_feature['input_ids'] - if hasattr(input_ids, 'tolist'): - input_ids = input_ids.tolist() - prompt = types.ModelInput.from_ints(input_ids) - future = sampling_client.sample( - prompt=prompt, - sampling_params=sampling_params, - num_samples=NUM_GENERATIONS, - ) - result = future.result() - all_sequences.extend(result.sequences) - - if not all_sequences: - logger.warning(f"Step {step}: No valid samples, skipping") - step += 1 - continue - - # ========== 3. Build trajectories and collect logprobs ========== - trajectories = [] - old_logps_list = [] - completion_lengths = [] - - for seq in all_sequences: - decoded_text = tokenizer.decode(seq.tokens, skip_special_tokens=True) - trajectories.append({ - 'messages': [{'role': 'assistant', 'content': decoded_text}] - }) - old_logps_list.append( - [lp for lp in seq.logprobs] if seq.logprobs else []) - completion_lengths.append(len(seq.tokens)) - - # ========== 4. Compute rewards ========== - total_rewards, format_rewards, accuracy_rewards = compute_rewards( - trajectories) - metrics.accumulate( - None, None, - completion_lengths=completion_lengths, - rewards={ - 'total': total_rewards, - 'format': format_rewards, - 'accuracy': accuracy_rewards, - }) - - # ========== 5. Compute advantages ========== - advantages = advantage_fn( - total_rewards, - num_generations=NUM_GENERATIONS, - scale='group', - ).tolist() - - frac_zero_std = ( - 1.0 if all(abs(a) < 1e-8 for a in advantages) else 0.0) - if frac_zero_std == 1.0: - logger.info( - f"Step {step}: All advantages are zero, skipping training") - step += 1 - continue - - # ========== 6. Train the policies with GRPO loss ========== - # Train the policies with the Advantage-Regularized policy - # gradient (GRPO) loss function. - # - # The GRPO loss function requires: - # 1. logprobs: The log probabilities of the tokens under the current policy - # 2. advantages: The advantage values for each completion - # - # The training data is constructed with: - # - model_input: The full prompt + completion tokens - # - target_tokens: The shifted tokens for next-token prediction - # - logprobs: The log probabilities from the sampling step - # - advantages: The computed advantage values - training_data = [] - for i, seq in enumerate(all_sequences): - # Build a Datum from the completion tokens with logprobs and advantages - prompt_feature = prompts[i // NUM_GENERATIONS] - prompt_ids = prompt_feature['input_ids'] - if hasattr(prompt_ids, 'tolist'): - prompt_ids = prompt_ids.tolist() - - sampled_tokens = list(seq.tokens) - logprobs = seq.logprobs if seq.logprobs else [0.0] * len(sampled_tokens) - advantage = float(advantages[i]) - - ob_len = len(prompt_ids) - 1 - input_tokens = prompt_ids + sampled_tokens[:-1] - target_tokens = [0] * ob_len + sampled_tokens - padded_advantages = [0.0] * ob_len + [advantage] * len(sampled_tokens) - padded_logprobs = [0.0] * ob_len + logprobs - - # Verify lengths match - assert len(input_tokens) == len(target_tokens) == len(padded_logprobs) == len(padded_advantages), \ - f"Length mismatch: input={len(input_tokens)}, target={len(target_tokens)}, " \ - f"logprobs={len(padded_logprobs)}, advantages={len(padded_advantages)}" - - datum = types.Datum( - model_input=types.ModelInput.from_ints(input_tokens), - loss_fn_inputs={ - 'target_tokens': target_tokens, - 'logprobs': types.TensorData.from_numpy(np.array(padded_logprobs, dtype=np.float32)), - 'advantages': types.TensorData.from_numpy(np.array(padded_advantages, dtype=np.float32)), - }, - ) - training_data.append(datum) - - if not training_data: - logger.info( - f"Step {step}: No training data constructed, skipping") - step += 1 - continue - - # Forward-backward pass with importance_sampling (GRPO) loss - # The training data already contains logprobs and advantages for the GRPO loss - fwdbwd_future = training_client.forward_backward( - training_data, "importance_sampling") - optim_future = training_client.optim_step( - types.AdamParams(learning_rate=LEARNING_RATE)) - - fwdbwd_result = fwdbwd_future.result() - optim_result = optim_future.result() - - # Compute metrics from the forward-backward result - # For importance_sampling, we get logprobs and elementwise_loss - logprobs_list = [] - elementwise_losses = [] - for output in fwdbwd_result.loss_fn_outputs: - if output.get('logprobs') is not None: - logprobs_list.append(output['logprobs'].to_numpy()) - if output.get('elementwise_loss') is not None: - elementwise_losses.append(output['elementwise_loss'].to_numpy()) - - # Compute average loss per token (weighted by advantages) - if elementwise_losses: - all_losses = np.concatenate(elementwise_losses) - avg_loss = np.mean(all_losses) if len(all_losses) > 0 else 0.0 - else: - avg_loss = 0.0 - - gc.collect() - - # ========== 7. Log ========== - log_dict = metrics.calculate() - log_dict['train/loss_per_token'] = float(avg_loss) - log_dict['train/frac_reward_zero_std'] = frac_zero_std - log_dict['train/num_training_samples'] = len(training_data) - logger.info(f"Step {step}: {log_dict}") - step += 1 - - # Save final checkpoint - save_future = training_client.save_state("grpo-countdown-final") - save_result = save_future.result() - logger.info(f"Saved final checkpoint to {save_result.path}") - - -if __name__ == '__main__': - main() diff --git a/cookbook/client/tinker/gsm8k_grpo.py b/cookbook/client/tinker/gsm8k_grpo.py index 139dd40e..5bac2798 100644 --- a/cookbook/client/tinker/gsm8k_grpo.py +++ b/cookbook/client/tinker/gsm8k_grpo.py @@ -33,7 +33,7 @@ from twinkle.data_format import Trajectory, InputFeature, Message from twinkle.dataset import Dataset, DatasetMeta from twinkle.metric import CompletionRewardMetric -from modelscope import AutoTokenizer +from twinkle.template import Template logger = get_logger() @@ -208,10 +208,9 @@ def main(): # Step 1: Prepare dataset and dataloader (client-side) dataset = create_gsm8k_dataset() dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) - tokenizer = AutoTokenizer.from_pretrained( - BASE_MODEL, trust_remote_code=True) + template = Template(model_id=f'ms://{BASE_MODEL}') - logger.info("Dataset and tokenizer initialized") + logger.info("Dataset and template initialized") # Step 2: Initialize the Tinker-compatible client logger.info("Connecting to Tinker server...") @@ -293,7 +292,7 @@ def main(): completion_lengths = [] for idx, seq in enumerate(all_sequences): - decoded_text = tokenizer.decode(seq.tokens, skip_special_tokens=True) + decoded_text = template.decode(seq.tokens, skip_special_tokens=True) # Use the corresponding user data for this sequence trajectories.append({ 'messages': [ diff --git a/cookbook/client/tinker/megatron/server_config_7b.yaml b/cookbook/client/tinker/megatron/server_config_7b.yaml index 1dac6d6c..7b625570 100644 --- a/cookbook/client/tinker/megatron/server_config_7b.yaml +++ b/cookbook/client/tinker/megatron/server_config_7b.yaml @@ -50,10 +50,12 @@ applications: dp_size: 2 queue_config: rps_limit: 100 # Max requests per second - tps_limit: 100000 # Max tokens per second + tps_limit: 10000 # Max tokens per second for a single user + max_input_tokens: 10000 # Maximum input tokens per request adapter_config: - per_token_adapter_limit: 30 # Max concurrent LoRA adapters - adapter_timeout: 1800 # Seconds before idle adapter unload + per_token_adapter_limit: 30 # Max concurrent LoRA adapters per user + adapter_timeout: 60 # Seconds before idle adapter unload + adapter_max_lifetime: 36000 # Maximum lifetime of an adapter in seconds (e.g., 10 hours) deployments: - name: ModelManagement autoscaling_config: diff --git a/cookbook/client/tinker/sample.py b/cookbook/client/tinker/sample.py index 4752c5b9..1a7d473d 100644 --- a/cookbook/client/tinker/sample.py +++ b/cookbook/client/tinker/sample.py @@ -6,28 +6,41 @@ from tinker import types from twinkle_client import init_tinker_compat_client -from modelscope import AutoTokenizer +from twinkle.data_format import Message, Trajectory +from twinkle.template import Template # Step 1: Define the base model and connect to the server base_model = "Qwen/Qwen2.5-7B-Instruct" -service_client = init_tinker_compat_client(base_url='http://localhost:8000', api_key="tml-EMPTY_TOKEN") +service_client = init_tinker_compat_client(base_url='http://localhost:8000') # Step 2: Create a sampling client by loading weights from a saved checkpoint. # The model_path is a twinkle:// URI pointing to a previously saved LoRA checkpoint. # The server will load the base model and apply the LoRA adapter weights. sampling_client = service_client.create_sampling_client( - model_path="twinkle://20260130_133245-Qwen_Qwen2_5-0_5B-Instruct-ffebd239/weights/pig-latin-lora-epoch-1", + model_path="twinkle://20260212_174205-Qwen_Qwen2_5-7B-Instruct-51edc9ed/weights/twinkle-lora-2", base_model=base_model) # Step 3: Load the tokenizer locally to encode the prompt and decode the results print(f"Using model {base_model}") -tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) + +template = Template(model_id=f'ms://{base_model}') + +trajectory = Trajectory( + messages=[ + Message(role='system', content='You are a helpful assistant'), + Message(role='user', content="你是谁?"), + ] +) + +input_feature = template.encode(trajectory, add_generation_prompt=True) + +input_ids = input_feature['input_ids'].tolist() # Step 4: Prepare the prompt and sampling parameters -prompt = types.ModelInput.from_ints(tokenizer.encode("English: coffee break\nPig Latin:")) +prompt = types.ModelInput.from_ints(input_ids) params = types.SamplingParams( - max_tokens=20, # Maximum number of tokens to generate - temperature=0.0, # Greedy sampling (deterministic, always pick the top token) + max_tokens=128, # Maximum number of tokens to generate + temperature=0.7, stop=["\n"] # Stop generation when a newline character is produced ) @@ -40,4 +53,4 @@ # Step 6: Decode and print the generated responses print("Responses:") for i, seq in enumerate(result.sequences): - print(f"{i}: {repr(tokenizer.decode(seq.tokens))}") + print(f"{i}: {repr(template.decode(seq.tokens))}") diff --git a/cookbook/client/tinker/self_congnition.py b/cookbook/client/tinker/self_congnition.py index d846b5f5..669d87d0 100644 --- a/cookbook/client/tinker/self_congnition.py +++ b/cookbook/client/tinker/self_congnition.py @@ -11,11 +11,12 @@ from tqdm import tqdm from tinker import types from twinkle_client import init_tinker_compat_client +from twinkle.data_format import Message, Trajectory +from twinkle.template import Template from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.preprocessor import SelfCognitionProcessor from twinkle.server.tinker.common import input_feature_to_datum -from modelscope import AutoTokenizer # The base model to fine-tune / evaluate base_model = "Qwen/Qwen2.5-7B-Instruct" @@ -83,7 +84,7 @@ def eval(): # Step 1: Load the trained LoRA checkpoint for inference # Path to a previously saved LoRA checkpoint (twinkle:// URI) - weight_path = "twinkle://20260211_112719-Qwen_Qwen2_5-7B-Instruct-a74a4826/weights/twinkle-lora-2" + weight_path = "twinkle://20260212_174205-Qwen_Qwen2_5-7B-Instruct-51edc9ed/weights/twinkle-lora-2" # Connect to the server and create a sampling client with the trained weights service_client = init_tinker_compat_client(base_url='http://localhost:8000') @@ -91,30 +92,22 @@ def eval(): model_path=weight_path, base_model=base_model) - # Load the tokenizer for encoding the prompt and decoding the output - tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) - # Step 2: Prepare the chat prompt # Build a multi-turn conversation to test the model's self-cognition - inputs = [ - { - 'role': 'system', - 'content': 'You are a helpful assistant.' - }, - { - 'role': 'user', - 'content': 'what is your name?' - } - ] - - # Apply the model's chat template to format the conversation - input_ids = tokenizer.apply_chat_template( - inputs, - tokenize=True, - add_generation_prompt=True # Adds the assistant prompt prefix + template = Template(model_id=f'ms://{base_model}') + + trajectory = Trajectory( + messages=[ + Message(role='system', content='You are a helpful assistant'), + Message(role='user', content="你是谁?"), + ] ) + input_feature = template.encode(trajectory, add_generation_prompt=True) + + input_ids = input_feature['input_ids'].tolist() + # Step 3: Generate responses prompt = types.ModelInput.from_ints(input_ids) @@ -132,9 +125,9 @@ def eval(): # Decode and print each response print("Responses:") for i, seq in enumerate(result.sequences): - print(f"{i}: {repr(tokenizer.decode(seq.tokens))}") + print(f"{i}: {repr(template.decode(seq.tokens))}") if __name__ == "__main__": - train() # Uncomment to run training - # eval() # Run evaluation / inference + # train() # Uncomment to run training + eval() # Run evaluation / inference From 41d92f8bd106c07b0623989db271925f9ff6804e Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Thu, 12 Feb 2026 20:20:38 +0800 Subject: [PATCH 08/22] update --- src/twinkle/server/tinker/model.py | 58 ++-- src/twinkle/server/twinkle/model.py | 35 +- src/twinkle/server/twinkle/sampler.py | 243 +------------- src/twinkle/server/utils/adapter_manager.py | 354 +++++++++----------- src/twinkle/server/utils/state.py | 19 ++ 5 files changed, 246 insertions(+), 463 deletions(-) diff --git a/src/twinkle/server/tinker/model.py b/src/twinkle/server/tinker/model.py index 5835f8f3..06fbcd9b 100644 --- a/src/twinkle/server/tinker/model.py +++ b/src/twinkle/server/tinker/model.py @@ -126,12 +126,36 @@ def __init__(self, nproc_per_node: int, device_group: Dict[str, Any], self._init_adapter_manager(**adapter_config) self.start_adapter_countdown() + def _cleanup_adapter(self, adapter_name: str) -> None: + """Common adapter cleanup logic used by both manual unload and automatic expiration. + + This method handles: + 1. Clearing adapter state + 2. Removing adapter from model + 3. Unregistering from adapter manager + 4. Removing from server state + + Args: + adapter_name: Name of the adapter to clean up + """ + # Remove from model if it exists + if self.get_adapter_info(adapter_name): + # Clear adapter state + self.clear_adapter_state(adapter_name) + + self.model.remove_adapter(adapter_name) + # Unregister from adapter manager + self.unregister_adapter(adapter_name) + + # Remove from server state + self.state.unload_model(adapter_name) + def _on_adapter_expired(self, adapter_name: str) -> None: # Called from AdapterManagerMixin's countdown thread. - self.clear_adapter_state(adapter_name) # Fail any pending tasks for this adapter/model. self.fail_pending_tasks_for_model(adapter_name, reason='Adapter expired') - super()._on_adapter_expired(adapter_name) + # Perform common cleanup (without token since it's automatic) + self._cleanup_adapter(adapter_name) @app.post('/create_model') async def create_model( @@ -159,25 +183,21 @@ async def create_model( async def _create_adapter(): try: if body.lora_config: - # Check adapter limit before creating - allowed, reason = self.check_adapter_limit( - request.state.token, True) - if not allowed: - raise RuntimeError(reason) - # TODO: support more lora config parameters, train_unembed, etc. lora_cfg = LoraConfig( r=body.lora_config.rank, target_modules='all-linear') adapter_name = self.get_adapter_name( adapter_name=model_id) + + # Register adapter FIRST (limit check happens inside register_adapter) + self.register_adapter( + adapter_name, request.state.token, session_id=body.session_id) + + # Create adapter AFTER successful registration self.model.add_adapter_to_model( adapter_name=adapter_name, config_or_dir=lora_cfg) - # Register adapter with rate limiter for lifecycle tracking - self.register_adapter( - adapter_name, request.state.token) - self.model.set_template('Template', adapter_name=adapter_name, model_id=self.base_model) self.model.set_processor('InputProcessor', adapter_name=adapter_name) @@ -195,9 +215,7 @@ async def _create_adapter(): except Exception: # Ensure we don't leave stale grad state. adapter_name = self.get_adapter_name(adapter_name=model_id) - self.clear_adapter_state(adapter_name) - # If adapter creation fails, decrement the count - self.check_adapter_limit(request.state.token, False) + self._cleanup_adapter(adapter_name) logger.error(traceback.format_exc()) return types.RequestFailedResponse( @@ -266,14 +284,8 @@ async def _do_unload(): # Only remove adapter, not the base model adapter_name = self.get_adapter_name( adapter_name=body.model_id) - self.clear_adapter_state(adapter_name) - if self.get_adapter_info(adapter_name): - self.model.remove_adapter(adapter_name) - # Unregister adapter from rate limiter - self.unregister_adapter(adapter_name) - # Decrement adapter count via rate limiter - self.check_adapter_limit(request.state.token, False) - self.state.unload_model(body.model_id) + # Use common cleanup logic + self._cleanup_adapter(adapter_name) return types.UnloadModelResponse(model_id=body.model_id) return await self.schedule_task( diff --git a/src/twinkle/server/twinkle/model.py b/src/twinkle/server/twinkle/model.py index 44364459..6f1a6d6b 100644 --- a/src/twinkle/server/twinkle/model.py +++ b/src/twinkle/server/twinkle/model.py @@ -182,6 +182,30 @@ def __init__(self, nproc_per_node: int, device_group: Dict[str, Any], device_mes self._init_adapter_manager(**adapter_config) self.start_adapter_countdown() + def _on_adapter_expired(self, adapter_name: str) -> None: + """Handle adapter expiration by removing it from the model. + + This method is called automatically by AdapterManagerMixin when + an adapter exceeds its timeout or TTL. + + Args: + adapter_name: Name of the expired adapter to remove. + """ + # Remove from model if it exists + if self.get_adapter_info(adapter_name): + # Clear adapter state + self.clear_adapter_state(adapter_name) + + self.model.remove_adapter(adapter_name) + # Unregister from adapter manager + self.unregister_adapter(adapter_name) + + # Remove from server state + self.state.unload_model(adapter_name) + # Remove adapter from model + self.model.remove_adapter(adapter_name) + + @app.post("/create") def create(self, request: Request, body: CreateRequest): return {'status': 'ok'} @@ -496,16 +520,11 @@ def add_adapter_to_model(self, request: Request, body: AddAdapterRequest): token = request.state.token training_run_manager = create_training_run_manager(token) - with self._adapter_lock: - self.model.add_adapter_to_model(adapter_name, config, **extra_kwargs) - - # Register adapter for lifecycle tracking + # Register adapter FIRST (limit check happens inside register_adapter) self.register_adapter(adapter_name, token) - # Check adapter limit (raises if exceeded) - allowed, reason = self.check_adapter_limit(token, True) - if not allowed: - raise RuntimeError(reason) + # Create adapter AFTER successful registration + self.model.add_adapter_to_model(adapter_name, config, **extra_kwargs) # Save training run metadata (similar to tinker's create_model) # Create a training run config from the adapter configuration diff --git a/src/twinkle/server/twinkle/sampler.py b/src/twinkle/server/twinkle/sampler.py index e46fd967..e4a25230 100644 --- a/src/twinkle/server/twinkle/sampler.py +++ b/src/twinkle/server/twinkle/sampler.py @@ -2,17 +2,6 @@ """ Twinkle sampler (inference) server. -This module provides a Ray Serve deployment for distributed text generation/inference. -It supports: -1. vLLM and Torch sampler backends -2. LoRA adapter loading via adapter URIs (twinkle:// paths or local paths) -3. Multi-user inference with adapter lifecycle management -4. Flexible sampling parameters -""" -import traceback -""" -Twinkle sampler (inference) server. - This module provides a Ray Serve deployment for distributed text generation/inference. It supports: 1. vLLM and Torch sampler backends @@ -23,8 +12,6 @@ import traceback from typing import Dict, Any, List, Optional, Union -from fastapi import FastAPI, Request -from pydantic import BaseModel, Field from fastapi import FastAPI, Request from pydantic import BaseModel, Field from ray import serve @@ -34,82 +21,12 @@ from twinkle.data_format import Trajectory, InputFeature, SamplingParams from twinkle.server.utils.adapter_manager import AdapterManagerMixin from twinkle.server.utils.validation import verify_request_token, get_token_from_request -from twinkle.data_format import Trajectory, InputFeature, SamplingParams -from twinkle.server.utils.adapter_manager import AdapterManagerMixin -from twinkle.server.utils.validation import verify_request_token, get_token_from_request from twinkle.server.utils.state import get_server_state, ServerStateProxy from twinkle.utils.logger import get_logger logger = get_logger() -# ----- Request/Response Models ----- - -class SampleRequest(BaseModel): - """Request body for the /sample endpoint.""" - inputs: Any = Field(..., description="List of Trajectory or InputFeature dicts") - sampling_params: Optional[Dict[str, Any]] = Field( - None, description="Sampling parameters (max_tokens, temperature, etc.)") - adapter_name: str = Field('', description="Adapter name for LoRA inference") - adapter_uri: Optional[str] = Field( - None, description="Adapter URI (twinkle:// path or local path) for LoRA inference") - num_samples: int = Field(1, description="Number of completions to generate per prompt") - - -class SampleResponseModel(BaseModel): - """Response body for the /sample endpoint.""" - sequences: List[Dict[str, Any]] = Field( - ..., description="List of sampled sequences, each with tokens, logprobs, stop_reason") - prompt_logprobs: Optional[List[Optional[float]]] = None - topk_prompt_logprobs: Optional[List[Optional[List]]] = None - - -class SetTemplateRequest(BaseModel): - """Request body for the /set_template endpoint.""" - template_cls: str = Field(..., description="Template class name (e.g. 'Template')") - adapter_name: str = Field('', description="Adapter name to associate the template with") - - class Config: - extra = "allow" - - -class SetTemplateResponse(BaseModel): - """Response body for the /set_template endpoint.""" - status: str = "ok" - - -class AddAdapterRequest(BaseModel): - """Request body for the /add_adapter_to_sampler endpoint.""" - adapter_name: str = Field(..., description="Name of the adapter to add") - config: Any = Field(..., description="LoRA configuration dict") - - -class AddAdapterResponse(BaseModel): - """Response body for the /add_adapter_to_sampler endpoint.""" - status: str = "ok" - adapter_name: str - -class HeartbeatRequest(BaseModel): - """Request body for the /heartbeat endpoint.""" - adapter_name: str = Field(..., description="Adapter name to keep alive") - - -class HeartbeatResponse(BaseModel): - """Response body for the /heartbeat endpoint.""" - status: str = "ok" - - -class CreateResponse(BaseModel): - """Response body for the /create endpoint.""" - status: str = "ok" - - -# ----- Application Builder ----- -from twinkle.utils.logger import get_logger - -logger = get_logger() - - # ----- Request/Response Models ----- class SampleRequest(BaseModel): @@ -184,27 +101,6 @@ def build_sampler_app(model_id: str, **kwargs): """Build a sampler application for text generation inference. - Args: - model_id: Model identifier (e.g., "Qwen/Qwen2.5-7B-Instruct") - nproc_per_node: Number of GPU processes per node - device_group: Device group configuration dict - device_mesh: Device mesh configuration dict for parallelism - deploy_options: Ray Serve deployment options - sampler_type: Type of sampler to use ('vllm' or 'torch') - engine_args: Additional engine arguments for the sampler - adapter_config: Adapter lifecycle config (adapter_timeout, per_token_adapter_limit) - **kwargs: Additional arguments passed to the sampler - - Returns: - Ray Serve deployment bound with configuration - """ - app = FastAPI( - title="Twinkle Sampler", - description="REST API for distributed text generation inference", - version="1.0.0" - ) - """Build a sampler application for text generation inference. - Args: model_id: Model identifier (e.g., "Qwen/Qwen2.5-7B-Instruct") nproc_per_node: Number of GPU processes per node @@ -287,26 +183,11 @@ def _on_adapter_expired(self, adapter_name: str, token: str) -> None: try: self.sampler.remove_adapter(adapter_name) logger.info(f"Removed expired adapter {adapter_name}") - self.check_adapter_limit(token, False) - except Exception as e: - logger.warning(f"Failed to remove expired adapter {adapter_name}: {e}") - _adapter_config = adapter_config or {} - self._init_adapter_manager(**_adapter_config) - self.start_adapter_countdown() - - def _on_adapter_expired(self, adapter_name: str, token: str) -> None: - """Handle expired adapters by removing them from the sampler.""" - try: - self.sampler.remove_adapter(adapter_name) - logger.info(f"Removed expired adapter {adapter_name}") - self.check_adapter_limit(token, False) + # Adapter count is now tracked dynamically, no manual update needed except Exception as e: logger.warning(f"Failed to remove expired adapter {adapter_name}: {e}") @staticmethod - def _get_adapter_name(request: Request, adapter_name: Optional[str]) -> Optional[str]: - if adapter_name is None or adapter_name == '': - return None def _get_adapter_name(request: Request, adapter_name: Optional[str]) -> Optional[str]: if adapter_name is None or adapter_name == '': return None @@ -404,117 +285,10 @@ def add_adapter_to_sampler(self, request: Request, body: AddAdapterRequest) -> A from peft import LoraConfig config = LoraConfig(**body.config) if isinstance(body.config, dict) else body.config - - with self._adapter_lock: - self.sampler.add_adapter_to_sampler(full_adapter_name, config) - + self.register_adapter(full_adapter_name, token) - allowed, reason = self.check_adapter_limit(token, True) - if not allowed: - raise RuntimeError(reason) - - return AddAdapterResponse(adapter_name=full_adapter_name) - - @app.post("/create", response_model=CreateResponse) - def create(self, request: Request) -> CreateResponse: - """Health check / session creation endpoint.""" - return CreateResponse() - - @app.post("/sample", response_model=SampleResponseModel) - def sample(self, request: Request, body: SampleRequest) -> SampleResponseModel: - """Sample completions from the model. - - Supports: - - Trajectory inputs (messages-based, requires template to be set) - - InputFeature inputs (pre-tokenized input_ids) - - LoRA adapter via adapter_name or adapter_uri (twinkle:// path) - - Multiple completions per prompt via num_samples - """ - try: - # Resolve adapter - adapter_path = None - adapter_name = body.adapter_name or '' - full_adapter_name = self._get_adapter_name(request, adapter_name) or '' - - if body.adapter_uri: - from .common.io_utils import create_checkpoint_manager - token = get_token_from_request(request) - checkpoint_manager = create_checkpoint_manager(token) - _, adapter_path = checkpoint_manager.parse_adapter_uri(body.adapter_uri) - - # Parse inputs - inputs = body.inputs - if isinstance(inputs, list) and inputs: - first = inputs[0] - if isinstance(first, dict) and 'input_ids' in first: - inputs = [InputFeature(**item) for item in inputs] - else: - inputs = [Trajectory(**item) for item in inputs] - elif isinstance(inputs, dict): - if 'input_ids' in inputs: - inputs = [InputFeature(**inputs)] - else: - inputs = [Trajectory(**inputs)] - - # Build sampling params - params = None - if body.sampling_params: - params = SamplingParams.from_dict(body.sampling_params) - - # Call sampler - response = self.sampler.sample( - inputs, - params, - adapter_name=full_adapter_name, - adapter_path=adapter_path, - num_samples=body.num_samples, - ) - if callable(response): - response = response() - - # Convert to response model - sequences = [] - for seq in response.sequences: - sequences.append({ - 'stop_reason': seq.stop_reason, - 'tokens': list(seq.tokens), - 'logprobs': list(seq.logprobs) if seq.logprobs is not None else None, - }) - - return SampleResponseModel( - sequences=sequences, - prompt_logprobs=response.prompt_logprobs, - topk_prompt_logprobs=response.topk_prompt_logprobs, - ) - except Exception: - logger.error(traceback.format_exc()) - raise - @app.post("/set_template", response_model=SetTemplateResponse) - def set_template(self, request: Request, body: SetTemplateRequest) -> SetTemplateResponse: - """Set the chat template for encoding Trajectory inputs.""" - full_adapter_name = self._get_adapter_name(request, body.adapter_name) or '' - extra_kwargs = body.model_extra or {} - self.sampler.set_template(body.template_cls, **extra_kwargs) - return SetTemplateResponse() - - @app.post("/add_adapter_to_sampler", response_model=AddAdapterResponse) - def add_adapter_to_sampler(self, request: Request, body: AddAdapterRequest) -> AddAdapterResponse: - """Add a LoRA adapter to the sampler.""" - assert body.adapter_name, 'You need to specify a valid `adapter_name`' - full_adapter_name = self._get_adapter_name(request, body.adapter_name) - token = get_token_from_request(request) - - from peft import LoraConfig - config = LoraConfig(**body.config) if isinstance(body.config, dict) else body.config - - with self._adapter_lock: - self.sampler.add_adapter_to_sampler(full_adapter_name, config) - - self.register_adapter(full_adapter_name, token) - allowed, reason = self.check_adapter_limit(token, True) - if not allowed: - raise RuntimeError(reason) + self.sampler.add_adapter_to_sampler(full_adapter_name, config) return AddAdapterResponse(adapter_name=full_adapter_name) @@ -525,15 +299,6 @@ def heartbeat(self, request: Request, body: HeartbeatRequest) -> HeartbeatRespon self.assert_adapter_exists(adapter_name=full_adapter_name) self.touch_adapter(full_adapter_name) return HeartbeatResponse() - @app.post("/heartbeat", response_model=HeartbeatResponse) - def heartbeat(self, request: Request, body: HeartbeatRequest) -> HeartbeatResponse: - """Keep an adapter alive by resetting its inactivity timer.""" - full_adapter_name = self._get_adapter_name(request, body.adapter_name) - self.assert_adapter_exists(adapter_name=full_adapter_name) - self.touch_adapter(full_adapter_name) - return HeartbeatResponse() return SamplerManagement.options(**deploy_options).bind( - nproc_per_node, device_group, device_mesh, sampler_type, engine_args, adapter_config, **kwargs) - return SamplerManagement.options(**deploy_options).bind( - nproc_per_node, device_group, device_mesh, sampler_type, engine_args, adapter_config, **kwargs) + nproc_per_node, device_group, device_mesh, sampler_type, engine_args, adapter_config, **kwargs) \ No newline at end of file diff --git a/src/twinkle/server/utils/adapter_manager.py b/src/twinkle/server/utils/adapter_manager.py index 3c0e2f2b..8730bcdf 100644 --- a/src/twinkle/server/utils/adapter_manager.py +++ b/src/twinkle/server/utils/adapter_manager.py @@ -30,18 +30,15 @@ class AdapterManagerMixin: that have been inactive for longer than the configured timeout period. Inheriting classes should: - 1. Have a `self.model` attribute for model operations - 2. Call _init_adapter_manager() in __init__ - 3. Optionally override _on_adapter_expired() to customize expiration handling + 1. Call _init_adapter_manager() in __init__ + 2. Override _on_adapter_expired() to customize expiration handling Attributes: _adapter_timeout: Timeout in seconds for inactive adapters. - model: Model instance for adapter operations (must be set by inheriting class). """ # Type hint for state attribute that inheriting classes must provide state: 'ServerStateProxy' - model: 'TwinkleModel' def _init_adapter_manager( self, @@ -54,8 +51,9 @@ def _init_adapter_manager( This should be called in the __init__ of the inheriting class. Args: - adapter_timeout: Timeout in seconds for inactive adapters. - Default is 1800.0 (30 minutes). + adapter_timeout: Timeout in seconds for inactive adapters and session-based expiration. + Default is 1800.0 (30 minutes). Adapters linked to sessions will expire + when their session hasn't been touched for this duration. per_token_adapter_limit: Maximum number of adapters per user token. Default is 30. adapter_max_lifetime: Maximum lifetime in seconds for an adapter since creation. @@ -66,35 +64,65 @@ def _init_adapter_manager( self._adapter_max_lifetime = adapter_max_lifetime # Adapter lifecycle tracking - # Dict mapping adapter_name -> {'token': str, 'last_activity': float, 'created_at': float, 'inactivity_counter': int} + # Dict mapping adapter_name -> {'token': str, 'session_id': str, 'last_activity': float, 'created_at': float, 'inactivity_counter': int} self._adapter_records: Dict[str, Dict[str, Any]] = {} # Track adapter count per token self._adapter_counts: Dict[str, int] = {} - self._adapter_lock = threading.Lock() # Countdown thread self._adapter_countdown_thread: Optional[threading.Thread] = None self._adapter_countdown_running = False - def register_adapter(self, adapter_name: str, token: str) -> None: + def register_adapter(self, adapter_name: str, token: str, session_id: Optional[str] = None) -> None: """Register a new adapter for lifecycle tracking. Args: adapter_name: Name of the adapter to register. token: User token that owns this adapter. + session_id: Optional session ID to associate with this adapter. + If provided, adapter will expire when the session expires. + + Raises: + RuntimeError: If adapter limit is exceeded for this token. """ - with self._adapter_lock: - current_time = time.time() - self._adapter_records[adapter_name] = { - 'token': token, - 'last_activity': current_time, - 'created_at': current_time, - 'inactivity_counter': 0, - 'state': {}, - 'expiring': False, - } - logger.debug( - f"[AdapterManager] Registered adapter {adapter_name} for token {token[:8]}...") + # Check adapter limit BEFORE registering + allowed, reason = self.check_adapter_limit(token) + if not allowed: + raise RuntimeError(reason) + + current_time = time.time() + self._adapter_records[adapter_name] = { + 'token': token, + 'session_id': session_id, + 'last_activity': current_time, + 'created_at': current_time, + 'inactivity_counter': 0, + 'state': {}, + 'expiring': False, + } + logger.debug( + f"[AdapterManager] Registered adapter {adapter_name} for token {token[:8]}..." + + (f" (session: {session_id})" if session_id else "")) + + def _is_session_alive(self, session_id: str) -> bool: + """Check if a session is still alive via state proxy. + + Args: + session_id: Session ID to check + + Returns: + True if session is alive, False if expired or not found + """ + if not session_id: + return True # No session association means always alive + + # Get session last heartbeat through proxy + last_heartbeat = self.state.get_session_last_heartbeat(session_id) + if last_heartbeat is None: + return False # Session doesn't exist + + # Check if session has timed out using adapter_timeout + return (time.time() - last_heartbeat) < self._adapter_timeout def unregister_adapter(self, adapter_name: str) -> bool: """Unregister an adapter from lifecycle tracking. @@ -105,14 +133,13 @@ def unregister_adapter(self, adapter_name: str) -> bool: Returns: True if adapter was found and removed, False otherwise. """ - with self._adapter_lock: - if adapter_name in self._adapter_records: - adapter_info = self._adapter_records.pop(adapter_name) - token = adapter_info.get('token') - logger.debug( - f"[AdapterManager] Unregistered adapter {adapter_name} for token {token[:8] if token else 'unknown'}...") - return True - return False + if adapter_name in self._adapter_records: + adapter_info = self._adapter_records.pop(adapter_name) + token = adapter_info.get('token') + logger.debug( + f"[AdapterManager] Unregistered adapter {adapter_name} for token {token[:8] if token else 'unknown'}...") + return True + return False def set_adapter_state(self, adapter_name: str, key: str, value: Any) -> None: """Set a per-adapter state value. @@ -121,40 +148,36 @@ def set_adapter_state(self, adapter_name: str, key: str, value: Any) -> None: adapter-scoped state (e.g., training readiness) without maintaining separate side maps. """ - with self._adapter_lock: - info = self._adapter_records.get(adapter_name) - if info is None: - return - state = info.setdefault('state', {}) - state[key] = value + info = self._adapter_records.get(adapter_name) + if info is None: + return + state = info.setdefault('state', {}) + state[key] = value def get_adapter_state(self, adapter_name: str, key: str, default: Any = None) -> Any: """Get a per-adapter state value.""" - with self._adapter_lock: - info = self._adapter_records.get(adapter_name) - if info is None: - return default - state = info.get('state') or {} - return state.get(key, default) + info = self._adapter_records.get(adapter_name) + if info is None: + return default + state = info.get('state') or {} + return state.get(key, default) def pop_adapter_state(self, adapter_name: str, key: str, default: Any = None) -> Any: """Pop a per-adapter state value.""" - with self._adapter_lock: - info = self._adapter_records.get(adapter_name) - if info is None: - return default - state = info.get('state') - if not isinstance(state, dict): - return default - return state.pop(key, default) + info = self._adapter_records.get(adapter_name) + if info is None: + return default + state = info.get('state') + if not isinstance(state, dict): + return default + return state.pop(key, default) def clear_adapter_state(self, adapter_name: str) -> None: """Clear all per-adapter state values.""" - with self._adapter_lock: - info = self._adapter_records.get(adapter_name) - if info is None: - return - info['state'] = {} + info = self._adapter_records.get(adapter_name) + if info is None: + return + info['state'] = {} def touch_adapter(self, adapter_name: str) -> bool: """Update adapter activity timestamp to prevent timeout. @@ -165,15 +188,14 @@ def touch_adapter(self, adapter_name: str) -> bool: Returns: True if adapter was found and touched, False otherwise. """ - with self._adapter_lock: - info = self._adapter_records.get(adapter_name) - if not info: - return False - if info.get('expiring'): - return False - info['last_activity'] = time.time() - info['inactivity_counter'] = 0 - return True + info = self._adapter_records.get(adapter_name) + if not info: + return False + if info.get('expiring'): + return False + info['last_activity'] = time.time() + info['inactivity_counter'] = 0 + return True def get_adapter_info(self, adapter_name: str) -> Optional[Dict[str, Any]]: """Get information about a registered adapter. @@ -184,45 +206,23 @@ def get_adapter_info(self, adapter_name: str) -> Optional[Dict[str, Any]]: Returns: Dict with adapter information or None if not found. """ - with self._adapter_lock: - return self._adapter_records.get(adapter_name) - - def list_adapters(self, token: Optional[str] = None) -> List[str]: - """List all registered adapters, optionally filtered by token. - - Args: - token: Optional user token to filter by. - - Returns: - List of adapter names. - """ - with self._adapter_lock: - if token is None: - return list(self._adapter_records.keys()) - return [ - name for name, info in self._adapter_records.items() - if info.get('token') == token - ] + return self._adapter_records.get(adapter_name) def _on_adapter_expired(self, adapter_name: str) -> None: """Hook method called when an adapter expires. - Default implementation removes the adapter from the model. - This is called from the countdown thread, so be careful with blocking operations. + This method must be overridden by inheriting classes to handle + adapter expiration logic. The base implementation raises NotImplementedError. Args: adapter_name: Name of the expired adapter. + + Raises: + NotImplementedError: If not overridden by inheriting class. """ - try: - # Best-effort cleanup of adapter state - with self._adapter_lock: - info = self._adapter_records.get(adapter_name) - if info is not None: - info['state'] = {} - self.model.remove_adapter(adapter_name) - logger.info(f"[AdapterManager] Removed expired adapter {adapter_name}") - except Exception as e: - logger.warning(f"[AdapterManager] Failed to remove expired adapter {adapter_name}: {e}") + raise NotImplementedError( + f"_on_adapter_expired must be implemented by {self.__class__.__name__}" + ) @staticmethod def get_adapter_name(adapter_name: str) -> str: @@ -240,23 +240,9 @@ def get_adapter_name(adapter_name: str) -> str: def assert_adapter_exists(self, adapter_name: str) -> None: """Validate that an adapter exists and is not expiring.""" - with self._adapter_lock: - info = self._adapter_records.get(adapter_name) - assert adapter_name and info is not None and not info.get('expiring'), \ - f"Adapter {adapter_name} not found" - - def assert_adapter_valid(self, adapter_name: Optional[str]) -> None: - """Validate that an adapter name is valid. - - Args: - adapter_name: The adapter name to validate (can be None or empty) - - Raises: - AssertionError: If adapter name is invalid - """ - assert (adapter_name is None or adapter_name == '' or - self.get_adapter_info(adapter_name) is not None), \ - f"Adapter {adapter_name} is invalid" + info = self._adapter_records.get(adapter_name) + assert adapter_name and info is not None and not info.get('expiring'), \ + f"Adapter {adapter_name} not found" def _adapter_countdown_loop(self) -> None: """Background thread that monitors and handles inactive adapters. @@ -274,50 +260,69 @@ def _adapter_countdown_loop(self) -> None: now = time.time() expired_adapters: List[Tuple[str, Optional[str]]] = [] - with self._adapter_lock: - for adapter_name, info in list(self._adapter_records.items()): - if info.get('expiring'): - continue - - created_at = info.get("created_at") or now - exceeded_ttl = ( - self._adapter_max_lifetime - and self._adapter_max_lifetime > 0 - and (now - created_at) > self._adapter_max_lifetime - ) - + # Create snapshot to avoid modification during iteration + adapter_snapshot = list(self._adapter_records.items()) + for adapter_name, info in adapter_snapshot: + if info.get('expiring'): + continue + + session_id = info.get('session_id') + created_at = info.get("created_at") + + # Check TTL for both cases + exceeded_ttl = ( + self._adapter_max_lifetime + and self._adapter_max_lifetime > 0 + and (now - created_at) > self._adapter_max_lifetime + ) + + # Different logic based on session association + if session_id: + # Has session: check session expiration and TTL + session_expired = not self._is_session_alive(session_id) + should_expire = session_expired or exceeded_ttl + expiration_reasons = [] + if exceeded_ttl: + expiration_reasons.append("ttl_exceeded") + if session_expired: + expiration_reasons.append("session_expired") + else: + # No session: check inactivity timeout and TTL info["inactivity_counter"] = info.get("inactivity_counter", 0) + 1 exceeded_inactivity = info["inactivity_counter"] > self._adapter_timeout - - if exceeded_ttl or exceeded_inactivity: - info['expiring'] = True - info['state'] = {} # best-effort clear - token = info.get('token') - expired_adapters.append((adapter_name, token)) - logger.debug( - f"[AdapterManager] Adapter {adapter_name} expired " - f"(ttl={exceeded_ttl}, inactivity={exceeded_inactivity})" - ) + should_expire = exceeded_ttl or exceeded_inactivity + expiration_reasons = [] + if exceeded_ttl: + expiration_reasons.append("ttl_exceeded") + if exceeded_inactivity: + expiration_reasons.append("inactivity_timeout") + + if should_expire: + info['expiring'] = True + info['state'] = {} # best-effort clear + token = info.get('token') + expired_adapters.append((adapter_name, token)) + logger.debug( + f"[AdapterManager] Adapter {adapter_name} expired " + f"(reasons={','.join(expiration_reasons)}, session={session_id})" + ) for adapter_name, token in expired_adapters: success = False try: self._on_adapter_expired(adapter_name) - if token: - self.check_adapter_limit(token, False) success = True except Exception as e: logger.warning( f"[AdapterManager] Error while expiring adapter {adapter_name}: {e}" ) finally: - with self._adapter_lock: - if success: - self._adapter_records.pop(adapter_name, None) - else: - info = self._adapter_records.get(adapter_name) - if info is not None: - info['expiring'] = False + if success: + self._adapter_records.pop(adapter_name, None) + else: + info = self._adapter_records.get(adapter_name) + if info is not None: + info['expiring'] = False except Exception as e: logger.warning(f"[AdapterManager] Error in countdown loop: {e}") @@ -352,64 +357,27 @@ def stop_adapter_countdown(self) -> None: self._adapter_countdown_thread.join(timeout=2.0) logger.debug("[AdapterManager] Countdown thread stopped") - def get_adapter_stats(self) -> Dict[str, Any]: - """Get adapter manager statistics. - - Returns: - Dict with registered adapter count and configuration. - """ - with self._adapter_lock: - return { - "registered_adapters": len(self._adapter_records), - "tracked_adapter_counts": len(self._adapter_counts), - "countdown_running": self._adapter_countdown_running, - "adapter_timeout_seconds": self._adapter_timeout, - "adapter_max_lifetime_seconds": self._adapter_max_lifetime, - "per_token_adapter_limit": self._per_token_adapter_limit, - } - - def check_adapter_limit(self, token: str, add: bool) -> Tuple[bool, Optional[str]]: - """Check and update adapter count for a user token. + def check_adapter_limit(self, token: str) -> Tuple[bool, Optional[str]]: + """Check adapter count for a user token. This method enforces per-user adapter limits to prevent resource exhaustion. + Counts adapters directly from _adapter_records instead of using state storage. Args: - token: User token to check/update. - add: True to add an adapter (increment count), False to remove (decrement count). + token: User token to check. Returns: Tuple of (allowed: bool, reason: Optional[str]). If allowed is False, reason contains the explanation. """ - user_key = token + '_' + 'model_adapter' - with self._adapter_lock: - current_count = self.state.get_config(user_key) or 0 - - if add: - # Check if adding would exceed limit - if current_count >= self._per_token_adapter_limit: - return False, f"Adapter limit exceeded: {current_count}/{self._per_token_adapter_limit} adapters" - # Increment count in global state - self.state.add_config(user_key, current_count + 1) - return True, None - else: - # Decrement count in global state - if current_count > 0: - current_count -= 1 - self.state.add_config(user_key, current_count) - if current_count <= 0: - self.state.pop_config(user_key) - return True, None - - def get_adapter_count(self, token: str) -> int: - """Get current adapter count for a user token. + # Count adapters directly from _adapter_records + current_count = sum( + 1 for record in self._adapter_records.values() + if record.get('token') == token and not record.get('expiring', False) + ) - Args: - token: User token to query. + # Check if current count exceeds limit + if current_count >= self._per_token_adapter_limit: + return False, f"Adapter limit exceeded: {current_count}/{self._per_token_adapter_limit} adapters" + return True, None - Returns: - Current number of adapters for this token. - """ - user_key = token + '_' + 'model_adapter' - with self._adapter_lock: - return self.state.get_config(user_key) or 0 diff --git a/src/twinkle/server/utils/state.py b/src/twinkle/server/utils/state.py index a413a98e..c1a8b3b4 100644 --- a/src/twinkle/server/utils/state.py +++ b/src/twinkle/server/utils/state.py @@ -64,6 +64,7 @@ def create_session(self, payload: Dict[str, Any]) -> str: 'user_metadata': payload.get('user_metadata') or {}, 'sdk_version': payload.get('sdk_version'), 'created_at': datetime.now().isoformat(), + 'last_heartbeat': time.time(), } return session_id @@ -82,6 +83,21 @@ def touch_session(self, session_id: str) -> bool: self.sessions[session_id]['last_heartbeat'] = time.time() return True + def get_session_last_heartbeat(self, session_id: str) -> Optional[float]: + """ + Get the last heartbeat timestamp for a session. + + Args: + session_id: The session ID to query + + Returns: + Last heartbeat timestamp, or None if session doesn't exist + """ + session_info = self.sessions.get(session_id) + if not session_info: + return None + return session_info.get('last_heartbeat') + # ----- Model Registration ----- def register_model(self, @@ -465,6 +481,9 @@ def create_session(self, payload: Dict[str, Any]) -> str: def touch_session(self, session_id: str) -> bool: return ray.get(self._actor.touch_session.remote(session_id)) + def get_session_last_heartbeat(self, session_id: str) -> Optional[float]: + return ray.get(self._actor.get_session_last_heartbeat.remote(session_id)) + # ----- Model Registration ----- def register_model(self, From 80c0fd8115b5167b0e7e3edd498eb50978c66c43 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Thu, 12 Feb 2026 21:24:38 +0800 Subject: [PATCH 09/22] update --- src/twinkle/server/tinker/server.py | 41 ++++++++++++++++++----------- src/twinkle/server/utils/state.py | 10 ++++--- 2 files changed, 31 insertions(+), 20 deletions(-) diff --git a/src/twinkle/server/tinker/server.py b/src/twinkle/server/tinker/server.py index 3e74e63b..c3c68879 100644 --- a/src/twinkle/server/tinker/server.py +++ b/src/twinkle/server/tinker/server.py @@ -33,6 +33,7 @@ def build_server_app( deploy_options: Dict[str, Any], supported_models: Optional[List[types.SupportedModel]] = None, + server_config: Dict[str, Any] = {}, **kwargs ): """Build and configure the Tinker-compatible server application. @@ -43,23 +44,12 @@ def build_server_app( Args: deploy_options: Ray Serve deployment configuration (num_replicas, etc.) supported_models: List of supported base models for validation + server_config: Server configuration options (per_token_adapter_limit, etc.) **kwargs: Additional keyword arguments (route_prefix, etc.) Returns: Configured Ray Serve deployment bound with options """ - # Normalize supported_models to objects; passing raw dicts can trigger internal errors - # when creating LoRA training clients via the tinker API. - if supported_models: - normalized = [] - for item in supported_models: - if isinstance(item, types.SupportedModel): - normalized.append(item) - elif isinstance(item, dict): - normalized.append(types.SupportedModel(**item)) - else: - raise TypeError(...) - supported_models = normalized app = FastAPI() @app.middleware("http") @@ -79,18 +69,19 @@ class TinkerCompatServer: - Training run and checkpoint CRUD operations """ - def __init__(self, supported_models: Optional[List[types.SupportedModel]] = None, **kwargs) -> None: + def __init__(self, supported_models: Optional[List[types.SupportedModel]] = None, server_config: Dict[str, Any] = {}, **kwargs) -> None: """Initialize the Tinker-compatible server. Args: supported_models: List of supported base models for validation **kwargs: Additional configuration (route_prefix, etc.) """ - self.state = get_server_state() + # Get per_token_adapter_limit from kwargs or use default + self.state = get_server_state(**server_config) # Disable proxy for internal requests to avoid routing through external proxies self.client = httpx.AsyncClient(timeout=None, trust_env=False) self.route_prefix = kwargs.get("route_prefix", "/api/v1") - self.supported_models = supported_models or [ + self.supported_models = self.normalize_models(supported_models) or [ types.SupportedModel(model_name="Qwen/Qwen2.5-0.5B-Instruct"), types.SupportedModel(model_name="Qwen/Qwen2.5-3B-Instruct"), types.SupportedModel(model_name="Qwen/Qwen2.5-7B-Instruct"), @@ -100,6 +91,20 @@ def __init__(self, supported_models: Optional[List[types.SupportedModel]] = None # Lock for ModelScope config file operations (login writes, get_user_info reads) self._modelscope_config_lock = asyncio.Lock() + def normalize_models(self, supported_models): + # Normalize supported_models to objects; passing raw dicts can trigger internal errors + # when creating LoRA training clients via the tinker API. + if supported_models: + normalized = [] + for item in supported_models: + if isinstance(item, types.SupportedModel): + normalized.append(item) + elif isinstance(item, dict): + normalized.append(types.SupportedModel(**item)) + else: + normalized.append(types.SupportedModel(name=item)) + return normalized + def _validate_base_model(self, base_model: str) -> None: """Validate that base_model is in supported_models list. @@ -710,4 +715,8 @@ async def save_weights_for_sampler( base_model = self._get_base_model(body.model_id) return await self._proxy_to_model(request, "save_weights_for_sampler", base_model) - return TinkerCompatServer.options(**deploy_options).bind(supported_models=supported_models, **kwargs) + return TinkerCompatServer.options(**deploy_options).bind( + supported_models=supported_models, + server_config=server_config, + **kwargs + ) diff --git a/src/twinkle/server/utils/state.py b/src/twinkle/server/utils/state.py index c1a8b3b4..fbd12744 100644 --- a/src/twinkle/server/utils/state.py +++ b/src/twinkle/server/utils/state.py @@ -584,7 +584,8 @@ def get_cleanup_stats(self) -> Dict[str, Any]: def get_server_state(actor_name: str = 'twinkle_server_state', - auto_start_cleanup: bool = True) -> ServerStateProxy: + auto_start_cleanup: bool = True, + **server_state_kwargs) -> ServerStateProxy: """ Get or create the ServerState Ray actor. @@ -594,6 +595,8 @@ def get_server_state(actor_name: str = 'twinkle_server_state', Args: actor_name: Name for the Ray actor (default: 'twinkle_server_state') auto_start_cleanup: Whether to automatically start the cleanup task (default: True) + **server_state_kwargs: Additional keyword arguments passed to ServerState constructor + (e.g., expiration_timeout, cleanup_interval, per_token_adapter_limit) Returns: A ServerStateProxy for interacting with the actor @@ -603,7 +606,7 @@ def get_server_state(actor_name: str = 'twinkle_server_state', except ValueError: try: _ServerState = ray.remote(ServerState) - actor = _ServerState.options(name=actor_name, lifetime='detached').remote() + actor = _ServerState.options(name=actor_name, lifetime='detached').remote(**server_state_kwargs) # Start cleanup task for newly created actor if auto_start_cleanup: try: @@ -613,5 +616,4 @@ def get_server_state(actor_name: str = 'twinkle_server_state', except ValueError: actor = ray.get_actor(actor_name) assert actor is not None - return ServerStateProxy(actor) - + return ServerStateProxy(actor) \ No newline at end of file From 16494e0961a56515ad56f9b4a4845a41720a8561 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Fri, 13 Feb 2026 09:22:45 +0800 Subject: [PATCH 10/22] update --- .gitignore | 3 +- cookbook/client/tinker/megatron/server.py | 2 +- .../tinker/megatron/server_config_7b.yaml | 74 +++++++++---------- 3 files changed, 40 insertions(+), 39 deletions(-) diff --git a/.gitignore b/.gitignore index cc500d06..c9efb363 100644 --- a/.gitignore +++ b/.gitignore @@ -151,4 +151,5 @@ megatron_output/ # ast template ast_index_file.py test_cookbook/ -/test*.py \ No newline at end of file +/test*.py +swanlog/ \ No newline at end of file diff --git a/cookbook/client/tinker/megatron/server.py b/cookbook/client/tinker/megatron/server.py index f04dfffa..28e04f19 100644 --- a/cookbook/client/tinker/megatron/server.py +++ b/cookbook/client/tinker/megatron/server.py @@ -15,7 +15,7 @@ # Resolve the path to server_config.yaml relative to this script's location file_dir = os.path.abspath(os.path.dirname(__file__)) -config_path = os.path.join(file_dir, 'server_config.yaml') +config_path = os.path.join(file_dir, 'server_config_7b.yaml') # Launch the Twinkle server — this call blocks until the server is shut down launch_server(config_path=config_path) \ No newline at end of file diff --git a/cookbook/client/tinker/megatron/server_config_7b.yaml b/cookbook/client/tinker/megatron/server_config_7b.yaml index 7b625570..2724ea4c 100644 --- a/cookbook/client/tinker/megatron/server_config_7b.yaml +++ b/cookbook/client/tinker/megatron/server_config_7b.yaml @@ -21,7 +21,8 @@ applications: route_prefix: /api/v1 # API endpoint prefix (Tinker-compatible) import_path: server # Python module to import args: - + server_config: + per_token_adapter_limit: 30 # Max concurrent LoRA adapters per user (global limit) deployments: - name: TinkerCompatServer autoscaling_config: @@ -53,8 +54,7 @@ applications: tps_limit: 10000 # Max tokens per second for a single user max_input_tokens: 10000 # Maximum input tokens per request adapter_config: - per_token_adapter_limit: 30 # Max concurrent LoRA adapters per user - adapter_timeout: 60 # Seconds before idle adapter unload + adapter_timeout: 30 # Seconds before idle adapter unload adapter_max_lifetime: 36000 # Maximum lifetime of an adapter in seconds (e.g., 10 hours) deployments: - name: ModelManagement @@ -71,37 +71,37 @@ applications: # 3. Sampler Service - Runs inference / sampling using vLLM engine # Used for generating text from the model (e.g., evaluating LoRA results). - - name: sampler-Qwen2.5-7B-Instruct - route_prefix: /api/v1/sampler/Qwen/Qwen2.5-7B-Instruct - import_path: sampler - args: - model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope model identifier - nproc_per_node: 2 # Number of GPU processes per node - sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler) - engine_args: # vLLM engine-specific settings - max_model_len: 4096 # Maximum sequence length the engine supports - gpu_memory_utilization: 0.5 # Fraction of GPU memory to use (0.0-1.0) - enable_lora: true # Allow loading LoRA adapters during inference - logprobs_mode: processed_logprobs # Logprobs mode for sampling results - device_group: # Logical device group for the sampler - name: sampler - ranks: [2] # GPU rank indices to use - device_type: cuda - device_mesh: - device_type: cuda - dp_size: 1 - queue_config: - rps_limit: 100 # Max requests per second - tps_limit: 100000 # Max tokens per second - deployments: - - name: SamplerManagement - autoscaling_config: - min_replicas: 1 - max_replicas: 1 - target_ongoing_requests: 16 - ray_actor_options: - num_cpus: 0.1 - runtime_env: - env_vars: - TWINKLE_TRUST_REMOTE_CODE: "0" - DEVICE_COUNT_PER_PHYSICAL_NODE: "8" + # - name: sampler-Qwen2.5-7B-Instruct + # route_prefix: /api/v1/sampler/Qwen/Qwen2.5-7B-Instruct + # import_path: sampler + # args: + # model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope model identifier + # nproc_per_node: 2 # Number of GPU processes per node + # sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler) + # engine_args: # vLLM engine-specific settings + # max_model_len: 4096 # Maximum sequence length the engine supports + # gpu_memory_utilization: 0.5 # Fraction of GPU memory to use (0.0-1.0) + # enable_lora: true # Allow loading LoRA adapters during inference + # logprobs_mode: processed_logprobs # Logprobs mode for sampling results + # device_group: # Logical device group for the sampler + # name: sampler + # ranks: [2] # GPU rank indices to use + # device_type: cuda + # device_mesh: + # device_type: cuda + # dp_size: 1 + # queue_config: + # rps_limit: 100 # Max requests per second + # tps_limit: 100000 # Max tokens per second + # deployments: + # - name: SamplerManagement + # autoscaling_config: + # min_replicas: 1 + # max_replicas: 1 + # target_ongoing_requests: 16 + # ray_actor_options: + # num_cpus: 0.1 + # runtime_env: + # env_vars: + # TWINKLE_TRUST_REMOTE_CODE: "0" + # DEVICE_COUNT_PER_PHYSICAL_NODE: "8" From c7b235b9d3a36f8f178abc2dec6f52a51a3027bf Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Fri, 13 Feb 2026 09:42:43 +0800 Subject: [PATCH 11/22] fix lint --- cookbook/client/tinker/sample.py | 14 +- cookbook/client/tinker/self_congnition.py | 8 +- cookbook/client/tinker/short_math_grpo.py | 4 +- .../server/tinker/common/compat_base.py | 2 +- src/twinkle/server/tinker/model.py | 63 +++---- src/twinkle/server/tinker/sampler.py | 6 +- src/twinkle/server/tinker/server.py | 35 ++-- src/twinkle/server/twinkle/model.py | 13 +- src/twinkle/server/twinkle/sampler.py | 24 ++- src/twinkle/server/utils/adapter_manager.py | 88 ++++----- src/twinkle/server/utils/state.py | 12 +- src/twinkle/server/utils/task_queue.py | 174 +++++++++--------- 12 files changed, 208 insertions(+), 235 deletions(-) diff --git a/cookbook/client/tinker/sample.py b/cookbook/client/tinker/sample.py index f8a76f07..6a1ce937 100644 --- a/cookbook/client/tinker/sample.py +++ b/cookbook/client/tinker/sample.py @@ -14,25 +14,25 @@ from twinkle.template import Template # Step 1: Define the base model and connect to the server -base_model = "Qwen/Qwen2.5-7B-Instruct" +base_model = 'Qwen/Qwen2.5-7B-Instruct' service_client = init_tinker_compat_client(base_url='http://localhost:8000') # Step 2: Create a sampling client by loading weights from a saved checkpoint. # The model_path is a twinkle:// URI pointing to a previously saved LoRA checkpoint. # The server will load the base model and apply the LoRA adapter weights. sampling_client = service_client.create_sampling_client( - model_path="twinkle://20260212_174205-Qwen_Qwen2_5-7B-Instruct-51edc9ed/weights/twinkle-lora-2", + model_path='twinkle://20260212_174205-Qwen_Qwen2_5-7B-Instruct-51edc9ed/weights/twinkle-lora-2', base_model=base_model) # Step 3: Load the tokenizer locally to encode the prompt and decode the results -print(f"Using model {base_model}") +print(f'Using model {base_model}') template = Template(model_id=f'ms://{base_model}') trajectory = Trajectory( messages=[ Message(role='system', content='You are a helpful assistant'), - Message(role='user', content="你是谁?"), + Message(role='user', content='你是谁?'), ] ) @@ -44,8 +44,8 @@ prompt = types.ModelInput.from_ints(input_ids) params = types.SamplingParams( max_tokens=128, # Maximum number of tokens to generate - temperature=0.7, - stop=["\n"] # Stop generation when a newline character is produced + temperature=0.7, + stop=['\n'] # Stop generation when a newline character is produced ) # Step 5: Send the sampling request to the server. @@ -57,4 +57,4 @@ # Step 6: Decode and print the generated responses print('Responses:') for i, seq in enumerate(result.sequences): - print(f"{i}: {repr(template.decode(seq.tokens))}") + print(f'{i}: {repr(template.decode(seq.tokens))}') diff --git a/cookbook/client/tinker/self_congnition.py b/cookbook/client/tinker/self_congnition.py index 078de3e7..925a32a4 100644 --- a/cookbook/client/tinker/self_congnition.py +++ b/cookbook/client/tinker/self_congnition.py @@ -82,7 +82,7 @@ def eval(): # Step 1: Load the trained LoRA checkpoint for inference # Path to a previously saved LoRA checkpoint (twinkle:// URI) - weight_path = "twinkle://20260212_174205-Qwen_Qwen2_5-7B-Instruct-51edc9ed/weights/twinkle-lora-2" + weight_path = 'twinkle://20260212_174205-Qwen_Qwen2_5-7B-Instruct-51edc9ed/weights/twinkle-lora-2' # Connect to the server and create a sampling client with the trained weights service_client = init_tinker_compat_client(base_url='http://localhost:8000') @@ -96,7 +96,7 @@ def eval(): trajectory = Trajectory( messages=[ Message(role='system', content='You are a helpful assistant'), - Message(role='user', content="你是谁?"), + Message(role='user', content='你是谁?'), ] ) @@ -121,9 +121,9 @@ def eval(): # Decode and print each response print('Responses:') for i, seq in enumerate(result.sequences): - print(f"{i}: {repr(template.decode(seq.tokens))}") + print(f'{i}: {repr(template.decode(seq.tokens))}') -if __name__ == "__main__": +if __name__ == '__main__': # train() # Uncomment to run training eval() # Run evaluation / inference diff --git a/cookbook/client/tinker/short_math_grpo.py b/cookbook/client/tinker/short_math_grpo.py index 5d3926d5..944d54c2 100644 --- a/cookbook/client/tinker/short_math_grpo.py +++ b/cookbook/client/tinker/short_math_grpo.py @@ -208,8 +208,8 @@ def main(): dataset = create_Math_dataset() dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) template = Template(model_id=f'ms://{BASE_MODEL}') - - logger.info("Dataset and template initialized") + + logger.info('Dataset and template initialized') # Step 2: Initialize the Tinker-compatible client logger.info('Connecting to Tinker server...') diff --git a/src/twinkle/server/tinker/common/compat_base.py b/src/twinkle/server/tinker/common/compat_base.py index 0cd4bcf9..1e476bbb 100644 --- a/src/twinkle/server/tinker/common/compat_base.py +++ b/src/twinkle/server/tinker/common/compat_base.py @@ -101,7 +101,7 @@ def _to_float(v): head, unit = s.split() # ignore unit/tail cleaned[f'{key}/{unit}'] = float(head) except Exception: - m = re.match(r"^([+-]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?)", s) + m = re.match(r'^([+-]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?)', s) if m: cleaned[key] = float(m.group(1)) diff --git a/src/twinkle/server/tinker/model.py b/src/twinkle/server/tinker/model.py index c901511b..2a119162 100644 --- a/src/twinkle/server/tinker/model.py +++ b/src/twinkle/server/tinker/model.py @@ -120,13 +120,13 @@ def __init__(self, def _cleanup_adapter(self, adapter_name: str) -> None: """Common adapter cleanup logic used by both manual unload and automatic expiration. - + This method handles: 1. Clearing adapter state 2. Removing adapter from model 3. Unregistering from adapter manager 4. Removing from server state - + Args: adapter_name: Name of the adapter to clean up """ @@ -134,11 +134,11 @@ def _cleanup_adapter(self, adapter_name: str) -> None: if self.get_adapter_info(adapter_name): # Clear adapter state self.clear_adapter_state(adapter_name) - + self.model.remove_adapter(adapter_name) # Unregister from adapter manager self.unregister_adapter(adapter_name) - + # Remove from server state self.state.unload_model(adapter_name) @@ -175,16 +175,13 @@ async def _create_adapter(): # TODO: support more lora config parameters, train_unembed, etc. lora_cfg = LoraConfig(r=body.lora_config.rank, target_modules='all-linear') - adapter_name = self.get_adapter_name( - adapter_name=model_id) - + adapter_name = self.get_adapter_name(adapter_name=model_id) + # Register adapter FIRST (limit check happens inside register_adapter) - self.register_adapter( - adapter_name, request.state.token, session_id=body.session_id) - + self.register_adapter(adapter_name, request.state.token, session_id=body.session_id) + # Create adapter AFTER successful registration - self.model.add_adapter_to_model( - adapter_name=adapter_name, config_or_dir=lora_cfg) + self.model.add_adapter_to_model(adapter_name=adapter_name, config_or_dir=lora_cfg) self.model.set_template('Template', adapter_name=adapter_name, model_id=self.base_model) self.model.set_processor('InputProcessor', adapter_name=adapter_name) @@ -193,8 +190,7 @@ async def _create_adapter(): # Fresh adapter has no accumulated gradients. self.set_adapter_state(adapter_name, 'grad_ready', False) - training_run_manager = create_training_run_manager( - request.state.token) + training_run_manager = create_training_run_manager(request.state.token) training_run_manager.save(model_id, body) return types.CreateModelResponse(model_id=model_id) @@ -261,8 +257,7 @@ async def unload_model(self, request: Request, body: types.UnloadModelRequest) - async def _do_unload(): # Only remove adapter, not the base model - adapter_name = self.get_adapter_name( - adapter_name=body.model_id) + adapter_name = self.get_adapter_name(adapter_name=body.model_id) # Use common cleanup logic self._cleanup_adapter(adapter_name) return types.UnloadModelResponse(model_id=body.model_id) @@ -315,9 +310,7 @@ async def _do_forward(): # Calculate input tokens and batch size for validation datum_list = body.forward_input.data - input_tokens = sum( - len(d.model_input.to_ints()) for d in datum_list - ) + input_tokens = sum(len(d.model_input.to_ints()) for d in datum_list) batch_size = len(datum_list) return await self.schedule_task( _do_forward, @@ -360,11 +353,12 @@ async def _do_forward_backward(): loss_fn_config = body.forward_backward_input.loss_fn_config or {} # Unified forward_backward for both Megatron and Transformers - output, loss = self.model.forward_backward(inputs=datum_list, - adapter_name=adapter_name, - loss_fn=loss_fn, - **loss_fn_config) - output_type = 'ImportanceSamplingLossReturn' if loss_fn == 'importance_sampling' else 'CrossEntropyLossReturn' + output, loss = self.model.forward_backward( + inputs=datum_list, adapter_name=adapter_name, loss_fn=loss_fn, **loss_fn_config) + if loss_fn == 'importance_sampling': + output_type = 'ImportanceSamplingLossReturn' + else: + output_type = 'CrossEntropyLossReturn' # Mark gradients as ready after a successful forward_backward. self.set_adapter_state(adapter_name, 'grad_ready', True) return types.ForwardBackwardOutput( @@ -381,9 +375,7 @@ async def _do_forward_backward(): # Calculate input tokens and batch size for validation datum_list = body.forward_backward_input.data - input_tokens = sum( - len(d.model_input.to_ints()) for d in datum_list - ) + input_tokens = sum(len(d.model_input.to_ints()) for d in datum_list) batch_size = len(datum_list) return await self.schedule_task( _do_forward_backward, @@ -417,14 +409,13 @@ async def _do_optim(): # Disallow empty step (must have at least one forward_backward since last step) if not self.get_adapter_state(adapter_name, 'grad_ready', False): raise RuntimeError( - f"No accumulated gradients for adapter={adapter_name}; call forward_backward before optim_step" + f'No accumulated gradients for adapter={adapter_name}; call forward_backward before optim_step' # noqa: E501 ) # Touch adapter to reset inactivity counter self.touch_adapter(adapter_name) - self.model.step(adam_params=body.adam_params, - adapter_name=adapter_name) + self.model.step(adam_params=body.adam_params, adapter_name=adapter_name) # Clear grad-ready after a successful step. self.set_adapter_state(adapter_name, 'grad_ready', False) metrics = self.model.calculate_metric(is_training=True, adapter_name=adapter_name) @@ -590,15 +581,15 @@ async def _do_load(): weight_path = body.path load_optimizer = body.optimizer - self.model.load(checkpoint_dir=weight_path, - load_optimizer=load_optimizer, - adapter_name=adapter_name, - token=token) + self.model.load( + checkpoint_dir=weight_path, + load_optimizer=load_optimizer, + adapter_name=adapter_name, + token=token) # Loading a checkpoint should reset step readiness. self.set_adapter_state(adapter_name, 'grad_ready', False) - return types.LoadWeightsResponse(path=body.path, - type='load_weights') + return types.LoadWeightsResponse(path=body.path, type='load_weights') except Exception: logger.error(traceback.format_exc()) return types.RequestFailedResponse( diff --git a/src/twinkle/server/tinker/sampler.py b/src/twinkle/server/tinker/sampler.py index 231da880..bf4108c9 100644 --- a/src/twinkle/server/tinker/sampler.py +++ b/src/twinkle/server/tinker/sampler.py @@ -160,14 +160,14 @@ async def _do_sample(): token = request.state.token checkpoint_manager = create_checkpoint_manager(token) adapter_name, adapter_uri = checkpoint_manager.parse_adapter_uri(model_path) - + # Validate adapter URI existence if provided if not adapter_uri or not os.path.exists(adapter_uri): return types.RequestFailedResponse( - error=f"Adapter URI {model_path} does not exist. Please check the model_path.", + error=f'Adapter URI {model_path} does not exist. Please check the model_path.', category=types.RequestErrorCategory.User, ) - + # Convert tinker SamplingParams to twinkle SamplingParams if needed sampling_params = None if body.sampling_params: diff --git a/src/twinkle/server/tinker/server.py b/src/twinkle/server/tinker/server.py index 68b55d61..2e669f56 100644 --- a/src/twinkle/server/tinker/server.py +++ b/src/twinkle/server/tinker/server.py @@ -28,12 +28,11 @@ logger = logging.getLogger(__name__) -def build_server_app( - deploy_options: Dict[str, Any], - supported_models: Optional[List[types.SupportedModel]] = None, - server_config: Dict[str, Any] = {}, - **kwargs -): + +def build_server_app(deploy_options: dict[str, Any], + supported_models: list[types.SupportedModel] | None = None, + server_config: dict[str, Any] = {}, + **kwargs): """Build and configure the Tinker-compatible server application. This factory function creates a FastAPI application with Ray Serve deployment @@ -66,8 +65,11 @@ class TinkerCompatServer: - Proxying to model/sampler deployments - Training run and checkpoint CRUD operations """ - - def __init__(self, supported_models: Optional[List[types.SupportedModel]] = None, server_config: Dict[str, Any] = {}, **kwargs) -> None: + + def __init__(self, + supported_models: list[types.SupportedModel] | None = None, + server_config: dict[str, Any] = {}, + **kwargs) -> None: """Initialize the Tinker-compatible server. Args: @@ -78,13 +80,13 @@ def __init__(self, supported_models: Optional[List[types.SupportedModel]] = None self.state = get_server_state(**server_config) # Disable proxy for internal requests to avoid routing through external proxies self.client = httpx.AsyncClient(timeout=None, trust_env=False) - self.route_prefix = kwargs.get("route_prefix", "/api/v1") + self.route_prefix = kwargs.get('route_prefix', '/api/v1') self.supported_models = self.normalize_models(supported_models) or [ - types.SupportedModel(model_name="Qwen/Qwen2.5-0.5B-Instruct"), - types.SupportedModel(model_name="Qwen/Qwen2.5-3B-Instruct"), - types.SupportedModel(model_name="Qwen/Qwen2.5-7B-Instruct"), - types.SupportedModel(model_name="Qwen/Qwen2.5-72B-Instruct"), - types.SupportedModel(model_name="Qwen/Qwen3-30B-A3B-Instruct-2507"), + types.SupportedModel(model_name='Qwen/Qwen2.5-0.5B-Instruct'), + types.SupportedModel(model_name='Qwen/Qwen2.5-3B-Instruct'), + types.SupportedModel(model_name='Qwen/Qwen2.5-7B-Instruct'), + types.SupportedModel(model_name='Qwen/Qwen2.5-72B-Instruct'), + types.SupportedModel(model_name='Qwen/Qwen3-30B-A3B-Instruct-2507'), ] # Lock for ModelScope config file operations (login writes, get_user_info reads) self._modelscope_config_lock = asyncio.Lock() @@ -682,7 +684,4 @@ async def save_weights_for_sampler(self, request: Request, body: types.SaveWeigh return await self._proxy_to_model(request, 'save_weights_for_sampler', base_model) return TinkerCompatServer.options(**deploy_options).bind( - supported_models=supported_models, - server_config=server_config, - **kwargs - ) + supported_models=supported_models, server_config=server_config, **kwargs) diff --git a/src/twinkle/server/twinkle/model.py b/src/twinkle/server/twinkle/model.py index 77fd7e1f..aa9450f9 100644 --- a/src/twinkle/server/twinkle/model.py +++ b/src/twinkle/server/twinkle/model.py @@ -200,18 +200,17 @@ def _on_adapter_expired(self, adapter_name: str) -> None: if self.get_adapter_info(adapter_name): # Clear adapter state self.clear_adapter_state(adapter_name) - + self.model.remove_adapter(adapter_name) # Unregister from adapter manager self.unregister_adapter(adapter_name) - + # Remove from server state self.state.unload_model(adapter_name) # Remove adapter from model self.model.remove_adapter(adapter_name) - - @app.post("/create") + @app.post('/create') def create(self, request: Request, body: CreateRequest): return {'status': 'ok'} @@ -508,13 +507,13 @@ def add_adapter_to_model(self, request: Request, body: AddAdapterRequest): # Extract token for metadata storage token = request.state.token training_run_manager = create_training_run_manager(token) - + # Register adapter FIRST (limit check happens inside register_adapter) self.register_adapter(adapter_name, token) - + # Create adapter AFTER successful registration self.model.add_adapter_to_model(adapter_name, config, **extra_kwargs) - + # Save training run metadata (similar to tinker's create_model) # Create a training run config from the adapter configuration lora_config = None diff --git a/src/twinkle/server/twinkle/sampler.py b/src/twinkle/server/twinkle/sampler.py index fcb2edf9..857c53f6 100644 --- a/src/twinkle/server/twinkle/sampler.py +++ b/src/twinkle/server/twinkle/sampler.py @@ -10,26 +10,24 @@ 4. Flexible sampling parameters """ import traceback -from typing import Dict, Any, List, Optional, Union - from fastapi import FastAPI, Request from pydantic import BaseModel, Field from ray import serve -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union import twinkle from twinkle import DeviceGroup, DeviceMesh -from twinkle.data_format import Trajectory, InputFeature, SamplingParams +from twinkle.data_format import InputFeature, SamplingParams, Trajectory from twinkle.server.utils.adapter_manager import AdapterManagerMixin -from twinkle.server.utils.validation import verify_request_token, get_token_from_request -from twinkle.server.utils.state import get_server_state, ServerStateProxy +from twinkle.server.utils.state import ServerStateProxy, get_server_state +from twinkle.server.utils.validation import get_token_from_request, verify_request_token from twinkle.utils.logger import get_logger logger = get_logger() - # ----- Request/Response Models ----- + class SampleRequest(BaseModel): """Request body for the /sample endpoint.""" inputs: Any = Field(..., description='List of Trajectory or InputFeature dicts') @@ -183,10 +181,10 @@ def _on_adapter_expired(self, adapter_name: str, token: str) -> None: """Handle expired adapters by removing them from the sampler.""" try: self.sampler.remove_adapter(adapter_name) - logger.info(f"Removed expired adapter {adapter_name}") + logger.info(f'Removed expired adapter {adapter_name}') # Adapter count is now tracked dynamically, no manual update needed except Exception as e: - logger.warning(f"Failed to remove expired adapter {adapter_name}: {e}") + logger.warning(f'Failed to remove expired adapter {adapter_name}: {e}') @staticmethod def _get_adapter_name(request: Request, adapter_name: Optional[str]) -> Optional[str]: @@ -285,14 +283,14 @@ def add_adapter_to_sampler(self, request: Request, body: AddAdapterRequest) -> A from peft import LoraConfig config = LoraConfig(**body.config) if isinstance(body.config, dict) else body.config - + self.register_adapter(full_adapter_name, token) self.sampler.add_adapter_to_sampler(full_adapter_name, config) return AddAdapterResponse(adapter_name=full_adapter_name) - @app.post("/heartbeat", response_model=HeartbeatResponse) + @app.post('/heartbeat', response_model=HeartbeatResponse) def heartbeat(self, request: Request, body: HeartbeatRequest) -> HeartbeatResponse: """Keep an adapter alive by resetting its inactivity timer.""" full_adapter_name = self._get_adapter_name(request, body.adapter_name) @@ -300,5 +298,5 @@ def heartbeat(self, request: Request, body: HeartbeatRequest) -> HeartbeatRespon self.touch_adapter(full_adapter_name) return HeartbeatResponse() - return SamplerManagement.options(**deploy_options).bind( - nproc_per_node, device_group, device_mesh, sampler_type, engine_args, adapter_config, **kwargs) \ No newline at end of file + return SamplerManagement.options(**deploy_options).bind(nproc_per_node, device_group, device_mesh, sampler_type, + engine_args, adapter_config, **kwargs) diff --git a/src/twinkle/server/utils/adapter_manager.py b/src/twinkle/server/utils/adapter_manager.py index 66f46472..3af4d06a 100644 --- a/src/twinkle/server/utils/adapter_manager.py +++ b/src/twinkle/server/utils/adapter_manager.py @@ -39,8 +39,8 @@ class AdapterManagerMixin: """ # Type hint for state attribute that inheriting classes must provide - state: 'ServerStateProxy' - + state: ServerStateProxy + def _init_adapter_manager( self, adapter_timeout: float = 1800.0, @@ -65,16 +65,17 @@ def _init_adapter_manager( self._adapter_max_lifetime = adapter_max_lifetime # Adapter lifecycle tracking - # Dict mapping adapter_name -> {'token': str, 'session_id': str, 'last_activity': float, 'created_at': float, 'inactivity_counter': int} - self._adapter_records: Dict[str, Dict[str, Any]] = {} + # Dict mapping adapter_name -> + # {'token': str, 'session_id': str, 'last_activity': float, 'created_at': float, 'inactivity_counter': int} + self._adapter_records: dict[str, dict[str, Any]] = {} # Track adapter count per token - self._adapter_counts: Dict[str, int] = {} + self._adapter_counts: dict[str, int] = {} # Countdown thread self._adapter_countdown_thread: threading.Thread | None = None self._adapter_countdown_running = False - def register_adapter(self, adapter_name: str, token: str, session_id: Optional[str] = None) -> None: + def register_adapter(self, adapter_name: str, token: str, session_id: str | None = None) -> None: """Register a new adapter for lifecycle tracking. Args: @@ -82,7 +83,7 @@ def register_adapter(self, adapter_name: str, token: str, session_id: Optional[s token: User token that owns this adapter. session_id: Optional session ID to associate with this adapter. If provided, adapter will expire when the session expires. - + Raises: RuntimeError: If adapter limit is exceeded for this token. """ @@ -90,7 +91,7 @@ def register_adapter(self, adapter_name: str, token: str, session_id: Optional[s allowed, reason = self.check_adapter_limit(token) if not allowed: raise RuntimeError(reason) - + current_time = time.time() self._adapter_records[adapter_name] = { 'token': token, @@ -101,9 +102,8 @@ def register_adapter(self, adapter_name: str, token: str, session_id: Optional[s 'state': {}, 'expiring': False, } - logger.debug( - f"[AdapterManager] Registered adapter {adapter_name} for token {token[:8]}..." + - (f" (session: {session_id})" if session_id else "")) + logger.debug(f'[AdapterManager] Registered adapter {adapter_name} for token {token[:8]}...' + + (f' (session: {session_id})' if session_id else '')) def _is_session_alive(self, session_id: str) -> bool: """Check if a session is still alive via state proxy. @@ -116,12 +116,12 @@ def _is_session_alive(self, session_id: str) -> bool: """ if not session_id: return True # No session association means always alive - + # Get session last heartbeat through proxy last_heartbeat = self.state.get_session_last_heartbeat(session_id) if last_heartbeat is None: return False # Session doesn't exist - + # Check if session has timed out using adapter_timeout return (time.time() - last_heartbeat) < self._adapter_timeout @@ -138,7 +138,8 @@ def unregister_adapter(self, adapter_name: str) -> bool: adapter_info = self._adapter_records.pop(adapter_name) token = adapter_info.get('token') logger.debug( - f"[AdapterManager] Unregistered adapter {adapter_name} for token {token[:8] if token else 'unknown'}...") + f"[AdapterManager] Unregistered adapter {adapter_name} for token {token[:8] if token else 'unknown'}..." + ) return True return False @@ -217,13 +218,11 @@ def _on_adapter_expired(self, adapter_name: str) -> None: Args: adapter_name: Name of the expired adapter. - + Raises: NotImplementedError: If not overridden by inheriting class. """ - raise NotImplementedError( - f"_on_adapter_expired must be implemented by {self.__class__.__name__}" - ) + raise NotImplementedError(f'_on_adapter_expired must be implemented by {self.__class__.__name__}') @staticmethod def get_adapter_name(adapter_name: str) -> str: @@ -243,7 +242,7 @@ def assert_adapter_exists(self, adapter_name: str) -> None: """Validate that an adapter exists and is not expiring.""" info = self._adapter_records.get(adapter_name) assert adapter_name and info is not None and not info.get('expiring'), \ - f"Adapter {adapter_name} not found" + f'Adapter {adapter_name} not found' def _adapter_countdown_loop(self) -> None: """Background thread that monitors and handles inactive adapters. @@ -259,7 +258,7 @@ def _adapter_countdown_loop(self) -> None: time.sleep(1) now = time.time() - expired_adapters: List[Tuple[str, Optional[str]]] = [] + expired_adapters: list[tuple[str, str | None]] = [] # Create snapshot to avoid modification during iteration adapter_snapshot = list(self._adapter_records.items()) for adapter_name, info in adapter_snapshot: @@ -267,15 +266,13 @@ def _adapter_countdown_loop(self) -> None: continue session_id = info.get('session_id') - created_at = info.get("created_at") - + created_at = info.get('created_at') + # Check TTL for both cases exceeded_ttl = ( - self._adapter_max_lifetime - and self._adapter_max_lifetime > 0 - and (now - created_at) > self._adapter_max_lifetime - ) - + self._adapter_max_lifetime and self._adapter_max_lifetime > 0 + and (now - created_at) > self._adapter_max_lifetime) + # Different logic based on session association if session_id: # Has session: check session expiration and TTL @@ -283,29 +280,27 @@ def _adapter_countdown_loop(self) -> None: should_expire = session_expired or exceeded_ttl expiration_reasons = [] if exceeded_ttl: - expiration_reasons.append("ttl_exceeded") + expiration_reasons.append('ttl_exceeded') if session_expired: - expiration_reasons.append("session_expired") + expiration_reasons.append('session_expired') else: # No session: check inactivity timeout and TTL - info["inactivity_counter"] = info.get("inactivity_counter", 0) + 1 - exceeded_inactivity = info["inactivity_counter"] > self._adapter_timeout + info['inactivity_counter'] = info.get('inactivity_counter', 0) + 1 + exceeded_inactivity = info['inactivity_counter'] > self._adapter_timeout should_expire = exceeded_ttl or exceeded_inactivity expiration_reasons = [] if exceeded_ttl: - expiration_reasons.append("ttl_exceeded") + expiration_reasons.append('ttl_exceeded') if exceeded_inactivity: - expiration_reasons.append("inactivity_timeout") - + expiration_reasons.append('inactivity_timeout') + if should_expire: info['expiring'] = True info['state'] = {} # best-effort clear token = info.get('token') expired_adapters.append((adapter_name, token)) - logger.debug( - f"[AdapterManager] Adapter {adapter_name} expired " - f"(reasons={','.join(expiration_reasons)}, session={session_id})" - ) + logger.debug(f'[AdapterManager] Adapter {adapter_name} expired ' + f"(reasons={','.join(expiration_reasons)}, session={session_id})") for adapter_name, token in expired_adapters: success = False @@ -313,9 +308,7 @@ def _adapter_countdown_loop(self) -> None: self._on_adapter_expired(adapter_name) success = True except Exception as e: - logger.warning( - f"[AdapterManager] Error while expiring adapter {adapter_name}: {e}" - ) + logger.warning(f'[AdapterManager] Error while expiring adapter {adapter_name}: {e}') finally: if success: self._adapter_records.pop(adapter_name, None) @@ -325,7 +318,7 @@ def _adapter_countdown_loop(self) -> None: info['expiring'] = False except Exception as e: - logger.warning(f"[AdapterManager] Error in countdown loop: {e}") + logger.warning(f'[AdapterManager] Error in countdown loop: {e}') continue logger.debug('[AdapterManager] Countdown thread stopped') @@ -354,7 +347,7 @@ def stop_adapter_countdown(self) -> None: self._adapter_countdown_thread.join(timeout=2.0) logger.debug('[AdapterManager] Countdown thread stopped') - def check_adapter_limit(self, token: str) -> Tuple[bool, Optional[str]]: + def check_adapter_limit(self, token: str) -> tuple[bool, str | None]: """Check adapter count for a user token. This method enforces per-user adapter limits to prevent resource exhaustion. @@ -368,13 +361,10 @@ def check_adapter_limit(self, token: str) -> Tuple[bool, Optional[str]]: If allowed is False, reason contains the explanation. """ # Count adapters directly from _adapter_records - current_count = sum( - 1 for record in self._adapter_records.values() - if record.get('token') == token and not record.get('expiring', False) - ) + current_count = sum(1 for record in self._adapter_records.values() + if record.get('token') == token and not record.get('expiring', False)) # Check if current count exceeds limit if current_count >= self._per_token_adapter_limit: - return False, f"Adapter limit exceeded: {current_count}/{self._per_token_adapter_limit} adapters" + return False, f'Adapter limit exceeded: {current_count}/{self._per_token_adapter_limit} adapters' return True, None - diff --git a/src/twinkle/server/utils/state.py b/src/twinkle/server/utils/state.py index 00866d9a..1f03408e 100644 --- a/src/twinkle/server/utils/state.py +++ b/src/twinkle/server/utils/state.py @@ -86,13 +86,13 @@ def touch_session(self, session_id: str) -> bool: self.sessions[session_id]['last_heartbeat'] = time.time() return True - def get_session_last_heartbeat(self, session_id: str) -> Optional[float]: + def get_session_last_heartbeat(self, session_id: str) -> float | None: """ Get the last heartbeat timestamp for a session. - + Args: session_id: The session ID to query - + Returns: Last heartbeat timestamp, or None if session doesn't exist """ @@ -477,7 +477,7 @@ def create_session(self, payload: dict[str, Any]) -> str: def touch_session(self, session_id: str) -> bool: return ray.get(self._actor.touch_session.remote(session_id)) - def get_session_last_heartbeat(self, session_id: str) -> Optional[float]: + def get_session_last_heartbeat(self, session_id: str) -> float | None: return ray.get(self._actor.get_session_last_heartbeat.remote(session_id)) # ----- Model Registration ----- @@ -586,7 +586,7 @@ def get_server_state(actor_name: str = 'twinkle_server_state', auto_start_cleanup: Whether to automatically start the cleanup task (default: True) **server_state_kwargs: Additional keyword arguments passed to ServerState constructor (e.g., expiration_timeout, cleanup_interval, per_token_adapter_limit) - + Returns: A ServerStateProxy for interacting with the actor """ @@ -605,4 +605,4 @@ def get_server_state(actor_name: str = 'twinkle_server_state', except ValueError: actor = ray.get_actor(actor_name) assert actor is not None - return ServerStateProxy(actor) \ No newline at end of file + return ServerStateProxy(actor) diff --git a/src/twinkle/server/utils/task_queue.py b/src/twinkle/server/utils/task_queue.py index e4b6d904..e87a4d9a 100644 --- a/src/twinkle/server/utils/task_queue.py +++ b/src/twinkle/server/utils/task_queue.py @@ -70,8 +70,8 @@ class TaskQueueConfig: enabled: bool = True # Rate limiting enabled by default # Remove tokens after 10x window inactivity token_cleanup_multiplier: float = 10.0 - token_cleanup_interval: float = 60.0 # Run cleanup every 60 seconds - max_input_tokens: int = 10000 # Maximum input tokens per request + token_cleanup_interval: float = 60.0 # Run cleanup every 60 seconds + max_input_tokens: int = 10000 # Maximum input tokens per request @classmethod def from_dict(cls, config_dict: dict[str, Any] | None = None) -> TaskQueueConfig: @@ -106,22 +106,22 @@ def from_dict(cls, config_dict: dict[str, Any] | None = None) -> TaskQueueConfig if 'token_cleanup_multiplier' in config_dict: config.token_cleanup_multiplier = float(config_dict['token_cleanup_multiplier']) if 'token_cleanup_interval' in config_dict: - config.token_cleanup_interval = float( - config_dict['token_cleanup_interval']) + config.token_cleanup_interval = float(config_dict['token_cleanup_interval']) if 'max_input_tokens' in config_dict: config.max_input_tokens = int(config_dict['max_input_tokens']) return config + @dataclass class _QueuedTask: request_id: str coro_factory: Callable[[], Coroutine] - model_id: Optional[str] - token: Optional[str] + model_id: str | None + token: str | None input_tokens: int - task_type: Optional[str] + task_type: str | None created_at: float - first_rate_limited_at: Optional[float] = None + first_rate_limited_at: float | None = None class TaskQueueMixin: @@ -165,7 +165,7 @@ def _init_task_queue(self, config: TaskQueueConfig | None = None) -> None: """ self._task_queue_config = config or TaskQueueConfig() # Per-key queues, but executed by a single global worker. - self._task_queues: Dict[str, asyncio.Queue] = {} + self._task_queues: dict[str, asyncio.Queue] = {} self._queue_order: Deque[str] = deque() self._new_task_event: asyncio.Event = asyncio.Event() @@ -181,23 +181,23 @@ def _init_task_queue(self, config: TaskQueueConfig | None = None) -> None: self._rate_limiter.start_cleanup_task() # Single worker to ensure model operations remain serial. - self._worker_task: Optional[asyncio.Task] = None + self._worker_task: asyncio.Task | None = None self._worker_started = False self._worker_start_lock = asyncio.Lock() # Event loop reference for thread-safe callbacks (e.g., adapter expiration thread) - self._event_loop: Optional[asyncio.AbstractEventLoop] = None + self._event_loop: asyncio.AbstractEventLoop | None = None @staticmethod def _queue_key( - model_id: Optional[str], - token: Optional[str], + model_id: str | None, + token: str | None, ) -> str: if model_id: - return f"model:{model_id}" + return f'model:{model_id}' if token: - return f"token:{token}" - return "default" + return f'token:{token}' + return 'default' async def _ensure_worker_started(self) -> None: """Ensure the single background worker is running.""" @@ -222,7 +222,7 @@ async def _queue_worker(self) -> None: Selection policy: round-robin across queue keys. If a task is rate-limited at execution time, it is requeued and the worker tries other queues. """ - print("[TaskQueue] Worker started") + print('[TaskQueue] Worker started') while True: try: # Wait until there is at least one queue with a task @@ -252,11 +252,14 @@ async def _queue_worker(self) -> None: # Global queue timeout if (now - task.created_at) > self._task_queue_config.queue_timeout: error_payload = { - 'error': f"Queue timeout exceeded: waited {now - task.created_at:.2f}s", + 'error': f'Queue timeout exceeded: waited {now - task.created_at:.2f}s', 'category': 'Server' } self.state.store_future_status( - task.request_id, TaskStatus.FAILED.value, task.model_id, result=error_payload, + task.request_id, + TaskStatus.FAILED.value, + task.model_id, + result=error_payload, queue_state=QueueState.PAUSED_CAPACITY.value, queue_state_reason=error_payload['error'], ) @@ -268,26 +271,25 @@ async def _queue_worker(self) -> None: # Execute executed_any = True self.state.store_future_status( - task.request_id, TaskStatus.RUNNING.value, task.model_id, - queue_state=QueueState.ACTIVE.value - ) + task.request_id, TaskStatus.RUNNING.value, task.model_id, queue_state=QueueState.ACTIVE.value) try: coro = task.coro_factory() result = await coro self.state.store_future_status( - task.request_id, TaskStatus.COMPLETED.value, task.model_id, result=result, - queue_state=QueueState.ACTIVE.value - ) + task.request_id, + TaskStatus.COMPLETED.value, + task.model_id, + result=result, + queue_state=QueueState.ACTIVE.value) except Exception: - error_payload = { - 'error': traceback.format_exc(), - 'category': 'Server' - } + error_payload = {'error': traceback.format_exc(), 'category': 'Server'} self.state.store_future_status( - task.request_id, TaskStatus.FAILED.value, task.model_id, result=error_payload, - queue_state=QueueState.ACTIVE.value - ) + task.request_id, + TaskStatus.FAILED.value, + task.model_id, + result=error_payload, + queue_state=QueueState.ACTIVE.value) finally: q.task_done() @@ -299,10 +301,10 @@ async def _queue_worker(self) -> None: await asyncio.sleep(min(self._task_queue_config.window_seconds, 0.1)) except asyncio.CancelledError: - logger.warning("[TaskQueue] Worker cancelled") + logger.warning('[TaskQueue] Worker cancelled') break except Exception: - logger.warning("Error in task queue worker") + logger.warning('Error in task queue worker') continue async def _fail_queue_tasks_async(self, queue_key: str, reason: str) -> None: @@ -316,14 +318,14 @@ async def _fail_queue_tasks_async(self, queue_key: str, reason: str) -> None: drained.append(q.get_nowait()) except asyncio.QueueEmpty: break - + for task in drained: - error_payload = { - 'error': reason, - 'category': 'Server' - } + error_payload = {'error': reason, 'category': 'Server'} self.state.store_future_status( - task.request_id, TaskStatus.FAILED.value, task.model_id, result=error_payload, + task.request_id, + TaskStatus.FAILED.value, + task.model_id, + result=error_payload, queue_state=QueueState.UNKNOWN.value, queue_state_reason=reason, ) @@ -337,14 +339,12 @@ async def _fail_queue_tasks_async(self, queue_key: str, reason: str) -> None: except ValueError: pass - def fail_pending_tasks_for_model(self, model_id: str, reason: str) -> None: """Fail and drop queued tasks for a model. Safe to call from non-async threads.""" queue_key = self._queue_key(model_id=model_id, token=None) if self._event_loop is None: # Best-effort: nothing we can do safely without a loop. - logger.warning( - f"[TaskQueue] fail_pending_tasks_for_model called without event loop: {queue_key}") + logger.warning(f'[TaskQueue] fail_pending_tasks_for_model called without event loop: {queue_key}') return def _schedule() -> None: @@ -355,12 +355,12 @@ def _schedule() -> None: async def _perform_preflight_checks( self, request_id: str, - model_id: Optional[str], - token: Optional[str], + model_id: str | None, + token: str | None, input_tokens: int, - batch_size: Optional[int] = None, - data_world_size: Optional[int] = None, - ) -> Optional[Dict[str, Any]]: + batch_size: int | None = None, + data_world_size: int | None = None, + ) -> dict[str, Any] | None: """Perform pre-flight checks including rate limiting and token validation. Args: @@ -379,13 +379,13 @@ async def _perform_preflight_checks( # Check max input tokens if input_tokens > self._task_queue_config.max_input_tokens: - error_msg = f"Input tokens ({input_tokens}) exceed maximum allowed ({self._task_queue_config.max_input_tokens})" - error_payload = { - 'error': error_msg, - 'category': 'User' - } + error_msg = f'Input tokens ({input_tokens}) exceed maximum allowed ({self._task_queue_config.max_input_tokens})' # noqa: E501 + error_payload = {'error': error_msg, 'category': 'User'} self.state.store_future_status( - request_id, TaskStatus.FAILED.value, model_id, result=error_payload, + request_id, + TaskStatus.FAILED.value, + model_id, + result=error_payload, queue_state=QueueState.UNKNOWN.value, queue_state_reason=error_msg, ) @@ -394,13 +394,13 @@ async def _perform_preflight_checks( # Check batch size if provided if batch_size is not None and data_world_size is not None: if batch_size < data_world_size: - error_msg = f"Batch size {batch_size} must be greater than or equal to data world size {data_world_size}" - error_payload = { - 'error': error_msg, - 'category': 'User' - } + error_msg = f'Batch size {batch_size} must be greater than or equal to data world size {data_world_size}' # noqa: E501 + error_payload = {'error': error_msg, 'category': 'User'} self.state.store_future_status( - request_id, TaskStatus.FAILED.value, model_id, result=error_payload, + request_id, + TaskStatus.FAILED.value, + model_id, + result=error_payload, queue_state=QueueState.UNKNOWN.value, queue_state_reason=error_msg, ) @@ -409,30 +409,30 @@ async def _perform_preflight_checks( # Check rate limits allowed, reason = await self._rate_limiter.check_and_record(token, input_tokens) if not allowed: - error_msg = f"Rate limit exceeded: {reason}" - error_payload = { - 'error': error_msg, - 'category': 'User' - } + error_msg = f'Rate limit exceeded: {reason}' + error_payload = {'error': error_msg, 'category': 'User'} self.state.store_future_status( - request_id, TaskStatus.FAILED.value, model_id, result=error_payload, + request_id, + TaskStatus.FAILED.value, + model_id, + result=error_payload, queue_state=QueueState.PAUSED_RATE_LIMIT.value, queue_state_reason=error_msg, ) return {'request_id': request_id, 'model_id': model_id} - + return None async def schedule_task( self, coro_factory: Callable[[], Coroutine], - model_id: Optional[str] = None, - token: Optional[str] = None, + model_id: str | None = None, + token: str | None = None, input_tokens: int = 0, - batch_size: Optional[int] = None, - data_world_size: Optional[int] = None, - task_type: Optional[str] = None, - ) -> Dict[str, Any]: + batch_size: int | None = None, + data_world_size: int | None = None, + task_type: str | None = None, + ) -> dict[str, Any]: """Schedule an async task with rate limiting and status tracking. This method replaces the old `schedule_task` function with proper @@ -459,12 +459,11 @@ async def schedule_task( Dict containing request_id and model_id for future retrieval. """ # Generate request_id first so it can be included in error responses - request_id = f"req_{uuid.uuid4().hex}" - + request_id = f'req_{uuid.uuid4().hex}' + # 1. Pre-flight checks: rate limiting, max token validation, and batch size validation - preflight_result = await self._perform_preflight_checks( - request_id, model_id, token, input_tokens, batch_size, data_world_size - ) + preflight_result = await self._perform_preflight_checks(request_id, model_id, token, input_tokens, batch_size, + data_world_size) if preflight_result is not None: return preflight_result @@ -472,8 +471,8 @@ async def schedule_task( self._event_loop = asyncio.get_running_loop() print( - f"[TaskQueue] Scheduling task {request_id}, rps_limit={self._task_queue_config.rps_limit}, enabled={self._task_queue_config.enabled}") - + f'[TaskQueue] Scheduling task {request_id}, rps_limit={self._task_queue_config.rps_limit}, enabled={self._task_queue_config.enabled}' # noqa: E501 + ) # 2. Register PENDING status FIRST self.state.store_future_status( @@ -489,7 +488,8 @@ async def schedule_task( # 5. Put task in queue and update status q = self._task_queues[queue_key] print( - f"[TaskQueue] Adding task {request_id} to queue key={queue_key} (current size: {q.qsize()}) type={task_type}") + f'[TaskQueue] Adding task {request_id} to queue key={queue_key} (current size: {q.qsize()}) type={task_type}' # noqa: E501 + ) await q.put( _QueuedTask( request_id=request_id, @@ -499,14 +499,10 @@ async def schedule_task( input_tokens=input_tokens, task_type=task_type, created_at=time.monotonic(), - ) - ) + )) self.state.store_future_status( - request_id, TaskStatus.QUEUED.value, model_id, - queue_state=QueueState.ACTIVE.value - ) - print( - f"[TaskQueue] Task {request_id} queued, new queue size: {q.qsize()} key={queue_key}") + request_id, TaskStatus.QUEUED.value, model_id, queue_state=QueueState.ACTIVE.value) + print(f'[TaskQueue] Task {request_id} queued, new queue size: {q.qsize()} key={queue_key}') self._new_task_event.set() @@ -571,4 +567,4 @@ async def shutdown_task_queue(self) -> None: self._task_queues.clear() self._queue_order.clear() - print("[TaskQueue] Task queue shutdown complete") + print('[TaskQueue] Task queue shutdown complete') From 82ad72a86541c48492afdff0648f064376bcc9c3 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Fri, 13 Feb 2026 10:17:42 +0800 Subject: [PATCH 12/22] update --- src/twinkle/server/utils/adapter_manager.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/twinkle/server/utils/adapter_manager.py b/src/twinkle/server/utils/adapter_manager.py index 3af4d06a..b45e6f61 100644 --- a/src/twinkle/server/utils/adapter_manager.py +++ b/src/twinkle/server/utils/adapter_manager.py @@ -17,7 +17,6 @@ if TYPE_CHECKING: from twinkle.server.utils.state import ServerStateProxy - from twinkle.model import TwinkleModel from twinkle.utils.logger import get_logger @@ -278,6 +277,8 @@ def _adapter_countdown_loop(self) -> None: # Has session: check session expiration and TTL session_expired = not self._is_session_alive(session_id) should_expire = session_expired or exceeded_ttl + logger.info(f'[AdapterManager] Adapter {adapter_name} session expiration check ' + f'(session_id={session_id}, session_alive={not session_expired}, should_expire={should_expire})') expiration_reasons = [] if exceeded_ttl: expiration_reasons.append('ttl_exceeded') @@ -288,6 +289,8 @@ def _adapter_countdown_loop(self) -> None: info['inactivity_counter'] = info.get('inactivity_counter', 0) + 1 exceeded_inactivity = info['inactivity_counter'] > self._adapter_timeout should_expire = exceeded_ttl or exceeded_inactivity + logger.info(f'[AdapterManager] Adapter {adapter_name} inactivity check ' + f'(inactivity_counter={info["inactivity_counter"]}, timeout={self._adapter_timeout}, should_expire={should_expire})') expiration_reasons = [] if exceeded_ttl: expiration_reasons.append('ttl_exceeded') @@ -299,7 +302,7 @@ def _adapter_countdown_loop(self) -> None: info['state'] = {} # best-effort clear token = info.get('token') expired_adapters.append((adapter_name, token)) - logger.debug(f'[AdapterManager] Adapter {adapter_name} expired ' + logger.info(f'[AdapterManager] Adapter {adapter_name} expired ' f"(reasons={','.join(expiration_reasons)}, session={session_id})") for adapter_name, token in expired_adapters: From 720aebc5cc128ca60fb0cf0869469e46f24f9063 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Fri, 13 Feb 2026 10:24:46 +0800 Subject: [PATCH 13/22] update --- src/twinkle/server/utils/adapter_manager.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/twinkle/server/utils/adapter_manager.py b/src/twinkle/server/utils/adapter_manager.py index b45e6f61..6c206939 100644 --- a/src/twinkle/server/utils/adapter_manager.py +++ b/src/twinkle/server/utils/adapter_manager.py @@ -277,7 +277,7 @@ def _adapter_countdown_loop(self) -> None: # Has session: check session expiration and TTL session_expired = not self._is_session_alive(session_id) should_expire = session_expired or exceeded_ttl - logger.info(f'[AdapterManager] Adapter {adapter_name} session expiration check ' + logger.debug(f'[AdapterManager] Adapter {adapter_name} session expiration check ' f'(session_id={session_id}, session_alive={not session_expired}, should_expire={should_expire})') expiration_reasons = [] if exceeded_ttl: @@ -289,7 +289,7 @@ def _adapter_countdown_loop(self) -> None: info['inactivity_counter'] = info.get('inactivity_counter', 0) + 1 exceeded_inactivity = info['inactivity_counter'] > self._adapter_timeout should_expire = exceeded_ttl or exceeded_inactivity - logger.info(f'[AdapterManager] Adapter {adapter_name} inactivity check ' + logger.debug(f'[AdapterManager] Adapter {adapter_name} inactivity check ' f'(inactivity_counter={info["inactivity_counter"]}, timeout={self._adapter_timeout}, should_expire={should_expire})') expiration_reasons = [] if exceeded_ttl: @@ -302,13 +302,13 @@ def _adapter_countdown_loop(self) -> None: info['state'] = {} # best-effort clear token = info.get('token') expired_adapters.append((adapter_name, token)) - logger.info(f'[AdapterManager] Adapter {adapter_name} expired ' - f"(reasons={','.join(expiration_reasons)}, session={session_id})") for adapter_name, token in expired_adapters: success = False try: self._on_adapter_expired(adapter_name) + logger.info(f'[AdapterManager] Adapter {adapter_name} expired ' + f"(reasons={','.join(expiration_reasons)}, session={session_id})") success = True except Exception as e: logger.warning(f'[AdapterManager] Error while expiring adapter {adapter_name}: {e}') From 01f88f7b9670d65dc150360f5acad17e38a5b0b1 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Fri, 13 Feb 2026 10:26:48 +0800 Subject: [PATCH 14/22] update --- cookbook/client/tinker/self_congnition.py | 6 +++--- src/twinkle/model/megatron/megatron.py | 6 +++--- src/twinkle/server/utils/state.py | 3 ++- src/twinkle/server/utils/task_queue.py | 10 +++++----- src/twinkle_client/__init__.py | 2 +- 5 files changed, 14 insertions(+), 13 deletions(-) diff --git a/cookbook/client/tinker/self_congnition.py b/cookbook/client/tinker/self_congnition.py index 925a32a4..5a565cc5 100644 --- a/cookbook/client/tinker/self_congnition.py +++ b/cookbook/client/tinker/self_congnition.py @@ -8,7 +8,7 @@ # The server must be running first (see server.py and server_config.yaml). import numpy as np import os -from modelscope import AutoTokenizer +from tqdm import tqdm from tinker import types from twinkle_client import init_tinker_compat_client from twinkle.data_format import Message, Trajectory @@ -125,5 +125,5 @@ def eval(): if __name__ == '__main__': - # train() # Uncomment to run training - eval() # Run evaluation / inference + train() # Uncomment to run training + # eval() # Run evaluation / inference diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index 7aacc90e..d9bf4207 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -848,13 +848,13 @@ def load(self, name: str, output_dir: Optional[str] = None, **kwargs): Args: name: Checkpoint name or HuggingFace Hub model id. output_dir: Parent directory that contains the checkpoint folder. - If None **and** ``resume`` is False, downloads from Hub. - resume: If True, restore optimizer, lr_scheduler and RNG state + If None **and** ``load_optimizer`` is False, downloads from Hub. + load_optimizer: If True, restore optimizer, lr_scheduler and RNG state from the mcore sub-checkpoint for training resumption. **kwargs: Additional arguments (``adapter_name``, ``no_load_optim``, ``no_load_rng``, etc.). """ - resume = kwargs.pop('resume', False) + resume = kwargs.pop('load_optimizer', False) if output_dir is None and not resume: # Load from hub token = kwargs.pop('token', None) diff --git a/src/twinkle/server/utils/state.py b/src/twinkle/server/utils/state.py index 1f03408e..e191d80a 100644 --- a/src/twinkle/server/utils/state.py +++ b/src/twinkle/server/utils/state.py @@ -31,7 +31,8 @@ class ServerState: def __init__( self, expiration_timeout: float = 86400.0, # 24 hours in seconds - cleanup_interval: float = 3600.0) -> None: # 1 hour in seconds + cleanup_interval: float = 3600.0, + **kwargs) -> None: # 1 hour in seconds # Session tracking self.sessions: dict[str, dict[str, Any]] = {} # Model registration diff --git a/src/twinkle/server/utils/task_queue.py b/src/twinkle/server/utils/task_queue.py index e87a4d9a..4d272a07 100644 --- a/src/twinkle/server/utils/task_queue.py +++ b/src/twinkle/server/utils/task_queue.py @@ -222,7 +222,7 @@ async def _queue_worker(self) -> None: Selection policy: round-robin across queue keys. If a task is rate-limited at execution time, it is requeued and the worker tries other queues. """ - print('[TaskQueue] Worker started') + logger.debug('[TaskQueue] Worker started') while True: try: # Wait until there is at least one queue with a task @@ -470,7 +470,7 @@ async def schedule_task( if self._event_loop is None: self._event_loop = asyncio.get_running_loop() - print( + logger.debug( f'[TaskQueue] Scheduling task {request_id}, rps_limit={self._task_queue_config.rps_limit}, enabled={self._task_queue_config.enabled}' # noqa: E501 ) @@ -487,7 +487,7 @@ async def schedule_task( # 5. Put task in queue and update status q = self._task_queues[queue_key] - print( + logger.debug( f'[TaskQueue] Adding task {request_id} to queue key={queue_key} (current size: {q.qsize()}) type={task_type}' # noqa: E501 ) await q.put( @@ -502,7 +502,7 @@ async def schedule_task( )) self.state.store_future_status( request_id, TaskStatus.QUEUED.value, model_id, queue_state=QueueState.ACTIVE.value) - print(f'[TaskQueue] Task {request_id} queued, new queue size: {q.qsize()} key={queue_key}') + logger.debug(f'[TaskQueue] Task {request_id} queued, new queue size: {q.qsize()} key={queue_key}') self._new_task_event.set() @@ -567,4 +567,4 @@ async def shutdown_task_queue(self) -> None: self._task_queues.clear() self._queue_order.clear() - print('[TaskQueue] Task queue shutdown complete') + logger.debug('[TaskQueue] Task queue shutdown complete') diff --git a/src/twinkle_client/__init__.py b/src/twinkle_client/__init__.py index e600023c..5a6928e9 100644 --- a/src/twinkle_client/__init__.py +++ b/src/twinkle_client/__init__.py @@ -21,7 +21,7 @@ def init_tinker_compat_client(base_url: str | None = None, api_key: str | None = # Apply patch to bypass tinker:// prefix validation patch_tinker() - if api_key is None: + if not api_key: api_key = get_api_key() if base_url and not base_url.startswith(('http://', 'https://')): From d6c274da5ccfacdce001e4b47335afd8c70587e7 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Fri, 13 Feb 2026 11:01:37 +0800 Subject: [PATCH 15/22] update --- cookbook/client/tinker/megatron/server_config.yaml | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/cookbook/client/tinker/megatron/server_config.yaml b/cookbook/client/tinker/megatron/server_config.yaml index 852fd85c..3560cdcd 100644 --- a/cookbook/client/tinker/megatron/server_config.yaml +++ b/cookbook/client/tinker/megatron/server_config.yaml @@ -56,6 +56,9 @@ applications: device_mesh: device_type: cuda dp_size: 4 + queue_config: + rps_limit: 20 # Max requests per second + tps_limit: 10000 # Max tokens per second deployments: - name: SamplerManagement autoscaling_config: @@ -88,11 +91,12 @@ applications: ep_size: 2 queue_config: - rps_limit: 100 # Max requests per second - tps_limit: 100000 # Max tokens per second + rps_limit: 20 # Max requests per second + tps_limit: 10000 # Max tokens per second adapter_config: - per_token_adapter_limit: 30 # Max concurrent LoRA adapters - adapter_timeout: 1800 # Seconds before idle adapter unload + per_token_adapter_limit: 3 # Max concurrent LoRA adapters + adapter_timeout: 30 # Seconds before idle adapter unload + adapter_max_lifetime: 36000 # Maximum lifetime of an adapter in seconds (e.g., 10 hours) deployments: - name: ModelManagement autoscaling_config: From 23038ffac9639d91a599640c5d465f3cb672433b Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Fri, 13 Feb 2026 11:02:45 +0800 Subject: [PATCH 16/22] update --- .../client/tinker/megatron/server_config.yaml | 4 +++- .../tinker/megatron/server_config_7b.yaml | 4 ++-- cookbook/client/twinkle/grpo.py | 6 ++--- .../twinkle/transformer/server_config.yaml | 22 ++++++++++++++----- 4 files changed, 24 insertions(+), 12 deletions(-) diff --git a/cookbook/client/tinker/megatron/server_config.yaml b/cookbook/client/tinker/megatron/server_config.yaml index 852fd85c..74347863 100644 --- a/cookbook/client/tinker/megatron/server_config.yaml +++ b/cookbook/client/tinker/megatron/server_config.yaml @@ -77,7 +77,9 @@ applications: args: use_megatron: true # Use HuggingFace Transformers backend model_id: "ms://Qwen/Qwen3-30B-A3B-Instruct-2507" # ModelScope model identifier - nproc_per_node: 4 # Number of GPU processes per node + max_length: 10240 # model max length + max_loras: 5 # model max loras + nproc_per_node: 4 # Number of GPU processes per node device_group: name: model ranks: [4,5,6,7] # GPU rank indices diff --git a/cookbook/client/tinker/megatron/server_config_7b.yaml b/cookbook/client/tinker/megatron/server_config_7b.yaml index de7e2a9a..690781e2 100644 --- a/cookbook/client/tinker/megatron/server_config_7b.yaml +++ b/cookbook/client/tinker/megatron/server_config_7b.yaml @@ -21,8 +21,7 @@ applications: route_prefix: /api/v1 # API endpoint prefix (Tinker-compatible) import_path: server # Python module to import args: - server_config: - per_token_adapter_limit: 30 # Max concurrent LoRA adapters per user (global limit) + deployments: - name: TinkerCompatServer autoscaling_config: @@ -56,6 +55,7 @@ applications: adapter_config: adapter_timeout: 30 # Seconds before idle adapter unload adapter_max_lifetime: 36000 # Maximum lifetime of an adapter in seconds (e.g., 10 hours) + per_token_adapter_limit: 30 deployments: - name: ModelManagement autoscaling_config: diff --git a/cookbook/client/twinkle/grpo.py b/cookbook/client/twinkle/grpo.py index 9d374beb..30a33c0e 100644 --- a/cookbook/client/twinkle/grpo.py +++ b/cookbook/client/twinkle/grpo.py @@ -42,13 +42,13 @@ # ========== Configuration ========== MODEL_ID = 'ms://Qwen/Qwen2.5-3B-Instruct' -NUM_GENERATIONS = 8 +NUM_GENERATIONS = 4 MAX_NEW_TOKENS = 1024 LEARNING_RATE = 1e-5 MAX_STEPS = 10 -BATCH_SIZE = 4 +BATCH_SIZE = 2 TEMPERATURE = 1.0 -SYNC_INTERVAL = 5 # Save weights for sampler every N steps +SYNC_INTERVAL = 1 # Save weights for sampler every N steps GRADIENT_ACCUMULATION_STEPS = 4 diff --git a/cookbook/client/twinkle/transformer/server_config.yaml b/cookbook/client/twinkle/transformer/server_config.yaml index 680e6f63..1ba07c60 100644 --- a/cookbook/client/twinkle/transformer/server_config.yaml +++ b/cookbook/client/twinkle/transformer/server_config.yaml @@ -49,8 +49,7 @@ applications: device_type: cuda device_mesh: # Distributed training mesh configuration device_type: cuda - mesh: [0,1] # Device indices in the mesh - mesh_dim_names: ['dp'] # Mesh dimension names: 'dp' = data parallel + dp_size: 2 # Mesh dimension names: 'dp' = data parallel deployments: - name: ModelManagement autoscaling_config: @@ -59,6 +58,10 @@ applications: target_ongoing_requests: 16 ray_actor_options: num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_TRUST_REMOTE_CODE: "0" + DEVICE_COUNT_PER_PHYSICAL_NODE: "8" # 3. Processor Service - Handles data preprocessing on CPU # Runs tokenization, template application, and other CPU-bound tasks. @@ -84,6 +87,10 @@ applications: target_ongoing_requests: 128 ray_actor_options: num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_TRUST_REMOTE_CODE: "0" + DEVICE_COUNT_PER_PHYSICAL_NODE: "8" # 4. Sampler Service - Handles text generation inference # Uses vLLM for efficient batched generation with optional LoRA adapters. @@ -93,7 +100,7 @@ applications: args: model_id: "ms://Qwen/Qwen2.5-3B-Instruct" # ModelScope model identifier to load sampler_type: vllm # Sampler backend (vllm or torch) - nproc_per_node: 1 # Number of GPU processes per node + nproc_per_node: 2 # Number of GPU processes per node engine_args: # vLLM engine configuration gpu_memory_utilization: 0.4 max_model_len: 1024 @@ -102,12 +109,11 @@ applications: adapter_timeout: 1800 # Seconds before idle adapter is unloaded device_group: name: sampler - ranks: [0] # GPU rank indices to use + ranks: [4] # GPU rank indices to use device_type: cuda device_mesh: device_type: cuda - mesh: [0] - mesh_dim_names: ['dp'] + dp_size: 1 deployments: - name: SamplerManagement autoscaling_config: @@ -116,3 +122,7 @@ applications: target_ongoing_requests: 16 ray_actor_options: num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_TRUST_REMOTE_CODE: "0" + DEVICE_COUNT_PER_PHYSICAL_NODE: "8" \ No newline at end of file From d8fa0b085898e43c48b18827af63f1dfb36fcfa2 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Fri, 13 Feb 2026 11:25:38 +0800 Subject: [PATCH 17/22] update --- .pre-commit-config.yaml | 16 +-- .../twinkle/transformer/server_config.yaml | 4 +- src/twinkle_client/dataloader/dataloader.py | 42 +++--- src/twinkle_client/dataset/base.py | 121 ++++++++---------- .../dataset/iterable_dataset.py | 53 ++++---- .../dataset/iterable_packing_dataset.py | 55 ++++---- src/twinkle_client/dataset/lazy_dataset.py | 43 ++++--- src/twinkle_client/dataset/packing_dataset.py | 19 +-- .../model/multi_lora_transformers.py | 5 +- src/twinkle_client/processor/base.py | 33 ++--- src/twinkle_client/sampler/vllm_sampler.py | 43 ++++--- 11 files changed, 213 insertions(+), 221 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 198c5575..eb1b6e27 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,23 +22,23 @@ repos: hooks: - id: pyupgrade args: [--py38-plus] - exclude: ^client_tools/ + exclude: ^(examples/|cookbook/|client_tools/|src/twinkle_client/) - repo: https://github.com/pre-commit/pre-commit-hooks rev: v6.0.0 hooks: - id: trailing-whitespace - exclude: ^client_tools/ + exclude: ^(examples/|cookbook/|client_tools/|src/twinkle_client/) - id: check-yaml - exclude: ^client_tools/ + exclude: ^(examples/|cookbook/|client_tools/|src/twinkle_client/) - id: end-of-file-fixer - exclude: ^client_tools/ + exclude: ^(examples/|cookbook/|client_tools/|src/twinkle_client/) - id: requirements-txt-fixer - exclude: ^client_tools/ + exclude: ^(examples/|cookbook/|client_tools/|src/twinkle_client/) - id: double-quote-string-fixer - exclude: ^client_tools/ + exclude: ^(examples/|cookbook/|client_tools/|src/twinkle_client/) - id: check-merge-conflict - exclude: ^client_tools/ + exclude: ^(examples/|cookbook/|client_tools/|src/twinkle_client/) - id: mixed-line-ending args: ["--fix=lf"] - exclude: ^client_tools/ + exclude: ^(examples/|cookbook/|client_tools/|src/twinkle_client/) diff --git a/cookbook/client/twinkle/transformer/server_config.yaml b/cookbook/client/twinkle/transformer/server_config.yaml index 1ba07c60..93fe8592 100644 --- a/cookbook/client/twinkle/transformer/server_config.yaml +++ b/cookbook/client/twinkle/transformer/server_config.yaml @@ -109,7 +109,7 @@ applications: adapter_timeout: 1800 # Seconds before idle adapter is unloaded device_group: name: sampler - ranks: [4] # GPU rank indices to use + ranks: [2] # GPU rank indices to use device_type: cuda device_mesh: device_type: cuda @@ -125,4 +125,4 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" - DEVICE_COUNT_PER_PHYSICAL_NODE: "8" \ No newline at end of file + DEVICE_COUNT_PER_PHYSICAL_NODE: "8" diff --git a/src/twinkle_client/dataloader/dataloader.py b/src/twinkle_client/dataloader/dataloader.py index f43ecea9..3cd2b564 100644 --- a/src/twinkle_client/dataloader/dataloader.py +++ b/src/twinkle_client/dataloader/dataloader.py @@ -10,13 +10,11 @@ # ============================================================================ from typing import Callable, Type, Union - +from twinkle_client.http import http_post, heartbeat_manager from twinkle.dataset import Dataset from twinkle.processor import InputProcessor -from twinkle_client.http import heartbeat_manager, http_post - -class DataLoader: +class DataLoader(object): """Client wrapper for DataLoader that calls server HTTP endpoints.""" def __init__(self, dataset: Union[Dataset, Callable], **kwargs): @@ -28,11 +26,9 @@ def __init__(self, dataset: Union[Dataset, Callable], **kwargs): json_data={ 'processor_type': 'dataloader', 'class_type': 'DataLoader', - **{ - 'dataset': dataset - }, - **kwargs - }) + **{'dataset': dataset}, **kwargs + } + ) response.raise_for_status() self.processor_id = response.json()['processor_id'] heartbeat_manager.register_processor(self.processor_id) @@ -43,6 +39,7 @@ def __del__(self): except: pass + def __len__(self): response = http_post( url=f'{self.server_url}/processors/call', @@ -50,9 +47,11 @@ def __len__(self): 'processor_id': self.processor_id, 'function': '__len__', **{}, - }) + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + def set_processor(self, processor_cls: Union[Type[InputProcessor], str, InputProcessor, Callable], **kwargs): response = http_post( @@ -60,13 +59,13 @@ def set_processor(self, processor_cls: Union[Type[InputProcessor], str, InputPro json_data={ 'processor_id': self.processor_id, 'function': 'set_processor', - **{ - 'processor_cls': processor_cls - }, + **{'processor_cls': processor_cls}, **kwargs - }) + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + def __iter__(self): response = http_post( @@ -75,16 +74,19 @@ def __iter__(self): 'processor_id': self.processor_id, 'function': '__iter__', **{}, - }) + } + ) response.raise_for_status() return self - + def __next__(self): response = http_post( url=f'{self.server_url}/processors/call', json_data={ 'processor_id': self.processor_id, 'function': '__next__', - }) + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + \ No newline at end of file diff --git a/src/twinkle_client/dataset/base.py b/src/twinkle_client/dataset/base.py index 351a5a3b..3d5b5062 100644 --- a/src/twinkle_client/dataset/base.py +++ b/src/twinkle_client/dataset/base.py @@ -10,14 +10,14 @@ # ============================================================================ from typing import Any, Callable, Dict, Type, Union - -from twinkle.dataset import Dataset, DatasetMeta -from twinkle.preprocessor import DataFilter, Preprocessor +from twinkle_client.http import http_post, heartbeat_manager +from twinkle.dataset import Dataset +from twinkle.dataset import DatasetMeta +from twinkle.preprocessor import DataFilter +from twinkle.preprocessor import Preprocessor from twinkle.template import Template -from twinkle_client.http import heartbeat_manager, http_post - -class Dataset: +class Dataset(object): """Client wrapper for Dataset that calls server HTTP endpoints.""" def __init__(self, dataset_meta: DatasetMeta, **kwargs): @@ -29,11 +29,9 @@ def __init__(self, dataset_meta: DatasetMeta, **kwargs): json_data={ 'processor_type': 'dataset', 'class_type': 'Dataset', - **{ - 'dataset_meta': dataset_meta - }, - **kwargs - }) + **{'dataset_meta': dataset_meta}, **kwargs + } + ) response.raise_for_status() self.processor_id = response.json()['processor_id'] heartbeat_manager.register_processor(self.processor_id) @@ -44,19 +42,20 @@ def __del__(self): except: pass + def set_template(self, template_func: Union[Template, Type[Template], str], **kwargs): response = http_post( url=f'{self.server_url}/processors/call', json_data={ 'processor_id': self.processor_id, 'function': 'set_template', - **{ - 'template_func': template_func - }, + **{'template_func': template_func}, **kwargs - }) + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + def encode(self, add_generation_prompt: bool = False, **kwargs): response = http_post( @@ -64,13 +63,13 @@ def encode(self, add_generation_prompt: bool = False, **kwargs): json_data={ 'processor_id': self.processor_id, 'function': 'encode', - **{ - 'add_generation_prompt': add_generation_prompt - }, + **{'add_generation_prompt': add_generation_prompt}, **kwargs - }) + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + def check(self, **kwargs): response = http_post( @@ -80,49 +79,39 @@ def check(self, **kwargs): 'function': 'check', **{}, **kwargs - }) + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + - def map(self, - preprocess_func: Union[Preprocessor, Callable, str, Type[Preprocessor]], - dataset_meta: DatasetMeta = None, - init_args: Dict[str, Any] = None, - **kwargs): + def map(self, preprocess_func: Union[Preprocessor, Callable, str, Type[Preprocessor]], dataset_meta: DatasetMeta = None, init_args: Dict[str, Any] = None, **kwargs): response = http_post( url=f'{self.server_url}/processors/call', json_data={ 'processor_id': self.processor_id, 'function': 'map', - **{ - 'preprocess_func': preprocess_func, - 'dataset_meta': dataset_meta, - 'init_args': init_args - }, + **{'preprocess_func': preprocess_func, 'dataset_meta': dataset_meta, 'init_args': init_args}, **kwargs - }) + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + - def filter(self, - filter_func: Union[Callable, str, Type[DataFilter], DataFilter], - dataset_meta: DatasetMeta = None, - init_args: Dict[str, Any] = None, - **kwargs): + def filter(self, filter_func: Union[Callable, str, Type[DataFilter], DataFilter], dataset_meta: DatasetMeta = None, init_args: Dict[str, Any] = None, **kwargs): response = http_post( url=f'{self.server_url}/processors/call', json_data={ 'processor_id': self.processor_id, 'function': 'filter', - **{ - 'filter_func': filter_func, - 'dataset_meta': dataset_meta, - 'init_args': init_args - }, + **{'filter_func': filter_func, 'dataset_meta': dataset_meta, 'init_args': init_args}, **kwargs - }) + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + def add_dataset(self, dataset_meta: DatasetMeta, **kwargs): response = http_post( @@ -130,26 +119,26 @@ def add_dataset(self, dataset_meta: DatasetMeta, **kwargs): json_data={ 'processor_id': self.processor_id, 'function': 'add_dataset', - **{ - 'dataset_meta': dataset_meta - }, + **{'dataset_meta': dataset_meta}, **kwargs - }) + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + - def mix_dataset(self, interleave=True): + def mix_dataset(self, interleave = True): response = http_post( url=f'{self.server_url}/processors/call', json_data={ 'processor_id': self.processor_id, 'function': 'mix_dataset', - **{ - 'interleave': interleave - }, - }) + **{'interleave': interleave}, + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + def __getitem__(self, idx): response = http_post( @@ -157,12 +146,12 @@ def __getitem__(self, idx): json_data={ 'processor_id': self.processor_id, 'function': '__getitem__', - **{ - 'idx': idx - }, - }) + **{'idx': idx}, + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + def __len__(self): response = http_post( @@ -171,6 +160,8 @@ def __len__(self): 'processor_id': self.processor_id, 'function': '__len__', **{}, - }) + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + \ No newline at end of file diff --git a/src/twinkle_client/dataset/iterable_dataset.py b/src/twinkle_client/dataset/iterable_dataset.py index f8bd650a..347d1012 100644 --- a/src/twinkle_client/dataset/iterable_dataset.py +++ b/src/twinkle_client/dataset/iterable_dataset.py @@ -9,12 +9,11 @@ # 2. Run: python client_tools/client_generator.py # ============================================================================ +from twinkle_client.http import http_post, heartbeat_manager +from twinkle.dataset import Dataset +from twinkle.dataset import DatasetMeta from torch.utils.data import IterableDataset -from twinkle.dataset import Dataset, DatasetMeta -from twinkle_client.http import heartbeat_manager, http_post - - class IterableDataset(IterableDataset): """Client wrapper for IterableDataset that calls server HTTP endpoints.""" @@ -27,11 +26,9 @@ def __init__(self, dataset_meta: DatasetMeta, **kwargs): json_data={ 'processor_type': 'dataset', 'class_type': 'IterableDataset', - **{ - 'dataset_meta': dataset_meta - }, - **kwargs - }) + **{'dataset_meta': dataset_meta}, **kwargs + } + ) response.raise_for_status() self.processor_id = response.json()['processor_id'] heartbeat_manager.register_processor(self.processor_id) @@ -42,19 +39,20 @@ def __del__(self): except: pass + def add_dataset(self, dataset_meta: DatasetMeta, **kwargs): response = http_post( url=f'{self.server_url}/processors/call', json_data={ 'processor_id': self.processor_id, 'function': 'add_dataset', - **{ - 'dataset_meta': dataset_meta - }, + **{'dataset_meta': dataset_meta}, **kwargs - }) + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + def __len__(self): response = http_post( @@ -63,9 +61,11 @@ def __len__(self): 'processor_id': self.processor_id, 'function': '__len__', **{}, - }) + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + def __getitem__(self, idx): response = http_post( @@ -73,12 +73,12 @@ def __getitem__(self, idx): json_data={ 'processor_id': self.processor_id, 'function': '__getitem__', - **{ - 'idx': idx - }, - }) + **{'idx': idx}, + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + def __iter__(self): response = http_post( @@ -87,16 +87,19 @@ def __iter__(self): 'processor_id': self.processor_id, 'function': '__iter__', **{}, - }) + } + ) response.raise_for_status() return self - + def __next__(self): response = http_post( url=f'{self.server_url}/processors/call', json_data={ 'processor_id': self.processor_id, 'function': '__next__', - }) + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + \ No newline at end of file diff --git a/src/twinkle_client/dataset/iterable_packing_dataset.py b/src/twinkle_client/dataset/iterable_packing_dataset.py index 8383e55a..ce2d918d 100644 --- a/src/twinkle_client/dataset/iterable_packing_dataset.py +++ b/src/twinkle_client/dataset/iterable_packing_dataset.py @@ -9,23 +9,17 @@ # 2. Run: python client_tools/client_generator.py # ============================================================================ -from torch.utils.data import IterableDataset from typing import Type, Union - -from twinkle.dataset import Dataset, DatasetMeta +from twinkle_client.http import http_post, heartbeat_manager +from twinkle.dataset import Dataset +from twinkle.dataset import DatasetMeta from twinkle.template import Template -from twinkle_client.http import heartbeat_manager, http_post - +from torch.utils.data import IterableDataset class IterablePackingDataset(IterableDataset): """Client wrapper for IterablePackingDataset that calls server HTTP endpoints.""" - def __init__(self, - dataset_meta: DatasetMeta, - packing_interval: int = 128, - packing_num_proc: int = 1, - cyclic: bool = False, - **kwargs): + def __init__(self, dataset_meta: DatasetMeta, packing_interval: int = 128, packing_num_proc: int = 1, cyclic: bool = False, **kwargs): from twinkle_client.http import get_base_url self.server_url = get_base_url() @@ -34,14 +28,9 @@ def __init__(self, json_data={ 'processor_type': 'dataset', 'class_type': 'IterablePackingDataset', - **{ - 'dataset_meta': dataset_meta, - 'packing_interval': packing_interval, - 'packing_num_proc': packing_num_proc, - 'cyclic': cyclic - }, - **kwargs - }) + **{'dataset_meta': dataset_meta, 'packing_interval': packing_interval, 'packing_num_proc': packing_num_proc, 'cyclic': cyclic}, **kwargs + } + ) response.raise_for_status() self.processor_id = response.json()['processor_id'] heartbeat_manager.register_processor(self.processor_id) @@ -52,19 +41,20 @@ def __del__(self): except: pass + def set_template(self, template_cls: Union[Type[Template], str, Template], **kwargs): response = http_post( url=f'{self.server_url}/processors/call', json_data={ 'processor_id': self.processor_id, 'function': 'set_template', - **{ - 'template_cls': template_cls - }, + **{'template_cls': template_cls}, **kwargs - }) + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + def pack_dataset(self): response = http_post( @@ -73,9 +63,11 @@ def pack_dataset(self): 'processor_id': self.processor_id, 'function': 'pack_dataset', **{}, - }) + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + def __iter__(self): response = http_post( @@ -84,16 +76,19 @@ def __iter__(self): 'processor_id': self.processor_id, 'function': '__iter__', **{}, - }) + } + ) response.raise_for_status() return self - + def __next__(self): response = http_post( url=f'{self.server_url}/processors/call', json_data={ 'processor_id': self.processor_id, 'function': '__next__', - }) + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + \ No newline at end of file diff --git a/src/twinkle_client/dataset/lazy_dataset.py b/src/twinkle_client/dataset/lazy_dataset.py index 586e6927..ce8178b1 100644 --- a/src/twinkle_client/dataset/lazy_dataset.py +++ b/src/twinkle_client/dataset/lazy_dataset.py @@ -9,11 +9,11 @@ # 2. Run: python client_tools/client_generator.py # ============================================================================ -from twinkle.dataset import Dataset, DatasetMeta -from twinkle_client.http import heartbeat_manager, http_post +from twinkle_client.http import http_post, heartbeat_manager +from twinkle.dataset import Dataset +from twinkle.dataset import DatasetMeta from .base import Dataset - class LazyDataset(Dataset): """Client wrapper for LazyDataset that calls server HTTP endpoints.""" @@ -26,11 +26,9 @@ def __init__(self, dataset_meta: DatasetMeta, **kwargs): json_data={ 'processor_type': 'dataset', 'class_type': 'LazyDataset', - **{ - 'dataset_meta': dataset_meta - }, - **kwargs - }) + **{'dataset_meta': dataset_meta}, **kwargs + } + ) response.raise_for_status() self.processor_id = response.json()['processor_id'] heartbeat_manager.register_processor(self.processor_id) @@ -41,6 +39,7 @@ def __del__(self): except: pass + def encode(self, **kwargs): response = http_post( url=f'{self.server_url}/processors/call', @@ -49,9 +48,11 @@ def encode(self, **kwargs): 'function': 'encode', **{}, **kwargs - }) + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + def check(self, **kwargs): response = http_post( @@ -61,9 +62,11 @@ def check(self, **kwargs): 'function': 'check', **{}, **kwargs - }) + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + def __getitem__(self, idx): response = http_post( @@ -71,12 +74,12 @@ def __getitem__(self, idx): json_data={ 'processor_id': self.processor_id, 'function': '__getitem__', - **{ - 'idx': idx - }, - }) + **{'idx': idx}, + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + def __len__(self): response = http_post( @@ -85,6 +88,8 @@ def __len__(self): 'processor_id': self.processor_id, 'function': '__len__', **{}, - }) + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + \ No newline at end of file diff --git a/src/twinkle_client/dataset/packing_dataset.py b/src/twinkle_client/dataset/packing_dataset.py index 8783c2ab..0d91546f 100644 --- a/src/twinkle_client/dataset/packing_dataset.py +++ b/src/twinkle_client/dataset/packing_dataset.py @@ -9,11 +9,11 @@ # 2. Run: python client_tools/client_generator.py # ============================================================================ -from twinkle.dataset import Dataset, DatasetMeta -from twinkle_client.http import heartbeat_manager, http_post +from twinkle_client.http import http_post, heartbeat_manager +from twinkle.dataset import Dataset +from twinkle.dataset import DatasetMeta from .base import Dataset - class PackingDataset(Dataset): """Client wrapper for PackingDataset that calls server HTTP endpoints.""" @@ -39,7 +39,7 @@ def __del__(self): except: pass - + def pack_dataset(self): response = http_post( url=f'{self.server_url}/processors/call', @@ -50,8 +50,8 @@ def pack_dataset(self): } ) response.raise_for_status() - return response.json()['result'] - + return response.json()["result"] + def __getitem__(self, index): response = http_post( @@ -63,8 +63,8 @@ def __getitem__(self, index): } ) response.raise_for_status() - return response.json()['result'] - + return response.json()["result"] + def __len__(self): response = http_post( @@ -76,4 +76,5 @@ def __len__(self): } ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + \ No newline at end of file diff --git a/src/twinkle_client/model/multi_lora_transformers.py b/src/twinkle_client/model/multi_lora_transformers.py index 35f8c849..f681c96b 100644 --- a/src/twinkle_client/model/multi_lora_transformers.py +++ b/src/twinkle_client/model/multi_lora_transformers.py @@ -8,12 +8,11 @@ # 1. Modify the source files in src/twinkle/ # 2. Run: python client_tools/client_generator.py # ============================================================================ +from typing import Any, Optional, Union, Type, Dict, Literal, List import uuid -from typing import Any, Dict, List, Literal, Optional, Type, Union - +from twinkle_client.http import http_post, heartbeat_manager from twinkle import DeviceMesh from twinkle.data_format import InputFeature, Trajectory -from twinkle_client.http import heartbeat_manager, http_post class MultiLoraTransformersModel: diff --git a/src/twinkle_client/processor/base.py b/src/twinkle_client/processor/base.py index 47e28fd2..d59572a7 100644 --- a/src/twinkle_client/processor/base.py +++ b/src/twinkle_client/processor/base.py @@ -10,20 +10,14 @@ # ============================================================================ from typing import List, Literal, Optional, Union - +from twinkle_client.http import http_post, heartbeat_manager from twinkle import DeviceMesh from twinkle.data_format import InputFeature -from twinkle_client.http import heartbeat_manager, http_post - -class InputProcessor: +class InputProcessor(object): """Client wrapper for InputProcessor that calls server HTTP endpoints.""" - def __init__(self, - device_mesh: Optional[DeviceMesh] = None, - padding_free: bool = False, - framework: Literal['transformers', 'megatron'] = 'transformers', - **kwargs): + def __init__(self, device_mesh: Optional[DeviceMesh] = None, padding_free: bool = False, framework: Literal['transformers', 'megatron'] = 'transformers', **kwargs): from twinkle_client.http import get_base_url self.server_url = get_base_url() @@ -32,13 +26,9 @@ def __init__(self, json_data={ 'processor_type': 'processor', 'class_type': 'InputProcessor', - **{ - 'device_mesh': device_mesh, - 'padding_free': padding_free, - 'framework': framework - }, - **kwargs - }) + **{'device_mesh': device_mesh, 'padding_free': padding_free, 'framework': framework}, **kwargs + } + ) response.raise_for_status() self.processor_id = response.json()['processor_id'] heartbeat_manager.register_processor(self.processor_id) @@ -49,16 +39,17 @@ def __del__(self): except: pass + def __call__(self, inputs: Union[InputFeature, List[InputFeature]], **kwargs): response = http_post( url=f'{self.server_url}/processors/call', json_data={ 'processor_id': self.processor_id, 'function': '__call__', - **{ - 'inputs': inputs - }, + **{'inputs': inputs}, **kwargs - }) + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + \ No newline at end of file diff --git a/src/twinkle_client/sampler/vllm_sampler.py b/src/twinkle_client/sampler/vllm_sampler.py index 5faed5e1..907881a4 100644 --- a/src/twinkle_client/sampler/vllm_sampler.py +++ b/src/twinkle_client/sampler/vllm_sampler.py @@ -8,12 +8,11 @@ # 1. Modify the source files in src/twinkle/ # 2. Run: python client_tools/client_generator.py # ============================================================================ -from peft import PeftConfig -from typing import Any, Dict, List, Optional, Union - -from twinkle.data_format import InputFeature, Trajectory +from typing import Any, Optional, List, Dict, Union +from twinkle_client.http import http_post, heartbeat_manager from twinkle.sampler.base import Sampler -from twinkle_client.http import heartbeat_manager, http_post +from peft import PeftConfig +from twinkle.data_format import Trajectory, InputFeature class vLLMSampler(Sampler): @@ -32,14 +31,20 @@ def __init__(self, model_id: str, **kwargs): if '://' in model_id: model_id = model_id.split('://')[1] self.server_url = f'{self.server_url}/samplers/{model_id}' - response = http_post(url=f'{self.server_url}/create', json_data=kwargs) + response = http_post( + url=f'{self.server_url}/create', + json_data=kwargs + ) response.raise_for_status() def _send_adapter_heartbeat(self): """Internal method to send adapter heartbeat.""" if not self.adapter_name: return - response = http_post(url=f'{self.server_url}/heartbeat', json_data={'adapter_name': self.adapter_name}) + response = http_post( + url=f'{self.server_url}/heartbeat', + json_data={'adapter_name': self.adapter_name} + ) response.raise_for_status() def add_adapter_to_sampler(self, adapter_name: str, config: PeftConfig, **kwargs): @@ -48,16 +53,16 @@ def add_adapter_to_sampler(self, adapter_name: str, config: PeftConfig, **kwargs config = config.__dict__ response = http_post( url=f'{self.server_url}/add_adapter_to_sampler', - json_data={ - 'adapter_name': adapter_name, - 'config': config, - **kwargs - }) + json_data={'adapter_name': adapter_name, 'config': config, **kwargs} + ) response.raise_for_status() # Register adapter for automatic heartbeat after successful creation self.adapter_name = adapter_name - heartbeat_manager.register_adapter(self.adapter_name, self._send_adapter_heartbeat) + heartbeat_manager.register_adapter( + self.adapter_name, + self._send_adapter_heartbeat + ) return response.json() @@ -98,7 +103,10 @@ def sample( if adapter_uri is not None: json_data['adapter_uri'] = adapter_uri - response = http_post(url=f'{self.server_url}/sample', json_data=json_data) + response = http_post( + url=f'{self.server_url}/sample', + json_data=json_data + ) response.raise_for_status() return response.json() @@ -106,10 +114,7 @@ def set_template(self, template_cls: str, adapter_name: str = '', **kwargs): """Set the template for encoding trajectories.""" response = http_post( url=f'{self.server_url}/set_template', - json_data={ - 'template_cls': template_cls, - 'adapter_name': adapter_name, - **kwargs - }) + json_data={'template_cls': template_cls, 'adapter_name': adapter_name, **kwargs} + ) response.raise_for_status() return response.json() From ca1eeabcdb90d0eacb787247d601d78c2c55f2f7 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Fri, 13 Feb 2026 11:41:53 +0800 Subject: [PATCH 18/22] update --- cookbook/client/tinker/megatron/server.py | 2 +- .../Usage Guide/Server and Client/Server.md | 117 ++++++++++++++++-- .../\346\234\215\345\212\241\347\253\257.md" | 92 +++++++++----- 3 files changed, 168 insertions(+), 43 deletions(-) diff --git a/cookbook/client/tinker/megatron/server.py b/cookbook/client/tinker/megatron/server.py index 51c85164..e38f43a4 100644 --- a/cookbook/client/tinker/megatron/server.py +++ b/cookbook/client/tinker/megatron/server.py @@ -15,7 +15,7 @@ # Resolve the path to server_config.yaml relative to this script's location file_dir = os.path.abspath(os.path.dirname(__file__)) -config_path = os.path.join(file_dir, 'server_config_7b.yaml') +config_path = os.path.join(file_dir, 'server_config.yaml') # Launch the Twinkle server — this call blocks until the server is shut down launch_server(config_path=config_path) diff --git a/docs/source_en/Usage Guide/Server and Client/Server.md b/docs/source_en/Usage Guide/Server and Client/Server.md index 54b2d46b..73583246 100644 --- a/docs/source_en/Usage Guide/Server and Client/Server.md +++ b/docs/source_en/Usage Guide/Server and Client/Server.md @@ -50,8 +50,77 @@ This configuration starts 3 nodes: - **Node 1** (Worker): 4 GPUs (cards 4-7) - **Node 2** (Worker): CPU-only node +#### 4. Set Environment Variables + +Before starting the Server, you need to set the following environment variables: + +```bash +export DEVICE_COUNT_PER_PHYSICAL_NODE=8 # Specify the total number of GPUs on each physical machine +export TWINKLE_TRUST_REMOTE_CODE=0 # Whether to trust remote code (security consideration) +``` + +> **Important Note**: `DEVICE_COUNT_PER_PHYSICAL_NODE` must be set to the actual number of physical GPUs on the machine, which is crucial for correctly parsing the `ranks` configuration. + ### Node Rank in YAML Configuration +In the YAML configuration file, **each component needs to occupy a separate Node**. + +**Example configuration:** + +```yaml +applications: + # Model service occupies GPU 0-3 (physical card numbers) + - name: models-Qwen2.5-7B-Instruct + route_prefix: /models/Qwen/Qwen2.5-7B-Instruct + import_path: model + args: + nproc_per_node: 4 + device_group: + name: model + ranks: [0, 1, 2, 3] # Physical GPU card numbers + device_type: cuda + device_mesh: + device_type: cuda + dp_size: 4 # Data parallel size + # tp_size: 1 # Tensor parallel size (optional) + # pp_size: 1 # Pipeline parallel size (optional) + # ep_size: 1 # Expert parallel size (optional) + + # Sampler service occupies GPU 4-5 (physical card numbers) + - name: sampler-Qwen2.5-7B-Instruct + route_prefix: /sampler/Qwen/Qwen2.5-7B-Instruct + import_path: sampler + args: + nproc_per_node: 2 + device_group: + name: sampler + ranks: [4, 5] # Physical GPU card numbers 4-5 + device_type: cuda + device_mesh: + device_type: cuda + dp_size: 2 # Data parallel size + + # Processor service occupies CPU + - name: processor + route_prefix: /processors + import_path: processor + args: + ncpu_proc_per_node: 4 + device_group: + name: processor + ranks: 0 # CPU index + device_type: CPU + device_mesh: + device_type: CPU + dp_size: 4 # Data parallel size +``` +**Important notes:** +- The `ranks` configuration uses **physical GPU card numbers**, directly corresponding to the actual GPU devices on the machine +- The `device_mesh` configuration uses parameters like `dp_size`, `tp_size`, `pp_size`, `ep_size` instead of the original `mesh` and `mesh_dim_names` +- The environment variable `DEVICE_COUNT_PER_PHYSICAL_NODE` must be set to inform the system of the total number of physical GPUs on each machine +- Different components will be automatically assigned to different Nodes +- Ray will automatically schedule to the appropriate Node based on resource requirements (`num_gpus`, `num_cpus` in `ray_actor_options`) + In the YAML configuration file, **each component needs to occupy a separate Node**, and the `ranks` within each Node are numbered starting from 0. **Example configuration:** @@ -204,8 +273,9 @@ applications: device_type: cuda device_mesh: # Distributed training mesh device_type: cuda - mesh: [0, 1] # Device indices in the mesh - mesh_dim_names: ['dp'] # Mesh dimensions: dp=data parallel + dp_size: 2 # Data parallel size + # tp_size: 1 # Tensor parallel size (optional) + # pp_size: 1 # Pipeline parallel size (optional) deployments: - name: ModelManagement autoscaling_config: @@ -229,8 +299,7 @@ applications: device_type: CPU device_mesh: device_type: CPU - mesh: [0, 1] - mesh_dim_names: ['dp'] + dp_size: 2 # Data parallel size deployments: - name: ProcessorManagement autoscaling_config: @@ -260,8 +329,7 @@ The difference from the Transformers backend is only in the `use_megatron` param device_type: cuda device_mesh: device_type: cuda - mesh: [0, 1] - mesh_dim_names: ['dp'] + dp_size: 2 # Data parallel size ``` > **Note**: The Megatron backend does not need `adapter_config` (LoRA adapter management is handled internally by Megatron). @@ -314,8 +382,7 @@ applications: device_type: cuda device_mesh: device_type: cuda - mesh: [0, 1] - mesh_dim_names: ['dp'] + dp_size: 2 # Data parallel size deployments: - name: ModelManagement autoscaling_config: @@ -324,6 +391,9 @@ applications: target_ongoing_requests: 16 ray_actor_options: num_cpus: 0.1 + runtime_env: + env_vars: + DEVICE_COUNT_PER_PHYSICAL_NODE: "8" # Total number of physical GPUs on each machine # 3. Sampler service (optional, for inference sampling) - name: sampler-Qwen2.5-0.5B-Instruct @@ -343,8 +413,7 @@ applications: device_type: cuda device_mesh: device_type: cuda - mesh: [0] - mesh_dim_names: ['dp'] + dp_size: 1 # Data parallel size deployments: - name: SamplerManagement autoscaling_config: @@ -354,6 +423,9 @@ applications: ray_actor_options: num_cpus: 0.1 num_gpus: 1 # Sampler needs independent GPU + runtime_env: + env_vars: + DEVICE_COUNT_PER_PHYSICAL_NODE: "8" # Total number of physical GPUs on each machine ``` ## Configuration Item Description @@ -375,11 +447,30 @@ applications: ```yaml device_group: name: model # Device group name - ranks: [0, 1] # GPU card number list + ranks: [0, 1] # Physical GPU card number list device_type: cuda # Device type: cuda / CPU device_mesh: device_type: cuda - mesh: [0, 1] # Device indices in the mesh - mesh_dim_names: ['dp'] # Dimension names, commonly used: dp (data parallel), tp (tensor parallel), pp (pipeline parallel) + dp_size: 2 # Data parallel size + # tp_size: 1 # Tensor parallel size (optional) + # pp_size: 1 # Pipeline parallel size (optional) + # ep_size: 1 # Expert parallel size (optional) +``` + +**Important configuration parameters:** + +| Parameter | Type | Description | +|------|------|------| +| `ranks` | list[int] | **Physical GPU card numbers**, directly corresponding to the actual GPU devices on the machine | +| `dp_size` | int | Data parallel size | +| `tp_size` | int (optional) | Tensor parallel size | +| `pp_size` | int (optional) | Pipeline parallel size | +| `ep_size` | int (optional) | Expert parallel size (for MoE models) | + +**Environment variables:** + +```bash +export DEVICE_COUNT_PER_PHYSICAL_NODE=8 # Total number of GPUs on each physical machine (must be set) +export TWINKLE_TRUST_REMOTE_CODE=0 # Whether to trust remote code ``` diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/\346\234\215\345\212\241\347\253\257.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/\346\234\215\345\212\241\347\253\257.md" index d194159d..ab7a2436 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/\346\234\215\345\212\241\347\253\257.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/\346\234\215\345\212\241\347\253\257.md" @@ -50,15 +50,26 @@ ray start --address=10.28.252.9:6379 --num-gpus=0 - **Node 1**(Worker):4 个 GPU(卡 4-7) - **Node 2**(Worker):纯 CPU 节点 +#### 4. 设置环境变量 + +在启动 Server 之前,需要设置以下环境变量: + +```bash +export DEVICE_COUNT_PER_PHYSICAL_NODE=8 # 指定每台物理机上的 GPU 总数 +export TWINKLE_TRUST_REMOTE_CODE=0 # 是否信任远程代码(安全考虑) +``` + +> **重要提示**:`DEVICE_COUNT_PER_PHYSICAL_NODE` 必须设置为机器上实际的物理 GPU 数量,这对于正确解析 `ranks` 配置至关重要。 + ### YAML 配置中的 Node Rank -在 YAML 配置文件中,**每个组件需要占用一个独立的 Node**,`ranks` 配置在各自的 Node 内都是从 0 开始编号的。 +在 YAML 配置文件中,**每个组件需要占用一个独立的 Node**。 **示例配置:** ```yaml applications: - # 模型服务占用 Node 0(Head 节点,GPU 0-3) + # 模型服务占用 GPU 0-3(物理卡号) - name: models-Qwen2.5-7B-Instruct route_prefix: /models/Qwen/Qwen2.5-7B-Instruct import_path: model @@ -66,14 +77,16 @@ applications: nproc_per_node: 4 device_group: name: model - ranks: [0, 1, 2, 3] # Node 0 内的 GPU 编号 + ranks: [0, 1, 2, 3] # 物理 GPU 卡号 device_type: cuda device_mesh: device_type: cuda - mesh: [0, 1, 2, 3] - mesh_dim_names: ['dp'] + dp_size: 4 # 数据并行大小 + # tp_size: 1 # 张量并行大小(可选) + # pp_size: 1 # 流水线并行大小(可选) + # ep_size: 1 # 专家并行大小(可选) - # Sampler 服务占用 Node 1(Worker 节点,GPU 4-7) + # Sampler 服务占用 GPU 4-5(物理卡号) - name: sampler-Qwen2.5-7B-Instruct route_prefix: /sampler/Qwen/Qwen2.5-7B-Instruct import_path: sampler @@ -81,14 +94,13 @@ applications: nproc_per_node: 2 device_group: name: sampler - ranks: [0, 1] # Node 1 内的 GPU 编号(对应物理 GPU 4-5) + ranks: [4, 5] # 物理 GPU 卡号 4-5 device_type: cuda device_mesh: device_type: cuda - mesh: [0, 1] - mesh_dim_names: ['dp'] + dp_size: 2 # 数据并行大小 - # Processor 服务占用 Node 2(CPU 节点) + # Processor 服务占用 CPU - name: processor route_prefix: /processors import_path: processor @@ -96,16 +108,16 @@ applications: ncpu_proc_per_node: 4 device_group: name: processor - ranks: 0 # Node 2 内的 CPU 编号 + ranks: 0 # CPU 编号 device_type: CPU device_mesh: device_type: CPU - mesh: [0, 1, 2, 3] - mesh_dim_names: ['dp'] + dp_size: 4 # 数据并行大小 ``` - **重要提示:** -- 每个组件的 `ranks` 配置都是相对于其所占用的 Ray Node 而言 +- `ranks` 配置使用**物理 GPU 卡号**,直接对应机器上的实际 GPU 设备 +- `device_mesh` 配置使用 `dp_size`、`tp_size`、`pp_size`、`ep_size` 等参数替代原来的 `mesh` 和 `mesh_dim_names` +- 必须设置环境变量 `DEVICE_COUNT_PER_PHYSICAL_NODE` 来告知系统每台机器的物理 GPU 总数 - 不同组件会自动分配到不同的 Node 上 - Ray 会根据资源需求(`ray_actor_options` 中的 `num_gpus`、`num_cpus`)自动调度到合适的 Node @@ -200,12 +212,13 @@ applications: nproc_per_node: 2 # 每节点 GPU 进程数 device_group: # 逻辑设备组 name: model - ranks: [0, 1] # 使用的 GPU 卡号 + ranks: [0, 1] # 物理 GPU 卡号 device_type: cuda device_mesh: # 分布式训练网格 device_type: cuda - mesh: [0, 1] # 网格中的设备索引 - mesh_dim_names: ['dp'] # 网格维度:dp=数据并行 + dp_size: 2 # 数据并行大小 + # tp_size: 1 # 张量并行大小(可选) + # pp_size: 1 # 流水线并行大小(可选) deployments: - name: ModelManagement autoscaling_config: @@ -229,8 +242,7 @@ applications: device_type: CPU device_mesh: device_type: CPU - mesh: [0, 1] - mesh_dim_names: ['dp'] + dp_size: 2 # 数据并行大小 deployments: - name: ProcessorManagement autoscaling_config: @@ -260,8 +272,7 @@ applications: device_type: cuda device_mesh: device_type: cuda - mesh: [0, 1] - mesh_dim_names: ['dp'] + dp_size: 2 # 数据并行大小 ``` > **注意**:Megatron 后端不需要 `adapter_config`(LoRA 适配器管理由 Megatron 内部处理)。 @@ -314,8 +325,7 @@ applications: device_type: cuda device_mesh: device_type: cuda - mesh: [0, 1] - mesh_dim_names: ['dp'] + dp_size: 2 # 数据并行大小 deployments: - name: ModelManagement autoscaling_config: @@ -324,6 +334,9 @@ applications: target_ongoing_requests: 16 ray_actor_options: num_cpus: 0.1 + runtime_env: + env_vars: + DEVICE_COUNT_PER_PHYSICAL_NODE: "8" # 每台机器的物理 GPU 总数 # 3. Sampler 服务(可选,用于推理采样) - name: sampler-Qwen2.5-0.5B-Instruct @@ -343,8 +356,7 @@ applications: device_type: cuda device_mesh: device_type: cuda - mesh: [0] - mesh_dim_names: ['dp'] + dp_size: 1 # 数据并行大小 deployments: - name: SamplerManagement autoscaling_config: @@ -354,6 +366,9 @@ applications: ray_actor_options: num_cpus: 0.1 num_gpus: 1 # Sampler 需要独立 GPU + runtime_env: + env_vars: + DEVICE_COUNT_PER_PHYSICAL_NODE: "8" # 每台机器的物理 GPU 总数 ``` ## 配置项说明 @@ -375,11 +390,30 @@ applications: ```yaml device_group: name: model # 设备组名称 - ranks: [0, 1] # GPU 卡号列表 + ranks: [0, 1] # 物理 GPU 卡号列表 device_type: cuda # 设备类型:cuda / CPU device_mesh: device_type: cuda - mesh: [0, 1] # 网格中的设备索引 - mesh_dim_names: ['dp'] # 维度名称,常用:dp(数据并行), tp(张量并行), pp(流水线并行) + dp_size: 2 # 数据并行大小 + # tp_size: 1 # 张量并行大小(可选) + # pp_size: 1 # 流水线并行大小(可选) + # ep_size: 1 # 专家并行大小(可选) +``` + +**重要配置参数说明:** + +| 参数 | 类型 | 说明 | +|------|------|------| +| `ranks` | list[int] | **物理 GPU 卡号**,直接对应机器上的实际 GPU 设备 | +| `dp_size` | int | 数据并行大小 | +| `tp_size` | int (可选) | 张量并行大小 | +| `pp_size` | int (可选) | 流水线并行大小 | +| `ep_size` | int (可选) | 专家并行大小(用于 MoE 模型) | + +**环境变量:** + +```bash +export DEVICE_COUNT_PER_PHYSICAL_NODE=8 # 每台物理机上的 GPU 总数(必须设置) +export TWINKLE_TRUST_REMOTE_CODE=0 # 是否信任远程代码 ``` From 7b5412bc761a61954667a49b4cc9fe9cff8c8267 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Fri, 13 Feb 2026 11:45:00 +0800 Subject: [PATCH 19/22] fix lint --- src/twinkle/server/utils/adapter_manager.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/twinkle/server/utils/adapter_manager.py b/src/twinkle/server/utils/adapter_manager.py index 6c206939..04e56922 100644 --- a/src/twinkle/server/utils/adapter_manager.py +++ b/src/twinkle/server/utils/adapter_manager.py @@ -277,8 +277,10 @@ def _adapter_countdown_loop(self) -> None: # Has session: check session expiration and TTL session_expired = not self._is_session_alive(session_id) should_expire = session_expired or exceeded_ttl - logger.debug(f'[AdapterManager] Adapter {adapter_name} session expiration check ' - f'(session_id={session_id}, session_alive={not session_expired}, should_expire={should_expire})') + logger.debug( + f'[AdapterManager] Adapter {adapter_name} session expiration check ' + f'(session_id={session_id}, session_alive={not session_expired}, should_expire={should_expire})' # noqa:E501 + ) expiration_reasons = [] if exceeded_ttl: expiration_reasons.append('ttl_exceeded') @@ -289,8 +291,10 @@ def _adapter_countdown_loop(self) -> None: info['inactivity_counter'] = info.get('inactivity_counter', 0) + 1 exceeded_inactivity = info['inactivity_counter'] > self._adapter_timeout should_expire = exceeded_ttl or exceeded_inactivity - logger.debug(f'[AdapterManager] Adapter {adapter_name} inactivity check ' - f'(inactivity_counter={info["inactivity_counter"]}, timeout={self._adapter_timeout}, should_expire={should_expire})') + logger.debug( + f'[AdapterManager] Adapter {adapter_name} inactivity check ' + f'(inactivity_counter={info["inactivity_counter"]}, timeout={self._adapter_timeout}, should_expire={should_expire})' # noqa:E501 + ) expiration_reasons = [] if exceeded_ttl: expiration_reasons.append('ttl_exceeded') @@ -308,7 +312,7 @@ def _adapter_countdown_loop(self) -> None: try: self._on_adapter_expired(adapter_name) logger.info(f'[AdapterManager] Adapter {adapter_name} expired ' - f"(reasons={','.join(expiration_reasons)}, session={session_id})") + f"(reasons={','.join(expiration_reasons)}, session={session_id})") success = True except Exception as e: logger.warning(f'[AdapterManager] Error while expiring adapter {adapter_name}: {e}') From 987d89e7551c10155041e1d6074e063adc5528df Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Fri, 13 Feb 2026 12:01:09 +0800 Subject: [PATCH 20/22] update --- .../tinker/megatron/server_config_7b.yaml | 68 +++++++++---------- cookbook/client/tinker/sample.py | 1 - cookbook/client/tinker/short_math_grpo.py | 8 ++- 3 files changed, 41 insertions(+), 36 deletions(-) diff --git a/cookbook/client/tinker/megatron/server_config_7b.yaml b/cookbook/client/tinker/megatron/server_config_7b.yaml index 690781e2..cad014c9 100644 --- a/cookbook/client/tinker/megatron/server_config_7b.yaml +++ b/cookbook/client/tinker/megatron/server_config_7b.yaml @@ -71,37 +71,37 @@ applications: # 3. Sampler Service - Runs inference / sampling using vLLM engine # Used for generating text from the model (e.g., evaluating LoRA results). - # - name: sampler-Qwen2.5-7B-Instruct - # route_prefix: /api/v1/sampler/Qwen/Qwen2.5-7B-Instruct - # import_path: sampler - # args: - # model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope model identifier - # nproc_per_node: 2 # Number of GPU processes per node - # sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler) - # engine_args: # vLLM engine-specific settings - # max_model_len: 4096 # Maximum sequence length the engine supports - # gpu_memory_utilization: 0.5 # Fraction of GPU memory to use (0.0-1.0) - # enable_lora: true # Allow loading LoRA adapters during inference - # logprobs_mode: processed_logprobs # Logprobs mode for sampling results - # device_group: # Logical device group for the sampler - # name: sampler - # ranks: [2] # GPU rank indices to use - # device_type: cuda - # device_mesh: - # device_type: cuda - # dp_size: 1 - # queue_config: - # rps_limit: 100 # Max requests per second - # tps_limit: 100000 # Max tokens per second - # deployments: - # - name: SamplerManagement - # autoscaling_config: - # min_replicas: 1 - # max_replicas: 1 - # target_ongoing_requests: 16 - # ray_actor_options: - # num_cpus: 0.1 - # runtime_env: - # env_vars: - # TWINKLE_TRUST_REMOTE_CODE: "0" - # DEVICE_COUNT_PER_PHYSICAL_NODE: "8" + - name: sampler-Qwen2.5-7B-Instruct + route_prefix: /api/v1/sampler/Qwen/Qwen2.5-7B-Instruct + import_path: sampler + args: + model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope model identifier + nproc_per_node: 2 # Number of GPU processes per node + sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler) + engine_args: # vLLM engine-specific settings + max_model_len: 4096 # Maximum sequence length the engine supports + gpu_memory_utilization: 0.5 # Fraction of GPU memory to use (0.0-1.0) + enable_lora: true # Allow loading LoRA adapters during inference + logprobs_mode: processed_logprobs # Logprobs mode for sampling results + device_group: # Logical device group for the sampler + name: sampler + ranks: [2] # GPU rank indices to use + device_type: cuda + device_mesh: + device_type: cuda + dp_size: 1 + queue_config: + rps_limit: 100 # Max requests per second + tps_limit: 100000 # Max tokens per second + deployments: + - name: SamplerManagement + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 16 + ray_actor_options: + num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_TRUST_REMOTE_CODE: "0" + DEVICE_COUNT_PER_PHYSICAL_NODE: "8" diff --git a/cookbook/client/tinker/sample.py b/cookbook/client/tinker/sample.py index 6a1ce937..29bc3ce3 100644 --- a/cookbook/client/tinker/sample.py +++ b/cookbook/client/tinker/sample.py @@ -4,7 +4,6 @@ # for text generation (sampling) via the Tinker-compatible client API. # The server must be running first (see server.py and server_config.yaml). -from modelscope import AutoTokenizer from tinker import types from twinkle.data_format import Message, Trajectory diff --git a/cookbook/client/tinker/short_math_grpo.py b/cookbook/client/tinker/short_math_grpo.py index 944d54c2..6ab037f3 100644 --- a/cookbook/client/tinker/short_math_grpo.py +++ b/cookbook/client/tinker/short_math_grpo.py @@ -21,15 +21,17 @@ import numpy as np import os import re -from modelscope import AutoTokenizer from tinker import types from typing import List, Tuple +from twinkle_client import init_tinker_compat_client from twinkle import get_logger from twinkle.advantage import GRPOAdvantage from twinkle.data_format import Message, Trajectory from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta +from twinkle.preprocessor import Preprocessor +from twinkle.reward.base import Reward from twinkle.metric import CompletionRewardMetric from twinkle.template import Template @@ -332,6 +334,10 @@ def main(): ).tolist() frac_zero_std = (1.0 if all(abs(a) < 1e-8 for a in advantages) else 0.0) + if frac_zero_std == 1.0: + logger.info(f'Step {step}: All advantages are zero, skipping training') + step += 1 + continue # ========== 6. Train the policies with GRPO loss ========== # Train the policies with the Advantage-Regularized policy From 65f1f7fe0373ffb97a2e96091b6964598031deb0 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Fri, 13 Feb 2026 12:10:28 +0800 Subject: [PATCH 21/22] fix --- .pre-commit-config.yaml | 14 +++++++------- cookbook/client/tinker/sample.py | 17 +++++++++-------- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index eb1b6e27..f1979a9a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,17 +28,17 @@ repos: rev: v6.0.0 hooks: - id: trailing-whitespace - exclude: ^(examples/|cookbook/|client_tools/|src/twinkle_client/) + exclude: ^(client_tools/|src/twinkle_client/) - id: check-yaml - exclude: ^(examples/|cookbook/|client_tools/|src/twinkle_client/) + exclude: ^(client_tools/|src/twinkle_client/) - id: end-of-file-fixer - exclude: ^(examples/|cookbook/|client_tools/|src/twinkle_client/) + exclude: ^(client_tools/|src/twinkle_client/) - id: requirements-txt-fixer - exclude: ^(examples/|cookbook/|client_tools/|src/twinkle_client/) + exclude: ^(client_tools/|src/twinkle_client/) - id: double-quote-string-fixer - exclude: ^(examples/|cookbook/|client_tools/|src/twinkle_client/) + exclude: ^(client_tools/|src/twinkle_client/) - id: check-merge-conflict - exclude: ^(examples/|cookbook/|client_tools/|src/twinkle_client/) + exclude: ^(client_tools/|src/twinkle_client/) - id: mixed-line-ending args: ["--fix=lf"] - exclude: ^(examples/|cookbook/|client_tools/|src/twinkle_client/) + exclude: ^(client_tools/|src/twinkle_client/) diff --git a/cookbook/client/tinker/sample.py b/cookbook/client/tinker/sample.py index 29bc3ce3..e995123f 100644 --- a/cookbook/client/tinker/sample.py +++ b/cookbook/client/tinker/sample.py @@ -9,19 +9,20 @@ from twinkle.data_format import Message, Trajectory from twinkle.template import Template from twinkle_client import init_tinker_compat_client -from twinkle.data_format import Message, Trajectory -from twinkle.template import Template # Step 1: Define the base model and connect to the server -base_model = 'Qwen/Qwen2.5-7B-Instruct' -service_client = init_tinker_compat_client(base_url='http://localhost:8000') - +base_model = 'Qwen/Qwen3-30B-A3B-Instruct-2507' +service_client = init_tinker_compat_client( + base_url='http://www.modelscope.cn/twinkle', + api_key=os.environ.get('MODELSCOPE_SDK_TOKEN') +) # Step 2: Create a sampling client by loading weights from a saved checkpoint. # The model_path is a twinkle:// URI pointing to a previously saved LoRA checkpoint. # The server will load the base model and apply the LoRA adapter weights. -sampling_client = service_client.create_sampling_client( - model_path='twinkle://20260212_174205-Qwen_Qwen2_5-7B-Instruct-51edc9ed/weights/twinkle-lora-2', - base_model=base_model) +service_client.create_sampling_client( + model_path='twinkle://xxx-Qwen_Qwen3-30B-A3B-Instruct-2507-xxx/weights/twinkle-lora-1', + base_model=base_model +) # Step 3: Load the tokenizer locally to encode the prompt and decode the results print(f'Using model {base_model}') From 45a9cdb0296fd9a7d651621db574f9b1addf6696 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Fri, 13 Feb 2026 12:17:22 +0800 Subject: [PATCH 22/22] update --- docs/source_en/Usage Guide/Server and Client/Server.md | 2 +- src/twinkle/server/twinkle/model.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/docs/source_en/Usage Guide/Server and Client/Server.md b/docs/source_en/Usage Guide/Server and Client/Server.md index 73583246..ec7b4b42 100644 --- a/docs/source_en/Usage Guide/Server and Client/Server.md +++ b/docs/source_en/Usage Guide/Server and Client/Server.md @@ -121,7 +121,7 @@ applications: - Different components will be automatically assigned to different Nodes - Ray will automatically schedule to the appropriate Node based on resource requirements (`num_gpus`, `num_cpus` in `ray_actor_options`) -In the YAML configuration file, **each component needs to occupy a separate Node**, and the `ranks` within each Node are numbered starting from 0. +In the YAML configuration file, **each component needs to occupy a separate Node**. **Example configuration:** diff --git a/src/twinkle/server/twinkle/model.py b/src/twinkle/server/twinkle/model.py index aa9450f9..1660cd10 100644 --- a/src/twinkle/server/twinkle/model.py +++ b/src/twinkle/server/twinkle/model.py @@ -200,8 +200,6 @@ def _on_adapter_expired(self, adapter_name: str) -> None: if self.get_adapter_info(adapter_name): # Clear adapter state self.clear_adapter_state(adapter_name) - - self.model.remove_adapter(adapter_name) # Unregister from adapter manager self.unregister_adapter(adapter_name)