1313from fastapi import FastAPI , Request
1414from peft import LoraConfig
1515from ray import serve
16+ from ray .serve .config import RequestRouterConfig
1617from tinker import types
1718from typing import Any , Dict , Optional
1819
2122from twinkle .server .utils .adapter_manager import AdapterManagerMixin
2223from twinkle .server .utils .state import ServerStateProxy , get_server_state
2324from twinkle .server .utils .task_queue import TaskQueueConfig , TaskQueueMixin
24- from twinkle .server .utils .validation import verify_request_token
25+ from twinkle .server .utils .validation import get_token_from_request , verify_request_token
2526from twinkle .utils .logger import get_logger
2627from .common .io_utils import create_checkpoint_manager , create_training_run_manager
28+ from .common .router import StickyLoraRequestRouter
2729
2830logger = get_logger ()
2931
@@ -62,7 +64,10 @@ async def verify_token(request: Request, call_next):
6264 """Middleware to verify authentication token for all requests."""
6365 return await verify_request_token (request = request , call_next = call_next )
6466
65- @serve .deployment (name = 'ModelManagement' )
67+ @serve .deployment (
68+ name = 'ModelManagement' ,
69+ request_router_config = RequestRouterConfig (request_router_class = StickyLoraRequestRouter , ),
70+ )
6671 @serve .ingress (app )
6772 class ModelManagement (TaskQueueMixin , AdapterManagerMixin ):
6873 """Model management service handling training operations.
@@ -99,28 +104,31 @@ def __init__(self,
99104 else :
100105 self .device_mesh = DeviceMesh .from_sizes (** device_mesh )
101106 self .use_megatron = use_megatron
102- replica_context = serve .get_replica_context ()
103- replica_id = replica_context . replica_id . unique_id
107+ self . replica_id = serve .get_replica_context (). replica_id . unique_id
108+ self . max_loras = kwargs . get ( 'max_loras' , 5 )
104109 # Initialize model immediately - choose backend based on use_megatron
105110 if use_megatron :
106111 from .common .megatron_model import TwinkleCompatMegatronModel
107112 self .model = TwinkleCompatMegatronModel (
108113 model_id = model_id ,
109114 device_mesh = self .device_mesh ,
110115 remote_group = self .device_group .name ,
111- instance_id = replica_id ,
116+ instance_id = self . replica_id ,
112117 ** kwargs )
113118 else :
114119 from .common .transformers_model import TwinkleCompatTransformersModel
115120 self .model = TwinkleCompatTransformersModel (
116121 model_id = model_id ,
117122 device_mesh = self .device_mesh ,
118123 remote_group = self .device_group .name ,
119- instance_id = replica_id ,
124+ instance_id = self . replica_id ,
120125 ** kwargs )
121126 self .base_model = model_id
122127 self .state : ServerStateProxy = get_server_state ()
123128
129+ # Register this replica so the router can track capacity
130+ self .state .register_replica (self .replica_id , self .max_loras )
131+
124132 # Initialize task queue
125133 self ._init_task_queue (TaskQueueConfig .from_dict (queue_config ))
126134
@@ -136,9 +144,18 @@ def __init__(self,
136144 4. Direct call actor instead of http or handler in server.py
137145 """
138146
139- # @serve.multiplexed(max_num_models_per_replica=kwargs.get('max_loras', 5))
140- # async def get_multiplexed_adapter(self, request_id: str):
141- # return request_id
147+ @serve .multiplexed (max_num_models_per_replica = kwargs .get ('max_loras' , 5 ))
148+ async def _sticky_entry (self , sticky_key : str ):
149+ return sticky_key
150+
151+ async def _ensure_sticky (self ):
152+ sticky_key = serve .get_multiplexed_model_id ()
153+ await self ._sticky_entry (sticky_key )
154+
155+ async def _on_request_start (self , request : Request ) -> str :
156+ await self ._ensure_sticky ()
157+ token = get_token_from_request (request )
158+ return token
142159
143160 def _cleanup_adapter (self , adapter_name : str ) -> None :
144161 """Common adapter cleanup logic used by both manual unload and automatic expiration.
@@ -188,12 +205,13 @@ async def create_model(self, request: Request, body: types.CreateModelRequest) -
188205 Returns:
189206 UntypedAPIFuture wrapping CreateModelResponse with model_id
190207 """
208+ token = await self ._on_request_start (request )
191209
192210 async def _create_adapter ():
193211 model_id = None
194212 try :
195213 # Register a new model_id for each create_model call
196- model_id = self .state .register_model (body .model_dump (), token = request . state . token )
214+ model_id = self .state .register_model (body .model_dump (), token = token , replica_id = self . replica_id )
197215
198216 # Create a new LoRA adapter for the model
199217 if body .lora_config :
@@ -203,7 +221,7 @@ async def _create_adapter():
203221 adapter_name = self .get_adapter_name (adapter_name = model_id )
204222
205223 # Register adapter FIRST
206- self .register_adapter (adapter_name , request . state . token , session_id = body .session_id )
224+ self .register_adapter (adapter_name , token , session_id = body .session_id )
207225
208226 # Create adapter AFTER successful registration
209227 self .model .add_adapter_to_model (adapter_name = adapter_name , config_or_dir = lora_cfg )
@@ -215,7 +233,7 @@ async def _create_adapter():
215233 # Fresh adapter has no accumulated gradients.
216234 self .set_adapter_state (adapter_name , 'grad_ready' , False )
217235
218- training_run_manager = create_training_run_manager (request . state . token )
236+ training_run_manager = create_training_run_manager (token )
219237 training_run_manager .save (model_id , body )
220238
221239 return types .CreateModelResponse (model_id = model_id )
@@ -233,7 +251,7 @@ async def _create_adapter():
233251
234252 return await self .schedule_task (
235253 _create_adapter ,
236- token = request . state . token ,
254+ token = token ,
237255 task_type = 'create_model' ,
238256 )
239257
@@ -248,9 +266,10 @@ async def get_info(self, request: Request, body: types.GetInfoRequest) -> types.
248266 Returns:
249267 GetInfoResponse with model metadata (name, lora_rank, etc.)
250268 """
269+ token = await self ._on_request_start (request )
251270 # Note: get_info doesn't require token for reading metadata in tinker
252271 # Using a default token or None since this is read-only
253- training_run_manager = create_training_run_manager (request . state . token )
272+ training_run_manager = create_training_run_manager (token )
254273 metadata = training_run_manager .get (str (body .model_id ))
255274 model_name = metadata .base_model if metadata else model_id
256275 lora_rank = None
@@ -279,6 +298,7 @@ async def unload_model(self, request: Request, body: types.UnloadModelRequest) -
279298 Returns:
280299 UntypedAPIFuture wrapping UnloadModelResponse
281300 """
301+ token = await self ._on_request_start (request )
282302
283303 async def _do_unload ():
284304 # Only remove adapter, not the base model
@@ -290,7 +310,7 @@ async def _do_unload():
290310 return await self .schedule_task (
291311 _do_unload ,
292312 model_id = body .model_id ,
293- token = request . state . token ,
313+ token = token ,
294314 task_type = 'unload_model' ,
295315 )
296316
@@ -307,6 +327,7 @@ async def forward(self, request: Request, body: types.ForwardRequest) -> types.U
307327 Returns:
308328 UntypedAPIFuture wrapping ForwardBackwardOutput with loss
309329 """
330+ token = await self ._on_request_start (request )
310331
311332 async def _do_forward ():
312333 try :
@@ -340,7 +361,7 @@ async def _do_forward():
340361 return await self .schedule_task (
341362 _do_forward ,
342363 model_id = body .model_id ,
343- token = request . state . token ,
364+ token = token ,
344365 input_tokens = input_tokens ,
345366 batch_size = batch_size ,
346367 data_world_size = self .device_mesh .data_world_size ,
@@ -364,6 +385,7 @@ async def forward_backward(self, request: Request,
364385 Returns:
365386 UntypedAPIFuture wrapping ForwardBackwardOutput with loss and metrics
366387 """
388+ token = await self ._on_request_start (request )
367389
368390 async def _do_forward_backward ():
369391 try :
@@ -405,7 +427,7 @@ async def _do_forward_backward():
405427 return await self .schedule_task (
406428 _do_forward_backward ,
407429 model_id = body .model_id ,
408- token = request . state . token ,
430+ token = token ,
409431 input_tokens = input_tokens ,
410432 batch_size = batch_size ,
411433 data_world_size = self .device_mesh .data_world_size ,
@@ -425,6 +447,7 @@ async def optim_step(self, request: Request, body: types.OptimStepRequest) -> ty
425447 Returns:
426448 UntypedAPIFuture wrapping OptimStepResponse
427449 """
450+ token = await self ._on_request_start (request )
428451
429452 async def _do_optim ():
430453 try :
@@ -455,7 +478,7 @@ async def _do_optim():
455478 return await self .schedule_task (
456479 _do_optim ,
457480 model_id = body .model_id ,
458- token = request . state . token ,
481+ token = token ,
459482 task_type = 'optim_step' ,
460483 )
461484
@@ -473,6 +496,7 @@ async def save_weights(self, request: Request, body: types.SaveWeightsRequest) -
473496 Returns:
474497 UntypedAPIFuture wrapping SaveWeightsResponse with saved path
475498 """
499+ token = await self ._on_request_start (request )
476500
477501 async def _do_save ():
478502 try :
@@ -482,8 +506,6 @@ async def _do_save():
482506 # Touch adapter to reset inactivity counter
483507 self .touch_adapter (adapter_name )
484508
485- # Extract token from request for user isolation
486- token = request .state .token
487509 checkpoint_manager = create_checkpoint_manager (token )
488510
489511 # get save dir with token-based isolation
@@ -506,7 +528,7 @@ async def _do_save():
506528 return await self .schedule_task (
507529 _do_save ,
508530 model_id = body .model_id ,
509- token = request . state . token ,
531+ token = token ,
510532 task_type = 'save_weights' ,
511533 )
512534
@@ -525,6 +547,7 @@ async def save_weights_for_sampler(self, request: Request,
525547 Returns:
526548 UntypedAPIFuture wrapping SaveWeightsForSamplerResponseInternal
527549 """
550+ token = await self ._on_request_start (request )
528551
529552 async def _do_save_for_sampler ():
530553 try :
@@ -535,8 +558,6 @@ async def _do_save_for_sampler():
535558 # Touch adapter to reset inactivity counter
536559 self .touch_adapter (adapter_name )
537560
538- # Extract token from request for user isolation
539- token = request .state .token
540561 checkpoint_manager = create_checkpoint_manager (token )
541562
542563 # get save dir with token-based isolation
@@ -571,7 +592,7 @@ async def _do_save_for_sampler():
571592 return await self .schedule_task (
572593 _do_save_for_sampler ,
573594 model_id = body .model_id ,
574- token = request . state . token ,
595+ token = token ,
575596 task_type = 'save_weights_for_sampler' ,
576597 )
577598
@@ -589,6 +610,7 @@ async def load_weights(self, request: Request, body: types.LoadWeightsRequest) -
589610 Returns:
590611 UntypedAPIFuture wrapping LoadWeightsResponse
591612 """
613+ token = await self ._on_request_start (request )
592614
593615 async def _do_load ():
594616 try :
@@ -600,9 +622,6 @@ async def _do_load():
600622 # Touch adapter to reset inactivity counter
601623 self .touch_adapter (adapter_name )
602624
603- # Extract token from request for user isolation
604- token = request .state .token
605-
606625 weight_path = body .path
607626 load_optimizer = body .optimizer
608627
@@ -625,7 +644,7 @@ async def _do_load():
625644 return await self .schedule_task (
626645 _do_load ,
627646 model_id = body .model_id ,
628- token = request . state . token ,
647+ token = token ,
629648 task_type = 'load_weights' ,
630649 )
631650
0 commit comments