Skip to content

Commit 2472409

Browse files
committed
update async
1 parent 5002a52 commit 2472409

File tree

11 files changed

+135
-85
lines changed

11 files changed

+135
-85
lines changed

src/twinkle/server/gateway/server.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ def _validate_base_model(self, base_model: str) -> None:
6161
detail=f"Base model '{base_model}' is not supported. "
6262
f"Supported models: {', '.join(supported_model_names)}")
6363

64-
def _get_base_model(self, model_id: str) -> str:
65-
metadata = self.state.get_model_metadata(model_id)
64+
async def _get_base_model(self, model_id: str) -> str:
65+
metadata = await self.state.get_model_metadata(model_id)
6666
if metadata and metadata.get('base_model'):
6767
return metadata['base_model']
6868
raise HTTPException(status_code=404, detail=f'Model {model_id} not found')

src/twinkle/server/gateway/tinker_gateway_handlers.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ async def create_session(
5656
body: types.CreateSessionRequest,
5757
self: GatewayServer = Depends(self_fn),
5858
) -> types.CreateSessionResponse:
59-
session_id = self.state.create_session(body.model_dump())
59+
session_id = await self.state.create_session(body.model_dump())
6060
return types.CreateSessionResponse(session_id=session_id)
6161

6262
@app.post('/session_heartbeat')
@@ -72,7 +72,7 @@ async def session_heartbeat(
7272
async def create_sampling_session(
7373
request: Request, body: types.CreateSamplingSessionRequest, self: GatewayServer = Depends(self_fn)
7474
) -> types.CreateSamplingSessionResponse: # noqa: E125
75-
sampling_session_id = self.state.create_sampling_session(body.model_dump())
75+
sampling_session_id = await self.state.create_sampling_session(body.model_dump())
7676
return types.CreateSamplingSessionResponse(sampling_session_id=sampling_session_id)
7777

7878
@app.post('/retrieve_future')
@@ -223,44 +223,44 @@ async def create_model(request: Request, body: types.CreateModelRequest,
223223

224224
@app.post('/get_info')
225225
async def get_info(request: Request, body: types.GetInfoRequest, self: GatewayServer = Depends(self_fn)) -> Any:
226-
return await self.proxy.proxy_to_model(request, 'get_info', self._get_base_model(body.model_id))
226+
return await self.proxy.proxy_to_model(request, 'get_info', await self._get_base_model(body.model_id))
227227

228228
@app.post('/unload_model')
229229
async def unload_model(request: Request, body: types.UnloadModelRequest,
230230
self: GatewayServer = Depends(self_fn)) -> Any:
231-
return await self.proxy.proxy_to_model(request, 'unload_model', self._get_base_model(body.model_id))
231+
return await self.proxy.proxy_to_model(request, 'unload_model', await self._get_base_model(body.model_id))
232232

233233
@app.post('/forward')
234234
async def forward(request: Request, body: types.ForwardRequest, self: GatewayServer = Depends(self_fn)) -> Any:
235-
return await self.proxy.proxy_to_model(request, 'forward', self._get_base_model(body.model_id))
235+
return await self.proxy.proxy_to_model(request, 'forward', await self._get_base_model(body.model_id))
236236

237237
@app.post('/forward_backward')
238238
async def forward_backward(request: Request,
239239
body: types.ForwardBackwardRequest,
240240
self: GatewayServer = Depends(self_fn)) -> Any:
241-
return await self.proxy.proxy_to_model(request, 'forward_backward', self._get_base_model(body.model_id))
241+
return await self.proxy.proxy_to_model(request, 'forward_backward', await self._get_base_model(body.model_id))
242242

243243
@app.post('/optim_step')
244244
async def optim_step(request: Request, body: types.OptimStepRequest, self: GatewayServer = Depends(self_fn)) -> Any:
245-
return await self.proxy.proxy_to_model(request, 'optim_step', self._get_base_model(body.model_id))
245+
return await self.proxy.proxy_to_model(request, 'optim_step', await self._get_base_model(body.model_id))
246246

247247
@app.post('/save_weights')
248248
async def save_weights(request: Request, body: types.SaveWeightsRequest,
249249
self: GatewayServer = Depends(self_fn)) -> Any:
250-
return await self.proxy.proxy_to_model(request, 'save_weights', self._get_base_model(body.model_id))
250+
return await self.proxy.proxy_to_model(request, 'save_weights', await self._get_base_model(body.model_id))
251251

252252
@app.post('/load_weights')
253253
async def load_weights(request: Request, body: types.LoadWeightsRequest,
254254
self: GatewayServer = Depends(self_fn)) -> Any:
255-
return await self.proxy.proxy_to_model(request, 'load_weights', self._get_base_model(body.model_id))
255+
return await self.proxy.proxy_to_model(request, 'load_weights', await self._get_base_model(body.model_id))
256256

257257
# --- Sampler Proxy Endpoints ---
258258

259259
@app.post('/asample')
260260
async def asample(request: Request, body: types.SampleRequest, self: GatewayServer = Depends(self_fn)) -> Any:
261261
base_model = body.base_model
262262
if not base_model and body.sampling_session_id:
263-
session = self.state.get_sampling_session(body.sampling_session_id)
263+
session = await self.state.get_sampling_session(body.sampling_session_id)
264264
if session:
265265
base_model = session.get('base_model')
266266
return await self.proxy.proxy_to_sampler(request, 'asample', base_model)
@@ -271,4 +271,5 @@ async def save_weights_for_sampler(
271271
body: types.SaveWeightsForSamplerRequest,
272272
self: GatewayServer = Depends(self_fn),
273273
) -> Any:
274-
return await self.proxy.proxy_to_model(request, 'save_weights_for_sampler', self._get_base_model(body.model_id))
274+
return await self.proxy.proxy_to_model(request, 'save_weights_for_sampler', await
275+
self._get_base_model(body.model_id))

src/twinkle/server/gateway/twinkle_gateway_handlers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ async def create_session(
4141
body: types.CreateSessionRequest,
4242
self: GatewayServer = Depends(self_fn),
4343
) -> types.CreateSessionResponse:
44-
session_id = self.state.create_session(body.model_dump())
44+
session_id = await self.state.create_session(body.model_dump())
4545
return types.CreateSessionResponse(session_id=session_id)
4646

4747
@app.post('/twinkle/session_heartbeat', response_model=types.SessionHeartbeatResponse)

src/twinkle/server/model/app.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,19 @@ def __init__(self,
7878
**kwargs)
7979

8080
self.state: ServerStateProxy = get_server_state()
81-
self.state.register_replica(self.replica_id, self.max_loras)
81+
self._replica_registered = False
8282

8383
# Initialize mixins
8484
self._init_task_queue(TaskQueueConfig.from_dict(queue_config))
8585
self._init_adapter_manager(**adapter_config)
8686
self.start_adapter_countdown()
8787

88+
async def _ensure_replica_registered(self):
89+
"""Lazily register replica on first async request."""
90+
if not self._replica_registered:
91+
await self.state.register_replica(self.replica_id, self.max_loras)
92+
self._replica_registered = True
93+
8894
@serve.multiplexed(max_num_models_per_replica=5)
8995
async def _sticky_entry(self, sticky_key: str):
9096
return sticky_key
@@ -95,22 +101,30 @@ async def _ensure_sticky(self):
95101

96102
async def _on_request_start(self, request: Request) -> str:
97103
await self._ensure_sticky()
104+
await self._ensure_replica_registered()
98105
token = get_token_from_request(request)
99106
return token
100107

101108
def __del__(self):
102-
self.state.unregister_replica(self.replica_id)
103-
104-
def _cleanup_adapter(self, adapter_name: str) -> None:
109+
try:
110+
# Best-effort cleanup; event loop may already be closed
111+
import asyncio
112+
loop = asyncio.get_event_loop()
113+
if loop.is_running():
114+
asyncio.create_task(self.state.unregister_replica(self.replica_id))
115+
except Exception:
116+
pass
117+
118+
async def _cleanup_adapter(self, adapter_name: str) -> None:
105119
if self.get_adapter_info(adapter_name):
106120
self.clear_adapter_state(adapter_name)
107121
self.model.remove_adapter(adapter_name)
108122
self.unregister_adapter(adapter_name)
109-
self.state.unload_model(adapter_name)
123+
await self.state.unload_model(adapter_name)
110124

111-
def _on_adapter_expired(self, adapter_name: str) -> None:
125+
async def _on_adapter_expired(self, adapter_name: str) -> None:
112126
self.fail_pending_tasks_for_model(adapter_name, reason='Adapter expired')
113-
self._cleanup_adapter(adapter_name)
127+
await self._cleanup_adapter(adapter_name)
114128

115129

116130
def build_model_app(model_id: str,

src/twinkle/server/model/tinker_handlers.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ async def create_model(
4040
async def _create_adapter():
4141
_model_id = None
4242
try:
43-
_model_id = self.state.register_model(body.model_dump(), token=token, replica_id=self.replica_id)
43+
_model_id = await self.state.register_model(body.model_dump(), token=token, replica_id=self.replica_id)
4444
if body.lora_config:
4545
lora_cfg = LoraConfig(r=body.lora_config.rank, target_modules='all-linear')
4646
adapter_name = self.get_adapter_name(adapter_name=_model_id)
@@ -56,7 +56,7 @@ async def _create_adapter():
5656
except Exception:
5757
if _model_id:
5858
adapter_name = self.get_adapter_name(adapter_name=_model_id)
59-
self._cleanup_adapter(adapter_name)
59+
await self._cleanup_adapter(adapter_name)
6060
logger.error(traceback.format_exc())
6161
return types.RequestFailedResponse(
6262
error=traceback.format_exc(),
@@ -95,7 +95,7 @@ async def unload_model(
9595

9696
async def _do_unload():
9797
adapter_name = self.get_adapter_name(adapter_name=body.model_id)
98-
self._cleanup_adapter(adapter_name)
98+
await self._cleanup_adapter(adapter_name)
9999
return types.UnloadModelResponse(model_id=body.model_id)
100100

101101
return await self.schedule_task(_do_unload, model_id=body.model_id, token=token, task_type='unload_model')
@@ -260,10 +260,10 @@ async def _do_save_for_sampler():
260260
name=checkpoint_name, output_dir=save_dir, adapter_name=adapter_name, save_optimizer=False)
261261
payload = body.model_dump()
262262
payload['model_path'] = tinker_path
263-
metadata = self.state.get_model_metadata(body.model_id) or {}
263+
metadata = await self.state.get_model_metadata(body.model_id) or {}
264264
if metadata.get('base_model'):
265265
payload['base_model'] = metadata['base_model']
266-
sampling_session_id = self.state.create_sampling_session(payload)
266+
sampling_session_id = await self.state.create_sampling_session(payload)
267267
return types.SaveWeightsForSamplerResponseInternal(path=None, sampling_session_id=sampling_session_id)
268268
except Exception:
269269
logger.error(traceback.format_exc())

src/twinkle/server/sampler/app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ async def _on_request_start(self, request: Request) -> str:
101101
token = get_token_from_request(request)
102102
return token
103103

104-
def _on_adapter_expired(self, adapter_name: str, token: str = None) -> None:
104+
async def _on_adapter_expired(self, adapter_name: str, token: str = None) -> None:
105105
"""Handle expired adapters by removing them from the sampler."""
106106
try:
107107
self.sampler.remove_adapter(adapter_name)

src/twinkle/server/sampler/tinker_handlers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ async def _do_sample():
5151
# Get model_path from body or sampling session
5252
model_path = body.model_path
5353
if not model_path and body.sampling_session_id:
54-
session = self.state.get_sampling_session(body.sampling_session_id)
54+
session = await self.state.get_sampling_session(body.sampling_session_id)
5555
if session:
5656
model_path = session.get('model_path')
5757

src/twinkle/server/utils/adapter_manager.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,21 @@ def _is_session_alive(self, session_id: str) -> bool:
107107
if not session_id:
108108
return True # No session association means always alive
109109

110-
# Get session last heartbeat through proxy
111-
last_heartbeat = self.state.get_session_last_heartbeat(session_id)
110+
# Get session last heartbeat through proxy (async method called from sync thread)
111+
import asyncio
112+
try:
113+
loop = asyncio.get_event_loop()
114+
if loop.is_running():
115+
# If loop is running, we can't use asyncio.run()
116+
# Use run_coroutine_threadsafe instead
117+
future = asyncio.run_coroutine_threadsafe(self.state.get_session_last_heartbeat(session_id), loop)
118+
last_heartbeat = future.result(timeout=5.0)
119+
else:
120+
last_heartbeat = asyncio.run(self.state.get_session_last_heartbeat(session_id))
121+
except Exception as e:
122+
logger.warning(f'[AdapterManager] Failed to check session liveness: {e}')
123+
return True # Assume alive on error
124+
112125
if last_heartbeat is None:
113126
return False # Session doesn't exist
114127

@@ -182,7 +195,7 @@ def get_adapter_info(self, adapter_name: str) -> dict[str, Any] | None:
182195
"""
183196
return self._adapter_records.get(adapter_name)
184197

185-
def _on_adapter_expired(self, adapter_name: str) -> None:
198+
async def _on_adapter_expired(self, adapter_name: str) -> None:
186199
"""Hook method called when an adapter expires.
187200
188201
This method must be overridden by inheriting classes to handle
@@ -277,7 +290,18 @@ def _adapter_countdown_loop(self) -> None:
277290
for adapter_name, _token, session_id in expired_adapters:
278291
success = False
279292
try:
280-
self._on_adapter_expired(adapter_name)
293+
# Call async _on_adapter_expired from sync thread
294+
import asyncio
295+
try:
296+
loop = asyncio.get_event_loop()
297+
if loop.is_running():
298+
future = asyncio.run_coroutine_threadsafe(self._on_adapter_expired(adapter_name), loop)
299+
future.result(timeout=10.0)
300+
else:
301+
asyncio.run(self._on_adapter_expired(adapter_name))
302+
except Exception as async_e:
303+
logger.warning(f'[AdapterManager] Async call failed for {adapter_name}: {async_e}')
304+
raise
281305
logger.info(f'[AdapterManager] Adapter {adapter_name} expired '
282306
f'(reason=session_expired, session={session_id})')
283307
success = True

src/twinkle/server/utils/processor_manager.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,18 @@ def _is_session_alive(self, session_id: str) -> bool:
125125
"""Check if a session is still alive via state proxy."""
126126
if not session_id:
127127
return True
128-
last_heartbeat = self.state.get_session_last_heartbeat(session_id)
128+
# Get session last heartbeat through proxy (async method called from sync thread)
129+
import asyncio
130+
try:
131+
loop = asyncio.get_event_loop()
132+
if loop.is_running():
133+
future = asyncio.run_coroutine_threadsafe(self.state.get_session_last_heartbeat(session_id), loop)
134+
last_heartbeat = future.result(timeout=5.0)
135+
else:
136+
last_heartbeat = asyncio.run(self.state.get_session_last_heartbeat(session_id))
137+
except Exception as e:
138+
logger.warning(f'[ProcessorManager] Failed to check session liveness: {e}')
139+
return True # Assume alive on error
129140
if last_heartbeat is None:
130141
return False
131142
return (time.time() - last_heartbeat) < self._processor_timeout

src/twinkle/server/utils/ray_serve_patch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def _patched_setup_request_context_and_handle(
6363
multiplexed_model_id = value.decode()
6464
handle = handle.options(multiplexed_model_id=multiplexed_model_id)
6565
request_context_info['multiplexed_model_id'] = multiplexed_model_id
66-
logger.info(f'[Ray Serve Patch] Matched multiplexed_model_id: {multiplexed_model_id}')
66+
logger.debug(f'[Ray Serve Patch] Matched multiplexed_model_id: {multiplexed_model_id}')
6767

6868
# Original logic for other headers (unchanged)
6969
if decoded_key == 'x-request-id':
@@ -91,8 +91,8 @@ def _apply_patch_in_worker_process():
9191
HTTPProxy.setup_request_context_and_handle = _patched_setup_request_context_and_handle
9292
_patch_applied = True
9393

94-
logger.info('[Ray Serve Patch] Applied in worker process: '
95-
'HTTPProxy.setup_request_context_and_handle patched')
94+
logger.debug('[Ray Serve Patch] Applied in worker process: '
95+
'HTTPProxy.setup_request_context_and_handle patched')
9696
except ImportError:
9797
# Ray Serve not available in this worker
9898
pass

0 commit comments

Comments
 (0)