Skip to content

Commit f7a3d0e

Browse files
committed
update async
1 parent 2472409 commit f7a3d0e

File tree

11 files changed

+676
-581
lines changed

11 files changed

+676
-581
lines changed

src/twinkle/server/model/app.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import twinkle
1616
from twinkle import DeviceGroup, DeviceMesh
17-
from twinkle.server.utils.adapter_manager import AdapterManagerMixin
17+
from twinkle.server.utils.lifecycle import AdapterManagerMixin
1818
from twinkle.server.utils.state import ServerStateProxy, get_server_state
1919
from twinkle.server.utils.task_queue import TaskQueueConfig, TaskQueueMixin
2020
from twinkle.server.utils.validation import get_token_from_request, verify_request_token
@@ -83,7 +83,7 @@ def __init__(self,
8383
# Initialize mixins
8484
self._init_task_queue(TaskQueueConfig.from_dict(queue_config))
8585
self._init_adapter_manager(**adapter_config)
86-
self.start_adapter_countdown()
86+
# Note: countdown task is started lazily in _ensure_sticky()
8787

8888
async def _ensure_replica_registered(self):
8989
"""Lazily register replica on first async request."""
@@ -98,6 +98,8 @@ async def _sticky_entry(self, sticky_key: str):
9898
async def _ensure_sticky(self):
9999
sticky_key = serve.get_multiplexed_model_id()
100100
await self._sticky_entry(sticky_key)
101+
# Lazy-start countdown task on first request (requires running event loop)
102+
self._ensure_countdown_started()
101103

102104
async def _on_request_start(self, request: Request) -> str:
103105
await self._ensure_sticky()

src/twinkle/server/processor/app.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import twinkle
2222
from twinkle import DeviceGroup, DeviceMesh, get_logger
23-
from twinkle.server.utils.processor_manager import ProcessorManagerMixin
23+
from twinkle.server.utils.lifecycle import ProcessorManagerMixin
2424
from twinkle.server.utils.state import ServerStateProxy, get_server_state
2525
from twinkle.server.utils.validation import verify_request_token
2626
from .twinkle_handlers import _register_processor_routes
@@ -69,7 +69,7 @@ def __init__(self,
6969
processor_timeout=float(_cfg.get('processor_timeout', 1800.0)),
7070
per_token_processor_limit=int(_cfg.get('per_token_processor_limit', _env_limit)),
7171
)
72-
self.start_processor_countdown()
72+
# Note: countdown task is started lazily in _ensure_sticky()
7373

7474
@serve.multiplexed(max_num_models_per_replica=100)
7575
async def _sticky_entry(self, sticky_key: str):
@@ -78,6 +78,8 @@ async def _sticky_entry(self, sticky_key: str):
7878
async def _ensure_sticky(self):
7979
sticky_key = serve.get_multiplexed_model_id()
8080
await self._sticky_entry(sticky_key)
81+
# Lazy-start countdown task on first request (requires running event loop)
82+
self._ensure_countdown_started()
8183

8284
def _on_processor_expired(self, processor_id: str) -> None:
8385
"""Called by the countdown thread when a processor's session expires."""

src/twinkle/server/sampler/app.py

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
import twinkle
1515
from twinkle import DeviceGroup, DeviceMesh
16-
from twinkle.server.utils.adapter_manager import AdapterManagerMixin
1716
from twinkle.server.utils.state import ServerStateProxy, get_server_state
1817
from twinkle.server.utils.task_queue import TaskQueueConfig, TaskQueueMixin
1918
from twinkle.server.utils.validation import get_token_from_request, verify_request_token
@@ -25,14 +24,13 @@
2524
logger = get_logger()
2625

2726

28-
class SamplerManagement(TaskQueueMixin, AdapterManagerMixin):
27+
class SamplerManagement(TaskQueueMixin):
2928
"""Unified sampler management service.
3029
3130
Manages:
3231
- vLLM or Torch sampler initialization and lifecycle
3332
- Tinker inference requests (/tinker/asample) with rate limiting via TaskQueueMixin
3433
- Twinkle inference requests (/twinkle/*) calling sampler directly
35-
- Adapter lifecycle via AdapterManagerMixin
3634
- Template configuration for trajectory encoding
3735
"""
3836

@@ -43,7 +41,6 @@ def __init__(self,
4341
device_mesh: dict[str, Any],
4442
sampler_type: str = 'vllm',
4543
engine_args: dict[str, Any] | None = None,
46-
adapter_config: dict[str, Any] | None = None,
4744
queue_config: dict[str, Any] | None = None,
4845
**kwargs):
4946
self.device_group = DeviceGroup(**device_group)
@@ -82,11 +79,8 @@ def __init__(self,
8279
self.sampler.set_template('Template', model_id=model_id)
8380
self.state: ServerStateProxy = get_server_state()
8481

85-
# Initialize both mixins
82+
# Initialize task queue mixin
8683
self._init_task_queue(TaskQueueConfig.from_dict(queue_config))
87-
_adapter_config = adapter_config or {}
88-
self._init_adapter_manager(**_adapter_config)
89-
self.start_adapter_countdown()
9084

9185
@serve.multiplexed(max_num_models_per_replica=5)
9286
async def _sticky_entry(self, sticky_key: str):
@@ -101,14 +95,6 @@ async def _on_request_start(self, request: Request) -> str:
10195
token = get_token_from_request(request)
10296
return token
10397

104-
async def _on_adapter_expired(self, adapter_name: str, token: str = None) -> None:
105-
"""Handle expired adapters by removing them from the sampler."""
106-
try:
107-
self.sampler.remove_adapter(adapter_name)
108-
logger.info(f'Removed expired adapter {adapter_name}')
109-
except Exception as e:
110-
logger.warning(f'Failed to remove expired adapter {adapter_name}: {e}')
111-
11298

11399
def build_sampler_app(model_id: str,
114100
nproc_per_node: int,
@@ -117,7 +103,6 @@ def build_sampler_app(model_id: str,
117103
deploy_options: dict[str, Any],
118104
sampler_type: str = 'vllm',
119105
engine_args: dict[str, Any] | None = None,
120-
adapter_config: dict[str, Any] | None = None,
121106
queue_config: dict[str, Any] | None = None,
122107
**kwargs):
123108
"""Build a unified sampler application for text generation inference.
@@ -133,7 +118,6 @@ def build_sampler_app(model_id: str,
133118
deploy_options: Ray Serve deployment options
134119
sampler_type: Type of sampler to use ('vllm' or 'torch')
135120
engine_args: Additional engine arguments for the sampler
136-
adapter_config: Adapter lifecycle config (timeout, per-token limits)
137121
queue_config: Task queue configuration dict (rps_limit, tps_limit, etc.)
138122
**kwargs: Additional arguments passed to the sampler
139123
@@ -161,8 +145,7 @@ def get_self() -> SamplerManagement:
161145
SamplerManagementWithIngress = serve.ingress(app)(SamplerManagement)
162146
DeploymentClass = serve.deployment(name='SamplerManagement')(SamplerManagementWithIngress)
163147
return DeploymentClass.options(**deploy_options).bind(model_id, nproc_per_node, device_group, device_mesh,
164-
sampler_type, engine_args, adapter_config, queue_config,
165-
**kwargs)
148+
sampler_type, engine_args, queue_config, **kwargs)
166149

167150

168151
build_sampler_app = wrap_builder_with_device_group_env(build_sampler_app)

src/twinkle/server/sampler/twinkle_handlers.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,13 +154,10 @@ def add_adapter_to_sampler(
154154
"""Add a LoRA adapter to the sampler."""
155155
assert body.adapter_name, 'You need to specify a valid `adapter_name`'
156156
full_adapter_name = _get_twinkle_sampler_adapter_name(request, body.adapter_name)
157-
from twinkle.server.utils.validation import get_token_from_request
158-
token = get_token_from_request(request)
159157

160158
from peft import LoraConfig
161159
config = LoraConfig(**body.config) if isinstance(body.config, dict) else body.config
162160

163-
self.register_adapter(full_adapter_name, token)
164161
self.sampler.add_adapter_to_sampler(full_adapter_name, config)
165162

166163
return types.AddAdapterResponse(adapter_name=full_adapter_name)
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
# Copyright (c) ModelScope Contributors. All rights reserved.
2-
from .adapter_manager import AdapterManagerMixin
32
from .checkpoint_base import (TRAIN_RUN_INFO_FILENAME, TWINKLE_DEFAULT_SAVE_DIR, BaseCheckpointManager, BaseFileManager,
43
BaseTrainingRunManager)
54
from .device_utils import auto_fill_device_group_visible_devices, wrap_builder_with_device_group_env
6-
from .processor_manager import ProcessorManagerMixin
5+
from .lifecycle import AdapterManagerMixin, ProcessorManagerMixin, SessionResourceMixin
76
from .rate_limiter import RateLimiter
87
from .task_queue import QueueState, TaskQueueConfig, TaskQueueMixin, TaskStatus

0 commit comments

Comments
 (0)