|
1 | 1 | # Copyright (c) ModelScope Contributors. All rights reserved. |
2 | 2 | """ |
3 | | -Processor management application (moved from twinkle/processor.py). |
| 3 | +Processor management application. |
4 | 4 |
|
5 | 5 | Provides a Ray Serve deployment for managing distributed processors |
6 | 6 | (datasets, dataloaders, preprocessors, rewards, templates, weight loaders, etc.). |
| 7 | +
|
| 8 | +Follows the same structural pattern as model/app.py: |
| 9 | +- ProcessorManagement is a top-level class inheriting ProcessorManagerMixin |
| 10 | +- Routes are registered in build_processor_app() via _register_processor_routes() |
| 11 | +- serve.ingress(app)(ProcessorManagement) applied before deployment |
| 12 | +- Sticky session routing via @serve.multiplexed keyed on session ID |
7 | 13 | """ |
8 | | -import importlib |
| 14 | +from __future__ import annotations |
| 15 | + |
9 | 16 | import os |
10 | | -import uuid |
11 | | -from fastapi import FastAPI, HTTPException, Request |
| 17 | +from fastapi import FastAPI, Request |
12 | 18 | from ray import serve |
13 | | -from typing import Any, Dict |
| 19 | +from typing import Any, Dict, Optional |
14 | 20 |
|
15 | 21 | import twinkle |
16 | | -import twinkle_client.types as types |
17 | 22 | from twinkle import DeviceGroup, DeviceMesh, get_logger |
18 | | -from twinkle.server.common.serialize import deserialize_object |
| 23 | +from twinkle.server.utils.processor_manager import ProcessorManagerMixin |
19 | 24 | from twinkle.server.utils.state import ServerStateProxy, get_server_state |
20 | 25 | from twinkle.server.utils.validation import verify_request_token |
| 26 | +from .twinkle_handlers import _register_processor_routes |
21 | 27 |
|
22 | 28 | logger = get_logger() |
23 | 29 |
|
24 | 30 |
|
| 31 | +class ProcessorManagement(ProcessorManagerMixin): |
| 32 | + """Processor management service. |
| 33 | +
|
| 34 | + Manages lifecycle and invocation of distributed processor objects |
| 35 | + (datasets, dataloaders, rewards, templates, etc.). |
| 36 | +
|
| 37 | + Lifecycle is handled by ProcessorManagerMixin: |
| 38 | + - Processors are registered with a session ID on creation. |
| 39 | + - A background thread expires processors whose session has timed out. |
| 40 | + - Per-user processor limit is enforced at registration. |
| 41 | + - Sticky session routing ensures session requests hit the same replica. |
| 42 | + """ |
| 43 | + |
| 44 | + def __init__(self, |
| 45 | + ncpu_proc_per_node: int, |
| 46 | + device_group: dict[str, Any], |
| 47 | + device_mesh: dict[str, Any], |
| 48 | + nproc_per_node: int = 1, |
| 49 | + processor_config: dict[str, Any] | None = None): |
| 50 | + self.device_group = DeviceGroup(**device_group) |
| 51 | + twinkle.initialize( |
| 52 | + mode='ray', |
| 53 | + nproc_per_node=nproc_per_node, |
| 54 | + groups=[self.device_group], |
| 55 | + lazy_collect=False, |
| 56 | + ncpu_proc_per_node=ncpu_proc_per_node) |
| 57 | + if 'mesh_dim_names' in device_mesh: |
| 58 | + self.device_mesh = DeviceMesh(**device_mesh) |
| 59 | + else: |
| 60 | + self.device_mesh = DeviceMesh.from_sizes(**device_mesh) |
| 61 | + |
| 62 | + # processor objects keyed by processor_id |
| 63 | + self.resource_dict: dict[str, Any] = {} |
| 64 | + self.state: ServerStateProxy = get_server_state() |
| 65 | + |
| 66 | + _cfg = processor_config or {} |
| 67 | + _env_limit = int(os.environ.get('TWINKLE_PER_USER_PROCESSOR_LIMIT', 20)) |
| 68 | + self._init_processor_manager( |
| 69 | + processor_timeout=float(_cfg.get('processor_timeout', 1800.0)), |
| 70 | + per_token_processor_limit=int(_cfg.get('per_token_processor_limit', _env_limit)), |
| 71 | + ) |
| 72 | + self.start_processor_countdown() |
| 73 | + |
| 74 | + @serve.multiplexed(max_num_models_per_replica=100) |
| 75 | + async def _sticky_entry(self, sticky_key: str): |
| 76 | + return sticky_key |
| 77 | + |
| 78 | + async def _ensure_sticky(self): |
| 79 | + sticky_key = serve.get_multiplexed_model_id() |
| 80 | + await self._sticky_entry(sticky_key) |
| 81 | + |
| 82 | + def _on_processor_expired(self, processor_id: str) -> None: |
| 83 | + """Called by the countdown thread when a processor's session expires.""" |
| 84 | + self.resource_dict.pop(processor_id, None) |
| 85 | + self.unregister_processor(processor_id) |
| 86 | + |
| 87 | + |
25 | 88 | def build_processor_app(ncpu_proc_per_node: int, |
26 | | - device_group: Dict[str, Any], |
27 | | - device_mesh: Dict[str, Any], |
28 | | - deploy_options: Dict[str, Any], |
| 89 | + device_group: dict[str, Any], |
| 90 | + device_mesh: dict[str, Any], |
| 91 | + deploy_options: dict[str, Any], |
29 | 92 | nproc_per_node: int = 1, |
| 93 | + processor_config: dict[str, Any] | None = None, |
30 | 94 | **kwargs): |
31 | 95 | """Build the processor management application. |
32 | 96 |
|
| 97 | + Follows the same pattern as build_model_app(): FastAPI app and routes are |
| 98 | + built here BEFORE serve.ingress so that the frozen app contains the full |
| 99 | + route table visible to ProxyActor. |
| 100 | +
|
33 | 101 | Args: |
34 | | - ncpu_proc_per_node: Number of CPU processes per node |
35 | | - device_group: Device group configuration dict |
36 | | - device_mesh: Device mesh configuration dict |
37 | | - deploy_options: Ray Serve deployment options |
38 | | - nproc_per_node: Number of GPU processes per node (default 1, not used for CPU-only tasks) |
39 | | - **kwargs: Additional arguments |
| 102 | + ncpu_proc_per_node: Number of CPU processes per node. |
| 103 | + device_group: Device group configuration dict. |
| 104 | + device_mesh: Device mesh configuration dict. |
| 105 | + deploy_options: Ray Serve deployment options. |
| 106 | + nproc_per_node: Number of GPU processes per node (default 1). |
| 107 | + processor_config: Optional lifecycle configuration dict. |
| 108 | + Supported keys: |
| 109 | + - ``processor_timeout`` (float): Session inactivity timeout seconds. Default 1800.0. |
| 110 | + - ``per_token_processor_limit`` (int): Max processors per user. |
| 111 | + Overrides ``TWINKLE_PER_USER_PROCESSOR_LIMIT`` env var when provided. |
| 112 | + **kwargs: Additional arguments. |
40 | 113 |
|
41 | 114 | Returns: |
42 | | - Ray Serve deployment bound with configuration |
| 115 | + Ray Serve deployment bound with configuration. |
43 | 116 | """ |
| 117 | + # Build the FastAPI app and register all routes BEFORE serve.ingress so that |
| 118 | + # the frozen app contains the complete route table (visible to ProxyActor). |
44 | 119 | app = FastAPI() |
45 | 120 |
|
46 | 121 | @app.middleware('http') |
47 | 122 | async def verify_token(request: Request, call_next): |
48 | 123 | return await verify_request_token(request=request, call_next=call_next) |
49 | 124 |
|
50 | | - processors = ['dataset', 'dataloader', 'preprocessor', 'processor', 'reward', 'template', 'weight_loader'] |
51 | | - |
52 | | - @serve.deployment(name='ProcessorManagement') |
53 | | - @serve.ingress(app) |
54 | | - class ProcessorManagement: |
55 | | - """Processor management service. |
56 | | -
|
57 | | - Manages lifecycle and invocation of distributed processor objects |
58 | | - (datasets, dataloaders, rewards, templates, etc.). |
59 | | - """ |
60 | | - |
61 | | - def __init__(self, |
62 | | - ncpu_proc_per_node: int, |
63 | | - device_group: Dict[str, Any], |
64 | | - device_mesh: Dict[str, Any], |
65 | | - nproc_per_node: int = 1): |
66 | | - self.device_group = DeviceGroup(**device_group) |
67 | | - twinkle.initialize( |
68 | | - mode='ray', |
69 | | - nproc_per_node=nproc_per_node, |
70 | | - groups=[self.device_group], |
71 | | - lazy_collect=False, |
72 | | - ncpu_proc_per_node=ncpu_proc_per_node) |
73 | | - if 'mesh_dim_names' in device_mesh: |
74 | | - self.device_mesh = DeviceMesh(**device_mesh) |
75 | | - else: |
76 | | - self.device_mesh = DeviceMesh.from_sizes(**device_mesh) |
77 | | - self.resource_dict = {} |
78 | | - self.state: ServerStateProxy = get_server_state() |
79 | | - self.per_token_processor_limit = int(os.environ.get('TWINKLE_PER_USER_PROCESSOR_LIMIT', 20)) |
80 | | - self.key_token_dict = {} |
81 | | - |
82 | | - def assert_processor_exists(self, processor_id: str): |
83 | | - assert processor_id and processor_id in self.resource_dict, f'Processor {processor_id} not found' |
84 | | - |
85 | | - @app.post('/twinkle/create', response_model=types.ProcessorCreateResponse) |
86 | | - def create(self, request: Request, body: types.ProcessorCreateRequest) -> types.ProcessorCreateResponse: |
87 | | - processor_type_name = body.processor_type |
88 | | - class_type = body.class_type |
89 | | - _kwargs = body.model_extra or {} |
90 | | - |
91 | | - assert processor_type_name in processors, f'Invalid processor type: {processor_type_name}' |
92 | | - processor_module = importlib.import_module(f'twinkle.{processor_type_name}') |
93 | | - assert hasattr(processor_module, class_type), f'Class {class_type} not found in {processor_type_name}' |
94 | | - processor_id = str(uuid.uuid4().hex) |
95 | | - self.key_token_dict[processor_id] = request.state.token |
96 | | - |
97 | | - _kwargs.pop('remote_group', None) |
98 | | - _kwargs.pop('device_mesh', None) |
99 | | - |
100 | | - resolved_kwargs = {} |
101 | | - for key, value in _kwargs.items(): |
102 | | - if isinstance(value, str) and value.startswith('pid:'): |
103 | | - ref_id = value[4:] |
104 | | - resolved_kwargs[key] = self.resource_dict[ref_id] |
105 | | - else: |
106 | | - value = deserialize_object(value) |
107 | | - resolved_kwargs[key] = value |
108 | | - |
109 | | - processor = getattr(processor_module, class_type)( |
110 | | - remote_group=self.device_group.name, |
111 | | - device_mesh=self.device_mesh, |
112 | | - instance_id=processor_id, |
113 | | - **resolved_kwargs) |
114 | | - self.resource_dict[processor_id] = processor |
115 | | - return types.ProcessorCreateResponse(processor_id='pid:' + processor_id) |
116 | | - |
117 | | - @app.post('/twinkle/call', response_model=types.ProcessorCallResponse) |
118 | | - def call(self, body: types.ProcessorCallRequest) -> types.ProcessorCallResponse: |
119 | | - processor_id = body.processor_id |
120 | | - function_name = body.function |
121 | | - _kwargs = body.model_extra or {} |
122 | | - processor_id = processor_id[4:] |
123 | | - self.assert_processor_exists(processor_id=processor_id) |
124 | | - processor = self.resource_dict.get(processor_id) |
125 | | - function = getattr(processor, function_name, None) |
126 | | - |
127 | | - assert function is not None, f'`{function_name}` not found in {processor.__class__}' |
128 | | - assert hasattr(function, '_execute'), f'Cannot call inner method of {processor.__class__}' |
129 | | - |
130 | | - resolved_kwargs = {} |
131 | | - for key, value in _kwargs.items(): |
132 | | - if isinstance(value, str) and value.startswith('pid:'): |
133 | | - ref_id = value[4:] |
134 | | - resolved_kwargs[key] = self.resource_dict[ref_id] |
135 | | - else: |
136 | | - value = deserialize_object(value) |
137 | | - resolved_kwargs[key] = value |
138 | | - |
139 | | - # Special handling for __next__ to catch StopIteration |
140 | | - if function_name == '__next__': |
141 | | - try: |
142 | | - result = function(**resolved_kwargs) |
143 | | - return types.ProcessorCallResponse(result=result) |
144 | | - except StopIteration: |
145 | | - # HTTP 410 Gone signals iterator exhausted |
146 | | - raise HTTPException(status_code=410, detail='Iterator exhausted') |
147 | | - |
148 | | - result = function(**resolved_kwargs) |
149 | | - if function_name == '__iter__': |
150 | | - return types.ProcessorCallResponse(result='ok') |
151 | | - else: |
152 | | - return types.ProcessorCallResponse(result=result) |
153 | | - |
154 | | - return ProcessorManagement.options(**deploy_options).bind( |
155 | | - ncpu_proc_per_node, device_group, device_mesh, nproc_per_node=nproc_per_node) |
| 125 | + def get_self() -> ProcessorManagement: |
| 126 | + return serve.get_replica_context().servable_object |
| 127 | + |
| 128 | + _register_processor_routes(app, get_self) |
| 129 | + |
| 130 | + ProcessorManagementWithIngress = serve.ingress(app)(ProcessorManagement) |
| 131 | + DeploymentClass = serve.deployment(name='ProcessorManagement')(ProcessorManagementWithIngress) |
| 132 | + return DeploymentClass.options(**deploy_options).bind(ncpu_proc_per_node, device_group, device_mesh, nproc_per_node, |
| 133 | + processor_config) |
0 commit comments