diff --git a/.vscode/settings.json b/.vscode/settings.json index 8d44549..d8fa722 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -65,7 +65,7 @@ "editor.defaultFormatter": "redhat.vscode-yaml", "editor.formatOnSave": true }, - "python.languageServer": "Pylance", + "python.languageServer": "None", "rust-analyzer.cargo.extraEnv": { "PYO3_PYTHON": "${workspaceFolder}/src-tauri/pyembed/python/python.exe" } diff --git a/app/components/panels/ConfigGroupPanel.vue b/app/components/panels/ConfigGroupPanel.vue index da50007..272b44c 100644 --- a/app/components/panels/ConfigGroupPanel.vue +++ b/app/components/panels/ConfigGroupPanel.vue @@ -248,6 +248,10 @@ const handleSave = async () => { configGroups.value.push(payload); currentIndex.value = configGroups.value.length - 1; } else if (hasSelection.value) { + const currentGroup = configGroups.value[selectedIndex.value]; + if (currentGroup?.id) { + payload.id = currentGroup.id; + } configGroups.value.splice(selectedIndex.value, 1, payload); } @@ -449,7 +453,7 @@ const moveDown = async () => { { themeDialogOpen.value = true; }; +const enable429Failover = computed({ + get: () => store.enable429Failover.value, + set: (value) => { + store.enable429Failover.value = value; + }, +}); + +const failover429CooldownSeconds = computed({ + get: () => store.failover429CooldownSeconds.value, + set: (value) => { + store.failover429CooldownSeconds.value = value; + }, +}); + +const routingGroupIds = computed({ + get: () => store.routingGroupIds.value, + set: (value) => { + store.routingGroupIds.value = value; + }, +}); + +type RoutingGroupOption = { + key: string; + label: string; + ids: string[]; +}; + +const routingGroupOptions = computed(() => { + const grouped = new Map(); + configGroups.value.forEach((group, index) => { + const id = (group.id || "").trim(); + if (!id) { + return; + } + const name = (group.name || "").trim(); + const key = name ? `name:${name}` : `id:${id}`; + const label = name || `配置组 ${index + 1}`; + const current = grouped.get(key); + if (current) { + current.ids.push(id); + return; + } + grouped.set(key, { key, label, ids: [id] }); + }); + return Array.from(grouped.values()); +}); + +const selectedRoutingGroupKeys = computed({ + get: () => { + const selectedIds = new Set(routingGroupIds.value); + return routingGroupOptions.value + .filter((option) => option.ids.some((id) => selectedIds.has(id))) + .map((option) => option.key); + }, + set: (keys: string[]) => { + const selectedKeys = new Set(keys); + routingGroupIds.value = Array.from( + new Set( + routingGroupOptions.value + .filter((option) => selectedKeys.has(option.key)) + .flatMap((option) => option.ids), + ), + ); + }, +}); + +const handleFailoverChange = async () => { + const ok = await store.saveConfig(); + if (ok) { + store.appendLog(`API 智能调度已${enable429Failover.value ? "启用" : "禁用"}`); + } else { + store.appendLog("保存配置失败"); + } +}; + +const handleRoutingGroupsChange = async () => { + const ok = await store.saveConfig(); + if (ok) { + if (selectedRoutingGroupKeys.value.length) { + store.appendLog(`已设置轮询配置组数量:${selectedRoutingGroupKeys.value.length}`); + } else { + store.appendLog("轮询配置组未指定,仅使用当前激活配置"); + } + return; + } + store.appendLog("保存配置失败"); +}; + const handleThemeSave = (value: ThemeConfig) => { const normalized = sanitizeThemeConfig(value); copyThemeConfig(themeConfig, normalized); @@ -139,6 +228,58 @@ const handleThemeSave = (value: ThemeConfig) => {
+
+
+
转发策略
+
负载均衡与自动容错
+
+ +
+ 节点冷却周期 (秒) + +
+
+
轮询配置组(可多选,未选则仅使用当前激活配置)
+
+ +
+
暂无可选配置组
+
+
+ 开启后,请求将在选中配置组间轮询分发。若节点触发 429 (Too Many Requests) + 频率限制,将自动静默切换至可用节点并对受限节点执行冷却隔离,确保服务连续性。 +
+
+
用户数据
diff --git a/app/composables/mtgaTypes.ts b/app/composables/mtgaTypes.ts index 4d8e2a8..6768667 100644 --- a/app/composables/mtgaTypes.ts +++ b/app/composables/mtgaTypes.ts @@ -1,4 +1,5 @@ export type ConfigGroup = { + id?: string; name?: string; api_url: string; model_id: string; @@ -13,6 +14,9 @@ export type ConfigPayload = { current_config_index: number; mapped_model_id: string; mtga_auth_key: string; + routing_group_ids?: string[]; + enable_429_failover?: boolean; + failover_429_cooldown_seconds?: number; }; export type AppInfo = { diff --git a/app/composables/useMtgaStore.ts b/app/composables/useMtgaStore.ts index 35d6174..985af87 100644 --- a/app/composables/useMtgaStore.ts +++ b/app/composables/useMtgaStore.ts @@ -162,6 +162,19 @@ const clampIndex = (value: number, max: number) => { return Math.min(Math.max(value, 0), max - 1); }; +const createConfigGroupId = (index: number) => { + if (typeof globalThis !== "undefined" && globalThis.crypto?.randomUUID) { + return globalThis.crypto.randomUUID(); + } + return `group-${Date.now()}-${index}-${Math.random().toString(16).slice(2, 10)}`; +}; + +const normalizeConfigGroups = (groups: ConfigGroup[]) => + groups.map((group, index) => { + const id = coerceText(group.id).trim() || createConfigGroupId(index); + return { ...group, id }; + }); + export const useMtgaStore = () => { const api = useMtgaApi(); @@ -169,6 +182,12 @@ export const useMtgaStore = () => { const currentConfigIndex = useState("mtga-current-config-index", () => 0); const mappedModelId = useState("mtga-mapped-model-id", () => ""); const mtgaAuthKey = useState("mtga-auth-key", () => ""); + const routingGroupIds = useState("mtga-routing-group-ids", () => []); + const enable429Failover = useState("mtga-enable-429-failover", () => false); + const failover429CooldownSeconds = useState( + "mtga-failover-429-cooldown-seconds", + () => 60, + ); const runtimeOptions = useState("mtga-runtime-options", () => ({ ...DEFAULT_RUNTIME_OPTIONS, })); @@ -449,24 +468,59 @@ export const useMtgaStore = () => { if (!result) { return false; } - configGroups.value = result.config_groups || []; + const loadedGroups = Array.isArray(result.config_groups) ? result.config_groups : []; + configGroups.value = normalizeConfigGroups(loadedGroups); currentConfigIndex.value = clampIndex( result.current_config_index ?? 0, configGroups.value.length, ); + const availableGroupIds = new Set( + configGroups.value.map((group) => coerceText(group.id).trim()), + ); + routingGroupIds.value = Array.isArray(result.routing_group_ids) + ? Array.from( + new Set( + result.routing_group_ids + .map((id) => coerceText(id).trim()) + .filter((id) => id && availableGroupIds.has(id)), + ), + ) + : []; mappedModelId.value = coerceText(result.mapped_model_id); mtgaAuthKey.value = coerceText(result.mtga_auth_key); + enable429Failover.value = Boolean(result.enable_429_failover); + const cooldownRaw = Number(result.failover_429_cooldown_seconds); + failover429CooldownSeconds.value = Number.isFinite(cooldownRaw) + ? Math.max(1, Math.round(cooldownRaw)) + : 60; return true; }; const saveConfig = async () => { + configGroups.value = normalizeConfigGroups(configGroups.value); const clampedIndex = clampIndex(currentConfigIndex.value, configGroups.value.length); currentConfigIndex.value = clampedIndex; + const availableGroupIds = new Set( + configGroups.value.map((group) => coerceText(group.id).trim()), + ); + routingGroupIds.value = Array.from( + new Set( + routingGroupIds.value + .map((id) => coerceText(id).trim()) + .filter((id) => id && availableGroupIds.has(id)), + ), + ); const payload: ConfigPayload = { config_groups: configGroups.value, current_config_index: clampedIndex, mapped_model_id: coerceText(mappedModelId.value), mtga_auth_key: coerceText(mtgaAuthKey.value), + routing_group_ids: routingGroupIds.value, + enable_429_failover: enable429Failover.value, + failover_429_cooldown_seconds: Math.max( + 1, + Math.round(Number(failover429CooldownSeconds.value) || 60), + ), }; const ok = await api.saveConfig(payload); return Boolean(ok); @@ -807,6 +861,9 @@ export const useMtgaStore = () => { currentConfigIndex, mappedModelId, mtgaAuthKey, + routingGroupIds, + enable429Failover, + failover429CooldownSeconds, runtimeOptions, logs, systemPrompts, diff --git a/app/docs/ui-migration.md b/app/docs/ui-migration.md index ccb045e..b143e8e 100644 --- a/app/docs/ui-migration.md +++ b/app/docs/ui-migration.md @@ -56,6 +56,8 @@ app/ - [x] `MainTabs` 支持切换并挂载各 Tab 内容(证书/hosts/代理/数据/关于)。 - [x] `ConfigGroupPanel` 改为可交互:列表数据、选中状态、增删改弹窗。 - [x] `GlobalConfigPanel` 与 `RuntimeOptionsPanel` 接入真实数据与保存逻辑。 +- [x] 新增 `routing_group_ids` 配置读写,支持按组路由选择。 +- [x] 设置页接入配置组多选,轮询仅针对选中组(空则回退全部)。 - [x] `LogPanel` 支持追加日志流(从后端或前端事件)。 - [x] `UpdateDialog`、确认弹窗完善交互与 HTML 内容渲染。 - [x] 用 `pyInvoke` 串起最小功能链路(例如 `greet` -> 日志输出)。 @@ -146,6 +148,7 @@ config_groups: ConfigGroup[] current_config_index: number mapped_model_id: string mtga_auth_key: string +routing_group_ids: string[] runtime_options: { debugMode: boolean disableSslStrict: boolean diff --git a/python-src/modules/proxy/proxy_app.py b/python-src/modules/proxy/proxy_app.py index 8910444..4ed9c7e 100644 --- a/python-src/modules/proxy/proxy_app.py +++ b/python-src/modules/proxy/proxy_app.py @@ -13,13 +13,21 @@ from flask import Flask, Response, jsonify, request from modules.proxy.proxy_auth import ProxyAuth -from modules.proxy.proxy_config import DEFAULT_MIDDLE_ROUTE, ProxyConfig, build_proxy_config +from modules.proxy.proxy_config import ( + DEFAULT_MIDDLE_ROUTE, + ProxyApiEndpoint, + ProxyConfig, + build_proxy_config, +) from modules.proxy.proxy_transport import ProxyTransport from modules.runtime.error_codes import ErrorCode from modules.runtime.operation_result import OperationResult from modules.runtime.resource_manager import ResourceManager from modules.services.system_prompt_service import SystemPromptStore +HTTP_STATUS_TOO_MANY_REQUESTS = 429 +HTTP_STATUS_BAD_REQUEST = 400 + class ProxyApp: """代理服务的领域逻辑:配置解析 + Flask 路由 + 上游转发。""" @@ -39,6 +47,8 @@ def __init__( self._transport_ref_counts: dict[int, int] = {} self._root_logger_default_level = logging.getLogger().level self._app_logger_default_level = logging.WARNING + self._endpoint_cursor = 0 + self._endpoint_429_until: dict[str, float] = {} self.app: Flask | None = None self.valid = True self.proxy_config: ProxyConfig | None = None @@ -111,6 +121,7 @@ def _snapshot_runtime_state(self) -> dict[str, Any]: "transport": self.transport, "http_client": self.http_client, "proxy_config": self.proxy_config, + "endpoint_cursor": self._endpoint_cursor, } def _snapshot_chat_runtime_state(self) -> dict[str, Any]: @@ -133,6 +144,7 @@ def _snapshot_chat_runtime_state(self) -> dict[str, Any]: "transport": transport, "http_client": self.http_client, "proxy_config": self.proxy_config, + "endpoint_cursor": self._endpoint_cursor, } def _release_transport_ref(self, transport: ProxyTransport | None) -> None: @@ -210,6 +222,7 @@ def apply_runtime_config(self, raw_config: dict[str, Any] | None) -> OperationRe self.auth = new_auth self.transport = new_transport self.http_client = new_transport.session + self._endpoint_cursor = 0 self._apply_debug_logging(self.debug_mode) @@ -227,6 +240,10 @@ def _timestamp_ms() -> str: ms = int((now % 1) * 1000) return f"{base}.{ms:03d}" + def _set_endpoint_cursor(self, value: int) -> None: + with self._config_lock: + self._endpoint_cursor = value + def _log_request(self, request_id: str, message: str) -> None: self.log_func(f"{self._timestamp_ms()} [{request_id}] {message}") @@ -445,13 +462,13 @@ def log(message: str) -> None: inbound_route = str(snapshot["inbound_route"]) target_model_id = str(snapshot["target_model_id"]) target_api_base_url = str(snapshot["target_api_base_url"]) - middle_route = str(snapshot["middle_route"]) stream_mode = snapshot["stream_mode"] debug_mode = bool(snapshot["debug_mode"]) auth = snapshot["auth"] transport = snapshot["transport"] http_client = snapshot["http_client"] proxy_config = snapshot["proxy_config"] + endpoint_cursor = int(snapshot["endpoint_cursor"] or 0) transport_released = False def release_transport() -> None: @@ -507,12 +524,10 @@ def release_transport() -> None: client_requested_stream = request_data.get("stream", False) log(f"客户端请求的流模式: {client_requested_stream}") - if "model" in request_data: - original_model = request_data["model"] - log(f"替换模型名: {original_model} -> {target_model_id}") - request_data["model"] = target_model_id - else: - log(f"请求中没有 model 字段,添加 model: {target_model_id}") + # 记录客户端原始请求的模型名,用于后续日志展示 + original_model = request_data.get("model", "unknown") + + if "model" not in request_data: request_data["model"] = target_model_id if stream_mode is not None: @@ -533,36 +548,208 @@ def release_transport() -> None: {"error": {"message": "Invalid authentication", "type": "authentication_error"}} ), 401 - target_api_key = "" - if isinstance(proxy_config, ProxyConfig): - target_api_key = proxy_config.api_key - forward_headers = auth.build_forward_headers( - auth_header, - target_api_key, - log_func=log, - ) - try: - target_url = ( - f"{target_api_base_url.rstrip('/')}" - f"{self._build_route(middle_route, 'chat/completions')}" - ) - log(f"转发请求到: {target_url}") - is_stream = request_data.get("stream", False) log(f"流模式: {is_stream}") - response_from_target = http_client.post( - target_url, - json=request_data, - headers=forward_headers, - stream=is_stream, - timeout=300, + api_endpoints: tuple[ProxyApiEndpoint, ...] + if isinstance(proxy_config, ProxyConfig): + api_endpoints = proxy_config.api_endpoints + else: + api_endpoints = ( + ProxyApiEndpoint( + api_url=target_api_base_url, + api_key="", + target_model_id=target_model_id, + ), + ) + + enable_routing = ( + isinstance(proxy_config, ProxyConfig) + and bool(proxy_config.enable_429_failover) + and len(api_endpoints) > 1 ) - response_from_target.raise_for_status() - if debug_mode: - log(f"上游响应状态码: {response_from_target.status_code}") - log(f"上游 Content-Type: {response_from_target.headers.get('content-type')}") + start_index = endpoint_cursor % len(api_endpoints) + + response_from_target = None + endpoint_order = [ + (start_index + offset) % len(api_endpoints) for offset in range(len(api_endpoints)) + ] + + def endpoint_key(endpoint: ProxyApiEndpoint) -> str: + return ( + f"{endpoint.api_url}|{endpoint.api_key}|" + f"{endpoint.target_model_id}|{endpoint.middle_route}" + ) + + def cooldown_remaining_seconds(key: str) -> float: + now = time.monotonic() + with self._config_lock: + until = float(self._endpoint_429_until.get(key, 0.0)) + if until <= now: + self._endpoint_429_until.pop(key, None) + return 0.0 + return until - now + + if enable_routing: + available: list[int] = [] + for idx in endpoint_order: + remaining = cooldown_remaining_seconds(endpoint_key(api_endpoints[idx])) + if remaining <= 0: + available.append(idx) + if available: + endpoint_order = available + else: + min_idx = endpoint_order[0] + min_remaining = cooldown_remaining_seconds( + endpoint_key(api_endpoints[min_idx]) + ) + for idx in endpoint_order[1:]: + remaining = cooldown_remaining_seconds(endpoint_key(api_endpoints[idx])) + if remaining < min_remaining: + min_idx = idx + min_remaining = remaining + endpoint_order = [min_idx] + log( + "所有节点处于 429 冷却中" + f"(节点={min_idx},剩余={min_remaining:.1f}s)" + ) + + for attempt, endpoint_index in enumerate(endpoint_order): + endpoint = api_endpoints[endpoint_index] + next_cursor = (endpoint_index + 1) % len(api_endpoints) + target_url = ( + f"{endpoint.api_url.rstrip('/')}" + f"{self._build_route(endpoint.middle_route, 'chat/completions')}" + ) + + if endpoint.target_model_id: + log(f"替换模型名: {original_model} -> {endpoint.target_model_id}") + request_data["model"] = endpoint.target_model_id + else: + # 如果端点没有指定特定模型,使用默认目标模型 + log(f"替换模型名: {original_model} -> {target_model_id}") + request_data["model"] = target_model_id + + current_request_data = dict(request_data) + + # 仅作兼容,后续移除 + if "siliconflow.cn" in target_url or "siliconflow.com" in target_url: + thinking_obj = current_request_data.get("thinking") + if thinking_obj is not None: + log(f"适配 SiliconFlow 参数,thinking={json.dumps(thinking_obj)}") + if isinstance(thinking_obj, dict): + thinking_map = cast(dict[str, Any], thinking_obj) + t_type = thinking_map.get("type") + t_budget = thinking_map.get("budget_tokens") or thinking_map.get( + "budget" + ) + if isinstance(t_type, str) and t_type: + current_request_data["enable_thinking"] = t_type != "disabled" + if isinstance(t_budget, (int, float)): + current_request_data["thinking_budget"] = int(t_budget) + elif isinstance(thinking_obj, str): + current_request_data["enable_thinking"] = thinking_obj != "disabled" + current_request_data.pop("thinking", None) + + log(f"转发请求到: {target_url}") + + target_api_key = endpoint.api_key + if not target_api_key and isinstance(proxy_config, ProxyConfig): + target_api_key = proxy_config.api_key + forward_headers = auth.build_forward_headers( + auth_header, + target_api_key, + log_func=log, + ) + + try: + response_from_target = http_client.post( + target_url, + json=current_request_data, + headers=forward_headers, + stream=is_stream, + timeout=300, + ) + except requests.exceptions.RequestException as request_exc: + log(f"上游请求异常: {request_exc}") + if enable_routing: + self._set_endpoint_cursor(next_cursor) + if enable_routing and attempt < len(endpoint_order) - 1: + log( + f"切换到下一个节点 {endpoint_index} -> {endpoint_order[attempt + 1]}" + ) + continue + raise + + if response_from_target.status_code == HTTP_STATUS_TOO_MANY_REQUESTS: + retry_after_seconds: float | None = None + retry_after = response_from_target.headers.get("retry-after") + retry_after_text = retry_after if retry_after else "-" + if retry_after and retry_after.isdigit(): + retry_after_seconds = float(int(retry_after)) + if retry_after_seconds is None: + retry_after_seconds = ( + float(proxy_config.failover_429_cooldown_seconds) + if isinstance(proxy_config, ProxyConfig) + else 60.0 + ) + key = endpoint_key(endpoint) + if enable_routing: + with self._config_lock: + self._endpoint_429_until[key] = ( + time.monotonic() + retry_after_seconds + ) + log( + "上游触发 429" + f"(节点={endpoint_index},总节点={len(api_endpoints)},retry-after={retry_after_text})" + ) + if enable_routing: + self._set_endpoint_cursor(next_cursor) + if enable_routing and attempt < len(endpoint_order) - 1: + log(f"切换到下一个节点 {endpoint_index} -> {endpoint_order[attempt + 1]}") + with contextlib.suppress(Exception): + response_from_target.close() + response_from_target = None + continue + + failover_status_codes = {400, 401, 403, 404, 408, 500, 502, 503, 504} + if ( + enable_routing + and response_from_target.status_code in failover_status_codes + and attempt < len(endpoint_order) - 1 + ): + raw_error_text = response_from_target.text + log( + "上游原始错误响应: " + f"status={response_from_target.status_code}, body={raw_error_text}" + ) + self._set_endpoint_cursor(next_cursor) + log( + f"上游返回错误状态码: {response_from_target.status_code},尝试切换节点" + ) + log( + f"切换到下一个节点 {endpoint_index} -> {endpoint_order[attempt + 1]}" + ) + with contextlib.suppress(Exception): + response_from_target.close() + response_from_target = None + continue + + response_from_target.raise_for_status() + # 只有请求成功才更新 cursor,确保下一次请求从下一个节点开始 + if enable_routing: + self._set_endpoint_cursor(next_cursor) + + with self._config_lock: + self._endpoint_429_until.pop(endpoint_key(endpoint), None) + if debug_mode: + log(f"上游响应状态码: {response_from_target.status_code}") + log(f"上游 Content-Type: {response_from_target.headers.get('content-type')}") + break + + if response_from_target is None: + raise requests.exceptions.RequestException("No available target API endpoint") if is_stream: log("返回流式响应") @@ -749,6 +936,13 @@ def simulate_stream() -> Generator[str]: except requests.exceptions.HTTPError as e: error_msg = f"目标 API HTTP 错误: {e.response.status_code} - {e.response.text}" log(error_msg) + if e.response.status_code == HTTP_STATUS_BAD_REQUEST: + with contextlib.suppress(Exception): + request_dump = json.dumps(request_data, indent=2, ensure_ascii=False) + log(f"--- 触发 400 错误的请求参数 ---\\n{request_dump}") + with contextlib.suppress(Exception): + if e.response is not None: + e.response.close() release_transport() return jsonify( {"error": f"Target API error: {e.response.status_code}", "details": e.response.text} diff --git a/python-src/modules/proxy/proxy_config.py b/python-src/modules/proxy/proxy_config.py index 4e9e81a..c864c40 100644 --- a/python-src/modules/proxy/proxy_config.py +++ b/python-src/modules/proxy/proxy_config.py @@ -3,7 +3,7 @@ import os from collections.abc import Callable from dataclasses import dataclass -from typing import Any +from typing import Any, cast import yaml @@ -14,9 +14,18 @@ type LogFunc = Callable[[str], None] +@dataclass(frozen=True) +class ProxyApiEndpoint: + api_url: str + api_key: str + target_model_id: str + middle_route: str = DEFAULT_MIDDLE_ROUTE + + @dataclass(frozen=True) class ProxyConfig: target_api_base_url: str + api_endpoints: tuple[ProxyApiEndpoint, ...] middle_route: str custom_model_id: str target_model_id: str @@ -25,6 +34,8 @@ class ProxyConfig: disable_ssl_strict_mode: bool api_key: str mtga_auth_key: str + enable_429_failover: bool + failover_429_cooldown_seconds: int def load_global_config( @@ -53,6 +64,97 @@ def _resolve_target_model_id(*, raw_config: dict[str, Any], custom_model_id: str return target_model_id if target_model_id else custom_model_id +def _parse_endpoint_from_group( + group: dict[str, Any], *, custom_model_id: str +) -> ProxyApiEndpoint | None: + url = (group.get("api_url") or "").strip() + if not url or url == PLACEHOLDER_API_URL: + return None + key = (group.get("api_key") or "").strip() + model = (group.get("model_id") or "").strip() or custom_model_id + route = normalize_middle_route(group.get("middle_route")) + return ProxyApiEndpoint( + api_url=url, + api_key=key, + target_model_id=model, + middle_route=route, + ) + + +def _extract_config_groups(global_config: dict[str, Any]) -> list[dict[str, Any]]: + config_groups: list[dict[str, Any]] = [] + raw_groups_obj = global_config.get("config_groups") + if not isinstance(raw_groups_obj, list): + return config_groups + for group_any in cast(list[object], raw_groups_obj): + if isinstance(group_any, dict): + config_groups.append(cast(dict[str, Any], group_any)) + return config_groups + + +def _extract_routing_group_ids(global_config: dict[str, Any]) -> list[str]: + routing_group_ids: list[str] = [] + routing_group_ids_raw = global_config.get("routing_group_ids") + if not isinstance(routing_group_ids_raw, list): + return routing_group_ids + for item in cast(list[object], routing_group_ids_raw): + if isinstance(item, str): + value = item.strip() + if value: + routing_group_ids.append(value) + return routing_group_ids + + +def _parse_api_endpoints( + *, + raw_config: dict[str, Any], + global_config: dict[str, Any], + custom_model_id: str, +) -> tuple[ProxyApiEndpoint, ...]: + enable_failover = bool(global_config.get("enable_429_failover", False)) + config_groups = _extract_config_groups(global_config) + routing_group_ids = _extract_routing_group_ids(global_config) + endpoints: list[ProxyApiEndpoint] = [] + endpoint_signatures: set[tuple[str, str, str, str]] = set() + + def append_unique(group: dict[str, Any]) -> None: + endpoint = _parse_endpoint_from_group(group, custom_model_id=custom_model_id) + if endpoint is None: + return + signature = ( + endpoint.api_url, + endpoint.api_key, + endpoint.target_model_id, + endpoint.middle_route, + ) + if signature in endpoint_signatures: + return + endpoint_signatures.add(signature) + endpoints.append(endpoint) + + selected_mode = bool(enable_failover and routing_group_ids) + if selected_mode: + selected_ids = set(routing_group_ids) + for group in config_groups: + group_id = str(group.get("id") or "").strip() + if group_id and group_id in selected_ids: + append_unique(group) + + # 如果启用了选择模式但结果为空(例如:选中组均无效),回退到仅使用当前激活配置 + # 这意味着如果用户开启轮询但没选任何组,就相当于没开启轮询 + if selected_mode and not endpoints: + selected_mode = False + # 清空以便重新添加单点 + endpoints.clear() + endpoint_signatures.clear() + + # 如果未启用选择模式(或回退),只添加当前激活的配置 + if not selected_mode: + append_unique(raw_config) + + return tuple(endpoints) + + def normalize_middle_route(value: str | None) -> str: raw_value = (value or "").strip() if not raw_value: @@ -66,6 +168,18 @@ def normalize_middle_route(value: str | None) -> str: return raw_value +def _parse_cooldown_seconds(global_config: dict[str, Any]) -> int: + try: + val = global_config.get("failover_429_cooldown_seconds") + if isinstance(val, (int, float)): + return max(1, int(val)) + if isinstance(val, str) and val.strip().isdigit(): + return max(1, int(val)) + return 60 + except (ValueError, TypeError): + return 60 + + def build_proxy_config( raw_config: dict[str, Any] | None, *, @@ -75,15 +189,24 @@ def build_proxy_config( raw_config = raw_config or {} global_config = load_global_config(resource_manager=resource_manager, log_func=log_func) - target_api_base_url = raw_config.get("api_url", PLACEHOLDER_API_URL) - if target_api_base_url == PLACEHOLDER_API_URL: - log_func("错误: 请在配置中设置正确的 API URL") - return None - custom_model_id = _resolve_custom_model_id( global_config=global_config, raw_config=raw_config, ) + + api_endpoints = _parse_api_endpoints( + raw_config=raw_config, + global_config=global_config, + custom_model_id=custom_model_id, + ) + if not api_endpoints: + log_func("错误: 没有可用的 API 端点") + return None + target_api_base_url = api_endpoints[0].api_url + if target_api_base_url == PLACEHOLDER_API_URL: + log_func("错误: 请在配置中设置正确的 API URL") + return None + target_model_id = _resolve_target_model_id( raw_config=raw_config, custom_model_id=custom_model_id, @@ -92,19 +215,23 @@ def build_proxy_config( return ProxyConfig( target_api_base_url=target_api_base_url, + api_endpoints=api_endpoints, middle_route=middle_route, custom_model_id=custom_model_id, target_model_id=target_model_id, stream_mode=raw_config.get("stream_mode"), debug_mode=bool(raw_config.get("debug_mode", False)), disable_ssl_strict_mode=bool(raw_config.get("disable_ssl_strict_mode", False)), - api_key=(raw_config.get("api_key") or ""), + api_key=api_endpoints[0].api_key, mtga_auth_key=(global_config.get("mtga_auth_key") or ""), + enable_429_failover=bool(global_config.get("enable_429_failover", False)), + failover_429_cooldown_seconds=_parse_cooldown_seconds(global_config), ) __all__ = [ "DEFAULT_MIDDLE_ROUTE", + "ProxyApiEndpoint", "ProxyConfig", "PLACEHOLDER_API_URL", "build_proxy_config", diff --git a/python-src/modules/proxy/proxy_transport.py b/python-src/modules/proxy/proxy_transport.py index 4dcab05..f25039a 100644 --- a/python-src/modules/proxy/proxy_transport.py +++ b/python-src/modules/proxy/proxy_transport.py @@ -107,11 +107,18 @@ def extract_sse_events( log_file = None buffer += chunk while True: - sep = buffer.find(b"\n\n") - if sep == -1: + sep_lf = buffer.find(b"\n\n") + sep_crlf = buffer.find(b"\r\n\r\n") + candidates = [pos for pos in (sep_lf, sep_crlf) if pos != -1] + if not candidates: break - event = buffer[:sep] - buffer = buffer[sep + 2 :] + sep = min(candidates) + if sep == sep_crlf: + event = buffer[:sep] + buffer = buffer[sep + 4 :] + else: + event = buffer[:sep] + buffer = buffer[sep + 2 :] yield chunk_index, event if buffer.strip(): log("警告: 上游 SSE 结束时存在未完整分隔的残留数据") diff --git a/python-src/modules/services/config_service.py b/python-src/modules/services/config_service.py index a5455f1..917246c 100644 --- a/python-src/modules/services/config_service.py +++ b/python-src/modules/services/config_service.py @@ -2,7 +2,7 @@ import os from dataclasses import dataclass -from typing import Any +from typing import Any, cast import yaml @@ -24,7 +24,7 @@ def load_config_groups(self) -> tuple[list[dict[str, Any]], int]: pass return [], 0 - def load_global_config(self) -> tuple[str, str]: + def load_global_config(self) -> tuple[str, str, bool, int, list[str]]: try: if os.path.exists(self.config_file): with open(self.config_file, encoding="utf-8") as f: @@ -32,17 +32,41 @@ def load_global_config(self) -> tuple[str, str]: if config: mapped_model_id = config.get("mapped_model_id", "") mtga_auth_key = config.get("mtga_auth_key", "") - return mapped_model_id, mtga_auth_key + enable_429_failover = bool(config.get("enable_429_failover", False)) + cooldown = config.get("failover_429_cooldown_seconds", 60) + routing_group_ids_raw = config.get("routing_group_ids") + routing_group_ids: list[str] = [] + if isinstance(routing_group_ids_raw, list): + seen: set[str] = set() + for item in cast(list[object], routing_group_ids_raw): + if not isinstance(item, str): + continue + group_id = item.strip() + if not group_id or group_id in seen: + continue + seen.add(group_id) + routing_group_ids.append(group_id) + try: + cooldown_seconds = max(1, int(cooldown or 60)) + except Exception: + cooldown_seconds = 60 + return ( + mapped_model_id, + mtga_auth_key, + enable_429_failover, + cooldown_seconds, + routing_group_ids, + ) except Exception: pass - return "", "" + return "", "", False, 60, [] def save_config_groups( self, config_groups: list[dict[str, Any]], current_index: int = 0, - mapped_model_id: str | None = None, - mtga_auth_key: str | None = None, + *, + global_config_updates: dict[str, Any] | None = None, ) -> bool: try: config_data: dict[str, Any] = {} @@ -53,10 +77,41 @@ def save_config_groups( config_data["config_groups"] = config_groups config_data["current_config_index"] = current_index - if mapped_model_id is not None: - config_data["mapped_model_id"] = mapped_model_id - if mtga_auth_key is not None: - config_data["mtga_auth_key"] = mtga_auth_key + if global_config_updates: + mapped_model_id = global_config_updates.get("mapped_model_id") + if mapped_model_id is not None: + config_data["mapped_model_id"] = mapped_model_id + + mtga_auth_key = global_config_updates.get("mtga_auth_key") + if mtga_auth_key is not None: + config_data["mtga_auth_key"] = mtga_auth_key + + enable_429_failover = global_config_updates.get("enable_429_failover") + if enable_429_failover is not None: + config_data["enable_429_failover"] = bool(enable_429_failover) + + failover_429_cooldown_seconds = global_config_updates.get( + "failover_429_cooldown_seconds" + ) + if failover_429_cooldown_seconds is not None: + config_data["failover_429_cooldown_seconds"] = max( + 1, int(failover_429_cooldown_seconds or 60) + ) + + routing_group_ids_raw = global_config_updates.get("routing_group_ids") + if routing_group_ids_raw is not None: + routing_group_ids: list[str] = [] + seen: set[str] = set() + if isinstance(routing_group_ids_raw, list): + for item in cast(list[object], routing_group_ids_raw): + if not isinstance(item, str): + continue + group_id = item.strip() + if not group_id or group_id in seen: + continue + seen.add(group_id) + routing_group_ids.append(group_id) + config_data["routing_group_ids"] = routing_group_ids os.makedirs(os.path.dirname(self.config_file), exist_ok=True) diff --git a/python-src/modules/services/proxy_orchestration.py b/python-src/modules/services/proxy_orchestration.py index 821a028..1e36d78 100644 --- a/python-src/modules/services/proxy_orchestration.py +++ b/python-src/modules/services/proxy_orchestration.py @@ -35,9 +35,9 @@ class GlobalConfigCheckResult: def ensure_global_config_ready( *, - load_global_config: Callable[[], tuple[str, str]], + load_global_config: Callable[[], tuple[str, str, bool, int, list[str]]], ) -> GlobalConfigCheckResult: - mapped_model_id, mtga_auth_key = load_global_config() + mapped_model_id, mtga_auth_key, *_ = load_global_config() mapped_model_id = (mapped_model_id or "").strip() mtga_auth_key = (mtga_auth_key or "").strip() diff --git a/python-src/modules/services/user_data_service.py b/python-src/modules/services/user_data_service.py index dd16f4b..42b749b 100644 --- a/python-src/modules/services/user_data_service.py +++ b/python-src/modules/services/user_data_service.py @@ -126,7 +126,7 @@ def find_latest_backup( if not backup_folders: raise NoBackupsError("未找到任何备份") - latest_backup = max(backup_folders, key=lambda x: os.path.basename(x)) + latest_backup = max(backup_folders, key=os.path.basename) backup_name = os.path.basename(latest_backup) return LatestBackupInfo(backup_name=backup_name, backup_path=latest_backup) diff --git a/python-src/mtga_app/__init__.py b/python-src/mtga_app/__init__.py index 7b1e014..cf14eba 100644 --- a/python-src/mtga_app/__init__.py +++ b/python-src/mtga_app/__init__.py @@ -258,6 +258,9 @@ class SaveConfigPayload(BaseModel): current_config_index: int mapped_model_id: str | None = None mtga_auth_key: str | None = None + routing_group_ids: list[str] | None = None + enable_429_failover: bool | None = None + failover_429_cooldown_seconds: int | None = None @lru_cache(maxsize=1) @@ -280,23 +283,39 @@ async def greet(body: GreetPayload) -> str: async def load_config() -> dict[str, Any]: config_store = _get_config_store() config_groups, current_index = config_store.load_config_groups() - mapped_model_id, mtga_auth_key = config_store.load_global_config() + mapped_model_id, mtga_auth_key, enable_429_failover, cooldown_seconds, routing_group_ids = ( + config_store.load_global_config() + ) return { "config_groups": config_groups, "current_config_index": current_index, "mapped_model_id": mapped_model_id, "mtga_auth_key": mtga_auth_key, + "routing_group_ids": routing_group_ids, + "enable_429_failover": enable_429_failover, + "failover_429_cooldown_seconds": cooldown_seconds, } @command_registry.command() async def save_config(body: SaveConfigPayload) -> bool: config_store = _get_config_store() + global_updates: dict[str, Any] = {} + if body.mapped_model_id is not None: + global_updates["mapped_model_id"] = body.mapped_model_id + if body.mtga_auth_key is not None: + global_updates["mtga_auth_key"] = body.mtga_auth_key + if body.routing_group_ids is not None: + global_updates["routing_group_ids"] = body.routing_group_ids + if body.enable_429_failover is not None: + global_updates["enable_429_failover"] = body.enable_429_failover + if body.failover_429_cooldown_seconds is not None: + global_updates["failover_429_cooldown_seconds"] = body.failover_429_cooldown_seconds + return config_store.save_config_groups( body.config_groups, body.current_config_index, - body.mapped_model_id, - body.mtga_auth_key, + global_config_updates=global_updates if global_updates else None, )