Skip to content

Commit 8d418d2

Browse files
committed
update custom route
1 parent d4c5db5 commit 8d418d2

File tree

8 files changed

+283
-69
lines changed

8 files changed

+283
-69
lines changed

cookbook/client/tinker/megatron/server_config_7b.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ applications:
2222
import_path: server # Python module to import
2323
args:
2424
server_config:
25-
per_token_model_limit: 1 # Maximum number of models (adapters) per token (server-globally enforced)
25+
per_token_model_limit: 3 # Maximum number of models (adapters) per token (server-globally enforced)
2626
supported_models:
2727
- Qwen/Qwen2.5-7B-Instruct
2828
deployments:
@@ -58,11 +58,12 @@ applications:
5858
adapter_config:
5959
adapter_timeout: 30 # Seconds before idle adapter unload
6060
adapter_max_lifetime: 36000 # Maximum lifetime of an adapter in seconds (e.g., 10 hours)
61+
max_loras: 1 # Maximum number of LoRA adapters per model
6162
deployments:
6263
- name: ModelManagement
6364
autoscaling_config:
64-
min_replicas: 1
65-
max_replicas: 1
65+
min_replicas: 2
66+
max_replicas: 2
6667
target_ongoing_requests: 16
6768
ray_actor_options:
6869
num_cpus: 0.1

src/twinkle/model/multi_lora.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def acquire_lora(self, tenant_adapter_name: str, config: LoraConfig) -> str:
115115
raise ValueError(f'Lora {tenant_adapter_name} already exists')
116116
_available_lora = self._get_available_lora()
117117
if _available_lora is None:
118-
raise RuntimeError(f'No lora available for tenant {tenant_adapter_name}')
118+
raise RuntimeError(f'No lora available for tenant {tenant_adapter_name}. Max loras: {self.max_loras}')
119119
if config.r > self.max_r:
120120
raise RuntimeError(f'Too big rank for lora: {config.r}')
121121
_available_lora.tenant_config = config
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from ray.serve.request_router import (FIFOMixin, MultiplexMixin, PendingRequest, ReplicaID, ReplicaResult,
2+
RequestRouter, RunningReplica)
3+
from typing import Dict, List, Optional
4+
5+
from twinkle.server.utils.state import ServerStateProxy, get_server_state
6+
7+
8+
class StickyLoraRequestRouter(FIFOMixin, MultiplexMixin, RequestRouter):
9+
10+
def __init__(self, *args, **kwargs):
11+
super().__init__(*args, **kwargs)
12+
13+
self.state: ServerStateProxy = get_server_state()
14+
15+
async def choose_replicas(
16+
self,
17+
candidate_replicas: List[RunningReplica],
18+
pending_request: Optional[PendingRequest] = None,
19+
) -> List[List[RunningReplica]]:
20+
"""
21+
This method chooses the best replica for the request based on
22+
multiplexed and avaliable lora count. The algorithm
23+
works as follows:
24+
25+
1. Populate top_ranked_replicas based on available replicas based on
26+
multiplex_id (only one replica is chosen)
27+
2. Populate and override top_ranked_replicas info based on avalible lora
28+
slots of the replica.
29+
"""
30+
31+
# Take the best set of replicas for the multiplexed model
32+
if (pending_request is not None and pending_request.metadata.multiplexed_model_id):
33+
ranked_replicas_multiplex: List[RunningReplica] = (self.rank_replicas_via_multiplex(
34+
replicas=candidate_replicas,
35+
multiplexed_model_id=pending_request.metadata.multiplexed_model_id,
36+
))[0]
37+
38+
# If found any replica, return it
39+
if ranked_replicas_multiplex:
40+
print('[Router] Found replica for multiplexed model !!!')
41+
return [ranked_replicas_multiplex]
42+
43+
# Dictionary to hold the top-ranked replicas
44+
top_ranked_replicas: Dict[ReplicaID, RunningReplica] = {}
45+
46+
# Filter out replicas that are not available (queue length exceed max ongoing request)
47+
ranked_replicas_locality = self.select_available_replicas(candidates=candidate_replicas)
48+
49+
for replica in ranked_replicas_locality:
50+
top_ranked_replicas[replica.replica_id] = replica
51+
52+
# Filter out replicas that exceed max lora count (query from server state)
53+
candidate_ids = [r.replica_id.unique_id for r in top_ranked_replicas.values()]
54+
available_ids = set(self.state.get_available_replica_ids(candidate_ids))
55+
if available_ids:
56+
top_ranked_replicas = {
57+
rid: r
58+
for rid, r in top_ranked_replicas.items() if r.replica_id.unique_id in available_ids
59+
}
60+
61+
if not top_ranked_replicas:
62+
# No replica has remaining LoRA capacity – fall back to all candidates
63+
print('[Router] No replica has remaining LoRA capacity')
64+
return [candidate_replicas]
65+
66+
print('[Router] StickyLoraRequestRouter choosing replica for request')
67+
68+
# Take the replica with minimum throughput.
69+
min_throughput_replicas = min(
70+
[replica for replica in top_ranked_replicas.values()],
71+
key=lambda r: r.routing_stats.get('throughput', 0),
72+
)
73+
return [[min_throughput_replicas]]

src/twinkle/server/tinker/model.py

Lines changed: 47 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from fastapi import FastAPI, Request
1414
from peft import LoraConfig
1515
from ray import serve
16+
from ray.serve.config import RequestRouterConfig
1617
from tinker import types
1718
from typing import Any, Dict, Optional
1819

@@ -21,9 +22,10 @@
2122
from twinkle.server.utils.adapter_manager import AdapterManagerMixin
2223
from twinkle.server.utils.state import ServerStateProxy, get_server_state
2324
from twinkle.server.utils.task_queue import TaskQueueConfig, TaskQueueMixin
24-
from twinkle.server.utils.validation import verify_request_token
25+
from twinkle.server.utils.validation import get_token_from_request, verify_request_token
2526
from twinkle.utils.logger import get_logger
2627
from .common.io_utils import create_checkpoint_manager, create_training_run_manager
28+
from .common.router import StickyLoraRequestRouter
2729

2830
logger = get_logger()
2931

@@ -62,7 +64,10 @@ async def verify_token(request: Request, call_next):
6264
"""Middleware to verify authentication token for all requests."""
6365
return await verify_request_token(request=request, call_next=call_next)
6466

65-
@serve.deployment(name='ModelManagement')
67+
@serve.deployment(
68+
name='ModelManagement',
69+
request_router_config=RequestRouterConfig(request_router_class=StickyLoraRequestRouter, ),
70+
)
6671
@serve.ingress(app)
6772
class ModelManagement(TaskQueueMixin, AdapterManagerMixin):
6873
"""Model management service handling training operations.
@@ -99,28 +104,31 @@ def __init__(self,
99104
else:
100105
self.device_mesh = DeviceMesh.from_sizes(**device_mesh)
101106
self.use_megatron = use_megatron
102-
replica_context = serve.get_replica_context()
103-
replica_id = replica_context.replica_id.unique_id
107+
self.replica_id = serve.get_replica_context().replica_id.unique_id
108+
self.max_loras = kwargs.get('max_loras', 5)
104109
# Initialize model immediately - choose backend based on use_megatron
105110
if use_megatron:
106111
from .common.megatron_model import TwinkleCompatMegatronModel
107112
self.model = TwinkleCompatMegatronModel(
108113
model_id=model_id,
109114
device_mesh=self.device_mesh,
110115
remote_group=self.device_group.name,
111-
instance_id=replica_id,
116+
instance_id=self.replica_id,
112117
**kwargs)
113118
else:
114119
from .common.transformers_model import TwinkleCompatTransformersModel
115120
self.model = TwinkleCompatTransformersModel(
116121
model_id=model_id,
117122
device_mesh=self.device_mesh,
118123
remote_group=self.device_group.name,
119-
instance_id=replica_id,
124+
instance_id=self.replica_id,
120125
**kwargs)
121126
self.base_model = model_id
122127
self.state: ServerStateProxy = get_server_state()
123128

129+
# Register this replica so the router can track capacity
130+
self.state.register_replica(self.replica_id, self.max_loras)
131+
124132
# Initialize task queue
125133
self._init_task_queue(TaskQueueConfig.from_dict(queue_config))
126134

@@ -136,9 +144,18 @@ def __init__(self,
136144
4. Direct call actor instead of http or handler in server.py
137145
"""
138146

139-
# @serve.multiplexed(max_num_models_per_replica=kwargs.get('max_loras', 5))
140-
# async def get_multiplexed_adapter(self, request_id: str):
141-
# return request_id
147+
@serve.multiplexed(max_num_models_per_replica=kwargs.get('max_loras', 5))
148+
async def _sticky_entry(self, sticky_key: str):
149+
return sticky_key
150+
151+
async def _ensure_sticky(self):
152+
sticky_key = serve.get_multiplexed_model_id()
153+
await self._sticky_entry(sticky_key)
154+
155+
async def _on_request_start(self, request: Request) -> str:
156+
await self._ensure_sticky()
157+
token = get_token_from_request(request)
158+
return token
142159

143160
def _cleanup_adapter(self, adapter_name: str) -> None:
144161
"""Common adapter cleanup logic used by both manual unload and automatic expiration.
@@ -188,12 +205,13 @@ async def create_model(self, request: Request, body: types.CreateModelRequest) -
188205
Returns:
189206
UntypedAPIFuture wrapping CreateModelResponse with model_id
190207
"""
208+
token = await self._on_request_start(request)
191209

192210
async def _create_adapter():
193211
model_id = None
194212
try:
195213
# Register a new model_id for each create_model call
196-
model_id = self.state.register_model(body.model_dump(), token=request.state.token)
214+
model_id = self.state.register_model(body.model_dump(), token=token, replica_id=self.replica_id)
197215

198216
# Create a new LoRA adapter for the model
199217
if body.lora_config:
@@ -203,7 +221,7 @@ async def _create_adapter():
203221
adapter_name = self.get_adapter_name(adapter_name=model_id)
204222

205223
# Register adapter FIRST
206-
self.register_adapter(adapter_name, request.state.token, session_id=body.session_id)
224+
self.register_adapter(adapter_name, token, session_id=body.session_id)
207225

208226
# Create adapter AFTER successful registration
209227
self.model.add_adapter_to_model(adapter_name=adapter_name, config_or_dir=lora_cfg)
@@ -215,7 +233,7 @@ async def _create_adapter():
215233
# Fresh adapter has no accumulated gradients.
216234
self.set_adapter_state(adapter_name, 'grad_ready', False)
217235

218-
training_run_manager = create_training_run_manager(request.state.token)
236+
training_run_manager = create_training_run_manager(token)
219237
training_run_manager.save(model_id, body)
220238

221239
return types.CreateModelResponse(model_id=model_id)
@@ -233,7 +251,7 @@ async def _create_adapter():
233251

234252
return await self.schedule_task(
235253
_create_adapter,
236-
token=request.state.token,
254+
token=token,
237255
task_type='create_model',
238256
)
239257

@@ -248,9 +266,10 @@ async def get_info(self, request: Request, body: types.GetInfoRequest) -> types.
248266
Returns:
249267
GetInfoResponse with model metadata (name, lora_rank, etc.)
250268
"""
269+
token = await self._on_request_start(request)
251270
# Note: get_info doesn't require token for reading metadata in tinker
252271
# Using a default token or None since this is read-only
253-
training_run_manager = create_training_run_manager(request.state.token)
272+
training_run_manager = create_training_run_manager(token)
254273
metadata = training_run_manager.get(str(body.model_id))
255274
model_name = metadata.base_model if metadata else model_id
256275
lora_rank = None
@@ -279,6 +298,7 @@ async def unload_model(self, request: Request, body: types.UnloadModelRequest) -
279298
Returns:
280299
UntypedAPIFuture wrapping UnloadModelResponse
281300
"""
301+
token = await self._on_request_start(request)
282302

283303
async def _do_unload():
284304
# Only remove adapter, not the base model
@@ -290,7 +310,7 @@ async def _do_unload():
290310
return await self.schedule_task(
291311
_do_unload,
292312
model_id=body.model_id,
293-
token=request.state.token,
313+
token=token,
294314
task_type='unload_model',
295315
)
296316

@@ -307,6 +327,7 @@ async def forward(self, request: Request, body: types.ForwardRequest) -> types.U
307327
Returns:
308328
UntypedAPIFuture wrapping ForwardBackwardOutput with loss
309329
"""
330+
token = await self._on_request_start(request)
310331

311332
async def _do_forward():
312333
try:
@@ -340,7 +361,7 @@ async def _do_forward():
340361
return await self.schedule_task(
341362
_do_forward,
342363
model_id=body.model_id,
343-
token=request.state.token,
364+
token=token,
344365
input_tokens=input_tokens,
345366
batch_size=batch_size,
346367
data_world_size=self.device_mesh.data_world_size,
@@ -364,6 +385,7 @@ async def forward_backward(self, request: Request,
364385
Returns:
365386
UntypedAPIFuture wrapping ForwardBackwardOutput with loss and metrics
366387
"""
388+
token = await self._on_request_start(request)
367389

368390
async def _do_forward_backward():
369391
try:
@@ -405,7 +427,7 @@ async def _do_forward_backward():
405427
return await self.schedule_task(
406428
_do_forward_backward,
407429
model_id=body.model_id,
408-
token=request.state.token,
430+
token=token,
409431
input_tokens=input_tokens,
410432
batch_size=batch_size,
411433
data_world_size=self.device_mesh.data_world_size,
@@ -425,6 +447,7 @@ async def optim_step(self, request: Request, body: types.OptimStepRequest) -> ty
425447
Returns:
426448
UntypedAPIFuture wrapping OptimStepResponse
427449
"""
450+
token = await self._on_request_start(request)
428451

429452
async def _do_optim():
430453
try:
@@ -455,7 +478,7 @@ async def _do_optim():
455478
return await self.schedule_task(
456479
_do_optim,
457480
model_id=body.model_id,
458-
token=request.state.token,
481+
token=token,
459482
task_type='optim_step',
460483
)
461484

@@ -473,6 +496,7 @@ async def save_weights(self, request: Request, body: types.SaveWeightsRequest) -
473496
Returns:
474497
UntypedAPIFuture wrapping SaveWeightsResponse with saved path
475498
"""
499+
token = await self._on_request_start(request)
476500

477501
async def _do_save():
478502
try:
@@ -482,8 +506,6 @@ async def _do_save():
482506
# Touch adapter to reset inactivity counter
483507
self.touch_adapter(adapter_name)
484508

485-
# Extract token from request for user isolation
486-
token = request.state.token
487509
checkpoint_manager = create_checkpoint_manager(token)
488510

489511
# get save dir with token-based isolation
@@ -506,7 +528,7 @@ async def _do_save():
506528
return await self.schedule_task(
507529
_do_save,
508530
model_id=body.model_id,
509-
token=request.state.token,
531+
token=token,
510532
task_type='save_weights',
511533
)
512534

@@ -525,6 +547,7 @@ async def save_weights_for_sampler(self, request: Request,
525547
Returns:
526548
UntypedAPIFuture wrapping SaveWeightsForSamplerResponseInternal
527549
"""
550+
token = await self._on_request_start(request)
528551

529552
async def _do_save_for_sampler():
530553
try:
@@ -535,8 +558,6 @@ async def _do_save_for_sampler():
535558
# Touch adapter to reset inactivity counter
536559
self.touch_adapter(adapter_name)
537560

538-
# Extract token from request for user isolation
539-
token = request.state.token
540561
checkpoint_manager = create_checkpoint_manager(token)
541562

542563
# get save dir with token-based isolation
@@ -571,7 +592,7 @@ async def _do_save_for_sampler():
571592
return await self.schedule_task(
572593
_do_save_for_sampler,
573594
model_id=body.model_id,
574-
token=request.state.token,
595+
token=token,
575596
task_type='save_weights_for_sampler',
576597
)
577598

@@ -589,6 +610,7 @@ async def load_weights(self, request: Request, body: types.LoadWeightsRequest) -
589610
Returns:
590611
UntypedAPIFuture wrapping LoadWeightsResponse
591612
"""
613+
token = await self._on_request_start(request)
592614

593615
async def _do_load():
594616
try:
@@ -600,9 +622,6 @@ async def _do_load():
600622
# Touch adapter to reset inactivity counter
601623
self.touch_adapter(adapter_name)
602624

603-
# Extract token from request for user isolation
604-
token = request.state.token
605-
606625
weight_path = body.path
607626
load_optimizer = body.optimizer
608627

@@ -625,7 +644,7 @@ async def _do_load():
625644
return await self.schedule_task(
626645
_do_load,
627646
model_id=body.model_id,
628-
token=request.state.token,
647+
token=token,
629648
task_type='load_weights',
630649
)
631650

0 commit comments

Comments
 (0)