Skip to content

Commit 80c0fd8

Browse files
committed
update
1 parent 41d92f8 commit 80c0fd8

File tree

2 files changed

+31
-20
lines changed

2 files changed

+31
-20
lines changed

src/twinkle/server/tinker/server.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
def build_server_app(
3434
deploy_options: Dict[str, Any],
3535
supported_models: Optional[List[types.SupportedModel]] = None,
36+
server_config: Dict[str, Any] = {},
3637
**kwargs
3738
):
3839
"""Build and configure the Tinker-compatible server application.
@@ -43,23 +44,12 @@ def build_server_app(
4344
Args:
4445
deploy_options: Ray Serve deployment configuration (num_replicas, etc.)
4546
supported_models: List of supported base models for validation
47+
server_config: Server configuration options (per_token_adapter_limit, etc.)
4648
**kwargs: Additional keyword arguments (route_prefix, etc.)
4749
4850
Returns:
4951
Configured Ray Serve deployment bound with options
5052
"""
51-
# Normalize supported_models to objects; passing raw dicts can trigger internal errors
52-
# when creating LoRA training clients via the tinker API.
53-
if supported_models:
54-
normalized = []
55-
for item in supported_models:
56-
if isinstance(item, types.SupportedModel):
57-
normalized.append(item)
58-
elif isinstance(item, dict):
59-
normalized.append(types.SupportedModel(**item))
60-
else:
61-
raise TypeError(...)
62-
supported_models = normalized
6353
app = FastAPI()
6454

6555
@app.middleware("http")
@@ -79,18 +69,19 @@ class TinkerCompatServer:
7969
- Training run and checkpoint CRUD operations
8070
"""
8171

82-
def __init__(self, supported_models: Optional[List[types.SupportedModel]] = None, **kwargs) -> None:
72+
def __init__(self, supported_models: Optional[List[types.SupportedModel]] = None, server_config: Dict[str, Any] = {}, **kwargs) -> None:
8373
"""Initialize the Tinker-compatible server.
8474
8575
Args:
8676
supported_models: List of supported base models for validation
8777
**kwargs: Additional configuration (route_prefix, etc.)
8878
"""
89-
self.state = get_server_state()
79+
# Get per_token_adapter_limit from kwargs or use default
80+
self.state = get_server_state(**server_config)
9081
# Disable proxy for internal requests to avoid routing through external proxies
9182
self.client = httpx.AsyncClient(timeout=None, trust_env=False)
9283
self.route_prefix = kwargs.get("route_prefix", "/api/v1")
93-
self.supported_models = supported_models or [
84+
self.supported_models = self.normalize_models(supported_models) or [
9485
types.SupportedModel(model_name="Qwen/Qwen2.5-0.5B-Instruct"),
9586
types.SupportedModel(model_name="Qwen/Qwen2.5-3B-Instruct"),
9687
types.SupportedModel(model_name="Qwen/Qwen2.5-7B-Instruct"),
@@ -100,6 +91,20 @@ def __init__(self, supported_models: Optional[List[types.SupportedModel]] = None
10091
# Lock for ModelScope config file operations (login writes, get_user_info reads)
10192
self._modelscope_config_lock = asyncio.Lock()
10293

94+
def normalize_models(self, supported_models):
95+
# Normalize supported_models to objects; passing raw dicts can trigger internal errors
96+
# when creating LoRA training clients via the tinker API.
97+
if supported_models:
98+
normalized = []
99+
for item in supported_models:
100+
if isinstance(item, types.SupportedModel):
101+
normalized.append(item)
102+
elif isinstance(item, dict):
103+
normalized.append(types.SupportedModel(**item))
104+
else:
105+
normalized.append(types.SupportedModel(name=item))
106+
return normalized
107+
103108
def _validate_base_model(self, base_model: str) -> None:
104109
"""Validate that base_model is in supported_models list.
105110
@@ -710,4 +715,8 @@ async def save_weights_for_sampler(
710715
base_model = self._get_base_model(body.model_id)
711716
return await self._proxy_to_model(request, "save_weights_for_sampler", base_model)
712717

713-
return TinkerCompatServer.options(**deploy_options).bind(supported_models=supported_models, **kwargs)
718+
return TinkerCompatServer.options(**deploy_options).bind(
719+
supported_models=supported_models,
720+
server_config=server_config,
721+
**kwargs
722+
)

src/twinkle/server/utils/state.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,8 @@ def get_cleanup_stats(self) -> Dict[str, Any]:
584584

585585

586586
def get_server_state(actor_name: str = 'twinkle_server_state',
587-
auto_start_cleanup: bool = True) -> ServerStateProxy:
587+
auto_start_cleanup: bool = True,
588+
**server_state_kwargs) -> ServerStateProxy:
588589
"""
589590
Get or create the ServerState Ray actor.
590591
@@ -594,6 +595,8 @@ def get_server_state(actor_name: str = 'twinkle_server_state',
594595
Args:
595596
actor_name: Name for the Ray actor (default: 'twinkle_server_state')
596597
auto_start_cleanup: Whether to automatically start the cleanup task (default: True)
598+
**server_state_kwargs: Additional keyword arguments passed to ServerState constructor
599+
(e.g., expiration_timeout, cleanup_interval, per_token_adapter_limit)
597600
598601
Returns:
599602
A ServerStateProxy for interacting with the actor
@@ -603,7 +606,7 @@ def get_server_state(actor_name: str = 'twinkle_server_state',
603606
except ValueError:
604607
try:
605608
_ServerState = ray.remote(ServerState)
606-
actor = _ServerState.options(name=actor_name, lifetime='detached').remote()
609+
actor = _ServerState.options(name=actor_name, lifetime='detached').remote(**server_state_kwargs)
607610
# Start cleanup task for newly created actor
608611
if auto_start_cleanup:
609612
try:
@@ -613,5 +616,4 @@ def get_server_state(actor_name: str = 'twinkle_server_state',
613616
except ValueError:
614617
actor = ray.get_actor(actor_name)
615618
assert actor is not None
616-
return ServerStateProxy(actor)
617-
619+
return ServerStateProxy(actor)

0 commit comments

Comments
 (0)