Skip to content

Commit 415da4e

Browse files
committed
fix processor
1 parent 641302e commit 415da4e

File tree

15 files changed

+507
-168
lines changed

15 files changed

+507
-168
lines changed

cookbook/client/server/transformer/server_config.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ applications:
2424
- Qwen/Qwen3.5-4B
2525
deployments:
2626
- name: TinkerCompatServer
27-
max_ongoing_requests: 10
27+
max_ongoing_requests: 50
2828
autoscaling_config:
2929
min_replicas: 1 # Minimum number of replicas
3030
max_replicas: 1 # Maximum number of replicas
@@ -107,14 +107,14 @@ applications:
107107
route_prefix: /api/v1/processor
108108
import_path: processor
109109
args:
110-
ncpu_proc_per_node: 1 # 每节点 CPU 进程数
110+
ncpu_proc_per_node: 2 # 每节点 CPU 进程数
111111
device_group:
112112
name: model
113-
ranks: 1
113+
ranks: 2
114114
device_type: CPU
115115
device_mesh:
116116
device_type: CPU
117-
dp_size: 1 # 数据并行大小
117+
dp_size: 2 # 数据并行大小
118118
deployments:
119119
- name: ProcessorManagement
120120
autoscaling_config:

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids
2222
[flake8]
2323
max-line-length = 120
2424
select = B,E,F,P,T4,W,B9
25-
ignore = F401,F403,F405,F821,W503,E251,W504,E126
25+
ignore = F401,F403,F405,F821,W503,E251,W504,E126,E125
2626
exclude = docs/src,*.pyi,.git,peft.py
2727

2828
[darglint]

src/twinkle/hub/hub.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def push_to_hub(cls,
374374
ignore_patterns = []
375375
if revision is None or revision == 'main':
376376
revision = 'master'
377-
return push_to_hub(
377+
result = push_to_hub(
378378
repo_id,
379379
folder_path,
380380
token or cls.ms_token,
@@ -383,6 +383,8 @@ def push_to_hub(cls,
383383
ignore_file_pattern=ignore_patterns,
384384
revision=revision,
385385
tag=path_in_repo)
386+
if not result:
387+
raise Exception('Failed to push to hub')
386388

387389
@classmethod
388390
def load_dataset(cls,

src/twinkle/server/common/datum.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Copyright (c) ModelScope Contributors. All rights reserved.
2-
# Moved from tinker/common/datum.py — logic unchanged.
32
from __future__ import annotations
43

54
import numpy as np

src/twinkle/server/common/router.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ async def choose_replicas(
5656

5757
# Filter out replicas that exceed max lora count (query from server state)
5858
candidate_ids = [r.replica_id.unique_id for r in top_ranked_replicas.values()]
59-
available_ids = set(self.state.get_available_replica_ids(candidate_ids))
59+
available_ids = set(await self.state.get_available_replica_ids(candidate_ids))
6060
if available_ids:
6161
top_ranked_replicas = {
6262
rid: r

src/twinkle/server/gateway/tinker_gateway_handlers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ async def create_session(
6161
async def session_heartbeat(
6262
request: Request, body: types.SessionHeartbeatRequest, self: GatewayServer = Depends(self_fn)
6363
) -> types.SessionHeartbeatResponse: # noqa: E125
64-
alive = self.state.touch_session(body.session_id)
64+
alive = await self.state.touch_session(body.session_id)
6565
if not alive:
6666
raise HTTPException(status_code=404, detail='Unknown session')
6767
return types.SessionHeartbeatResponse()
@@ -84,7 +84,7 @@ async def retrieve_future(request: Request,
8484
start = asyncio.get_event_loop().time()
8585

8686
while True:
87-
record = self.state.get_future(request_id)
87+
record = await self.state.get_future(request_id)
8888

8989
if record is None:
9090
return {'type': 'try_again'}
@@ -103,7 +103,7 @@ async def retrieve_future(request: Request,
103103

104104
await asyncio.sleep(poll_interval)
105105

106-
record = self.state.get_future(request_id)
106+
record = await self.state.get_future(request_id)
107107
if not record:
108108
return {'type': 'try_again'}
109109

@@ -207,7 +207,7 @@ async def publish_checkpoint(request: Request,
207207

208208
checkpoint_name = checkpoint_id.split('/')[-1]
209209
hub_model_id = f'{username}/{run_id}_{checkpoint_name}'
210-
HubOperation.async_push_to_hub(repo_id=hub_model_id, folder_path=checkpoint_dir, token=token, private=True)
210+
HubOperation.push_to_hub(repo_id=hub_model_id, folder_path=checkpoint_dir, token=token, private=True)
211211

212212
return Response(status_code=204)
213213

src/twinkle/server/gateway/twinkle_gateway_handlers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ async def session_heartbeat(
4343
body: types.SessionHeartbeatRequest,
4444
self: GatewayServer = Depends(self_fn),
4545
) -> types.SessionHeartbeatResponse:
46-
alive = self.state.touch_session(body.session_id)
46+
alive = await self.state.touch_session(body.session_id)
4747
if not alive:
4848
raise HTTPException(status_code=404, detail='Unknown session')
4949
return types.SessionHeartbeatResponse()
Lines changed: 101 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -1,155 +1,133 @@
11
# Copyright (c) ModelScope Contributors. All rights reserved.
22
"""
3-
Processor management application (moved from twinkle/processor.py).
3+
Processor management application.
44
55
Provides a Ray Serve deployment for managing distributed processors
66
(datasets, dataloaders, preprocessors, rewards, templates, weight loaders, etc.).
7+
8+
Follows the same structural pattern as model/app.py:
9+
- ProcessorManagement is a top-level class inheriting ProcessorManagerMixin
10+
- Routes are registered in build_processor_app() via _register_processor_routes()
11+
- serve.ingress(app)(ProcessorManagement) applied before deployment
12+
- Sticky session routing via @serve.multiplexed keyed on session ID
713
"""
8-
import importlib
14+
from __future__ import annotations
15+
916
import os
10-
import uuid
11-
from fastapi import FastAPI, HTTPException, Request
17+
from fastapi import FastAPI, Request
1218
from ray import serve
13-
from typing import Any, Dict
19+
from typing import Any, Dict, Optional
1420

1521
import twinkle
16-
import twinkle_client.types as types
1722
from twinkle import DeviceGroup, DeviceMesh, get_logger
18-
from twinkle.server.common.serialize import deserialize_object
23+
from twinkle.server.utils.processor_manager import ProcessorManagerMixin
1924
from twinkle.server.utils.state import ServerStateProxy, get_server_state
2025
from twinkle.server.utils.validation import verify_request_token
26+
from .twinkle_handlers import _register_processor_routes
2127

2228
logger = get_logger()
2329

2430

31+
class ProcessorManagement(ProcessorManagerMixin):
32+
"""Processor management service.
33+
34+
Manages lifecycle and invocation of distributed processor objects
35+
(datasets, dataloaders, rewards, templates, etc.).
36+
37+
Lifecycle is handled by ProcessorManagerMixin:
38+
- Processors are registered with a session ID on creation.
39+
- A background thread expires processors whose session has timed out.
40+
- Per-user processor limit is enforced at registration.
41+
- Sticky session routing ensures session requests hit the same replica.
42+
"""
43+
44+
def __init__(self,
45+
ncpu_proc_per_node: int,
46+
device_group: dict[str, Any],
47+
device_mesh: dict[str, Any],
48+
nproc_per_node: int = 1,
49+
processor_config: dict[str, Any] | None = None):
50+
self.device_group = DeviceGroup(**device_group)
51+
twinkle.initialize(
52+
mode='ray',
53+
nproc_per_node=nproc_per_node,
54+
groups=[self.device_group],
55+
lazy_collect=False,
56+
ncpu_proc_per_node=ncpu_proc_per_node)
57+
if 'mesh_dim_names' in device_mesh:
58+
self.device_mesh = DeviceMesh(**device_mesh)
59+
else:
60+
self.device_mesh = DeviceMesh.from_sizes(**device_mesh)
61+
62+
# processor objects keyed by processor_id
63+
self.resource_dict: dict[str, Any] = {}
64+
self.state: ServerStateProxy = get_server_state()
65+
66+
_cfg = processor_config or {}
67+
_env_limit = int(os.environ.get('TWINKLE_PER_USER_PROCESSOR_LIMIT', 20))
68+
self._init_processor_manager(
69+
processor_timeout=float(_cfg.get('processor_timeout', 1800.0)),
70+
per_token_processor_limit=int(_cfg.get('per_token_processor_limit', _env_limit)),
71+
)
72+
self.start_processor_countdown()
73+
74+
@serve.multiplexed(max_num_models_per_replica=100)
75+
async def _sticky_entry(self, sticky_key: str):
76+
return sticky_key
77+
78+
async def _ensure_sticky(self):
79+
sticky_key = serve.get_multiplexed_model_id()
80+
await self._sticky_entry(sticky_key)
81+
82+
def _on_processor_expired(self, processor_id: str) -> None:
83+
"""Called by the countdown thread when a processor's session expires."""
84+
self.resource_dict.pop(processor_id, None)
85+
self.unregister_processor(processor_id)
86+
87+
2588
def build_processor_app(ncpu_proc_per_node: int,
26-
device_group: Dict[str, Any],
27-
device_mesh: Dict[str, Any],
28-
deploy_options: Dict[str, Any],
89+
device_group: dict[str, Any],
90+
device_mesh: dict[str, Any],
91+
deploy_options: dict[str, Any],
2992
nproc_per_node: int = 1,
93+
processor_config: dict[str, Any] | None = None,
3094
**kwargs):
3195
"""Build the processor management application.
3296
97+
Follows the same pattern as build_model_app(): FastAPI app and routes are
98+
built here BEFORE serve.ingress so that the frozen app contains the full
99+
route table visible to ProxyActor.
100+
33101
Args:
34-
ncpu_proc_per_node: Number of CPU processes per node
35-
device_group: Device group configuration dict
36-
device_mesh: Device mesh configuration dict
37-
deploy_options: Ray Serve deployment options
38-
nproc_per_node: Number of GPU processes per node (default 1, not used for CPU-only tasks)
39-
**kwargs: Additional arguments
102+
ncpu_proc_per_node: Number of CPU processes per node.
103+
device_group: Device group configuration dict.
104+
device_mesh: Device mesh configuration dict.
105+
deploy_options: Ray Serve deployment options.
106+
nproc_per_node: Number of GPU processes per node (default 1).
107+
processor_config: Optional lifecycle configuration dict.
108+
Supported keys:
109+
- ``processor_timeout`` (float): Session inactivity timeout seconds. Default 1800.0.
110+
- ``per_token_processor_limit`` (int): Max processors per user.
111+
Overrides ``TWINKLE_PER_USER_PROCESSOR_LIMIT`` env var when provided.
112+
**kwargs: Additional arguments.
40113
41114
Returns:
42-
Ray Serve deployment bound with configuration
115+
Ray Serve deployment bound with configuration.
43116
"""
117+
# Build the FastAPI app and register all routes BEFORE serve.ingress so that
118+
# the frozen app contains the complete route table (visible to ProxyActor).
44119
app = FastAPI()
45120

46121
@app.middleware('http')
47122
async def verify_token(request: Request, call_next):
48123
return await verify_request_token(request=request, call_next=call_next)
49124

50-
processors = ['dataset', 'dataloader', 'preprocessor', 'processor', 'reward', 'template', 'weight_loader']
51-
52-
@serve.deployment(name='ProcessorManagement')
53-
@serve.ingress(app)
54-
class ProcessorManagement:
55-
"""Processor management service.
56-
57-
Manages lifecycle and invocation of distributed processor objects
58-
(datasets, dataloaders, rewards, templates, etc.).
59-
"""
60-
61-
def __init__(self,
62-
ncpu_proc_per_node: int,
63-
device_group: Dict[str, Any],
64-
device_mesh: Dict[str, Any],
65-
nproc_per_node: int = 1):
66-
self.device_group = DeviceGroup(**device_group)
67-
twinkle.initialize(
68-
mode='ray',
69-
nproc_per_node=nproc_per_node,
70-
groups=[self.device_group],
71-
lazy_collect=False,
72-
ncpu_proc_per_node=ncpu_proc_per_node)
73-
if 'mesh_dim_names' in device_mesh:
74-
self.device_mesh = DeviceMesh(**device_mesh)
75-
else:
76-
self.device_mesh = DeviceMesh.from_sizes(**device_mesh)
77-
self.resource_dict = {}
78-
self.state: ServerStateProxy = get_server_state()
79-
self.per_token_processor_limit = int(os.environ.get('TWINKLE_PER_USER_PROCESSOR_LIMIT', 20))
80-
self.key_token_dict = {}
81-
82-
def assert_processor_exists(self, processor_id: str):
83-
assert processor_id and processor_id in self.resource_dict, f'Processor {processor_id} not found'
84-
85-
@app.post('/twinkle/create', response_model=types.ProcessorCreateResponse)
86-
def create(self, request: Request, body: types.ProcessorCreateRequest) -> types.ProcessorCreateResponse:
87-
processor_type_name = body.processor_type
88-
class_type = body.class_type
89-
_kwargs = body.model_extra or {}
90-
91-
assert processor_type_name in processors, f'Invalid processor type: {processor_type_name}'
92-
processor_module = importlib.import_module(f'twinkle.{processor_type_name}')
93-
assert hasattr(processor_module, class_type), f'Class {class_type} not found in {processor_type_name}'
94-
processor_id = str(uuid.uuid4().hex)
95-
self.key_token_dict[processor_id] = request.state.token
96-
97-
_kwargs.pop('remote_group', None)
98-
_kwargs.pop('device_mesh', None)
99-
100-
resolved_kwargs = {}
101-
for key, value in _kwargs.items():
102-
if isinstance(value, str) and value.startswith('pid:'):
103-
ref_id = value[4:]
104-
resolved_kwargs[key] = self.resource_dict[ref_id]
105-
else:
106-
value = deserialize_object(value)
107-
resolved_kwargs[key] = value
108-
109-
processor = getattr(processor_module, class_type)(
110-
remote_group=self.device_group.name,
111-
device_mesh=self.device_mesh,
112-
instance_id=processor_id,
113-
**resolved_kwargs)
114-
self.resource_dict[processor_id] = processor
115-
return types.ProcessorCreateResponse(processor_id='pid:' + processor_id)
116-
117-
@app.post('/twinkle/call', response_model=types.ProcessorCallResponse)
118-
def call(self, body: types.ProcessorCallRequest) -> types.ProcessorCallResponse:
119-
processor_id = body.processor_id
120-
function_name = body.function
121-
_kwargs = body.model_extra or {}
122-
processor_id = processor_id[4:]
123-
self.assert_processor_exists(processor_id=processor_id)
124-
processor = self.resource_dict.get(processor_id)
125-
function = getattr(processor, function_name, None)
126-
127-
assert function is not None, f'`{function_name}` not found in {processor.__class__}'
128-
assert hasattr(function, '_execute'), f'Cannot call inner method of {processor.__class__}'
129-
130-
resolved_kwargs = {}
131-
for key, value in _kwargs.items():
132-
if isinstance(value, str) and value.startswith('pid:'):
133-
ref_id = value[4:]
134-
resolved_kwargs[key] = self.resource_dict[ref_id]
135-
else:
136-
value = deserialize_object(value)
137-
resolved_kwargs[key] = value
138-
139-
# Special handling for __next__ to catch StopIteration
140-
if function_name == '__next__':
141-
try:
142-
result = function(**resolved_kwargs)
143-
return types.ProcessorCallResponse(result=result)
144-
except StopIteration:
145-
# HTTP 410 Gone signals iterator exhausted
146-
raise HTTPException(status_code=410, detail='Iterator exhausted')
147-
148-
result = function(**resolved_kwargs)
149-
if function_name == '__iter__':
150-
return types.ProcessorCallResponse(result='ok')
151-
else:
152-
return types.ProcessorCallResponse(result=result)
153-
154-
return ProcessorManagement.options(**deploy_options).bind(
155-
ncpu_proc_per_node, device_group, device_mesh, nproc_per_node=nproc_per_node)
125+
def get_self() -> ProcessorManagement:
126+
return serve.get_replica_context().servable_object
127+
128+
_register_processor_routes(app, get_self)
129+
130+
ProcessorManagementWithIngress = serve.ingress(app)(ProcessorManagement)
131+
DeploymentClass = serve.deployment(name='ProcessorManagement')(ProcessorManagementWithIngress)
132+
return DeploymentClass.options(**deploy_options).bind(ncpu_proc_per_node, device_group, device_mesh, nproc_per_node,
133+
processor_config)

0 commit comments

Comments
 (0)