From 9bee81d02f41dab20ed533fe878e0ab3f88f798c Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Sun, 1 Mar 2026 12:43:30 +0800 Subject: [PATCH 1/7] update init --- src/twinkle/server/__init__.py | 8 -------- src/twinkle/server/launcher.py | 2 +- src/twinkle/server/tinker/__init__.py | 25 +++++++++++++---------- src/twinkle/server/tinker/common/datum.py | 3 ++- src/twinkle/server/tinker/model.py | 4 ++++ src/twinkle/server/tinker/sampler.py | 4 ++++ src/twinkle/server/tinker/server.py | 4 ++-- src/twinkle/server/twinkle/__init__.py | 23 +++++++++++++++++---- 8 files changed, 46 insertions(+), 27 deletions(-) diff --git a/src/twinkle/server/__init__.py b/src/twinkle/server/__init__.py index b2f890a6..5bdb1f97 100644 --- a/src/twinkle/server/__init__.py +++ b/src/twinkle/server/__init__.py @@ -1,15 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from .launcher import ServerLauncher, launch_server -from .twinkle.model import build_model_app -from .twinkle.processor import build_processor_app -from .twinkle.sampler import build_sampler_app -from .twinkle.server import build_server_app __all__ = [ - 'build_model_app', - 'build_processor_app', - 'build_sampler_app', - 'build_server_app', 'ServerLauncher', 'launch_server', ] diff --git a/src/twinkle/server/launcher.py b/src/twinkle/server/launcher.py index b5b53f6a..7b7c735c 100644 --- a/src/twinkle/server/launcher.py +++ b/src/twinkle/server/launcher.py @@ -101,7 +101,7 @@ def _get_builders(self) -> dict[str, Callable]: 'build_sampler_app': build_sampler_app, } else: # twinkle - from twinkle.server import build_model_app, build_processor_app, build_sampler_app, build_server_app + from twinkle.server.twinkle import build_model_app, build_processor_app, build_sampler_app, build_server_app self._builders = { 'build_server_app': build_server_app, 'build_model_app': build_model_app, diff --git a/src/twinkle/server/tinker/__init__.py b/src/twinkle/server/tinker/__init__.py index 6c1570ff..40688d64 100644 --- a/src/twinkle/server/tinker/__init__.py +++ b/src/twinkle/server/tinker/__init__.py @@ -1,15 +1,18 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import sys +from typing import TYPE_CHECKING -from ..utils import wrap_builder_with_device_group_env -from .model import build_model_app as _build_model_app -from .sampler import build_sampler_app as _build_sampler_app -from .server import build_server_app +from twinkle.utils.import_utils import _LazyModule -build_model_app = wrap_builder_with_device_group_env(_build_model_app) -build_sampler_app = wrap_builder_with_device_group_env(_build_sampler_app) +_import_structure = { + 'model': ['build_model_app'], + 'sampler': ['build_sampler_app'], + 'server': ['build_server_app'], +} -__all__ = [ - 'build_model_app', - 'build_sampler_app', - 'build_server_app', -] +if TYPE_CHECKING: + from .model import build_model_app + from .sampler import build_sampler_app + from .server import build_server_app +else: + sys.modules[__name__] = _LazyModule(__name__, __file__, _import_structure, module_spec=__spec__) diff --git a/src/twinkle/server/tinker/common/datum.py b/src/twinkle/server/tinker/common/datum.py index 289176fe..0eb74f82 100644 --- a/src/twinkle/server/tinker/common/datum.py +++ b/src/twinkle/server/tinker/common/datum.py @@ -3,7 +3,6 @@ import numpy as np from collections import defaultdict from tinker import types -from typing import List, Union from twinkle.data_format.input_feature import InputFeature from twinkle.template import Template @@ -92,6 +91,8 @@ def input_feature_to_datum(input_feature: InputFeature) -> types.Datum: labels_raw = input_feature['labels'] if isinstance(labels_raw, np.ndarray): labels_arr = labels_raw.astype(np.int64) + elif isinstance(labels_raw, list): + labels_arr = np.asarray(labels_raw, dtype=np.int64) else: labels_arr = np.asarray(labels_raw.cpu(), dtype=np.int64) diff --git a/src/twinkle/server/tinker/model.py b/src/twinkle/server/tinker/model.py index 30ced15e..80778c36 100644 --- a/src/twinkle/server/tinker/model.py +++ b/src/twinkle/server/tinker/model.py @@ -24,6 +24,7 @@ from twinkle.server.utils.task_queue import TaskQueueConfig, TaskQueueMixin from twinkle.server.utils.validation import get_token_from_request, verify_request_token from twinkle.utils.logger import get_logger +from ..utils import wrap_builder_with_device_group_env from .common.io_utils import create_checkpoint_manager, create_training_run_manager from .common.router import StickyLoraRequestRouter @@ -653,3 +654,6 @@ async def _do_load(): return ModelManagement.options(**deploy_options).bind(nproc_per_node, device_group, device_mesh, use_megatron, queue_config, **kwargs) + + +build_model_app = wrap_builder_with_device_group_env(build_model_app) diff --git a/src/twinkle/server/tinker/sampler.py b/src/twinkle/server/tinker/sampler.py index 20b0a5a1..406524f3 100644 --- a/src/twinkle/server/tinker/sampler.py +++ b/src/twinkle/server/tinker/sampler.py @@ -23,6 +23,7 @@ from twinkle.server.utils.task_queue import TaskQueueConfig, TaskQueueMixin from twinkle.server.utils.validation import get_token_from_request, verify_request_token from twinkle.utils.logger import get_logger +from ..utils import wrap_builder_with_device_group_env from .common.io_utils import create_checkpoint_manager logger = get_logger() @@ -245,3 +246,6 @@ async def _do_sample(): return SamplerManagement.options(**deploy_options).bind(nproc_per_node, device_group, device_mesh, sampler_type, engine_args, queue_config, **kwargs) + + +build_sampler_app = wrap_builder_with_device_group_env(build_sampler_app) diff --git a/src/twinkle/server/tinker/server.py b/src/twinkle/server/tinker/server.py index 808f23c1..b35048e0 100644 --- a/src/twinkle/server/tinker/server.py +++ b/src/twinkle/server/tinker/server.py @@ -13,7 +13,6 @@ import asyncio import httpx -import logging import os from fastapi import FastAPI, HTTPException, Request, Response from ray import serve @@ -24,9 +23,10 @@ from twinkle.server.utils.state import get_server_state from twinkle.server.utils.task_queue import QueueState from twinkle.server.utils.validation import get_token_from_request, verify_request_token +from twinkle.utils.logger import get_logger from .common.io_utils import create_checkpoint_manager, create_training_run_manager -logger = logging.getLogger(__name__) +logger = get_logger() def build_server_app(deploy_options: dict[str, Any], diff --git a/src/twinkle/server/twinkle/__init__.py b/src/twinkle/server/twinkle/__init__.py index 54cc96be..7371b1d7 100644 --- a/src/twinkle/server/twinkle/__init__.py +++ b/src/twinkle/server/twinkle/__init__.py @@ -1,5 +1,20 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -from .model import build_model_app -from .processor import build_processor_app -from .sampler import build_sampler_app -from .server import build_server_app +import sys +from typing import TYPE_CHECKING + +from twinkle.utils.import_utils import _LazyModule + +_import_structure = { + 'model': ['build_model_app'], + 'processor': ['build_processor_app'], + 'sampler': ['build_sampler_app'], + 'server': ['build_server_app'], +} + +if TYPE_CHECKING: + from .model import build_model_app + from .processor import build_processor_app + from .sampler import build_sampler_app + from .server import build_server_app +else: + sys.modules[__name__] = _LazyModule(__name__, __file__, _import_structure, module_spec=__spec__) From a42bd71e28168c42a1e0d68f5c09acd10f8a0eca Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Sun, 1 Mar 2026 13:16:55 +0800 Subject: [PATCH 2/7] update proxy --- .../Tinker-Compatible-Client.md | 2 +- ...71\345\256\242\346\210\267\347\253\257.md" | 2 +- src/twinkle/server/launcher.py | 5 + src/twinkle/server/tinker/proxy.py | 180 ++++++++++++++++++ src/twinkle/server/tinker/server.py | 113 ++--------- src/twinkle/server/utils/validation.py | 7 +- src/twinkle_client/http/http_utils.py | 2 +- src/twinkle_client/utils/patch_tinker.py | 2 +- 8 files changed, 213 insertions(+), 100 deletions(-) create mode 100644 src/twinkle/server/tinker/proxy.py diff --git a/docs/source_en/Usage Guide/Server and Client/Tinker-Compatible-Client.md b/docs/source_en/Usage Guide/Server and Client/Tinker-Compatible-Client.md index 57e86366..a01fd141 100644 --- a/docs/source_en/Usage Guide/Server and Client/Tinker-Compatible-Client.md +++ b/docs/source_en/Usage Guide/Server and Client/Tinker-Compatible-Client.md @@ -28,7 +28,7 @@ for item in service_client.get_server_capabilities().supported_models: When calling `init_tinker_client`, the following operations are automatically executed: 1. **Patch Tinker SDK**: Bypass Tinker's `tinker://` prefix validation, allowing it to connect to standard HTTP addresses -2. **Set Request Headers**: Inject necessary authentication headers such as `serve_multiplexed_model_id` and `Authorization` +2. **Set Request Headers**: Inject necessary authentication headers such as `X-Ray-Serve-Request-Id` and `Authorization` After initialization, simply import `from tinker import ServiceClient` to connect to Twinkle Server, and **all existing Tinker training code can be used directly** without any modifications. diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Tinker\345\205\274\345\256\271\345\256\242\346\210\267\347\253\257.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Tinker\345\205\274\345\256\271\345\256\242\346\210\267\347\253\257.md" index e11ded44..11b51303 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Tinker\345\205\274\345\256\271\345\256\242\346\210\267\347\253\257.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Tinker\345\205\274\345\256\271\345\256\242\346\210\267\347\253\257.md" @@ -28,7 +28,7 @@ for item in service_client.get_server_capabilities().supported_models: 调用 `init_tinker_client` 时,会自动执行以下操作: 1. **Patch Tinker SDK**:绕过 Tinker 的 `tinker://` 前缀校验,使其可以连接到标准 HTTP 地址 -2. **设置请求头**:注入 `serve_multiplexed_model_id` 和 `Authorization` 等必要的认证头 +2. **设置请求头**:注入 `X-Ray-Serve-Request-Id` 和 `Authorization` 等必要的认证头 初始化之后,直接导入 `from tinker import ServiceClient` 即可连接到 Twinkle Server,**所有已有的 Tinker 训练代码都可以直接使用**,无需任何修改。 diff --git a/src/twinkle/server/launcher.py b/src/twinkle/server/launcher.py index 7b7c735c..334cd99d 100644 --- a/src/twinkle/server/launcher.py +++ b/src/twinkle/server/launcher.py @@ -214,6 +214,11 @@ def _deploy_application(self, app_config: dict[str, Any]) -> None: # Copy all deployment options from the config, except 'name'. deploy_options = {k: v for k, v in deploy_config.items() if k != 'name'} + # Pass http_options to server apps for internal proxy routing + http_options = self.config.get('http_options', {}) + if http_options: + args['http_options'] = http_options + # Build and deploy the application app = builder(deploy_options=deploy_options, **{k: v for k, v in args.items()}) diff --git a/src/twinkle/server/tinker/proxy.py b/src/twinkle/server/tinker/proxy.py new file mode 100644 index 00000000..0f4d8868 --- /dev/null +++ b/src/twinkle/server/tinker/proxy.py @@ -0,0 +1,180 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Proxy utilities for forwarding requests to internal services. + +This module provides HTTP proxy functionality to route requests from the Tinker server +to appropriate model or sampler services based on base_model routing. +""" + +from __future__ import annotations + +import httpx +import os +from fastapi import Request, Response +from typing import Any + +from twinkle.utils.logger import get_logger + +logger = get_logger() + + +class ServiceProxy: + """HTTP proxy for routing requests to internal model and sampler services. + + This proxy handles: + 1. URL construction using localhost to avoid external routing loops + 2. Header forwarding with appropriate cleanup + 3. Debug logging for troubleshooting + 4. Error handling and response forwarding + """ + + def __init__( + self, + http_options: dict[str, Any] | None = None, + route_prefix: str = '/api/v1', + ): + """Initialize the service proxy. + + Args: + http_options: HTTP server options (host, port) for internal routing + route_prefix: URL prefix for routing (default: '/api/v1') + """ + self.http_options = http_options or {} + self.route_prefix = route_prefix + # Disable proxy for internal requests to avoid routing through external proxies + self.client = httpx.AsyncClient(timeout=None, trust_env=False) + + def _build_target_url(self, service_type: str, base_model: str, endpoint: str) -> str: + """Build the target URL for internal service routing. + + Constructs URLs using localhost to avoid extra external hops. + When requests come from www.modelscope.com/twinkle, we proxy to + localhost:port directly instead of back to modelscope.com. + + Args: + service_type: Either 'model' or 'sampler' + base_model: The base model name for routing + endpoint: The target endpoint name + + Returns: + Complete target URL for the internal service + """ + prefix = self.route_prefix.rstrip('/') if self.route_prefix else '' + host = self.http_options.get('host', 'localhost') + port = self.http_options.get('port', 8000) + + # Use localhost for internal routing + if host == '0.0.0.0': + host = 'localhost' + + base_url = f'http://{host}:{port}' + return f'{base_url}{prefix}/{service_type}/{base_model}/{endpoint}' + + def _prepare_headers(self, request_headers: dict[str, str]) -> dict[str, str]: + """Prepare headers for proxying by removing problematic headers. + + Args: + request_headers: Original request headers + + Returns: + Cleaned headers safe for proxying + """ + headers = dict(request_headers) + # Remove headers that should not be forwarded + headers.pop('host', None) + headers.pop('content-length', None) + # Add serve_multiplexed_model_id for sticky sessions + headers['serve_multiplexed_model_id'] = request_headers.get('X-Ray-Serve-Request-Id') + return headers + + async def proxy_request( + self, + request: Request, + endpoint: str, + base_model: str, + service_type: str, + ) -> Response: + """Generic proxy method to forward requests to model or sampler services. + + This method consolidates the common proxy logic for both model and sampler endpoints. + + Args: + request: The incoming FastAPI request + endpoint: The target endpoint name (e.g., 'create_model', 'asample') + base_model: The base model name for routing + service_type: Either 'model' or 'sampler' to determine the target service + + Returns: + Proxied response from the target service + """ + body_bytes = await request.body() + target_url = self._build_target_url(service_type, base_model, endpoint) + headers = self._prepare_headers(dict(request.headers)) + + try: + # Debug logging for troubleshooting proxy issues + if os.environ.get('TWINKLE_DEBUG_PROXY', '0') == '1': + logger.info( + 'proxy_request service=%s endpoint=%s target_url=%s request_id=%s', + service_type, + endpoint, + target_url, + headers.get('x-ray-serve-request-id'), + ) + + # Forward the request to the target service + response = await self.client.request( + method=request.method, + url=target_url, + content=body_bytes, + headers=headers, + params=request.query_params, + ) + + # Debug logging for response + if os.environ.get('TWINKLE_DEBUG_PROXY', '0') == '1': + logger.info( + 'proxy_response status=%s body_preview=%s', + response.status_code, + response.text[:200], + ) + + return Response( + content=response.content, + status_code=response.status_code, + headers=dict(response.headers), + media_type=response.headers.get('content-type'), + ) + except Exception as e: + logger.error('Proxy error: %s', str(e), exc_info=True) + return Response(content=f'Proxy Error: {str(e)}', status_code=502) + + async def proxy_to_model(self, request: Request, endpoint: str, base_model: str) -> Response: + """Proxy request to model endpoint. + + Routes the request to the appropriate model deployment based on base_model. + + Args: + request: The incoming FastAPI request + endpoint: The target endpoint name (e.g., 'create_model', 'forward') + base_model: The base model name for routing + + Returns: + Proxied response from the model service + """ + return await self.proxy_request(request, endpoint, base_model, 'model') + + async def proxy_to_sampler(self, request: Request, endpoint: str, base_model: str) -> Response: + """Proxy request to sampler endpoint. + + Routes the request to the appropriate sampler deployment based on base_model. + + Args: + request: The incoming FastAPI request + endpoint: The target endpoint name (e.g., 'asample') + base_model: The base model name for routing + + Returns: + Proxied response from the sampler service + """ + return await self.proxy_request(request, endpoint, base_model, 'sampler') diff --git a/src/twinkle/server/tinker/server.py b/src/twinkle/server/tinker/server.py index b35048e0..81543c58 100644 --- a/src/twinkle/server/tinker/server.py +++ b/src/twinkle/server/tinker/server.py @@ -12,7 +12,6 @@ from __future__ import annotations import asyncio -import httpx import os from fastapi import FastAPI, HTTPException, Request, Response from ray import serve @@ -25,6 +24,7 @@ from twinkle.server.utils.validation import get_token_from_request, verify_request_token from twinkle.utils.logger import get_logger from .common.io_utils import create_checkpoint_manager, create_training_run_manager +from .proxy import ServiceProxy logger = get_logger() @@ -32,6 +32,7 @@ def build_server_app(deploy_options: dict[str, Any], supported_models: list[types.SupportedModel] | None = None, server_config: dict[str, Any] = {}, + http_options: dict[str, Any] | None = None, **kwargs): """Build and configure the Tinker-compatible server application. @@ -69,18 +70,23 @@ class TinkerCompatServer: def __init__(self, supported_models: list[types.SupportedModel] | None = None, server_config: dict[str, Any] = {}, + http_options: dict[str, Any] | None = None, **kwargs) -> None: """Initialize the Tinker-compatible server. Args: supported_models: List of supported base models for validation + server_config: Server configuration options + http_options: HTTP server options (host, port) for internal proxy routing **kwargs: Additional configuration (route_prefix, etc.) """ - # Get per_token_adapter_limit from kwargs or use default self.state = get_server_state(**server_config) - # Disable proxy for internal requests to avoid routing through external proxies - self.client = httpx.AsyncClient(timeout=None, trust_env=False) self.route_prefix = kwargs.get('route_prefix', '/api/v1') + self.http_options = http_options or {} + + # Initialize service proxy for routing requests to model/sampler services + self.proxy = ServiceProxy(http_options=http_options, route_prefix=self.route_prefix) + self.supported_models = self.normalize_models(supported_models) or [ types.SupportedModel(model_name='Qwen/Qwen3-30B-A3B-Instruct-2507'), ] @@ -135,83 +141,6 @@ def _get_base_model(self, model_id: str) -> str: return metadata['base_model'] raise HTTPException(status_code=404, detail=f'Model {model_id} not found') - async def _proxy_request(self, request: Request, endpoint: str, base_model: str, service_type: str) -> Response: - """Generic proxy method to forward requests to model or sampler services. - - This method consolidates the common proxy logic for both model and sampler endpoints. - - Args: - request: The incoming FastAPI request - endpoint: The target endpoint name (e.g., 'create_model', 'asample') - base_model: The base model name for routing - service_type: Either 'model' or 'sampler' to determine the target service - - Returns: - Proxied response from the target service - """ - body_bytes = await request.body() - - # Construct target URL: /{service_type}/{base_model}/{endpoint} - prefix = self.route_prefix.rstrip('/') if self.route_prefix else '' - base_url = f'{request.url.scheme}://{request.url.netloc}' - target_url = f'{base_url}{prefix}/{service_type}/{base_model}/{endpoint}' - - headers = dict(request.headers) - headers.pop('host', None) - headers.pop('content-length', None) - - try: - if os.environ.get('TWINKLE_DEBUG_PROXY', '0') == '1': - logger.info('proxy_to_model endpoint=%s target_url=%s x-ray-serve-request-id=%s', endpoint, - target_url, headers.get('x-ray-serve-request-id')) - rp_ = await self.client.request( - method=request.method, - url=target_url, - content=body_bytes, - headers=headers, - params=request.query_params, - ) - if os.environ.get('TWINKLE_DEBUG_PROXY', '0') == '1': - logger.info('proxy_to_model response status=%s body=%s', rp_.status_code, rp_.text[:200]) - return Response( - content=rp_.content, - status_code=rp_.status_code, - headers=dict(rp_.headers), - media_type=rp_.headers.get('content-type'), - ) - except Exception as e: - return Response(content=f'Proxy Error: {str(e)}', status_code=502) - - async def _proxy_to_model(self, request: Request, endpoint: str, base_model: str) -> Response: - """Proxy request to model endpoint. - - Routes the request to the appropriate model deployment based on base_model. - - Args: - request: The incoming FastAPI request - endpoint: The target endpoint name (e.g., 'create_model', 'forward') - base_model: The base model name for routing - - Returns: - Proxied response from the model service - """ - return await self._proxy_request(request, endpoint, base_model, 'model') - - async def _proxy_to_sampler(self, request: Request, endpoint: str, base_model: str) -> Response: - """Proxy request to sampler endpoint. - - Routes the request to the appropriate sampler deployment based on base_model. - - Args: - request: The incoming FastAPI request - endpoint: The target endpoint name (e.g., 'asample') - base_model: The base model name for routing - - Returns: - Proxied response from the sampler service - """ - return await self._proxy_request(request, endpoint, base_model, 'sampler') - # --- Endpoints --------------------------------------------------------- @app.get('/healthz') @@ -553,7 +482,7 @@ async def create_model(self, request: Request, body: types.CreateModelRequest) - Proxied response from model service """ self._validate_base_model(body.base_model) - return await self._proxy_to_model(request, 'create_model', body.base_model) + return await self.proxy.proxy_to_model(request, 'create_model', body.base_model) @app.post('/get_info') async def get_info(self, request: Request, body: types.GetInfoRequest) -> Any: @@ -565,7 +494,7 @@ async def get_info(self, request: Request, body: types.GetInfoRequest) -> Any: Returns: Proxied response from model service """ - return await self._proxy_to_model(request, 'get_info', self._get_base_model(body.model_id)) + return await self.proxy.proxy_to_model(request, 'get_info', self._get_base_model(body.model_id)) @app.post('/unload_model') async def unload_model(self, request: Request, body: types.UnloadModelRequest) -> Any: @@ -577,7 +506,7 @@ async def unload_model(self, request: Request, body: types.UnloadModelRequest) - Returns: Proxied response from model service """ - return await self._proxy_to_model(request, 'unload_model', self._get_base_model(body.model_id)) + return await self.proxy.proxy_to_model(request, 'unload_model', self._get_base_model(body.model_id)) @app.post('/forward') async def forward(self, request: Request, body: types.ForwardRequest) -> Any: @@ -589,7 +518,7 @@ async def forward(self, request: Request, body: types.ForwardRequest) -> Any: Returns: Proxied response from model service """ - return await self._proxy_to_model(request, 'forward', self._get_base_model(body.model_id)) + return await self.proxy.proxy_to_model(request, 'forward', self._get_base_model(body.model_id)) @app.post('/forward_backward') async def forward_backward(self, request: Request, body: types.ForwardBackwardRequest) -> Any: @@ -601,7 +530,7 @@ async def forward_backward(self, request: Request, body: types.ForwardBackwardRe Returns: Proxied response from model service """ - return await self._proxy_to_model(request, 'forward_backward', self._get_base_model(body.model_id)) + return await self.proxy.proxy_to_model(request, 'forward_backward', self._get_base_model(body.model_id)) @app.post('/optim_step') async def optim_step(self, request: Request, body: types.OptimStepRequest) -> Any: @@ -613,7 +542,7 @@ async def optim_step(self, request: Request, body: types.OptimStepRequest) -> An Returns: Proxied response from model service """ - return await self._proxy_to_model(request, 'optim_step', self._get_base_model(body.model_id)) + return await self.proxy.proxy_to_model(request, 'optim_step', self._get_base_model(body.model_id)) @app.post('/save_weights') async def save_weights(self, request: Request, body: types.SaveWeightsRequest) -> Any: @@ -625,7 +554,7 @@ async def save_weights(self, request: Request, body: types.SaveWeightsRequest) - Returns: Proxied response from model service """ - return await self._proxy_to_model(request, 'save_weights', self._get_base_model(body.model_id)) + return await self.proxy.proxy_to_model(request, 'save_weights', self._get_base_model(body.model_id)) @app.post('/load_weights') async def load_weights(self, request: Request, body: types.LoadWeightsRequest) -> Any: @@ -637,7 +566,7 @@ async def load_weights(self, request: Request, body: types.LoadWeightsRequest) - Returns: Proxied response from model service """ - return await self._proxy_to_model(request, 'load_weights', self._get_base_model(body.model_id)) + return await self.proxy.proxy_to_model(request, 'load_weights', self._get_base_model(body.model_id)) # --- Sampler Proxy Endpoints ---------------------------------------- @@ -662,7 +591,7 @@ async def asample(self, request: Request, body: types.SampleRequest) -> Any: if session: base_model = session.get('base_model') - return await self._proxy_to_sampler(request, 'asample', base_model) + return await self.proxy.proxy_to_sampler(request, 'asample', base_model) @app.post('/save_weights_for_sampler') async def save_weights_for_sampler(self, request: Request, body: types.SaveWeightsForSamplerRequest) -> Any: @@ -678,7 +607,7 @@ async def save_weights_for_sampler(self, request: Request, body: types.SaveWeigh """ # Proxy to model service for save_weights_for_sampler base_model = self._get_base_model(body.model_id) - return await self._proxy_to_model(request, 'save_weights_for_sampler', base_model) + return await self.proxy.proxy_to_model(request, 'save_weights_for_sampler', base_model) return TinkerCompatServer.options(**deploy_options).bind( - supported_models=supported_models, server_config=server_config, **kwargs) + supported_models=supported_models, server_config=server_config, http_options=http_options, **kwargs) diff --git a/src/twinkle/server/utils/validation.py b/src/twinkle/server/utils/validation.py index d419818a..23539ed8 100644 --- a/src/twinkle/server/utils/validation.py +++ b/src/twinkle/server/utils/validation.py @@ -11,7 +11,7 @@ async def verify_request_token(request: Request, call_next): This middleware: 1. Extracts the Bearer token from Authorization header 2. Validates the token - 3. Extracts serve_multiplexed_model_id for sticky sessions + 3. Extracts X-Ray-Serve-Request-Id for sticky sessions 4. Stores token and request_id in request.state for later use Args: @@ -26,11 +26,10 @@ async def verify_request_token(request: Request, call_next): if not is_token_valid(token): return JSONResponse(status_code=403, content={'detail': 'Invalid token'}) - request_id = request.headers.get('serve_multiplexed_model_id') + request_id = request.headers.get('X-Ray-Serve-Request-Id') if not request_id: return JSONResponse( - status_code=400, - content={'detail': 'Missing serve_multiplexed_model_id header, required for sticky session'}) + status_code=400, content={'detail': 'Missing X-Ray-Serve-Request-Id header, required for sticky session'}) request.state.request_id = request_id request.state.token = token response = await call_next(request) diff --git a/src/twinkle_client/http/http_utils.py b/src/twinkle_client/http/http_utils.py index f9cafa1c..522b46af 100644 --- a/src/twinkle_client/http/http_utils.py +++ b/src/twinkle_client/http/http_utils.py @@ -16,7 +16,7 @@ def _build_headers(additional_headers: Optional[Dict[str, str]] = None) -> Dict[ Dictionary of headers """ headers = { - 'serve_multiplexed_model_id': get_request_id(), + 'X-Ray-Serve-Request-Id': get_request_id(), 'Authorization': 'Bearer ' + get_api_key(), 'Twinkle-Authorization': 'Bearer ' + get_api_key(), # For server compatibility } diff --git a/src/twinkle_client/utils/patch_tinker.py b/src/twinkle_client/utils/patch_tinker.py index d0245c20..826274ae 100644 --- a/src/twinkle_client/utils/patch_tinker.py +++ b/src/twinkle_client/utils/patch_tinker.py @@ -127,7 +127,7 @@ def _patched_service_client_init(self, user_metadata=None, **kwargs): api_key = get_api_key() twinkle_headers = { - 'serve_multiplexed_model_id': get_request_id(), + 'X-Ray-Serve-Request-Id': get_request_id(), 'Authorization': 'Bearer ' + api_key, 'Twinkle-Authorization': 'Bearer ' + api_key, } From 1816c9d284517dcaad1e86eaa4831f711b754cff Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Sun, 1 Mar 2026 13:31:44 +0800 Subject: [PATCH 3/7] update proxy --- src/twinkle/server/tinker/proxy.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/twinkle/server/tinker/proxy.py b/src/twinkle/server/tinker/proxy.py index 0f4d8868..51cc43c3 100644 --- a/src/twinkle/server/tinker/proxy.py +++ b/src/twinkle/server/tinker/proxy.py @@ -79,12 +79,15 @@ def _prepare_headers(self, request_headers: dict[str, str]) -> dict[str, str]: Returns: Cleaned headers safe for proxying """ + logger.info('prepare_headers request_headers=%s', request_headers) headers = dict(request_headers) # Remove headers that should not be forwarded headers.pop('host', None) headers.pop('content-length', None) - # Add serve_multiplexed_model_id for sticky sessions - headers['serve_multiplexed_model_id'] = request_headers.get('X-Ray-Serve-Request-Id') + # Add serve_multiplexed_model_id for sticky sessions if present + request_id = request_headers.get('X-Ray-Serve-Request-Id') + if request_id is not None: + headers['serve_multiplexed_model_id'] = request_id return headers async def proxy_request( From 42ae691d73392d5670b37e7775caab0b0365fbce Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Sun, 1 Mar 2026 13:58:15 +0800 Subject: [PATCH 4/7] update proxy --- src/twinkle/server/tinker/proxy.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/twinkle/server/tinker/proxy.py b/src/twinkle/server/tinker/proxy.py index 51cc43c3..f6c3f942 100644 --- a/src/twinkle/server/tinker/proxy.py +++ b/src/twinkle/server/tinker/proxy.py @@ -70,22 +70,24 @@ def _build_target_url(self, service_type: str, base_model: str, endpoint: str) - base_url = f'http://{host}:{port}' return f'{base_url}{prefix}/{service_type}/{base_model}/{endpoint}' - def _prepare_headers(self, request_headers: dict[str, str]) -> dict[str, str]: + def _prepare_headers(self, request_headers) -> dict[str, str]: """Prepare headers for proxying by removing problematic headers. Args: - request_headers: Original request headers + request_headers: Original request headers (case-insensitive from FastAPI) Returns: Cleaned headers safe for proxying """ logger.info('prepare_headers request_headers=%s', request_headers) + # Convert to dict while preserving case-insensitive lookups for special headers headers = dict(request_headers) # Remove headers that should not be forwarded headers.pop('host', None) headers.pop('content-length', None) # Add serve_multiplexed_model_id for sticky sessions if present - request_id = request_headers.get('X-Ray-Serve-Request-Id') + # Use case-insensitive lookup from original request_headers + request_id = request_headers.get('x-ray-serve-request-id') if request_id is not None: headers['serve_multiplexed_model_id'] = request_id return headers @@ -112,7 +114,8 @@ async def proxy_request( """ body_bytes = await request.body() target_url = self._build_target_url(service_type, base_model, endpoint) - headers = self._prepare_headers(dict(request.headers)) + # Pass original request.headers (case-insensitive) instead of dict conversion + headers = self._prepare_headers(request.headers) try: # Debug logging for troubleshooting proxy issues @@ -122,7 +125,7 @@ async def proxy_request( service_type, endpoint, target_url, - headers.get('x-ray-serve-request-id'), + headers.get('serve_multiplexed_model_id'), ) # Forward the request to the target service From 599cb3e7a106df757a3ff3006582a18ae5db4734 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Sun, 1 Mar 2026 14:18:35 +0800 Subject: [PATCH 5/7] update --- src/twinkle/server/tinker/proxy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/twinkle/server/tinker/proxy.py b/src/twinkle/server/tinker/proxy.py index f6c3f942..a9bccef4 100644 --- a/src/twinkle/server/tinker/proxy.py +++ b/src/twinkle/server/tinker/proxy.py @@ -79,7 +79,7 @@ def _prepare_headers(self, request_headers) -> dict[str, str]: Returns: Cleaned headers safe for proxying """ - logger.info('prepare_headers request_headers=%s', request_headers) + logger.debug('prepare_headers request_headers=%s', request_headers) # Convert to dict while preserving case-insensitive lookups for special headers headers = dict(request_headers) # Remove headers that should not be forwarded From 894785f0874eb3789f091736be1a1ec963e0b9b4 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Sun, 1 Mar 2026 14:27:12 +0800 Subject: [PATCH 6/7] update --- .../client/tinker/custom_service/megatron/server_config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cookbook/client/tinker/custom_service/megatron/server_config.yaml b/cookbook/client/tinker/custom_service/megatron/server_config.yaml index 04f8c12c..b8fa1abd 100644 --- a/cookbook/client/tinker/custom_service/megatron/server_config.yaml +++ b/cookbook/client/tinker/custom_service/megatron/server_config.yaml @@ -62,8 +62,8 @@ applications: deployments: - name: ModelManagement autoscaling_config: - min_replicas: 2 - max_replicas: 2 + min_replicas: 1 + max_replicas: 1 target_ongoing_requests: 16 ray_actor_options: num_cpus: 0.1 From 7634154602a56e584fddb67c8f7af46780ece657 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Mon, 2 Mar 2026 10:47:06 +0800 Subject: [PATCH 7/7] update --- .../tinker/custom_service/self_cognition.py | 7 +++-- src/twinkle/server/tinker/proxy.py | 28 +++++++++---------- 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/cookbook/client/tinker/custom_service/self_cognition.py b/cookbook/client/tinker/custom_service/self_cognition.py index 9b78a14f..4acc88f7 100644 --- a/cookbook/client/tinker/custom_service/self_cognition.py +++ b/cookbook/client/tinker/custom_service/self_cognition.py @@ -7,6 +7,7 @@ # that the model has learned the custom identity. # The server must be running first (see server.py and server_config.yaml). import os +import numpy as np from tqdm import tqdm from tinker import types from twinkle import init_tinker_client @@ -76,9 +77,9 @@ def train(): optim_result = optim_future.result() # Compute weighted average log-loss per token for monitoring - # logprobs = np.concatenate([output['logprobs'].tolist() for output in fwdbwd_result.loss_fn_outputs]) - # weights = np.concatenate([example.loss_fn_inputs['weights'].tolist() for example in input_datum]) - # print(f'Loss per token: {-np.dot(logprobs, weights) / weights.sum():.4f}') + logprobs = np.concatenate([output['logprobs'].tolist() for output in fwdbwd_result.loss_fn_outputs]) + weights = np.concatenate([example.loss_fn_inputs['weights'].tolist() for example in input_datum]) + print(f'Loss per token: {-np.dot(logprobs, weights) / weights.sum():.4f}') print(f'Training Metrics: {optim_result}') # Save a checkpoint after each epoch diff --git a/src/twinkle/server/tinker/proxy.py b/src/twinkle/server/tinker/proxy.py index a9bccef4..bc429199 100644 --- a/src/twinkle/server/tinker/proxy.py +++ b/src/twinkle/server/tinker/proxy.py @@ -87,7 +87,7 @@ def _prepare_headers(self, request_headers) -> dict[str, str]: headers.pop('content-length', None) # Add serve_multiplexed_model_id for sticky sessions if present # Use case-insensitive lookup from original request_headers - request_id = request_headers.get('x-ray-serve-request-id') + request_id = request_headers.get('X-Ray-Serve-Request-Id') if request_id is not None: headers['serve_multiplexed_model_id'] = request_id return headers @@ -119,14 +119,13 @@ async def proxy_request( try: # Debug logging for troubleshooting proxy issues - if os.environ.get('TWINKLE_DEBUG_PROXY', '0') == '1': - logger.info( - 'proxy_request service=%s endpoint=%s target_url=%s request_id=%s', - service_type, - endpoint, - target_url, - headers.get('serve_multiplexed_model_id'), - ) + logger.debug( + 'proxy_request service=%s endpoint=%s target_url=%s request_id=%s', + service_type, + endpoint, + target_url, + headers.get('serve_multiplexed_model_id'), + ) # Forward the request to the target service response = await self.client.request( @@ -138,12 +137,11 @@ async def proxy_request( ) # Debug logging for response - if os.environ.get('TWINKLE_DEBUG_PROXY', '0') == '1': - logger.info( - 'proxy_response status=%s body_preview=%s', - response.status_code, - response.text[:200], - ) + logger.debug( + 'proxy_response status=%s body_preview=%s', + response.status_code, + response.text[:200], + ) return Response( content=response.content,