Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ applications:
deployments:
- name: ModelManagement
autoscaling_config:
min_replicas: 2
max_replicas: 2
min_replicas: 1
max_replicas: 1
target_ongoing_requests: 16
ray_actor_options:
num_cpus: 0.1
Expand Down
7 changes: 4 additions & 3 deletions cookbook/client/tinker/custom_service/self_cognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# that the model has learned the custom identity.
# The server must be running first (see server.py and server_config.yaml).
import os
import numpy as np
from tqdm import tqdm
from tinker import types
from twinkle import init_tinker_client
Expand Down Expand Up @@ -76,9 +77,9 @@ def train():
optim_result = optim_future.result()

# Compute weighted average log-loss per token for monitoring
# logprobs = np.concatenate([output['logprobs'].tolist() for output in fwdbwd_result.loss_fn_outputs])
# weights = np.concatenate([example.loss_fn_inputs['weights'].tolist() for example in input_datum])
# print(f'Loss per token: {-np.dot(logprobs, weights) / weights.sum():.4f}')
logprobs = np.concatenate([output['logprobs'].tolist() for output in fwdbwd_result.loss_fn_outputs])
weights = np.concatenate([example.loss_fn_inputs['weights'].tolist() for example in input_datum])
print(f'Loss per token: {-np.dot(logprobs, weights) / weights.sum():.4f}')
print(f'Training Metrics: {optim_result}')

# Save a checkpoint after each epoch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ for item in service_client.get_server_capabilities().supported_models:
When calling `init_tinker_client`, the following operations are automatically executed:

1. **Patch Tinker SDK**: Bypass Tinker's `tinker://` prefix validation, allowing it to connect to standard HTTP addresses
2. **Set Request Headers**: Inject necessary authentication headers such as `serve_multiplexed_model_id` and `Authorization`
2. **Set Request Headers**: Inject necessary authentication headers such as `X-Ray-Serve-Request-Id` and `Authorization`

After initialization, simply import `from tinker import ServiceClient` to connect to Twinkle Server, and **all existing Tinker training code can be used directly** without any modifications.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ for item in service_client.get_server_capabilities().supported_models:
调用 `init_tinker_client` 时,会自动执行以下操作:

1. **Patch Tinker SDK**:绕过 Tinker 的 `tinker://` 前缀校验,使其可以连接到标准 HTTP 地址
2. **设置请求头**:注入 `serve_multiplexed_model_id` 和 `Authorization` 等必要的认证头
2. **设置请求头**:注入 `X-Ray-Serve-Request-Id` 和 `Authorization` 等必要的认证头

初始化之后,直接导入 `from tinker import ServiceClient` 即可连接到 Twinkle Server,**所有已有的 Tinker 训练代码都可以直接使用**,无需任何修改。

Expand Down
8 changes: 0 additions & 8 deletions src/twinkle/server/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,7 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
from .launcher import ServerLauncher, launch_server
from .twinkle.model import build_model_app
from .twinkle.processor import build_processor_app
from .twinkle.sampler import build_sampler_app
from .twinkle.server import build_server_app

__all__ = [
'build_model_app',
'build_processor_app',
'build_sampler_app',
'build_server_app',
'ServerLauncher',
'launch_server',
]
7 changes: 6 additions & 1 deletion src/twinkle/server/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def _get_builders(self) -> dict[str, Callable]:
'build_sampler_app': build_sampler_app,
}
else: # twinkle
from twinkle.server import build_model_app, build_processor_app, build_sampler_app, build_server_app
from twinkle.server.twinkle import build_model_app, build_processor_app, build_sampler_app, build_server_app
self._builders = {
'build_server_app': build_server_app,
'build_model_app': build_model_app,
Expand Down Expand Up @@ -214,6 +214,11 @@ def _deploy_application(self, app_config: dict[str, Any]) -> None:
# Copy all deployment options from the config, except 'name'.
deploy_options = {k: v for k, v in deploy_config.items() if k != 'name'}

# Pass http_options to server apps for internal proxy routing
http_options = self.config.get('http_options', {})
if http_options:
args['http_options'] = http_options

# Build and deploy the application
app = builder(deploy_options=deploy_options, **{k: v for k, v in args.items()})

Expand Down
25 changes: 14 additions & 11 deletions src/twinkle/server/tinker/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import sys
from typing import TYPE_CHECKING

from ..utils import wrap_builder_with_device_group_env
from .model import build_model_app as _build_model_app
from .sampler import build_sampler_app as _build_sampler_app
from .server import build_server_app
from twinkle.utils.import_utils import _LazyModule

build_model_app = wrap_builder_with_device_group_env(_build_model_app)
build_sampler_app = wrap_builder_with_device_group_env(_build_sampler_app)
_import_structure = {
'model': ['build_model_app'],
'sampler': ['build_sampler_app'],
'server': ['build_server_app'],
}

__all__ = [
'build_model_app',
'build_sampler_app',
'build_server_app',
]
if TYPE_CHECKING:
from .model import build_model_app
from .sampler import build_sampler_app
from .server import build_server_app
else:
sys.modules[__name__] = _LazyModule(__name__, __file__, _import_structure, module_spec=__spec__)
3 changes: 2 additions & 1 deletion src/twinkle/server/tinker/common/datum.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import numpy as np
from collections import defaultdict
from tinker import types
from typing import List, Union

from twinkle.data_format.input_feature import InputFeature
from twinkle.template import Template
Expand Down Expand Up @@ -92,6 +91,8 @@ def input_feature_to_datum(input_feature: InputFeature) -> types.Datum:
labels_raw = input_feature['labels']
if isinstance(labels_raw, np.ndarray):
labels_arr = labels_raw.astype(np.int64)
elif isinstance(labels_raw, list):
labels_arr = np.asarray(labels_raw, dtype=np.int64)
else:
labels_arr = np.asarray(labels_raw.cpu(), dtype=np.int64)

Expand Down
4 changes: 4 additions & 0 deletions src/twinkle/server/tinker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from twinkle.server.utils.task_queue import TaskQueueConfig, TaskQueueMixin
from twinkle.server.utils.validation import get_token_from_request, verify_request_token
from twinkle.utils.logger import get_logger
from ..utils import wrap_builder_with_device_group_env
from .common.io_utils import create_checkpoint_manager, create_training_run_manager
from .common.router import StickyLoraRequestRouter

Expand Down Expand Up @@ -653,3 +654,6 @@ async def _do_load():

return ModelManagement.options(**deploy_options).bind(nproc_per_node, device_group, device_mesh, use_megatron,
queue_config, **kwargs)


build_model_app = wrap_builder_with_device_group_env(build_model_app)
184 changes: 184 additions & 0 deletions src/twinkle/server/tinker/proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
"""
Proxy utilities for forwarding requests to internal services.

This module provides HTTP proxy functionality to route requests from the Tinker server
to appropriate model or sampler services based on base_model routing.
"""

from __future__ import annotations

import httpx
import os
from fastapi import Request, Response
from typing import Any

from twinkle.utils.logger import get_logger

logger = get_logger()


class ServiceProxy:
"""HTTP proxy for routing requests to internal model and sampler services.

This proxy handles:
1. URL construction using localhost to avoid external routing loops
2. Header forwarding with appropriate cleanup
3. Debug logging for troubleshooting
4. Error handling and response forwarding
"""

def __init__(
self,
http_options: dict[str, Any] | None = None,
route_prefix: str = '/api/v1',
):
"""Initialize the service proxy.

Args:
http_options: HTTP server options (host, port) for internal routing
route_prefix: URL prefix for routing (default: '/api/v1')
"""
self.http_options = http_options or {}
self.route_prefix = route_prefix
# Disable proxy for internal requests to avoid routing through external proxies
self.client = httpx.AsyncClient(timeout=None, trust_env=False)

def _build_target_url(self, service_type: str, base_model: str, endpoint: str) -> str:
"""Build the target URL for internal service routing.

Constructs URLs using localhost to avoid extra external hops.
When requests come from www.modelscope.com/twinkle, we proxy to
localhost:port directly instead of back to modelscope.com.

Args:
service_type: Either 'model' or 'sampler'
base_model: The base model name for routing
endpoint: The target endpoint name

Returns:
Complete target URL for the internal service
"""
prefix = self.route_prefix.rstrip('/') if self.route_prefix else ''
host = self.http_options.get('host', 'localhost')
port = self.http_options.get('port', 8000)

# Use localhost for internal routing
if host == '0.0.0.0':
host = 'localhost'

base_url = f'http://{host}:{port}'
return f'{base_url}{prefix}/{service_type}/{base_model}/{endpoint}'

def _prepare_headers(self, request_headers) -> dict[str, str]:
"""Prepare headers for proxying by removing problematic headers.

Args:
request_headers: Original request headers (case-insensitive from FastAPI)

Returns:
Cleaned headers safe for proxying
"""
logger.debug('prepare_headers request_headers=%s', request_headers)
# Convert to dict while preserving case-insensitive lookups for special headers
headers = dict(request_headers)
# Remove headers that should not be forwarded
headers.pop('host', None)
headers.pop('content-length', None)
# Add serve_multiplexed_model_id for sticky sessions if present
# Use case-insensitive lookup from original request_headers
request_id = request_headers.get('X-Ray-Serve-Request-Id')
if request_id is not None:
headers['serve_multiplexed_model_id'] = request_id
return headers

async def proxy_request(
self,
request: Request,
endpoint: str,
base_model: str,
service_type: str,
) -> Response:
"""Generic proxy method to forward requests to model or sampler services.

This method consolidates the common proxy logic for both model and sampler endpoints.

Args:
request: The incoming FastAPI request
endpoint: The target endpoint name (e.g., 'create_model', 'asample')
base_model: The base model name for routing
service_type: Either 'model' or 'sampler' to determine the target service

Returns:
Proxied response from the target service
"""
body_bytes = await request.body()
target_url = self._build_target_url(service_type, base_model, endpoint)
# Pass original request.headers (case-insensitive) instead of dict conversion
headers = self._prepare_headers(request.headers)

try:
# Debug logging for troubleshooting proxy issues
logger.debug(
'proxy_request service=%s endpoint=%s target_url=%s request_id=%s',
service_type,
endpoint,
target_url,
headers.get('serve_multiplexed_model_id'),
)

# Forward the request to the target service
response = await self.client.request(
method=request.method,
url=target_url,
content=body_bytes,
headers=headers,
params=request.query_params,
)

# Debug logging for response
logger.debug(
'proxy_response status=%s body_preview=%s',
response.status_code,
response.text[:200],
)

return Response(
content=response.content,
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.headers.get('content-type'),
)
except Exception as e:
logger.error('Proxy error: %s', str(e), exc_info=True)
return Response(content=f'Proxy Error: {str(e)}', status_code=502)

async def proxy_to_model(self, request: Request, endpoint: str, base_model: str) -> Response:
"""Proxy request to model endpoint.

Routes the request to the appropriate model deployment based on base_model.

Args:
request: The incoming FastAPI request
endpoint: The target endpoint name (e.g., 'create_model', 'forward')
base_model: The base model name for routing

Returns:
Proxied response from the model service
"""
return await self.proxy_request(request, endpoint, base_model, 'model')

async def proxy_to_sampler(self, request: Request, endpoint: str, base_model: str) -> Response:
"""Proxy request to sampler endpoint.

Routes the request to the appropriate sampler deployment based on base_model.

Args:
request: The incoming FastAPI request
endpoint: The target endpoint name (e.g., 'asample')
base_model: The base model name for routing

Returns:
Proxied response from the sampler service
"""
return await self.proxy_request(request, endpoint, base_model, 'sampler')
4 changes: 4 additions & 0 deletions src/twinkle/server/tinker/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from twinkle.server.utils.task_queue import TaskQueueConfig, TaskQueueMixin
from twinkle.server.utils.validation import get_token_from_request, verify_request_token
from twinkle.utils.logger import get_logger
from ..utils import wrap_builder_with_device_group_env
from .common.io_utils import create_checkpoint_manager

logger = get_logger()
Expand Down Expand Up @@ -245,3 +246,6 @@ async def _do_sample():

return SamplerManagement.options(**deploy_options).bind(nproc_per_node, device_group, device_mesh, sampler_type,
engine_args, queue_config, **kwargs)


build_sampler_app = wrap_builder_with_device_group_env(build_sampler_app)
Loading
Loading