1313
1414import twinkle
1515from twinkle import DeviceGroup , DeviceMesh
16- from twinkle .server .utils .adapter_manager import AdapterManagerMixin
1716from twinkle .server .utils .state import ServerStateProxy , get_server_state
1817from twinkle .server .utils .task_queue import TaskQueueConfig , TaskQueueMixin
1918from twinkle .server .utils .validation import get_token_from_request , verify_request_token
2524logger = get_logger ()
2625
2726
28- class SamplerManagement (TaskQueueMixin , AdapterManagerMixin ):
27+ class SamplerManagement (TaskQueueMixin ):
2928 """Unified sampler management service.
3029
3130 Manages:
3231 - vLLM or Torch sampler initialization and lifecycle
3332 - Tinker inference requests (/tinker/asample) with rate limiting via TaskQueueMixin
3433 - Twinkle inference requests (/twinkle/*) calling sampler directly
35- - Adapter lifecycle via AdapterManagerMixin
3634 - Template configuration for trajectory encoding
3735 """
3836
@@ -43,7 +41,6 @@ def __init__(self,
4341 device_mesh : dict [str , Any ],
4442 sampler_type : str = 'vllm' ,
4543 engine_args : dict [str , Any ] | None = None ,
46- adapter_config : dict [str , Any ] | None = None ,
4744 queue_config : dict [str , Any ] | None = None ,
4845 ** kwargs ):
4946 self .device_group = DeviceGroup (** device_group )
@@ -82,11 +79,8 @@ def __init__(self,
8279 self .sampler .set_template ('Template' , model_id = model_id )
8380 self .state : ServerStateProxy = get_server_state ()
8481
85- # Initialize both mixins
82+ # Initialize task queue mixin
8683 self ._init_task_queue (TaskQueueConfig .from_dict (queue_config ))
87- _adapter_config = adapter_config or {}
88- self ._init_adapter_manager (** _adapter_config )
89- self .start_adapter_countdown ()
9084
9185 @serve .multiplexed (max_num_models_per_replica = 5 )
9286 async def _sticky_entry (self , sticky_key : str ):
@@ -101,14 +95,6 @@ async def _on_request_start(self, request: Request) -> str:
10195 token = get_token_from_request (request )
10296 return token
10397
104- async def _on_adapter_expired (self , adapter_name : str , token : str = None ) -> None :
105- """Handle expired adapters by removing them from the sampler."""
106- try :
107- self .sampler .remove_adapter (adapter_name )
108- logger .info (f'Removed expired adapter { adapter_name } ' )
109- except Exception as e :
110- logger .warning (f'Failed to remove expired adapter { adapter_name } : { e } ' )
111-
11298
11399def build_sampler_app (model_id : str ,
114100 nproc_per_node : int ,
@@ -117,7 +103,6 @@ def build_sampler_app(model_id: str,
117103 deploy_options : dict [str , Any ],
118104 sampler_type : str = 'vllm' ,
119105 engine_args : dict [str , Any ] | None = None ,
120- adapter_config : dict [str , Any ] | None = None ,
121106 queue_config : dict [str , Any ] | None = None ,
122107 ** kwargs ):
123108 """Build a unified sampler application for text generation inference.
@@ -133,7 +118,6 @@ def build_sampler_app(model_id: str,
133118 deploy_options: Ray Serve deployment options
134119 sampler_type: Type of sampler to use ('vllm' or 'torch')
135120 engine_args: Additional engine arguments for the sampler
136- adapter_config: Adapter lifecycle config (timeout, per-token limits)
137121 queue_config: Task queue configuration dict (rps_limit, tps_limit, etc.)
138122 **kwargs: Additional arguments passed to the sampler
139123
@@ -161,8 +145,7 @@ def get_self() -> SamplerManagement:
161145 SamplerManagementWithIngress = serve .ingress (app )(SamplerManagement )
162146 DeploymentClass = serve .deployment (name = 'SamplerManagement' )(SamplerManagementWithIngress )
163147 return DeploymentClass .options (** deploy_options ).bind (model_id , nproc_per_node , device_group , device_mesh ,
164- sampler_type , engine_args , adapter_config , queue_config ,
165- ** kwargs )
148+ sampler_type , engine_args , queue_config , ** kwargs )
166149
167150
168151build_sampler_app = wrap_builder_with_device_group_env (build_sampler_app )
0 commit comments