Skip to content

Commit 2c24077

Browse files
authored
Fix proxy (#87)
* update init * update proxy * update proxy * update proxy * update * update * update
1 parent cc2e9c0 commit 2c24077

File tree

16 files changed

+269
-132
lines changed

16 files changed

+269
-132
lines changed

cookbook/client/tinker/custom_service/megatron/server_config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ applications:
6262
deployments:
6363
- name: ModelManagement
6464
autoscaling_config:
65-
min_replicas: 2
66-
max_replicas: 2
65+
min_replicas: 1
66+
max_replicas: 1
6767
target_ongoing_requests: 16
6868
ray_actor_options:
6969
num_cpus: 0.1

cookbook/client/tinker/custom_service/self_cognition.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# that the model has learned the custom identity.
88
# The server must be running first (see server.py and server_config.yaml).
99
import os
10+
import numpy as np
1011
from tqdm import tqdm
1112
from tinker import types
1213
from twinkle import init_tinker_client
@@ -76,9 +77,9 @@ def train():
7677
optim_result = optim_future.result()
7778

7879
# Compute weighted average log-loss per token for monitoring
79-
# logprobs = np.concatenate([output['logprobs'].tolist() for output in fwdbwd_result.loss_fn_outputs])
80-
# weights = np.concatenate([example.loss_fn_inputs['weights'].tolist() for example in input_datum])
81-
# print(f'Loss per token: {-np.dot(logprobs, weights) / weights.sum():.4f}')
80+
logprobs = np.concatenate([output['logprobs'].tolist() for output in fwdbwd_result.loss_fn_outputs])
81+
weights = np.concatenate([example.loss_fn_inputs['weights'].tolist() for example in input_datum])
82+
print(f'Loss per token: {-np.dot(logprobs, weights) / weights.sum():.4f}')
8283
print(f'Training Metrics: {optim_result}')
8384

8485
# Save a checkpoint after each epoch

docs/source_en/Usage Guide/Server and Client/Tinker-Compatible-Client.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ for item in service_client.get_server_capabilities().supported_models:
2828
When calling `init_tinker_client`, the following operations are automatically executed:
2929

3030
1. **Patch Tinker SDK**: Bypass Tinker's `tinker://` prefix validation, allowing it to connect to standard HTTP addresses
31-
2. **Set Request Headers**: Inject necessary authentication headers such as `serve_multiplexed_model_id` and `Authorization`
31+
2. **Set Request Headers**: Inject necessary authentication headers such as `X-Ray-Serve-Request-Id` and `Authorization`
3232

3333
After initialization, simply import `from tinker import ServiceClient` to connect to Twinkle Server, and **all existing Tinker training code can be used directly** without any modifications.
3434

docs/source_zh/使用指引/服务端和客户端/Tinker兼容客户端.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ for item in service_client.get_server_capabilities().supported_models:
2828
调用 `init_tinker_client` 时,会自动执行以下操作:
2929

3030
1. **Patch Tinker SDK**:绕过 Tinker 的 `tinker://` 前缀校验,使其可以连接到标准 HTTP 地址
31-
2. **设置请求头**:注入 `serve_multiplexed_model_id``Authorization` 等必要的认证头
31+
2. **设置请求头**:注入 `X-Ray-Serve-Request-Id``Authorization` 等必要的认证头
3232

3333
初始化之后,直接导入 `from tinker import ServiceClient` 即可连接到 Twinkle Server,**所有已有的 Tinker 训练代码都可以直接使用**,无需任何修改。
3434

src/twinkle/server/__init__.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,7 @@
11
# Copyright (c) ModelScope Contributors. All rights reserved.
22
from .launcher import ServerLauncher, launch_server
3-
from .twinkle.model import build_model_app
4-
from .twinkle.processor import build_processor_app
5-
from .twinkle.sampler import build_sampler_app
6-
from .twinkle.server import build_server_app
73

84
__all__ = [
9-
'build_model_app',
10-
'build_processor_app',
11-
'build_sampler_app',
12-
'build_server_app',
135
'ServerLauncher',
146
'launch_server',
157
]

src/twinkle/server/launcher.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def _get_builders(self) -> dict[str, Callable]:
101101
'build_sampler_app': build_sampler_app,
102102
}
103103
else: # twinkle
104-
from twinkle.server import build_model_app, build_processor_app, build_sampler_app, build_server_app
104+
from twinkle.server.twinkle import build_model_app, build_processor_app, build_sampler_app, build_server_app
105105
self._builders = {
106106
'build_server_app': build_server_app,
107107
'build_model_app': build_model_app,
@@ -214,6 +214,11 @@ def _deploy_application(self, app_config: dict[str, Any]) -> None:
214214
# Copy all deployment options from the config, except 'name'.
215215
deploy_options = {k: v for k, v in deploy_config.items() if k != 'name'}
216216

217+
# Pass http_options to server apps for internal proxy routing
218+
http_options = self.config.get('http_options', {})
219+
if http_options:
220+
args['http_options'] = http_options
221+
217222
# Build and deploy the application
218223
app = builder(deploy_options=deploy_options, **{k: v for k, v in args.items()})
219224

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
# Copyright (c) ModelScope Contributors. All rights reserved.
2+
import sys
3+
from typing import TYPE_CHECKING
24

3-
from ..utils import wrap_builder_with_device_group_env
4-
from .model import build_model_app as _build_model_app
5-
from .sampler import build_sampler_app as _build_sampler_app
6-
from .server import build_server_app
5+
from twinkle.utils.import_utils import _LazyModule
76

8-
build_model_app = wrap_builder_with_device_group_env(_build_model_app)
9-
build_sampler_app = wrap_builder_with_device_group_env(_build_sampler_app)
7+
_import_structure = {
8+
'model': ['build_model_app'],
9+
'sampler': ['build_sampler_app'],
10+
'server': ['build_server_app'],
11+
}
1012

11-
__all__ = [
12-
'build_model_app',
13-
'build_sampler_app',
14-
'build_server_app',
15-
]
13+
if TYPE_CHECKING:
14+
from .model import build_model_app
15+
from .sampler import build_sampler_app
16+
from .server import build_server_app
17+
else:
18+
sys.modules[__name__] = _LazyModule(__name__, __file__, _import_structure, module_spec=__spec__)

src/twinkle/server/tinker/common/datum.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import numpy as np
44
from collections import defaultdict
55
from tinker import types
6-
from typing import List, Union
76

87
from twinkle.data_format.input_feature import InputFeature
98
from twinkle.template import Template
@@ -92,6 +91,8 @@ def input_feature_to_datum(input_feature: InputFeature) -> types.Datum:
9291
labels_raw = input_feature['labels']
9392
if isinstance(labels_raw, np.ndarray):
9493
labels_arr = labels_raw.astype(np.int64)
94+
elif isinstance(labels_raw, list):
95+
labels_arr = np.asarray(labels_raw, dtype=np.int64)
9596
else:
9697
labels_arr = np.asarray(labels_raw.cpu(), dtype=np.int64)
9798

src/twinkle/server/tinker/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from twinkle.server.utils.task_queue import TaskQueueConfig, TaskQueueMixin
2525
from twinkle.server.utils.validation import get_token_from_request, verify_request_token
2626
from twinkle.utils.logger import get_logger
27+
from ..utils import wrap_builder_with_device_group_env
2728
from .common.io_utils import create_checkpoint_manager, create_training_run_manager
2829
from .common.router import StickyLoraRequestRouter
2930

@@ -653,3 +654,6 @@ async def _do_load():
653654

654655
return ModelManagement.options(**deploy_options).bind(nproc_per_node, device_group, device_mesh, use_megatron,
655656
queue_config, **kwargs)
657+
658+
659+
build_model_app = wrap_builder_with_device_group_env(build_model_app)

src/twinkle/server/tinker/proxy.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
# Copyright (c) ModelScope Contributors. All rights reserved.
2+
"""
3+
Proxy utilities for forwarding requests to internal services.
4+
5+
This module provides HTTP proxy functionality to route requests from the Tinker server
6+
to appropriate model or sampler services based on base_model routing.
7+
"""
8+
9+
from __future__ import annotations
10+
11+
import httpx
12+
import os
13+
from fastapi import Request, Response
14+
from typing import Any
15+
16+
from twinkle.utils.logger import get_logger
17+
18+
logger = get_logger()
19+
20+
21+
class ServiceProxy:
22+
"""HTTP proxy for routing requests to internal model and sampler services.
23+
24+
This proxy handles:
25+
1. URL construction using localhost to avoid external routing loops
26+
2. Header forwarding with appropriate cleanup
27+
3. Debug logging for troubleshooting
28+
4. Error handling and response forwarding
29+
"""
30+
31+
def __init__(
32+
self,
33+
http_options: dict[str, Any] | None = None,
34+
route_prefix: str = '/api/v1',
35+
):
36+
"""Initialize the service proxy.
37+
38+
Args:
39+
http_options: HTTP server options (host, port) for internal routing
40+
route_prefix: URL prefix for routing (default: '/api/v1')
41+
"""
42+
self.http_options = http_options or {}
43+
self.route_prefix = route_prefix
44+
# Disable proxy for internal requests to avoid routing through external proxies
45+
self.client = httpx.AsyncClient(timeout=None, trust_env=False)
46+
47+
def _build_target_url(self, service_type: str, base_model: str, endpoint: str) -> str:
48+
"""Build the target URL for internal service routing.
49+
50+
Constructs URLs using localhost to avoid extra external hops.
51+
When requests come from www.modelscope.com/twinkle, we proxy to
52+
localhost:port directly instead of back to modelscope.com.
53+
54+
Args:
55+
service_type: Either 'model' or 'sampler'
56+
base_model: The base model name for routing
57+
endpoint: The target endpoint name
58+
59+
Returns:
60+
Complete target URL for the internal service
61+
"""
62+
prefix = self.route_prefix.rstrip('/') if self.route_prefix else ''
63+
host = self.http_options.get('host', 'localhost')
64+
port = self.http_options.get('port', 8000)
65+
66+
# Use localhost for internal routing
67+
if host == '0.0.0.0':
68+
host = 'localhost'
69+
70+
base_url = f'http://{host}:{port}'
71+
return f'{base_url}{prefix}/{service_type}/{base_model}/{endpoint}'
72+
73+
def _prepare_headers(self, request_headers) -> dict[str, str]:
74+
"""Prepare headers for proxying by removing problematic headers.
75+
76+
Args:
77+
request_headers: Original request headers (case-insensitive from FastAPI)
78+
79+
Returns:
80+
Cleaned headers safe for proxying
81+
"""
82+
logger.debug('prepare_headers request_headers=%s', request_headers)
83+
# Convert to dict while preserving case-insensitive lookups for special headers
84+
headers = dict(request_headers)
85+
# Remove headers that should not be forwarded
86+
headers.pop('host', None)
87+
headers.pop('content-length', None)
88+
# Add serve_multiplexed_model_id for sticky sessions if present
89+
# Use case-insensitive lookup from original request_headers
90+
request_id = request_headers.get('X-Ray-Serve-Request-Id')
91+
if request_id is not None:
92+
headers['serve_multiplexed_model_id'] = request_id
93+
return headers
94+
95+
async def proxy_request(
96+
self,
97+
request: Request,
98+
endpoint: str,
99+
base_model: str,
100+
service_type: str,
101+
) -> Response:
102+
"""Generic proxy method to forward requests to model or sampler services.
103+
104+
This method consolidates the common proxy logic for both model and sampler endpoints.
105+
106+
Args:
107+
request: The incoming FastAPI request
108+
endpoint: The target endpoint name (e.g., 'create_model', 'asample')
109+
base_model: The base model name for routing
110+
service_type: Either 'model' or 'sampler' to determine the target service
111+
112+
Returns:
113+
Proxied response from the target service
114+
"""
115+
body_bytes = await request.body()
116+
target_url = self._build_target_url(service_type, base_model, endpoint)
117+
# Pass original request.headers (case-insensitive) instead of dict conversion
118+
headers = self._prepare_headers(request.headers)
119+
120+
try:
121+
# Debug logging for troubleshooting proxy issues
122+
logger.debug(
123+
'proxy_request service=%s endpoint=%s target_url=%s request_id=%s',
124+
service_type,
125+
endpoint,
126+
target_url,
127+
headers.get('serve_multiplexed_model_id'),
128+
)
129+
130+
# Forward the request to the target service
131+
response = await self.client.request(
132+
method=request.method,
133+
url=target_url,
134+
content=body_bytes,
135+
headers=headers,
136+
params=request.query_params,
137+
)
138+
139+
# Debug logging for response
140+
logger.debug(
141+
'proxy_response status=%s body_preview=%s',
142+
response.status_code,
143+
response.text[:200],
144+
)
145+
146+
return Response(
147+
content=response.content,
148+
status_code=response.status_code,
149+
headers=dict(response.headers),
150+
media_type=response.headers.get('content-type'),
151+
)
152+
except Exception as e:
153+
logger.error('Proxy error: %s', str(e), exc_info=True)
154+
return Response(content=f'Proxy Error: {str(e)}', status_code=502)
155+
156+
async def proxy_to_model(self, request: Request, endpoint: str, base_model: str) -> Response:
157+
"""Proxy request to model endpoint.
158+
159+
Routes the request to the appropriate model deployment based on base_model.
160+
161+
Args:
162+
request: The incoming FastAPI request
163+
endpoint: The target endpoint name (e.g., 'create_model', 'forward')
164+
base_model: The base model name for routing
165+
166+
Returns:
167+
Proxied response from the model service
168+
"""
169+
return await self.proxy_request(request, endpoint, base_model, 'model')
170+
171+
async def proxy_to_sampler(self, request: Request, endpoint: str, base_model: str) -> Response:
172+
"""Proxy request to sampler endpoint.
173+
174+
Routes the request to the appropriate sampler deployment based on base_model.
175+
176+
Args:
177+
request: The incoming FastAPI request
178+
endpoint: The target endpoint name (e.g., 'asample')
179+
base_model: The base model name for routing
180+
181+
Returns:
182+
Proxied response from the sampler service
183+
"""
184+
return await self.proxy_request(request, endpoint, base_model, 'sampler')

0 commit comments

Comments
 (0)