Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cookbook/client/tinker/megatron/server_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ applications:
route_prefix: /api/v1 # API endpoint prefix (Tinker-compatible)
import_path: server # Python module to import
args:
server_config:
per_token_model_limit: 3 # Maximum number of models (adapters) per token (server-globally enforced)

deployments:
- name: TinkerCompatServer
Expand Down Expand Up @@ -95,7 +97,6 @@ applications:
rps_limit: 20 # Max requests per second
tps_limit: 16000 # Max tokens per second
adapter_config:
per_token_adapter_limit: 3 # Max concurrent LoRA adapters
adapter_timeout: 30 # Seconds before idle adapter unload
adapter_max_lifetime: 36000 # Maximum lifetime of an adapter in seconds (e.g., 10 hours)
deployments:
Expand Down
69 changes: 35 additions & 34 deletions cookbook/client/tinker/megatron/server_config_7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ applications:
route_prefix: /api/v1 # API endpoint prefix (Tinker-compatible)
import_path: server # Python module to import
args:
server_config:
per_token_model_limit: 1 # Maximum number of models (adapters) per token (server-globally enforced)
supported_models:
- Qwen/Qwen2.5-7B-Instruct
deployments:
Expand Down Expand Up @@ -56,7 +58,6 @@ applications:
adapter_config:
adapter_timeout: 30 # Seconds before idle adapter unload
adapter_max_lifetime: 36000 # Maximum lifetime of an adapter in seconds (e.g., 10 hours)
per_token_adapter_limit: 30
deployments:
- name: ModelManagement
autoscaling_config:
Expand All @@ -71,36 +72,36 @@ applications:

# 3. Sampler Service - Runs inference / sampling using vLLM engine
# Used for generating text from the model (e.g., evaluating LoRA results).
- name: sampler-Qwen2.5-7B-Instruct
route_prefix: /api/v1/sampler/Qwen/Qwen2.5-7B-Instruct
import_path: sampler
args:
model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope model identifier
nproc_per_node: 2 # Number of GPU processes per node
sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler)
engine_args: # vLLM engine-specific settings
max_model_len: 4096 # Maximum sequence length the engine supports
gpu_memory_utilization: 0.5 # Fraction of GPU memory to use (0.0-1.0)
enable_lora: true # Allow loading LoRA adapters during inference
logprobs_mode: processed_logprobs # Logprobs mode for sampling results
device_group: # Logical device group for the sampler
name: sampler
ranks: [2] # GPU rank indices to use
device_type: cuda
device_mesh:
device_type: cuda
dp_size: 1
queue_config:
rps_limit: 100 # Max requests per second
tps_limit: 100000 # Max tokens per second
deployments:
- name: SamplerManagement
autoscaling_config:
min_replicas: 1
max_replicas: 1
target_ongoing_requests: 16
ray_actor_options:
num_cpus: 0.1
runtime_env:
env_vars:
TWINKLE_TRUST_REMOTE_CODE: "0"
# - name: sampler-Qwen2.5-7B-Instruct
# route_prefix: /api/v1/sampler/Qwen/Qwen2.5-7B-Instruct
# import_path: sampler
# args:
# model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope model identifier
# nproc_per_node: 2 # Number of GPU processes per node
# sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler)
# engine_args: # vLLM engine-specific settings
# max_model_len: 4096 # Maximum sequence length the engine supports
# gpu_memory_utilization: 0.5 # Fraction of GPU memory to use (0.0-1.0)
# enable_lora: true # Allow loading LoRA adapters during inference
# logprobs_mode: processed_logprobs # Logprobs mode for sampling results
# device_group: # Logical device group for the sampler
# name: sampler
# ranks: [2] # GPU rank indices to use
# device_type: cuda
# device_mesh:
# device_type: cuda
# dp_size: 1
# queue_config:
# rps_limit: 100 # Max requests per second
# tps_limit: 100000 # Max tokens per second
# deployments:
# - name: SamplerManagement
# autoscaling_config:
# min_replicas: 1
# max_replicas: 1
# target_ongoing_requests: 16
# ray_actor_options:
# num_cpus: 0.1
# runtime_env:
# env_vars:
# TWINKLE_TRUST_REMOTE_CODE: "0"
1 change: 0 additions & 1 deletion cookbook/client/tinker/self_congnition.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
# 2. eval(): Load a trained checkpoint and sample from it to verify
# that the model has learned the custom identity.
# The server must be running first (see server.py and server_config.yaml).
import numpy as np
import os
from tqdm import tqdm
from tinker import types
Expand Down
4 changes: 2 additions & 2 deletions cookbook/client/tinker/transformer/server_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ applications:
route_prefix: /api/v1 # API endpoint prefix (Tinker-compatible)
import_path: server # Python module to import
args:

server_config:
per_token_model_limit: 3 # Maximum number of models (adapters) per token (server-globally enforced)
deployments:
- name: TinkerCompatServer
autoscaling_config:
Expand Down Expand Up @@ -52,7 +53,6 @@ applications:
rps_limit: 100 # Max requests per second
tps_limit: 100000 # Max tokens per second
adapter_config:
per_token_adapter_limit: 30 # Max concurrent LoRA adapters
adapter_timeout: 1800 # Seconds before idle adapter unload
deployments:
- name: ModelManagement
Expand Down
4 changes: 2 additions & 2 deletions cookbook/client/twinkle/megatron/server_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ applications:
route_prefix: /server # API endpoint prefix
import_path: server # Python module to import
args:

server_config:
per_token_model_limit: 3 # Maximum number of models (adapters) per token (server-globally enforced)
deployments:
- name: TwinkleServer
autoscaling_config:
Expand Down Expand Up @@ -50,7 +51,6 @@ applications:
mesh: [0,1] # Device indices in the mesh
mesh_dim_names: ['dp'] # Mesh dimension names: 'dp' = data parallel
adapter_config:
per_token_adapter_limit: 30 # Max concurrent LoRA adapters
adapter_timeout: 1800 # Seconds before idle adapter unload
deployments:
- name: ModelManagement
Expand Down
5 changes: 2 additions & 3 deletions cookbook/client/twinkle/transformer/server_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ applications:
route_prefix: /server # API endpoint prefix
import_path: server # Python module to import
args:

server_config:
per_token_model_limit: 3 # Maximum number of models (adapters) per token (server-globally enforced)
deployments:
- name: TwinkleServer
autoscaling_config:
Expand All @@ -40,7 +41,6 @@ applications:
use_megatron: false # Use HuggingFace Transformers (not Megatron)
model_id: "ms://Qwen/Qwen2.5-3B-Instruct" # ModelScope model identifier to load
adapter_config:
per_token_adapter_limit: 30 # Max LoRA adapters that can be active simultaneously
adapter_timeout: 1800 # Seconds before an idle adapter is unloaded
nproc_per_node: 2 # Number of GPU processes per node
device_group: # Logical device group for this model
Expand Down Expand Up @@ -103,7 +103,6 @@ applications:
gpu_memory_utilization: 0.4
max_model_len: 1024
adapter_config: # Adapter lifecycle management
per_token_adapter_limit: 30 # Max LoRA adapters per user
adapter_timeout: 1800 # Seconds before idle adapter is unloaded
device_group:
name: sampler
Expand Down
3 changes: 2 additions & 1 deletion docs/source_en/Usage Guide/Server and Client/Server.md
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,6 @@ applications:
use_megatron: false # Use Transformers backend
model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope model identifier
adapter_config: # LoRA adapter configuration
per_token_adapter_limit: 30 # Maximum number of LoRAs that can be activated simultaneously
adapter_timeout: 1800 # Idle adapter timeout unload time (seconds)
nproc_per_node: 2 # Number of GPU processes per node
device_group: # Logical device group
Expand Down Expand Up @@ -354,6 +353,8 @@ applications:
route_prefix: /api/v1 # Tinker protocol API prefix
import_path: server
args:
server_config:
per_token_model_limit: 30 # Maximum number of models (adapters) per token (server-global)
deployments:
- name: TinkerCompatServer
autoscaling_config:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ applications:
use_megatron: false # 使用 Transformers 后端
model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope 模型标识
adapter_config: # LoRA 适配器配置
per_token_adapter_limit: 30 # 同时可激活的最大 LoRA 数量
adapter_timeout: 1800 # 空闲适配器超时卸载时间(秒)
nproc_per_node: 2 # 每节点 GPU 进程数
device_group: # 逻辑设备组
Expand Down Expand Up @@ -297,6 +296,8 @@ applications:
route_prefix: /api/v1 # Tinker 协议 API 前缀
import_path: server
args:
server_config:
per_token_model_limit: 30 # 每个 token 最多可创建的模型(适配器)数量(服务器全局生效)
deployments:
- name: TinkerCompatServer
autoscaling_config:
Expand Down
15 changes: 9 additions & 6 deletions src/twinkle/server/tinker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,18 +188,21 @@ async def create_model(self, request: Request, body: types.CreateModelRequest) -
Returns:
UntypedAPIFuture wrapping CreateModelResponse with model_id
"""
# Register a new model_id for each create_model call
model_id = self.state.register_model(body.model_dump(), token=request.state.token)

async def _create_adapter():
model_id = None
try:
# Register a new model_id for each create_model call
model_id = self.state.register_model(body.model_dump(), token=request.state.token)

# Create a new LoRA adapter for the model
if body.lora_config:
# TODO: support more lora config parameters, train_unembed, etc.
lora_cfg = LoraConfig(r=body.lora_config.rank, target_modules='all-linear')

adapter_name = self.get_adapter_name(adapter_name=model_id)

# Register adapter FIRST (limit check happens inside register_adapter)
# Register adapter FIRST
self.register_adapter(adapter_name, request.state.token, session_id=body.session_id)

# Create adapter AFTER successful registration
Expand All @@ -218,8 +221,9 @@ async def _create_adapter():
return types.CreateModelResponse(model_id=model_id)
except Exception:
# Ensure we don't leave stale grad state.
adapter_name = self.get_adapter_name(adapter_name=model_id)
self._cleanup_adapter(adapter_name)
if model_id:
adapter_name = self.get_adapter_name(adapter_name=model_id)
self._cleanup_adapter(adapter_name)

logger.error(traceback.format_exc())
return types.RequestFailedResponse(
Expand All @@ -229,7 +233,6 @@ async def _create_adapter():

return await self.schedule_task(
_create_adapter,
model_id=model_id,
token=request.state.token,
task_type='create_model',
)
Expand Down
1 change: 0 additions & 1 deletion src/twinkle/server/tinker/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from __future__ import annotations

import asyncio
import dataclasses
import httpx
import logging
import os
Expand Down
2 changes: 1 addition & 1 deletion src/twinkle/server/twinkle/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ def add_adapter_to_model(self, request: Request, body: AddAdapterRequest):
token = request.state.token
training_run_manager = create_training_run_manager(token)

# Register adapter FIRST (limit check happens inside register_adapter)
# Register adapter FIRST
self.register_adapter(adapter_name, token)

# Create adapter AFTER successful registration
Expand Down
36 changes: 0 additions & 36 deletions src/twinkle/server/utils/adapter_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ class AdapterManagerMixin:
def _init_adapter_manager(
self,
adapter_timeout: float = 1800.0,
per_token_adapter_limit: int = 30,
adapter_max_lifetime: float = 12 * 60 * 60,
) -> None:
"""Initialize the adapter manager.
Expand All @@ -54,21 +53,16 @@ def _init_adapter_manager(
adapter_timeout: Timeout in seconds for inactive adapters and session-based expiration.
Default is 1800.0 (30 minutes). Adapters linked to sessions will expire
when their session hasn't been touched for this duration.
per_token_adapter_limit: Maximum number of adapters per user token.
Default is 30.
adapter_max_lifetime: Maximum lifetime in seconds for an adapter since creation.
Default is 43200.0 (12 hours). If <= 0, lifetime enforcement is disabled.
"""
self._adapter_timeout = adapter_timeout
self._per_token_adapter_limit = per_token_adapter_limit
self._adapter_max_lifetime = adapter_max_lifetime

# Adapter lifecycle tracking
# Dict mapping adapter_name ->
# {'token': str, 'session_id': str, 'last_activity': float, 'created_at': float, 'inactivity_counter': int}
self._adapter_records: dict[str, dict[str, Any]] = {}
# Track adapter count per token
self._adapter_counts: dict[str, int] = {}

# Countdown thread
self._adapter_countdown_thread: threading.Thread | None = None
Expand All @@ -82,15 +76,7 @@ def register_adapter(self, adapter_name: str, token: str, session_id: str | None
token: User token that owns this adapter.
session_id: Optional session ID to associate with this adapter.
If provided, adapter will expire when the session expires.

Raises:
RuntimeError: If adapter limit is exceeded for this token.
"""
# Check adapter limit BEFORE registering
allowed, reason = self.check_adapter_limit(token)
if not allowed:
raise RuntimeError(reason)

current_time = time.time()
self._adapter_records[adapter_name] = {
'token': token,
Expand Down Expand Up @@ -353,25 +339,3 @@ def stop_adapter_countdown(self) -> None:
# Wait for thread to finish (it checks the flag every second)
self._adapter_countdown_thread.join(timeout=2.0)
logger.debug('[AdapterManager] Countdown thread stopped')

def check_adapter_limit(self, token: str) -> tuple[bool, str | None]:
"""Check adapter count for a user token.

This method enforces per-user adapter limits to prevent resource exhaustion.
Counts adapters directly from _adapter_records instead of using state storage.

Args:
token: User token to check.

Returns:
Tuple of (allowed: bool, reason: Optional[str]).
If allowed is False, reason contains the explanation.
"""
# Count adapters directly from _adapter_records
current_count = sum(1 for record in self._adapter_records.values()
if record.get('token') == token and not record.get('expiring', False))

# Check if current count exceeds limit
if current_count >= self._per_token_adapter_limit:
return False, f'Adapter limit exceeded: {current_count}/{self._per_token_adapter_limit} adapters'
return True, None
Loading
Loading