From fddb79dd56ae8f6b735e46ff7aa86764febfd972 Mon Sep 17 00:00:00 2001 From: Chenyme <118253778+chenyme@users.noreply.github.com> Date: Mon, 9 Feb 2026 14:38:59 +0800 Subject: [PATCH 01/27] refactor: streamline asset service methods by integrating reverse request handlers --- app/services/grok/services/assets.py | 178 ++++-------------- app/services/reverse/__init__.py | 17 ++ app/services/reverse/assets_delete.py | 102 +++++++++++ app/services/reverse/assets_download.py | 132 ++++++++++++++ app/services/reverse/assets_list.py | 104 +++++++++++ app/services/reverse/assets_upload.py | 111 ++++++++++++ app/services/reverse/utils/headers.py | 101 +++++++++++ app/services/reverse/utils/retry.py | 229 ++++++++++++++++++++++++ app/services/reverse/utils/statsig.py | 59 ++++++ 9 files changed, 887 insertions(+), 146 deletions(-) create mode 100644 app/services/reverse/__init__.py create mode 100644 app/services/reverse/assets_delete.py create mode 100644 app/services/reverse/assets_download.py create mode 100644 app/services/reverse/assets_list.py create mode 100644 app/services/reverse/assets_upload.py create mode 100644 app/services/reverse/utils/headers.py create mode 100644 app/services/reverse/utils/retry.py create mode 100644 app/services/reverse/utils/statsig.py diff --git a/app/services/grok/services/assets.py b/app/services/grok/services/assets.py index 9197f853..12d31d68 100644 --- a/app/services/grok/services/assets.py +++ b/app/services/grok/services/assets.py @@ -9,7 +9,6 @@ import re import time from contextlib import asynccontextmanager -from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, List, Optional, Tuple from urllib.parse import urlparse @@ -26,15 +25,15 @@ from app.core.exceptions import AppException, UpstreamException, ValidationException from app.core.logger import logger from app.core.storage import DATA_DIR -from app.services.grok.utils.headers import apply_statsig, build_sso_cookie -from app.services.token.service import TokenService +from app.services.reverse import ( + AssetsDeleteReverse, + AssetsDownloadReverse, + AssetsListReverse, + AssetsUploadReverse, +) # ==================== 常量 ==================== -UPLOAD_API = "https://grok.com/rest/app-chat/upload-file" -LIST_API = "https://grok.com/rest/assets" -DELETE_API = "https://grok.com/rest/assets-metadata" -DOWNLOAD_API = "https://assets.grok.com" LOCK_DIR = DATA_DIR / ".locks" # 全局信号量(运行时动态初始化) @@ -115,67 +114,15 @@ async def _file_lock(name: str, timeout: int = 10): fd.close() -@dataclass -class ServiceConfig: - """服务配置""" - - proxy: str - timeout: int - browser: str - user_agent: str - - @classmethod - def from_settings(cls, proxy: Optional[str] = None): - return cls( - proxy=proxy - or get_config("network.asset_proxy_url") - or get_config("network.base_proxy_url"), - timeout=get_config("network.timeout"), - browser=get_config("security.browser"), - user_agent=get_config("security.user_agent"), - ) - - def get_proxies(self) -> Optional[dict]: - return {"http": self.proxy, "https": self.proxy} if self.proxy else None - - # ==================== 基础服务 ==================== class BaseService: """基础服务类""" - def __init__(self, proxy: Optional[str] = None): - self.config = ServiceConfig.from_settings(proxy) + def __init__(self): self._session: Optional[AsyncSession] = None - def _build_headers( - self, token: str, referer: str = "https://grok.com/", download: bool = False - ) -> dict: - """构建请求头""" - if download: - headers = { - "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7", - "Sec-Fetch-Dest": "document", - "Sec-Fetch-Mode": "navigate", - "Sec-Fetch-Site": "same-site", - "Sec-Fetch-User": "?1", - "Referer": referer, - "User-Agent": self.config.user_agent, - } - else: - headers = { - "Accept": "*/*", - "Content-Type": "application/json", - "Origin": "https://grok.com", - "Referer": referer, - "User-Agent": self.config.user_agent, - } - apply_statsig(headers) - - headers["Cookie"] = build_sso_cookie(token) - return headers - async def _get_session(self) -> AsyncSession: """获取复用 Session""" if self._session is None: @@ -298,44 +245,19 @@ async def upload(self, file_input: str, token: str) -> Tuple[str, str]: # 执行上传 session = await self._get_session() - response = await session.post( - UPLOAD_API, - headers=self._build_headers(token), - json={"fileName": filename, "fileMimeType": mime, "content": b64}, - impersonate=self.config.browser, - timeout=self.config.timeout, - proxies=self.config.get_proxies(), + response = await AssetsUploadReverse.request( + session, + token, + filename, + mime, + b64, ) - # 处理响应 - if response.status_code == 200: - result = response.json() - file_id = result.get("fileMetadataId", "") - file_uri = result.get("fileUri", "") - logger.info(f"Upload success: {filename} -> {file_id}") - return file_id, file_uri - - # 认证失败 - if response.status_code in (401, 403): - logger.warning(f"Upload auth failed: {response.status_code}") - try: - await TokenService.record_fail( - token, response.status_code, "upload_auth_failed" - ) - except Exception as e: - logger.error(f"Failed to record token failure: {e}") - - raise UpstreamException( - message=f"Upload authentication failed: {response.status_code}", - details={"status": response.status_code, "token_invalidated": True}, - ) - - # 其他错误 - logger.error(f"Upload failed: {filename} - {response.status_code}") - raise UpstreamException( - message=f"Upload failed: {response.status_code}", - details={"status": response.status_code}, - ) + result = response.json() + file_id = result.get("fileMetadataId", "") + file_uri = result.get("fileUri", "") + logger.info(f"Upload success: {filename} -> {file_id}") + return file_id, file_uri # ==================== 列表服务 ==================== @@ -346,7 +268,6 @@ class ListService(BaseService): async def iter_assets(self, token: str): """分页迭代资产列表""" - headers = self._build_headers(token, referer="https://grok.com/files") params = { "pageSize": 50, "orderBy": "ORDER_BY_LAST_USE_TIME", @@ -367,21 +288,12 @@ async def iter_assets(self, token: str): else: params.pop("pageToken", None) - response = await session.get( - LIST_API, - headers=headers, - params=params, - impersonate=self.config.browser, - timeout=self.config.timeout, - proxies=self.config.get_proxies(), + response = await AssetsListReverse.request( + session, + token, + params, ) - if response.status_code != 200: - raise UpstreamException( - message=f"List failed: {response.status_code}", - details={"status": response.status_code}, - ) - result = response.json() page_assets = result.get("assets", []) yield page_assets @@ -417,28 +329,19 @@ async def delete(self, token: str, asset_id: str) -> bool: """删除单个文件""" async with _get_assets_semaphore(): session = await self._get_session() - response = await session.delete( - f"{DELETE_API}/{asset_id}", - headers=self._build_headers(token, referer="https://grok.com/files"), - impersonate=self.config.browser, - timeout=self.config.timeout, - proxies=self.config.get_proxies(), + response = await AssetsDeleteReverse.request( + session, + token, + asset_id, ) - if response.status_code == 200: - logger.debug(f"Deleted: {asset_id}") - return True - - logger.error(f"Delete failed: {asset_id} - {response.status_code}") - raise UpstreamException( - message=f"Delete failed: {asset_id}", - details={"status": response.status_code}, - ) + logger.debug(f"Deleted: {asset_id}") + return True async def delete_all(self, token: str) -> Dict[str, int]: """删除所有文件""" total = success = failed = 0 - list_service = ListService(self.config.proxy) + list_service = ListService() try: async for assets in list_service.iter_assets(token): @@ -496,8 +399,8 @@ async def _delete_one(self, token: str, asset: Dict, index: int) -> bool: class DownloadService(BaseService): """文件下载服务""" - def __init__(self, proxy: Optional[str] = None): - super().__init__(proxy) + def __init__(self): + super().__init__() self.base_dir = DATA_DIR / "tmp" self.image_dir = self.base_dir / "image" self.video_dir = self.base_dir / "video" @@ -552,25 +455,8 @@ async def _download_file(self, file_path: str, token: str, cache_path: Path) -> if not file_path.startswith("/"): file_path = f"/{file_path}" - url = f"{DOWNLOAD_API}{file_path}" - headers = self._build_headers(token, download=True) - session = await self._get_session() - response = await session.get( - url, - headers=headers, - proxies=self.config.get_proxies(), - timeout=self.config.timeout, - allow_redirects=True, - impersonate=self.config.browser, - stream=True, - ) - - if response.status_code != 200: - raise UpstreamException( - message=f"Download failed: {response.status_code}", - details={"path": file_path, "status": response.status_code}, - ) + response = await AssetsDownloadReverse.request(session, token, file_path) # 保存文件 tmp_path = cache_path.with_suffix(cache_path.suffix + ".tmp") diff --git a/app/services/reverse/__init__.py b/app/services/reverse/__init__.py new file mode 100644 index 00000000..b86b1fdf --- /dev/null +++ b/app/services/reverse/__init__.py @@ -0,0 +1,17 @@ +"""Reverse interfaces for Grok endpoints.""" + +from .assets_delete import AssetsDeleteReverse +from .assets_download import AssetsDownloadReverse +from .assets_list import AssetsListReverse +from .assets_upload import AssetsUploadReverse +from .utils.headers import build_headers +from .utils.statsig import StatsigGenerator + +__all__ = [ + "AssetsDeleteReverse", + "AssetsDownloadReverse", + "AssetsListReverse", + "AssetsUploadReverse", + "StatsigGenerator", + "build_headers", +] diff --git a/app/services/reverse/assets_delete.py b/app/services/reverse/assets_delete.py new file mode 100644 index 00000000..794f2ba1 --- /dev/null +++ b/app/services/reverse/assets_delete.py @@ -0,0 +1,102 @@ +""" +Reverse interface: delete asset metadata. +""" + +from typing import Any +from curl_cffi.requests import AsyncSession + +from app.core.logger import logger +from app.core.config import get_config +from app.core.exceptions import UpstreamException +from app.services.token.service import TokenService +from app.services.reverse.utils.headers import build_headers +from app.services.reverse.utils.retry import retry_on_status + +DELETE_API = "https://grok.com/rest/assets-metadata" + + +class AssetsDeleteReverse: + """/rest/assets-metadata/{file_id} reverse interface.""" + + @staticmethod + async def request(session: AsyncSession, token: str, asset_id: str) -> Any: + """Delete asset from Grok. + + Args: + session: AsyncSession, the session to use for the request. + token: str, the SSO token. + asset_id: str, the ID of the asset to delete. + + Returns: + Any: The response from the request. + """ + try: + # Get proxies + base_proxy = get_config("network.base_proxy_url") + assert_proxy = get_config("network.asset_proxy_url") + if assert_proxy: + proxies = {"http": assert_proxy, "https": assert_proxy} + else: + proxies = {"http": base_proxy, "https": base_proxy} + + # Build headers + headers = build_headers( + cookie_token=token, + content_type="application/json", + origin="https://grok.com", + referer="https://grok.com/files", + ) + + # Curl Config + timeout = get_config("network.timeout") + browser = get_config("security.browser") + + async def _do_request(): + response = await session.delete( + f"{DELETE_API}/{asset_id}", + headers=headers, + proxies=proxies, + timeout=timeout, + impersonate=browser, + ) + + if response.status_code != 200: + logger.error( + f"AssetsDeleteReverse: Delete failed, {response.status_code}", + extra={"error_type": "UpstreamException"}, + ) + raise UpstreamException( + message=f"AssetsDeleteReverse: Delete failed, {response.status_code}", + details={"status": response.status_code}, + ) + + return response + + return await retry_on_status(_do_request) + + except Exception as e: + # Handle upstream exception + if isinstance(e, UpstreamException): + status = None + if e.details and "status" in e.details: + status = e.details["status"] + else: + status = getattr(e, "status_code", None) + if status == 401: + try: + await TokenService.record_fail(token, status, "assets_delete_auth_failed") + except Exception: + pass + raise + + # Handle other non-upstream exceptions + logger.error( + f"AssetsDeleteReverse: Delete failed, {str(e)}", + extra={"error_type": type(e).__name__}, + ) + raise UpstreamException( + message=f"AssetsDeleteReverse: Delete failed, {str(e)}", + details={"status": 502, "error": str(e)}, + ) + +__all__ = ["AssetsDeleteReverse"] diff --git a/app/services/reverse/assets_download.py b/app/services/reverse/assets_download.py new file mode 100644 index 00000000..7b4addd5 --- /dev/null +++ b/app/services/reverse/assets_download.py @@ -0,0 +1,132 @@ +""" +Reverse interface: download asset. +""" + +import urllib.parse +from typing import Any +from pathlib import Path +from curl_cffi.requests import AsyncSession + +from app.core.logger import logger +from app.core.config import get_config +from app.core.exceptions import UpstreamException +from app.services.token.service import TokenService +from app.services.reverse.utils.headers import build_headers +from app.services.reverse.utils.retry import retry_on_status + +DOWNLOAD_API = "https://assets.grok.com" + +_CONTENT_TYPES = { + ".jpg": "image/jpeg", + ".jpeg": "image/jpeg", + ".png": "image/png", + ".webp": "image/webp", + ".mp4": "video/mp4", + ".webm": "video/webm", +} + + +class AssetsDownloadReverse: + """assets.grok.com/{path} reverse interface.""" + + @staticmethod + async def request(session: AsyncSession, token: str, file_path: str) -> Any: + """Download asset from Grok. + + Args: + session: AsyncSession, the session to use for the request. + token: str, the SSO token. + file_path: str, the path of the file to download. + + Returns: + Any: The response from the request. + """ + try: + # Normalize path + if not file_path.startswith("/"): + file_path = f"/{file_path}" + url = f"{DOWNLOAD_API}{file_path}" + + # Get proxies + base_proxy = get_config("network.base_proxy_url") + assert_proxy = get_config("network.asset_proxy_url") + if assert_proxy: + proxies = {"http": assert_proxy, "https": assert_proxy} + else: + proxies = {"http": base_proxy, "https": base_proxy} + + # Guess content type by extension for Accept/Sec-Fetch-Dest + content_type = _CONTENT_TYPES.get(Path(urllib.parse.urlparse(file_path).path).suffix.lower()) + + # Build headers + headers = build_headers( + cookie_token=token, + content_type=content_type, + origin="https://grok.com", + referer="https://grok.com/", + ) + ## Align with browser download navigation headers + headers["Cache-Control"] = "no-cache" + headers["Pragma"] = "no-cache" + headers["Priority"] = "u=0, i" + headers["Sec-Fetch-Mode"] = "navigate" + headers["Sec-Fetch-User"] = "?1" + headers["Upgrade-Insecure-Requests"] = "1" + + # Curl Config + timeout = get_config("network.timeout") + browser = get_config("security.browser") + + async def _do_request(): + response = await session.get( + url, + headers=headers, + proxies=proxies, + timeout=timeout, + allow_redirects=True, + impersonate=browser, + stream=True, + ) + + if response.status_code != 200: + logger.error( + f"AssetsDownloadReverse: Download failed, {response.status_code}", + extra={"error_type": "UpstreamException"}, + ) + raise UpstreamException( + message=f"AssetsDownloadReverse: Download failed, {response.status_code}", + details={"status": response.status_code}, + ) + + return response + + return await retry_on_status(_do_request) + + except Exception as e: + # Handle upstream exception + if isinstance(e, UpstreamException): + status = None + if e.details and "status" in e.details: + status = e.details["status"] + else: + status = getattr(e, "status_code", None) + + if status == 401: + try: + await TokenService.record_fail(token, status, "assets_download_auth_failed") + except Exception: + pass + raise + + # Handle other non-upstream exceptions + logger.error( + f"AssetsDownloadReverse: Download failed, {str(e)}", + extra={"error_type": type(e).__name__}, + ) + raise UpstreamException( + message=f"AssetsDownloadReverse: Download failed, {str(e)}", + details={"status": 502, "error": str(e)}, + ) + + +__all__ = ["AssetsDownloadReverse"] diff --git a/app/services/reverse/assets_list.py b/app/services/reverse/assets_list.py new file mode 100644 index 00000000..07263725 --- /dev/null +++ b/app/services/reverse/assets_list.py @@ -0,0 +1,104 @@ +""" +Reverse interface: list assets. +""" + +from typing import Any, Dict +from curl_cffi.requests import AsyncSession + +from app.core.logger import logger +from app.core.config import get_config +from app.core.exceptions import UpstreamException +from app.services.token.service import TokenService +from app.services.reverse.utils.headers import build_headers +from app.services.reverse.utils.retry import retry_on_status + +LIST_API = "https://grok.com/rest/assets" + + +class AssetsListReverse: + """/rest/assets reverse interface.""" + + @staticmethod + async def request(session: AsyncSession, token: str, params: Dict[str, Any]) -> Any: + """List assets from Grok. + + Args: + session: AsyncSession, the session to use for the request. + token: str, the SSO token. + params: Dict[str, Any], the parameters for the request. + + Returns: + Any: The response from the request. + """ + try: + # Get proxies + base_proxy = get_config("network.base_proxy_url") + assert_proxy = get_config("network.asset_proxy_url") + if assert_proxy: + proxies = {"http": assert_proxy, "https": assert_proxy} + else: + proxies = {"http": base_proxy, "https": base_proxy} + + # Build headers + headers = build_headers( + cookie_token=token, + content_type="application/json", + origin="https://grok.com", + referer="https://grok.com/files", + ) + + # Curl Config + timeout = get_config("network.timeout") + browser = get_config("security.browser") + + async def _do_request(): + response = await session.get( + LIST_API, + headers=headers, + params=params, + proxies=proxies, + timeout=timeout, + impersonate=browser, + ) + + if response.status_code != 200: + logger.error( + f"AssetsListReverse: List failed, {response.status_code}", + extra={"error_type": "UpstreamException"}, + ) + raise UpstreamException( + message=f"AssetsListReverse: List failed, {response.status_code}", + details={"status": response.status_code}, + ) + + return response + + return await retry_on_status(_do_request) + + except Exception as e: + # Handle upstream exception + if isinstance(e, UpstreamException): + status = None + if e.details and "status" in e.details: + status = e.details["status"] + else: + status = getattr(e, "status_code", None) + if status == 401: + try: + await TokenService.record_fail(token, status, "assets_list_auth_failed") + except Exception: + pass + raise + + # Handle other non-upstream exceptions + logger.error( + f"AssetsListReverse: List failed, {str(e)}", + extra={"error_type": type(e).__name__}, + ) + raise UpstreamException( + message=f"AssetsListReverse: List failed, {str(e)}", + details={"status": 502, "error": str(e)}, + ) + + +__all__ = ["AssetsListReverse"] diff --git a/app/services/reverse/assets_upload.py b/app/services/reverse/assets_upload.py new file mode 100644 index 00000000..e4f65acf --- /dev/null +++ b/app/services/reverse/assets_upload.py @@ -0,0 +1,111 @@ +""" +Reverse interface: upload asset. +""" + +from typing import Any +from curl_cffi.requests import AsyncSession + +from app.core.logger import logger +from app.core.exceptions import UpstreamException +from app.services.token.service import TokenService +from app.core.config import get_config +from app.services.reverse.utils.headers import build_headers +from app.services.reverse.utils.retry import retry_on_status + +UPLOAD_API = "https://grok.com/rest/app-chat/upload-file" + + +class AssetsUploadReverse: + """/rest/app-chat/upload-file reverse interface.""" + + @staticmethod + async def request(session: AsyncSession, token: str, fileName: str, fileMimeType: str, content: str) -> Any: + """Upload asset to Grok. + + Args: + session: AsyncSession, the session to use for the request. + token: str, the SSO token. + fileName: str, the name of the file. + fileMimeType: str, the MIME type of the file. + content: str, the content of the file. + + Returns: + Any: The response from the request. + """ + try: + # Get proxies + base_proxy = get_config("network.base_proxy_url") + assert_proxy = get_config("network.asset_proxy_url") + if assert_proxy: + proxies = {"http": assert_proxy, "https": assert_proxy} + else: + proxies = {"http": base_proxy, "https": base_proxy} + + # Build headers + headers = build_headers( + cookie_token=token, + content_type="application/json", + origin="https://grok.com", + referer="https://grok.com/", + ) + + # Build payload + payload = { + "fileName": fileName, + "fileMimeType": fileMimeType, + "content": content, + } + + # Curl Config + timeout = get_config("network.timeout") + browser = get_config("security.browser") + + async def _do_request(): + response = await session.post( + UPLOAD_API, + headers=headers, + json=payload, + proxies=proxies, + timeout=timeout, + impersonate=browser, + ) + if response.status_code != 200: + logger.error( + f"AssetsUploadReverse: Upload failed, {response.status_code}", + extra={"error_type": "UpstreamException"}, + ) + raise UpstreamException( + message=f"AssetsUploadReverse: Upload failed, {response.status_code}", + details={"status": response.status_code}, + ) + return response + + return await retry_on_status(_do_request) + + except Exception as e: + # Handle upstream exception + if isinstance(e, UpstreamException): + status = None + if e.details and "status" in e.details: + status = e.details["status"] + else: + status = getattr(e, "status_code", None) + if status == 401: + try: + await TokenService.record_fail(token, status, "assets_upload_auth_failed") + except Exception: + pass + raise + + # Handle other non-upstream exceptions + logger.error( + f"AssetsUploadReverse: Upload failed, {str(e)}", + extra={"error_type": type(e).__name__}, + ) + raise UpstreamException( + message=f"AssetsUploadReverse: Upload failed, {str(e)}", + details={"status": 502, "error": str(e)}, + ) + + +__all__ = ["AssetsUploadReverse"] diff --git a/app/services/reverse/utils/headers.py b/app/services/reverse/utils/headers.py new file mode 100644 index 00000000..874425fd --- /dev/null +++ b/app/services/reverse/utils/headers.py @@ -0,0 +1,101 @@ +"""Shared header builders for reverse interfaces.""" + +import uuid +import orjson +from typing import Dict, Optional + +from app.core.config import get_config +from app.core.logger import logger +from app.services.reverse.utils.statsig import StatsigGenerator + + +def _build_sso_cookie(sso_token: str) -> str: + """ + Build SSO Cookie string. + """ + # Remove "sso=" prefix if present + sso_token = sso_token[4:] if sso_token.startswith("sso=") else sso_token + + # SSO Cookie + cookie = f"sso={sso_token}; sso-rw={sso_token}" + + # CF Clearance + cf_clearance = get_config("security.cf_clearance") + if cf_clearance: + cookie += f";cf_clearance={cf_clearance}" + + return cookie + + +def build_headers( + cookie_token: str, + content_type: Optional[str] = None, + origin: Optional[str] = None, + referer: Optional[str] = None, +) -> Dict[str, str]: + """ + Build headers for reverse interfaces. + + Args: + cookie_token: The SSO token. + content_type: Optional Content-Type value. + origin: Optional Origin value. Defaults to "https://grok.com" if not provided. + referer: Optional Referer value. Defaults to "https://grok.com/" if not provided. + + Returns: + Dict[str, str]: The headers dictionary. + """ + headers = { + "Accept-Encoding": "gzip, deflate, br, zstd", + "Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8", + "Baggage": "sentry-environment=production,sentry-release=d6add6fb0460641fd482d767a335ef72b9b6abb8,sentry-public_key=b311e0f2690c81f25e2c4cf6d4f7ce1c", + "Origin": origin or "https://grok.com", + "Priority": "u=1, i", + "Referer": referer or "https://grok.com/", + "Sec-Ch-Ua": '"Google Chrome";v="136", "Chromium";v="136", "Not(A:Brand";v="24"', + "Sec-Ch-Ua-Arch": "arm", + "Sec-Ch-Ua-Bitness": "64", + "Sec-Ch-Ua-Mobile": "?0", + "Sec-Ch-Ua-Model": "", + "Sec-Ch-Ua-Platform": '"macOS"', + "Sec-Fetch-Mode": "cors", + "User-Agent": get_config("security.user_agent"), + } + + # Cookie + headers["Cookie"] = _build_sso_cookie(cookie_token) + + # Content-Type and Accept/Sec-Fetch-Dest + if content_type and content_type == "application/json": + headers["Content-Type"] = "application/json" + headers["Accept"] = "*/*" + headers["Sec-Fetch-Dest"] = "empty" + elif content_type in ["image/jpeg", "image/png", "video/mp4", "video/webm"]: + headers["Content-Type"] = content_type + headers["Accept"] = "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7" + headers["Sec-Fetch-Dest"] = "document" + else: + headers["Content-Type"] = "application/json" + headers["Accept"] = "*/*" + headers["Sec-Fetch-Dest"] = "empty" + + # Sec-Fetch-Site + if headers["Origin"] == headers["Referer"]: + headers["Sec-Fetch-Site"] = "same-origin" + else: + headers["Sec-Fetch-Site"] = "same-site" + + # X-Statsig-ID and X-XAI-Request-ID + headers["x-statsig-id"] = StatsigGenerator.gen_id() + headers["x-xai-request-id"] = str(uuid.uuid4()) + + # Print headers without Cookie + safe_headers = dict(headers) + if "Cookie" in safe_headers: + safe_headers["Cookie"] = "" + logger.debug(f"Built headers: {orjson.dumps(safe_headers, indent=2)}") + + return headers + + +__all__ = ["build_headers"] diff --git a/app/services/reverse/utils/retry.py b/app/services/reverse/utils/retry.py new file mode 100644 index 00000000..0de15b6f --- /dev/null +++ b/app/services/reverse/utils/retry.py @@ -0,0 +1,229 @@ +""" +Reverse retry utilities. +""" + +import asyncio +import random +from typing import Callable, Any, Optional + +from app.core.logger import logger +from app.core.config import get_config +from app.core.exceptions import UpstreamException + + +class RetryContext: + """Retry context.""" + + def __init__(self): + self.attempt = 0 + self.max_retry = int(get_config("retry.max_retry")) + self.retry_codes = get_config("retry.retry_status_codes") + self.last_error = None + self.last_status = None + self.total_delay = 0.0 + self.retry_budget = float(get_config("retry.retry_budget")) + + # Backoff parameters + self.backoff_base = float(get_config("retry.retry_backoff_base")) + self.backoff_factor = float(get_config("retry.retry_backoff_factor")) + self.backoff_max = float(get_config("retry.retry_backoff_max")) + + # Decorrelated jitter state + self._last_delay = self.backoff_base + + def should_retry(self, status_code: int) -> bool: + """Check if should retry.""" + if self.attempt >= self.max_retry: + return False + if status_code not in self.retry_codes: + return False + if self.total_delay >= self.retry_budget: + return False + return True + + def record_error(self, status_code: int, error: Exception): + """Record error information.""" + self.last_status = status_code + self.last_error = error + self.attempt += 1 + + def calculate_delay(self, status_code: int, retry_after: Optional[float] = None) -> float: + """ + Calculate backoff delay time. + + Args: + status_code: HTTP status code + retry_after: Retry-After header value (seconds) + + Returns: + Delay time (seconds) + """ + # Use Retry-After if available + if retry_after is not None and retry_after > 0: + delay = min(retry_after, self.backoff_max) + self._last_delay = delay + return delay + + # Use decorrelated jitter for 429 + if status_code == 429: + # decorrelated jitter: delay = random(base, last_delay * 3) + delay = random.uniform(self.backoff_base, self._last_delay * 3) + delay = min(delay, self.backoff_max) + self._last_delay = delay + return delay + + # Use exponential backoff + full jitter for other status codes + exp_delay = self.backoff_base * (self.backoff_factor**self.attempt) + delay = random.uniform(0, min(exp_delay, self.backoff_max)) + return delay + + def record_delay(self, delay: float): + """Record delay time.""" + self.total_delay += delay + + +def extract_retry_after(error: Exception) -> Optional[float]: + """ + Extract Retry-After value from exception. + + Args: + error: Exception object + + Returns: + Retry-After value (seconds), or None + """ + if not isinstance(error, UpstreamException): + return None + + details = error.details or {} + + # Try to get Retry-After from details + retry_after = details.get("retry_after") + if retry_after is not None: + try: + return float(retry_after) + except (ValueError, TypeError): + pass + + # Try to get Retry-After from headers + headers = details.get("headers", {}) + if isinstance(headers, dict): + retry_after = headers.get("Retry-After") or headers.get("retry-after") + if retry_after is not None: + try: + return float(retry_after) + except (ValueError, TypeError): + pass + + return None + + +async def retry_on_status( + func: Callable, + *args, + extract_status: Callable[[Exception], Optional[int]] = None, + on_retry: Callable[[int, int, Exception, float], None] = None, + **kwargs, +) -> Any: + """ + Generic retry function. + + Args: + func: Retry function + *args: Function arguments + extract_status: Function to extract status code from exception + on_retry: Callback function for retry (attempt, status_code, error, delay) + **kwargs: Function keyword arguments + + Returns: + Function execution result + + Raises: + Last failed exception + """ + ctx = RetryContext() + + # Status code extractor + if extract_status is None: + + def extract_status(e: Exception) -> Optional[int]: + if isinstance(e, UpstreamException): + # Try to get status code from details, fallback to status_code attribute + if e.details and "status" in e.details: + return e.details["status"] + return getattr(e, "status_code", None) + return None + + while ctx.attempt <= ctx.max_retry: + try: + result = await func(*args, **kwargs) + + # Record log + if ctx.attempt > 0: + logger.info( + f"Retry succeeded after {ctx.attempt} attempts, " + f"total delay: {ctx.total_delay:.2f}s" + ) + + return result + + except Exception as e: + # Extract status code + status_code = extract_status(e) + + if status_code is None: + # Error cannot be identified as retryable + logger.error(f"Non-retryable error: {e}") + raise + + # Record error + ctx.record_error(status_code, e) + + # Check if should retry + if ctx.should_retry(status_code): + # Extract Retry-After + retry_after = extract_retry_after(e) + + # Calculate delay + delay = ctx.calculate_delay(status_code, retry_after) + + # Check if exceeds budget + if ctx.total_delay + delay > ctx.retry_budget: + logger.warning( + f"Retry budget exhausted: {ctx.total_delay:.2f}s + {delay:.2f}s > {ctx.retry_budget}s" + ) + raise + + ctx.record_delay(delay) + + logger.warning( + f"Retry {ctx.attempt}/{ctx.max_retry} for status {status_code}, " + f"waiting {delay:.2f}s (total: {ctx.total_delay:.2f}s)" + + (f", Retry-After: {retry_after}s" if retry_after else "") + ) + + # Callback + if on_retry: + on_retry(ctx.attempt, status_code, e, delay) + + await asyncio.sleep(delay) + continue + else: + # Not retryable or retry budget exhausted + if status_code in ctx.retry_codes: + logger.error( + f"Retry exhausted after {ctx.attempt} attempts, " + f"last status: {status_code}, total delay: {ctx.total_delay:.2f}s" + ) + else: + logger.error(f"Non-retryable status code: {status_code}") + + # Raise last failed exception + raise + + +__all__ = [ + "RetryContext", + "retry_on_status", + "extract_retry_after", +] diff --git a/app/services/reverse/utils/statsig.py b/app/services/reverse/utils/statsig.py new file mode 100644 index 00000000..28280a03 --- /dev/null +++ b/app/services/reverse/utils/statsig.py @@ -0,0 +1,59 @@ +""" +Statsig ID generator for reverse interfaces. +""" + +import base64 +import random +import string + +from app.core.config import get_config +from app.core.logger import logger + + +STATIC_STATSIG_ID = "ZTpUeXBlRXJyb3I6IENhbm5vdCByZWFkIHByb3BlcnRpZXMgb2YgdW5kZWZpbmVkIChyZWFkaW5nICdjaGlsZE5vZGVzJyk=" + + +class StatsigGenerator: + """Statsig ID generator for reverse interfaces.""" + + @staticmethod + def _rand(length: int, alphanumeric: bool = False) -> str: + """Generate random string.""" + chars = ( + string.ascii_lowercase + string.digits + if alphanumeric + else string.ascii_lowercase + ) + return "".join(random.choices(chars, k=length)) + + @staticmethod + def gen_id() -> str: + """ + Generate Statsig ID. + + Returns: + Base64 encoded ID. + """ + dynamic = get_config("chat.dynamic_statsig") + + # Dynamic Statsig ID + if dynamic: + logger.debug("Generating dynamic Statsig ID") + + if random.choice([True, False]): + rand = StatsigGenerator._rand(5, alphanumeric=True) + message = f"e:TypeError: Cannot read properties of null (reading 'children['{rand}']')" + else: + rand = StatsigGenerator._rand(10) + message = ( + f"e:TypeError: Cannot read properties of undefined (reading '{rand}')" + ) + + return base64.b64encode(message.encode()).decode() + + # Static Statsig ID + logger.debug("Generating static Statsig ID") + return STATIC_STATSIG_ID + + +__all__ = ["StatsigGenerator"] From d3a1b09d3a6184a84632a98f62bf8507f3d39928 Mon Sep 17 00:00:00 2001 From: Chenyme <118253778+chenyme@users.noreply.github.com> Date: Mon, 9 Feb 2026 16:41:43 +0800 Subject: [PATCH 02/27] refactor: integrate reverse request handlers for image and video services --- app/api/v1/image.py | 26 +-- app/services/grok/services/chat.py | 195 ++++------------------ app/services/grok/services/media.py | 113 +++---------- app/services/reverse/__init__.py | 4 + app/services/reverse/app_chat.py | 208 ++++++++++++++++++++++++ app/services/reverse/assets_download.py | 2 +- app/services/reverse/media_post.py | 108 ++++++++++++ app/services/reverse/utils/headers.py | 5 +- 8 files changed, 379 insertions(+), 282 deletions(-) create mode 100644 app/services/reverse/app_chat.py create mode 100644 app/services/reverse/media_post.py diff --git a/app/api/v1/image.py b/app/api/v1/image.py index ec207c18..bc47b16d 100644 --- a/app/api/v1/image.py +++ b/app/api/v1/image.py @@ -644,25 +644,7 @@ async def edit_image( parent_post_id ) - raw_payload = { - "temporary": bool(get_config("chat.temporary")), - "modelName": model_info.grok_model, - "message": edit_request.prompt, - "enableImageGeneration": True, - "returnImageBytes": False, - "returnRawGrokInXaiRequest": False, - "enableImageStreaming": True, - "imageGenerationCount": 2, - "forceConcise": False, - "toolOverrides": {"imageGen": True}, - "enableSideBySide": True, - "sendFinalMetadata": True, - "isReasoning": False, - "disableTextFollowUps": True, - "responseMetadata": {"modelConfigOverride": model_config_override}, - "disableMemory": False, - "forceSideBySide": False, - } + tool_overrides = {"imageGen": True} # 流式模式 if edit_request.stream: @@ -673,7 +655,8 @@ async def edit_image( model=model_info.grok_model, mode=None, stream=True, - raw_payload=raw_payload, + tool_overrides=tool_overrides, + model_config_override=model_config_override, ) processor = ImageStreamProcessor( @@ -703,7 +686,8 @@ async def _call_edit(): model=model_info.grok_model, mode=None, stream=True, - raw_payload=raw_payload, + tool_overrides=tool_overrides, + model_config_override=model_config_override, ) processor = ImageCollectProcessor( model_info.model_id, token, response_format=response_format diff --git a/app/services/grok/services/chat.py b/app/services/grok/services/chat.py index dcd36ccc..26112f69 100644 --- a/app/services/grok/services/chat.py +++ b/app/services/grok/services/chat.py @@ -2,7 +2,6 @@ Grok Chat 服务 """ -import orjson from typing import Dict, List, Any from dataclasses import dataclass @@ -12,22 +11,17 @@ from app.core.config import get_config from app.core.exceptions import ( AppException, - UpstreamException, ValidationException, ErrorType, ) from app.services.grok.models.model import ModelService from app.services.grok.services.assets import UploadService from app.services.grok.processors import StreamProcessor, CollectProcessor -from app.services.grok.utils.retry import retry_on_status -from app.services.grok.utils.headers import apply_statsig, build_sso_cookie +from app.services.reverse import AppChatReverse from app.services.grok.utils.stream import wrap_stream_with_usage from app.services.token import get_token_manager, EffortType -CHAT_API = "https://grok.com/rest/app-chat/conversations/new" - - @dataclass class ChatRequest: """聊天请求数据""" @@ -122,38 +116,6 @@ def extract( class ChatRequestBuilder: """请求构造器""" - @staticmethod - def build_headers(token: str) -> Dict[str, str]: - """构造请求头""" - user_agent = get_config("security.user_agent") - headers = { - "Accept": "*/*", - "Accept-Encoding": "gzip, deflate, br, zstd", - "Accept-Language": "zh-CN,zh;q=0.9", - "Baggage": "sentry-environment=production,sentry-release=d6add6fb0460641fd482d767a335ef72b9b6abb8,sentry-public_key=b311e0f2690c81f25e2c4cf6d4f7ce1c", - "Cache-Control": "no-cache", - "Content-Type": "application/json", - "Origin": "https://grok.com", - "Pragma": "no-cache", - "Priority": "u=1, i", - "Referer": "https://grok.com/", - "Sec-Ch-Ua": '"Google Chrome";v="136", "Chromium";v="136", "Not(A:Brand";v="24"', - "Sec-Ch-Ua-Arch": "arm", - "Sec-Ch-Ua-Bitness": "64", - "Sec-Ch-Ua-Mobile": "?0", - "Sec-Ch-Ua-Model": "", - "Sec-Ch-Ua-Platform": '"macOS"', - "Sec-Fetch-Dest": "empty", - "Sec-Fetch-Mode": "cors", - "Sec-Fetch-Site": "same-origin", - "User-Agent": user_agent, - } - - apply_statsig(headers) - headers["Cookie"] = build_sso_cookie(token) - - return headers - @staticmethod def build_payload( message: str, @@ -163,53 +125,20 @@ def build_payload( image_attachments: List[str] = None, ) -> Dict[str, Any]: """构造请求体""" - merged_attachments = [] - if file_attachments: - merged_attachments.extend(file_attachments) - if image_attachments: - merged_attachments.extend(image_attachments) - - payload = { - "temporary": get_config("chat.temporary"), - "modelName": model, - "message": message, - "fileAttachments": merged_attachments, - "imageAttachments": [], - "disableSearch": False, - "enableImageGeneration": True, - "returnImageBytes": False, - "enableImageStreaming": True, - "imageGenerationCount": 2, - "forceConcise": False, - "toolOverrides": {}, - "enableSideBySide": True, - "sendFinalMetadata": True, - "responseMetadata": { - "modelConfigOverride": {"modelMap": {}}, - "requestModelDetails": {"modelId": model}, - }, - "disableMemory": get_config("chat.disable_memory"), - "deviceEnvInfo": { - "darkModeEnabled": False, - "devicePixelRatio": 2, - "screenWidth": 2056, - "screenHeight": 1329, - "viewportWidth": 2056, - "viewportHeight": 1083, - }, - } - - if mode: - payload["modelMode"] = mode - - return payload + return AppChatReverse.build_payload( + message=message, + model=model, + mode=mode, + file_attachments=file_attachments, + image_attachments=image_attachments, + ) class GrokChatService: """Grok API 调用服务""" - def __init__(self, proxy: str = None): - self.proxy = proxy or get_config("network.base_proxy_url") + def __init__(self): + pass async def chat( self, @@ -220,105 +149,37 @@ async def chat( stream: bool = None, file_attachments: List[str] = None, image_attachments: List[str] = None, - raw_payload: Dict[str, Any] = None, + tool_overrides: Dict[str, Any] = None, + model_config_override: Dict[str, Any] = None, ): """发送聊天请求""" if stream is None: stream = get_config("chat.stream") - headers = ChatRequestBuilder.build_headers(token) - payload = ( - raw_payload - if raw_payload is not None - else ChatRequestBuilder.build_payload( - message, model, mode, file_attachments, image_attachments - ) - ) - proxies = {"http": self.proxy, "https": self.proxy} if self.proxy else None - timeout = get_config("network.timeout") - logger.debug( f"Chat request: model={model}, mode={mode}, stream={stream}, attachments={len(file_attachments or [])}" ) - # 建立连接 - async def establish_connection(): - browser = get_config("security.browser") - session = AsyncSession(impersonate=browser) - try: - response = await session.post( - CHAT_API, - headers=headers, - data=orjson.dumps(payload), - timeout=timeout, - stream=True, - proxies=proxies, - ) - - if response.status_code != 200: - content = "" - try: - content = await response.text() - except Exception: - pass - - logger.error( - f"Chat failed: status={response.status_code}, token={token[:10]}..." - ) - - await session.close() - raise UpstreamException( - message=f"Grok API request failed: {response.status_code}", - details={"status": response.status_code, "body": content}, - ) - - logger.info(f"Chat connected: model={model}, stream={stream}") - return session, response - - except UpstreamException: - raise - except Exception as e: - logger.error(f"Chat request error: {e}") - await session.close() - raise UpstreamException( - message=f"Chat connection failed: {str(e)}", - details={"error": str(e)}, - ) - - # 重试机制 - def extract_status(e: Exception) -> int | None: - if isinstance(e, UpstreamException) and e.details: - return e.details.get("status") - return None - - session = None - response = None + browser = get_config("security.browser") + session = AsyncSession(impersonate=browser) try: - session, response = await retry_on_status( - establish_connection, extract_status=extract_status + stream_response = await AppChatReverse.request( + session, + token, + message=message, + model=model, + mode=mode, + file_attachments=file_attachments, + image_attachments=image_attachments, + tool_overrides=tool_overrides, + model_config_override=model_config_override, ) - except Exception as e: - status_code = extract_status(e) - if status_code: - token_mgr = await get_token_manager() - reason = str(e) - if isinstance(e, UpstreamException) and e.details: - body = e.details.get("body") - if body: - reason = f"{reason} | body: {body}" - await token_mgr.record_fail(token, status_code, reason) + logger.info(f"Chat connected: model={model}, stream={stream}") + except Exception: + await session.close() raise - # 流式传输 - async def stream_response(): - try: - async for line in response.aiter_lines(): - yield line - finally: - if session: - await session.close() - - return stream_response() + return stream_response async def chat_openai(self, token: str, request: ChatRequest): """OpenAI 兼容接口""" diff --git a/app/services/grok/services/media.py b/app/services/grok/services/media.py index 79d9b677..50aa0254 100644 --- a/app/services/grok/services/media.py +++ b/app/services/grok/services/media.py @@ -3,9 +3,7 @@ """ import asyncio -from typing import AsyncGenerator, Optional - -import orjson +from typing import AsyncGenerator from curl_cffi.requests import AsyncSession from app.core.logger import logger @@ -19,11 +17,8 @@ from app.services.grok.models.model import ModelService from app.services.token import get_token_manager, EffortType from app.services.grok.processors import VideoStreamProcessor, VideoCollectProcessor -from app.services.grok.utils.headers import apply_statsig, build_sso_cookie from app.services.grok.utils.stream import wrap_stream_with_usage - -CREATE_POST_API = "https://grok.com/rest/media/post/create" -CHAT_API = "https://grok.com/rest/app-chat/conversations/new" +from app.services.reverse import AppChatReverse, MediaPostReverse _MEDIA_SEMAPHORE = None _MEDIA_SEM_VALUE = 0 @@ -42,47 +37,9 @@ def _get_semaphore() -> asyncio.Semaphore: class VideoService: """视频生成服务""" - def __init__(self, proxy: str = None): - self.proxy = proxy or get_config("network.base_proxy_url") + def __init__(self): self.timeout = get_config("network.timeout") - def _build_headers( - self, token: str, referer: str = "https://grok.com/imagine" - ) -> dict: - """构建请求头""" - user_agent = get_config("security.user_agent") - headers = { - "Accept": "*/*", - "Accept-Encoding": "gzip, deflate, br, zstd", - "Accept-Language": "zh-CN,zh;q=0.9", - "Baggage": "sentry-environment=production,sentry-release=d6add6fb0460641fd482d767a335ef72b9b6abb8,sentry-public_key=b311e0f2690c81f25e2c4cf6d4f7ce1c", - "Cache-Control": "no-cache", - "Content-Type": "application/json", - "Origin": "https://grok.com", - "Pragma": "no-cache", - "Priority": "u=1, i", - "Referer": referer, - "Sec-Ch-Ua": '"Google Chrome";v="136", "Chromium";v="136", "Not(A:Brand";v="24"', - "Sec-Ch-Ua-Arch": "arm", - "Sec-Ch-Ua-Bitness": "64", - "Sec-Ch-Ua-Mobile": "?0", - "Sec-Ch-Ua-Model": "", - "Sec-Ch-Ua-Platform": '"macOS"', - "Sec-Fetch-Dest": "empty", - "Sec-Fetch-Mode": "cors", - "Sec-Fetch-Site": "same-origin", - "User-Agent": user_agent, - } - - apply_statsig(headers) - headers["Cookie"] = build_sso_cookie(token) - - return headers - - def _build_proxies(self) -> Optional[dict]: - """构建代理""" - return {"http": self.proxy, "https": self.proxy} if self.proxy else None - async def create_post( self, token: str, @@ -92,28 +49,15 @@ async def create_post( ) -> str: """创建媒体帖子,返回 post ID""" try: - headers = self._build_headers(token) - - # 根据类型构建不同的载荷 - if media_type == "MEDIA_POST_TYPE_IMAGE" and media_url: - payload = {"mediaType": media_type, "mediaUrl": media_url} - else: - payload = {"mediaType": media_type, "prompt": prompt} + if media_type == "MEDIA_POST_TYPE_IMAGE" and not media_url: + raise ValidationException("media_url is required for image posts") async with AsyncSession() as session: - response = await session.post( - CREATE_POST_API, - headers=headers, - json=payload, - impersonate=get_config("security.browser"), - timeout=30, - proxies=self._build_proxies(), - ) - - if response.status_code != 200: - logger.error(f"Create post failed: {response.status_code}") - raise UpstreamException( - f"Failed to create post: {response.status_code}" + response = await MediaPostReverse.request( + session, + token, + media_type, + media_url or "", ) post_id = response.json().get("post", {}).get("id", "") @@ -198,40 +142,25 @@ async def _generate_internal( """内部生成逻辑""" session = None try: - headers = self._build_headers(token) payload = self._build_payload( prompt, post_id, aspect_ratio, video_length, resolution_name, preset ) - session = AsyncSession(impersonate=get_config("security.browser")) - response = await session.post( - CHAT_API, - headers=headers, - data=orjson.dumps(payload), - timeout=self.timeout, - stream=True, - proxies=self._build_proxies(), + session = AsyncSession() + stream_response = await AppChatReverse.request( + session, + token, + message=payload.get("message"), + model=payload.get("modelName"), + tool_overrides=payload.get("toolOverrides"), + model_config_override=( + (payload.get("responseMetadata") or {}).get("modelConfigOverride") + ), ) - if response.status_code != 200: - logger.error( - f"Video generation failed: status={response.status_code}, post_id={post_id}" - ) - raise UpstreamException( - message=f"Video generation failed: {response.status_code}", - details={"status": response.status_code}, - ) - logger.info(f"Video generation started: post_id={post_id}") - async def stream_response(): - try: - async for line in response.aiter_lines(): - yield line - finally: - await session.close() - - return stream_response() + return stream_response except Exception as e: if session: diff --git a/app/services/reverse/__init__.py b/app/services/reverse/__init__.py index b86b1fdf..158e51de 100644 --- a/app/services/reverse/__init__.py +++ b/app/services/reverse/__init__.py @@ -1,17 +1,21 @@ """Reverse interfaces for Grok endpoints.""" +from .app_chat import AppChatReverse from .assets_delete import AssetsDeleteReverse from .assets_download import AssetsDownloadReverse from .assets_list import AssetsListReverse from .assets_upload import AssetsUploadReverse +from .media_post import MediaPostReverse from .utils.headers import build_headers from .utils.statsig import StatsigGenerator __all__ = [ + "AppChatReverse", "AssetsDeleteReverse", "AssetsDownloadReverse", "AssetsListReverse", "AssetsUploadReverse", + "MediaPostReverse", "StatsigGenerator", "build_headers", ] diff --git a/app/services/reverse/app_chat.py b/app/services/reverse/app_chat.py new file mode 100644 index 00000000..9d3c28bd --- /dev/null +++ b/app/services/reverse/app_chat.py @@ -0,0 +1,208 @@ +""" +Reverse interface: app chat conversations. +""" + +import orjson +from typing import Any, Dict, List +from curl_cffi.requests import AsyncSession + +from app.core.logger import logger +from app.core.config import get_config +from app.core.exceptions import UpstreamException +from app.services.token.service import TokenService +from app.services.reverse.utils.headers import build_headers +from app.services.reverse.utils.retry import retry_on_status + +CHAT_API = "https://grok.com/rest/app-chat/conversations/new" + + +class AppChatReverse: + """/rest/app-chat/conversations/new reverse interface.""" + + @staticmethod + def build_payload( + message: str, + model: str, + mode: str = None, + file_attachments: List[str] = None, + image_attachments: List[str] = None, + tool_overrides: Dict[str, Any] = None, + model_config_override: Dict[str, Any] = None, + ) -> Dict[str, Any]: + """Build chat payload for Grok app-chat API.""" + + attachments = [] + if file_attachments: + attachments.extend(file_attachments) + if image_attachments: + attachments.extend(image_attachments) + + payload = { + "deviceEnvInfo": { + "darkModeEnabled": False, + "devicePixelRatio": 2, + "screenWidth": 2056, + "screenHeight": 1329, + "viewportWidth": 2056, + "viewportHeight": 1083, + }, + "disableMemory": get_config("chat.disable_memory"), + "disableSearch": False, + "disableSelfHarmShortCircuit": False, + "disableTextFollowUps": False, + "enableImageGeneration": True, + "enableImageStreaming": True, + "enableSideBySide": True, + "fileAttachments": attachments, + "forceConcise": False, + "forceSideBySide": False, + "imageAttachments": [], + "imageGenerationCount": 2, + "isAsyncChat": False, + "isReasoning": False, + "message": message, + "modelMode": mode, + "modelName": model, + "responseMetadata": { + "requestModelDetails": {"modelId": model}, + }, + "returnImageBytes": False, + "returnRawGrokInXaiRequest": False, + "sendFinalMetadata": True, + "temporary": get_config("chat.temporary"), + "toolOverrides": tool_overrides or {}, + } + + if model_config_override: + payload["responseMetadata"]["modelConfigOverride"] = model_config_override + + return payload + + @staticmethod + async def request( + session: AsyncSession, + token: str, + message: str, + model: str, + mode: str = None, + file_attachments: List[str] = None, + image_attachments: List[str] = None, + tool_overrides: Dict[str, Any] = None, + model_config_override: Dict[str, Any] = None, + ) -> Any: + """Send app chat request to Grok. + + Args: + session: AsyncSession, the session to use for the request. + token: str, the SSO token. + message: str, the message to send. + model: str, the model to use. + mode: str, the mode to use. + file_attachments: List[str], the file attachments to send. + image_attachments: List[str], the image attachments to send. + tool_overrides: Dict[str, Any], the tool overrides to use. + model_config_override: Dict[str, Any], the model config override to use. + + Returns: + Any: The response from the request. + """ + try: + # Get proxies + base_proxy = get_config("network.base_proxy_url") + proxies = {"http": base_proxy, "https": base_proxy} if base_proxy else None + + # Build headers + headers = build_headers( + cookie_token=token, + content_type="application/json", + origin="https://grok.com", + referer="https://grok.com/", + ) + + # Build payload + payload = AppChatReverse.build_payload( + message=message, + model=model, + mode=mode, + file_attachments=file_attachments, + image_attachments=image_attachments, + tool_overrides=tool_overrides, + model_config_override=model_config_override, + ) + + # Curl Config + timeout = get_config("network.timeout") + browser = get_config("security.browser") + + async def _do_request(): + response = await session.post( + CHAT_API, + headers=headers, + data=orjson.dumps(payload), + timeout=timeout, + stream=True, + proxies=proxies, + impersonate=browser, + ) + + if response.status_code != 200: + + # Get response content + content = "" + try: + content = await response.text() + except Exception: + pass + + logger.error( + f"AppChatReverse: Chat failed, {response.status_code}", + extra={"error_type": "UpstreamException"}, + ) + raise UpstreamException( + message=f"AppChatReverse: Chat failed, {response.status_code}", + details={"status": response.status_code, "body": content}, + ) + + return response + + response = await retry_on_status(_do_request) + + # Stream response + async def stream_response(): + try: + async for line in response.aiter_lines(): + yield line + finally: + await session.close() + + return stream_response() + + except Exception as e: + # Handle upstream exception + if isinstance(e, UpstreamException): + status = None + if e.details and "status" in e.details: + status = e.details["status"] + else: + status = getattr(e, "status_code", None) + if status == 401: + try: + await TokenService.record_fail( + token, status, "app_chat_auth_failed" + ) + except Exception: + pass + raise + + # Handle other non-upstream exceptions + logger.error( + f"AppChatReverse: Chat failed, {str(e)}", + extra={"error_type": type(e).__name__}, + ) + raise UpstreamException( + message=f"AppChatReverse: Chat failed, {str(e)}", + details={"status": 502, "error": str(e)}, + ) + + +__all__ = ["AppChatReverse"] diff --git a/app/services/reverse/assets_download.py b/app/services/reverse/assets_download.py index 7b4addd5..df0b32d1 100644 --- a/app/services/reverse/assets_download.py +++ b/app/services/reverse/assets_download.py @@ -62,7 +62,7 @@ async def request(session: AsyncSession, token: str, file_path: str) -> Any: headers = build_headers( cookie_token=token, content_type=content_type, - origin="https://grok.com", + origin="https://assets.grok.com", referer="https://grok.com/", ) ## Align with browser download navigation headers diff --git a/app/services/reverse/media_post.py b/app/services/reverse/media_post.py new file mode 100644 index 00000000..f7d358ea --- /dev/null +++ b/app/services/reverse/media_post.py @@ -0,0 +1,108 @@ +""" +Reverse interface: media post create. +""" + +import orjson +from typing import Any +from curl_cffi.requests import AsyncSession + +from app.core.logger import logger +from app.core.config import get_config +from app.core.exceptions import UpstreamException +from app.services.token.service import TokenService +from app.services.reverse.utils.headers import build_headers +from app.services.reverse.utils.retry import retry_on_status + +MEDIA_POST_API = "https://grok.com/rest/media/post/create" + + +class MediaPostReverse: + """/rest/media/post/create reverse interface.""" + + @staticmethod + async def request(session: AsyncSession, token: str, mediaType: str, mediaUrl: str) -> Any: + """Create media post in Grok. + + Args: + session: AsyncSession, the session to use for the request. + token: str, the SSO token. + mediaType: str, the media type. + mediaUrl: str, the media URL. + + Returns: + Any: The response from the request. + """ + try: + # Get proxies + base_proxy = get_config("network.base_proxy_url") + proxies = {"http": base_proxy, "https": base_proxy} if base_proxy else None + + # Build headers + headers = build_headers( + cookie_token=token, + content_type="application/json", + origin="https://grok.com", + referer="https://grok.com", + ) + + # Build payload + payload = { + "mediaType": mediaType, + "mediaUrl": mediaUrl, + } + + # Curl Config + timeout = get_config("network.timeout") + browser = get_config("security.browser") + + async def _do_request(): + response = await session.post( + MEDIA_POST_API, + headers=headers, + data=orjson.dumps(payload), + timeout=timeout, + proxies=proxies, + impersonate=browser, + ) + + if response.status_code != 200: + logger.error( + f"MediaPostReverse: Media post create failed, {response.status_code}", + extra={"error_type": "UpstreamException"}, + ) + raise UpstreamException( + message=f"MediaPostReverse: Media post create failed, {response.status_code}", + details={"status": response.status_code}, + ) + + return response + + return await retry_on_status(_do_request) + + except Exception as e: + # Handle upstream exception + if isinstance(e, UpstreamException): + status = None + if e.details and "status" in e.details: + status = e.details["status"] + else: + status = getattr(e, "status_code", None) + if status == 401: + try: + await TokenService.record_fail(token, status, "media_post_auth_failed") + except Exception: + pass + raise + + # Handle other non-upstream exceptions + logger.error( + f"MediaPostReverse: Media post create failed, {str(e)}", + extra={"error_type": type(e).__name__}, + ) + raise UpstreamException( + message=f"MediaPostReverse: Media post create failed, {str(e)}", + details={"status": 502, "error": str(e)}, + ) + + +__all__ = ["MediaPostReverse"] diff --git a/app/services/reverse/utils/headers.py b/app/services/reverse/utils/headers.py index 874425fd..fd7e74cc 100644 --- a/app/services/reverse/utils/headers.py +++ b/app/services/reverse/utils/headers.py @@ -2,6 +2,7 @@ import uuid import orjson +from urllib.parse import urlparse from typing import Dict, Optional from app.core.config import get_config @@ -80,7 +81,9 @@ def build_headers( headers["Sec-Fetch-Dest"] = "empty" # Sec-Fetch-Site - if headers["Origin"] == headers["Referer"]: + origin_domain = urlparse(headers.get("Origin", "")).hostname + referer_domain = urlparse(headers.get("Referer", "")).hostname + if origin_domain and referer_domain and origin_domain == referer_domain: headers["Sec-Fetch-Site"] = "same-origin" else: headers["Sec-Fetch-Site"] = "same-site" From d947ef69bafa89fc6a0ef04f529f11a9f64b2409 Mon Sep 17 00:00:00 2001 From: Chenyme <118253778+chenyme@users.noreply.github.com> Date: Mon, 9 Feb 2026 22:03:27 +0800 Subject: [PATCH 03/27] refactor: simplify usage service by removing model_name parameter and integrating reverse request handler for rate limits --- app/api/v1/admin.py | 8 +-- app/services/grok/services/usage.py | 108 +++------------------------- app/services/reverse/__init__.py | 2 + app/services/reverse/rate_limits.py | 108 ++++++++++++++++++++++++++++ app/services/token/manager.py | 4 +- app/services/token/service.py | 7 +- 6 files changed, 124 insertions(+), 113 deletions(-) create mode 100644 app/services/reverse/rate_limits.py diff --git a/app/api/v1/admin.py b/app/api/v1/admin.py index c27103ad..c2dc4808 100644 --- a/app/api/v1/admin.py +++ b/app/api/v1/admin.py @@ -864,9 +864,7 @@ async def refresh_tokens_api(data: dict): batch_size = get_config("performance.usage_batch_size") async def _refresh_one(t): - return await mgr.sync_usage( - t, "grok-3", consume_on_fail=False, is_usage=False - ) + return await mgr.sync_usage(t, consume_on_fail=False, is_usage=False) raw_results = await run_in_batches( unique_tokens, @@ -918,9 +916,7 @@ async def _run(): try: async def _refresh_one(t: str): - return await mgr.sync_usage( - t, "grok-3", consume_on_fail=False, is_usage=False - ) + return await mgr.sync_usage(t, consume_on_fail=False, is_usage=False) async def _on_item(item: str, res: dict): task.record(bool(res.get("ok"))) diff --git a/app/services/grok/services/usage.py b/app/services/grok/services/usage.py index 7550c822..e2734330 100644 --- a/app/services/grok/services/usage.py +++ b/app/services/grok/services/usage.py @@ -9,11 +9,7 @@ from app.core.logger import logger from app.core.config import get_config -from app.core.exceptions import UpstreamException -from app.services.grok.utils.headers import apply_statsig, build_sso_cookie -from app.services.grok.utils.retry import retry_on_status - -LIMITS_API = "https://grok.com/rest/rate-limits" +from app.services.reverse import RateLimitsReverse _USAGE_SEMAPHORE = asyncio.Semaphore(25) _USAGE_SEM_VALUE = 25 @@ -22,52 +18,12 @@ class UsageService: """用量查询服务""" - def __init__(self, proxy: str = None): - self.proxy = proxy or get_config("network.base_proxy_url") - self.timeout = get_config("network.timeout") - - def _build_headers(self, token: str) -> dict: - """构建请求头""" - user_agent = get_config("security.user_agent") - headers = { - "Accept": "*/*", - "Accept-Encoding": "gzip, deflate, br, zstd", - "Accept-Language": "zh-CN,zh;q=0.9", - "Baggage": "sentry-environment=production,sentry-release=d6add6fb0460641fd482d767a335ef72b9b6abb8,sentry-public_key=b311e0f2690c81f25e2c4cf6d4f7ce1c", - "Cache-Control": "no-cache", - "Content-Type": "application/json", - "Origin": "https://grok.com", - "Pragma": "no-cache", - "Priority": "u=1, i", - "Referer": "https://grok.com/", - "Sec-Ch-Ua": '"Google Chrome";v="136", "Chromium";v="136", "Not(A:Brand";v="24"', - "Sec-Ch-Ua-Arch": "arm", - "Sec-Ch-Ua-Bitness": "64", - "Sec-Ch-Ua-Mobile": "?0", - "Sec-Ch-Ua-Model": "", - "Sec-Ch-Ua-Platform": '"macOS"', - "Sec-Fetch-Dest": "empty", - "Sec-Fetch-Mode": "cors", - "Sec-Fetch-Site": "same-origin", - "User-Agent": user_agent, - } - - apply_statsig(headers) - headers["Cookie"] = build_sso_cookie(token) - - return headers - - def _build_proxies(self) -> dict: - """构建代理配置""" - return {"http": self.proxy, "https": self.proxy} if self.proxy else None - - async def get(self, token: str, model_name: str = "grok-4-1-thinking-1129") -> Dict: + async def get(self, token: str) -> Dict: """ 获取速率限制信息 Args: token: 认证 Token - model_name: 模型名称 Returns: 响应数据 @@ -86,61 +42,15 @@ async def get(self, token: str, model_name: str = "grok-4-1-thinking-1129") -> D _USAGE_SEM_VALUE = value _USAGE_SEMAPHORE = asyncio.Semaphore(value) async with _USAGE_SEMAPHORE: - # 定义状态码提取器 - def extract_status(e: Exception) -> int | None: - if isinstance(e, UpstreamException) and e.details: - return e.details.get("status") - return None - - # 定义实际的请求函数 - async def do_request(): - try: - headers = self._build_headers(token) - payload = {"requestKind": "DEFAULT", "modelName": model_name} - browser = get_config("security.browser") - - async with AsyncSession() as session: - response = await session.post( - LIMITS_API, - headers=headers, - json=payload, - impersonate=browser, - timeout=self.timeout, - proxies=self._build_proxies(), - ) - - if response.status_code == 200: - data = response.json() - remaining = data.get("remainingTokens", 0) - logger.info( - f"Usage sync success: remaining={remaining}, token={token[:10]}..." - ) - return data - - logger.error( - f"Usage sync failed: status={response.status_code}, token={token[:10]}..." - ) - - raise UpstreamException( - message=f"Failed to get usage stats: {response.status_code}", - details={"status": response.status_code}, - ) - - except Exception as e: - if isinstance(e, UpstreamException): - raise - logger.error(f"Usage error: {e}") - raise UpstreamException( - message=f"Usage service error: {str(e)}", - details={"error": str(e)}, - ) - - # 带重试的执行 try: - result = await retry_on_status( - do_request, extract_status=extract_status + async with AsyncSession() as session: + response = await RateLimitsReverse.request(session, token) + data = response.json() + remaining = data.get("remainingTokens", 0) + logger.info( + f"Usage sync success: remaining={remaining}, token={token[:10]}..." ) - return result + return data except Exception: # 最后一次失败已经被记录 diff --git a/app/services/reverse/__init__.py b/app/services/reverse/__init__.py index 158e51de..cb4e88aa 100644 --- a/app/services/reverse/__init__.py +++ b/app/services/reverse/__init__.py @@ -6,6 +6,7 @@ from .assets_list import AssetsListReverse from .assets_upload import AssetsUploadReverse from .media_post import MediaPostReverse +from .rate_limits import RateLimitsReverse from .utils.headers import build_headers from .utils.statsig import StatsigGenerator @@ -16,6 +17,7 @@ "AssetsListReverse", "AssetsUploadReverse", "MediaPostReverse", + "RateLimitsReverse", "StatsigGenerator", "build_headers", ] diff --git a/app/services/reverse/rate_limits.py b/app/services/reverse/rate_limits.py new file mode 100644 index 00000000..aa852605 --- /dev/null +++ b/app/services/reverse/rate_limits.py @@ -0,0 +1,108 @@ +""" +Reverse interface: rate limits. +""" + +import orjson +from typing import Any +from curl_cffi.requests import AsyncSession + +from app.core.logger import logger +from app.core.config import get_config +from app.core.exceptions import UpstreamException +from app.services.token.service import TokenService +from app.services.reverse.utils.headers import build_headers +from app.services.reverse.utils.retry import retry_on_status + +RATE_LIMITS_API = "https://grok.com/rest/rate-limits" + + +class RateLimitsReverse: + """/rest/rate-limits reverse interface.""" + + @staticmethod + async def request(session: AsyncSession, token: str) -> Any: + """Fetch rate limits from Grok. + + Args: + session: AsyncSession, the session to use for the request. + token: str, the SSO token. + + Returns: + Any: The response from the request. + """ + try: + # Get proxies + base_proxy = get_config("network.base_proxy_url") + proxies = {"http": base_proxy, "https": base_proxy} if base_proxy else None + + # Build headers + headers = build_headers( + cookie_token=token, + content_type="application/json", + origin="https://grok.com", + referer="https://grok.com/", + ) + + # Build payload + payload = { + "requestKind": "DEFAULT", + "modelName": "grok-4-1-thinking-1129", + } + + # Curl Config + timeout = get_config("network.timeout") + browser = get_config("security.browser") + + async def _do_request(): + response = await session.post( + RATE_LIMITS_API, + headers=headers, + data=orjson.dumps(payload), + timeout=timeout, + proxies=proxies, + impersonate=browser, + ) + + if response.status_code != 200: + logger.error( + f"RateLimitsReverse: Request failed, {response.status_code}", + extra={"error_type": "UpstreamException"}, + ) + raise UpstreamException( + message=f"RateLimitsReverse: Request failed, {response.status_code}", + details={"status": response.status_code}, + ) + + return response + + return await retry_on_status(_do_request) + + except Exception as e: + # Handle upstream exception + if isinstance(e, UpstreamException): + status = None + if e.details and "status" in e.details: + status = e.details["status"] + else: + status = getattr(e, "status_code", None) + if status == 401: + try: + await TokenService.record_fail( + token, status, "rate_limits_auth_failed" + ) + except Exception: + pass + raise + + # Handle other non-upstream exceptions + logger.error( + f"RateLimitsReverse: Request failed, {str(e)}", + extra={"error_type": type(e).__name__}, + ) + raise UpstreamException( + message=f"RateLimitsReverse: Request failed, {str(e)}", + details={"status": 502, "error": str(e)}, + ) + + +__all__ = ["RateLimitsReverse"] diff --git a/app/services/token/manager.py b/app/services/token/manager.py index b5e28f14..e8e6d41a 100644 --- a/app/services/token/manager.py +++ b/app/services/token/manager.py @@ -329,7 +329,6 @@ async def consume( async def sync_usage( self, token_str: str, - model_name: str, fallback_effort: EffortType = EffortType.LOW, consume_on_fail: bool = True, is_usage: bool = True, @@ -341,7 +340,6 @@ async def sync_usage( Args: token_str: Token 字符串(可带 sso= 前缀) - model_name: 模型名称(用于 API 查询) fallback_effort: 降级时的消耗力度 consume_on_fail: 失败时是否降级扣费 is_usage: 是否记录为一次使用(影响 use_count) @@ -367,7 +365,7 @@ async def sync_usage( from app.services.grok.services.usage import UsageService usage_service = UsageService() - result = await usage_service.get(token_str, model_name=model_name) + result = await usage_service.get(token_str) if result and "remainingTokens" in result: old_quota = target_token.quota diff --git a/app/services/token/service.py b/app/services/token/service.py index 75b23a67..63e635c3 100644 --- a/app/services/token/service.py +++ b/app/services/token/service.py @@ -43,22 +43,19 @@ async def consume(token: str, effort: EffortType = EffortType.LOW) -> bool: return await manager.consume(token, effort) @staticmethod - async def sync_usage( - token: str, model: str, effort: EffortType = EffortType.LOW - ) -> bool: + async def sync_usage(token: str, effort: EffortType = EffortType.LOW) -> bool: """ 同步 Token 使用量(优先 API,降级本地) Args: token: Token 字符串 - model: 模型名称 effort: 降级时的消耗力度 Returns: 是否成功 """ manager = await get_token_manager() - return await manager.sync_usage(token, model, effort) + return await manager.sync_usage(token, effort) @staticmethod async def record_fail(token: str, status_code: int = 401, reason: str = "") -> bool: From 370edb5f92a1e87588599cb9ea74b08e0dd6b1c8 Mon Sep 17 00:00:00 2001 From: Chenyme <118253778+chenyme@users.noreply.github.com> Date: Tue, 10 Feb 2026 10:15:04 +0800 Subject: [PATCH 04/27] refactor: integrate reverse request handlers for NSFW --- app/services/grok/services/nsfw.py | 148 ++++----------------------- app/services/reverse/__init__.py | 8 ++ app/services/reverse/nsfw_mgmt.py | 132 ++++++++++++++++++++++++ app/services/reverse/set_birth.py | 117 ++++++++++++++++++++++ app/services/reverse/utils/grpc.py | 156 +++++++++++++++++++++++++++++ 5 files changed, 430 insertions(+), 131 deletions(-) create mode 100644 app/services/reverse/nsfw_mgmt.py create mode 100644 app/services/reverse/set_birth.py create mode 100644 app/services/reverse/utils/grpc.py diff --git a/app/services/grok/services/nsfw.py b/app/services/grok/services/nsfw.py index 26a4f261..b56cd626 100644 --- a/app/services/grok/services/nsfw.py +++ b/app/services/grok/services/nsfw.py @@ -6,23 +6,15 @@ from dataclasses import dataclass from typing import Optional -import datetime -import random from curl_cffi.requests import AsyncSession -from app.core.config import get_config from app.core.logger import logger from app.services.grok.protocols.grpc_web import ( - encode_grpc_web_payload, - parse_grpc_web_response, - get_grpc_status, -) -from app.services.grok.utils.headers import build_sso_cookie - -NSFW_API = "https://grok.com/auth_mgmt.AuthManagement/UpdateUserFeatureControls" -BIRTH_DATE_API = "https://grok.com/rest/auth/set-birth-date" - +from app.core.config import get_config +from app.core.exceptions import UpstreamException +from app.services.reverse import NsfwMgmtReverse, SetBirthReverse +from app.services.reverse.utils.grpc import GrpcStatus @dataclass class NSFWResult: @@ -38,139 +30,33 @@ class NSFWResult: class NSFWService: """NSFW 模式服务""" - def __init__(self, proxy: str = None): - self.proxy = proxy or get_config("network.base_proxy_url") - self.timeout = float(get_config("network.timeout")) - - def _build_proxies(self) -> Optional[dict]: - """构建代理配置""" - return {"http": self.proxy, "https": self.proxy} if self.proxy else None - - @staticmethod - def _random_birth_date() -> str: - """生成随机出生日期(20-40岁)""" - today = datetime.date.today() - birth_year = today.year - random.randint(20, 40) - birth_month = random.randint(1, 12) - birth_day = random.randint(1, 28) - hour = random.randint(0, 23) - minute = random.randint(0, 59) - second = random.randint(0, 59) - microsecond = random.randint(0, 999) - return f"{birth_year:04d}-{birth_month:02d}-{birth_day:02d}T{hour:02d}:{minute:02d}:{second:02d}.{microsecond:03d}Z" - - def _build_headers(self, token: str) -> dict: - """构造 gRPC-Web 请求头""" - cookie = build_sso_cookie(token, include_rw=True) - user_agent = get_config("security.user_agent") - return { - "accept": "*/*", - "content-type": "application/grpc-web+proto", - "origin": "https://grok.com", - "referer": "https://grok.com/", - "user-agent": user_agent, - "x-grpc-web": "1", - "x-user-agent": "connect-es/2.1.1", - "cookie": cookie, - } - - def _build_birth_headers(self, token: str) -> dict: - """构造设置出生日期请求头""" - cookie = build_sso_cookie(token, include_rw=True) - user_agent = get_config("security.user_agent") - return { - "accept": "*/*", - "content-type": "application/json", - "origin": "https://grok.com", - "referer": "https://grok.com/?_s=account", - "user-agent": user_agent, - "cookie": cookie, - } - - @staticmethod - def _build_payload() -> bytes: - """构造请求 payload""" - # protobuf (match captured HAR): - # 0a 02 10 01 -> field 1 (len=2) with inner bool=true - # 12 1a -> field 2, length 26 - # 0a 18 -> nested message with name string - name = b"always_show_nsfw_content" - inner = b"\x0a" + bytes([len(name)]) + name - protobuf = b"\x0a\x02\x10\x01\x12" + bytes([len(inner)]) + inner - return encode_grpc_web_payload(protobuf) - - async def _set_birth_date( - self, session: AsyncSession, token: str - ) -> tuple[bool, int, Optional[str]]: - """设置出生日期""" - headers = self._build_birth_headers(token) - payload = {"birthDate": self._random_birth_date()} - - try: - response = await session.post( - BIRTH_DATE_API, - json=payload, - headers=headers, - timeout=self.timeout, - proxies=self._build_proxies(), - ) - if response.status_code in (200, 204): - return True, response.status_code, None - return False, response.status_code, f"HTTP {response.status_code}" - except Exception as e: - return False, 0, str(e)[:100] - async def enable(self, token: str) -> NSFWResult: """为单个 token 开启 NSFW 模式""" - headers = self._build_headers(token) - payload = self._build_payload() - logger.debug(f"NSFW payload: len={len(payload)} hex={payload.hex()}") - try: browser = get_config("security.browser") async with AsyncSession(impersonate=browser) as session: # 先设置出生日期 - ok, birth_status, birth_err = await self._set_birth_date(session, token) - if not ok: + try: + await SetBirthReverse.request(session, token) + except UpstreamException as e: + status = None + if e.details and "status" in e.details: + status = e.details["status"] + else: + status = getattr(e, "status_code", None) return NSFWResult( success=False, - http_status=birth_status, - error=f"Set birth date failed: {birth_err}", + http_status=status or 0, + error=f"Set birth date failed: {str(e)}", ) # 开启 NSFW - response = await session.post( - NSFW_API, - data=payload, - headers=headers, - timeout=self.timeout, - proxies=self._build_proxies(), - ) - - if response.status_code != 200: - return NSFWResult( - success=False, - http_status=response.status_code, - error=f"HTTP {response.status_code}", - ) - - # 解析 gRPC-Web 响应 - _, trailers = parse_grpc_web_response( - response.content, content_type=response.headers.get("content-type") - ) - - grpc_status = get_grpc_status(trailers) - logger.debug( - f"NSFW response: http={response.status_code} grpc={grpc_status.code} " - f"msg={grpc_status.message} trailers={trailers}" - ) - - # HTTP 200 且无 grpc-status(空响应)或 grpc-status=0 都算成功 - success = grpc_status.code == -1 or grpc_status.ok + grpc_status: GrpcStatus = await NsfwMgmtReverse.request(session, token) + success = grpc_status.code in (-1, 0) return NSFWResult( success=success, - http_status=response.status_code, + http_status=200, grpc_status=grpc_status.code, grpc_message=grpc_status.message or None, ) diff --git a/app/services/reverse/__init__.py b/app/services/reverse/__init__.py index cb4e88aa..af7a4fdb 100644 --- a/app/services/reverse/__init__.py +++ b/app/services/reverse/__init__.py @@ -6,7 +6,11 @@ from .assets_list import AssetsListReverse from .assets_upload import AssetsUploadReverse from .media_post import MediaPostReverse +from .nsfw_mgmt import NsfwMgmtReverse from .rate_limits import RateLimitsReverse +from .set_birth import SetBirthReverse +from .livekit_tokens import LivekitTokenReverse +from .ws_livekit import LivekitWebSocketReverse from .utils.headers import build_headers from .utils.statsig import StatsigGenerator @@ -17,7 +21,11 @@ "AssetsListReverse", "AssetsUploadReverse", "MediaPostReverse", + "NsfwMgmtReverse", "RateLimitsReverse", + "SetBirthReverse", + "LivekitTokenReverse", + "LivekitWebSocketReverse", "StatsigGenerator", "build_headers", ] diff --git a/app/services/reverse/nsfw_mgmt.py b/app/services/reverse/nsfw_mgmt.py new file mode 100644 index 00000000..349e2417 --- /dev/null +++ b/app/services/reverse/nsfw_mgmt.py @@ -0,0 +1,132 @@ +""" +Reverse interface: NSFW feature controls (gRPC-Web). +""" + +from curl_cffi.requests import AsyncSession + +from app.core.logger import logger +from app.core.config import get_config +from app.core.exceptions import UpstreamException +from app.services.token.service import TokenService +from app.services.reverse.utils.headers import build_headers +from app.services.reverse.utils.retry import retry_on_status +from app.services.reverse.utils.grpc import GrpcClient, GrpcStatus + +NSFW_MGMT_API = "https://grok.com/auth_mgmt.AuthManagement/UpdateUserFeatureControls" + + +class NsfwMgmtReverse: + """/auth_mgmt.AuthManagement/UpdateUserFeatureControls reverse interface.""" + + @staticmethod + async def request(session: AsyncSession, token: str) -> GrpcStatus: + """Enable NSFW feature control via gRPC-Web. + + Args: + session: AsyncSession, the session to use for the request. + token: str, the SSO token. + + Returns: + GrpcStatus: Parsed gRPC status. + """ + try: + # Get proxies + base_proxy = get_config("network.base_proxy_url") + proxies = {"http": base_proxy, "https": base_proxy} if base_proxy else None + + # Build headers + headers = build_headers( + cookie_token=token, + origin="https://grok.com", + referer="https://grok.com/?_s=data", + ) + headers["Content-Type"] = "application/grpc-web+proto" + headers["Accept"] = "*/*" + headers["Sec-Fetch-Dest"] = "empty" + headers["x-grpc-web"] = "1" + headers["x-user-agent"] = "connect-es/2.1.1" + headers["Cache-Control"] = "no-cache" + headers["Pragma"] = "no-cache" + + # Build payload + name = "always_show_nsfw_content".encode("utf-8") + inner = b"\x0a" + bytes([len(name)]) + name + protobuf = b"\x0a\x02\x10\x01\x12" + bytes([len(inner)]) + inner + payload = GrpcClient.encode_payload(protobuf) + + # Curl Config + timeout = get_config("network.timeout") + browser = get_config("security.browser") + + async def _do_request(): + response = await session.post( + NSFW_MGMT_API, + headers=headers, + data=payload, + timeout=timeout, + proxies=proxies, + impersonate=browser, + ) + + if response.status_code != 200: + logger.error( + f"NsfwMgmtReverse: Request failed, {response.status_code}", + extra={"error_type": "UpstreamException"}, + ) + raise UpstreamException( + message=f"NsfwMgmtReverse: Request failed, {response.status_code}", + details={"status": response.status_code}, + ) + + return response + + response = await retry_on_status(_do_request) + + _, trailers = GrpcClient.parse_response( + response.content, + content_type=response.headers.get("content-type"), + headers=response.headers, + ) + grpc_status = GrpcClient.get_status(trailers) + + if grpc_status.code not in (-1, 0): + raise UpstreamException( + message=f"NsfwMgmtReverse: gRPC failed, {grpc_status.code}", + details={ + "status": grpc_status.http_equiv, + "grpc_status": grpc_status.code, + "grpc_message": grpc_status.message, + }, + ) + + return grpc_status + + except Exception as e: + # Handle upstream exception + if isinstance(e, UpstreamException): + status = None + if e.details and "status" in e.details: + status = e.details["status"] + else: + status = getattr(e, "status_code", None) + if status == 401: + try: + await TokenService.record_fail( + token, status, "nsfw_mgmt_auth_failed" + ) + except Exception: + pass + raise + + # Handle other non-upstream exceptions + logger.error( + f"NsfwMgmtReverse: Request failed, {str(e)}", + extra={"error_type": type(e).__name__}, + ) + raise UpstreamException( + message=f"NsfwMgmtReverse: Request failed, {str(e)}", + details={"status": 502, "error": str(e)}, + ) + + +__all__ = ["NsfwMgmtReverse"] diff --git a/app/services/reverse/set_birth.py b/app/services/reverse/set_birth.py new file mode 100644 index 00000000..f1b72211 --- /dev/null +++ b/app/services/reverse/set_birth.py @@ -0,0 +1,117 @@ +""" +Reverse interface: set birth date. +""" + +import datetime +import random +from typing import Any +from curl_cffi.requests import AsyncSession + +from app.core.logger import logger +from app.core.config import get_config +from app.core.exceptions import UpstreamException +from app.services.token.service import TokenService +from app.services.reverse.utils.headers import build_headers +from app.services.reverse.utils.retry import retry_on_status + +SET_BIRTH_API = "https://grok.com/rest/auth/set-birth-date" + + +class SetBirthReverse: + """/rest/auth/set-birth-date reverse interface.""" + + @staticmethod + async def request(session: AsyncSession, token: str) -> Any: + """Set birth date in Grok. + + Args: + session: AsyncSession, the session to use for the request. + token: str, the SSO token. + + Returns: + Any: The response from the request. + """ + try: + # Get proxies + base_proxy = get_config("network.base_proxy_url") + proxies = {"http": base_proxy, "https": base_proxy} if base_proxy else None + + # Build headers + headers = build_headers( + cookie_token=token, + content_type="application/json", + origin="https://grok.com", + referer="https://grok.com/?_s=home", + ) + + # Build payload + today = datetime.date.today() + birth_year = today.year - random.randint(20, 40) + birth_month = random.randint(1, 12) + birth_day = random.randint(1, 28) + hour = random.randint(0, 23) + minute = random.randint(0, 59) + second = random.randint(0, 59) + microsecond = random.randint(0, 999) + payload = { + "birthDate": f"{birth_year:04d}-{birth_month:02d}-{birth_day:02d}" + f"T{hour:02d}:{minute:02d}:{second:02d}.{microsecond:03d}Z" + } + + # Curl Config + timeout = get_config("network.timeout") + browser = get_config("security.browser") + + async def _do_request(): + response = await session.post( + SET_BIRTH_API, + headers=headers, + json=payload, + timeout=timeout, + proxies=proxies, + impersonate=browser, + ) + + if response.status_code not in (200, 204): + logger.error( + f"SetBirthReverse: Request failed, {response.status_code}", + extra={"error_type": "UpstreamException"}, + ) + raise UpstreamException( + message=f"SetBirthReverse: Request failed, {response.status_code}", + details={"status": response.status_code}, + ) + + return response + + return await retry_on_status(_do_request) + + except Exception as e: + # Handle upstream exception + if isinstance(e, UpstreamException): + status = None + if e.details and "status" in e.details: + status = e.details["status"] + else: + status = getattr(e, "status_code", None) + if status == 401: + try: + await TokenService.record_fail( + token, status, "set_birth_auth_failed" + ) + except Exception: + pass + raise + + # Handle other non-upstream exceptions + logger.error( + f"SetBirthReverse: Request failed, {str(e)}", + extra={"error_type": type(e).__name__}, + ) + raise UpstreamException( + message=f"SetBirthReverse: Request failed, {str(e)}", + details={"status": 502, "error": str(e)}, + ) + + +__all__ = ["SetBirthReverse"] diff --git a/app/services/reverse/utils/grpc.py b/app/services/reverse/utils/grpc.py new file mode 100644 index 00000000..74b745e1 --- /dev/null +++ b/app/services/reverse/utils/grpc.py @@ -0,0 +1,156 @@ +""" +gRPC-Web helpers for reverse interfaces. +""" + +import base64 +import re +import struct +from dataclasses import dataclass +from typing import Dict, List, Mapping, Tuple +from urllib.parse import unquote + + +_B64_RE = re.compile(rb"^[A-Za-z0-9+/=\r\n]+$") + + +@dataclass(frozen=True) +class GrpcStatus: + code: int + message: str = "" + + @property + def ok(self) -> bool: + return self.code == 0 + + @property + def http_equiv(self) -> int: + mapping = { + 0: 200, + 16: 401, + 7: 403, + 8: 429, + 4: 504, + 14: 503, + } + return mapping.get(self.code, 502) + + +class GrpcClient: + """gRPC-Web helpers wrapper.""" + + @staticmethod + def encode_payload(data: bytes) -> bytes: + """Encode gRPC-Web data frame.""" + return b"\x00" + struct.pack(">I", len(data)) + data + + @staticmethod + def _maybe_decode_grpc_web_text(body: bytes, content_type: str | None) -> bytes: + ct = (content_type or "").lower() + if "grpc-web-text" in ct: + compact = b"".join(body.split()) + return base64.b64decode(compact, validate=False) + + head = body[: min(len(body), 2048)] + if head and _B64_RE.fullmatch(head): + compact = b"".join(body.split()) + try: + return base64.b64decode(compact, validate=True) + except Exception: + return body + return body + + @staticmethod + def _parse_trailer_block(payload: bytes) -> Dict[str, str]: + text = payload.decode("utf-8", errors="replace") + lines = [ln for ln in re.split(r"\r\n|\n", text) if ln] + + trailers: Dict[str, str] = {} + for ln in lines: + if ":" not in ln: + continue + k, v = ln.split(":", 1) + trailers[k.strip().lower()] = v.strip() + + if "grpc-message" in trailers: + trailers["grpc-message"] = unquote(trailers["grpc-message"]) + + return trailers + + @classmethod + def parse_response( + cls, + body: bytes, + content_type: str | None = None, + headers: Mapping[str, str] | None = None, + ) -> Tuple[List[bytes], Dict[str, str]]: + decoded = cls._maybe_decode_grpc_web_text(body, content_type) + + messages: List[bytes] = [] + trailers: Dict[str, str] = {} + + i = 0 + n = len(decoded) + while i < n: + if n - i < 5: + break + + flag = decoded[i] + length = int.from_bytes(decoded[i + 1 : i + 5], "big") + i += 5 + + if n - i < length: + break + + payload = decoded[i : i + length] + i += length + + if flag & 0x80: + trailers.update(cls._parse_trailer_block(payload)) + elif flag & 0x01: + raise ValueError("grpc-web compressed flag not supported") + else: + messages.append(payload) + + if headers: + lower = {k.lower(): v for k, v in headers.items()} + if "grpc-status" in lower and "grpc-status" not in trailers: + trailers["grpc-status"] = str(lower["grpc-status"]).strip() + if "grpc-message" in lower and "grpc-message" not in trailers: + trailers["grpc-message"] = unquote(str(lower["grpc-message"]).strip()) + + return messages, trailers + + @staticmethod + def get_status(trailers: Mapping[str, str]) -> GrpcStatus: + raw = str(trailers.get("grpc-status", "")).strip() + msg = str(trailers.get("grpc-message", "")).strip() + try: + code = int(raw) + except Exception: + code = -1 + return GrpcStatus(code=code, message=msg) + + +def encode_grpc_web_payload(data: bytes) -> bytes: + return GrpcClient.encode_payload(data) + + +def parse_grpc_web_response( + body: bytes, + content_type: str | None = None, + headers: Mapping[str, str] | None = None, +) -> Tuple[List[bytes], Dict[str, str]]: + return GrpcClient.parse_response(body, content_type=content_type, headers=headers) + + +def get_grpc_status(trailers: Mapping[str, str]) -> GrpcStatus: + return GrpcClient.get_status(trailers) + + +__all__ = [ + "encode_grpc_web_payload", + "parse_grpc_web_response", + "get_grpc_status", + "GrpcStatus", + "GrpcClient", +] From 06df48b87acb2c802b5854ee6a0b89fb7b704828 Mon Sep 17 00:00:00 2001 From: Chenyme <118253778+chenyme@users.noreply.github.com> Date: Tue, 10 Feb 2026 10:15:24 +0800 Subject: [PATCH 05/27] refactor: integrate reverse request handlers for NSFW --- app/services/grok/services/nsfw.py | 1 - 1 file changed, 1 deletion(-) diff --git a/app/services/grok/services/nsfw.py b/app/services/grok/services/nsfw.py index b56cd626..a919d54e 100644 --- a/app/services/grok/services/nsfw.py +++ b/app/services/grok/services/nsfw.py @@ -10,7 +10,6 @@ from curl_cffi.requests import AsyncSession from app.core.logger import logger -from app.services.grok.protocols.grpc_web import ( from app.core.config import get_config from app.core.exceptions import UpstreamException from app.services.reverse import NsfwMgmtReverse, SetBirthReverse From 413493d5b524de1eeea029f40eba8cbe1087e766 Mon Sep 17 00:00:00 2001 From: Chenyme <118253778+chenyme@users.noreply.github.com> Date: Tue, 10 Feb 2026 10:39:55 +0800 Subject: [PATCH 06/27] refactor: integrate reverse request handlers for image and voice services --- app/services/grok/services/image.py | 14 +- app/services/grok/services/voice.py | 104 ++-------- app/services/reverse/__init__.py | 3 +- app/services/reverse/assets_upload.py | 2 +- app/services/reverse/utils/grpc.py | 5 +- app/services/reverse/utils/headers.py | 4 +- app/services/reverse/utils/statsig.py | 7 +- app/services/reverse/utils/websocket.py | 144 ++++++++++++++ app/services/reverse/ws_livekit.py | 253 ++++++++++++++++++++++++ 9 files changed, 417 insertions(+), 119 deletions(-) create mode 100644 app/services/reverse/utils/websocket.py create mode 100644 app/services/reverse/ws_livekit.py diff --git a/app/services/grok/services/image.py b/app/services/grok/services/image.py index 573e1a21..218c334b 100644 --- a/app/services/grok/services/image.py +++ b/app/services/grok/services/image.py @@ -10,14 +10,13 @@ import time import uuid from typing import AsyncGenerator, Dict, Optional -from urllib.parse import urlparse import aiohttp -from aiohttp_socks import ProxyConnector from app.core.config import get_config from app.core.logger import logger from app.services.grok.utils.headers import build_sso_cookie +from app.services.reverse.utils.websocket import resolve_proxy WS_URL = "wss://grok.com/ws/imagine/listen" @@ -36,16 +35,7 @@ def __init__(self): def _resolve_proxy(self) -> tuple[aiohttp.BaseConnector, Optional[str]]: proxy_url = get_config("network.base_proxy_url") - if not proxy_url: - return aiohttp.TCPConnector(ssl=self._ssl_context), None - - scheme = urlparse(proxy_url).scheme.lower() - if scheme.startswith("socks"): - logger.info(f"Using SOCKS proxy: {proxy_url}") - return ProxyConnector.from_url(proxy_url, ssl=self._ssl_context), None - - logger.info(f"Using HTTP proxy: {proxy_url}") - return aiohttp.TCPConnector(ssl=self._ssl_context), proxy_url + return resolve_proxy(proxy_url, self._ssl_context) def _get_ws_headers(self, token: str) -> Dict[str, str]: cookie = build_sso_cookie(token, include_rw=True) diff --git a/app/services/grok/services/voice.py b/app/services/grok/services/voice.py index 006fc547..208bf954 100644 --- a/app/services/grok/services/voice.py +++ b/app/services/grok/services/voice.py @@ -2,25 +2,17 @@ Grok Voice Mode Service """ -import orjson -from typing import Dict, Any +from typing import Any, Dict from curl_cffi.requests import AsyncSession -from app.core.logger import logger from app.core.config import get_config -from app.core.exceptions import UpstreamException -from app.services.grok.utils.headers import apply_statsig, build_sso_cookie - -LIVEKIT_TOKEN_API = "https://grok.com/rest/livekit/tokens" +from app.services.reverse import LivekitTokenReverse class VoiceService: """Voice Mode Service (LiveKit)""" - def __init__(self, proxy: str = None): - self.proxy = proxy or get_config("network.base_proxy_url") - async def get_token( self, token: str, @@ -28,86 +20,12 @@ async def get_token( personality: str = "assistant", speed: float = 1.0, ) -> Dict[str, Any]: - """ - Get LiveKit token - - Args: - token: Auth token - Returns: - Dict containing token and livekitUrl - """ - logger.debug( - f"Voice token request: voice={voice}, personality={personality}, speed={speed}" - ) - headers = self._build_headers(token) - payload = self._build_payload(voice, personality, speed) - - proxies = {"http": self.proxy, "https": self.proxy} if self.proxy else None - - try: - browser = get_config("security.browser") - timeout = get_config("network.timeout") - async with AsyncSession(impersonate=browser) as session: - response = await session.post( - LIVEKIT_TOKEN_API, - headers=headers, - data=orjson.dumps(payload), - timeout=timeout, - proxies=proxies, - ) - - if response.status_code != 200: - body = response.text[:200] - logger.error( - f"Voice token failed: status={response.status_code}, body={body}" - ) - raise UpstreamException( - message=f"Failed to get voice token: {response.status_code}", - details={"status": response.status_code, "body": response.text}, - ) - - result = response.json() - logger.info(f"Voice token obtained: voice={voice}") - return result - - except Exception as e: - logger.error(f"Voice service error: {e}") - if isinstance(e, UpstreamException): - raise - raise UpstreamException(f"Voice service error: {str(e)}") - - def _build_headers(self, token: str) -> Dict[str, str]: - headers = { - "Accept": "*/*", - "Content-Type": "application/json", - "Origin": "https://grok.com", - "Referer": "https://grok.com/", - # Statsig ID is crucial - } - - apply_statsig(headers) - headers["Cookie"] = build_sso_cookie(token) - - return headers - - def _build_payload( - self, - voice: str = "ara", - personality: str = "assistant", - speed: float = 1.0, - ) -> Dict[str, Any]: - """Construct payload with voice settings""" - return { - "sessionPayload": orjson.dumps( - { - "voice": voice, - "personality": personality, - "playback_speed": speed, - "enable_vision": False, - "turn_detection": {"type": "server_vad"}, - } - ).decode(), - "requestAgentDispatch": False, - "livekitUrl": "wss://livekit.grok.com", - "params": {"enable_markdown_transcript": "true"}, - } + browser = get_config("security.browser") + async with AsyncSession(impersonate=browser) as session: + return await LivekitTokenReverse.request( + session, + token=token, + voice=voice, + personality=personality, + speed=speed, + ) diff --git a/app/services/reverse/__init__.py b/app/services/reverse/__init__.py index af7a4fdb..69594b5b 100644 --- a/app/services/reverse/__init__.py +++ b/app/services/reverse/__init__.py @@ -9,8 +9,7 @@ from .nsfw_mgmt import NsfwMgmtReverse from .rate_limits import RateLimitsReverse from .set_birth import SetBirthReverse -from .livekit_tokens import LivekitTokenReverse -from .ws_livekit import LivekitWebSocketReverse +from .ws_livekit import LivekitTokenReverse, LivekitWebSocketReverse from .utils.headers import build_headers from .utils.statsig import StatsigGenerator diff --git a/app/services/reverse/assets_upload.py b/app/services/reverse/assets_upload.py index e4f65acf..0466be6c 100644 --- a/app/services/reverse/assets_upload.py +++ b/app/services/reverse/assets_upload.py @@ -6,9 +6,9 @@ from curl_cffi.requests import AsyncSession from app.core.logger import logger +from app.core.config import get_config from app.core.exceptions import UpstreamException from app.services.token.service import TokenService -from app.core.config import get_config from app.services.reverse.utils.headers import build_headers from app.services.reverse.utils.retry import retry_on_status diff --git a/app/services/reverse/utils/grpc.py b/app/services/reverse/utils/grpc.py index 74b745e1..2cb26883 100644 --- a/app/services/reverse/utils/grpc.py +++ b/app/services/reverse/utils/grpc.py @@ -10,9 +10,6 @@ from urllib.parse import unquote -_B64_RE = re.compile(rb"^[A-Za-z0-9+/=\r\n]+$") - - @dataclass(frozen=True) class GrpcStatus: code: int @@ -51,7 +48,7 @@ def _maybe_decode_grpc_web_text(body: bytes, content_type: str | None) -> bytes: return base64.b64decode(compact, validate=False) head = body[: min(len(body), 2048)] - if head and _B64_RE.fullmatch(head): + if head and re.compile(rb"^[A-Za-z0-9+/=\r\n]+$").fullmatch(head): compact = b"".join(body.split()) try: return base64.b64decode(compact, validate=True) diff --git a/app/services/reverse/utils/headers.py b/app/services/reverse/utils/headers.py index fd7e74cc..e26de015 100644 --- a/app/services/reverse/utils/headers.py +++ b/app/services/reverse/utils/headers.py @@ -5,8 +5,8 @@ from urllib.parse import urlparse from typing import Dict, Optional -from app.core.config import get_config from app.core.logger import logger +from app.core.config import get_config from app.services.reverse.utils.statsig import StatsigGenerator @@ -14,7 +14,7 @@ def _build_sso_cookie(sso_token: str) -> str: """ Build SSO Cookie string. """ - # Remove "sso=" prefix if present + # Format sso_token = sso_token[4:] if sso_token.startswith("sso=") else sso_token # SSO Cookie diff --git a/app/services/reverse/utils/statsig.py b/app/services/reverse/utils/statsig.py index 28280a03..69e81968 100644 --- a/app/services/reverse/utils/statsig.py +++ b/app/services/reverse/utils/statsig.py @@ -6,11 +6,8 @@ import random import string -from app.core.config import get_config from app.core.logger import logger - - -STATIC_STATSIG_ID = "ZTpUeXBlRXJyb3I6IENhbm5vdCByZWFkIHByb3BlcnRpZXMgb2YgdW5kZWZpbmVkIChyZWFkaW5nICdjaGlsZE5vZGVzJyk=" +from app.core.config import get_config class StatsigGenerator: @@ -53,7 +50,7 @@ def gen_id() -> str: # Static Statsig ID logger.debug("Generating static Statsig ID") - return STATIC_STATSIG_ID + return "ZTpUeXBlRXJyb3I6IENhbm5vdCByZWFkIHByb3BlcnRpZXMgb2YgdW5kZWZpbmVkIChyZWFkaW5nICdjaGlsZE5vZGVzJyk=" __all__ = ["StatsigGenerator"] diff --git a/app/services/reverse/utils/websocket.py b/app/services/reverse/utils/websocket.py new file mode 100644 index 00000000..a4a8ffb9 --- /dev/null +++ b/app/services/reverse/utils/websocket.py @@ -0,0 +1,144 @@ +""" +WebSocket helpers for reverse interfaces. +""" + +import ssl +import certifi +import aiohttp +from aiohttp_socks import ProxyConnector +from typing import Mapping, Optional +from urllib.parse import urlparse + +from app.core.logger import logger +from app.core.config import get_config + + +def _default_ssl_context() -> ssl.SSLContext: + context = ssl.create_default_context() + context.load_verify_locations(certifi.where()) + return context + + +def _normalize_socks_proxy(proxy_url: str) -> tuple[str, Optional[bool]]: + scheme = urlparse(proxy_url).scheme.lower() + rdns: Optional[bool] = None + base_scheme = scheme + + if scheme == "socks5h": + base_scheme = "socks5" + rdns = True + elif scheme == "socks4a": + base_scheme = "socks4" + rdns = True + + if base_scheme != scheme: + proxy_url = proxy_url.replace(f"{scheme}://", f"{base_scheme}://", 1) + + return proxy_url, rdns + + +def resolve_proxy( + proxy_url: str | None, ssl_context: ssl.SSLContext +) -> tuple[aiohttp.BaseConnector, Optional[str]]: + """Resolve proxy connector. + + Args: + proxy_url: str, the proxy URL. + ssl_context: ssl.SSLContext, the SSL context. + + Returns: + tuple[aiohttp.BaseConnector, Optional[str]]: The proxy connector and the proxy URL. + """ + if not proxy_url: + return aiohttp.TCPConnector(ssl=ssl_context), None + + scheme = urlparse(proxy_url).scheme.lower() + if scheme.startswith("socks"): + normalized, rdns = _normalize_socks_proxy(proxy_url) + logger.info(f"Using SOCKS proxy: {proxy_url}") + try: + if rdns is not None: + return ( + ProxyConnector.from_url(normalized, rdns=rdns, ssl=ssl_context), + None, + ) + except TypeError: + return ProxyConnector.from_url(normalized, ssl=ssl_context), None + return ProxyConnector.from_url(normalized, ssl=ssl_context), None + + logger.info(f"Using HTTP proxy: {proxy_url}") + return aiohttp.TCPConnector(ssl=ssl_context), proxy_url + + +class WebSocketConnection: + """WebSocket connection wrapper.""" + + def __init__( + self, + session: aiohttp.ClientSession, + ws: aiohttp.ClientWebSocketResponse, + ) -> None: + self.session = session + self.ws = ws + + async def close(self) -> None: + if not self.ws.closed: + await self.ws.close() + await self.session.close() + + async def __aenter__(self) -> aiohttp.ClientWebSocketResponse: + return self.ws + + async def __aexit__(self, exc_type, exc, tb) -> None: + await self.close() + + +class WebSocketClient: + """WebSocket client with proxy support.""" + + def __init__(self, proxy: str | None = None) -> None: + self.proxy = proxy or get_config("network.base_proxy_url") + self._ssl_context = _default_ssl_context() + + async def connect( + self, + url: str, + headers: Mapping[str, str] | None = None, + timeout: float | aiohttp.ClientTimeout | None = None, + ) -> WebSocketConnection: + """Connect to the WebSocket. + + Args: + url: str, the URL to connect to. + headers: Mapping[str, str], the headers to send. + timeout: float | aiohttp.ClientTimeout | None, the timeout. + + Returns: + WebSocketConnection: The WebSocket connection. + """ + # Resolve proxy + connector, proxy = resolve_proxy(self.proxy, self._ssl_context) + + # Build client timeout + client_timeout = ( + timeout + if isinstance(timeout, aiohttp.ClientTimeout) + else aiohttp.ClientTimeout(total=timeout) + ) + + # Create session + session = aiohttp.ClientSession(connector=connector, timeout=client_timeout) + try: + ws = await session.ws_connect( + url, + headers=headers, + proxy=proxy, + ssl=self._ssl_context, + ) + return WebSocketConnection(session, ws) + except Exception: + await session.close() + raise + + +__all__ = ["WebSocketClient", "WebSocketConnection", "resolve_proxy"] diff --git a/app/services/reverse/ws_livekit.py b/app/services/reverse/ws_livekit.py new file mode 100644 index 00000000..52eb84ea --- /dev/null +++ b/app/services/reverse/ws_livekit.py @@ -0,0 +1,253 @@ +""" +Reverse interface: LiveKit token + WebSocket. +""" + +import orjson +from typing import Any, Dict +from urllib.parse import urlencode +from curl_cffi.requests import AsyncSession + +from app.core.logger import logger +from app.core.config import get_config +from app.core.exceptions import UpstreamException +from app.services.token.service import TokenService +from app.services.reverse.utils.headers import build_headers +from app.services.reverse.utils.retry import retry_on_status +from app.services.reverse.utils.websocket import WebSocketClient, WebSocketConnection + +LIVEKIT_TOKEN_API = "https://grok.com/rest/livekit/tokens" +LIVEKIT_WS_URL = "wss://livekit.grok.com" + + +class LivekitTokenReverse: + """/rest/livekit/tokens reverse interface.""" + + @staticmethod + async def request( + session: AsyncSession, + token: str, + voice: str = "ara", + personality: str = "assistant", + speed: float = 1.0, + livekit_url: str = LIVEKIT_WS_URL, + ) -> Dict[str, Any]: + """Fetch LiveKit token. + + Args: + session: AsyncSession, the session to use for the request. + token: str, the SSO token. + voice: str, the voice to use for the request. + personality: str, the personality to use for the request. + speed: float, the speed to use for the request. + livekit_url: str, the LiveKit URL to use for the request. + + Returns: + Dict[str, Any]: The LiveKit token. + """ + try: + # Get proxies + base_proxy = get_config("network.base_proxy_url") + proxies = {"http": base_proxy, "https": base_proxy} if base_proxy else None + + # Build headers + headers = build_headers( + cookie_token=token, + content_type="application/json", + origin="https://grok.com", + referer="https://grok.com/", + ) + + # Build payload + payload = { + "sessionPayload": orjson.dumps( + { + "voice": voice, + "personality": personality, + "playback_speed": speed, + "enable_vision": False, + "turn_detection": {"type": "server_vad"}, + } + ).decode(), + "requestAgentDispatch": False, + "livekitUrl": livekit_url, + "params": {"enable_markdown_transcript": "true"}, + } + + # Curl Config + timeout = get_config("network.timeout") + browser = get_config("security.browser") + + async def _do_request(): + response = await session.post( + LIVEKIT_TOKEN_API, + headers=headers, + data=orjson.dumps(payload), + timeout=timeout, + proxies=proxies, + impersonate=browser, + ) + + if response.status_code != 200: + body = response.text[:200] + logger.error( + f"LivekitTokenReverse: Request failed, {response.status_code}, body={body}" + ) + raise UpstreamException( + message=f"LivekitTokenReverse: Request failed, {response.status_code}", + details={"status": response.status_code, "body": response.text}, + ) + + return response + + return await retry_on_status(_do_request) + + + except Exception as e: + # Handle upstream exception + if isinstance(e, UpstreamException): + status = None + if e.details and "status" in e.details: + status = e.details["status"] + else: + status = getattr(e, "status_code", None) + if status == 401: + try: + await TokenService.record_fail( + token, status, "livekit_token_auth_failed" + ) + except Exception: + pass + raise + + # Handle other non-upstream exceptions + logger.error( + f"LivekitTokenReverse: Request failed, {str(e)}", + extra={"error_type": type(e).__name__}, + ) + raise UpstreamException( + message=f"LivekitTokenReverse: Request failed, {str(e)}", + details={"status": 502, "error": str(e)}, + ) + + +class LivekitWebSocketReverse: + """LiveKit WebSocket reverse interface.""" + + def __init__(self) -> None: + self._client = WebSocketClient() + + def build_url( + self, + access_token: str, + *, + livekit_url: str = LIVEKIT_WS_URL, + auto_subscribe: bool = True, + sdk: str = "js", + version: str = "2.11.4", + protocol: int = 15, + ) -> str: + """Build LiveKit WebSocket URL. + + Args: + access_token: str, the LiveKit access token. + livekit_url: str, the LiveKit URL to use for the request. + auto_subscribe: bool, whether to auto subscribe to the WebSocket. + sdk: str, the SDK to use for the request. + version: str, the version to use for the request. + protocol: int, the protocol to use for the request. + + Returns: + str: The LiveKit WebSocket URL. + """ + # Build base URL + base = livekit_url.rstrip("/") + if not base.endswith("/rtc"): + base = f"{base}/rtc" + + # Build parameters + params = { + "access_token": access_token, + "auto_subscribe": str(int(auto_subscribe)), + "sdk": sdk, + "version": version, + "protocol": str(protocol), + } + + return f"{base}?{urlencode(params)}" + + def _build_headers(self, extra: Dict[str, str] | None = None) -> Dict[str, str]: + """Build LiveKit WebSocket headers.""" + # Build headers + headers = { + "Origin": "https://grok.com", + "User-Agent": get_config("security.user_agent"), + "Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8", + "Cache-Control": "no-cache", + "Pragma": "no-cache", + } + + # Update headers + if extra: + headers.update(extra) + return headers + + async def connect( + self, + access_token: str, + *, + livekit_url: str = LIVEKIT_WS_URL, + auto_subscribe: bool = True, + sdk: str = "js", + version: str = "2.11.4", + protocol: int = 15, + headers: Dict[str, str] | None = None, + timeout: float | None = None, + ) -> WebSocketConnection: + """Connect to the LiveKit WebSocket. + + Args: + access_token: str, the LiveKit access token. + livekit_url: str, the LiveKit URL to use for the request. + auto_subscribe: bool, whether to auto subscribe to the WebSocket. + sdk: str, the SDK to use for the request. + version: str, the version to use for the request. + protocol: int, the protocol to use for the request. + headers: Dict[str, str], the headers to send. + timeout: float, the timeout to use for the request. + + Returns: + WebSocketConnection: The LiveKit WebSocket connection. + """ + # Build URL + url = self.build_url( + access_token, + livekit_url=livekit_url, + auto_subscribe=auto_subscribe, + sdk=sdk, + version=version, + protocol=protocol, + ) + + # Build WebSocket headers + ws_headers = self._build_headers(headers) + + # Build timeout + if timeout is None: + timeout = get_config("network.timeout") + + # Connect to the LiveKit WebSocket + try: + return await self._client.connect(url, headers=ws_headers, timeout=timeout) + except Exception as e: + logger.error(f"LivekitWebSocketReverse: Connect failed, {e}") + raise UpstreamException( + f"LivekitWebSocketReverse: Connect failed, {str(e)}" + ) + + +__all__ = [ + "LivekitTokenReverse", + "LivekitWebSocketReverse", + "LIVEKIT_TOKEN_API", + "LIVEKIT_WS_URL", +] From 4573a1928e63fc6eb751436431f17c4439abe3f1 Mon Sep 17 00:00:00 2001 From: Chenyme <118253778+chenyme@users.noreply.github.com> Date: Tue, 10 Feb 2026 10:46:51 +0800 Subject: [PATCH 07/27] fix: response handling in LivekitTokenReverse --- app/services/reverse/utils/headers.py | 2 +- app/services/reverse/ws_livekit.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/app/services/reverse/utils/headers.py b/app/services/reverse/utils/headers.py index e26de015..7b415778 100644 --- a/app/services/reverse/utils/headers.py +++ b/app/services/reverse/utils/headers.py @@ -96,7 +96,7 @@ def build_headers( safe_headers = dict(headers) if "Cookie" in safe_headers: safe_headers["Cookie"] = "" - logger.debug(f"Built headers: {orjson.dumps(safe_headers, indent=2)}") + logger.debug(f"Built headers: {orjson.dumps(safe_headers).decode()}") return headers diff --git a/app/services/reverse/ws_livekit.py b/app/services/reverse/ws_livekit.py index 52eb84ea..e810fa3c 100644 --- a/app/services/reverse/ws_livekit.py +++ b/app/services/reverse/ws_livekit.py @@ -99,7 +99,8 @@ async def _do_request(): return response - return await retry_on_status(_do_request) + response = await retry_on_status(_do_request) + return response.json() except Exception as e: From 06357ad6c5cab4b46cebe0aaccda5a60bc768348 Mon Sep 17 00:00:00 2001 From: Chenyme <118253778+chenyme@users.noreply.github.com> Date: Tue, 10 Feb 2026 11:46:22 +0800 Subject: [PATCH 08/27] refactor: remove gRPC-Web protocol implementation and integrate WebSocket reverse handlers for image service --- app/services/grok/protocols/grpc_web.py | 157 -------------- app/services/grok/services/image.py | 274 +----------------------- app/services/grok/services/voice.py | 3 +- app/services/reverse/__init__.py | 2 + app/services/reverse/utils/grpc.py | 27 +-- app/services/reverse/utils/headers.py | 56 +++-- app/services/reverse/utils/websocket.py | 32 ++- app/services/reverse/ws_imagine.py | 262 ++++++++++++++++++++++ app/services/reverse/ws_livekit.py | 108 ++-------- 9 files changed, 345 insertions(+), 576 deletions(-) create mode 100644 app/services/reverse/ws_imagine.py diff --git a/app/services/grok/protocols/grpc_web.py b/app/services/grok/protocols/grpc_web.py index 0724727d..e69de29b 100644 --- a/app/services/grok/protocols/grpc_web.py +++ b/app/services/grok/protocols/grpc_web.py @@ -1,157 +0,0 @@ -""" -gRPC-Web 协议工具 - -提供 framing 编码/解码、trailer 解析等通用功能。 -支持 application/grpc-web+proto 和 application/grpc-web-text (base64) 两种格式。 -""" - -from __future__ import annotations - -import base64 -import re -import struct -from dataclasses import dataclass -from typing import Dict, List, Mapping, Tuple -from urllib.parse import unquote - - -_B64_RE = re.compile(rb"^[A-Za-z0-9+/=\r\n]+$") - - -def encode_grpc_web_payload(data: bytes) -> bytes: - """ - 编码 gRPC-Web data frame - - Frame format: - 1-byte flags + 4-byte big-endian length + message bytes - """ - return b"\x00" + struct.pack(">I", len(data)) + data - - -def _maybe_decode_grpc_web_text(body: bytes, content_type: str | None) -> bytes: - """处理 grpc-web-text 模式的 base64 解码""" - ct = (content_type or "").lower() - if "grpc-web-text" in ct: - compact = b"".join(body.split()) - return base64.b64decode(compact, validate=False) - - # 启发式:body 仅包含 base64 字符才尝试解码 - head = body[: min(len(body), 2048)] - if head and _B64_RE.fullmatch(head): - compact = b"".join(body.split()) - try: - return base64.b64decode(compact, validate=True) - except Exception: - return body - return body - - -def _parse_trailer_block(payload: bytes) -> Dict[str, str]: - """解析 trailer frame 内容""" - text = payload.decode("utf-8", errors="replace") - lines = [ln for ln in re.split(r"\r\n|\n", text) if ln] - - trailers: Dict[str, str] = {} - for ln in lines: - if ":" not in ln: - continue - k, v = ln.split(":", 1) - trailers[k.strip().lower()] = v.strip() - - # grpc-message 可能是 percent-encoding - if "grpc-message" in trailers: - trailers["grpc-message"] = unquote(trailers["grpc-message"]) - - return trailers - - -def parse_grpc_web_response( - body: bytes, - content_type: str | None = None, - headers: Mapping[str, str] | None = None, -) -> Tuple[List[bytes], Dict[str, str]]: - """ - 解析 gRPC-Web 响应 - - Returns: - (messages, trailers): data frames 列表和合并后的 trailers - """ - decoded = _maybe_decode_grpc_web_text(body, content_type) - - messages: List[bytes] = [] - trailers: Dict[str, str] = {} - - i = 0 - n = len(decoded) - while i < n: - if n - i < 5: - break - - flag = decoded[i] - length = int.from_bytes(decoded[i + 1 : i + 5], "big") - i += 5 - - if n - i < length: - break - - payload = decoded[i : i + length] - i += length - - if flag & 0x80: # trailer frame - trailers.update(_parse_trailer_block(payload)) - elif flag & 0x01: # compressed (不支持) - raise ValueError("grpc-web compressed flag not supported") - else: - messages.append(payload) - - # 兼容:grpc-status 可能在 response headers 中 - if headers: - lower = {k.lower(): v for k, v in headers.items()} - if "grpc-status" in lower and "grpc-status" not in trailers: - trailers["grpc-status"] = str(lower["grpc-status"]).strip() - if "grpc-message" in lower and "grpc-message" not in trailers: - trailers["grpc-message"] = unquote(str(lower["grpc-message"]).strip()) - - return messages, trailers - - -@dataclass(frozen=True) -class GrpcStatus: - code: int - message: str = "" - - @property - def ok(self) -> bool: - return self.code == 0 - - @property - def http_equiv(self) -> int: - """映射到类 HTTP 状态码""" - mapping = { - 0: 200, # OK - 16: 401, # UNAUTHENTICATED - 7: 403, # PERMISSION_DENIED - 8: 429, # RESOURCE_EXHAUSTED - 4: 504, # DEADLINE_EXCEEDED - 14: 503, # UNAVAILABLE - } - return mapping.get(self.code, 502) - - -def get_grpc_status(trailers: Mapping[str, str]) -> GrpcStatus: - """从 trailers 提取 gRPC 状态""" - raw = str(trailers.get("grpc-status", "")).strip() - msg = str(trailers.get("grpc-message", "")).strip() - try: - code = int(raw) - except Exception: - code = -1 - return GrpcStatus(code=code, message=msg) - - -__all__ = [ - "encode_grpc_web_payload", - "parse_grpc_web_response", - "get_grpc_status", - "GrpcStatus", -] diff --git a/app/services/grok/services/image.py b/app/services/grok/services/image.py index 218c334b..73139767 100644 --- a/app/services/grok/services/image.py +++ b/app/services/grok/services/image.py @@ -2,278 +2,10 @@ Grok Imagine WebSocket image service. """ -import asyncio -import certifi -import json -import re -import ssl -import time -import uuid -from typing import AsyncGenerator, Dict, Optional +from app.services.reverse.ws_imagine import ImagineWebSocketReverse -import aiohttp -from app.core.config import get_config -from app.core.logger import logger -from app.services.grok.utils.headers import build_sso_cookie -from app.services.reverse.utils.websocket import resolve_proxy - -WS_URL = "wss://grok.com/ws/imagine/listen" - - -class _BlockedError(Exception): - pass - - -class ImageService: - """Grok Imagine WebSocket image service.""" - - def __init__(self): - self._ssl_context = ssl.create_default_context() - self._ssl_context.load_verify_locations(certifi.where()) - self._url_pattern = re.compile(r"/images/([a-f0-9-]+)\.(png|jpg|jpeg)") - - def _resolve_proxy(self) -> tuple[aiohttp.BaseConnector, Optional[str]]: - proxy_url = get_config("network.base_proxy_url") - return resolve_proxy(proxy_url, self._ssl_context) - - def _get_ws_headers(self, token: str) -> Dict[str, str]: - cookie = build_sso_cookie(token, include_rw=True) - user_agent = get_config("security.user_agent") - return { - "Cookie": cookie, - "Origin": "https://grok.com", - "User-Agent": user_agent, - "Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8", - "Cache-Control": "no-cache", - "Pragma": "no-cache", - } - - def _extract_image_id(self, url: str) -> Optional[str]: - match = self._url_pattern.search(url or "") - return match.group(1) if match else None - - def _is_final_image(self, url: str, blob_size: int) -> bool: - return (url or "").lower().endswith( - (".jpg", ".jpeg") - ) and blob_size > get_config("image.image_ws_final_min_bytes") - - def _classify_image(self, url: str, blob: str) -> Optional[Dict[str, object]]: - if not url or not blob: - return None - - image_id = self._extract_image_id(url) or uuid.uuid4().hex - blob_size = len(blob) - is_final = self._is_final_image(url, blob_size) - - stage = ( - "final" - if is_final - else ( - "medium" - if blob_size > get_config("image.image_ws_medium_min_bytes") - else "preview" - ) - ) - - return { - "type": "image", - "image_id": image_id, - "stage": stage, - "blob": blob, - "blob_size": blob_size, - "url": url, - "is_final": is_final, - } - - async def stream( - self, - token: str, - prompt: str, - aspect_ratio: str = "2:3", - n: int = 1, - enable_nsfw: bool = True, - max_retries: int = None, - ) -> AsyncGenerator[Dict[str, object], None]: - retries = max(1, max_retries if max_retries is not None else 1) - logger.info( - f"Image generation: prompt='{prompt[:50]}...', n={n}, ratio={aspect_ratio}, nsfw={enable_nsfw}" - ) - - for attempt in range(retries): - try: - yielded_any = False - async for item in self._stream_once( - token, prompt, aspect_ratio, n, enable_nsfw - ): - yielded_any = True - yield item - return - except _BlockedError: - if yielded_any or attempt + 1 >= retries: - if not yielded_any: - yield { - "type": "error", - "error_code": "blocked", - "error": "blocked_no_final_image", - } - return - logger.warning(f"WebSocket blocked, retry {attempt + 1}/{retries}") - except Exception as e: - logger.error(f"WebSocket stream failed: {e}") - return - - async def _stream_once( - self, - token: str, - prompt: str, - aspect_ratio: str, - n: int, - enable_nsfw: bool, - ) -> AsyncGenerator[Dict[str, object], None]: - request_id = str(uuid.uuid4()) - headers = self._get_ws_headers(token) - timeout = float(get_config("network.timeout")) - blocked_seconds = float(get_config("image.image_ws_blocked_seconds")) - - try: - connector, proxy = self._resolve_proxy() - except Exception as e: - logger.error(f"WebSocket proxy setup failed: {e}") - return - - try: - async with aiohttp.ClientSession(connector=connector) as session: - async with session.ws_connect( - WS_URL, - headers=headers, - heartbeat=20, - receive_timeout=timeout, - proxy=proxy, - ) as ws: - message = { - "type": "conversation.item.create", - "timestamp": int(time.time() * 1000), - "item": { - "type": "message", - "content": [ - { - "requestId": request_id, - "text": prompt, - "type": "input_text", - "properties": { - "section_count": 0, - "is_kids_mode": False, - "enable_nsfw": enable_nsfw, - "skip_upsampler": False, - "is_initial": False, - "aspect_ratio": aspect_ratio, - }, - } - ], - }, - } - - await ws.send_json(message) - logger.info(f"WebSocket request sent: {prompt[:80]}...") - - images = {} - completed = 0 - start_time = last_activity = time.time() - medium_received_time = None - - while time.time() - start_time < timeout: - try: - ws_msg = await asyncio.wait_for(ws.receive(), timeout=5.0) - except asyncio.TimeoutError: - if ( - medium_received_time - and completed == 0 - and time.time() - medium_received_time - > min(10, blocked_seconds) - ): - raise _BlockedError() - if completed > 0 and time.time() - last_activity > 10: - logger.info( - f"WebSocket idle timeout, collected {completed} images" - ) - break - continue - - if ws_msg.type == aiohttp.WSMsgType.TEXT: - last_activity = time.time() - msg = json.loads(ws_msg.data) - msg_type = msg.get("type") - - if msg_type == "image": - info = self._classify_image( - msg.get("url", ""), msg.get("blob", "") - ) - if not info: - continue - - image_id = info["image_id"] - existing = images.get(image_id, {}) - - if ( - info["stage"] == "medium" - and medium_received_time is None - ): - medium_received_time = time.time() - - if info["is_final"] and not existing.get("is_final"): - completed += 1 - logger.debug( - f"Final image received: id={image_id}, size={info['blob_size']}" - ) - - images[image_id] = { - "is_final": info["is_final"] - or existing.get("is_final") - } - yield info - - elif msg_type == "error": - logger.warning( - f"WebSocket error: {msg.get('err_code', '')} - {msg.get('err_msg', '')}" - ) - yield { - "type": "error", - "error_code": msg.get("err_code", ""), - "error": msg.get("err_msg", ""), - } - return - - if completed >= n: - logger.info( - f"WebSocket collected {completed} final images" - ) - break - - if ( - medium_received_time - and completed == 0 - and time.time() - medium_received_time > blocked_seconds - ): - raise _BlockedError() - - elif ws_msg.type in ( - aiohttp.WSMsgType.CLOSED, - aiohttp.WSMsgType.ERROR, - ): - logger.warning(f"WebSocket closed/error: {ws_msg.type}") - yield { - "type": "error", - "error_code": "ws_closed", - "error": f"websocket closed: {ws_msg.type}", - } - break - - except aiohttp.ClientError as e: - logger.error(f"WebSocket connection error: {e}") - yield {"type": "error", "error_code": "connection_failed", "error": str(e)} - - -image_service = ImageService() +ImageService = ImagineWebSocketReverse +image_service = ImagineWebSocketReverse() __all__ = ["image_service", "ImageService"] diff --git a/app/services/grok/services/voice.py b/app/services/grok/services/voice.py index 208bf954..81515dc2 100644 --- a/app/services/grok/services/voice.py +++ b/app/services/grok/services/voice.py @@ -22,10 +22,11 @@ async def get_token( ) -> Dict[str, Any]: browser = get_config("security.browser") async with AsyncSession(impersonate=browser) as session: - return await LivekitTokenReverse.request( + response = await LivekitTokenReverse.request( session, token=token, voice=voice, personality=personality, speed=speed, ) + return response.json() diff --git a/app/services/reverse/__init__.py b/app/services/reverse/__init__.py index 69594b5b..08734a8e 100644 --- a/app/services/reverse/__init__.py +++ b/app/services/reverse/__init__.py @@ -10,6 +10,7 @@ from .rate_limits import RateLimitsReverse from .set_birth import SetBirthReverse from .ws_livekit import LivekitTokenReverse, LivekitWebSocketReverse +from .ws_imagine import ImagineWebSocketReverse from .utils.headers import build_headers from .utils.statsig import StatsigGenerator @@ -25,6 +26,7 @@ "SetBirthReverse", "LivekitTokenReverse", "LivekitWebSocketReverse", + "ImagineWebSocketReverse", "StatsigGenerator", "build_headers", ] diff --git a/app/services/reverse/utils/grpc.py b/app/services/reverse/utils/grpc.py index 2cb26883..07ee84dd 100644 --- a/app/services/reverse/utils/grpc.py +++ b/app/services/reverse/utils/grpc.py @@ -6,7 +6,7 @@ import re import struct from dataclasses import dataclass -from typing import Dict, List, Mapping, Tuple +from typing import Dict, List, Mapping, Optional, Tuple from urllib.parse import unquote @@ -41,7 +41,7 @@ def encode_payload(data: bytes) -> bytes: return b"\x00" + struct.pack(">I", len(data)) + data @staticmethod - def _maybe_decode_grpc_web_text(body: bytes, content_type: str | None) -> bytes: + def _maybe_decode_grpc_web_text(body: bytes, content_type: Optional[str]) -> bytes: ct = (content_type or "").lower() if "grpc-web-text" in ct: compact = b"".join(body.split()) @@ -77,8 +77,8 @@ def _parse_trailer_block(payload: bytes) -> Dict[str, str]: def parse_response( cls, body: bytes, - content_type: str | None = None, - headers: Mapping[str, str] | None = None, + content_type: Optional[str] = None, + headers: Optional[Mapping[str, str]] = None, ) -> Tuple[List[bytes], Dict[str, str]]: decoded = cls._maybe_decode_grpc_web_text(body, content_type) @@ -128,26 +128,7 @@ def get_status(trailers: Mapping[str, str]) -> GrpcStatus: return GrpcStatus(code=code, message=msg) -def encode_grpc_web_payload(data: bytes) -> bytes: - return GrpcClient.encode_payload(data) - - -def parse_grpc_web_response( - body: bytes, - content_type: str | None = None, - headers: Mapping[str, str] | None = None, -) -> Tuple[List[bytes], Dict[str, str]]: - return GrpcClient.parse_response(body, content_type=content_type, headers=headers) - - -def get_grpc_status(trailers: Mapping[str, str]) -> GrpcStatus: - return GrpcClient.get_status(trailers) - - __all__ = [ - "encode_grpc_web_payload", - "parse_grpc_web_response", - "get_grpc_status", "GrpcStatus", "GrpcClient", ] diff --git a/app/services/reverse/utils/headers.py b/app/services/reverse/utils/headers.py index 7b415778..03a8f253 100644 --- a/app/services/reverse/utils/headers.py +++ b/app/services/reverse/utils/headers.py @@ -10,9 +10,15 @@ from app.services.reverse.utils.statsig import StatsigGenerator -def _build_sso_cookie(sso_token: str) -> str: +def build_sso_cookie(sso_token: str) -> str: """ Build SSO Cookie string. + + Args: + sso_token: str, the SSO token. + + Returns: + str: The SSO Cookie string. """ # Format sso_token = sso_token[4:] if sso_token.startswith("sso=") else sso_token @@ -28,20 +34,44 @@ def _build_sso_cookie(sso_token: str) -> str: return cookie -def build_headers( - cookie_token: str, - content_type: Optional[str] = None, - origin: Optional[str] = None, - referer: Optional[str] = None, -) -> Dict[str, str]: +def build_ws_headers(token: Optional[str] = None, origin: Optional[str] = None, extra: Optional[Dict[str, str]] = None) -> Dict[str, str]: + """ + Build headers for WebSocket requests. + + Args: + token: Optional[str], the SSO token for Cookie. Defaults to None. + origin: Optional[str], the Origin value. Defaults to "https://grok.com" if not provided. + extra: Optional[Dict[str, str]], extra headers to merge. Defaults to None. + + Returns: + Dict[str, str]: The headers dictionary. + """ + headers = { + "Origin": origin or "https://grok.com", + "User-Agent": get_config("security.user_agent"), + "Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8", + "Cache-Control": "no-cache", + "Pragma": "no-cache", + } + + if token: + headers["Cookie"] = build_sso_cookie(token) + + if extra: + headers.update(extra) + + return headers + + +def build_headers(cookie_token: str, content_type: Optional[str] = None, origin: Optional[str] = None, referer: Optional[str] = None) -> Dict[str, str]: """ Build headers for reverse interfaces. Args: - cookie_token: The SSO token. - content_type: Optional Content-Type value. - origin: Optional Origin value. Defaults to "https://grok.com" if not provided. - referer: Optional Referer value. Defaults to "https://grok.com/" if not provided. + cookie_token: str, the SSO token. + content_type: Optional[str], the Content-Type value. + origin: Optional[str], the Origin value. Defaults to "https://grok.com" if not provided. + referer: Optional[str], the Referer value. Defaults to "https://grok.com/" if not provided. Returns: Dict[str, str]: The headers dictionary. @@ -64,7 +94,7 @@ def build_headers( } # Cookie - headers["Cookie"] = _build_sso_cookie(cookie_token) + headers["Cookie"] = build_sso_cookie(cookie_token) # Content-Type and Accept/Sec-Fetch-Dest if content_type and content_type == "application/json": @@ -101,4 +131,4 @@ def build_headers( return headers -__all__ = ["build_headers"] +__all__ = ["build_headers", "build_sso_cookie", "build_ws_headers"] diff --git a/app/services/reverse/utils/websocket.py b/app/services/reverse/utils/websocket.py index a4a8ffb9..67f15164 100644 --- a/app/services/reverse/utils/websocket.py +++ b/app/services/reverse/utils/websocket.py @@ -37,14 +37,12 @@ def _normalize_socks_proxy(proxy_url: str) -> tuple[str, Optional[bool]]: return proxy_url, rdns -def resolve_proxy( - proxy_url: str | None, ssl_context: ssl.SSLContext -) -> tuple[aiohttp.BaseConnector, Optional[str]]: +def resolve_proxy(proxy_url: Optional[str] = None, ssl_context: ssl.SSLContext = _default_ssl_context()) -> tuple[aiohttp.BaseConnector, Optional[str]]: """Resolve proxy connector. Args: - proxy_url: str, the proxy URL. - ssl_context: ssl.SSLContext, the SSL context. + proxy_url: Optional[str], the proxy URL. Defaults to None. + ssl_context: ssl.SSLContext, the SSL context. Defaults to _default_ssl_context(). Returns: tuple[aiohttp.BaseConnector, Optional[str]]: The proxy connector and the proxy URL. @@ -73,11 +71,7 @@ def resolve_proxy( class WebSocketConnection: """WebSocket connection wrapper.""" - def __init__( - self, - session: aiohttp.ClientSession, - ws: aiohttp.ClientWebSocketResponse, - ) -> None: + def __init__(self, session: aiohttp.ClientSession, ws: aiohttp.ClientWebSocketResponse) -> None: self.session = session self.ws = ws @@ -96,22 +90,22 @@ async def __aexit__(self, exc_type, exc, tb) -> None: class WebSocketClient: """WebSocket client with proxy support.""" - def __init__(self, proxy: str | None = None) -> None: + def __init__(self, proxy: Optional[str] = None) -> None: self.proxy = proxy or get_config("network.base_proxy_url") self._ssl_context = _default_ssl_context() async def connect( self, url: str, - headers: Mapping[str, str] | None = None, - timeout: float | aiohttp.ClientTimeout | None = None, + headers: Optional[Mapping[str, str]] = None, + ws_kwargs: Optional[Mapping[str, object]] = None, ) -> WebSocketConnection: """Connect to the WebSocket. Args: url: str, the URL to connect to. - headers: Mapping[str, str], the headers to send. - timeout: float | aiohttp.ClientTimeout | None, the timeout. + headers: Optional[Mapping[str, str]], the headers to send. Defaults to None. + ws_kwargs: Optional[Mapping[str, object]], extra ws_connect kwargs. Defaults to None. Returns: WebSocketConnection: The WebSocket connection. @@ -120,20 +114,18 @@ async def connect( connector, proxy = resolve_proxy(self.proxy, self._ssl_context) # Build client timeout - client_timeout = ( - timeout - if isinstance(timeout, aiohttp.ClientTimeout) - else aiohttp.ClientTimeout(total=timeout) - ) + client_timeout = aiohttp.ClientTimeout(total=get_config("network.timeout")) # Create session session = aiohttp.ClientSession(connector=connector, timeout=client_timeout) try: + extra_kwargs = dict(ws_kwargs or {}) ws = await session.ws_connect( url, headers=headers, proxy=proxy, ssl=self._ssl_context, + **extra_kwargs, ) return WebSocketConnection(session, ws) except Exception: diff --git a/app/services/reverse/ws_imagine.py b/app/services/reverse/ws_imagine.py new file mode 100644 index 00000000..d4c3c7cf --- /dev/null +++ b/app/services/reverse/ws_imagine.py @@ -0,0 +1,262 @@ +""" +Reverse interface: Imagine WebSocket image stream. +""" + +import asyncio +import orjson +import re +import time +import uuid +from typing import AsyncGenerator, Dict, Optional + +import aiohttp + +from app.core.config import get_config +from app.core.logger import logger +from app.services.reverse.utils.headers import build_ws_headers +from app.services.reverse.utils.websocket import WebSocketClient + +WS_IMAGINE_URL = "wss://grok.com/ws/imagine/listen" + + +class _BlockedError(Exception): + pass + + +class ImagineWebSocketReverse: + """Imagine WebSocket reverse interface.""" + + def __init__(self) -> None: + self._url_pattern = re.compile(r"/images/([a-f0-9-]+)\.(png|jpg|jpeg)") + self._client = WebSocketClient() + + def _extract_image_id(self, url: str) -> Optional[str]: + match = self._url_pattern.search(url or "") + return match.group(1) if match else None + + def _is_final_image(self, url: str, blob_size: int, final_min_bytes: int) -> bool: + return (url or "").lower().endswith((".jpg", ".jpeg")) and blob_size > final_min_bytes + + def _classify_image(self, url: str, blob: str, final_min_bytes: int, medium_min_bytes: int) -> Optional[Dict[str, object]]: + if not url or not blob: + return None + + image_id = self._extract_image_id(url) or uuid.uuid4().hex + blob_size = len(blob) + is_final = self._is_final_image(url, blob_size, final_min_bytes) + + stage = ( + "final" + if is_final + else ("medium" if blob_size > medium_min_bytes else "preview") + ) + + return { + "type": "image", + "image_id": image_id, + "stage": stage, + "blob": blob, + "blob_size": blob_size, + "url": url, + "is_final": is_final, + } + + def _build_request_message(self, request_id: str, prompt: str, aspect_ratio: str, enable_nsfw: bool) -> Dict[str, object]: + return { + "type": "conversation.item.create", + "timestamp": int(time.time() * 1000), + "item": { + "type": "message", + "content": [ + { + "requestId": request_id, + "text": prompt, + "type": "input_text", + "properties": { + "section_count": 0, + "is_kids_mode": False, + "enable_nsfw": enable_nsfw, + "skip_upsampler": False, + "is_initial": False, + "aspect_ratio": aspect_ratio, + }, + } + ], + }, + } + + async def stream( + self, + token: str, + prompt: str, + aspect_ratio: str = "2:3", + n: int = 1, + enable_nsfw: bool = True, + max_retries: Optional[int] = None, + ) -> AsyncGenerator[Dict[str, object], None]: + retries = max(1, max_retries if max_retries is not None else 1) + logger.info( + f"Image generation: prompt='{prompt[:50]}...', n={n}, ratio={aspect_ratio}, nsfw={enable_nsfw}" + ) + + for attempt in range(retries): + try: + yielded_any = False + async for item in self._stream_once( + token, prompt, aspect_ratio, n, enable_nsfw + ): + yielded_any = True + yield item + return + except _BlockedError: + if yielded_any or attempt + 1 >= retries: + if not yielded_any: + yield { + "type": "error", + "error_code": "blocked", + "error": "blocked_no_final_image", + } + return + logger.warning(f"WebSocket blocked, retry {attempt + 1}/{retries}") + except Exception as e: + logger.error(f"WebSocket stream failed: {e}") + return + + async def _stream_once( + self, + token: str, + prompt: str, + aspect_ratio: str, + n: int, + enable_nsfw: bool, + ) -> AsyncGenerator[Dict[str, object], None]: + request_id = str(uuid.uuid4()) + headers = build_ws_headers(token=token) + timeout = float(get_config("network.timeout")) + blocked_seconds = float(get_config("image.image_ws_blocked_seconds")) + blocked_grace = min(10.0, blocked_seconds) + final_min_bytes = int(get_config("image.image_ws_final_min_bytes")) + medium_min_bytes = int(get_config("image.image_ws_medium_min_bytes")) + + try: + conn = await self._client.connect( + WS_IMAGINE_URL, + headers=headers, + ws_kwargs={ + "heartbeat": 20, + "receive_timeout": timeout, + }, + ) + except Exception as e: + logger.error(f"WebSocket connect failed: {e}") + yield { + "type": "error", + "error_code": "connection_failed", + "error": str(e), + } + return + + try: + async with conn as ws: + message = self._build_request_message( + request_id, prompt, aspect_ratio, enable_nsfw + ) + await ws.send_json(message) + logger.info(f"WebSocket request sent: {prompt[:80]}...") + + final_ids: set[str] = set() + completed = 0 + start_time = last_activity = time.monotonic() + medium_received_time: Optional[float] = None + + while time.monotonic() - start_time < timeout: + try: + ws_msg = await asyncio.wait_for(ws.receive(), timeout=5.0) + except asyncio.TimeoutError: + now = time.monotonic() + if ( + medium_received_time + and completed == 0 + and now - medium_received_time > blocked_grace + ): + raise _BlockedError() + if completed > 0 and now - last_activity > 10: + logger.info( + f"WebSocket idle timeout, collected {completed} images" + ) + break + continue + + if ws_msg.type == aiohttp.WSMsgType.TEXT: + last_activity = time.monotonic() + try: + msg = orjson.loads(ws_msg.data) + except orjson.JSONDecodeError as e: + logger.warning(f"WebSocket message decode failed: {e}") + continue + + msg_type = msg.get("type") + + if msg_type == "image": + info = self._classify_image( + msg.get("url", ""), + msg.get("blob", ""), + final_min_bytes, + medium_min_bytes, + ) + if not info: + continue + + image_id = info["image_id"] + if info["stage"] == "medium" and medium_received_time is None: + medium_received_time = time.monotonic() + + if info["is_final"] and image_id not in final_ids: + final_ids.add(image_id) + completed += 1 + logger.debug( + f"Final image received: id={image_id}, size={info['blob_size']}" + ) + + yield info + + elif msg_type == "error": + logger.warning( + f"WebSocket error: {msg.get('err_code', '')} - {msg.get('err_msg', '')}" + ) + yield { + "type": "error", + "error_code": msg.get("err_code", ""), + "error": msg.get("err_msg", ""), + } + return + + if completed >= n: + logger.info(f"WebSocket collected {completed} final images") + break + + if ( + medium_received_time + and completed == 0 + and time.monotonic() - medium_received_time > blocked_seconds + ): + raise _BlockedError() + + elif ws_msg.type in ( + aiohttp.WSMsgType.CLOSED, + aiohttp.WSMsgType.ERROR, + ): + logger.warning(f"WebSocket closed/error: {ws_msg.type}") + yield { + "type": "error", + "error_code": "ws_closed", + "error": f"websocket closed: {ws_msg.type}", + } + break + + except aiohttp.ClientError as e: + logger.error(f"WebSocket connection error: {e}") + yield {"type": "error", "error_code": "connection_failed", "error": str(e)} + + +__all__ = ["ImagineWebSocketReverse", "WS_IMAGINE_URL"] diff --git a/app/services/reverse/ws_livekit.py b/app/services/reverse/ws_livekit.py index e810fa3c..095e31b2 100644 --- a/app/services/reverse/ws_livekit.py +++ b/app/services/reverse/ws_livekit.py @@ -11,7 +11,7 @@ from app.core.config import get_config from app.core.exceptions import UpstreamException from app.services.token.service import TokenService -from app.services.reverse.utils.headers import build_headers +from app.services.reverse.utils.headers import build_headers, build_ws_headers from app.services.reverse.utils.retry import retry_on_status from app.services.reverse.utils.websocket import WebSocketClient, WebSocketConnection @@ -29,7 +29,6 @@ async def request( voice: str = "ara", personality: str = "assistant", speed: float = 1.0, - livekit_url: str = LIVEKIT_WS_URL, ) -> Dict[str, Any]: """Fetch LiveKit token. @@ -39,7 +38,6 @@ async def request( voice: str, the voice to use for the request. personality: str, the personality to use for the request. speed: float, the speed to use for the request. - livekit_url: str, the LiveKit URL to use for the request. Returns: Dict[str, Any]: The LiveKit token. @@ -69,7 +67,7 @@ async def request( } ).decode(), "requestAgentDispatch": False, - "livekitUrl": livekit_url, + "livekitUrl": LIVEKIT_WS_URL, "params": {"enable_markdown_transcript": "true"}, } @@ -100,8 +98,7 @@ async def _do_request(): return response response = await retry_on_status(_do_request) - return response.json() - + return response except Exception as e: # Handle upstream exception @@ -137,108 +134,37 @@ class LivekitWebSocketReverse: def __init__(self) -> None: self._client = WebSocketClient() - def build_url( - self, - access_token: str, - *, - livekit_url: str = LIVEKIT_WS_URL, - auto_subscribe: bool = True, - sdk: str = "js", - version: str = "2.11.4", - protocol: int = 15, - ) -> str: - """Build LiveKit WebSocket URL. + async def connect(self, token: str) -> WebSocketConnection: + """Connect to the LiveKit WebSocket. Args: - access_token: str, the LiveKit access token. - livekit_url: str, the LiveKit URL to use for the request. - auto_subscribe: bool, whether to auto subscribe to the WebSocket. - sdk: str, the SDK to use for the request. - version: str, the version to use for the request. - protocol: int, the protocol to use for the request. + token: str, the SSO token. Returns: - str: The LiveKit WebSocket URL. + WebSocketConnection: The LiveKit WebSocket connection. """ - # Build base URL - base = livekit_url.rstrip("/") + # Format URL + base = LIVEKIT_WS_URL.rstrip("/") if not base.endswith("/rtc"): base = f"{base}/rtc" # Build parameters params = { - "access_token": access_token, - "auto_subscribe": str(int(auto_subscribe)), - "sdk": sdk, - "version": version, - "protocol": str(protocol), + "access_token": token, + "auto_subscribe": "1", + "sdk": "js", + "version": "2.11.4", + "protocol": "15", } - return f"{base}?{urlencode(params)}" - - def _build_headers(self, extra: Dict[str, str] | None = None) -> Dict[str, str]: - """Build LiveKit WebSocket headers.""" - # Build headers - headers = { - "Origin": "https://grok.com", - "User-Agent": get_config("security.user_agent"), - "Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8", - "Cache-Control": "no-cache", - "Pragma": "no-cache", - } - - # Update headers - if extra: - headers.update(extra) - return headers - - async def connect( - self, - access_token: str, - *, - livekit_url: str = LIVEKIT_WS_URL, - auto_subscribe: bool = True, - sdk: str = "js", - version: str = "2.11.4", - protocol: int = 15, - headers: Dict[str, str] | None = None, - timeout: float | None = None, - ) -> WebSocketConnection: - """Connect to the LiveKit WebSocket. - - Args: - access_token: str, the LiveKit access token. - livekit_url: str, the LiveKit URL to use for the request. - auto_subscribe: bool, whether to auto subscribe to the WebSocket. - sdk: str, the SDK to use for the request. - version: str, the version to use for the request. - protocol: int, the protocol to use for the request. - headers: Dict[str, str], the headers to send. - timeout: float, the timeout to use for the request. - - Returns: - WebSocketConnection: The LiveKit WebSocket connection. - """ # Build URL - url = self.build_url( - access_token, - livekit_url=livekit_url, - auto_subscribe=auto_subscribe, - sdk=sdk, - version=version, - protocol=protocol, - ) + url = f"{base}?{urlencode(params)}" # Build WebSocket headers - ws_headers = self._build_headers(headers) - - # Build timeout - if timeout is None: - timeout = get_config("network.timeout") + ws_headers = build_ws_headers() - # Connect to the LiveKit WebSocket try: - return await self._client.connect(url, headers=ws_headers, timeout=timeout) + return await self._client.connect(url, headers=ws_headers) except Exception as e: logger.error(f"LivekitWebSocketReverse: Connect failed, {e}") raise UpstreamException( From 90fc01680765a87b65a0db4f26f1e80ca896da93 Mon Sep 17 00:00:00 2001 From: Chenyme <118253778+chenyme@users.noreply.github.com> Date: Wed, 11 Feb 2026 11:32:54 +0800 Subject: [PATCH 09/27] merge: remove unused header, retry, and Statsig utility files --- app/services/grok/utils/headers.py | 31 ---- app/services/grok/utils/retry.py | 264 ----------------------------- app/services/grok/utils/statsig.py | 51 ------ app/services/reverse/app_chat.py | 15 +- 4 files changed, 13 insertions(+), 348 deletions(-) delete mode 100644 app/services/grok/utils/headers.py delete mode 100644 app/services/grok/utils/retry.py delete mode 100644 app/services/grok/utils/statsig.py diff --git a/app/services/grok/utils/headers.py b/app/services/grok/utils/headers.py deleted file mode 100644 index 7a5e1c2a..00000000 --- a/app/services/grok/utils/headers.py +++ /dev/null @@ -1,31 +0,0 @@ -""" -Common header helpers for Grok services. -""" - -from __future__ import annotations - -import uuid -from typing import Dict - -from app.core.config import get_config -from app.services.grok.utils.statsig import StatsigService - - -def _normalize_token(token: str) -> str: - return token[4:] if token.startswith("sso=") else token - - -def build_sso_cookie(token: str, include_rw: bool = False) -> str: - token = _normalize_token(token) - cf = get_config("security.cf_clearance") - cookie = f"sso={token}" - if include_rw: - cookie = f"{cookie}; sso-rw={token}" - if cf: - cookie = f"{cookie};cf_clearance={cf}" - return cookie - - -def apply_statsig(headers: Dict[str, str]) -> None: - headers["x-statsig-id"] = StatsigService.gen_id() - headers["x-xai-request-id"] = str(uuid.uuid4()) diff --git a/app/services/grok/utils/retry.py b/app/services/grok/utils/retry.py deleted file mode 100644 index 162c4f8c..00000000 --- a/app/services/grok/utils/retry.py +++ /dev/null @@ -1,264 +0,0 @@ -""" -Grok API 重试工具 - -提供可配置的重试机制,支持: -- 指数退避 + decorrelated jitter -- Retry-After header 支持 -- 429 专用退避策略 -- 重试预算控制 -""" - -import asyncio -import random -from typing import Callable, Any, Optional -from functools import wraps - -from app.core.logger import logger -from app.core.config import get_config -from app.core.exceptions import UpstreamException - - -class RetryContext: - """重试上下文""" - - def __init__(self): - self.attempt = 0 - self.max_retry = int(get_config("retry.max_retry")) - self.retry_codes = get_config("retry.retry_status_codes") - self.last_error = None - self.last_status = None - self.total_delay = 0.0 - self.retry_budget = float(get_config("retry.retry_budget")) - - # 退避参数 - self.backoff_base = float(get_config("retry.retry_backoff_base")) - self.backoff_factor = float(get_config("retry.retry_backoff_factor")) - self.backoff_max = float(get_config("retry.retry_backoff_max")) - - # decorrelated jitter 状态 - self._last_delay = self.backoff_base - - def should_retry(self, status_code: int) -> bool: - """判断是否重试""" - if self.attempt >= self.max_retry: - return False - if status_code not in self.retry_codes: - return False - if self.total_delay >= self.retry_budget: - return False - return True - - def record_error(self, status_code: int, error: Exception): - """记录错误信息""" - self.last_status = status_code - self.last_error = error - self.attempt += 1 - - def calculate_delay( - self, status_code: int, retry_after: Optional[float] = None - ) -> float: - """ - 计算退避延迟时间 - - Args: - status_code: HTTP 状态码 - retry_after: Retry-After header 值(秒) - - Returns: - 延迟时间(秒) - """ - # 优先使用 Retry-After - if retry_after is not None and retry_after > 0: - delay = min(retry_after, self.backoff_max) - self._last_delay = delay - return delay - - # 429 使用 decorrelated jitter - if status_code == 429: - # decorrelated jitter: delay = random(base, last_delay * 3) - delay = random.uniform(self.backoff_base, self._last_delay * 3) - delay = min(delay, self.backoff_max) - self._last_delay = delay - return delay - - # 其他状态码使用指数退避 + full jitter - exp_delay = self.backoff_base * (self.backoff_factor**self.attempt) - delay = random.uniform(0, min(exp_delay, self.backoff_max)) - return delay - - def record_delay(self, delay: float): - """记录延迟时间""" - self.total_delay += delay - - -def extract_retry_after(error: Exception) -> Optional[float]: - """ - 从异常中提取 Retry-After 值 - - Args: - error: 异常对象 - - Returns: - Retry-After 秒数,或 None - """ - if not isinstance(error, UpstreamException): - return None - - details = error.details or {} - - # 尝试从 details 中获取 - retry_after = details.get("retry_after") - if retry_after is not None: - try: - return float(retry_after) - except (ValueError, TypeError): - pass - - # 尝试从 headers 中获取 - headers = details.get("headers", {}) - if isinstance(headers, dict): - retry_after = headers.get("Retry-After") or headers.get("retry-after") - if retry_after is not None: - try: - return float(retry_after) - except (ValueError, TypeError): - pass - - return None - - -async def retry_on_status( - func: Callable, - *args, - extract_status: Callable[[Exception], Optional[int]] = None, - on_retry: Callable[[int, int, Exception, float], None] = None, - **kwargs, -) -> Any: - """ - 通用重试函数 - - Args: - func: 重试的异步函数 - *args: 函数参数 - extract_status: 异常提取状态码的函数 - on_retry: 重试时的回调函数 (attempt, status_code, error, delay) - **kwargs: 函数关键字参数 - - Returns: - 函数执行结果 - - Raises: - 最后一次失败的异常 - """ - ctx = RetryContext() - - # 状态码提取器 - if extract_status is None: - - def extract_status(e: Exception) -> Optional[int]: - if isinstance(e, UpstreamException): - # 优先从 details 获取,回退到 status_code 属性 - if e.details and "status" in e.details: - return e.details["status"] - return getattr(e, "status_code", None) - return None - - while ctx.attempt <= ctx.max_retry: - try: - result = await func(*args, **kwargs) - - # 记录日志 - if ctx.attempt > 0: - logger.info( - f"Retry succeeded after {ctx.attempt} attempts, " - f"total delay: {ctx.total_delay:.2f}s" - ) - - return result - - except Exception as e: - # 提取状态码 - status_code = extract_status(e) - - if status_code is None: - # 错误无法识别 - logger.error(f"Non-retryable error: {e}") - raise - - # 记录错误 - ctx.record_error(status_code, e) - - # 判断是否重试 - if ctx.should_retry(status_code): - # 提取 Retry-After - retry_after = extract_retry_after(e) - - # 计算延迟 - delay = ctx.calculate_delay(status_code, retry_after) - - # 检查是否超出预算 - if ctx.total_delay + delay > ctx.retry_budget: - logger.warning( - f"Retry budget exhausted: {ctx.total_delay:.2f}s + {delay:.2f}s > {ctx.retry_budget}s" - ) - raise - - ctx.record_delay(delay) - - logger.warning( - f"Retry {ctx.attempt}/{ctx.max_retry} for status {status_code}, " - f"waiting {delay:.2f}s (total: {ctx.total_delay:.2f}s)" - + (f", Retry-After: {retry_after}s" if retry_after else "") - ) - - # 回调 - if on_retry: - on_retry(ctx.attempt, status_code, e, delay) - - await asyncio.sleep(delay) - continue - else: - # 不可重试或重试次数耗尽 - if status_code in ctx.retry_codes: - logger.error( - f"Retry exhausted after {ctx.attempt} attempts, " - f"last status: {status_code}, total delay: {ctx.total_delay:.2f}s" - ) - else: - logger.error(f"Non-retryable status code: {status_code}") - - # 抛出最后一次的错误 - raise - - -def with_retry( - extract_status: Callable[[Exception], Optional[int]] = None, - on_retry: Callable[[int, int, Exception, float], None] = None, -): - """ - 重试装饰器 - - Usage: - @with_retry() - async def my_api_call(): - ... - """ - - def decorator(func: Callable): - @wraps(func) - async def wrapper(*args, **kwargs): - return await retry_on_status( - func, *args, extract_status=extract_status, on_retry=on_retry, **kwargs - ) - - return wrapper - - return decorator - - -__all__ = [ - "RetryContext", - "retry_on_status", - "with_retry", - "extract_retry_after", -] diff --git a/app/services/grok/utils/statsig.py b/app/services/grok/utils/statsig.py deleted file mode 100644 index c2cd15f8..00000000 --- a/app/services/grok/utils/statsig.py +++ /dev/null @@ -1,51 +0,0 @@ -""" -Statsig ID 生成服务 -""" - -import base64 -import random -import string - -from app.core.config import get_config - - -class StatsigService: - """Statsig ID 生成服务""" - - @staticmethod - def _rand(length: int, alphanumeric: bool = False) -> str: - """生成随机字符串""" - chars = ( - string.ascii_lowercase + string.digits - if alphanumeric - else string.ascii_lowercase - ) - return "".join(random.choices(chars, k=length)) - - @staticmethod - def gen_id() -> str: - """ - 生成 Statsig ID - - Returns: - Base64 编码的 ID - """ - dynamic = get_config("chat.dynamic_statsig") - - if not dynamic: - return "ZTpUeXBlRXJyb3I6IENhbm5vdCByZWFkIHByb3BlcnRpZXMgb2YgdW5kZWZpbmVkIChyZWFkaW5nICdjaGlsZE5vZGVzJyk=" - - # 随机格式 - if random.choice([True, False]): - rand = StatsigService._rand(5, alphanumeric=True) - message = f"e:TypeError: Cannot read properties of null (reading 'children['{rand}']')" - else: - rand = StatsigService._rand(10) - message = ( - f"e:TypeError: Cannot read properties of undefined (reading '{rand}')" - ) - - return base64.b64encode(message.encode()).decode() - - -__all__ = ["StatsigService"] diff --git a/app/services/reverse/app_chat.py b/app/services/reverse/app_chat.py index 9d3c28bd..77ce6856 100644 --- a/app/services/reverse/app_chat.py +++ b/app/services/reverse/app_chat.py @@ -3,7 +3,7 @@ """ import orjson -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from curl_cffi.requests import AsyncSession from app.core.logger import logger @@ -165,7 +165,18 @@ async def _do_request(): return response - response = await retry_on_status(_do_request) + def extract_status(e: Exception) -> Optional[int]: + if isinstance(e, UpstreamException): + if e.details and "status" in e.details: + status = e.details["status"] + else: + status = getattr(e, "status_code", None) + if status == 429: + return None + return status + return None + + response = await retry_on_status(_do_request, extract_status=extract_status) # Stream response async def stream_response(): From 9967bdd90e0d747056702fc613e3b0db1d6d898f Mon Sep 17 00:00:00 2001 From: Chenyme <118253778+chenyme@users.noreply.github.com> Date: Wed, 11 Feb 2026 15:14:38 +0800 Subject: [PATCH 10/27] refactor: optimize base64 validation and enhance header logging --- app/services/reverse/utils/grpc.py | 5 ++++- app/services/reverse/utils/headers.py | 9 +++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/app/services/reverse/utils/grpc.py b/app/services/reverse/utils/grpc.py index 07ee84dd..446ba4bb 100644 --- a/app/services/reverse/utils/grpc.py +++ b/app/services/reverse/utils/grpc.py @@ -9,6 +9,9 @@ from typing import Dict, List, Mapping, Optional, Tuple from urllib.parse import unquote +# Base64 正则 +B64_RE = re.compile(rb"^[A-Za-z0-9+/=\r\n]+$") + @dataclass(frozen=True) class GrpcStatus: @@ -48,7 +51,7 @@ def _maybe_decode_grpc_web_text(body: bytes, content_type: Optional[str]) -> byt return base64.b64decode(compact, validate=False) head = body[: min(len(body), 2048)] - if head and re.compile(rb"^[A-Za-z0-9+/=\r\n]+$").fullmatch(head): + if head and B64_RE.fullmatch(head): compact = b"".join(body.split()) try: return base64.b64decode(compact, validate=True) diff --git a/app/services/reverse/utils/headers.py b/app/services/reverse/utils/headers.py index 03a8f253..4388bf2f 100644 --- a/app/services/reverse/utils/headers.py +++ b/app/services/reverse/utils/headers.py @@ -123,10 +123,11 @@ def build_headers(cookie_token: str, content_type: Optional[str] = None, origin: headers["x-xai-request-id"] = str(uuid.uuid4()) # Print headers without Cookie - safe_headers = dict(headers) - if "Cookie" in safe_headers: - safe_headers["Cookie"] = "" - logger.debug(f"Built headers: {orjson.dumps(safe_headers).decode()}") + if logger.isEnabledFor(10): + safe_headers = dict(headers) + if "Cookie" in safe_headers: + safe_headers["Cookie"] = "" + logger.debug(f"Built headers: {orjson.dumps(safe_headers).decode()}") return headers From eb00bbd57da706ee2280e6fc9658daffbf9bd213 Mon Sep 17 00:00:00 2001 From: Chenyme <118253778+chenyme@users.noreply.github.com> Date: Wed, 11 Feb 2026 16:10:08 +0800 Subject: [PATCH 11/27] chore: remove unused gRPC-Web protocol files --- app/services/grok/protocols/__init__.py | 0 app/services/grok/protocols/grpc_web.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 app/services/grok/protocols/__init__.py delete mode 100644 app/services/grok/protocols/grpc_web.py diff --git a/app/services/grok/protocols/__init__.py b/app/services/grok/protocols/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/app/services/grok/protocols/grpc_web.py b/app/services/grok/protocols/grpc_web.py deleted file mode 100644 index e69de29b..00000000 From 716e274b8952cafc5a88d5ffc7acdd0451f5acd0 Mon Sep 17 00:00:00 2001 From: Chenyme <118253778+chenyme@users.noreply.github.com> Date: Wed, 11 Feb 2026 17:35:25 +0800 Subject: [PATCH 12/27] refactor: update GROK services and remove unused components --- app/api/v1/admin.py | 309 +++-------- app/api/v1/chat.py | 2 +- app/api/v1/image.py | 384 ++----------- app/services/grok/batch_services/__init__.py | 7 + app/services/grok/batch_services/assets.py | 107 ++++ app/services/grok/batch_services/nsfw.py | 48 ++ app/services/grok/batch_services/usage.py | 36 ++ app/services/grok/processors/__init__.py | 12 +- app/services/grok/processors/base.py | 2 +- .../{chat_processors.py => chat.py} | 22 +- app/services/grok/processors/image.py | 506 ++++++++++++++++++ .../grok/processors/image_processors.py | 248 --------- .../grok/processors/image_ws_processors.py | 268 ---------- .../{video_processors.py => video.py} | 16 +- app/services/grok/services/image.py | 300 ++++++++++- app/services/grok/services/image_edit.py | 197 +++++++ .../grok/services/{media.py => video.py} | 34 +- app/services/grok/utils/download.py | 5 + 18 files changed, 1366 insertions(+), 1137 deletions(-) create mode 100644 app/services/grok/batch_services/__init__.py create mode 100644 app/services/grok/batch_services/assets.py create mode 100644 app/services/grok/batch_services/nsfw.py create mode 100644 app/services/grok/batch_services/usage.py rename app/services/grok/processors/{chat_processors.py => chat.py} (96%) create mode 100644 app/services/grok/processors/image.py delete mode 100644 app/services/grok/processors/image_processors.py delete mode 100644 app/services/grok/processors/image_ws_processors.py rename app/services/grok/processors/{video_processors.py => video.py} (96%) create mode 100644 app/services/grok/services/image_edit.py rename app/services/grok/services/{media.py => video.py} (93%) create mode 100644 app/services/grok/utils/download.py diff --git a/app/api/v1/admin.py b/app/api/v1/admin.py index c2dc4808..8c05b051 100644 --- a/app/api/v1/admin.py +++ b/app/api/v1/admin.py @@ -16,7 +16,11 @@ from app.core.storage import get_storage, LocalStorage, RedisStorage, SQLStorage from app.core.exceptions import AppException from app.services.token.manager import get_token_manager -from app.services.grok.utils.batch import run_in_batches +from app.services.grok.batch_services import ( + BatchUsageService, + BatchNSFWService, + BatchAssetsService, +) import os import time import uuid @@ -27,10 +31,8 @@ from app.core.logger import logger from app.api.v1.image import resolve_aspect_ratio from app.services.grok.services.voice import VoiceService -from app.services.grok.services.image import image_service +from app.services.grok.services.image import ImageGenerationService from app.services.grok.models.model import ModelService -from app.services.grok.processors.image_ws_processors import ImageWSCollectProcessor -from app.services.token import EffortType TEMPLATE_DIR = Path(__file__).parent.parent.parent / "static" @@ -363,7 +365,6 @@ async def _run(prompt: str, aspect_ratio: str): return token_mgr = await get_token_manager() - enable_nsfw = bool(get_config("image.image_ws_nsfw", True)) sequence = 0 run_id = uuid.uuid4().hex @@ -399,26 +400,23 @@ async def _run(prompt: str, aspect_ratio: str): await asyncio.sleep(2) continue - upstream = image_service.stream( + start_at = time.time() + result = await ImageGenerationService().generate( + token_mgr=token_mgr, token=token, + model_info=model_info, prompt=prompt, - aspect_ratio=aspect_ratio, - n=6, - enable_nsfw=enable_nsfw, - ) - - processor = ImageWSCollectProcessor( - model_info.model_id, - token, n=6, response_format="b64_json", + size="1024x1024", + aspect_ratio=aspect_ratio, + stream=False, + use_ws=True, ) - - start_at = time.time() - images = await processor.process(upstream) elapsed_ms = int((time.time() - start_at) * 1000) - if images and all(img and img != "error" for img in images): + images = [img for img in result.data if img and img != "error"] + if images: # 一次发送所有 6 张图片 for img_b64 in images: sequence += 1 @@ -433,17 +431,6 @@ async def _run(prompt: str, aspect_ratio: str): "run_id": run_id, } ) - - # 消耗 token(6 张图片按高成本计算) - try: - effort = ( - EffortType.HIGH - if (model_info and model_info.cost.value == "high") - else EffortType.LOW - ) - await token_mgr.consume(token, effort) - except Exception as e: - logger.warning(f"Failed to consume token: {e}") else: await _send( { @@ -602,7 +589,6 @@ async def event_stream(): return token_mgr = await get_token_manager() - enable_nsfw = bool(get_config("image.image_ws_nsfw", True)) sequence = 0 run_id = uuid.uuid4().hex @@ -645,26 +631,23 @@ async def event_stream(): await asyncio.sleep(2) continue - upstream = image_service.stream( + start_at = time.time() + result = await ImageGenerationService().generate( + token_mgr=token_mgr, token=token, + model_info=model_info, prompt=prompt, - aspect_ratio=ratio, - n=6, - enable_nsfw=enable_nsfw, - ) - - processor = ImageWSCollectProcessor( - model_info.model_id, - token, n=6, response_format="b64_json", + size="1024x1024", + aspect_ratio=ratio, + stream=False, + use_ws=True, ) - - start_at = time.time() - images = await processor.process(upstream) elapsed_ms = int((time.time() - start_at) * 1000) - if images and all(img and img != "error" for img in images): + images = [img for img in result.data if img and img != "error"] + if images: for img_b64 in images: sequence += 1 yield _sse_event( @@ -678,16 +661,6 @@ async def event_stream(): "run_id": run_id, } ) - - try: - effort = ( - EffortType.HIGH - if (model_info and model_info.cost.value == "high") - else EffortType.LOW - ) - await token_mgr.consume(token, effort) - except Exception as e: - logger.warning(f"Failed to consume token: {e}") else: yield _sse_event( { @@ -863,12 +836,9 @@ async def refresh_tokens_api(data: dict): max_concurrent = get_config("performance.usage_max_concurrent") batch_size = get_config("performance.usage_batch_size") - async def _refresh_one(t): - return await mgr.sync_usage(t, consume_on_fail=False, is_usage=False) - - raw_results = await run_in_batches( + raw_results = await BatchUsageService.refresh( unique_tokens, - _refresh_one, + mgr, max_concurrent=max_concurrent, batch_size=batch_size, ) @@ -915,15 +885,12 @@ async def refresh_tokens_api_async(data: dict): async def _run(): try: - async def _refresh_one(t: str): - return await mgr.sync_usage(t, consume_on_fail=False, is_usage=False) - async def _on_item(item: str, res: dict): task.record(bool(res.get("ok"))) - raw_results = await run_in_batches( + raw_results = await BatchUsageService.refresh( unique_tokens, - _refresh_one, + mgr, max_concurrent=max_concurrent, batch_size=batch_size, on_item=_on_item, @@ -979,11 +946,8 @@ async def _on_item(item: str, res: dict): @router.post("/api/v1/admin/tokens/nsfw/enable", dependencies=[Depends(verify_api_key)]) async def enable_nsfw_api(data: dict): """批量开启 NSFW (Unhinged) 模式""" - from app.services.grok.services.nsfw import NSFWService - try: mgr = await get_token_manager() - nsfw_service = NSFWService() # 收集 token 列表 tokens = _collect_tokens(data) @@ -1010,23 +974,11 @@ async def enable_nsfw_api(data: dict): max_concurrent = get_config("performance.nsfw_max_concurrent") batch_size = get_config("performance.nsfw_batch_size") - # 定义 worker - async def _enable(token: str): - result = await nsfw_service.enable(token) - # 成功后添加 nsfw tag - if result.success: - await mgr.add_tag(token, "nsfw") - return { - "success": result.success, - "http_status": result.http_status, - "grpc_status": result.grpc_status, - "grpc_message": result.grpc_message, - "error": result.error, - } - - # 执行批量操作 - raw_results = await run_in_batches( - unique_tokens, _enable, max_concurrent=max_concurrent, batch_size=batch_size + raw_results = await BatchNSFWService.enable( + unique_tokens, + mgr, + max_concurrent=max_concurrent, + batch_size=batch_size, ) # 构造返回结果(mask token) @@ -1073,10 +1025,7 @@ async def _enable(token: str): ) async def enable_nsfw_api_async(data: dict): """批量开启 NSFW (Unhinged) 模式(异步批量 + SSE 进度)""" - from app.services.grok.services.nsfw import NSFWService - mgr = await get_token_manager() - nsfw_service = NSFWService() tokens = _collect_tokens(data) @@ -1103,25 +1052,13 @@ async def enable_nsfw_api_async(data: dict): async def _run(): try: - async def _enable(token: str): - result = await nsfw_service.enable(token) - if result.success: - await mgr.add_tag(token, "nsfw") - return { - "success": result.success, - "http_status": result.http_status, - "grpc_status": result.grpc_status, - "grpc_message": result.grpc_message, - "error": result.error, - } - async def _on_item(item: str, res: dict): ok = bool(res.get("ok") and res.get("data", {}).get("success")) task.record(ok) - raw_results = await run_in_batches( + raw_results = await BatchNSFWService.enable( unique_tokens, - _enable, + mgr, max_concurrent=max_concurrent, batch_size=batch_size, on_item=_on_item, @@ -1184,9 +1121,8 @@ async def admin_cache_page(): @router.get("/api/v1/admin/cache", dependencies=[Depends(verify_api_key)]) async def get_cache_stats_api(request: Request): """获取缓存统计""" - from app.services.grok.services.assets import DownloadService, ListService + from app.services.grok.utils.download import DownloadService from app.services.token.manager import get_token_manager - from app.services.grok.utils.batch import run_in_batches try: dl_service = DownloadService() @@ -1238,51 +1174,14 @@ async def get_cache_stats_api(request: Request): truncated = False original_count = 0 - async def _fetch_assets(token: str): - list_service = ListService() - try: - return await list_service.count(token) - finally: - await list_service.close() - - async def _fetch_detail(token: str): - account = account_map.get(token) - try: - count = await _fetch_assets(token) - return { - "detail": { - "token": token, - "token_masked": account["token_masked"] if account else token, - "count": count, - "status": "ok", - "last_asset_clear_at": account["last_asset_clear_at"] - if account - else None, - }, - "count": count, - } - except Exception as e: - return { - "detail": { - "token": token, - "token_masked": account["token_masked"] if account else token, - "count": 0, - "status": f"error: {str(e)}", - "last_asset_clear_at": account["last_asset_clear_at"] - if account - else None, - }, - "count": 0, - } - if selected_tokens: selected_tokens, truncated, original_count = _truncate_tokens( selected_tokens, max_tokens, "Assets fetch" ) total = 0 - raw_results = await run_in_batches( + raw_results = await BatchAssetsService.fetch_details( selected_tokens, - _fetch_detail, + account_map, max_concurrent=max_concurrent, batch_size=batch_size, ) @@ -1318,9 +1217,9 @@ async def _fetch_detail(token: str): if len(tokens) > max_tokens: tokens = tokens[:max_tokens] truncated = True - raw_results = await run_in_batches( + raw_results = await BatchAssetsService.fetch_details( tokens, - _fetch_detail, + account_map, max_concurrent=max_concurrent, batch_size=batch_size, ) @@ -1351,23 +1250,28 @@ async def _fetch_detail(token: str): else: token = selected_token if token: - try: - count = await _fetch_assets(token) - match = next((a for a in accounts if a["token"] == token), None) + raw_results = await BatchAssetsService.fetch_details( + [token], + account_map, + max_concurrent=1, + batch_size=1, + ) + res = raw_results.get(token, {}) + data = res.get("data", {}) + detail = data.get("detail") if res.get("ok") else None + if detail: online_stats = { - "count": count, - "status": "ok", - "token": token, - "token_masked": match["token_masked"] if match else token, - "last_asset_clear_at": match["last_asset_clear_at"] - if match - else None, + "count": data.get("count", 0), + "status": detail.get("status", "ok"), + "token": detail.get("token"), + "token_masked": detail.get("token_masked"), + "last_asset_clear_at": detail.get("last_asset_clear_at"), } - except Exception as e: + else: match = next((a for a in accounts if a["token"] == token), None) online_stats = { "count": 0, - "status": f"error: {str(e)}", + "status": f"error: {res.get('error')}", "token": token, "token_masked": match["token_masked"] if match else token, "last_asset_clear_at": match["last_asset_clear_at"] @@ -1404,9 +1308,8 @@ async def _fetch_detail(token: str): ) async def load_online_cache_api_async(data: dict): """在线资产统计(异步批量 + SSE 进度)""" - from app.services.grok.services.assets import DownloadService, ListService + from app.services.grok.utils.download import DownloadService from app.services.token.manager import get_token_manager - from app.services.grok.utils.batch import run_in_batches mgr = await get_token_manager() @@ -1462,44 +1365,16 @@ async def _run(): image_stats = dl_service.get_stats("image") video_stats = dl_service.get_stats("video") - async def _fetch_detail(token: str): - account = account_map.get(token) - list_service = ListService() - try: - count = await list_service.count(token) - detail = { - "token": token, - "token_masked": account["token_masked"] if account else token, - "count": count, - "status": "ok", - "last_asset_clear_at": account["last_asset_clear_at"] - if account - else None, - } - return {"ok": True, "detail": detail, "count": count} - except Exception as e: - detail = { - "token": token, - "token_masked": account["token_masked"] if account else token, - "count": 0, - "status": f"error: {str(e)}", - "last_asset_clear_at": account["last_asset_clear_at"] - if account - else None, - } - return {"ok": False, "detail": detail, "count": 0} - finally: - await list_service.close() - async def _on_item(item: str, res: dict): ok = bool(res.get("data", {}).get("ok")) task.record(ok) - raw_results = await run_in_batches( + raw_results = await BatchAssetsService.fetch_details( selected_tokens, - _fetch_detail, + account_map, max_concurrent=max_concurrent, batch_size=batch_size, + include_ok=True, on_item=_on_item, should_cancel=lambda: task.cancelled, ) @@ -1555,7 +1430,7 @@ async def _on_item(item: str, res: dict): @router.post("/api/v1/admin/cache/clear", dependencies=[Depends(verify_api_key)]) async def clear_local_cache_api(data: dict): """清理本地缓存""" - from app.services.grok.services.assets import DownloadService + from app.services.grok.utils.download import DownloadService cache_type = data.get("type", "image") @@ -1575,7 +1450,7 @@ async def list_local_cache_api( page_size: int = 1000, ): """列出本地缓存文件""" - from app.services.grok.services.assets import DownloadService + from app.services.grok.utils.download import DownloadService try: if type_: @@ -1590,7 +1465,7 @@ async def list_local_cache_api( @router.post("/api/v1/admin/cache/item/delete", dependencies=[Depends(verify_api_key)]) async def delete_local_cache_item_api(data: dict): """删除单个本地缓存文件""" - from app.services.grok.services.assets import DownloadService + from app.services.grok.utils.download import DownloadService cache_type = data.get("type", "image") name = data.get("name") @@ -1607,15 +1482,10 @@ async def delete_local_cache_item_api(data: dict): @router.post("/api/v1/admin/cache/online/clear", dependencies=[Depends(verify_api_key)]) async def clear_online_cache_api(data: dict): """清理在线缓存""" - from app.services.grok.services.assets import DeleteService from app.services.token.manager import get_token_manager - from app.services.grok.utils.batch import run_in_batches - - delete_service = None try: mgr = await get_token_manager() tokens = data.get("tokens") - delete_service = DeleteService() if isinstance(tokens, list): token_list = [t.strip() for t in tokens if isinstance(t, str) and t.strip()] @@ -1637,17 +1507,9 @@ async def clear_online_cache_api(data: dict): ) batch_size = max(1, int(get_config("performance.assets_batch_size"))) - async def _clear_one(t: str): - try: - result = await delete_service.delete_all(t) - await mgr.mark_asset_clear(t) - return {"status": "success", "result": result} - except Exception as e: - return {"status": "error", "error": str(e)} - - raw_results = await run_in_batches( + raw_results = await BatchAssetsService.clear_online( token_list, - _clear_one, + mgr, max_concurrent=max_concurrent, batch_size=batch_size, ) @@ -1670,14 +1532,19 @@ async def _clear_one(t: str): status_code=400, detail="No available token to perform cleanup" ) - result = await delete_service.delete_all(token) - await mgr.mark_asset_clear(token) - return {"status": "success", "result": result} + raw_results = await BatchAssetsService.clear_online( + [token], + mgr, + max_concurrent=1, + batch_size=1, + ) + res = raw_results.get(token, {}) + data = res.get("data", {}) + if res.get("ok") and data.get("status") == "success": + return {"status": "success", "result": data.get("result")} + return {"status": "error", "error": data.get("error") or res.get("error")} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - finally: - if delete_service: - await delete_service.close() @router.post( @@ -1685,9 +1552,7 @@ async def _clear_one(t: str): ) async def clear_online_cache_api_async(data: dict): """清理在线缓存(异步批量 + SSE 进度)""" - from app.services.grok.services.assets import DeleteService from app.services.token.manager import get_token_manager - from app.services.grok.utils.batch import run_in_batches mgr = await get_token_manager() tokens = data.get("tokens") @@ -1709,26 +1574,17 @@ async def clear_online_cache_api_async(data: dict): task = create_task(len(token_list)) async def _run(): - delete_service = DeleteService() try: - - async def _clear_one(t: str): - try: - result = await delete_service.delete_all(t) - await mgr.mark_asset_clear(t) - return {"ok": True, "result": result} - except Exception as e: - return {"ok": False, "error": str(e)} - async def _on_item(item: str, res: dict): ok = bool(res.get("data", {}).get("ok")) task.record(ok) - raw_results = await run_in_batches( + raw_results = await BatchAssetsService.clear_online( token_list, - _clear_one, + mgr, max_concurrent=max_concurrent, batch_size=batch_size, + include_ok=True, on_item=_on_item, should_cancel=lambda: task.cancelled, ) @@ -1767,7 +1623,6 @@ async def _on_item(item: str, res: dict): except Exception as e: task.fail_task(str(e)) finally: - await delete_service.close() asyncio.create_task(expire_task(task.id, 300)) asyncio.create_task(_run()) diff --git a/app/api/v1/chat.py b/app/api/v1/chat.py index 6c420567..ba5a820f 100644 --- a/app/api/v1/chat.py +++ b/app/api/v1/chat.py @@ -260,7 +260,7 @@ async def chat_completions(request: ChatCompletionRequest): # 检测视频模型 model_info = ModelService.get(request.model) if model_info and model_info.is_video: - from app.services.grok.services.media import VideoService + from app.services.grok.services.video import VideoService # 提取视频配置 (默认值在 Pydantic 模型中处理) v_conf = request.video_config or VideoConfig() diff --git a/app/api/v1/image.py b/app/api/v1/image.py index bc47b16d..6e8ab670 100644 --- a/app/api/v1/image.py +++ b/app/api/v1/image.py @@ -2,11 +2,7 @@ Image Generation API 路由 """ -import asyncio import base64 -import math -import random -import re import time from pathlib import Path from typing import List, Optional, Union @@ -15,21 +11,12 @@ from fastapi.responses import StreamingResponse, JSONResponse from pydantic import BaseModel, Field, ValidationError -from app.services.grok.services.chat import GrokChatService -from app.services.grok.services.image import image_service -from app.services.grok.services.assets import UploadService -from app.services.grok.services.media import VideoService +from app.services.grok.services.image import ImageGenerationService +from app.services.grok.services.image_edit import ImageEditService from app.services.grok.models.model import ModelService -from app.services.grok.processors import ( - ImageStreamProcessor, - ImageCollectProcessor, - ImageWSStreamProcessor, - ImageWSCollectProcessor, -) -from app.services.token import get_token_manager, EffortType +from app.services.token import get_token_manager from app.core.exceptions import ValidationException, AppException, ErrorType from app.core.config import get_config -from app.core.logger import logger router = APIRouter(tags=["Images"]) @@ -202,30 +189,6 @@ def validate_edit_request(request: ImageEditRequest, images: List[UploadFile]): ) -def _get_effort(model_info) -> EffortType: - """获取模型消耗级别""" - return ( - EffortType.HIGH - if (model_info and model_info.cost.value == "high") - else EffortType.LOW - ) - - -async def _wrap_stream_with_usage(stream, token_mgr, token, model_info): - """包装流式响应,成功完成时记录使用""" - success = False - try: - async for chunk in stream: - yield chunk - success = True - finally: - if success: - try: - await token_mgr.consume(token, _get_effort(model_info)) - except Exception as e: - logger.warning(f"Failed to consume token: {e}") - - async def _get_token(model: str): """获取可用 token""" token_mgr = await get_token_manager() @@ -248,46 +211,6 @@ async def _get_token(model: str): return token_mgr, token -async def call_grok( - token_mgr, - token: str, - prompt: str, - model_info, - file_attachments: Optional[List[str]] = None, - response_format: str = "b64_json", -) -> List[str]: - """调用 Grok 获取图片,返回 base64 列表""" - chat_service = GrokChatService() - success = False - - try: - response = await chat_service.chat( - token=token, - message=prompt, - model=model_info.grok_model, - mode=model_info.model_mode, - stream=True, - file_attachments=file_attachments, - ) - - processor = ImageCollectProcessor( - model_info.model_id, token, response_format=response_format - ) - images = await processor.process(response) - success = True - return images - - except Exception as e: - logger.error(f"Grok image call failed: {e}") - return [] - finally: - if success: - try: - await token_mgr.consume(token, _get_effort(model_info)) - except Exception as e: - logger.warning(f"Failed to consume token: {e}") - - @router.post("/images/generations") async def create_image(request: ImageGenerationRequest): """ @@ -321,160 +244,30 @@ async def create_image(request: ImageGenerationRequest): token_mgr, token = await _get_token(request.model) model_info = ModelService.get(request.model) use_ws = bool(get_config("image.image_ws")) + aspect_ratio = resolve_aspect_ratio(request.size) + + result = await ImageGenerationService().generate( + token_mgr=token_mgr, + token=token, + model_info=model_info, + prompt=request.prompt, + n=request.n, + response_format=response_format, + size=request.size, + aspect_ratio=aspect_ratio, + stream=bool(request.stream), + use_ws=use_ws, + ) - # 流式模式 - if request.stream: - if use_ws: - aspect_ratio = resolve_aspect_ratio(request.size) - enable_nsfw = bool(get_config("image.image_ws_nsfw")) - upstream = image_service.stream( - token=token, - prompt=request.prompt, - aspect_ratio=aspect_ratio, - n=request.n, - enable_nsfw=enable_nsfw, - ) - processor = ImageWSStreamProcessor( - model_info.model_id, - token, - n=request.n, - response_format=response_format, - size=request.size, - ) - - return StreamingResponse( - _wrap_stream_with_usage( - processor.process(upstream), token_mgr, token, model_info - ), - media_type="text/event-stream", - headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, - ) - - chat_service = GrokChatService() - response = await chat_service.chat( - token=token, - message=f"Image Generation: {request.prompt}", - model=model_info.grok_model, - mode=model_info.model_mode, - stream=True, - ) - - processor = ImageStreamProcessor( - model_info.model_id, token, n=request.n, response_format=response_format - ) - + if result.stream: return StreamingResponse( - _wrap_stream_with_usage( - processor.process(response), token_mgr, token, model_info - ), + result.data, media_type="text/event-stream", headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, ) - # 非流式模式 - n = request.n - - usage_override = None - if use_ws: - aspect_ratio = resolve_aspect_ratio(request.size) - enable_nsfw = bool(get_config("image.image_ws_nsfw")) - all_images = [] - seen = set() - expected_per_call = 6 - calls_needed = max(1, math.ceil(n / expected_per_call)) - calls_needed = min(calls_needed, n) - - async def _fetch_batch(call_target: int): - upstream = image_service.stream( - token=token, - prompt=request.prompt, - aspect_ratio=aspect_ratio, - n=call_target, - enable_nsfw=enable_nsfw, - ) - processor = ImageWSCollectProcessor( - model_info.model_id, - token, - n=call_target, - response_format=response_format, - ) - return await processor.process(upstream) - - tasks = [] - for i in range(calls_needed): - remaining = n - (i * expected_per_call) - call_target = min(expected_per_call, remaining) - tasks.append(_fetch_batch(call_target)) - - results = await asyncio.gather(*tasks, return_exceptions=True) - for batch in results: - if isinstance(batch, Exception): - logger.warning(f"WS batch failed: {batch}") - continue - for img in batch: - if img not in seen: - seen.add(img) - all_images.append(img) - if len(all_images) >= n: - break - if len(all_images) >= n: - break - try: - await token_mgr.consume(token, _get_effort(model_info)) - except Exception as e: - logger.warning(f"Failed to consume token: {e}") - usage_override = { - "total_tokens": 0, - "input_tokens": 0, - "output_tokens": 0, - "input_tokens_details": {"text_tokens": 0, "image_tokens": 0}, - } - else: - calls_needed = (n + 1) // 2 - - if calls_needed == 1: - # 单次调用 - all_images = await call_grok( - token_mgr, - token, - f"Image Generation: {request.prompt}", - model_info, - response_format=response_format, - ) - else: - # 并发调用 - tasks = [ - call_grok( - token_mgr, - token, - f"Image Generation: {request.prompt}", - model_info, - response_format=response_format, - ) - for _ in range(calls_needed) - ] - results = await asyncio.gather(*tasks, return_exceptions=True) - - # 收集成功的图片 - all_images = [] - for result in results: - if isinstance(result, Exception): - logger.error(f"Concurrent call failed: {result}") - elif isinstance(result, list): - all_images.extend(result) - - # 随机选取 n 张图片 - if len(all_images) >= n: - selected_images = random.sample(all_images, n) - else: - # 全部返回,error 填充缺失 - selected_images = all_images.copy() - while len(selected_images) < n: - selected_images.append("error") - - # 构建响应 - data = [{response_field: img} for img in selected_images] - usage = usage_override or { + data = [{response_field: img} for img in result.data] + usage = result.usage_override or { "total_tokens": 0, "input_tokens": 0, "output_tokens": 0, @@ -588,134 +381,25 @@ async def edit_image( token_mgr, token = await _get_token(edit_request.model) model_info = ModelService.get(edit_request.model) - # 上传图片 - image_urls: List[str] = [] - upload_service = UploadService() - try: - for image in images: - file_id, file_uri = await upload_service.upload(image, token) - if file_uri: - if file_uri.startswith("http"): - image_urls.append(file_uri) - else: - image_urls.append(f"https://assets.grok.com/{file_uri.lstrip('/')}") - finally: - await upload_service.close() - - if not image_urls: - raise AppException( - message="Image upload failed", - error_type=ErrorType.SERVER.value, - code="upload_failed", - ) - - parent_post_id = None - try: - media_service = VideoService() - parent_post_id = await media_service.create_image_post(token, image_urls[0]) - logger.debug(f"Parent post ID: {parent_post_id}") - except Exception as e: - logger.warning(f"Create image post failed: {e}") - - if not parent_post_id: - for url in image_urls: - match = re.search(r"/generated/([a-f0-9-]+)/", url) - if match: - parent_post_id = match.group(1) - logger.debug(f"Parent post ID: {parent_post_id}") - break - match = re.search(r"/users/[^/]+/([a-f0-9-]+)/content", url) - if match: - parent_post_id = match.group(1) - logger.debug(f"Parent post ID: {parent_post_id}") - break - - model_config_override = { - "modelMap": { - "imageEditModel": "imagine", - "imageEditModelConfig": { - "imageReferences": image_urls, - }, - } - } - - if parent_post_id: - model_config_override["modelMap"]["imageEditModelConfig"]["parentPostId"] = ( - parent_post_id - ) - - tool_overrides = {"imageGen": True} - - # 流式模式 - if edit_request.stream: - chat_service = GrokChatService() - response = await chat_service.chat( - token=token, - message=edit_request.prompt, - model=model_info.grok_model, - mode=None, - stream=True, - tool_overrides=tool_overrides, - model_config_override=model_config_override, - ) - - processor = ImageStreamProcessor( - model_info.model_id, - token, - n=edit_request.n, - response_format=response_format, - ) + result = await ImageEditService().edit( + token_mgr=token_mgr, + token=token, + model_info=model_info, + prompt=edit_request.prompt, + images=images, + n=edit_request.n, + response_format=response_format, + stream=bool(edit_request.stream), + ) + if result.stream: return StreamingResponse( - _wrap_stream_with_usage( - processor.process(response), token_mgr, token, model_info - ), + result.data, media_type="text/event-stream", headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, ) - # 非流式模式 - n = edit_request.n - calls_needed = (n + 1) // 2 - - async def _call_edit(): - chat_service = GrokChatService() - response = await chat_service.chat( - token=token, - message=edit_request.prompt, - model=model_info.grok_model, - mode=None, - stream=True, - tool_overrides=tool_overrides, - model_config_override=model_config_override, - ) - processor = ImageCollectProcessor( - model_info.model_id, token, response_format=response_format - ) - return await processor.process(response) - - if calls_needed == 1: - all_images = await _call_edit() - else: - tasks = [_call_edit() for _ in range(calls_needed)] - results = await asyncio.gather(*tasks, return_exceptions=True) - - all_images = [] - for result in results: - if isinstance(result, Exception): - logger.error(f"Concurrent call failed: {result}") - elif isinstance(result, list): - all_images.extend(result) - - # 选择图片 - if len(all_images) >= n: - selected_images = random.sample(all_images, n) - else: - selected_images = all_images.copy() - while len(selected_images) < n: - selected_images.append("error") - - data = [{response_field: img} for img in selected_images] + data = [{response_field: img} for img in result.data] return JSONResponse( content={ diff --git a/app/services/grok/batch_services/__init__.py b/app/services/grok/batch_services/__init__.py new file mode 100644 index 00000000..49796375 --- /dev/null +++ b/app/services/grok/batch_services/__init__.py @@ -0,0 +1,7 @@ +"""Batch services.""" + +from .usage import BatchUsageService +from .nsfw import BatchNSFWService +from .assets import BatchAssetsService + +__all__ = ["BatchUsageService", "BatchNSFWService", "BatchAssetsService"] diff --git a/app/services/grok/batch_services/assets.py b/app/services/grok/batch_services/assets.py new file mode 100644 index 00000000..3ad213e2 --- /dev/null +++ b/app/services/grok/batch_services/assets.py @@ -0,0 +1,107 @@ +""" +Batch assets service. +""" + +from typing import Callable, Awaitable, Dict, Any, Optional + +from app.services.grok.services.assets import ListService, DeleteService +from app.services.grok.utils.batch import run_in_batches + + +class BatchAssetsService: + """Batch assets orchestration.""" + + @staticmethod + async def fetch_details( + tokens: list[str], + account_map: Dict[str, Dict[str, Any]], + *, + max_concurrent: int, + batch_size: int, + include_ok: bool = False, + on_item: Optional[Callable[[str, Dict[str, Any]], Awaitable[None]]] = None, + should_cancel: Optional[Callable[[], bool]] = None, + ) -> Dict[str, Dict[str, Any]]: + account_map = account_map or {} + + async def _fetch_detail(token: str): + account = account_map.get(token) + list_service = ListService() + try: + count = await list_service.count(token) + detail = { + "token": token, + "token_masked": account["token_masked"] if account else token, + "count": count, + "status": "ok", + "last_asset_clear_at": account["last_asset_clear_at"] + if account + else None, + } + if include_ok: + return {"ok": True, "detail": detail, "count": count} + return {"detail": detail, "count": count} + except Exception as e: + detail = { + "token": token, + "token_masked": account["token_masked"] if account else token, + "count": 0, + "status": f"error: {str(e)}", + "last_asset_clear_at": account["last_asset_clear_at"] + if account + else None, + } + if include_ok: + return {"ok": False, "detail": detail, "count": 0} + return {"detail": detail, "count": 0} + finally: + await list_service.close() + + return await run_in_batches( + tokens, + _fetch_detail, + max_concurrent=max_concurrent, + batch_size=batch_size, + on_item=on_item, + should_cancel=should_cancel, + ) + + @staticmethod + async def clear_online( + tokens: list[str], + mgr, + *, + max_concurrent: int, + batch_size: int, + include_ok: bool = False, + on_item: Optional[Callable[[str, Dict[str, Any]], Awaitable[None]]] = None, + should_cancel: Optional[Callable[[], bool]] = None, + ) -> Dict[str, Dict[str, Any]]: + delete_service = DeleteService() + + async def _clear_one(token: str): + try: + result = await delete_service.delete_all(token) + await mgr.mark_asset_clear(token) + if include_ok: + return {"ok": True, "result": result} + return {"status": "success", "result": result} + except Exception as e: + if include_ok: + return {"ok": False, "error": str(e)} + return {"status": "error", "error": str(e)} + + try: + return await run_in_batches( + tokens, + _clear_one, + max_concurrent=max_concurrent, + batch_size=batch_size, + on_item=on_item, + should_cancel=should_cancel, + ) + finally: + await delete_service.close() + + +__all__ = ["BatchAssetsService"] diff --git a/app/services/grok/batch_services/nsfw.py b/app/services/grok/batch_services/nsfw.py new file mode 100644 index 00000000..3c8ae06b --- /dev/null +++ b/app/services/grok/batch_services/nsfw.py @@ -0,0 +1,48 @@ +""" +Batch NSFW service. +""" + +from typing import Callable, Awaitable, Dict, Any, Optional + +from app.services.grok.services.nsfw import NSFWService +from app.services.grok.utils.batch import run_in_batches + + +class BatchNSFWService: + """Batch NSFW orchestration.""" + + @staticmethod + async def enable( + tokens: list[str], + mgr, + *, + max_concurrent: int, + batch_size: int, + on_item: Optional[Callable[[str, Dict[str, Any]], Awaitable[None]]] = None, + should_cancel: Optional[Callable[[], bool]] = None, + ) -> Dict[str, Dict[str, Any]]: + nsfw_service = NSFWService() + + async def _enable(token: str): + result = await nsfw_service.enable(token) + if result.success: + await mgr.add_tag(token, "nsfw") + return { + "success": result.success, + "http_status": result.http_status, + "grpc_status": result.grpc_status, + "grpc_message": result.grpc_message, + "error": result.error, + } + + return await run_in_batches( + tokens, + _enable, + max_concurrent=max_concurrent, + batch_size=batch_size, + on_item=on_item, + should_cancel=should_cancel, + ) + + +__all__ = ["BatchNSFWService"] diff --git a/app/services/grok/batch_services/usage.py b/app/services/grok/batch_services/usage.py new file mode 100644 index 00000000..29e1b4db --- /dev/null +++ b/app/services/grok/batch_services/usage.py @@ -0,0 +1,36 @@ +""" +Batch usage service. +""" + +from typing import Callable, Awaitable, Dict, Any, Optional + +from app.services.grok.utils.batch import run_in_batches + + +class BatchUsageService: + """Batch usage orchestration.""" + + @staticmethod + async def refresh( + tokens: list[str], + mgr, + *, + max_concurrent: int, + batch_size: int, + on_item: Optional[Callable[[str, Dict[str, Any]], Awaitable[None]]] = None, + should_cancel: Optional[Callable[[], bool]] = None, + ) -> Dict[str, Dict[str, Any]]: + async def _refresh_one(t: str): + return await mgr.sync_usage(t, consume_on_fail=False, is_usage=False) + + return await run_in_batches( + tokens, + _refresh_one, + max_concurrent=max_concurrent, + batch_size=batch_size, + on_item=on_item, + should_cancel=should_cancel, + ) + + +__all__ = ["BatchUsageService"] diff --git a/app/services/grok/processors/__init__.py b/app/services/grok/processors/__init__.py index 04773f67..72fb3d58 100644 --- a/app/services/grok/processors/__init__.py +++ b/app/services/grok/processors/__init__.py @@ -3,10 +3,14 @@ """ from .base import BaseProcessor, StreamIdleTimeoutError -from .chat_processors import StreamProcessor, CollectProcessor -from .video_processors import VideoStreamProcessor, VideoCollectProcessor -from .image_processors import ImageStreamProcessor, ImageCollectProcessor -from .image_ws_processors import ImageWSStreamProcessor, ImageWSCollectProcessor +from .chat import StreamProcessor, CollectProcessor +from .video import VideoStreamProcessor, VideoCollectProcessor +from .image import ( + ImageStreamProcessor, + ImageCollectProcessor, + ImageWSStreamProcessor, + ImageWSCollectProcessor, +) __all__ = [ "BaseProcessor", diff --git a/app/services/grok/processors/base.py b/app/services/grok/processors/base.py index 76b4838d..f1c208e0 100644 --- a/app/services/grok/processors/base.py +++ b/app/services/grok/processors/base.py @@ -8,7 +8,7 @@ from app.core.config import get_config from app.core.logger import logger -from app.services.grok.services.assets import DownloadService +from app.services.grok.utils.download import DownloadService ASSET_URL = "https://assets.grok.com/" diff --git a/app/services/grok/processors/chat_processors.py b/app/services/grok/processors/chat.py similarity index 96% rename from app/services/grok/processors/chat_processors.py rename to app/services/grok/processors/chat.py index 000e09b8..f55f8530 100644 --- a/app/services/grok/processors/chat_processors.py +++ b/app/services/grok/processors/chat.py @@ -1,5 +1,5 @@ """ -聊天响应处理器 +Chat response processors. """ import asyncio @@ -24,7 +24,7 @@ class StreamProcessor(BaseProcessor): - """流式响应处理器""" + """Stream response processor.""" def __init__(self, model: str, token: str = "", think: bool = None): super().__init__(model, token) @@ -43,7 +43,7 @@ def __init__(self, model: str, token: str = "", think: bool = None): self.show_think = think def _filter_token(self, token: str) -> str: - """过滤 token 中的特殊标签(如 ...),支持跨 token 的标签过滤""" + """Filter special tags (supports cross-token tag filtering).""" if not self.filter_tags: return token @@ -92,7 +92,7 @@ def _filter_token(self, token: str) -> str: return "".join(result) def _sse(self, content: str = "", role: str = None, finish: str = None) -> str: - """构建 SSE 响应""" + """Build SSE response.""" delta = {} if role: delta["role"] = role @@ -115,7 +115,7 @@ def _sse(self, content: str = "", role: str = None, finish: str = None) -> str: async def process( self, response: AsyncIterable[bytes] ) -> AsyncGenerator[str, None]: - """处理流式响应""" + """Process stream response.""" idle_timeout = get_config("timeout.stream_idle_timeout") try: @@ -139,7 +139,7 @@ async def process( yield self._sse(role="assistant") self.role_sent = True - # 图像生成进度 + # Image generation progress if img := resp.get("streamingImageGenerationResponse"): if self.show_think: if not self.think_opened: @@ -160,7 +160,7 @@ async def process( yield self._sse("\n") self.think_opened = False - # 处理生成的图片 + # Handle generated images for url in _collect_image_urls(mr): parts = url.split("/") img_id = parts[-2] if len(parts) >= 2 else "image" @@ -194,7 +194,7 @@ async def process( self.fingerprint = meta["llm_info"]["modelHash"] continue - # 普通 token + # Normal token if (token := resp.get("token")) is not None: if token: filtered = self._filter_token(token) @@ -242,7 +242,7 @@ async def process( class CollectProcessor(BaseProcessor): - """非流式响应处理器""" + """Non-stream response processor.""" def __init__(self, model: str, token: str = ""): super().__init__(model, token) @@ -250,7 +250,7 @@ def __init__(self, model: str, token: str = ""): self.filter_tags = get_config("chat.filter_tags") def _filter_content(self, content: str) -> str: - """过滤内容中的特殊标签""" + """Filter special tags in content.""" if not content or not self.filter_tags: return content @@ -262,7 +262,7 @@ def _filter_content(self, content: str) -> str: return result async def process(self, response: AsyncIterable[bytes]) -> dict[str, Any]: - """处理并收集完整响应""" + """Process and collect full response.""" response_id = "" fingerprint = "" content = "" diff --git a/app/services/grok/processors/image.py b/app/services/grok/processors/image.py new file mode 100644 index 00000000..a53b94a8 --- /dev/null +++ b/app/services/grok/processors/image.py @@ -0,0 +1,506 @@ +""" +Image response processors (HTTP + WebSocket). +""" + +import asyncio +import base64 +import random +import time +from pathlib import Path +from typing import AsyncGenerator, AsyncIterable, List, Dict, Optional + +import orjson +from curl_cffi.requests.errors import RequestsError + +from app.core.config import get_config +from app.core.logger import logger +from app.core.storage import DATA_DIR +from app.core.exceptions import UpstreamException +from .base import ( + BaseProcessor, + StreamIdleTimeoutError, + _with_idle_timeout, + _normalize_stream_line, + _collect_image_urls, + _is_http2_stream_error, +) + + +class ImageStreamProcessor(BaseProcessor): + """HTTP image stream processor.""" + + def __init__( + self, model: str, token: str = "", n: int = 1, response_format: str = "b64_json" + ): + super().__init__(model, token) + self.partial_index = 0 + self.n = n + self.target_index = random.randint(0, 1) if n == 1 else None + self.response_format = response_format + if response_format == "url": + self.response_field = "url" + elif response_format == "base64": + self.response_field = "base64" + else: + self.response_field = "b64_json" + + def _sse(self, event: str, data: dict) -> str: + """Build SSE response.""" + return f"event: {event}\ndata: {orjson.dumps(data).decode()}\n\n" + + async def process( + self, response: AsyncIterable[bytes] + ) -> AsyncGenerator[str, None]: + """Process stream response.""" + final_images = [] + idle_timeout = get_config("timeout.stream_idle_timeout") + + try: + async for line in _with_idle_timeout(response, idle_timeout, self.model): + line = _normalize_stream_line(line) + if not line: + continue + try: + data = orjson.loads(line) + except orjson.JSONDecodeError: + continue + + resp = data.get("result", {}).get("response", {}) + + # Image generation progress + if img := resp.get("streamingImageGenerationResponse"): + image_index = img.get("imageIndex", 0) + progress = img.get("progress", 0) + + if self.n == 1 and image_index != self.target_index: + continue + + out_index = 0 if self.n == 1 else image_index + + yield self._sse( + "image_generation.partial_image", + { + "type": "image_generation.partial_image", + self.response_field: "", + "index": out_index, + "progress": progress, + }, + ) + continue + + # modelResponse + if mr := resp.get("modelResponse"): + if urls := _collect_image_urls(mr): + for url in urls: + if self.response_format == "url": + processed = await self.process_url(url, "image") + if processed: + final_images.append(processed) + continue + try: + dl_service = self._get_dl() + base64_data = await dl_service.to_base64( + url, self.token, "image" + ) + if base64_data: + if "," in base64_data: + b64 = base64_data.split(",", 1)[1] + else: + b64 = base64_data + final_images.append(b64) + except Exception as e: + logger.warning( + f"Failed to convert image to base64, falling back to URL: {e}" + ) + processed = await self.process_url(url, "image") + if processed: + final_images.append(processed) + continue + + for index, b64 in enumerate(final_images): + if self.n == 1: + if index != self.target_index: + continue + out_index = 0 + else: + out_index = index + + yield self._sse( + "image_generation.completed", + { + "type": "image_generation.completed", + self.response_field: b64, + "index": out_index, + "usage": { + "total_tokens": 0, + "input_tokens": 0, + "output_tokens": 0, + "input_tokens_details": { + "text_tokens": 0, + "image_tokens": 0, + }, + }, + }, + ) + except asyncio.CancelledError: + logger.debug("Image stream cancelled by client") + except StreamIdleTimeoutError as e: + raise UpstreamException( + message=f"Image stream idle timeout after {e.idle_seconds}s", + status_code=504, + details={ + "error": str(e), + "type": "stream_idle_timeout", + "idle_seconds": e.idle_seconds, + }, + ) + except RequestsError as e: + if _is_http2_stream_error(e): + logger.warning(f"HTTP/2 stream error in image: {e}") + raise UpstreamException( + message="Upstream connection closed unexpectedly", + status_code=502, + details={"error": str(e), "type": "http2_stream_error"}, + ) + logger.error(f"Image stream request error: {e}") + raise UpstreamException( + message=f"Upstream request failed: {e}", + status_code=502, + details={"error": str(e)}, + ) + except Exception as e: + logger.error( + f"Image stream processing error: {e}", + extra={"error_type": type(e).__name__}, + ) + raise + finally: + await self.close() + + +class ImageCollectProcessor(BaseProcessor): + """HTTP image non-stream processor.""" + + def __init__(self, model: str, token: str = "", response_format: str = "b64_json"): + super().__init__(model, token) + self.response_format = response_format + + async def process(self, response: AsyncIterable[bytes]) -> List[str]: + """Process and collect images.""" + images = [] + idle_timeout = get_config("timeout.stream_idle_timeout") + + try: + async for line in _with_idle_timeout(response, idle_timeout, self.model): + line = _normalize_stream_line(line) + if not line: + continue + try: + data = orjson.loads(line) + except orjson.JSONDecodeError: + continue + + resp = data.get("result", {}).get("response", {}) + + if mr := resp.get("modelResponse"): + if urls := _collect_image_urls(mr): + for url in urls: + if self.response_format == "url": + processed = await self.process_url(url, "image") + if processed: + images.append(processed) + continue + try: + dl_service = self._get_dl() + base64_data = await dl_service.to_base64( + url, self.token, "image" + ) + if base64_data: + if "," in base64_data: + b64 = base64_data.split(",", 1)[1] + else: + b64 = base64_data + images.append(b64) + except Exception as e: + logger.warning( + f"Failed to convert image to base64, falling back to URL: {e}" + ) + processed = await self.process_url(url, "image") + if processed: + images.append(processed) + + except asyncio.CancelledError: + logger.debug("Image collect cancelled by client") + except StreamIdleTimeoutError as e: + logger.warning(f"Image collect idle timeout: {e}") + except RequestsError as e: + if _is_http2_stream_error(e): + logger.warning(f"HTTP/2 stream error in image collect: {e}") + else: + logger.error(f"Image collect request error: {e}") + except Exception as e: + logger.error( + f"Image collect processing error: {e}", + extra={"error_type": type(e).__name__}, + ) + finally: + await self.close() + + return images + + +class ImageWSBaseProcessor(BaseProcessor): + """WebSocket image processor base.""" + + def __init__(self, model: str, token: str = "", response_format: str = "b64_json"): + super().__init__(model, token) + self.response_format = response_format + if response_format == "url": + self.response_field = "url" + elif response_format == "base64": + self.response_field = "base64" + else: + self.response_field = "b64_json" + self._image_dir: Optional[Path] = None + + def _ensure_image_dir(self) -> Path: + if self._image_dir is None: + base_dir = DATA_DIR / "tmp" / "image" + base_dir.mkdir(parents=True, exist_ok=True) + self._image_dir = base_dir + return self._image_dir + + def _strip_base64(self, blob: str) -> str: + if not blob: + return "" + if "," in blob and "base64" in blob.split(",", 1)[0]: + return blob.split(",", 1)[1] + return blob + + def _filename(self, image_id: str, is_final: bool) -> str: + ext = "jpg" if is_final else "png" + return f"{image_id}.{ext}" + + def _build_file_url(self, filename: str) -> str: + app_url = get_config("app.app_url") + if app_url: + return f"{app_url.rstrip('/')}/v1/files/image/{filename}" + return f"/v1/files/image/{filename}" + + def _save_blob(self, image_id: str, blob: str, is_final: bool) -> str: + data = self._strip_base64(blob) + if not data: + return "" + image_dir = self._ensure_image_dir() + filename = self._filename(image_id, is_final) + filepath = image_dir / filename + with open(filepath, "wb") as f: + f.write(base64.b64decode(data)) + return self._build_file_url(filename) + + def _pick_best(self, existing: Optional[Dict], incoming: Dict) -> Dict: + if not existing: + return incoming + if incoming.get("is_final") and not existing.get("is_final"): + return incoming + if existing.get("is_final") and not incoming.get("is_final"): + return existing + if incoming.get("blob_size", 0) > existing.get("blob_size", 0): + return incoming + return existing + + def _to_output(self, image_id: str, item: Dict) -> str: + try: + if self.response_format == "url": + return self._save_blob( + image_id, item.get("blob", ""), item.get("is_final", False) + ) + return self._strip_base64(item.get("blob", "")) + except Exception as e: + logger.warning(f"Image output failed: {e}") + return "" + + +class ImageWSStreamProcessor(ImageWSBaseProcessor): + """WebSocket image stream processor.""" + + def __init__( + self, + model: str, + token: str = "", + n: int = 1, + response_format: str = "b64_json", + size: str = "1024x1024", + ): + super().__init__(model, token, "b64_json") + self.n = n + self.size = size + self._target_id: Optional[str] = None + self._index_map: Dict[str, int] = {} + self._partial_map: Dict[str, int] = {} + + def _assign_index(self, image_id: str) -> Optional[int]: + if image_id in self._index_map: + return self._index_map[image_id] + if len(self._index_map) >= self.n: + return None + self._index_map[image_id] = len(self._index_map) + return self._index_map[image_id] + + def _sse(self, event: str, data: dict) -> str: + return f"event: {event}\ndata: {orjson.dumps(data).decode()}\n\n" + + async def process(self, response: AsyncIterable[dict]) -> AsyncGenerator[str, None]: + images: Dict[str, Dict] = {} + + async for item in response: + if item.get("type") == "error": + message = item.get("error") or "Upstream error" + code = item.get("error_code") or "upstream_error" + yield self._sse( + "error", + { + "error": { + "message": message, + "type": "server_error", + "code": code, + } + }, + ) + return + if item.get("type") != "image": + continue + + image_id = item.get("image_id") + if not image_id: + continue + + if self.n == 1: + if self._target_id is None: + self._target_id = image_id + index = 0 if image_id == self._target_id else None + else: + index = self._assign_index(image_id) + + images[image_id] = self._pick_best(images.get(image_id), item) + + if index is None: + continue + + if item.get("stage") != "final": + partial_b64 = self._strip_base64(item.get("blob", "")) + if not partial_b64: + continue + partial_index = self._partial_map.get(image_id, 0) + if item.get("stage") == "medium": + partial_index = max(partial_index, 1) + self._partial_map[image_id] = partial_index + yield self._sse( + "image_generation.partial_image", + { + "type": "image_generation.partial_image", + "b64_json": partial_b64, + "created_at": int(time.time()), + "size": self.size, + "index": index, + "partial_image_index": partial_index, + }, + ) + + if self.n == 1: + if self._target_id and self._target_id in images: + selected = [(self._target_id, images[self._target_id])] + else: + selected = ( + [ + max( + images.items(), + key=lambda x: ( + x[1].get("is_final", False), + x[1].get("blob_size", 0), + ), + ) + ] + if images + else [] + ) + else: + selected = [ + (image_id, images[image_id]) + for image_id in self._index_map + if image_id in images + ] + + for image_id, item in selected: + output = self._strip_base64(item.get("blob", "")) + if not output: + continue + + if self.n == 1: + index = 0 + else: + index = self._index_map.get(image_id, 0) + yield self._sse( + "image_generation.completed", + { + "type": "image_generation.completed", + "b64_json": output, + "created_at": int(time.time()), + "size": self.size, + "index": index, + "usage": { + "total_tokens": 0, + "input_tokens": 0, + "output_tokens": 0, + "input_tokens_details": {"text_tokens": 0, "image_tokens": 0}, + }, + }, + ) + + +class ImageWSCollectProcessor(ImageWSBaseProcessor): + """WebSocket image non-stream processor.""" + + def __init__( + self, model: str, token: str = "", n: int = 1, response_format: str = "b64_json" + ): + super().__init__(model, token, response_format) + self.n = n + + async def process(self, response: AsyncIterable[dict]) -> List[str]: + images: Dict[str, Dict] = {} + + async for item in response: + if item.get("type") == "error": + message = item.get("error") or "Upstream error" + raise UpstreamException(message, details=item) + if item.get("type") != "image": + continue + image_id = item.get("image_id") + if not image_id: + continue + images[image_id] = self._pick_best(images.get(image_id), item) + + selected = sorted( + images.values(), + key=lambda x: (x.get("is_final", False), x.get("blob_size", 0)), + reverse=True, + ) + if self.n: + selected = selected[: self.n] + + results: List[str] = [] + for item in selected: + output = self._to_output(item.get("image_id", ""), item) + if output: + results.append(output) + + return results + + +__all__ = [ + "ImageStreamProcessor", + "ImageCollectProcessor", + "ImageWSStreamProcessor", + "ImageWSCollectProcessor", +] diff --git a/app/services/grok/processors/image_processors.py b/app/services/grok/processors/image_processors.py deleted file mode 100644 index 8f78ac3f..00000000 --- a/app/services/grok/processors/image_processors.py +++ /dev/null @@ -1,248 +0,0 @@ -""" -图片生成响应处理器(HTTP) -""" - -import asyncio -import random -from typing import AsyncGenerator, AsyncIterable, List - -import orjson -from curl_cffi.requests.errors import RequestsError - -from app.core.config import get_config -from app.core.logger import logger -from app.core.exceptions import UpstreamException -from .base import ( - BaseProcessor, - StreamIdleTimeoutError, - _with_idle_timeout, - _normalize_stream_line, - _collect_image_urls, - _is_http2_stream_error, -) - - -class ImageStreamProcessor(BaseProcessor): - """图片生成流式响应处理器""" - - def __init__( - self, model: str, token: str = "", n: int = 1, response_format: str = "b64_json" - ): - super().__init__(model, token) - self.partial_index = 0 - self.n = n - self.target_index = random.randint(0, 1) if n == 1 else None - self.response_format = response_format - if response_format == "url": - self.response_field = "url" - elif response_format == "base64": - self.response_field = "base64" - else: - self.response_field = "b64_json" - - def _sse(self, event: str, data: dict) -> str: - """构建 SSE 响应""" - return f"event: {event}\ndata: {orjson.dumps(data).decode()}\n\n" - - async def process( - self, response: AsyncIterable[bytes] - ) -> AsyncGenerator[str, None]: - """处理流式响应""" - final_images = [] - idle_timeout = get_config("timeout.stream_idle_timeout") - - try: - async for line in _with_idle_timeout(response, idle_timeout, self.model): - line = _normalize_stream_line(line) - if not line: - continue - try: - data = orjson.loads(line) - except orjson.JSONDecodeError: - continue - - resp = data.get("result", {}).get("response", {}) - - # 图片生成进度 - if img := resp.get("streamingImageGenerationResponse"): - image_index = img.get("imageIndex", 0) - progress = img.get("progress", 0) - - if self.n == 1 and image_index != self.target_index: - continue - - out_index = 0 if self.n == 1 else image_index - - yield self._sse( - "image_generation.partial_image", - { - "type": "image_generation.partial_image", - self.response_field: "", - "index": out_index, - "progress": progress, - }, - ) - continue - - # modelResponse - if mr := resp.get("modelResponse"): - if urls := _collect_image_urls(mr): - for url in urls: - if self.response_format == "url": - processed = await self.process_url(url, "image") - if processed: - final_images.append(processed) - continue - try: - dl_service = self._get_dl() - base64_data = await dl_service.to_base64( - url, self.token, "image" - ) - if base64_data: - if "," in base64_data: - b64 = base64_data.split(",", 1)[1] - else: - b64 = base64_data - final_images.append(b64) - except Exception as e: - logger.warning( - f"Failed to convert image to base64, falling back to URL: {e}" - ) - processed = await self.process_url(url, "image") - if processed: - final_images.append(processed) - continue - - for index, b64 in enumerate(final_images): - if self.n == 1: - if index != self.target_index: - continue - out_index = 0 - else: - out_index = index - - yield self._sse( - "image_generation.completed", - { - "type": "image_generation.completed", - self.response_field: b64, - "index": out_index, - "usage": { - "total_tokens": 0, - "input_tokens": 0, - "output_tokens": 0, - "input_tokens_details": { - "text_tokens": 0, - "image_tokens": 0, - }, - }, - }, - ) - except asyncio.CancelledError: - logger.debug("Image stream cancelled by client") - except StreamIdleTimeoutError as e: - raise UpstreamException( - message=f"Image stream idle timeout after {e.idle_seconds}s", - status_code=504, - details={ - "error": str(e), - "type": "stream_idle_timeout", - "idle_seconds": e.idle_seconds, - }, - ) - except RequestsError as e: - if _is_http2_stream_error(e): - logger.warning(f"HTTP/2 stream error in image: {e}") - raise UpstreamException( - message="Upstream connection closed unexpectedly", - status_code=502, - details={"error": str(e), "type": "http2_stream_error"}, - ) - logger.error(f"Image stream request error: {e}") - raise UpstreamException( - message=f"Upstream request failed: {e}", - status_code=502, - details={"error": str(e)}, - ) - except Exception as e: - logger.error( - f"Image stream processing error: {e}", - extra={"error_type": type(e).__name__}, - ) - raise - finally: - await self.close() - - -class ImageCollectProcessor(BaseProcessor): - """图片生成非流式响应处理器""" - - def __init__(self, model: str, token: str = "", response_format: str = "b64_json"): - super().__init__(model, token) - self.response_format = response_format - - async def process(self, response: AsyncIterable[bytes]) -> List[str]: - """处理并收集图片""" - images = [] - idle_timeout = get_config("timeout.stream_idle_timeout") - - try: - async for line in _with_idle_timeout(response, idle_timeout, self.model): - line = _normalize_stream_line(line) - if not line: - continue - try: - data = orjson.loads(line) - except orjson.JSONDecodeError: - continue - - resp = data.get("result", {}).get("response", {}) - - if mr := resp.get("modelResponse"): - if urls := _collect_image_urls(mr): - for url in urls: - if self.response_format == "url": - processed = await self.process_url(url, "image") - if processed: - images.append(processed) - continue - try: - dl_service = self._get_dl() - base64_data = await dl_service.to_base64( - url, self.token, "image" - ) - if base64_data: - if "," in base64_data: - b64 = base64_data.split(",", 1)[1] - else: - b64 = base64_data - images.append(b64) - except Exception as e: - logger.warning( - f"Failed to convert image to base64, falling back to URL: {e}" - ) - processed = await self.process_url(url, "image") - if processed: - images.append(processed) - - except asyncio.CancelledError: - logger.debug("Image collect cancelled by client") - except StreamIdleTimeoutError as e: - logger.warning(f"Image collect idle timeout: {e}") - except RequestsError as e: - if _is_http2_stream_error(e): - logger.warning(f"HTTP/2 stream error in image collect: {e}") - else: - logger.error(f"Image collect request error: {e}") - except Exception as e: - logger.error( - f"Image collect processing error: {e}", - extra={"error_type": type(e).__name__}, - ) - finally: - await self.close() - - return images - - -__all__ = ["ImageStreamProcessor", "ImageCollectProcessor"] diff --git a/app/services/grok/processors/image_ws_processors.py b/app/services/grok/processors/image_ws_processors.py deleted file mode 100644 index 788a442b..00000000 --- a/app/services/grok/processors/image_ws_processors.py +++ /dev/null @@ -1,268 +0,0 @@ -""" -图片生成响应处理器(WebSocket) -""" - -import base64 -import time -from pathlib import Path -from typing import AsyncGenerator, AsyncIterable, List, Dict, Optional - -import orjson - -from app.core.config import get_config -from app.core.logger import logger -from app.core.storage import DATA_DIR -from app.core.exceptions import UpstreamException -from .base import BaseProcessor - - -class ImageWSBaseProcessor(BaseProcessor): - """WebSocket 图片处理基类""" - - def __init__(self, model: str, token: str = "", response_format: str = "b64_json"): - super().__init__(model, token) - self.response_format = response_format - if response_format == "url": - self.response_field = "url" - elif response_format == "base64": - self.response_field = "base64" - else: - self.response_field = "b64_json" - self._image_dir: Optional[Path] = None - - def _ensure_image_dir(self) -> Path: - if self._image_dir is None: - base_dir = DATA_DIR / "tmp" / "image" - base_dir.mkdir(parents=True, exist_ok=True) - self._image_dir = base_dir - return self._image_dir - - def _strip_base64(self, blob: str) -> str: - if not blob: - return "" - if "," in blob and "base64" in blob.split(",", 1)[0]: - return blob.split(",", 1)[1] - return blob - - def _filename(self, image_id: str, is_final: bool) -> str: - ext = "jpg" if is_final else "png" - return f"{image_id}.{ext}" - - def _build_file_url(self, filename: str) -> str: - app_url = get_config("app.app_url") - if app_url: - return f"{app_url.rstrip('/')}/v1/files/image/{filename}" - return f"/v1/files/image/{filename}" - - def _save_blob(self, image_id: str, blob: str, is_final: bool) -> str: - data = self._strip_base64(blob) - if not data: - return "" - image_dir = self._ensure_image_dir() - filename = self._filename(image_id, is_final) - filepath = image_dir / filename - with open(filepath, "wb") as f: - f.write(base64.b64decode(data)) - return self._build_file_url(filename) - - def _pick_best(self, existing: Optional[Dict], incoming: Dict) -> Dict: - if not existing: - return incoming - if incoming.get("is_final") and not existing.get("is_final"): - return incoming - if existing.get("is_final") and not incoming.get("is_final"): - return existing - if incoming.get("blob_size", 0) > existing.get("blob_size", 0): - return incoming - return existing - - def _to_output(self, image_id: str, item: Dict) -> str: - try: - if self.response_format == "url": - return self._save_blob( - image_id, item.get("blob", ""), item.get("is_final", False) - ) - return self._strip_base64(item.get("blob", "")) - except Exception as e: - logger.warning(f"Image output failed: {e}") - return "" - - -class ImageWSStreamProcessor(ImageWSBaseProcessor): - """WebSocket 图片流式响应处理器""" - - def __init__( - self, - model: str, - token: str = "", - n: int = 1, - response_format: str = "b64_json", - size: str = "1024x1024", - ): - super().__init__(model, token, "b64_json") - self.n = n - self.size = size - self._target_id: Optional[str] = None - self._index_map: Dict[str, int] = {} - self._partial_map: Dict[str, int] = {} - - def _assign_index(self, image_id: str) -> Optional[int]: - if image_id in self._index_map: - return self._index_map[image_id] - if len(self._index_map) >= self.n: - return None - self._index_map[image_id] = len(self._index_map) - return self._index_map[image_id] - - def _sse(self, event: str, data: dict) -> str: - return f"event: {event}\ndata: {orjson.dumps(data).decode()}\n\n" - - async def process(self, response: AsyncIterable[dict]) -> AsyncGenerator[str, None]: - images: Dict[str, Dict] = {} - - async for item in response: - if item.get("type") == "error": - message = item.get("error") or "Upstream error" - code = item.get("error_code") or "upstream_error" - yield self._sse( - "error", - { - "error": { - "message": message, - "type": "server_error", - "code": code, - } - }, - ) - return - if item.get("type") != "image": - continue - - image_id = item.get("image_id") - if not image_id: - continue - - if self.n == 1: - if self._target_id is None: - self._target_id = image_id - index = 0 if image_id == self._target_id else None - else: - index = self._assign_index(image_id) - - images[image_id] = self._pick_best(images.get(image_id), item) - - if index is None: - continue - - if item.get("stage") != "final": - partial_b64 = self._strip_base64(item.get("blob", "")) - if not partial_b64: - continue - partial_index = self._partial_map.get(image_id, 0) - if item.get("stage") == "medium": - partial_index = max(partial_index, 1) - self._partial_map[image_id] = partial_index - yield self._sse( - "image_generation.partial_image", - { - "type": "image_generation.partial_image", - "b64_json": partial_b64, - "created_at": int(time.time()), - "size": self.size, - "index": index, - "partial_image_index": partial_index, - }, - ) - - if self.n == 1: - if self._target_id and self._target_id in images: - selected = [(self._target_id, images[self._target_id])] - else: - selected = ( - [ - max( - images.items(), - key=lambda x: ( - x[1].get("is_final", False), - x[1].get("blob_size", 0), - ), - ) - ] - if images - else [] - ) - else: - selected = [ - (image_id, images[image_id]) - for image_id in self._index_map - if image_id in images - ] - - for image_id, item in selected: - output = self._strip_base64(item.get("blob", "")) - if not output: - continue - - if self.n == 1: - index = 0 - else: - index = self._index_map.get(image_id, 0) - yield self._sse( - "image_generation.completed", - { - "type": "image_generation.completed", - "b64_json": output, - "created_at": int(time.time()), - "size": self.size, - "index": index, - "usage": { - "total_tokens": 0, - "input_tokens": 0, - "output_tokens": 0, - "input_tokens_details": {"text_tokens": 0, "image_tokens": 0}, - }, - }, - ) - - -class ImageWSCollectProcessor(ImageWSBaseProcessor): - """WebSocket 图片非流式响应处理器""" - - def __init__( - self, model: str, token: str = "", n: int = 1, response_format: str = "b64_json" - ): - super().__init__(model, token, response_format) - self.n = n - - async def process(self, response: AsyncIterable[dict]) -> List[str]: - images: Dict[str, Dict] = {} - - async for item in response: - if item.get("type") == "error": - message = item.get("error") or "Upstream error" - raise UpstreamException(message, details=item) - if item.get("type") != "image": - continue - image_id = item.get("image_id") - if not image_id: - continue - images[image_id] = self._pick_best(images.get(image_id), item) - - selected = sorted( - images.values(), - key=lambda x: (x.get("is_final", False), x.get("blob_size", 0)), - reverse=True, - ) - if self.n: - selected = selected[: self.n] - - results: List[str] = [] - for item in selected: - output = self._to_output(item.get("image_id", ""), item) - if output: - results.append(output) - - return results - - -__all__ = ["ImageWSStreamProcessor", "ImageWSCollectProcessor"] diff --git a/app/services/grok/processors/video_processors.py b/app/services/grok/processors/video.py similarity index 96% rename from app/services/grok/processors/video_processors.py rename to app/services/grok/processors/video.py index a0ead8c3..89521b9e 100644 --- a/app/services/grok/processors/video_processors.py +++ b/app/services/grok/processors/video.py @@ -1,5 +1,5 @@ """ -视频响应处理器 +Video response processors. """ import asyncio @@ -22,7 +22,7 @@ class VideoStreamProcessor(BaseProcessor): - """视频流式响应处理器""" + """Video stream response processor.""" def __init__(self, model: str, token: str = "", think: bool = None): super().__init__(model, token) @@ -37,7 +37,7 @@ def __init__(self, model: str, token: str = "", think: bool = None): self.show_think = think def _sse(self, content: str = "", role: str = None, finish: str = None) -> str: - """构建 SSE 响应""" + """Build SSE response.""" delta = {} if role: delta["role"] = role @@ -57,7 +57,7 @@ def _sse(self, content: str = "", role: str = None, finish: str = None) -> str: return f"data: {orjson.dumps(chunk).decode()}\n\n" def _build_video_html(self, video_url: str, thumbnail_url: str = "") -> str: - """构建视频 HTML 标签""" + """Build video HTML tag.""" import html safe_video_url = html.escape(video_url) @@ -70,7 +70,7 @@ def _build_video_html(self, video_url: str, thumbnail_url: str = "") -> str: async def process( self, response: AsyncIterable[bytes] ) -> AsyncGenerator[str, None]: - """处理视频流式响应""" + """Process video stream response.""" idle_timeout = get_config("timeout.video_idle_timeout") try: @@ -92,7 +92,7 @@ async def process( yield self._sse(role="assistant") self.role_sent = True - # 视频生成进度 + # Video generation progress if video_resp := resp.get("streamingVideoGenerationResponse"): progress = video_resp.get("progress", 0) @@ -175,7 +175,7 @@ async def process( class VideoCollectProcessor(BaseProcessor): - """视频非流式响应处理器""" + """Video non-stream response processor.""" def __init__(self, model: str, token: str = ""): super().__init__(model, token) @@ -188,7 +188,7 @@ def _build_video_html(self, video_url: str, thumbnail_url: str = "") -> str: ''' async def process(self, response: AsyncIterable[bytes]) -> dict[str, Any]: - """处理并收集视频响应""" + """Process and collect video response.""" response_id = "" content = "" idle_timeout = get_config("timeout.video_idle_timeout") diff --git a/app/services/grok/services/image.py b/app/services/grok/services/image.py index 73139767..4bba8d9f 100644 --- a/app/services/grok/services/image.py +++ b/app/services/grok/services/image.py @@ -1,11 +1,307 @@ """ -Grok Imagine WebSocket image service. +Grok image services. """ +import asyncio +import math +import random +from dataclasses import dataclass +from typing import Any, AsyncGenerator, List, Optional, Union + +from app.core.config import get_config +from app.core.logger import logger +from app.services.grok.processors import ( + ImageStreamProcessor, + ImageCollectProcessor, + ImageWSStreamProcessor, + ImageWSCollectProcessor, +) +from app.services.grok.services.chat import GrokChatService +from app.services.grok.utils.stream import wrap_stream_with_usage +from app.services.token import EffortType from app.services.reverse.ws_imagine import ImagineWebSocketReverse ImageService = ImagineWebSocketReverse image_service = ImagineWebSocketReverse() -__all__ = ["image_service", "ImageService"] + +@dataclass +class ImageGenerationResult: + stream: bool + data: Union[AsyncGenerator[str, None], List[str]] + usage_override: Optional[dict] = None + + +class ImageGenerationService: + """Image generation orchestration service.""" + + async def generate( + self, + *, + token_mgr: Any, + token: str, + model_info: Any, + prompt: str, + n: int, + response_format: str, + size: str, + aspect_ratio: str, + stream: bool, + use_ws: bool, + ) -> ImageGenerationResult: + if stream: + if use_ws: + return await self._stream_ws( + token_mgr=token_mgr, + token=token, + model_info=model_info, + prompt=prompt, + n=n, + response_format=response_format, + size=size, + aspect_ratio=aspect_ratio, + ) + return await self._stream_http( + token_mgr=token_mgr, + token=token, + model_info=model_info, + prompt=prompt, + n=n, + response_format=response_format, + ) + + if use_ws: + return await self._collect_ws( + token_mgr=token_mgr, + token=token, + model_info=model_info, + prompt=prompt, + n=n, + response_format=response_format, + aspect_ratio=aspect_ratio, + ) + + return await self._collect_http( + token_mgr=token_mgr, + token=token, + model_info=model_info, + prompt=prompt, + n=n, + response_format=response_format, + ) + + async def _stream_ws( + self, + *, + token_mgr: Any, + token: str, + model_info: Any, + prompt: str, + n: int, + response_format: str, + size: str, + aspect_ratio: str, + ) -> ImageGenerationResult: + enable_nsfw = bool(get_config("image.image_ws_nsfw")) + upstream = image_service.stream( + token=token, + prompt=prompt, + aspect_ratio=aspect_ratio, + n=n, + enable_nsfw=enable_nsfw, + ) + processor = ImageWSStreamProcessor( + model_info.model_id, + token, + n=n, + response_format=response_format, + size=size, + ) + stream = wrap_stream_with_usage( + processor.process(upstream), + token_mgr, + token, + model_info.model_id, + ) + return ImageGenerationResult(stream=True, data=stream) + + async def _stream_http( + self, + *, + token_mgr: Any, + token: str, + model_info: Any, + prompt: str, + n: int, + response_format: str, + ) -> ImageGenerationResult: + response = await GrokChatService().chat( + token=token, + message=f"Image Generation: {prompt}", + model=model_info.grok_model, + mode=model_info.model_mode, + stream=True, + ) + processor = ImageStreamProcessor( + model_info.model_id, + token, + n=n, + response_format=response_format, + ) + stream = wrap_stream_with_usage( + processor.process(response), + token_mgr, + token, + model_info.model_id, + ) + return ImageGenerationResult(stream=True, data=stream) + + async def _collect_ws( + self, + *, + token_mgr: Any, + token: str, + model_info: Any, + prompt: str, + n: int, + response_format: str, + aspect_ratio: str, + ) -> ImageGenerationResult: + enable_nsfw = bool(get_config("image.image_ws_nsfw")) + all_images: List[str] = [] + seen = set() + expected_per_call = 6 + calls_needed = max(1, int(math.ceil(n / expected_per_call))) + calls_needed = min(calls_needed, n) + + async def _fetch_batch(call_target: int): + upstream = image_service.stream( + token=token, + prompt=prompt, + aspect_ratio=aspect_ratio, + n=call_target, + enable_nsfw=enable_nsfw, + ) + processor = ImageWSCollectProcessor( + model_info.model_id, + token, + n=call_target, + response_format=response_format, + ) + return await processor.process(upstream) + + tasks = [] + for i in range(calls_needed): + remaining = n - (i * expected_per_call) + call_target = min(expected_per_call, remaining) + tasks.append(_fetch_batch(call_target)) + + results = await asyncio.gather(*tasks, return_exceptions=True) + for batch in results: + if isinstance(batch, Exception): + logger.warning(f"WS batch failed: {batch}") + continue + for img in batch: + if img not in seen: + seen.add(img) + all_images.append(img) + if len(all_images) >= n: + break + if len(all_images) >= n: + break + + try: + await token_mgr.consume(token, self._get_effort(model_info)) + except Exception as e: + logger.warning(f"Failed to consume token: {e}") + + selected = self._select_images(all_images, n) + usage_override = { + "total_tokens": 0, + "input_tokens": 0, + "output_tokens": 0, + "input_tokens_details": {"text_tokens": 0, "image_tokens": 0}, + } + return ImageGenerationResult( + stream=False, data=selected, usage_override=usage_override + ) + + async def _collect_http( + self, + *, + token_mgr: Any, + token: str, + model_info: Any, + prompt: str, + n: int, + response_format: str, + ) -> ImageGenerationResult: + calls_needed = (n + 1) // 2 + + async def _call_grok(): + success = False + try: + response = await GrokChatService().chat( + token=token, + message=f"Image Generation: {prompt}", + model=model_info.grok_model, + mode=model_info.model_mode, + stream=True, + ) + processor = ImageCollectProcessor( + model_info.model_id, token, response_format=response_format + ) + images = await processor.process(response) + success = True + return images + except Exception as e: + logger.error(f"Grok image call failed: {e}") + return [] + finally: + if success: + try: + await token_mgr.consume(token, self._get_effort(model_info)) + except Exception as e: + logger.warning(f"Failed to consume token: {e}") + + if calls_needed == 1: + all_images = await _call_grok() + else: + tasks = [_call_grok() for _ in range(calls_needed)] + results = await asyncio.gather(*tasks, return_exceptions=True) + + all_images: List[str] = [] + for result in results: + if isinstance(result, Exception): + logger.error(f"Concurrent call failed: {result}") + elif isinstance(result, list): + all_images.extend(result) + + selected = self._select_images(all_images, n) + return ImageGenerationResult(stream=False, data=selected) + + @staticmethod + def _get_effort(model_info: Any) -> EffortType: + return ( + EffortType.HIGH + if (model_info and model_info.cost.value == "high") + else EffortType.LOW + ) + + @staticmethod + def _select_images(images: List[str], n: int) -> List[str]: + if len(images) >= n: + return random.sample(images, n) + selected = images.copy() + while len(selected) < n: + selected.append("error") + return selected + + +__all__ = [ + "image_service", + "ImageService", + "ImageGenerationService", + "ImageGenerationResult", +] diff --git a/app/services/grok/services/image_edit.py b/app/services/grok/services/image_edit.py new file mode 100644 index 00000000..ca460bc2 --- /dev/null +++ b/app/services/grok/services/image_edit.py @@ -0,0 +1,197 @@ +""" +Grok image edit service. +""" + +import asyncio +import random +import re +from dataclasses import dataclass +from typing import AsyncGenerator, List, Union, Any + +from app.core.exceptions import AppException, ErrorType +from app.core.logger import logger +from app.services.grok.processors import ImageCollectProcessor, ImageStreamProcessor +from app.services.grok.services.assets import UploadService +from app.services.grok.services.chat import GrokChatService +from app.services.grok.services.video import VideoService +from app.services.grok.utils.stream import wrap_stream_with_usage + + +@dataclass +class ImageEditResult: + stream: bool + data: Union[AsyncGenerator[str, None], List[str]] + + +class ImageEditService: + """Image edit orchestration service.""" + + async def edit( + self, + *, + token_mgr: Any, + token: str, + model_info: Any, + prompt: str, + images: List[str], + n: int, + response_format: str, + stream: bool, + ) -> ImageEditResult: + image_urls = await self._upload_images(images, token) + parent_post_id = await self._get_parent_post_id(token, image_urls) + + model_config_override = { + "modelMap": { + "imageEditModel": "imagine", + "imageEditModelConfig": { + "imageReferences": image_urls, + }, + } + } + if parent_post_id: + model_config_override["modelMap"]["imageEditModelConfig"][ + "parentPostId" + ] = parent_post_id + + tool_overrides = {"imageGen": True} + + if stream: + response = await GrokChatService().chat( + token=token, + message=prompt, + model=model_info.grok_model, + mode=None, + stream=True, + tool_overrides=tool_overrides, + model_config_override=model_config_override, + ) + processor = ImageStreamProcessor( + model_info.model_id, + token, + n=n, + response_format=response_format, + ) + return ImageEditResult( + stream=True, + data=wrap_stream_with_usage( + processor.process(response), + token_mgr, + token, + model_info.model_id, + ), + ) + + images_out = await self._collect_images( + token=token, + prompt=prompt, + model_info=model_info, + n=n, + response_format=response_format, + tool_overrides=tool_overrides, + model_config_override=model_config_override, + ) + return ImageEditResult(stream=False, data=images_out) + + async def _upload_images(self, images: List[str], token: str) -> List[str]: + image_urls: List[str] = [] + upload_service = UploadService() + try: + for image in images: + _, file_uri = await upload_service.upload(image, token) + if file_uri: + if file_uri.startswith("http"): + image_urls.append(file_uri) + else: + image_urls.append( + f"https://assets.grok.com/{file_uri.lstrip('/')}" + ) + finally: + await upload_service.close() + + if not image_urls: + raise AppException( + message="Image upload failed", + error_type=ErrorType.SERVER.value, + code="upload_failed", + ) + + return image_urls + + async def _get_parent_post_id(self, token: str, image_urls: List[str]) -> str: + parent_post_id = None + try: + media_service = VideoService() + parent_post_id = await media_service.create_image_post(token, image_urls[0]) + logger.debug(f"Parent post ID: {parent_post_id}") + except Exception as e: + logger.warning(f"Create image post failed: {e}") + + if parent_post_id: + return parent_post_id + + for url in image_urls: + match = re.search(r"/generated/([a-f0-9-]+)/", url) + if match: + parent_post_id = match.group(1) + logger.debug(f"Parent post ID: {parent_post_id}") + break + match = re.search(r"/users/[^/]+/([a-f0-9-]+)/content", url) + if match: + parent_post_id = match.group(1) + logger.debug(f"Parent post ID: {parent_post_id}") + break + + return parent_post_id or "" + + async def _collect_images( + self, + *, + token: str, + prompt: str, + model_info: Any, + n: int, + response_format: str, + tool_overrides: dict, + model_config_override: dict, + ) -> List[str]: + calls_needed = (n + 1) // 2 + + async def _call_edit(): + response = await GrokChatService().chat( + token=token, + message=prompt, + model=model_info.grok_model, + mode=None, + stream=True, + tool_overrides=tool_overrides, + model_config_override=model_config_override, + ) + processor = ImageCollectProcessor( + model_info.model_id, token, response_format=response_format + ) + return await processor.process(response) + + if calls_needed == 1: + all_images = await _call_edit() + else: + tasks = [_call_edit() for _ in range(calls_needed)] + results = await asyncio.gather(*tasks, return_exceptions=True) + + all_images: List[str] = [] + for result in results: + if isinstance(result, Exception): + logger.error(f"Concurrent call failed: {result}") + elif isinstance(result, list): + all_images.extend(result) + + if len(all_images) >= n: + return random.sample(all_images, n) + + selected_images = all_images.copy() + while len(selected_images) < n: + selected_images.append("error") + return selected_images + + +__all__ = ["ImageEditService", "ImageEditResult"] diff --git a/app/services/grok/services/media.py b/app/services/grok/services/video.py similarity index 93% rename from app/services/grok/services/media.py rename to app/services/grok/services/video.py index 50aa0254..bb53e535 100644 --- a/app/services/grok/services/media.py +++ b/app/services/grok/services/video.py @@ -1,5 +1,5 @@ """ -Grok 视频生成服务 +Grok video generation service. """ import asyncio @@ -25,7 +25,7 @@ def _get_semaphore() -> asyncio.Semaphore: - """获取或更新信号量""" + """Get or refresh the semaphore.""" global _MEDIA_SEMAPHORE, _MEDIA_SEM_VALUE value = max(1, int(get_config("performance.media_max_concurrent"))) if value != _MEDIA_SEM_VALUE: @@ -35,7 +35,7 @@ def _get_semaphore() -> asyncio.Semaphore: class VideoService: - """视频生成服务""" + """Video generation service.""" def __init__(self): self.timeout = get_config("network.timeout") @@ -47,7 +47,7 @@ async def create_post( media_type: str = "MEDIA_POST_TYPE_VIDEO", media_url: str = None, ) -> str: - """创建媒体帖子,返回 post ID""" + """Create media post and return post ID.""" try: if media_type == "MEDIA_POST_TYPE_IMAGE" and not media_url: raise ValidationException("media_url is required for image posts") @@ -74,7 +74,7 @@ async def create_post( raise UpstreamException(f"Create post error: {str(e)}") async def create_image_post(self, token: str, image_url: str) -> str: - """创建图片帖子,返回 post ID""" + """Create image post and return post ID.""" return await self.create_post( token, prompt="", media_type="MEDIA_POST_TYPE_IMAGE", media_url=image_url ) @@ -88,7 +88,7 @@ def _build_payload( resolution_name: str = "480p", preset: str = "normal", ) -> dict: - """构建视频生成载荷""" + """Build video generation payload.""" mode_map = { "fun": "--mode=extremely-crazy", "normal": "--mode=normal", @@ -139,7 +139,7 @@ async def _generate_internal( resolution_name: str, preset: str, ) -> AsyncGenerator[bytes, None]: - """内部生成逻辑""" + """Internal generation logic.""" session = None try: payload = self._build_payload( @@ -182,7 +182,7 @@ async def generate( resolution_name: str = "480p", preset: str = "normal", ) -> AsyncGenerator[bytes, None]: - """生成视频""" + """Generate video.""" logger.info( f"Video generation: prompt='{prompt[:50]}...', ratio={aspect_ratio}, length={video_length}s, preset={preset}" ) @@ -208,7 +208,7 @@ async def generate_from_image( resolution: str = "480p", preset: str = "normal", ) -> AsyncGenerator[bytes, None]: - """从图片生成视频""" + """Generate video from image.""" logger.info( f"Image to video: prompt='{prompt[:50]}...', image={image_url[:80]}" ) @@ -229,12 +229,12 @@ async def completions( resolution: str = "480p", preset: str = "normal", ): - """视频生成入口""" - # 获取 token(使用智能路由) + """Video generation entrypoint.""" + # Get token via intelligent routing. token_mgr = await get_token_manager() await token_mgr.reload_if_stale() - # 使用智能路由选择 token(根据视频需求与候选池) + # Select token based on video requirements and pool candidates. pool_candidates = ModelService.pool_candidates_for_model(model) token_info = token_mgr.get_token_for_video( resolution=resolution, @@ -250,7 +250,7 @@ async def completions( status_code=429, ) - # 从 TokenInfo 对象中提取 token 字符串 + # Extract token string from TokenInfo. token = token_info.token if token.startswith("sso="): token = token[4:] @@ -258,7 +258,7 @@ async def completions( think = {"enabled": True, "disabled": False}.get(thinking) is_stream = stream if stream is not None else get_config("chat.stream") - # 提取内容 + # Extract content. from app.services.grok.services.chat import MessageExtractor from app.services.grok.services.assets import UploadService @@ -267,7 +267,7 @@ async def completions( except ValueError as e: raise ValidationException(str(e)) - # 处理图片附件 + # Handle image attachments. image_url = None if attachments: upload_service = UploadService() @@ -281,7 +281,7 @@ async def completions( finally: await upload_service.close() - # 生成视频 + # Generate video. service = VideoService() if image_url: response = await service.generate_from_image( @@ -292,7 +292,7 @@ async def completions( token, prompt, aspect_ratio, video_length, resolution, preset ) - # 处理响应 + # Process response. if is_stream: processor = VideoStreamProcessor(model, token, think) return wrap_stream_with_usage( diff --git a/app/services/grok/utils/download.py b/app/services/grok/utils/download.py new file mode 100644 index 00000000..b3b930f5 --- /dev/null +++ b/app/services/grok/utils/download.py @@ -0,0 +1,5 @@ +"""Download service (compat wrapper).""" + +from app.services.grok.services.assets import DownloadService + +__all__ = ["DownloadService"] From 432acf41cf893752cd0285b0e9a6106dd11fd4ec Mon Sep 17 00:00:00 2001 From: Chenyme <118253778+chenyme@users.noreply.github.com> Date: Fri, 13 Feb 2026 00:31:05 +0800 Subject: [PATCH 13/27] refactor: replace DownloadService with CacheService in admin API and update related methods --- app/api/v1/admin.py | 34 +- app/services/grok/batch_services/assets.py | 147 ++++- app/services/grok/batch_services/nsfw.py | 61 +- app/services/grok/batch_services/usage.py | 53 +- app/services/grok/processors/base.py | 2 +- app/services/grok/processors/chat.py | 4 +- app/services/grok/processors/image.py | 4 +- app/services/grok/services/assets.py | 666 --------------------- app/services/grok/services/chat.py | 4 +- app/services/grok/services/image_edit.py | 4 +- app/services/grok/services/nsfw.py | 68 --- app/services/grok/services/usage.py | 60 -- app/services/grok/services/video.py | 4 +- app/services/grok/utils/cache.py | 110 ++++ app/services/grok/utils/download.py | 232 ++++++- app/services/grok/utils/locks.py | 71 +++ app/services/grok/utils/upload.py | 137 +++++ app/services/token/manager.py | 11 +- 18 files changed, 835 insertions(+), 837 deletions(-) delete mode 100644 app/services/grok/services/assets.py delete mode 100644 app/services/grok/services/nsfw.py delete mode 100644 app/services/grok/services/usage.py create mode 100644 app/services/grok/utils/cache.py create mode 100644 app/services/grok/utils/locks.py create mode 100644 app/services/grok/utils/upload.py diff --git a/app/api/v1/admin.py b/app/api/v1/admin.py index 8c05b051..d88b98d0 100644 --- a/app/api/v1/admin.py +++ b/app/api/v1/admin.py @@ -1121,13 +1121,13 @@ async def admin_cache_page(): @router.get("/api/v1/admin/cache", dependencies=[Depends(verify_api_key)]) async def get_cache_stats_api(request: Request): """获取缓存统计""" - from app.services.grok.utils.download import DownloadService + from app.services.grok.utils.cache import CacheService from app.services.token.manager import get_token_manager try: - dl_service = DownloadService() - image_stats = dl_service.get_stats("image") - video_stats = dl_service.get_stats("video") + cache_service = CacheService() + image_stats = cache_service.get_stats("image") + video_stats = cache_service.get_stats("video") mgr = await get_token_manager() pools = mgr.pools @@ -1308,7 +1308,7 @@ async def get_cache_stats_api(request: Request): ) async def load_online_cache_api_async(data: dict): """在线资产统计(异步批量 + SSE 进度)""" - from app.services.grok.utils.download import DownloadService + from app.services.grok.utils.cache import CacheService from app.services.token.manager import get_token_manager mgr = await get_token_manager() @@ -1361,9 +1361,9 @@ async def load_online_cache_api_async(data: dict): async def _run(): try: - dl_service = DownloadService() - image_stats = dl_service.get_stats("image") - video_stats = dl_service.get_stats("video") + cache_service = CacheService() + image_stats = cache_service.get_stats("image") + video_stats = cache_service.get_stats("video") async def _on_item(item: str, res: dict): ok = bool(res.get("data", {}).get("ok")) @@ -1430,13 +1430,13 @@ async def _on_item(item: str, res: dict): @router.post("/api/v1/admin/cache/clear", dependencies=[Depends(verify_api_key)]) async def clear_local_cache_api(data: dict): """清理本地缓存""" - from app.services.grok.utils.download import DownloadService + from app.services.grok.utils.cache import CacheService cache_type = data.get("type", "image") try: - dl_service = DownloadService() - result = dl_service.clear(cache_type) + cache_service = CacheService() + result = cache_service.clear(cache_type) return {"status": "success", "result": result} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -1450,13 +1450,13 @@ async def list_local_cache_api( page_size: int = 1000, ): """列出本地缓存文件""" - from app.services.grok.utils.download import DownloadService + from app.services.grok.utils.cache import CacheService try: if type_: cache_type = type_ - dl_service = DownloadService() - result = dl_service.list_files(cache_type, page, page_size) + cache_service = CacheService() + result = cache_service.list_files(cache_type, page, page_size) return {"status": "success", **result} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -1465,15 +1465,15 @@ async def list_local_cache_api( @router.post("/api/v1/admin/cache/item/delete", dependencies=[Depends(verify_api_key)]) async def delete_local_cache_item_api(data: dict): """删除单个本地缓存文件""" - from app.services.grok.utils.download import DownloadService + from app.services.grok.utils.cache import CacheService cache_type = data.get("type", "image") name = data.get("name") if not name: raise HTTPException(status_code=400, detail="Missing file name") try: - dl_service = DownloadService() - result = dl_service.delete_file(cache_type, name) + cache_service = CacheService() + result = cache_service.delete_file(cache_type, name) return {"status": "success", "result": result} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) diff --git a/app/services/grok/batch_services/assets.py b/app/services/grok/batch_services/assets.py index 3ad213e2..c22045c5 100644 --- a/app/services/grok/batch_services/assets.py +++ b/app/services/grok/batch_services/assets.py @@ -2,12 +2,155 @@ Batch assets service. """ -from typing import Callable, Awaitable, Dict, Any, Optional +import asyncio +from typing import Callable, Awaitable, Dict, Any, Optional, List -from app.services.grok.services.assets import ListService, DeleteService +from curl_cffi.requests import AsyncSession + +from app.core.config import get_config +from app.core.logger import logger +from app.services.reverse import AssetsListReverse, AssetsDeleteReverse +from app.services.grok.utils.locks import _get_assets_semaphore from app.services.grok.utils.batch import run_in_batches +class BaseAssetsService: + """Base assets service.""" + + def __init__(self): + self._session: Optional[AsyncSession] = None + + async def _get_session(self) -> AsyncSession: + if self._session is None: + self._session = AsyncSession() + return self._session + + async def close(self): + if self._session: + await self._session.close() + self._session = None + + +class ListService(BaseAssetsService): + """Assets list service.""" + + async def iter_assets(self, token: str): + params = { + "pageSize": 50, + "orderBy": "ORDER_BY_LAST_USE_TIME", + "source": "SOURCE_ANY", + "isLatest": "true", + } + page_token = None + seen_tokens = set() + + async with AsyncSession() as session: + while True: + if page_token: + if page_token in seen_tokens: + logger.warning("Pagination stopped: repeated page token") + break + seen_tokens.add(page_token) + params["pageToken"] = page_token + else: + params.pop("pageToken", None) + + response = await AssetsListReverse.request( + session, + token, + params, + ) + + result = response.json() + page_assets = result.get("assets", []) + yield page_assets + + page_token = result.get("nextPageToken") + if not page_token: + break + + async def list(self, token: str) -> List[Dict]: + assets = [] + async for page_assets in self.iter_assets(token): + assets.extend(page_assets) + logger.info(f"List success: {len(assets)} files") + return assets + + async def count(self, token: str) -> int: + total = 0 + async for page_assets in self.iter_assets(token): + total += len(page_assets) + logger.debug(f"Asset count: {total}") + return total + + +class DeleteService(BaseAssetsService): + """Assets delete service.""" + + async def delete(self, token: str, asset_id: str) -> bool: + async with _get_assets_semaphore(): + session = await self._get_session() + await AssetsDeleteReverse.request( + session, + token, + asset_id, + ) + + logger.debug(f"Deleted: {asset_id}") + return True + + async def delete_all(self, token: str) -> Dict[str, int]: + total = success = failed = 0 + list_service = ListService() + + try: + async for assets in list_service.iter_assets(token): + if not assets: + continue + + total += len(assets) + batch_result = await self._delete_batch(token, assets) + success += batch_result["success"] + failed += batch_result["failed"] + + if total == 0: + logger.info("No assets to delete") + return {"total": 0, "success": 0, "failed": 0, "skipped": True} + finally: + await list_service.close() + + logger.info(f"Delete all: total={total}, success={success}, failed={failed}") + return {"total": total, "success": success, "failed": failed} + + async def _delete_batch(self, token: str, assets: List[Dict]) -> Dict[str, int]: + batch_size = max(1, int(get_config("performance.assets_delete_batch_size"))) + success = failed = 0 + + for i in range(0, len(assets), batch_size): + batch = assets[i : i + batch_size] + results = await asyncio.gather( + *[ + self._delete_one(token, asset, idx) + for idx, asset in enumerate(batch) + ], + return_exceptions=True, + ) + success += sum(1 for r in results if r is True) + failed += sum(1 for r in results if r is not True) + + return {"success": success, "failed": failed} + + async def _delete_one(self, token: str, asset: Dict, index: int) -> bool: + await asyncio.sleep(0.01 * index) + asset_id = asset.get("assetId", "") + if not asset_id: + return False + try: + return await self.delete(token, asset_id) + except Exception: + return False + + class BatchAssetsService: """Batch assets orchestration.""" diff --git a/app/services/grok/batch_services/nsfw.py b/app/services/grok/batch_services/nsfw.py index 3c8ae06b..b701bc3a 100644 --- a/app/services/grok/batch_services/nsfw.py +++ b/app/services/grok/batch_services/nsfw.py @@ -2,12 +2,69 @@ Batch NSFW service. """ +from dataclasses import dataclass from typing import Callable, Awaitable, Dict, Any, Optional -from app.services.grok.services.nsfw import NSFWService +from curl_cffi.requests import AsyncSession + +from app.core.logger import logger +from app.core.config import get_config +from app.core.exceptions import UpstreamException +from app.services.reverse import NsfwMgmtReverse, SetBirthReverse +from app.services.reverse.utils.grpc import GrpcStatus from app.services.grok.utils.batch import run_in_batches +@dataclass +class NSFWResult: + """NSFW 操作结果""" + + success: bool + http_status: int + grpc_status: Optional[int] = None + grpc_message: Optional[str] = None + error: Optional[str] = None + + +class NSFWService: + """NSFW 模式服务""" + + async def enable(self, token: str) -> NSFWResult: + """为单个 token 开启 NSFW 模式""" + try: + browser = get_config("security.browser") + async with AsyncSession(impersonate=browser) as session: + # 先设置出生日期 + try: + await SetBirthReverse.request(session, token) + except UpstreamException as e: + status = None + if e.details and "status" in e.details: + status = e.details["status"] + else: + status = getattr(e, "status_code", None) + return NSFWResult( + success=False, + http_status=status or 0, + error=f"Set birth date failed: {str(e)}", + ) + + # 开启 NSFW + grpc_status: GrpcStatus = await NsfwMgmtReverse.request(session, token) + success = grpc_status.code in (-1, 0) + + return NSFWResult( + success=success, + http_status=200, + grpc_status=grpc_status.code, + grpc_message=grpc_status.message or None, + ) + + except Exception as e: + logger.error(f"NSFW enable failed: {e}") + return NSFWResult(success=False, http_status=0, error=str(e)[:100]) + + class BatchNSFWService: """Batch NSFW orchestration.""" @@ -45,4 +102,4 @@ async def _enable(token: str): ) -__all__ = ["BatchNSFWService"] +__all__ = ["BatchNSFWService", "NSFWService", "NSFWResult"] diff --git a/app/services/grok/batch_services/usage.py b/app/services/grok/batch_services/usage.py index 29e1b4db..edac6f4e 100644 --- a/app/services/grok/batch_services/usage.py +++ b/app/services/grok/batch_services/usage.py @@ -2,10 +2,61 @@ Batch usage service. """ +import asyncio from typing import Callable, Awaitable, Dict, Any, Optional +from curl_cffi.requests import AsyncSession + +from app.core.logger import logger +from app.core.config import get_config +from app.services.reverse import RateLimitsReverse from app.services.grok.utils.batch import run_in_batches +_USAGE_SEMAPHORE = asyncio.Semaphore(25) +_USAGE_SEM_VALUE = 25 + + +class UsageService: + """用量查询服务""" + + async def get(self, token: str) -> Dict: + """ + 获取速率限制信息 + + Args: + token: 认证 Token + + Returns: + 响应数据 + + Raises: + UpstreamException: 当获取失败且重试耗尽时 + """ + value = get_config("performance.usage_max_concurrent") + try: + value = int(value) + except Exception: + value = 25 + value = max(1, value) + global _USAGE_SEMAPHORE, _USAGE_SEM_VALUE + if value != _USAGE_SEM_VALUE: + _USAGE_SEM_VALUE = value + _USAGE_SEMAPHORE = asyncio.Semaphore(value) + async with _USAGE_SEMAPHORE: + try: + async with AsyncSession() as session: + response = await RateLimitsReverse.request(session, token) + data = response.json() + remaining = data.get("remainingTokens", 0) + logger.info( + f"Usage sync success: remaining={remaining}, token={token[:10]}..." + ) + return data + + except Exception: + # 最后一次失败已经被记录 + raise + class BatchUsageService: """Batch usage orchestration.""" @@ -33,4 +84,4 @@ async def _refresh_one(t: str): ) -__all__ = ["BatchUsageService"] +__all__ = ["BatchUsageService", "UsageService"] diff --git a/app/services/grok/processors/base.py b/app/services/grok/processors/base.py index f1c208e0..1d2d5c94 100644 --- a/app/services/grok/processors/base.py +++ b/app/services/grok/processors/base.py @@ -144,7 +144,7 @@ async def process_url(self, path: str, media_type: str = "image") -> str: if self.app_url: dl_service = self._get_dl() - await dl_service.download(path, self.token, media_type) + await dl_service.download_file(path, self.token, media_type) return f"{self.app_url.rstrip('/')}/v1/files/{media_type}{path}" else: return f"{ASSET_URL.rstrip('/')}{path}" diff --git a/app/services/grok/processors/chat.py b/app/services/grok/processors/chat.py index f55f8530..1e044370 100644 --- a/app/services/grok/processors/chat.py +++ b/app/services/grok/processors/chat.py @@ -168,7 +168,7 @@ async def process( if self.image_format == "base64": try: dl_service = self._get_dl() - base64_data = await dl_service.to_base64( + base64_data = await dl_service.parse_b64( url, self.token, "image" ) if base64_data: @@ -296,7 +296,7 @@ async def process(self, response: AsyncIterable[bytes]) -> dict[str, Any]: if self.image_format == "base64": try: dl_service = self._get_dl() - base64_data = await dl_service.to_base64( + base64_data = await dl_service.parse_b64( url, self.token, "image" ) if base64_data: diff --git a/app/services/grok/processors/image.py b/app/services/grok/processors/image.py index a53b94a8..f8f1f29e 100644 --- a/app/services/grok/processors/image.py +++ b/app/services/grok/processors/image.py @@ -99,7 +99,7 @@ async def process( continue try: dl_service = self._get_dl() - base64_data = await dl_service.to_base64( + base64_data = await dl_service.parse_b64( url, self.token, "image" ) if base64_data: @@ -212,7 +212,7 @@ async def process(self, response: AsyncIterable[bytes]) -> List[str]: continue try: dl_service = self._get_dl() - base64_data = await dl_service.to_base64( + base64_data = await dl_service.parse_b64( url, self.token, "image" ) if base64_data: diff --git a/app/services/grok/services/assets.py b/app/services/grok/services/assets.py deleted file mode 100644 index 12d31d68..00000000 --- a/app/services/grok/services/assets.py +++ /dev/null @@ -1,666 +0,0 @@ -""" -Grok 文件资产服务 -""" - -import asyncio -import base64 -import hashlib -import os -import re -import time -from contextlib import asynccontextmanager -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple -from urllib.parse import urlparse - -try: - import fcntl -except ImportError: - fcntl = None - -import aiofiles -from curl_cffi.requests import AsyncSession - -from app.core.config import get_config -from app.core.exceptions import AppException, UpstreamException, ValidationException -from app.core.logger import logger -from app.core.storage import DATA_DIR -from app.services.reverse import ( - AssetsDeleteReverse, - AssetsDownloadReverse, - AssetsListReverse, - AssetsUploadReverse, -) - -# ==================== 常量 ==================== - -LOCK_DIR = DATA_DIR / ".locks" - -# 全局信号量(运行时动态初始化) -_ASSETS_SEMAPHORE = None -_ASSETS_SEM_VALUE = None - -# 常用 MIME 类型(业务数据,非配置) -MIME_TYPES = { - ".jpg": "image/jpeg", - ".jpeg": "image/jpeg", - ".png": "image/png", - ".gif": "image/gif", - ".webp": "image/webp", - ".bmp": "image/bmp", - ".pdf": "application/pdf", - ".txt": "text/plain", - ".md": "text/markdown", - ".csv": "text/csv", - ".json": "application/json", - ".xml": "application/xml", - ".py": "text/x-python-script", - ".js": "application/javascript", - ".html": "text/html", - ".css": "text/css", - ".mp4": "video/mp4", - ".webm": "video/webm", -} - -IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"} -VIDEO_EXTS = {".mp4", ".mov", ".m4v", ".webm", ".avi", ".mkv"} - -# ==================== 工具函数 ==================== - - -def _get_assets_semaphore() -> asyncio.Semaphore: - """获取全局并发控制信号量""" - value = max(1, int(get_config("performance.assets_max_concurrent"))) - - global _ASSETS_SEMAPHORE, _ASSETS_SEM_VALUE - if _ASSETS_SEMAPHORE is None or value != _ASSETS_SEM_VALUE: - _ASSETS_SEM_VALUE = value - _ASSETS_SEMAPHORE = asyncio.Semaphore(value) - return _ASSETS_SEMAPHORE - - -@asynccontextmanager -async def _file_lock(name: str, timeout: int = 10): - """文件锁""" - if fcntl is None: - yield - return - - LOCK_DIR.mkdir(parents=True, exist_ok=True) - lock_path = LOCK_DIR / f"{name}.lock" - fd = None - locked = False - start = time.monotonic() - - try: - fd = open(lock_path, "a+") - while True: - try: - fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB) - locked = True - break - except BlockingIOError: - if time.monotonic() - start >= timeout: - break - await asyncio.sleep(0.05) - yield - finally: - if fd: - if locked: - try: - fcntl.flock(fd, fcntl.LOCK_UN) - except Exception: - pass - fd.close() - - -# ==================== 基础服务 ==================== - - -class BaseService: - """基础服务类""" - - def __init__(self): - self._session: Optional[AsyncSession] = None - - async def _get_session(self) -> AsyncSession: - """获取复用 Session""" - if self._session is None: - self._session = AsyncSession() - return self._session - - async def close(self): - """关闭 Session""" - if self._session: - await self._session.close() - self._session = None - - @staticmethod - def is_url(s: str) -> bool: - """检查是否为 URL""" - try: - r = urlparse(s) - return bool(r.scheme and r.netloc and r.scheme in ["http", "https"]) - except Exception: - return False - - @staticmethod - async def fetch(url: str) -> Tuple[str, str, str]: - """获取远程资源并转 Base64""" - try: - async with AsyncSession() as session: - response = await session.get(url, timeout=10) - if response.status_code >= 400: - raise UpstreamException( - message=f"Failed to fetch: {response.status_code}", - details={"url": url, "status": response.status_code}, - ) - - filename = url.split("/")[-1].split("?")[0] or "download" - content_type = response.headers.get( - "content-type", "application/octet-stream" - ).split(";")[0] - b64 = base64.b64encode(response.content).decode() - - logger.debug(f"Fetched: {url}") - return filename, b64, content_type - except Exception as e: - if isinstance(e, AppException): - raise - logger.error(f"Fetch failed: {url} - {e}") - raise UpstreamException(f"Fetch failed: {str(e)}", details={"url": url}) - - @staticmethod - def parse_b64(data_uri: str) -> Tuple[str, str, str]: - """解析 Base64 数据""" - if not data_uri.startswith("data:"): - return "file.bin", data_uri, "application/octet-stream" - - try: - header, b64 = data_uri.split(",", 1) - except ValueError: - return "file.bin", data_uri, "application/octet-stream" - - if ";base64" not in header: - return "file.bin", data_uri, "application/octet-stream" - - mime = header[5:].split(";", 1)[0] or "application/octet-stream" - b64 = re.sub(r"\s+", "", b64) - ext = mime.split("/")[-1] if "/" in mime else "bin" - return f"file.{ext}", b64, mime - - @staticmethod - def to_b64(file_path: Path, mime_type: str) -> str: - """文件转 base64 data URI""" - try: - if not file_path.exists(): - logger.warning(f"File not found for base64 conversion: {file_path}") - raise AppException( - f"File not found: {file_path}", code="file_not_found" - ) - - if not file_path.is_file(): - logger.warning(f"Path is not a file: {file_path}") - raise AppException( - f"Invalid file path: {file_path}", code="invalid_file_path" - ) - - b64_data = base64.b64encode(file_path.read_bytes()).decode() - return f"data:{mime_type};base64,{b64_data}" - except AppException: - raise - except Exception as e: - logger.error(f"File to base64 failed: {file_path} - {e}") - raise AppException( - f"Failed to read file: {file_path}", code="file_read_error" - ) - - -# ==================== 上传服务 ==================== - - -class UploadService(BaseService): - """文件上传服务""" - - async def upload(self, file_input: str, token: str) -> Tuple[str, str]: - """ - 上传文件到 Grok - - Returns: - (file_id, file_uri) - """ - async with _get_assets_semaphore(): - # 处理输入 - if self.is_url(file_input): - filename, b64, mime = await self.fetch(file_input) - else: - filename, b64, mime = self.parse_b64(file_input) - - logger.debug( - f"Upload prepare: filename={filename}, type={mime}, size={len(b64)}" - ) - - if not b64: - raise ValidationException("Invalid file input: empty content") - - # 执行上传 - session = await self._get_session() - response = await AssetsUploadReverse.request( - session, - token, - filename, - mime, - b64, - ) - - result = response.json() - file_id = result.get("fileMetadataId", "") - file_uri = result.get("fileUri", "") - logger.info(f"Upload success: {filename} -> {file_id}") - return file_id, file_uri - - -# ==================== 列表服务 ==================== - - -class ListService(BaseService): - """文件列表查询服务""" - - async def iter_assets(self, token: str): - """分页迭代资产列表""" - params = { - "pageSize": 50, - "orderBy": "ORDER_BY_LAST_USE_TIME", - "source": "SOURCE_ANY", - "isLatest": "true", - } - page_token = None - seen_tokens = set() - - async with AsyncSession() as session: - while True: - if page_token: - if page_token in seen_tokens: - logger.warning("Pagination stopped: repeated page token") - break - seen_tokens.add(page_token) - params["pageToken"] = page_token - else: - params.pop("pageToken", None) - - response = await AssetsListReverse.request( - session, - token, - params, - ) - - result = response.json() - page_assets = result.get("assets", []) - yield page_assets - - page_token = result.get("nextPageToken") - if not page_token: - break - - async def list(self, token: str) -> List[Dict]: - """查询文件列表""" - assets = [] - async for page_assets in self.iter_assets(token): - assets.extend(page_assets) - logger.info(f"List success: {len(assets)} files") - return assets - - async def count(self, token: str) -> int: - """统计资产数量""" - total = 0 - async for page_assets in self.iter_assets(token): - total += len(page_assets) - logger.debug(f"Asset count: {total}") - return total - - -# ==================== 删除服务 ==================== - - -class DeleteService(BaseService): - """文件删除服务""" - - async def delete(self, token: str, asset_id: str) -> bool: - """删除单个文件""" - async with _get_assets_semaphore(): - session = await self._get_session() - response = await AssetsDeleteReverse.request( - session, - token, - asset_id, - ) - - logger.debug(f"Deleted: {asset_id}") - return True - - async def delete_all(self, token: str) -> Dict[str, int]: - """删除所有文件""" - total = success = failed = 0 - list_service = ListService() - - try: - async for assets in list_service.iter_assets(token): - if not assets: - continue - - total += len(assets) - batch_result = await self._delete_batch(token, assets) - success += batch_result["success"] - failed += batch_result["failed"] - - if total == 0: - logger.info("No assets to delete") - return {"total": 0, "success": 0, "failed": 0, "skipped": True} - finally: - await list_service.close() - - logger.info(f"Delete all: total={total}, success={success}, failed={failed}") - return {"total": total, "success": success, "failed": failed} - - async def _delete_batch(self, token: str, assets: List[Dict]) -> Dict[str, int]: - """批量删除""" - batch_size = max(1, int(get_config("performance.assets_delete_batch_size"))) - success = failed = 0 - - for i in range(0, len(assets), batch_size): - batch = assets[i : i + batch_size] - results = await asyncio.gather( - *[ - self._delete_one(token, asset, idx) - for idx, asset in enumerate(batch) - ], - return_exceptions=True, - ) - success += sum(1 for r in results if r is True) - failed += sum(1 for r in results if r is not True) - - return {"success": success, "failed": failed} - - async def _delete_one(self, token: str, asset: Dict, index: int) -> bool: - """删除单个资产(带延迟)""" - await asyncio.sleep(0.01 * index) - asset_id = asset.get("assetId", "") - if not asset_id: - return False - try: - return await self.delete(token, asset_id) - except Exception: - return False - - -# ==================== 下载服务 ==================== - - -class DownloadService(BaseService): - """文件下载服务""" - - def __init__(self): - super().__init__() - self.base_dir = DATA_DIR / "tmp" - self.image_dir = self.base_dir / "image" - self.video_dir = self.base_dir / "video" - self.image_dir.mkdir(parents=True, exist_ok=True) - self.video_dir.mkdir(parents=True, exist_ok=True) - self._cleanup_running = False - - def _cache_path(self, file_path: str, media_type: str) -> Path: - """获取缓存路径""" - cache_dir = self.image_dir if media_type == "image" else self.video_dir - filename = file_path.lstrip("/").replace("/", "-") - return cache_dir / filename - - def _get_mime(self, cache_path: Path, response=None) -> str: - """获取 MIME 类型""" - if response: - return response.headers.get( - "content-type", "application/octet-stream" - ).split(";")[0] - return MIME_TYPES.get(cache_path.suffix.lower(), "application/octet-stream") - - async def download( - self, file_path: str, token: str, media_type: str = "image" - ) -> Tuple[Optional[Path], str]: - """下载文件到本地""" - async with _get_assets_semaphore(): - cache_path = self._cache_path(file_path, media_type) - - # 检查缓存 - if cache_path.exists(): - logger.debug(f"Cache hit: {cache_path}") - return cache_path, self._get_mime(cache_path) - - # 文件锁防止并发下载 - lock_name = f"dl_{media_type}_{hashlib.sha1(str(cache_path).encode()).hexdigest()[:16]}" - async with _file_lock(lock_name, timeout=10): - # 双重检查 - if cache_path.exists(): - return cache_path, self._get_mime(cache_path) - - # 执行下载 - mime = await self._download_file(file_path, token, cache_path) - logger.info(f"Downloaded: {file_path}") - - # 异步检查缓存限制 - asyncio.create_task(self.check_limit()) - - return cache_path, mime - - async def _download_file(self, file_path: str, token: str, cache_path: Path) -> str: - """执行下载""" - if not file_path.startswith("/"): - file_path = f"/{file_path}" - - session = await self._get_session() - response = await AssetsDownloadReverse.request(session, token, file_path) - - # 保存文件 - tmp_path = cache_path.with_suffix(cache_path.suffix + ".tmp") - try: - async with aiofiles.open(tmp_path, "wb") as f: - # 尝试流式写入 - if hasattr(response, "aiter_content"): - async for chunk in response.aiter_content(): - if chunk: - await f.write(chunk) - else: - await f.write(response.content) - os.replace(tmp_path, cache_path) - finally: - if tmp_path.exists() and not cache_path.exists(): - try: - tmp_path.unlink() - except Exception: - pass - - return self._get_mime(cache_path, response) - - async def to_base64( - self, file_path: str, token: str, media_type: str = "image" - ) -> str: - """下载并转 base64""" - try: - cache_path, mime = await self.download(file_path, token, media_type) - if not cache_path or not cache_path.exists(): - logger.warning(f"Download failed for {file_path}: invalid path") - raise AppException( - "Download failed: invalid path", code="download_failed" - ) - - data_uri = self.to_b64(cache_path, mime) - - # 删除临时文件 - if data_uri: - try: - cache_path.unlink() - except Exception as e: - logger.debug(f"Failed to cleanup temp file {cache_path}: {e}") - - return data_uri - except Exception as e: - logger.error(f"Failed to convert {file_path} to base64: {e}") - raise - - def get_stats(self, media_type: str = "image") -> Dict[str, Any]: - """获取缓存统计""" - cache_dir = self.image_dir if media_type == "image" else self.video_dir - if not cache_dir.exists(): - return {"count": 0, "size_mb": 0.0} - - allowed = IMAGE_EXTS if media_type == "image" else VIDEO_EXTS - files = [ - f - for f in cache_dir.glob("*") - if f.is_file() and f.suffix.lower() in allowed - ] - total_size = sum(f.stat().st_size for f in files) - return {"count": len(files), "size_mb": round(total_size / 1024 / 1024, 2)} - - def list_files( - self, media_type: str = "image", page: int = 1, page_size: int = 1000 - ) -> Dict[str, Any]: - """列出缓存文件""" - cache_dir = self.image_dir if media_type == "image" else self.video_dir - if not cache_dir.exists(): - return {"total": 0, "page": page, "page_size": page_size, "items": []} - - allowed = IMAGE_EXTS if media_type == "image" else VIDEO_EXTS - files = [ - f - for f in cache_dir.glob("*") - if f.is_file() and f.suffix.lower() in allowed - ] - - # 构建文件列表 - items = [] - for f in files: - try: - stat = f.stat() - items.append( - { - "name": f.name, - "size_bytes": stat.st_size, - "mtime_ms": int(stat.st_mtime * 1000), - } - ) - except Exception: - continue - - items.sort(key=lambda x: x["mtime_ms"], reverse=True) - - # 分页 - total = len(items) - start = max(0, (page - 1) * page_size) - paged = items[start : start + page_size] - - # 添加 URL - for item in paged: - item["view_url"] = f"/v1/files/{media_type}/{item['name']}" - - return {"total": total, "page": page, "page_size": page_size, "items": paged} - - def delete_file(self, media_type: str, name: str) -> Dict[str, Any]: - """删除缓存文件""" - cache_dir = self.image_dir if media_type == "image" else self.video_dir - file_path = cache_dir / name.replace("/", "-") - - if file_path.exists(): - try: - file_path.unlink() - return {"deleted": True} - except Exception: - pass - return {"deleted": False} - - def clear(self, media_type: str = "image") -> Dict[str, Any]: - """清空缓存""" - cache_dir = self.image_dir if media_type == "image" else self.video_dir - if not cache_dir.exists(): - return {"count": 0, "size_mb": 0.0} - - files = list(cache_dir.glob("*")) - total_size = sum(f.stat().st_size for f in files if f.is_file()) - count = 0 - - for f in files: - if f.is_file(): - try: - f.unlink() - count += 1 - except Exception: - pass - - return {"count": count, "size_mb": round(total_size / 1024 / 1024, 2)} - - async def check_limit(self): - """检查并清理缓存""" - if self._cleanup_running or not get_config("cache.enable_auto_clean"): - return - - self._cleanup_running = True - try: - async with _file_lock("cache_cleanup", timeout=5): - limit_mb = get_config("cache.limit_mb") - all_files, total_size = self._collect_files() - current_mb = total_size / 1024 / 1024 - - if current_mb <= limit_mb: - return - - # 清理到 80% - logger.info( - f"Cache limit exceeded ({current_mb:.2f}MB > {limit_mb}MB), cleaning..." - ) - all_files.sort(key=lambda x: x[1]) # 按时间排序 - - deleted_count = 0 - deleted_size = 0 - target_mb = limit_mb * 0.8 - - for f, _, size in all_files: - try: - f.unlink() - deleted_count += 1 - deleted_size += size - total_size -= size - if (total_size / 1024 / 1024) <= target_mb: - break - except Exception: - pass - - logger.info( - f"Cache cleanup: {deleted_count} files ({deleted_size / 1024 / 1024:.2f}MB)" - ) - finally: - self._cleanup_running = False - - def _collect_files(self) -> Tuple[List[Tuple[Path, float, int]], int]: - """收集所有缓存文件""" - total_size = 0 - all_files = [] - - for d in [self.image_dir, self.video_dir]: - if d.exists(): - for f in d.glob("*"): - if f.is_file(): - try: - stat = f.stat() - total_size += stat.st_size - all_files.append((f, stat.st_mtime, stat.st_size)) - except Exception: - pass - - return all_files, total_size - - -__all__ = [ - "BaseService", - "UploadService", - "ListService", - "DeleteService", - "DownloadService", -] diff --git a/app/services/grok/services/chat.py b/app/services/grok/services/chat.py index fc8d2856..aa439239 100644 --- a/app/services/grok/services/chat.py +++ b/app/services/grok/services/chat.py @@ -16,7 +16,7 @@ UpstreamException, ) from app.services.grok.models.model import ModelService -from app.services.grok.services.assets import UploadService +from app.services.grok.utils.upload import UploadService from app.services.grok.processors import StreamProcessor, CollectProcessor from app.services.reverse import AppChatReverse from app.services.grok.utils.stream import wrap_stream_with_usage @@ -209,7 +209,7 @@ async def chat_openai(self, token: str, request: ChatRequest): upload_service = UploadService() try: for attach_type, attach_data in attachments: - file_id, _ = await upload_service.upload(attach_data, token) + file_id, _ = await upload_service.upload_file(attach_data, token) file_ids.append(file_id) logger.debug( f"Attachment uploaded: type={attach_type}, file_id={file_id}" diff --git a/app/services/grok/services/image_edit.py b/app/services/grok/services/image_edit.py index ca460bc2..0eb81777 100644 --- a/app/services/grok/services/image_edit.py +++ b/app/services/grok/services/image_edit.py @@ -11,7 +11,7 @@ from app.core.exceptions import AppException, ErrorType from app.core.logger import logger from app.services.grok.processors import ImageCollectProcessor, ImageStreamProcessor -from app.services.grok.services.assets import UploadService +from app.services.grok.utils.upload import UploadService from app.services.grok.services.chat import GrokChatService from app.services.grok.services.video import VideoService from app.services.grok.utils.stream import wrap_stream_with_usage @@ -98,7 +98,7 @@ async def _upload_images(self, images: List[str], token: str) -> List[str]: upload_service = UploadService() try: for image in images: - _, file_uri = await upload_service.upload(image, token) + _, file_uri = await upload_service.upload_file(image, token) if file_uri: if file_uri.startswith("http"): image_urls.append(file_uri) diff --git a/app/services/grok/services/nsfw.py b/app/services/grok/services/nsfw.py deleted file mode 100644 index a919d54e..00000000 --- a/app/services/grok/services/nsfw.py +++ /dev/null @@ -1,68 +0,0 @@ -""" -NSFW (Unhinged) 模式服务 - -使用 gRPC-Web 协议开启账号的 NSFW 功能。 -""" - -from dataclasses import dataclass -from typing import Optional - -from curl_cffi.requests import AsyncSession - -from app.core.logger import logger -from app.core.config import get_config -from app.core.exceptions import UpstreamException -from app.services.reverse import NsfwMgmtReverse, SetBirthReverse -from app.services.reverse.utils.grpc import GrpcStatus - -@dataclass -class NSFWResult: - """NSFW 操作结果""" - - success: bool - http_status: int - grpc_status: Optional[int] = None - grpc_message: Optional[str] = None - error: Optional[str] = None - - -class NSFWService: - """NSFW 模式服务""" - - async def enable(self, token: str) -> NSFWResult: - """为单个 token 开启 NSFW 模式""" - try: - browser = get_config("security.browser") - async with AsyncSession(impersonate=browser) as session: - # 先设置出生日期 - try: - await SetBirthReverse.request(session, token) - except UpstreamException as e: - status = None - if e.details and "status" in e.details: - status = e.details["status"] - else: - status = getattr(e, "status_code", None) - return NSFWResult( - success=False, - http_status=status or 0, - error=f"Set birth date failed: {str(e)}", - ) - - # 开启 NSFW - grpc_status: GrpcStatus = await NsfwMgmtReverse.request(session, token) - success = grpc_status.code in (-1, 0) - - return NSFWResult( - success=success, - http_status=200, - grpc_status=grpc_status.code, - grpc_message=grpc_status.message or None, - ) - - except Exception as e: - logger.error(f"NSFW enable failed: {e}") - return NSFWResult(success=False, http_status=0, error=str(e)[:100]) - - -__all__ = ["NSFWService", "NSFWResult"] diff --git a/app/services/grok/services/usage.py b/app/services/grok/services/usage.py deleted file mode 100644 index e2734330..00000000 --- a/app/services/grok/services/usage.py +++ /dev/null @@ -1,60 +0,0 @@ -""" -Grok 用量服务 -""" - -import asyncio -from typing import Dict - -from curl_cffi.requests import AsyncSession - -from app.core.logger import logger -from app.core.config import get_config -from app.services.reverse import RateLimitsReverse - -_USAGE_SEMAPHORE = asyncio.Semaphore(25) -_USAGE_SEM_VALUE = 25 - - -class UsageService: - """用量查询服务""" - - async def get(self, token: str) -> Dict: - """ - 获取速率限制信息 - - Args: - token: 认证 Token - - Returns: - 响应数据 - - Raises: - UpstreamException: 当获取失败且重试耗尽时 - """ - value = get_config("performance.usage_max_concurrent") - try: - value = int(value) - except Exception: - value = 25 - value = max(1, value) - global _USAGE_SEMAPHORE, _USAGE_SEM_VALUE - if value != _USAGE_SEM_VALUE: - _USAGE_SEM_VALUE = value - _USAGE_SEMAPHORE = asyncio.Semaphore(value) - async with _USAGE_SEMAPHORE: - try: - async with AsyncSession() as session: - response = await RateLimitsReverse.request(session, token) - data = response.json() - remaining = data.get("remainingTokens", 0) - logger.info( - f"Usage sync success: remaining={remaining}, token={token[:10]}..." - ) - return data - - except Exception: - # 最后一次失败已经被记录 - raise - - -__all__ = ["UsageService"] diff --git a/app/services/grok/services/video.py b/app/services/grok/services/video.py index bb53e535..cd65c8ad 100644 --- a/app/services/grok/services/video.py +++ b/app/services/grok/services/video.py @@ -260,7 +260,7 @@ async def completions( # Extract content. from app.services.grok.services.chat import MessageExtractor - from app.services.grok.services.assets import UploadService + from app.services.grok.utils.upload import UploadService try: prompt, attachments = MessageExtractor.extract(messages, is_video=True) @@ -274,7 +274,7 @@ async def completions( try: for attach_type, attach_data in attachments: if attach_type == "image": - _, file_uri = await upload_service.upload(attach_data, token) + _, file_uri = await upload_service.upload_file(attach_data, token) image_url = f"https://assets.grok.com/{file_uri}" logger.info(f"Image uploaded for video: {image_url}") break diff --git a/app/services/grok/utils/cache.py b/app/services/grok/utils/cache.py new file mode 100644 index 00000000..a728df15 --- /dev/null +++ b/app/services/grok/utils/cache.py @@ -0,0 +1,110 @@ +""" +Local cache utilities. +""" + +from typing import Any, Dict + +from app.core.storage import DATA_DIR + +IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"} +VIDEO_EXTS = {".mp4", ".mov", ".m4v", ".webm", ".avi", ".mkv"} + + +class CacheService: + """Local cache service.""" + + def __init__(self): + base_dir = DATA_DIR / "tmp" + self.image_dir = base_dir / "image" + self.video_dir = base_dir / "video" + self.image_dir.mkdir(parents=True, exist_ok=True) + self.video_dir.mkdir(parents=True, exist_ok=True) + + def _cache_dir(self, media_type: str): + return self.image_dir if media_type == "image" else self.video_dir + + def _allowed_exts(self, media_type: str): + return IMAGE_EXTS if media_type == "image" else VIDEO_EXTS + + def get_stats(self, media_type: str = "image") -> Dict[str, Any]: + cache_dir = self._cache_dir(media_type) + if not cache_dir.exists(): + return {"count": 0, "size_mb": 0.0} + + allowed = self._allowed_exts(media_type) + files = [ + f for f in cache_dir.glob("*") if f.is_file() and f.suffix.lower() in allowed + ] + total_size = sum(f.stat().st_size for f in files) + return {"count": len(files), "size_mb": round(total_size / 1024 / 1024, 2)} + + def list_files( + self, media_type: str = "image", page: int = 1, page_size: int = 1000 + ) -> Dict[str, Any]: + cache_dir = self._cache_dir(media_type) + if not cache_dir.exists(): + return {"total": 0, "page": page, "page_size": page_size, "items": []} + + allowed = self._allowed_exts(media_type) + files = [ + f for f in cache_dir.glob("*") if f.is_file() and f.suffix.lower() in allowed + ] + + items = [] + for f in files: + try: + stat = f.stat() + items.append( + { + "name": f.name, + "size_bytes": stat.st_size, + "mtime_ms": int(stat.st_mtime * 1000), + } + ) + except Exception: + continue + + items.sort(key=lambda x: x["mtime_ms"], reverse=True) + + total = len(items) + start = max(0, (page - 1) * page_size) + paged = items[start : start + page_size] + + for item in paged: + item["view_url"] = f"/v1/files/{media_type}/{item['name']}" + + return {"total": total, "page": page, "page_size": page_size, "items": paged} + + def delete_file(self, media_type: str, name: str) -> Dict[str, Any]: + cache_dir = self._cache_dir(media_type) + file_path = cache_dir / name.replace("/", "-") + + if file_path.exists(): + try: + file_path.unlink() + return {"deleted": True} + except Exception: + pass + return {"deleted": False} + + def clear(self, media_type: str = "image") -> Dict[str, Any]: + cache_dir = self._cache_dir(media_type) + if not cache_dir.exists(): + return {"count": 0, "size_mb": 0.0} + + files = list(cache_dir.glob("*")) + total_size = sum(f.stat().st_size for f in files if f.is_file()) + count = 0 + + for f in files: + if f.is_file(): + try: + f.unlink() + count += 1 + except Exception: + pass + + return {"count": count, "size_mb": round(total_size / 1024 / 1024, 2)} + + +__all__ = ["CacheService"] diff --git a/app/services/grok/utils/download.py b/app/services/grok/utils/download.py index b3b930f5..9179610b 100644 --- a/app/services/grok/utils/download.py +++ b/app/services/grok/utils/download.py @@ -1,5 +1,233 @@ -"""Download service (compat wrapper).""" +""" +Download service. +""" + +import asyncio +import base64 +import hashlib +import os +from pathlib import Path +from typing import List, Optional, Tuple +from urllib.parse import urlparse + +import aiofiles +from curl_cffi.requests import AsyncSession + +from app.core.logger import logger +from app.core.storage import DATA_DIR +from app.core.config import get_config +from app.core.exceptions import AppException +from app.services.reverse import AssetsDownloadReverse +from app.services.grok.utils.locks import _get_assets_semaphore, _file_lock + + +class DownloadService: + """Assets download service.""" + + def __init__(self): + self._session: Optional[AsyncSession] = None + base_dir = DATA_DIR / "tmp" + self.image_dir = base_dir / "image" + self.video_dir = base_dir / "video" + self.image_dir.mkdir(parents=True, exist_ok=True) + self.video_dir.mkdir(parents=True, exist_ok=True) + self._cleanup_running = False + + async def create(self) -> AsyncSession: + """Create or reuse a session.""" + if self._session is None: + self._session = AsyncSession() + return self._session + + async def close(self): + """Close the session.""" + if self._session: + await self._session.close() + self._session = None + + @staticmethod + def _is_url(value: str) -> bool: + """Check if the value is a URL.""" + try: + parsed = urlparse(value) + return bool(parsed.scheme and parsed.netloc and parsed.scheme in ["http", "https"]) + except Exception: + return False + + async def parse_b64(self, file_path: str, token: str, media_type: str = "image") -> str: + """Download and return data URI.""" + try: + cache_path, mime = await self.download_file(file_path, token, media_type) + if not cache_path or not cache_path.exists(): + logger.warning(f"Download failed for {file_path}: invalid path") + raise AppException( + "Download failed: invalid path", code="download_failed" + ) + + data_uri = await self.format_b64(cache_path, mime) + + if data_uri: + try: + cache_path.unlink() + except Exception as e: + logger.debug(f"Failed to cleanup temp file {cache_path}: {e}") + + return data_uri + except Exception as e: + logger.error(f"Failed to convert {file_path} to base64: {e}") + raise + + @staticmethod + async def format_b64(file_path: Path, mime_type: str) -> str: + """Format local file to data URI.""" + try: + if not file_path.exists(): + logger.warning(f"File not found for base64 conversion: {file_path}") + raise AppException( + f"File not found: {file_path}", code="file_not_found" + ) + + if not file_path.is_file(): + logger.warning(f"Path is not a file: {file_path}") + raise AppException( + f"Invalid file path: {file_path}", code="invalid_file_path" + ) + + async with aiofiles.open(file_path, "rb") as f: + data = await f.read() + b64_data = base64.b64encode(data).decode() + return f"data:{mime_type};base64,{b64_data}" + except AppException: + raise + except Exception as e: + logger.error(f"File to base64 failed: {file_path} - {e}") + raise AppException( + f"Failed to read file: {file_path}", code="file_read_error" + ) + + def check_format(self, file_path: str) -> str: + """Normalize file path for download.""" + if not isinstance(file_path, str) or not file_path.strip(): + raise AppException("Invalid file path", code="invalid_file_path") + if self._is_url(file_path): + file_path = urlparse(file_path).path or "" + if not file_path.startswith("/"): + file_path = f"/{file_path}" + return file_path + + async def download_file(self, file_path: str, token: str, media_type: str = "image") -> Tuple[Optional[Path], str]: + """Download asset to local cache. + + Args: + file_path: str, the path of the file to download. + token: str, the SSO token. + media_type: str, the media type of the file. + + Returns: + Tuple[Optional[Path], str]: The path of the downloaded file and the MIME type. + """ + async with _get_assets_semaphore(): + file_path = self.check_format(file_path) + cache_dir = self.image_dir if media_type == "image" else self.video_dir + filename = file_path.lstrip("/").replace("/", "-") + cache_path = cache_dir / filename + + lock_name = ( + f"dl_{media_type}_{hashlib.sha1(str(cache_path).encode()).hexdigest()[:16]}" + ) + async with _file_lock(lock_name, timeout=10): + session = await self.create() + response = await AssetsDownloadReverse.request( + session, token, file_path + ) + + tmp_path = cache_path.with_suffix(cache_path.suffix + ".tmp") + try: + async with aiofiles.open(tmp_path, "wb") as f: + if hasattr(response, "aiter_content"): + async for chunk in response.aiter_content(): + if chunk: + await f.write(chunk) + else: + await f.write(response.content) + os.replace(tmp_path, cache_path) + finally: + if tmp_path.exists() and not cache_path.exists(): + try: + tmp_path.unlink() + except Exception: + pass + + mime = response.headers.get( + "content-type", "application/octet-stream" + ).split(";")[0] + logger.info(f"Downloaded: {file_path}") + + asyncio.create_task(self.check_limit()) + + return cache_path, mime + + async def check_limit(self): + """Check cache limit and cleanup. + + Args: + self: DownloadService, the download service instance. + + Returns: + None + """ + if self._cleanup_running or not get_config("cache.enable_auto_clean"): + return + + self._cleanup_running = True + try: + async with _file_lock("cache_cleanup", timeout=5): + limit_mb = get_config("cache.limit_mb") + total_size = 0 + all_files: List[Tuple[Path, float, int]] = [] + + for d in [self.image_dir, self.video_dir]: + if d.exists(): + for f in d.glob("*"): + if f.is_file(): + try: + stat = f.stat() + total_size += stat.st_size + all_files.append( + (f, stat.st_mtime, stat.st_size) + ) + except Exception: + pass + current_mb = total_size / 1024 / 1024 + + if current_mb <= limit_mb: + return + + logger.info( + f"Cache limit exceeded ({current_mb:.2f}MB > {limit_mb}MB), cleaning..." + ) + all_files.sort(key=lambda x: x[1]) + + deleted_count = 0 + deleted_size = 0 + target_mb = limit_mb * 0.8 + + for f, _, size in all_files: + try: + f.unlink() + deleted_count += 1 + deleted_size += size + total_size -= size + if (total_size / 1024 / 1024) <= target_mb: + break + except Exception: + pass + + logger.info( + f"Cache cleanup: {deleted_count} files ({deleted_size / 1024 / 1024:.2f}MB)" + ) + finally: + self._cleanup_running = False -from app.services.grok.services.assets import DownloadService __all__ = ["DownloadService"] diff --git a/app/services/grok/utils/locks.py b/app/services/grok/utils/locks.py new file mode 100644 index 00000000..654ae9e1 --- /dev/null +++ b/app/services/grok/utils/locks.py @@ -0,0 +1,71 @@ +""" +Shared locking helpers for assets operations. +""" + +import asyncio +import time +from contextlib import asynccontextmanager +from pathlib import Path + +from app.core.config import get_config +from app.core.storage import DATA_DIR + +try: + import fcntl +except ImportError: + fcntl = None + + +LOCK_DIR = DATA_DIR / ".locks" + +_ASSETS_SEMAPHORE = None +_ASSETS_SEM_VALUE = None + + +def _get_assets_semaphore() -> asyncio.Semaphore: + """Return global semaphore for assets operations.""" + value = max(1, int(get_config("performance.assets_max_concurrent"))) + + global _ASSETS_SEMAPHORE, _ASSETS_SEM_VALUE + if _ASSETS_SEMAPHORE is None or value != _ASSETS_SEM_VALUE: + _ASSETS_SEM_VALUE = value + _ASSETS_SEMAPHORE = asyncio.Semaphore(value) + return _ASSETS_SEMAPHORE + + +@asynccontextmanager +async def _file_lock(name: str, timeout: int = 10): + """File lock guard.""" + if fcntl is None: + yield + return + + LOCK_DIR.mkdir(parents=True, exist_ok=True) + lock_path = Path(LOCK_DIR) / f"{name}.lock" + fd = None + locked = False + start = time.monotonic() + + try: + fd = open(lock_path, "a+") + while True: + try: + fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB) + locked = True + break + except BlockingIOError: + if time.monotonic() - start >= timeout: + break + await asyncio.sleep(0.05) + yield + finally: + if fd: + if locked: + try: + fcntl.flock(fd, fcntl.LOCK_UN) + except Exception: + pass + fd.close() + + +__all__ = ["_get_assets_semaphore", "_file_lock"] diff --git a/app/services/grok/utils/upload.py b/app/services/grok/utils/upload.py new file mode 100644 index 00000000..85302ac1 --- /dev/null +++ b/app/services/grok/utils/upload.py @@ -0,0 +1,137 @@ +""" +Upload service. +""" + +import base64 +import re +from typing import Optional, Tuple +from urllib.parse import urlparse + +from curl_cffi.requests import AsyncSession + +from app.core.exceptions import AppException, UpstreamException, ValidationException +from app.core.logger import logger +from app.services.reverse import AssetsUploadReverse +from app.services.grok.utils.locks import _get_assets_semaphore + + +class UploadService: + """Assets upload service.""" + + def __init__(self): + self._session: Optional[AsyncSession] = None + + async def create(self) -> AsyncSession: + """Create or reuse a session.""" + if self._session is None: + self._session = AsyncSession() + return self._session + + async def close(self): + """Close the session.""" + if self._session: + await self._session.close() + self._session = None + + @staticmethod + def _is_url(value: str) -> bool: + """Check if the value is a URL.""" + try: + parsed = urlparse(value) + return bool(parsed.scheme and parsed.netloc and parsed.scheme in ["http", "https"]) + except Exception: + return False + + @staticmethod + async def parse_b64(url: str) -> Tuple[str, str, str]: + """Fetch URL content and return (filename, base64, mime).""" + try: + async with AsyncSession() as session: + response = await session.get(url, timeout=10) + if response.status_code >= 400: + raise UpstreamException( + message=f"Failed to fetch: {response.status_code}", + details={"url": url, "status": response.status_code}, + ) + + filename = url.split("/")[-1].split("?")[0] or "download" + content_type = response.headers.get( + "content-type", "application/octet-stream" + ).split(";")[0] + b64 = base64.b64encode(response.content).decode() + + logger.debug(f"Fetched: {url}") + return filename, b64, content_type + except Exception as e: + if isinstance(e, AppException): + raise + logger.error(f"Fetch failed: {url} - {e}") + raise UpstreamException(f"Fetch failed: {str(e)}", details={"url": url}) + + @staticmethod + def format_b64(data_uri: str) -> Tuple[str, str, str]: + """Format data URI to (filename, base64, mime).""" + if not data_uri.startswith("data:"): + return "file.bin", data_uri, "application/octet-stream" + + try: + header, b64 = data_uri.split(",", 1) + except ValueError: + return "file.bin", data_uri, "application/octet-stream" + + if ";base64" not in header: + return "file.bin", data_uri, "application/octet-stream" + + mime = header[5:].split(";", 1)[0] or "application/octet-stream" + b64 = re.sub(r"\s+", "", b64) + ext = mime.split("/")[-1] if "/" in mime else "bin" + return f"file.{ext}", b64, mime + + async def check_format(self, file_input: str) -> Tuple[str, str, str]: + """Check file input format and return (filename, base64, mime).""" + if not isinstance(file_input, str) or not file_input.strip(): + raise ValidationException("Invalid file input: empty content") + + if self._is_url(file_input): + return await self.parse_b64(file_input) + + return self.format_b64(file_input) + + async def upload_file(self, file_input: str, token: str) -> Tuple[str, str]: + """ + Upload file to Grok. + + Args: + file_input: str, the file input. + token: str, the SSO token. + + Returns: + Tuple[str, str]: The file ID and URI. + """ + async with _get_assets_semaphore(): + filename, b64, mime = await self.check_format(file_input) + + logger.debug( + f"Upload prepare: filename={filename}, type={mime}, size={len(b64)}" + ) + + if not b64: + raise ValidationException("Invalid file input: empty content") + + session = await self.create() + response = await AssetsUploadReverse.request( + session, + token, + filename, + mime, + b64, + ) + + result = response.json() + file_id = result.get("fileMetadataId", "") + file_uri = result.get("fileUri", "") + logger.info(f"Upload success: {filename} -> {file_id}") + return file_id, file_uri + + +__all__ = ["UploadService"] diff --git a/app/services/token/manager.py b/app/services/token/manager.py index f4b547e9..4f056d8b 100644 --- a/app/services/token/manager.py +++ b/app/services/token/manager.py @@ -14,9 +14,10 @@ BASIC__DEFAULT_QUOTA, SUPER_DEFAULT_QUOTA, ) -from app.core.storage import get_storage +from app.core.storage import get_storage, LocalStorage from app.core.config import get_config from app.services.token.pool import TokenPool +from app.services.grok.batch_services.usage import UsageService DEFAULT_REFRESH_BATCH_SIZE = 10 @@ -70,8 +71,6 @@ async def _load(self): # 如果后端返回 None 或空数据,尝试从本地 data/token.json 初始化后端 if not data: - from app.core.storage import LocalStorage - local_storage = LocalStorage() local_data = await local_storage.load_tokens() if local_data: @@ -363,8 +362,6 @@ async def sync_usage( # 尝试 API 同步 try: - from app.services.grok.services.usage import UsageService - usage_service = UsageService() result = await usage_service.get(token_str) @@ -634,8 +631,6 @@ async def refresh_cooling_tokens(self) -> Dict[str, int]: Returns: {"checked": int, "refreshed": int, "recovered": int, "expired": int} """ - from app.services.grok.services.usage import UsageService - # 收集需要刷新的 token to_refresh: List[TokenInfo] = [] for pool in self.pools.values(): @@ -676,7 +671,7 @@ async def _refresh_one(token_info: TokenInfo) -> dict: # 重试逻辑:最多 2 次重试 for retry in range(3): # 0, 1, 2 try: - result = await usage_service.get(token_str, model_name="grok-3") + result = await usage_service.get(token_str) if result and "remainingTokens" in result: new_quota = result["remainingTokens"] From 0d0e7931815a357040f53ca8bd93d20bb80544f8 Mon Sep 17 00:00:00 2001 From: Chenyme <118253778+chenyme@users.noreply.github.com> Date: Fri, 13 Feb 2026 11:25:50 +0800 Subject: [PATCH 14/27] feat: enhance asset management configuration and services --- app/api/v1/admin.py | 67 ++---- app/services/grok/batch_services/__init__.py | 4 +- app/services/grok/batch_services/assets.py | 240 +++++++++---------- app/services/grok/defaults.py | 16 +- app/services/grok/processors/base.py | 18 +- app/services/grok/processors/chat.py | 54 +---- app/services/grok/processors/video.py | 55 +---- app/services/grok/utils/download.py | 170 ++++++++----- app/services/grok/utils/locks.py | 37 ++- app/services/grok/utils/upload.py | 133 ++++++++-- app/services/reverse/assets_delete.py | 2 +- app/services/reverse/assets_download.py | 2 +- app/services/reverse/assets_list.py | 2 +- app/services/reverse/assets_upload.py | 2 +- app/static/config/config.js | 33 ++- config.defaults.toml | 32 ++- 16 files changed, 476 insertions(+), 391 deletions(-) diff --git a/app/api/v1/admin.py b/app/api/v1/admin.py index d88b98d0..ecd5caa4 100644 --- a/app/api/v1/admin.py +++ b/app/api/v1/admin.py @@ -16,11 +16,8 @@ from app.core.storage import get_storage, LocalStorage, RedisStorage, SQLStorage from app.core.exceptions import AppException from app.services.token.manager import get_token_manager -from app.services.grok.batch_services import ( - BatchUsageService, - BatchNSFWService, - BatchAssetsService, -) +from app.services.grok.batch_services import BatchUsageService, BatchNSFWService +from app.services.grok.batch_services.assets import ListService, DeleteService import os import time import uuid @@ -130,6 +127,8 @@ def _truncate_tokens( return unique_tokens, truncated, original_count + + def _mask_token(token: str) -> str: """掩码 token 显示""" return f"{token[:8]}...{token[-8:]}" if len(token) > 20 else token @@ -1167,19 +1166,14 @@ async def get_cache_stats_api(request: Request): } online_details = [] account_map = {a["token"]: a for a in accounts} - max_concurrent = max(1, int(get_config("performance.assets_max_concurrent"))) - batch_size = max(1, int(get_config("performance.assets_batch_size"))) - max_tokens = int(get_config("performance.assets_max_tokens")) - + batch_size = max(1, int(get_config("asset.list_batch_size"))) + max_concurrent = batch_size truncated = False original_count = 0 if selected_tokens: - selected_tokens, truncated, original_count = _truncate_tokens( - selected_tokens, max_tokens, "Assets fetch" - ) total = 0 - raw_results = await BatchAssetsService.fetch_details( + raw_results = await ListService.fetch_assets_details( selected_tokens, account_map, max_concurrent=max_concurrent, @@ -1214,10 +1208,7 @@ async def get_cache_stats_api(request: Request): total = 0 tokens = list(dict.fromkeys([account["token"] for account in accounts])) original_count = len(tokens) - if len(tokens) > max_tokens: - tokens = tokens[:max_tokens] - truncated = True - raw_results = await BatchAssetsService.fetch_details( + raw_results = await ListService.fetch_assets_details( tokens, account_map, max_concurrent=max_concurrent, @@ -1250,7 +1241,7 @@ async def get_cache_stats_api(request: Request): else: token = selected_token if token: - raw_results = await BatchAssetsService.fetch_details( + raw_results = await ListService.fetch_assets_details( [token], account_map, max_concurrent=1, @@ -1349,13 +1340,11 @@ async def load_online_cache_api_async(data: dict): else: raise HTTPException(status_code=400, detail="No tokens provided") - max_tokens = int(get_config("performance.assets_max_tokens")) - selected_tokens, truncated, original_count = _truncate_tokens( - selected_tokens, max_tokens, "Assets load" - ) + truncated = False + original_count = len(selected_tokens) - max_concurrent = get_config("performance.assets_max_concurrent") - batch_size = get_config("performance.assets_batch_size") + batch_size = get_config("asset.list_batch_size") + max_concurrent = batch_size task = create_task(len(selected_tokens)) @@ -1369,7 +1358,7 @@ async def _on_item(item: str, res: dict): ok = bool(res.get("data", {}).get("ok")) task.record(ok) - raw_results = await BatchAssetsService.fetch_details( + raw_results = await ListService.fetch_assets_details( selected_tokens, account_map, max_concurrent=max_concurrent, @@ -1496,18 +1485,14 @@ async def clear_online_cache_api(data: dict): token_list = list(dict.fromkeys(token_list)) # 最大数量限制 - max_tokens = int(get_config("performance.assets_max_tokens")) - token_list, truncated, original_count = _truncate_tokens( - token_list, max_tokens, "Clear online cache" - ) + truncated = False + original_count = len(token_list) results = {} - max_concurrent = max( - 1, int(get_config("performance.assets_max_concurrent")) - ) - batch_size = max(1, int(get_config("performance.assets_batch_size"))) + batch_size = max(1, int(get_config("asset.delete_batch_size"))) + max_concurrent = batch_size - raw_results = await BatchAssetsService.clear_online( + raw_results = await DeleteService.clear_assets( token_list, mgr, max_concurrent=max_concurrent, @@ -1532,7 +1517,7 @@ async def clear_online_cache_api(data: dict): status_code=400, detail="No available token to perform cleanup" ) - raw_results = await BatchAssetsService.clear_online( + raw_results = await DeleteService.clear_assets( [token], mgr, max_concurrent=1, @@ -1563,13 +1548,11 @@ async def clear_online_cache_api_async(data: dict): if not token_list: raise HTTPException(status_code=400, detail="No tokens provided") - max_tokens = int(get_config("performance.assets_max_tokens")) - token_list, truncated, original_count = _truncate_tokens( - token_list, max_tokens, "Clear online cache async" - ) + truncated = False + original_count = len(token_list) - max_concurrent = get_config("performance.assets_max_concurrent") - batch_size = get_config("performance.assets_batch_size") + batch_size = get_config("asset.delete_batch_size") + max_concurrent = batch_size task = create_task(len(token_list)) @@ -1579,7 +1562,7 @@ async def _on_item(item: str, res: dict): ok = bool(res.get("data", {}).get("ok")) task.record(ok) - raw_results = await BatchAssetsService.clear_online( + raw_results = await DeleteService.clear_assets( token_list, mgr, max_concurrent=max_concurrent, diff --git a/app/services/grok/batch_services/__init__.py b/app/services/grok/batch_services/__init__.py index 49796375..7a4107ed 100644 --- a/app/services/grok/batch_services/__init__.py +++ b/app/services/grok/batch_services/__init__.py @@ -2,6 +2,4 @@ from .usage import BatchUsageService from .nsfw import BatchNSFWService -from .assets import BatchAssetsService - -__all__ = ["BatchUsageService", "BatchNSFWService", "BatchAssetsService"] +__all__ = ["BatchUsageService", "BatchNSFWService"] diff --git a/app/services/grok/batch_services/assets.py b/app/services/grok/batch_services/assets.py index c22045c5..14b2ccc2 100644 --- a/app/services/grok/batch_services/assets.py +++ b/app/services/grok/batch_services/assets.py @@ -3,14 +3,13 @@ """ import asyncio -from typing import Callable, Awaitable, Dict, Any, Optional, List +from typing import Dict, List, Optional from curl_cffi.requests import AsyncSession from app.core.config import get_config from app.core.logger import logger from app.services.reverse import AssetsListReverse, AssetsDeleteReverse -from app.services.grok.utils.locks import _get_assets_semaphore from app.services.grok.utils.batch import run_in_batches @@ -31,10 +30,34 @@ async def close(self): self._session = None +_LIST_SEMAPHORE = None +_LIST_SEM_VALUE = None +_DELETE_SEMAPHORE = None +_DELETE_SEM_VALUE = None + + +def _get_list_semaphore() -> asyncio.Semaphore: + value = max(1, int(get_config("asset.list_concurrent"))) + global _LIST_SEMAPHORE, _LIST_SEM_VALUE + if _LIST_SEMAPHORE is None or value != _LIST_SEM_VALUE: + _LIST_SEM_VALUE = value + _LIST_SEMAPHORE = asyncio.Semaphore(value) + return _LIST_SEMAPHORE + + +def _get_delete_semaphore() -> asyncio.Semaphore: + value = max(1, int(get_config("asset.delete_concurrent"))) + global _DELETE_SEMAPHORE, _DELETE_SEM_VALUE + if _DELETE_SEMAPHORE is None or value != _DELETE_SEM_VALUE: + _DELETE_SEM_VALUE = value + _DELETE_SEMAPHORE = asyncio.Semaphore(value) + return _DELETE_SEMAPHORE + + class ListService(BaseAssetsService): """Assets list service.""" - async def iter_assets(self, token: str): + async def list(self, token: str) -> Dict[str, List[str] | int]: params = { "pageSize": 50, "orderBy": "ORDER_BY_LAST_USE_TIME", @@ -43,135 +66,61 @@ async def iter_assets(self, token: str): } page_token = None seen_tokens = set() + asset_ids: List[str] = [] + session = await self._get_session() + while True: + if page_token: + if page_token in seen_tokens: + logger.warning("Pagination stopped: repeated page token") + break + seen_tokens.add(page_token) + params["pageToken"] = page_token + else: + params.pop("pageToken", None) - async with AsyncSession() as session: - while True: - if page_token: - if page_token in seen_tokens: - logger.warning("Pagination stopped: repeated page token") - break - seen_tokens.add(page_token) - params["pageToken"] = page_token - else: - params.pop("pageToken", None) - + async with _get_list_semaphore(): response = await AssetsListReverse.request( session, token, params, ) - result = response.json() - page_assets = result.get("assets", []) - yield page_assets - - page_token = result.get("nextPageToken") - if not page_token: - break - - async def list(self, token: str) -> List[Dict]: - assets = [] - async for page_assets in self.iter_assets(token): - assets.extend(page_assets) - logger.info(f"List success: {len(assets)} files") - return assets - - async def count(self, token: str) -> int: - total = 0 - async for page_assets in self.iter_assets(token): - total += len(page_assets) - logger.debug(f"Asset count: {total}") - return total - - -class DeleteService(BaseAssetsService): - """Assets delete service.""" - - async def delete(self, token: str, asset_id: str) -> bool: - async with _get_assets_semaphore(): - session = await self._get_session() - await AssetsDeleteReverse.request( - session, - token, - asset_id, - ) - - logger.debug(f"Deleted: {asset_id}") - return True - - async def delete_all(self, token: str) -> Dict[str, int]: - total = success = failed = 0 - list_service = ListService() - - try: - async for assets in list_service.iter_assets(token): - if not assets: - continue - - total += len(assets) - batch_result = await self._delete_batch(token, assets) - success += batch_result["success"] - failed += batch_result["failed"] - - if total == 0: - logger.info("No assets to delete") - return {"total": 0, "success": 0, "failed": 0, "skipped": True} - finally: - await list_service.close() - - logger.info(f"Delete all: total={total}, success={success}, failed={failed}") - return {"total": total, "success": success, "failed": failed} - - async def _delete_batch(self, token: str, assets: List[Dict]) -> Dict[str, int]: - batch_size = max(1, int(get_config("performance.assets_delete_batch_size"))) - success = failed = 0 - - for i in range(0, len(assets), batch_size): - batch = assets[i : i + batch_size] - results = await asyncio.gather( - *[ - self._delete_one(token, asset, idx) - for idx, asset in enumerate(batch) - ], - return_exceptions=True, - ) - success += sum(1 for r in results if r is True) - failed += sum(1 for r in results if r is not True) - - return {"success": success, "failed": failed} + result = response.json() + page_assets = result.get("assets", []) + if page_assets: + for asset in page_assets: + asset_id = asset.get("assetId") + if asset_id: + asset_ids.append(asset_id) - async def _delete_one(self, token: str, asset: Dict, index: int) -> bool: - await asyncio.sleep(0.01 * index) - asset_id = asset.get("assetId", "") - if not asset_id: - return False - try: - return await self.delete(token, asset_id) - except Exception: - return False + page_token = result.get("nextPageToken") + if not page_token: + break - -class BatchAssetsService: - """Batch assets orchestration.""" + logger.info(f"List success: {len(asset_ids)} files") + return {"asset_ids": asset_ids, "count": len(asset_ids)} @staticmethod - async def fetch_details( + async def fetch_assets_details( tokens: list[str], - account_map: Dict[str, Dict[str, Any]], + account_map: dict, *, max_concurrent: int, batch_size: int, include_ok: bool = False, - on_item: Optional[Callable[[str, Dict[str, Any]], Awaitable[None]]] = None, - should_cancel: Optional[Callable[[], bool]] = None, - ) -> Dict[str, Dict[str, Any]]: + on_item=None, + should_cancel=None, + ) -> dict: + """Batch fetch assets details for tokens.""" account_map = account_map or {} + shared_service = ListService() async def _fetch_detail(token: str): account = account_map.get(token) - list_service = ListService() try: - count = await list_service.count(token) + result = await shared_service.list(token) + asset_ids = result.get("asset_ids", []) + count = result.get("count", len(asset_ids)) detail = { "token": token, "token_masked": account["token_masked"] if account else token, @@ -197,34 +146,68 @@ async def _fetch_detail(token: str): if include_ok: return {"ok": False, "detail": detail, "count": 0} return {"detail": detail, "count": 0} - finally: - await list_service.close() - - return await run_in_batches( - tokens, - _fetch_detail, - max_concurrent=max_concurrent, - batch_size=batch_size, - on_item=on_item, - should_cancel=should_cancel, - ) + + try: + return await run_in_batches( + tokens, + _fetch_detail, + max_concurrent=max_concurrent, + batch_size=batch_size, + on_item=on_item, + should_cancel=should_cancel, + ) + finally: + await shared_service.close() + + +class DeleteService(BaseAssetsService): + """Assets delete service.""" + + async def delete(self, token: str, asset_ids: List[str]) -> Dict[str, int]: + if not asset_ids: + logger.info("No assets to delete") + return {"total": 0, "success": 0, "failed": 0, "skipped": True} + + total = len(asset_ids) + success = 0 + failed = 0 + session = await self._get_session() + + async def _delete_one(asset_id: str): + async with _get_delete_semaphore(): + await AssetsDeleteReverse.request(session, token, asset_id) + + tasks = [_delete_one(asset_id) for asset_id in asset_ids if asset_id] + results = await asyncio.gather(*tasks, return_exceptions=True) + for res in results: + if isinstance(res, Exception): + failed += 1 + else: + success += 1 + + logger.info(f"Delete all: total={total}, success={success}, failed={failed}") + return {"total": total, "success": success, "failed": failed} @staticmethod - async def clear_online( + async def clear_assets( tokens: list[str], mgr, *, max_concurrent: int, batch_size: int, include_ok: bool = False, - on_item: Optional[Callable[[str, Dict[str, Any]], Awaitable[None]]] = None, - should_cancel: Optional[Callable[[], bool]] = None, - ) -> Dict[str, Dict[str, Any]]: + on_item=None, + should_cancel=None, + ) -> dict: + """Batch clear assets for tokens.""" delete_service = DeleteService() + list_service = ListService() async def _clear_one(token: str): try: - result = await delete_service.delete_all(token) + result = await list_service.list(token) + asset_ids = result.get("asset_ids", []) + result = await delete_service.delete(token, asset_ids) await mgr.mark_asset_clear(token) if include_ok: return {"ok": True, "result": result} @@ -245,6 +228,7 @@ async def _clear_one(token: str): ) finally: await delete_service.close() + await list_service.close() -__all__ = ["BatchAssetsService"] +__all__ = ["ListService", "DeleteService"] diff --git a/app/services/grok/defaults.py b/app/services/grok/defaults.py index 03f1c10f..b273dcc6 100644 --- a/app/services/grok/defaults.py +++ b/app/services/grok/defaults.py @@ -62,11 +62,19 @@ "enable_auto_clean": True, "limit_mb": 1024, }, + "asset": { + "upload_concurrent": 30, + "upload_timeout": 60, + "download_concurrent": 30, + "download_timeout": 60, + "list_concurrent": 10, + "list_timeout": 60, + "list_batch_size": 10, + "delete_concurrent": 10, + "delete_timeout": 60, + "delete_batch_size": 10, + }, "performance": { - "assets_max_concurrent": 25, - "assets_delete_batch_size": 10, - "assets_batch_size": 10, - "assets_max_tokens": 1000, "media_max_concurrent": 50, "usage_max_concurrent": 25, "usage_batch_size": 50, diff --git a/app/services/grok/processors/base.py b/app/services/grok/processors/base.py index 1d2d5c94..cab0631a 100644 --- a/app/services/grok/processors/base.py +++ b/app/services/grok/processors/base.py @@ -11,8 +11,6 @@ from app.services.grok.utils.download import DownloadService -ASSET_URL = "https://assets.grok.com/" - T = TypeVar("T") @@ -134,20 +132,8 @@ async def close(self): async def process_url(self, path: str, media_type: str = "image") -> str: """处理资产 URL""" - if path.startswith("http"): - from urllib.parse import urlparse - - path = urlparse(path).path - - if not path.startswith("/"): - path = f"/{path}" - - if self.app_url: - dl_service = self._get_dl() - await dl_service.download_file(path, self.token, media_type) - return f"{self.app_url.rstrip('/')}/v1/files/{media_type}{path}" - else: - return f"{ASSET_URL.rstrip('/')}{path}" + dl_service = self._get_dl() + return await dl_service.resolve_url(path, self.token, media_type) __all__ = [ diff --git a/app/services/grok/processors/chat.py b/app/services/grok/processors/chat.py index 1e044370..e5e7bd61 100644 --- a/app/services/grok/processors/chat.py +++ b/app/services/grok/processors/chat.py @@ -33,7 +33,6 @@ def __init__(self, model: str, token: str = "", think: bool = None): self.think_opened: bool = False self.role_sent: bool = False self.filter_tags = get_config("chat.filter_tags") - self.image_format = get_config("app.image_format") self._tag_buffer: str = "" self._in_filter_tag: bool = False @@ -164,27 +163,11 @@ async def process( for url in _collect_image_urls(mr): parts = url.split("/") img_id = parts[-2] if len(parts) >= 2 else "image" - - if self.image_format == "base64": - try: - dl_service = self._get_dl() - base64_data = await dl_service.parse_b64( - url, self.token, "image" - ) - if base64_data: - yield self._sse(f"![{img_id}]({base64_data})\n") - else: - final_url = await self.process_url(url, "image") - yield self._sse(f"![{img_id}]({final_url})\n") - except Exception as e: - logger.warning( - f"Failed to convert image to base64, falling back to URL: {e}" - ) - final_url = await self.process_url(url, "image") - yield self._sse(f"![{img_id}]({final_url})\n") - else: - final_url = await self.process_url(url, "image") - yield self._sse(f"![{img_id}]({final_url})\n") + dl_service = self._get_dl() + rendered = await dl_service.render_image( + url, self.token, img_id + ) + yield self._sse(f"{rendered}\n") if ( (meta := mr.get("metadata", {})) @@ -246,7 +229,6 @@ class CollectProcessor(BaseProcessor): def __init__(self, model: str, token: str = ""): super().__init__(model, token) - self.image_format = get_config("app.image_format") self.filter_tags = get_config("chat.filter_tags") def _filter_content(self, content: str) -> str: @@ -292,27 +274,11 @@ async def process(self, response: AsyncIterable[bytes]) -> dict[str, Any]: for url in urls: parts = url.split("/") img_id = parts[-2] if len(parts) >= 2 else "image" - - if self.image_format == "base64": - try: - dl_service = self._get_dl() - base64_data = await dl_service.parse_b64( - url, self.token, "image" - ) - if base64_data: - content += f"![{img_id}]({base64_data})\n" - else: - final_url = await self.process_url(url, "image") - content += f"![{img_id}]({final_url})\n" - except Exception as e: - logger.warning( - f"Failed to convert image to base64, falling back to URL: {e}" - ) - final_url = await self.process_url(url, "image") - content += f"![{img_id}]({final_url})\n" - else: - final_url = await self.process_url(url, "image") - content += f"![{img_id}]({final_url})\n" + dl_service = self._get_dl() + rendered = await dl_service.render_image( + url, self.token, img_id + ) + content += f"{rendered}\n" if ( (meta := mr.get("metadata", {})) diff --git a/app/services/grok/processors/video.py b/app/services/grok/processors/video.py index 89521b9e..8ba68b31 100644 --- a/app/services/grok/processors/video.py +++ b/app/services/grok/processors/video.py @@ -29,7 +29,6 @@ def __init__(self, model: str, token: str = "", think: bool = None): self.response_id: Optional[str] = None self.think_opened: bool = False self.role_sent: bool = False - self.video_format = str(get_config("app.video_format")).lower() if think is None: self.show_think = get_config("chat.thinking") @@ -56,17 +55,6 @@ def _sse(self, content: str = "", role: str = None, finish: str = None) -> str: } return f"data: {orjson.dumps(chunk).decode()}\n\n" - def _build_video_html(self, video_url: str, thumbnail_url: str = "") -> str: - """Build video HTML tag.""" - import html - - safe_video_url = html.escape(video_url) - safe_thumbnail_url = html.escape(thumbnail_url) - poster_attr = f' poster="{safe_thumbnail_url}"' if safe_thumbnail_url else "" - return f'''''' - async def process( self, response: AsyncIterable[bytes] ) -> AsyncGenerator[str, None]: @@ -111,20 +99,11 @@ async def process( self.think_opened = False if video_url: - final_video_url = await self.process_url(video_url, "video") - final_thumbnail_url = "" - if thumbnail_url: - final_thumbnail_url = await self.process_url( - thumbnail_url, "image" - ) - - if self.video_format == "url": - yield self._sse(final_video_url) - else: - video_html = self._build_video_html( - final_video_url, final_thumbnail_url - ) - yield self._sse(video_html) + dl_service = self._get_dl() + rendered = await dl_service.render_video( + video_url, self.token, thumbnail_url + ) + yield self._sse(rendered) logger.info(f"Video generated: {video_url}") continue @@ -179,13 +158,6 @@ class VideoCollectProcessor(BaseProcessor): def __init__(self, model: str, token: str = ""): super().__init__(model, token) - self.video_format = str(get_config("app.video_format")).lower() - - def _build_video_html(self, video_url: str, thumbnail_url: str = "") -> str: - poster_attr = f' poster="{thumbnail_url}"' if thumbnail_url else "" - return f'''''' async def process(self, response: AsyncIterable[bytes]) -> dict[str, Any]: """Process and collect video response.""" @@ -212,19 +184,10 @@ async def process(self, response: AsyncIterable[bytes]) -> dict[str, Any]: thumbnail_url = video_resp.get("thumbnailImageUrl", "") if video_url: - final_video_url = await self.process_url(video_url, "video") - final_thumbnail_url = "" - if thumbnail_url: - final_thumbnail_url = await self.process_url( - thumbnail_url, "image" - ) - - if self.video_format == "url": - content = final_video_url - else: - content = self._build_video_html( - final_video_url, final_thumbnail_url - ) + dl_service = self._get_dl() + content = await dl_service.render_video( + video_url, self.token, thumbnail_url + ) logger.info(f"Video generated: {video_url}") except asyncio.CancelledError: diff --git a/app/services/grok/utils/download.py b/app/services/grok/utils/download.py index 9179610b..fbb7872f 100644 --- a/app/services/grok/utils/download.py +++ b/app/services/grok/utils/download.py @@ -1,5 +1,7 @@ """ Download service. + +Download service for assets.grok.com. """ import asyncio @@ -18,7 +20,7 @@ from app.core.config import get_config from app.core.exceptions import AppException from app.services.reverse import AssetsDownloadReverse -from app.services.grok.utils.locks import _get_assets_semaphore, _file_lock +from app.services.grok.utils.locks import _get_download_semaphore, _file_lock class DownloadService: @@ -45,72 +47,119 @@ async def close(self): await self._session.close() self._session = None - @staticmethod - def _is_url(value: str) -> bool: - """Check if the value is a URL.""" + async def resolve_url( + self, path_or_url: str, token: str, media_type: str = "image" + ) -> str: + asset_url = path_or_url + path = path_or_url + if path_or_url.startswith("http"): + parsed = urlparse(path_or_url) + path = parsed.path or "" + asset_url = path_or_url + else: + if not path_or_url.startswith("/"): + path_or_url = f"/{path_or_url}" + path = path_or_url + asset_url = f"https://assets.grok.com{path_or_url}" + + app_url = get_config("app.app_url") + if app_url: + await self.download_file(asset_url, token, media_type) + return f"{app_url.rstrip('/')}/v1/files/{media_type}{path}" + return asset_url + + async def render_image( + self, url: str, token: str, image_id: str = "image" + ) -> str: + fmt = get_config("app.image_format") + fmt = fmt.lower() if isinstance(fmt, str) else "url" + if fmt not in ("base64", "url", "markdown"): + fmt = "url" try: - parsed = urlparse(value) - return bool(parsed.scheme and parsed.netloc and parsed.scheme in ["http", "https"]) - except Exception: - return False + if fmt == "base64": + data_uri = await self.parse_b64(url, token, "image") + return f"![{image_id}]({data_uri})" + final_url = await self.resolve_url(url, token, "image") + return f"![{image_id}]({final_url})" + except Exception as e: + logger.warning(f"Image render failed, fallback to URL: {e}") + final_url = await self.resolve_url(url, token, "image") + return f"![{image_id}]({final_url})" + + async def render_video( + self, video_url: str, token: str, thumbnail_url: str = "" + ) -> str: + fmt = get_config("app.video_format") + fmt = fmt.lower() if isinstance(fmt, str) else "url" + if fmt not in ("url", "markdown", "html"): + fmt = "url" + final_video_url = await self.resolve_url(video_url, token, "video") + final_thumb_url = "" + if thumbnail_url: + final_thumb_url = await self.resolve_url(thumbnail_url, token, "image") + if fmt == "url": + return final_video_url + if fmt == "markdown": + return f"[video]({final_video_url})" + import html + + safe_video_url = html.escape(final_video_url) + safe_thumbnail_url = html.escape(final_thumb_url) + poster_attr = f' poster="{safe_thumbnail_url}"' if safe_thumbnail_url else "" + return f'''''' async def parse_b64(self, file_path: str, token: str, media_type: str = "image") -> str: """Download and return data URI.""" try: - cache_path, mime = await self.download_file(file_path, token, media_type) - if not cache_path or not cache_path.exists(): - logger.warning(f"Download failed for {file_path}: invalid path") - raise AppException( - "Download failed: invalid path", code="download_failed" - ) - - data_uri = await self.format_b64(cache_path, mime) - - if data_uri: - try: - cache_path.unlink() - except Exception as e: - logger.debug(f"Failed to cleanup temp file {cache_path}: {e}") + if not isinstance(file_path, str) or not file_path.strip(): + raise AppException("Invalid file path", code="invalid_file_path") + if file_path.startswith("data:"): + raise AppException("Invalid file path", code="invalid_file_path") + if not self._is_url(file_path): + raise AppException("Invalid file path", code="invalid_file_path") + + file_path = self._normalize_path(file_path) + lock_name = f"dl_b64_{hashlib.sha1(file_path.encode()).hexdigest()[:16]}" + lock_timeout = max(1, int(get_config("asset.download_timeout"))) + async with _get_download_semaphore(): + async with _file_lock(lock_name, timeout=lock_timeout): + session = await self.create() + response = await AssetsDownloadReverse.request( + session, token, file_path + ) + + if hasattr(response, "aiter_content"): + data = bytearray() + async for chunk in response.aiter_content(): + if chunk: + data.extend(chunk) + raw = bytes(data) + else: + raw = response.content + + content_type = response.headers.get( + "content-type", "application/octet-stream" + ).split(";")[0] + data_uri = f"data:{content_type};base64,{base64.b64encode(raw).decode()}" return data_uri except Exception as e: logger.error(f"Failed to convert {file_path} to base64: {e}") raise - @staticmethod - async def format_b64(file_path: Path, mime_type: str) -> str: - """Format local file to data URI.""" - try: - if not file_path.exists(): - logger.warning(f"File not found for base64 conversion: {file_path}") - raise AppException( - f"File not found: {file_path}", code="file_not_found" - ) - - if not file_path.is_file(): - logger.warning(f"Path is not a file: {file_path}") - raise AppException( - f"Invalid file path: {file_path}", code="invalid_file_path" - ) - - async with aiofiles.open(file_path, "rb") as f: - data = await f.read() - b64_data = base64.b64encode(data).decode() - return f"data:{mime_type};base64,{b64_data}" - except AppException: - raise - except Exception as e: - logger.error(f"File to base64 failed: {file_path} - {e}") - raise AppException( - f"Failed to read file: {file_path}", code="file_read_error" - ) - - def check_format(self, file_path: str) -> str: + def _normalize_path(self, file_path: str) -> str: """Normalize file path for download.""" if not isinstance(file_path, str) or not file_path.strip(): raise AppException("Invalid file path", code="invalid_file_path") - if self._is_url(file_path): - file_path = urlparse(file_path).path or "" + parsed = urlparse(file_path) + if not (parsed.scheme and parsed.netloc and parsed.scheme in ["http", "https"]): + raise AppException("Invalid file path", code="invalid_file_path") + path = parsed.path or "" + if parsed.query: + path = f"{path}?{parsed.query}" + file_path = path if not file_path.startswith("/"): file_path = f"/{file_path}" return file_path @@ -126,8 +175,8 @@ async def download_file(self, file_path: str, token: str, media_type: str = "ima Returns: Tuple[Optional[Path], str]: The path of the downloaded file and the MIME type. """ - async with _get_assets_semaphore(): - file_path = self.check_format(file_path) + async with _get_download_semaphore(): + file_path = self._normalize_path(file_path) cache_dir = self.image_dir if media_type == "image" else self.video_dir filename = file_path.lstrip("/").replace("/", "-") cache_path = cache_dir / filename @@ -135,11 +184,10 @@ async def download_file(self, file_path: str, token: str, media_type: str = "ima lock_name = ( f"dl_{media_type}_{hashlib.sha1(str(cache_path).encode()).hexdigest()[:16]}" ) - async with _file_lock(lock_name, timeout=10): + lock_timeout = max(1, int(get_config("asset.download_timeout"))) + async with _file_lock(lock_name, timeout=lock_timeout): session = await self.create() - response = await AssetsDownloadReverse.request( - session, token, file_path - ) + response = await AssetsDownloadReverse.request(session, token, file_path) tmp_path = cache_path.with_suffix(cache_path.suffix + ".tmp") try: @@ -163,11 +211,11 @@ async def download_file(self, file_path: str, token: str, media_type: str = "ima ).split(";")[0] logger.info(f"Downloaded: {file_path}") - asyncio.create_task(self.check_limit()) + asyncio.create_task(self._check_limit()) return cache_path, mime - async def check_limit(self): + async def _check_limit(self): """Check cache limit and cleanup. Args: diff --git a/app/services/grok/utils/locks.py b/app/services/grok/utils/locks.py index 654ae9e1..0ad227f5 100644 --- a/app/services/grok/utils/locks.py +++ b/app/services/grok/utils/locks.py @@ -18,19 +18,32 @@ LOCK_DIR = DATA_DIR / ".locks" -_ASSETS_SEMAPHORE = None -_ASSETS_SEM_VALUE = None +_UPLOAD_SEMAPHORE = None +_UPLOAD_SEM_VALUE = None +_DOWNLOAD_SEMAPHORE = None +_DOWNLOAD_SEM_VALUE = None -def _get_assets_semaphore() -> asyncio.Semaphore: - """Return global semaphore for assets operations.""" - value = max(1, int(get_config("performance.assets_max_concurrent"))) +def _get_upload_semaphore() -> asyncio.Semaphore: + """Return global semaphore for upload operations.""" + value = max(1, int(get_config("asset.upload_concurrent"))) - global _ASSETS_SEMAPHORE, _ASSETS_SEM_VALUE - if _ASSETS_SEMAPHORE is None or value != _ASSETS_SEM_VALUE: - _ASSETS_SEM_VALUE = value - _ASSETS_SEMAPHORE = asyncio.Semaphore(value) - return _ASSETS_SEMAPHORE + global _UPLOAD_SEMAPHORE, _UPLOAD_SEM_VALUE + if _UPLOAD_SEMAPHORE is None or value != _UPLOAD_SEM_VALUE: + _UPLOAD_SEM_VALUE = value + _UPLOAD_SEMAPHORE = asyncio.Semaphore(value) + return _UPLOAD_SEMAPHORE + + +def _get_download_semaphore() -> asyncio.Semaphore: + """Return global semaphore for download operations.""" + value = max(1, int(get_config("asset.download_concurrent"))) + + global _DOWNLOAD_SEMAPHORE, _DOWNLOAD_SEM_VALUE + if _DOWNLOAD_SEMAPHORE is None or value != _DOWNLOAD_SEM_VALUE: + _DOWNLOAD_SEM_VALUE = value + _DOWNLOAD_SEMAPHORE = asyncio.Semaphore(value) + return _DOWNLOAD_SEMAPHORE @asynccontextmanager @@ -57,6 +70,8 @@ async def _file_lock(name: str, timeout: int = 10): if time.monotonic() - start >= timeout: break await asyncio.sleep(0.05) + if not locked: + raise TimeoutError(f"Failed to acquire lock: {name}") yield finally: if fd: @@ -68,4 +83,4 @@ async def _file_lock(name: str, timeout: int = 10): fd.close() -__all__ = ["_get_assets_semaphore", "_file_lock"] +__all__ = ["_get_upload_semaphore", "_get_download_semaphore", "_file_lock"] diff --git a/app/services/grok/utils/upload.py b/app/services/grok/utils/upload.py index 85302ac1..ee17cc42 100644 --- a/app/services/grok/utils/upload.py +++ b/app/services/grok/utils/upload.py @@ -1,18 +1,26 @@ """ Upload service. + +Upload service for assets.grok.com. """ import base64 +import hashlib +import mimetypes import re -from typing import Optional, Tuple +from pathlib import Path +from typing import AsyncIterator, Optional, Tuple from urllib.parse import urlparse +import aiofiles from curl_cffi.requests import AsyncSession +from app.core.config import get_config from app.core.exceptions import AppException, UpstreamException, ValidationException from app.core.logger import logger +from app.core.storage import DATA_DIR from app.services.reverse import AssetsUploadReverse -from app.services.grok.utils.locks import _get_assets_semaphore +from app.services.grok.utils.locks import _get_upload_semaphore, _file_lock class UploadService: @@ -20,6 +28,7 @@ class UploadService: def __init__(self): self._session: Optional[AsyncSession] = None + self._chunk_size = 64 * 1024 async def create(self) -> AsyncSession: """Create or reuse a session.""" @@ -38,16 +47,102 @@ def _is_url(value: str) -> bool: """Check if the value is a URL.""" try: parsed = urlparse(value) - return bool(parsed.scheme and parsed.netloc and parsed.scheme in ["http", "https"]) + return bool( + parsed.scheme and parsed.netloc and parsed.scheme in ["http", "https"] + ) except Exception: return False @staticmethod - async def parse_b64(url: str) -> Tuple[str, str, str]: + def _infer_mime(filename: str, fallback: str = "application/octet-stream") -> str: + mime, _ = mimetypes.guess_type(filename) + return mime or fallback + + @staticmethod + async def _encode_b64_stream(chunks: AsyncIterator[bytes]) -> str: + parts = [] + remain = b"" + async for chunk in chunks: + if not chunk: + continue + chunk = remain + chunk + keep = len(chunk) % 3 + if keep: + remain = chunk[-keep:] + chunk = chunk[:-keep] + else: + remain = b"" + if chunk: + parts.append(base64.b64encode(chunk).decode()) + if remain: + parts.append(base64.b64encode(remain).decode()) + return "".join(parts) + + async def _read_local_file(self, local_type: str, name: str) -> Tuple[str, str, str]: + base_dir = DATA_DIR / "tmp" + if local_type == "video": + local_dir = base_dir / "video" + mime = "video/mp4" + else: + local_dir = base_dir / "image" + suffix = Path(name).suffix.lower() + if suffix == ".png": + mime = "image/png" + elif suffix == ".webp": + mime = "image/webp" + elif suffix == ".gif": + mime = "image/gif" + else: + mime = "image/jpeg" + + local_path = local_dir / name + lock_name = f"ul_local_{hashlib.sha1(str(local_path).encode()).hexdigest()[:16]}" + lock_timeout = max(1, int(get_config("asset.upload_timeout"))) + async with _file_lock(lock_name, timeout=lock_timeout): + if not local_path.exists(): + raise ValidationException(f"Local file not found: {local_path}") + if not local_path.is_file(): + raise ValidationException(f"Invalid local file: {local_path}") + + async def _iter_file() -> AsyncIterator[bytes]: + async with aiofiles.open(local_path, "rb") as f: + while True: + chunk = await f.read(self._chunk_size) + if not chunk: + break + yield chunk + + b64 = await self._encode_b64_stream(_iter_file()) + filename = name or "file" + return filename, b64, mime + + async def parse_b64(self, url: str) -> Tuple[str, str, str]: """Fetch URL content and return (filename, base64, mime).""" try: - async with AsyncSession() as session: - response = await session.get(url, timeout=10) + app_url = get_config("app.app_url") or "" + if app_url and self._is_url(url): + parsed = urlparse(url) + app_parsed = urlparse(app_url) + if ( + parsed.scheme == app_parsed.scheme + and parsed.netloc == app_parsed.netloc + and parsed.path.startswith("/v1/files/") + ): + parts = parsed.path.strip("/").split("/", 3) + if len(parts) >= 4: + local_type = parts[2] + name = parts[3].replace("/", "-") + return await self._read_local_file(local_type, name) + + lock_name = f"ul_url_{hashlib.sha1(url.encode()).hexdigest()[:16]}" + timeout = float(get_config("asset.upload_timeout")) + proxy_url = get_config("network.base_proxy_url") + proxies = {"http": proxy_url, "https": proxy_url} if proxy_url else None + + lock_timeout = max(1, int(get_config("asset.upload_timeout"))) + async with _file_lock(lock_name, timeout=lock_timeout): + session = await self.create() + response = await session.get(url, timeout=timeout, proxies=proxies) if response.status_code >= 400: raise UpstreamException( message=f"Failed to fetch: {response.status_code}", @@ -56,9 +151,14 @@ async def parse_b64(url: str) -> Tuple[str, str, str]: filename = url.split("/")[-1].split("?")[0] or "download" content_type = response.headers.get( - "content-type", "application/octet-stream" - ).split(";")[0] - b64 = base64.b64encode(response.content).decode() + "content-type", "" + ).split(";")[0].strip() + if not content_type: + content_type = self._infer_mime(filename) + if hasattr(response, "aiter_content"): + b64 = await self._encode_b64_stream(response.aiter_content()) + else: + b64 = base64.b64encode(response.content).decode() logger.debug(f"Fetched: {url}") return filename, b64, content_type @@ -72,18 +172,20 @@ async def parse_b64(url: str) -> Tuple[str, str, str]: def format_b64(data_uri: str) -> Tuple[str, str, str]: """Format data URI to (filename, base64, mime).""" if not data_uri.startswith("data:"): - return "file.bin", data_uri, "application/octet-stream" + raise ValidationException("Invalid file input: not a data URI") try: header, b64 = data_uri.split(",", 1) except ValueError: - return "file.bin", data_uri, "application/octet-stream" + raise ValidationException("Invalid data URI format") if ";base64" not in header: - return "file.bin", data_uri, "application/octet-stream" + raise ValidationException("Invalid data URI: missing base64 marker") mime = header[5:].split(";", 1)[0] or "application/octet-stream" b64 = re.sub(r"\s+", "", b64) + if not mime or not b64: + raise ValidationException("Invalid data URI: empty content") ext = mime.split("/")[-1] if "/" in mime else "bin" return f"file.{ext}", b64, mime @@ -95,7 +197,10 @@ async def check_format(self, file_input: str) -> Tuple[str, str, str]: if self._is_url(file_input): return await self.parse_b64(file_input) - return self.format_b64(file_input) + if file_input.startswith("data:"): + return self.format_b64(file_input) + + raise ValidationException("Invalid file input: must be URL or base64") async def upload_file(self, file_input: str, token: str) -> Tuple[str, str]: """ @@ -108,7 +213,7 @@ async def upload_file(self, file_input: str, token: str) -> Tuple[str, str]: Returns: Tuple[str, str]: The file ID and URI. """ - async with _get_assets_semaphore(): + async with _get_upload_semaphore(): filename, b64, mime = await self.check_format(file_input) logger.debug( diff --git a/app/services/reverse/assets_delete.py b/app/services/reverse/assets_delete.py index 794f2ba1..a982f4b9 100644 --- a/app/services/reverse/assets_delete.py +++ b/app/services/reverse/assets_delete.py @@ -48,7 +48,7 @@ async def request(session: AsyncSession, token: str, asset_id: str) -> Any: ) # Curl Config - timeout = get_config("network.timeout") + timeout = get_config("asset.delete_timeout") browser = get_config("security.browser") async def _do_request(): diff --git a/app/services/reverse/assets_download.py b/app/services/reverse/assets_download.py index df0b32d1..e2491a04 100644 --- a/app/services/reverse/assets_download.py +++ b/app/services/reverse/assets_download.py @@ -74,7 +74,7 @@ async def request(session: AsyncSession, token: str, file_path: str) -> Any: headers["Upgrade-Insecure-Requests"] = "1" # Curl Config - timeout = get_config("network.timeout") + timeout = get_config("asset.download_timeout") browser = get_config("security.browser") async def _do_request(): diff --git a/app/services/reverse/assets_list.py b/app/services/reverse/assets_list.py index 07263725..9b7762ff 100644 --- a/app/services/reverse/assets_list.py +++ b/app/services/reverse/assets_list.py @@ -48,7 +48,7 @@ async def request(session: AsyncSession, token: str, params: Dict[str, Any]) -> ) # Curl Config - timeout = get_config("network.timeout") + timeout = get_config("asset.list_timeout") browser = get_config("security.browser") async def _do_request(): diff --git a/app/services/reverse/assets_upload.py b/app/services/reverse/assets_upload.py index 0466be6c..517e7598 100644 --- a/app/services/reverse/assets_upload.py +++ b/app/services/reverse/assets_upload.py @@ -57,7 +57,7 @@ async def request(session: AsyncSession, token: str, fileName: str, fileMimeType } # Curl Config - timeout = get_config("network.timeout") + timeout = get_config("asset.upload_timeout") browser = get_config("security.browser") async def _do_request(): diff --git a/app/static/config/config.js b/app/static/config/config.js index f0282a7c..ee60d64e 100644 --- a/app/static/config/config.js +++ b/app/static/config/config.js @@ -13,12 +13,18 @@ const NUMERIC_FIELDS = new Set([ 'fail_threshold', 'limit_mb', 'save_delay_ms', - 'assets_max_concurrent', + 'upload_concurrent', + 'upload_timeout', + 'download_concurrent', + 'download_timeout', + 'list_concurrent', + 'list_timeout', + 'list_batch_size', + 'delete_concurrent', + 'delete_timeout', + 'delete_batch_size', 'media_max_concurrent', 'usage_max_concurrent', - 'assets_delete_batch_size', - 'assets_batch_size', - 'assets_max_tokens', 'usage_batch_size', 'usage_max_tokens', 'reload_interval_sec', @@ -98,6 +104,19 @@ const LOCALE_MAP = { "enable_auto_clean": { title: "自动清理", desc: "是否启用缓存自动清理,开启后按上限自动回收。" }, "limit_mb": { title: "清理阈值", desc: "缓存大小阈值(MB),超过阈值会触发清理。" } }, + "asset": { + "label": "资产配置", + "upload_concurrent": { title: "上传并发", desc: "上传接口的最大并发数。推荐 30。" }, + "upload_timeout": { title: "上传超时", desc: "上传接口超时时间(秒)。推荐 60。" }, + "download_concurrent": { title: "下载并发", desc: "下载接口的最大并发数。推荐 30。" }, + "download_timeout": { title: "下载超时", desc: "下载接口超时时间(秒)。推荐 60。" }, + "list_concurrent": { title: "查询并发", desc: "资产查询接口的最大并发数。推荐 10。" }, + "list_timeout": { title: "查询超时", desc: "资产查询接口超时时间(秒)。推荐 60。" }, + "list_batch_size": { title: "查询批次大小", desc: "单次查询可处理的 Token 数量。推荐 10。" }, + "delete_concurrent": { title: "删除并发", desc: "资产删除接口的最大并发数。推荐 10。" }, + "delete_timeout": { title: "删除超时", desc: "资产删除接口超时时间(秒)。推荐 60。" }, + "delete_batch_size": { title: "删除批次大小", desc: "单次删除可处理的 Token 数量。推荐 10。" } + }, "performance": { "label": "并发性能", "media_max_concurrent": { title: "Media 并发上限", desc: "视频/媒体生成请求的并发上限。推荐 50。" }, @@ -106,11 +125,7 @@ const LOCALE_MAP = { "nsfw_max_tokens": { title: "NSFW 开启最大数量", desc: "单次批量开启 NSFW 的 Token 数量上限,防止误操作。推荐 1000。" }, "usage_max_concurrent": { title: "Token 刷新并发上限", desc: "批量刷新 Token 用量时的并发请求上限。推荐 25。" }, "usage_batch_size": { title: "Token 刷新批次大小", desc: "批量刷新 Token 用量的单批处理数量。推荐 50。" }, - "usage_max_tokens": { title: "Token 刷新最大数量", desc: "单次批量刷新 Token 用量时的处理数量上限。推荐 1000。" }, - "assets_max_concurrent": { title: "Assets 处理并发上限", desc: "批量查找/删除资产时的并发请求上限。推荐 25。" }, - "assets_batch_size": { title: "Assets 处理批次大小", desc: "批量查找/删除资产时的单批处理数量。推荐 10。" }, - "assets_max_tokens": { title: "Assets 处理最大数量", desc: "单次批量查找/删除资产时的处理数量上限。推荐 1000。" }, - "assets_delete_batch_size": { title: "Assets 单账号删除批量大小", desc: "单账号批量删除资产时的单批并发数量。推荐 10。" } + "usage_max_tokens": { title: "Token 刷新最大数量", desc: "单次批量刷新 Token 用量时的处理数量上限。推荐 1000。" } } }; diff --git a/config.defaults.toml b/config.defaults.toml index 8694424a..e15bed8e 100644 --- a/config.defaults.toml +++ b/config.defaults.toml @@ -109,17 +109,32 @@ enable_auto_clean = true # 缓存大小上限(MB) limit_mb = 1024 +# ==================== Asset ==================== +[asset] +# 上传并发数 +upload_concurrent = 30 +# 上传超时时间(秒) +upload_timeout = 60 +# 下载并发数 +download_concurrent = 30 +# 下载超时时间(秒) +download_timeout = 60 +# 资产查询并发数 +list_concurrent = 10 +# 资产查询超时时间(秒) +list_timeout = 60 +# 资产查询批次大小(Token 维度) +list_batch_size = 10 +# 资产删除并发数 +delete_concurrent = 10 +# 资产删除超时时间(秒) +delete_timeout = 60 +# 资产删除批次大小(Token 维度) +delete_batch_size = 10 + # ==================== 并发性能 ==================== [performance] -# Assets 批量处理并发上限 -assets_max_concurrent = 25 -# Assets 批量删除大小 -assets_delete_batch_size = 10 -# Assets 单批处理数量 -assets_batch_size = 10 -# Assets 单次最大处理数量 -assets_max_tokens = 1000 # Media 生成并发上限 media_max_concurrent = 50 @@ -137,4 +152,3 @@ nsfw_max_concurrent = 10 nsfw_batch_size = 50 # NSFW 单次最大数量 nsfw_max_tokens = 1000 - From 314102dcbadc910fbf449a4b94baa69ce989e89d Mon Sep 17 00:00:00 2001 From: Chenyme <118253778+chenyme@users.noreply.github.com> Date: Fri, 13 Feb 2026 13:00:12 +0800 Subject: [PATCH 15/27] feat: add NSFW configuration and refactor related services for improved management --- app/api/v1/admin.py | 29 +++-- app/core/logger.py | 4 + app/services/grok/batch_services/__init__.py | 4 +- app/services/grok/batch_services/assets.py | 7 +- app/services/grok/batch_services/nsfw.py | 129 +++++++++---------- app/services/grok/batch_services/usage.py | 6 +- app/services/grok/defaults.py | 7 +- app/services/grok/services/chat.py | 2 +- app/services/grok/services/video.py | 3 +- app/services/grok/services/voice.py | 2 +- app/services/grok/utils/__init__.py | 0 app/services/grok/utils/download.py | 2 +- app/services/grok/utils/upload.py | 2 +- app/services/reverse/accept_tos.py | 118 +++++++++++++++++ app/services/reverse/nsfw_mgmt.py | 12 +- app/services/reverse/rate_limits.py | 8 -- app/services/reverse/set_birth.py | 14 +- app/services/reverse/utils/grpc.py | 48 +++++++ app/services/reverse/utils/headers.py | 9 +- app/services/token/manager.py | 9 ++ app/services/token/service.py | 27 ++-- app/static/config/config.js | 12 +- config.defaults.toml | 14 +- data/config.toml | 27 +++- 24 files changed, 341 insertions(+), 154 deletions(-) delete mode 100644 app/services/grok/utils/__init__.py create mode 100644 app/services/reverse/accept_tos.py diff --git a/app/api/v1/admin.py b/app/api/v1/admin.py index ecd5caa4..cc9c6fe4 100644 --- a/app/api/v1/admin.py +++ b/app/api/v1/admin.py @@ -8,7 +8,7 @@ WebSocketDisconnect, ) from fastapi.responses import HTMLResponse, StreamingResponse, RedirectResponse -from typing import Optional +from typing import Optional, List, Tuple from pydantic import BaseModel from app.core.auth import verify_api_key, verify_app_key, get_admin_api_key from app.core.config import config, get_config @@ -16,7 +16,8 @@ from app.core.storage import get_storage, LocalStorage, RedisStorage, SQLStorage from app.core.exceptions import AppException from app.services.token.manager import get_token_manager -from app.services.grok.batch_services import BatchUsageService, BatchNSFWService +from app.services.grok.batch_services import BatchUsageService +from app.services.grok.batch_services.nsfw import NSFWService from app.services.grok.batch_services.assets import ListService, DeleteService import os import time @@ -87,7 +88,7 @@ async def _delete_imagine_session(task_id: str) -> None: _IMAGINE_SESSIONS.pop(task_id, None) -async def _delete_imagine_sessions(task_ids: list[str]) -> int: +async def _delete_imagine_sessions(task_ids: List[str]) -> int: if not task_ids: return 0 removed = 0 @@ -99,7 +100,7 @@ async def _delete_imagine_sessions(task_ids: list[str]) -> int: return removed -def _collect_tokens(data: dict) -> list[str]: +def _collect_tokens(data: dict) -> List[str]: """从请求数据中收集 token 列表""" tokens = [] if isinstance(data.get("token"), str) and data["token"].strip(): @@ -110,8 +111,8 @@ def _collect_tokens(data: dict) -> list[str]: def _truncate_tokens( - tokens: list[str], max_tokens: int, operation: str = "operation" -) -> tuple[list[str], bool, int]: + tokens: List[str], max_tokens: int, operation: str = "operation" +) -> Tuple[List[str], bool, int]: """去重并截断 token 列表,返回 (unique_tokens, truncated, original_count)""" unique_tokens = list(dict.fromkeys(tokens)) original_count = len(unique_tokens) @@ -538,7 +539,7 @@ async def admin_imagine_start(data: ImagineStartRequest): class ImagineStopRequest(BaseModel): - task_ids: list[str] + task_ids: List[str] @router.post("/api/v1/admin/imagine/stop", dependencies=[Depends(verify_api_key)]) @@ -970,10 +971,10 @@ async def enable_nsfw_api(data: dict): ) # 批量执行配置 - max_concurrent = get_config("performance.nsfw_max_concurrent") - batch_size = get_config("performance.nsfw_batch_size") + max_concurrent = get_config("nsfw.concurrent") + batch_size = get_config("nsfw.batch_size") - raw_results = await BatchNSFWService.enable( + raw_results = await NSFWService.batch( unique_tokens, mgr, max_concurrent=max_concurrent, @@ -1043,8 +1044,8 @@ async def enable_nsfw_api_async(data: dict): tokens, max_tokens, "NSFW enable" ) - max_concurrent = get_config("performance.nsfw_max_concurrent") - batch_size = get_config("performance.nsfw_batch_size") + max_concurrent = get_config("nsfw.concurrent") + batch_size = get_config("nsfw.batch_size") task = create_task(len(unique_tokens)) @@ -1055,7 +1056,7 @@ async def _on_item(item: str, res: dict): ok = bool(res.get("ok") and res.get("data", {}).get("success")) task.record(ok) - raw_results = await BatchNSFWService.enable( + raw_results = await NSFWService.batch( unique_tokens, mgr, max_concurrent=max_concurrent, @@ -1328,7 +1329,7 @@ async def load_online_cache_api_async(data: dict): tokens = data.get("tokens") scope = data.get("scope") - selected_tokens: list[str] = [] + selected_tokens: List[str] = [] if isinstance(tokens, list): selected_tokens = [str(t).strip() for t in tokens if str(t).strip()] diff --git a/app/core/logger.py b/app/core/logger.py index a49b219d..0b0290f7 100644 --- a/app/core/logger.py +++ b/app/core/logger.py @@ -9,6 +9,10 @@ from pathlib import Path from loguru import logger +# Provide logging.Logger compatibility for legacy calls +if not hasattr(logger, "isEnabledFor"): + logger.isEnabledFor = lambda _level: True + # 日志目录 DEFAULT_LOG_DIR = Path(__file__).parent.parent.parent / "logs" LOG_DIR = Path(os.getenv("LOG_DIR", str(DEFAULT_LOG_DIR))) diff --git a/app/services/grok/batch_services/__init__.py b/app/services/grok/batch_services/__init__.py index 7a4107ed..58666583 100644 --- a/app/services/grok/batch_services/__init__.py +++ b/app/services/grok/batch_services/__init__.py @@ -1,5 +1,5 @@ """Batch services.""" from .usage import BatchUsageService -from .nsfw import BatchNSFWService -__all__ = ["BatchUsageService", "BatchNSFWService"] + +__all__ = ["BatchUsageService"] diff --git a/app/services/grok/batch_services/assets.py b/app/services/grok/batch_services/assets.py index 14b2ccc2..5ba9ab35 100644 --- a/app/services/grok/batch_services/assets.py +++ b/app/services/grok/batch_services/assets.py @@ -9,7 +9,8 @@ from app.core.config import get_config from app.core.logger import logger -from app.services.reverse import AssetsListReverse, AssetsDeleteReverse +from app.services.reverse.assets_list import AssetsListReverse +from app.services.reverse.assets_delete import AssetsDeleteReverse from app.services.grok.utils.batch import run_in_batches @@ -102,7 +103,7 @@ async def list(self, token: str) -> Dict[str, List[str] | int]: @staticmethod async def fetch_assets_details( - tokens: list[str], + tokens: List[str], account_map: dict, *, max_concurrent: int, @@ -190,7 +191,7 @@ async def _delete_one(asset_id: str): @staticmethod async def clear_assets( - tokens: list[str], + tokens: List[str], mgr, *, max_concurrent: int, diff --git a/app/services/grok/batch_services/nsfw.py b/app/services/grok/batch_services/nsfw.py index b701bc3a..f53128b3 100644 --- a/app/services/grok/batch_services/nsfw.py +++ b/app/services/grok/batch_services/nsfw.py @@ -2,7 +2,6 @@ Batch NSFW service. """ -from dataclasses import dataclass from typing import Callable, Awaitable, Dict, Any, Optional from curl_cffi.requests import AsyncSession @@ -10,66 +9,16 @@ from app.core.logger import logger from app.core.config import get_config from app.core.exceptions import UpstreamException -from app.services.reverse import NsfwMgmtReverse, SetBirthReverse -from app.services.reverse.utils.grpc import GrpcStatus +from app.services.reverse.accept_tos import AcceptTosReverse +from app.services.reverse.nsfw_mgmt import NsfwMgmtReverse +from app.services.reverse.set_birth import SetBirthReverse from app.services.grok.utils.batch import run_in_batches -@dataclass -class NSFWResult: - """NSFW 操作结果""" - - success: bool - http_status: int - grpc_status: Optional[int] = None - grpc_message: Optional[str] = None - error: Optional[str] = None - - class NSFWService: """NSFW 模式服务""" - - async def enable(self, token: str) -> NSFWResult: - """为单个 token 开启 NSFW 模式""" - try: - browser = get_config("security.browser") - async with AsyncSession(impersonate=browser) as session: - # 先设置出生日期 - try: - await SetBirthReverse.request(session, token) - except UpstreamException as e: - status = None - if e.details and "status" in e.details: - status = e.details["status"] - else: - status = getattr(e, "status_code", None) - return NSFWResult( - success=False, - http_status=status or 0, - error=f"Set birth date failed: {str(e)}", - ) - - # 开启 NSFW - grpc_status: GrpcStatus = await NsfwMgmtReverse.request(session, token) - success = grpc_status.code in (-1, 0) - - return NSFWResult( - success=success, - http_status=200, - grpc_status=grpc_status.code, - grpc_message=grpc_status.message or None, - ) - - except Exception as e: - logger.error(f"NSFW enable failed: {e}") - return NSFWResult(success=False, http_status=0, error=str(e)[:100]) - - -class BatchNSFWService: - """Batch NSFW orchestration.""" - @staticmethod - async def enable( + async def batch( tokens: list[str], mgr, *, @@ -78,19 +27,63 @@ async def enable( on_item: Optional[Callable[[str, Dict[str, Any]], Awaitable[None]]] = None, should_cancel: Optional[Callable[[], bool]] = None, ) -> Dict[str, Dict[str, Any]]: - nsfw_service = NSFWService() - + """Batch enable NSFW.""" async def _enable(token: str): - result = await nsfw_service.enable(token) - if result.success: - await mgr.add_tag(token, "nsfw") - return { - "success": result.success, - "http_status": result.http_status, - "grpc_status": result.grpc_status, - "grpc_message": result.grpc_message, - "error": result.error, - } + try: + browser = get_config("security.browser") + async with AsyncSession(impersonate=browser) as session: + async def _record_fail(err: UpstreamException, reason: str): + status = None + if err.details and "status" in err.details: + status = err.details["status"] + else: + status = getattr(err, "status_code", None) + if status in (401, 403): + await mgr.record_fail(token, status, reason) + return status or 0 + + try: + await AcceptTosReverse.request(session, token) + except UpstreamException as e: + status = await _record_fail(e, "tos_auth_failed") + return { + "success": False, + "http_status": status, + "error": f"Accept ToS failed: {str(e)}", + } + + try: + await SetBirthReverse.request(session, token) + except UpstreamException as e: + status = await _record_fail(e, "set_birth_auth_failed") + return { + "success": False, + "http_status": status, + "error": f"Set birth date failed: {str(e)}", + } + + try: + grpc_status = await NsfwMgmtReverse.request(session, token) + success = grpc_status.code in (-1, 0) + except UpstreamException as e: + status = await _record_fail(e, "nsfw_mgmt_auth_failed") + return { + "success": False, + "http_status": status, + "error": f"NSFW enable failed: {str(e)}", + } + if success: + await mgr.add_tag(token, "nsfw") + return { + "success": success, + "http_status": 200, + "grpc_status": grpc_status.code, + "grpc_message": grpc_status.message or None, + "error": None, + } + except Exception as e: + logger.error(f"NSFW enable failed: {e}") + return {"success": False, "http_status": 0, "error": str(e)[:100]} return await run_in_batches( tokens, @@ -102,4 +95,4 @@ async def _enable(token: str): ) -__all__ = ["BatchNSFWService", "NSFWService", "NSFWResult"] +__all__ = ["NSFWService"] diff --git a/app/services/grok/batch_services/usage.py b/app/services/grok/batch_services/usage.py index edac6f4e..3bf66ba3 100644 --- a/app/services/grok/batch_services/usage.py +++ b/app/services/grok/batch_services/usage.py @@ -3,13 +3,13 @@ """ import asyncio -from typing import Callable, Awaitable, Dict, Any, Optional +from typing import Callable, Awaitable, Dict, Any, Optional, List from curl_cffi.requests import AsyncSession from app.core.logger import logger from app.core.config import get_config -from app.services.reverse import RateLimitsReverse +from app.services.reverse.rate_limits import RateLimitsReverse from app.services.grok.utils.batch import run_in_batches _USAGE_SEMAPHORE = asyncio.Semaphore(25) @@ -63,7 +63,7 @@ class BatchUsageService: @staticmethod async def refresh( - tokens: list[str], + tokens: List[str], mgr, *, max_concurrent: int, diff --git a/app/services/grok/defaults.py b/app/services/grok/defaults.py index b273dcc6..4307fa2e 100644 --- a/app/services/grok/defaults.py +++ b/app/services/grok/defaults.py @@ -74,13 +74,16 @@ "delete_timeout": 60, "delete_batch_size": 10, }, + "nsfw": { + "concurrent": 10, + "batch_size": 50, + "timeout": 60, + }, "performance": { "media_max_concurrent": 50, "usage_max_concurrent": 25, "usage_batch_size": 50, "usage_max_tokens": 1000, - "nsfw_max_concurrent": 10, - "nsfw_batch_size": 50, "nsfw_max_tokens": 1000, }, } diff --git a/app/services/grok/services/chat.py b/app/services/grok/services/chat.py index aa439239..0dd27ce5 100644 --- a/app/services/grok/services/chat.py +++ b/app/services/grok/services/chat.py @@ -18,7 +18,7 @@ from app.services.grok.models.model import ModelService from app.services.grok.utils.upload import UploadService from app.services.grok.processors import StreamProcessor, CollectProcessor -from app.services.reverse import AppChatReverse +from app.services.reverse.app_chat import AppChatReverse from app.services.grok.utils.stream import wrap_stream_with_usage from app.services.token import get_token_manager, EffortType diff --git a/app/services/grok/services/video.py b/app/services/grok/services/video.py index cd65c8ad..56ff5928 100644 --- a/app/services/grok/services/video.py +++ b/app/services/grok/services/video.py @@ -18,7 +18,8 @@ from app.services.token import get_token_manager, EffortType from app.services.grok.processors import VideoStreamProcessor, VideoCollectProcessor from app.services.grok.utils.stream import wrap_stream_with_usage -from app.services.reverse import AppChatReverse, MediaPostReverse +from app.services.reverse.app_chat import AppChatReverse +from app.services.reverse.media_post import MediaPostReverse _MEDIA_SEMAPHORE = None _MEDIA_SEM_VALUE = 0 diff --git a/app/services/grok/services/voice.py b/app/services/grok/services/voice.py index 81515dc2..a6dc5ed6 100644 --- a/app/services/grok/services/voice.py +++ b/app/services/grok/services/voice.py @@ -7,7 +7,7 @@ from curl_cffi.requests import AsyncSession from app.core.config import get_config -from app.services.reverse import LivekitTokenReverse +from app.services.reverse.ws_livekit import LivekitTokenReverse class VoiceService: diff --git a/app/services/grok/utils/__init__.py b/app/services/grok/utils/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/app/services/grok/utils/download.py b/app/services/grok/utils/download.py index fbb7872f..a13c7179 100644 --- a/app/services/grok/utils/download.py +++ b/app/services/grok/utils/download.py @@ -19,7 +19,7 @@ from app.core.storage import DATA_DIR from app.core.config import get_config from app.core.exceptions import AppException -from app.services.reverse import AssetsDownloadReverse +from app.services.reverse.assets_download import AssetsDownloadReverse from app.services.grok.utils.locks import _get_download_semaphore, _file_lock diff --git a/app/services/grok/utils/upload.py b/app/services/grok/utils/upload.py index ee17cc42..96707923 100644 --- a/app/services/grok/utils/upload.py +++ b/app/services/grok/utils/upload.py @@ -19,7 +19,7 @@ from app.core.exceptions import AppException, UpstreamException, ValidationException from app.core.logger import logger from app.core.storage import DATA_DIR -from app.services.reverse import AssetsUploadReverse +from app.services.reverse.assets_upload import AssetsUploadReverse from app.services.grok.utils.locks import _get_upload_semaphore, _file_lock diff --git a/app/services/reverse/accept_tos.py b/app/services/reverse/accept_tos.py new file mode 100644 index 00000000..203e1f62 --- /dev/null +++ b/app/services/reverse/accept_tos.py @@ -0,0 +1,118 @@ +""" +Reverse interface: accept ToS (gRPC-Web). +""" + +from curl_cffi.requests import AsyncSession + +from app.core.logger import logger +from app.core.config import get_config +from app.core.exceptions import UpstreamException +from app.services.reverse.utils.headers import build_headers +from app.services.reverse.utils.retry import retry_on_status +from app.services.reverse.utils.grpc import GrpcClient, GrpcStatus + +ACCEPT_TOS_API = "https://accounts.x.ai/auth_mgmt.AuthManagement/SetTosAcceptedVersion" + + +class AcceptTosReverse: + """/auth_mgmt.AuthManagement/SetTosAcceptedVersion reverse interface.""" + + @staticmethod + async def request(session: AsyncSession, token: str) -> GrpcStatus: + """Accept ToS via gRPC-Web. + + Args: + session: AsyncSession, the session to use for the request. + token: str, the SSO token. + + Returns: + GrpcStatus: Parsed gRPC status. + """ + try: + # Get proxies + base_proxy = get_config("network.base_proxy_url") + proxies = {"http": base_proxy, "https": base_proxy} if base_proxy else None + + # Build headers + headers = build_headers( + cookie_token=token, + origin="https://accounts.x.ai", + referer="https://accounts.x.ai/accept-tos", + ) + headers["Content-Type"] = "application/grpc-web+proto" + headers["Accept"] = "*/*" + headers["Sec-Fetch-Dest"] = "empty" + headers["x-grpc-web"] = "1" + headers["x-user-agent"] = "connect-es/2.1.1" + headers["Cache-Control"] = "no-cache" + headers["Pragma"] = "no-cache" + + # Build payload + payload = GrpcClient.encode_payload(b"\x10\x01") + + # Curl Config + timeout = get_config("nsfw.timeout") + browser = get_config("security.browser") + + async def _do_request(): + response = await session.post( + ACCEPT_TOS_API, + headers=headers, + data=payload, + timeout=timeout, + proxies=proxies, + impersonate=browser, + ) + + if response.status_code != 200: + logger.error( + f"AcceptTosReverse: Request failed, {response.status_code}", + extra={"error_type": "UpstreamException"}, + ) + raise UpstreamException( + message=f"AcceptTosReverse: Request failed, {response.status_code}", + details={"status": response.status_code}, + ) + + logger.debug(f"AcceptTosReverse: Request successful, {response.status_code}") + + return response + + response = await retry_on_status(_do_request) + + _, trailers = GrpcClient.parse_response( + response.content, + content_type=response.headers.get("content-type"), + headers=response.headers, + ) + grpc_status = GrpcClient.get_status(trailers) + + if grpc_status.code not in (-1, 0): + raise UpstreamException( + message=f"AcceptTosReverse: gRPC failed, {grpc_status.code}", + details={ + "status": grpc_status.http_equiv, + "grpc_status": grpc_status.code, + "grpc_message": grpc_status.message, + }, + ) + + return grpc_status + + except Exception as e: + # Handle upstream exception + if isinstance(e, UpstreamException): + raise + + # Handle other non-upstream exceptions + logger.error( + f"AcceptTosReverse: Request failed, {str(e)}", + extra={"error_type": type(e).__name__}, + ) + raise UpstreamException( + message=f"AcceptTosReverse: Request failed, {str(e)}", + details={"status": 502, "error": str(e)}, + ) + + +__all__ = ["AcceptTosReverse"] diff --git a/app/services/reverse/nsfw_mgmt.py b/app/services/reverse/nsfw_mgmt.py index 349e2417..8056b231 100644 --- a/app/services/reverse/nsfw_mgmt.py +++ b/app/services/reverse/nsfw_mgmt.py @@ -7,7 +7,6 @@ from app.core.logger import logger from app.core.config import get_config from app.core.exceptions import UpstreamException -from app.services.token.service import TokenService from app.services.reverse.utils.headers import build_headers from app.services.reverse.utils.retry import retry_on_status from app.services.reverse.utils.grpc import GrpcClient, GrpcStatus @@ -55,7 +54,7 @@ async def request(session: AsyncSession, token: str) -> GrpcStatus: payload = GrpcClient.encode_payload(protobuf) # Curl Config - timeout = get_config("network.timeout") + timeout = get_config("nsfw.timeout") browser = get_config("security.browser") async def _do_request(): @@ -78,6 +77,8 @@ async def _do_request(): details={"status": response.status_code}, ) + logger.debug(f"NsfwMgmtReverse: Request successful, {response.status_code}") + return response response = await retry_on_status(_do_request) @@ -109,13 +110,6 @@ async def _do_request(): status = e.details["status"] else: status = getattr(e, "status_code", None) - if status == 401: - try: - await TokenService.record_fail( - token, status, "nsfw_mgmt_auth_failed" - ) - except Exception: - pass raise # Handle other non-upstream exceptions diff --git a/app/services/reverse/rate_limits.py b/app/services/reverse/rate_limits.py index aa852605..5efd3c02 100644 --- a/app/services/reverse/rate_limits.py +++ b/app/services/reverse/rate_limits.py @@ -9,7 +9,6 @@ from app.core.logger import logger from app.core.config import get_config from app.core.exceptions import UpstreamException -from app.services.token.service import TokenService from app.services.reverse.utils.headers import build_headers from app.services.reverse.utils.retry import retry_on_status @@ -85,13 +84,6 @@ async def _do_request(): status = e.details["status"] else: status = getattr(e, "status_code", None) - if status == 401: - try: - await TokenService.record_fail( - token, status, "rate_limits_auth_failed" - ) - except Exception: - pass raise # Handle other non-upstream exceptions diff --git a/app/services/reverse/set_birth.py b/app/services/reverse/set_birth.py index f1b72211..556a331c 100644 --- a/app/services/reverse/set_birth.py +++ b/app/services/reverse/set_birth.py @@ -10,7 +10,6 @@ from app.core.logger import logger from app.core.config import get_config from app.core.exceptions import UpstreamException -from app.services.token.service import TokenService from app.services.reverse.utils.headers import build_headers from app.services.reverse.utils.retry import retry_on_status @@ -46,7 +45,7 @@ async def request(session: AsyncSession, token: str) -> Any: # Build payload today = datetime.date.today() - birth_year = today.year - random.randint(20, 40) + birth_year = today.year - random.randint(20, 48) birth_month = random.randint(1, 12) birth_day = random.randint(1, 28) hour = random.randint(0, 23) @@ -59,7 +58,7 @@ async def request(session: AsyncSession, token: str) -> Any: } # Curl Config - timeout = get_config("network.timeout") + timeout = get_config("nsfw.timeout") browser = get_config("security.browser") async def _do_request(): @@ -82,6 +81,8 @@ async def _do_request(): details={"status": response.status_code}, ) + logger.debug(f"SetBirthReverse: Request successful, {response.status_code}") + return response return await retry_on_status(_do_request) @@ -94,13 +95,6 @@ async def _do_request(): status = e.details["status"] else: status = getattr(e, "status_code", None) - if status == 401: - try: - await TokenService.record_fail( - token, status, "set_birth_auth_failed" - ) - except Exception: - pass raise # Handle other non-upstream exceptions diff --git a/app/services/reverse/utils/grpc.py b/app/services/reverse/utils/grpc.py index 446ba4bb..39eb6787 100644 --- a/app/services/reverse/utils/grpc.py +++ b/app/services/reverse/utils/grpc.py @@ -3,12 +3,15 @@ """ import base64 +import json import re import struct from dataclasses import dataclass from typing import Dict, List, Mapping, Optional, Tuple from urllib.parse import unquote +from app.core.logger import logger + # Base64 正则 B64_RE = re.compile(rb"^[A-Za-z0-9+/=\r\n]+$") @@ -38,6 +41,22 @@ def http_equiv(self) -> int: class GrpcClient: """gRPC-Web helpers wrapper.""" + @staticmethod + def _safe_headers(headers: Optional[Mapping[str, str]]) -> Dict[str, str]: + if not headers: + return {} + safe: Dict[str, str] = {} + for k, v in headers.items(): + if k.lower() in ("set-cookie", "cookie", "authorization"): + safe[k] = "" + else: + safe[k] = str(v) + return safe + + @staticmethod + def _b64(data: bytes) -> str: + return base64.b64encode(data).decode() + @staticmethod def encode_payload(data: bytes) -> bytes: """Encode gRPC-Web data frame.""" @@ -118,6 +137,35 @@ def parse_response( if "grpc-message" in lower and "grpc-message" not in trailers: trailers["grpc-message"] = unquote(str(lower["grpc-message"]).strip()) + # Log full response details on gRPC error + raw_status = str(trailers.get("grpc-status", "")).strip() + try: + status_code = int(raw_status) + except Exception: + status_code = -1 + + if status_code not in (0, -1): + try: + payload = { + "grpc_status": status_code, + "grpc_message": trailers.get("grpc-message", ""), + "content_type": content_type or "", + "headers": cls._safe_headers(headers), + "trailers": trailers, + "messages_b64": [cls._b64(m) for m in messages], + "body_b64": cls._b64(body), + } + logger.error( + "gRPC response error: {}", + json.dumps(payload, ensure_ascii=False), + extra={"error_type": "GrpcError"}, + ) + except Exception as e: + logger.error( + f"gRPC response error: failed to log payload ({e})", + extra={"error_type": "GrpcError"}, + ) + return messages, trailers @staticmethod diff --git a/app/services/reverse/utils/headers.py b/app/services/reverse/utils/headers.py index 4388bf2f..03a8f253 100644 --- a/app/services/reverse/utils/headers.py +++ b/app/services/reverse/utils/headers.py @@ -123,11 +123,10 @@ def build_headers(cookie_token: str, content_type: Optional[str] = None, origin: headers["x-xai-request-id"] = str(uuid.uuid4()) # Print headers without Cookie - if logger.isEnabledFor(10): - safe_headers = dict(headers) - if "Cookie" in safe_headers: - safe_headers["Cookie"] = "" - logger.debug(f"Built headers: {orjson.dumps(safe_headers).decode()}") + safe_headers = dict(headers) + if "Cookie" in safe_headers: + safe_headers["Cookie"] = "" + logger.debug(f"Built headers: {orjson.dumps(safe_headers).decode()}") return headers diff --git a/app/services/token/manager.py b/app/services/token/manager.py index 4f056d8b..68bfd847 100644 --- a/app/services/token/manager.py +++ b/app/services/token/manager.py @@ -16,6 +16,7 @@ ) from app.core.storage import get_storage, LocalStorage from app.core.config import get_config +from app.core.exceptions import UpstreamException from app.services.token.pool import TokenPool from app.services.grok.batch_services.usage import UsageService @@ -382,6 +383,14 @@ async def sync_usage( return True except Exception as e: + if isinstance(e, UpstreamException): + status = None + if e.details and "status" in e.details: + status = e.details["status"] + else: + status = getattr(e, "status_code", None) + if status in (401, 403): + await self.record_fail(token_str, status, "rate_limits_auth_failed") logger.warning( f"Token {raw_token[:10]}...: API sync failed, fallback to local ({e})" ) diff --git a/app/services/token/service.py b/app/services/token/service.py index 63e635c3..b441fbeb 100644 --- a/app/services/token/service.py +++ b/app/services/token/service.py @@ -2,7 +2,6 @@ from typing import List, Optional, Dict -from app.services.token.manager import get_token_manager from app.services.token.models import TokenInfo, EffortType @@ -13,6 +12,12 @@ class TokenService: 提供简化的 API,隐藏内部实现细节 """ + @staticmethod + async def _get_manager(): + from app.services.token.manager import get_token_manager + + return await get_token_manager() + @staticmethod async def get_token(pool_name: str = "ssoBasic") -> Optional[str]: """ @@ -24,7 +29,7 @@ async def get_token(pool_name: str = "ssoBasic") -> Optional[str]: Returns: Token 字符串(不含 sso= 前缀)或 None """ - manager = await get_token_manager() + manager = await TokenService._get_manager() return manager.get_token(pool_name) @staticmethod @@ -39,7 +44,7 @@ async def consume(token: str, effort: EffortType = EffortType.LOW) -> bool: Returns: 是否成功 """ - manager = await get_token_manager() + manager = await TokenService._get_manager() return await manager.consume(token, effort) @staticmethod @@ -54,7 +59,7 @@ async def sync_usage(token: str, effort: EffortType = EffortType.LOW) -> bool: Returns: 是否成功 """ - manager = await get_token_manager() + manager = await TokenService._get_manager() return await manager.sync_usage(token, effort) @staticmethod @@ -70,7 +75,7 @@ async def record_fail(token: str, status_code: int = 401, reason: str = "") -> b Returns: 是否成功 """ - manager = await get_token_manager() + manager = await TokenService._get_manager() return await manager.record_fail(token, status_code, reason) @staticmethod @@ -85,7 +90,7 @@ async def add_token(token: str, pool_name: str = "ssoBasic") -> bool: Returns: 是否成功 """ - manager = await get_token_manager() + manager = await TokenService._get_manager() return await manager.add(token, pool_name) @staticmethod @@ -99,7 +104,7 @@ async def remove_token(token: str) -> bool: Returns: 是否成功 """ - manager = await get_token_manager() + manager = await TokenService._get_manager() return await manager.remove(token) @staticmethod @@ -113,13 +118,13 @@ async def reset_token(token: str) -> bool: Returns: 是否成功 """ - manager = await get_token_manager() + manager = await TokenService._get_manager() return await manager.reset_token(token) @staticmethod async def reset_all(): """重置所有 Token""" - manager = await get_token_manager() + manager = await TokenService._get_manager() await manager.reset_all() @staticmethod @@ -130,7 +135,7 @@ async def get_stats() -> Dict[str, dict]: Returns: 各池的统计信息 """ - manager = await get_token_manager() + manager = await TokenService._get_manager() return manager.get_stats() @staticmethod @@ -144,7 +149,7 @@ async def list_tokens(pool_name: str = "ssoBasic") -> List[TokenInfo]: Returns: Token 列表 """ - manager = await get_token_manager() + manager = await TokenService._get_manager() return manager.get_pool_tokens(pool_name) diff --git a/app/static/config/config.js b/app/static/config/config.js index ee60d64e..815a177b 100644 --- a/app/static/config/config.js +++ b/app/static/config/config.js @@ -33,8 +33,8 @@ const NUMERIC_FIELDS = new Set([ 'image_ws_blocked_seconds', 'image_ws_final_min_bytes', 'image_ws_medium_min_bytes', - 'nsfw_max_concurrent', - 'nsfw_batch_size', + 'concurrent', + 'batch_size', 'nsfw_max_tokens' ]); @@ -117,11 +117,15 @@ const LOCALE_MAP = { "delete_timeout": { title: "删除超时", desc: "资产删除接口超时时间(秒)。推荐 60。" }, "delete_batch_size": { title: "删除批次大小", desc: "单次删除可处理的 Token 数量。推荐 10。" } }, + "nsfw": { + "label": "NSFW 配置", + "concurrent": { title: "并发上限", desc: "批量开启 NSFW 模式时的并发请求上限。推荐 10。" }, + "batch_size": { title: "批次大小", desc: "批量开启 NSFW 模式的单批处理数量。推荐 50。" }, + "timeout": { title: "请求超时", desc: "NSFW 开启相关请求的超时时间(秒)。推荐 60。" } + }, "performance": { "label": "并发性能", "media_max_concurrent": { title: "Media 并发上限", desc: "视频/媒体生成请求的并发上限。推荐 50。" }, - "nsfw_max_concurrent": { title: "NSFW 开启并发上限", desc: "批量开启 NSFW 模式时的并发请求上限。推荐 10。" }, - "nsfw_batch_size": { title: "NSFW 开启批量大小", desc: "批量开启 NSFW 模式的单批处理数量。推荐 50。" }, "nsfw_max_tokens": { title: "NSFW 开启最大数量", desc: "单次批量开启 NSFW 的 Token 数量上限,防止误操作。推荐 1000。" }, "usage_max_concurrent": { title: "Token 刷新并发上限", desc: "批量刷新 Token 用量时的并发请求上限。推荐 25。" }, "usage_batch_size": { title: "Token 刷新批次大小", desc: "批量刷新 Token 用量的单批处理数量。推荐 50。" }, diff --git a/config.defaults.toml b/config.defaults.toml index e15bed8e..7b158566 100644 --- a/config.defaults.toml +++ b/config.defaults.toml @@ -132,6 +132,16 @@ delete_timeout = 60 # 资产删除批次大小(Token 维度) delete_batch_size = 10 +# ==================== NSFW ==================== +[nsfw] + +# NSFW 批量开启并发上限 +concurrent = 10 +# NSFW 批量开启批次大小 +batch_size = 50 +# NSFW 请求超时时间(秒) +timeout = 60 + # ==================== 并发性能 ==================== [performance] @@ -146,9 +156,5 @@ usage_batch_size = 50 # Token 用量单次最大数量 usage_max_tokens = 1000 -# NSFW 批量开启并发上限 -nsfw_max_concurrent = 10 -# NSFW 批量开启批次大小 -nsfw_batch_size = 50 # NSFW 单次最大数量 nsfw_max_tokens = 1000 diff --git a/data/config.toml b/data/config.toml index 93016ea5..f5face4f 100644 --- a/data/config.toml +++ b/data/config.toml @@ -54,15 +54,30 @@ reload_interval_sec = 30 enable_auto_clean = true limit_mb = 1024 +[asset] +upload_concurrent = 30 +upload_timeout = 60 +download_concurrent = 30 +download_timeout = 60 +list_concurrent = 10 +list_timeout = 60 +list_batch_size = 10 +delete_concurrent = 10 +delete_timeout = 60 +delete_batch_size = 10 + +[nsfw] +concurrent = 10 +batch_size = 50 +timeout = 60 + [performance] -assets_max_concurrent = 25 -assets_delete_batch_size = 10 -assets_batch_size = 10 -assets_max_tokens = 1000 media_max_concurrent = 50 usage_max_concurrent = 25 usage_batch_size = 50 usage_max_tokens = 1000 -nsfw_max_concurrent = 10 -nsfw_batch_size = 50 nsfw_max_tokens = 1000 +assets_max_concurrent = 25 +assets_delete_batch_size = 10 +assets_batch_size = 10 +assets_max_tokens = 1000 From 6228788956ad8d345c59445308c8cc3d1608a529 Mon Sep 17 00:00:00 2001 From: Chenyme <118253778+chenyme@users.noreply.github.com> Date: Fri, 13 Feb 2026 13:27:45 +0800 Subject: [PATCH 16/27] refactor: consolidate usage configuration and update related services for improved clarity and performance --- app/api/v1/admin.py | 123 ++++--------------- app/services/grok/batch_services/__init__.py | 4 +- app/services/grok/batch_services/usage.py | 9 +- app/services/grok/defaults.py | 9 +- app/services/grok/services/__init__.py | 0 app/services/reverse/rate_limits.py | 4 +- app/static/config/config.js | 18 ++- config.defaults.toml | 20 ++- data/config.toml | 9 +- 9 files changed, 58 insertions(+), 138 deletions(-) delete mode 100644 app/services/grok/services/__init__.py diff --git a/app/api/v1/admin.py b/app/api/v1/admin.py index cc9c6fe4..83279aa4 100644 --- a/app/api/v1/admin.py +++ b/app/api/v1/admin.py @@ -16,7 +16,7 @@ from app.core.storage import get_storage, LocalStorage, RedisStorage, SQLStorage from app.core.exceptions import AppException from app.services.token.manager import get_token_manager -from app.services.grok.batch_services import BatchUsageService +from app.services.grok.batch_services.usage import UsageService from app.services.grok.batch_services.nsfw import NSFWService from app.services.grok.batch_services.assets import ListService, DeleteService import os @@ -110,22 +110,9 @@ def _collect_tokens(data: dict) -> List[str]: return tokens -def _truncate_tokens( - tokens: List[str], max_tokens: int, operation: str = "operation" -) -> Tuple[List[str], bool, int]: - """去重并截断 token 列表,返回 (unique_tokens, truncated, original_count)""" - unique_tokens = list(dict.fromkeys(tokens)) - original_count = len(unique_tokens) - truncated = False - - if len(unique_tokens) > max_tokens: - unique_tokens = unique_tokens[:max_tokens] - truncated = True - logger.warning( - f"{operation}: truncated from {original_count} to {max_tokens} tokens" - ) - - return unique_tokens, truncated, original_count +def _dedupe_tokens(tokens: List[str]) -> List[str]: + """去重 token 列表(保持原顺序)""" + return list(dict.fromkeys(tokens)) @@ -826,17 +813,14 @@ async def refresh_tokens_api(data: dict): if not tokens: raise HTTPException(status_code=400, detail="No tokens provided") - # 去重并截断 - max_tokens = int(get_config("performance.usage_max_tokens")) - unique_tokens, truncated, original_count = _truncate_tokens( - tokens, max_tokens, "Usage refresh" - ) + # 去重 + unique_tokens = _dedupe_tokens(tokens) # 批量执行配置 - max_concurrent = get_config("performance.usage_max_concurrent") - batch_size = get_config("performance.usage_batch_size") + max_concurrent = get_config("usage.concurrent") + batch_size = get_config("usage.batch_size") - raw_results = await BatchUsageService.refresh( + raw_results = await UsageService.batch( unique_tokens, mgr, max_concurrent=max_concurrent, @@ -851,10 +835,6 @@ async def refresh_tokens_api(data: dict): results[token] = False response = {"status": "success", "results": results} - if truncated: - response["warning"] = ( - f"数量超出限制,仅处理前 {max_tokens} 个(共 {original_count} 个)" - ) return response except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -871,14 +851,11 @@ async def refresh_tokens_api_async(data: dict): if not tokens: raise HTTPException(status_code=400, detail="No tokens provided") - # 去重并截断 - max_tokens = int(get_config("performance.usage_max_tokens")) - unique_tokens, truncated, original_count = _truncate_tokens( - tokens, max_tokens, "Usage refresh" - ) + # 去重 + unique_tokens = _dedupe_tokens(tokens) - max_concurrent = get_config("performance.usage_max_concurrent") - batch_size = get_config("performance.usage_batch_size") + max_concurrent = get_config("usage.concurrent") + batch_size = get_config("usage.batch_size") task = create_task(len(unique_tokens)) @@ -888,7 +865,7 @@ async def _run(): async def _on_item(item: str, res: dict): task.record(bool(res.get("ok"))) - raw_results = await BatchUsageService.refresh( + raw_results = await UsageService.batch( unique_tokens, mgr, max_concurrent=max_concurrent, @@ -923,12 +900,7 @@ async def _on_item(item: str, res: dict): }, "results": results, } - warning = None - if truncated: - warning = ( - f"数量超出限制,仅处理前 {max_tokens} 个(共 {original_count} 个)" - ) - task.finish(result, warning=warning) + task.finish(result) except Exception as e: task.fail_task(str(e)) finally: @@ -964,11 +936,8 @@ async def enable_nsfw_api(data: dict): if not tokens: raise HTTPException(status_code=400, detail="No tokens available") - # 去重并截断 - max_tokens = int(get_config("performance.nsfw_max_tokens")) - unique_tokens, truncated, original_count = _truncate_tokens( - tokens, max_tokens, "NSFW enable" - ) + # 去重 + unique_tokens = _dedupe_tokens(tokens) # 批量执行配置 max_concurrent = get_config("nsfw.concurrent") @@ -1006,11 +975,6 @@ async def enable_nsfw_api(data: dict): } # 添加截断提示 - if truncated: - response["warning"] = ( - f"数量超出限制,仅处理前 {max_tokens} 个(共 {original_count} 个)" - ) - return response except HTTPException: @@ -1038,11 +1002,8 @@ async def enable_nsfw_api_async(data: dict): if not tokens: raise HTTPException(status_code=400, detail="No tokens available") - # 去重并截断 - max_tokens = int(get_config("performance.nsfw_max_tokens")) - unique_tokens, truncated, original_count = _truncate_tokens( - tokens, max_tokens, "NSFW enable" - ) + # 去重 + unique_tokens = _dedupe_tokens(tokens) max_concurrent = get_config("nsfw.concurrent") batch_size = get_config("nsfw.batch_size") @@ -1092,12 +1053,7 @@ async def _on_item(item: str, res: dict): }, "results": results, } - warning = None - if truncated: - warning = ( - f"数量超出限制,仅处理前 {max_tokens} 个(共 {original_count} 个)" - ) - task.finish(result, warning=warning) + task.finish(result) except Exception as e: task.fail_task(str(e)) finally: @@ -1169,9 +1125,6 @@ async def get_cache_stats_api(request: Request): account_map = {a["token"]: a for a in accounts} batch_size = max(1, int(get_config("asset.list_batch_size"))) max_concurrent = batch_size - truncated = False - original_count = 0 - if selected_tokens: total = 0 raw_results = await ListService.fetch_assets_details( @@ -1208,7 +1161,6 @@ async def get_cache_stats_api(request: Request): elif scope == "all": total = 0 tokens = list(dict.fromkeys([account["token"] for account in accounts])) - original_count = len(tokens) raw_results = await ListService.fetch_assets_details( tokens, account_map, @@ -1286,10 +1238,6 @@ async def get_cache_stats_api(request: Request): "online_scope": scope or "none", "online_details": online_details, } - if truncated: - response["warning"] = ( - f"数量超出限制,仅处理前 {max_tokens} 个(共 {original_count} 个)" - ) return response except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -1341,9 +1289,6 @@ async def load_online_cache_api_async(data: dict): else: raise HTTPException(status_code=400, detail="No tokens provided") - truncated = False - original_count = len(selected_tokens) - batch_size = get_config("asset.list_batch_size") max_concurrent = batch_size @@ -1397,12 +1342,7 @@ async def _on_item(item: str, res: dict): "online_scope": scope or "none", "online_details": online_details, } - warning = None - if truncated: - warning = ( - f"数量超出限制,仅处理前 {max_tokens} 个(共 {original_count} 个)" - ) - task.finish(result, warning=warning) + task.finish(result) except Exception as e: task.fail_task(str(e)) finally: @@ -1485,10 +1425,6 @@ async def clear_online_cache_api(data: dict): # 去重并保持顺序 token_list = list(dict.fromkeys(token_list)) - # 最大数量限制 - truncated = False - original_count = len(token_list) - results = {} batch_size = max(1, int(get_config("asset.delete_batch_size"))) max_concurrent = batch_size @@ -1505,12 +1441,7 @@ async def clear_online_cache_api(data: dict): else: results[token] = {"status": "error", "error": res.get("error")} - response = {"status": "success", "results": results} - if truncated: - response["warning"] = ( - f"数量超出限制,仅处理前 {max_tokens} 个(共 {original_count} 个)" - ) - return response + return {"status": "success", "results": results} token = data.get("token") or mgr.get_token() if not token: @@ -1549,9 +1480,6 @@ async def clear_online_cache_api_async(data: dict): if not token_list: raise HTTPException(status_code=400, detail="No tokens provided") - truncated = False - original_count = len(token_list) - batch_size = get_config("asset.delete_batch_size") max_concurrent = batch_size @@ -1598,12 +1526,7 @@ async def _on_item(item: str, res: dict): }, "results": results, } - warning = None - if truncated: - warning = ( - f"数量超出限制,仅处理前 {max_tokens} 个(共 {original_count} 个)" - ) - task.finish(result, warning=warning) + task.finish(result) except Exception as e: task.fail_task(str(e)) finally: diff --git a/app/services/grok/batch_services/__init__.py b/app/services/grok/batch_services/__init__.py index 58666583..d71f450e 100644 --- a/app/services/grok/batch_services/__init__.py +++ b/app/services/grok/batch_services/__init__.py @@ -1,5 +1,5 @@ """Batch services.""" -from .usage import BatchUsageService +from .usage import UsageService -__all__ = ["BatchUsageService"] +__all__ = ["UsageService"] diff --git a/app/services/grok/batch_services/usage.py b/app/services/grok/batch_services/usage.py index 3bf66ba3..ddefb742 100644 --- a/app/services/grok/batch_services/usage.py +++ b/app/services/grok/batch_services/usage.py @@ -32,7 +32,7 @@ async def get(self, token: str) -> Dict: Raises: UpstreamException: 当获取失败且重试耗尽时 """ - value = get_config("performance.usage_max_concurrent") + value = get_config("usage.concurrent") try: value = int(value) except Exception: @@ -58,11 +58,8 @@ async def get(self, token: str) -> Dict: raise -class BatchUsageService: - """Batch usage orchestration.""" - @staticmethod - async def refresh( + async def batch( tokens: List[str], mgr, *, @@ -84,4 +81,4 @@ async def _refresh_one(t: str): ) -__all__ = ["BatchUsageService", "UsageService"] +__all__ = ["UsageService"] diff --git a/app/services/grok/defaults.py b/app/services/grok/defaults.py index 4307fa2e..3a66dd2d 100644 --- a/app/services/grok/defaults.py +++ b/app/services/grok/defaults.py @@ -79,12 +79,13 @@ "batch_size": 50, "timeout": 60, }, + "usage": { + "concurrent": 10, + "batch_size": 50, + "timeout": 60, + }, "performance": { "media_max_concurrent": 50, - "usage_max_concurrent": 25, - "usage_batch_size": 50, - "usage_max_tokens": 1000, - "nsfw_max_tokens": 1000, }, } diff --git a/app/services/grok/services/__init__.py b/app/services/grok/services/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/app/services/reverse/rate_limits.py b/app/services/reverse/rate_limits.py index 5efd3c02..48b8c220 100644 --- a/app/services/reverse/rate_limits.py +++ b/app/services/reverse/rate_limits.py @@ -49,7 +49,9 @@ async def request(session: AsyncSession, token: str) -> Any: } # Curl Config - timeout = get_config("network.timeout") + timeout = get_config("usage.timeout") + if timeout is None: + timeout = get_config("network.timeout") browser = get_config("security.browser") async def _do_request(): diff --git a/app/static/config/config.js b/app/static/config/config.js index 815a177b..d1ddeeb6 100644 --- a/app/static/config/config.js +++ b/app/static/config/config.js @@ -24,9 +24,6 @@ const NUMERIC_FIELDS = new Set([ 'delete_timeout', 'delete_batch_size', 'media_max_concurrent', - 'usage_max_concurrent', - 'usage_batch_size', - 'usage_max_tokens', 'reload_interval_sec', 'stream_idle_timeout', 'video_idle_timeout', @@ -34,8 +31,7 @@ const NUMERIC_FIELDS = new Set([ 'image_ws_final_min_bytes', 'image_ws_medium_min_bytes', 'concurrent', - 'batch_size', - 'nsfw_max_tokens' + 'batch_size' ]); const LOCALE_MAP = { @@ -123,13 +119,15 @@ const LOCALE_MAP = { "batch_size": { title: "批次大小", desc: "批量开启 NSFW 模式的单批处理数量。推荐 50。" }, "timeout": { title: "请求超时", desc: "NSFW 开启相关请求的超时时间(秒)。推荐 60。" } }, + "usage": { + "label": "Usage 配置", + "concurrent": { title: "并发上限", desc: "批量刷新用量时的并发请求上限。推荐 10。" }, + "batch_size": { title: "批次大小", desc: "批量刷新用量的单批处理数量。推荐 50。" }, + "timeout": { title: "请求超时", desc: "用量查询接口的超时时间(秒)。推荐 60。" } + }, "performance": { "label": "并发性能", - "media_max_concurrent": { title: "Media 并发上限", desc: "视频/媒体生成请求的并发上限。推荐 50。" }, - "nsfw_max_tokens": { title: "NSFW 开启最大数量", desc: "单次批量开启 NSFW 的 Token 数量上限,防止误操作。推荐 1000。" }, - "usage_max_concurrent": { title: "Token 刷新并发上限", desc: "批量刷新 Token 用量时的并发请求上限。推荐 25。" }, - "usage_batch_size": { title: "Token 刷新批次大小", desc: "批量刷新 Token 用量的单批处理数量。推荐 50。" }, - "usage_max_tokens": { title: "Token 刷新最大数量", desc: "单次批量刷新 Token 用量时的处理数量上限。推荐 1000。" } + "media_max_concurrent": { title: "Media 并发上限", desc: "视频/媒体生成请求的并发上限。推荐 50。" } } }; diff --git a/config.defaults.toml b/config.defaults.toml index 7b158566..959df3ee 100644 --- a/config.defaults.toml +++ b/config.defaults.toml @@ -134,7 +134,6 @@ delete_batch_size = 10 # ==================== NSFW ==================== [nsfw] - # NSFW 批量开启并发上限 concurrent = 10 # NSFW 批量开启批次大小 @@ -142,19 +141,18 @@ batch_size = 50 # NSFW 请求超时时间(秒) timeout = 60 +# ==================== Usage ==================== +[usage] +# Usage 批量开启并发上限 +concurrent = 10 +# Usage 批量开启批次大小 +batch_size = 50 +# Usage 请求超时时间(秒) +timeout = 60 + # ==================== 并发性能 ==================== [performance] # Media 生成并发上限 media_max_concurrent = 50 - -# Token 用量刷新并发上限 -usage_max_concurrent = 25 -# Token 用量刷新批次大小 -usage_batch_size = 50 -# Token 用量单次最大数量 -usage_max_tokens = 1000 - -# NSFW 单次最大数量 -nsfw_max_tokens = 1000 diff --git a/data/config.toml b/data/config.toml index f5face4f..92187997 100644 --- a/data/config.toml +++ b/data/config.toml @@ -71,12 +71,13 @@ concurrent = 10 batch_size = 50 timeout = 60 +[usage] +concurrent = 10 +batch_size = 50 +timeout = 60 + [performance] media_max_concurrent = 50 -usage_max_concurrent = 25 -usage_batch_size = 50 -usage_max_tokens = 1000 -nsfw_max_tokens = 1000 assets_max_concurrent = 25 assets_delete_batch_size = 10 assets_batch_size = 10 From c62faf193b3e2a58c1ac8d30391bc7a58d6674ed Mon Sep 17 00:00:00 2001 From: Chenyme <118253778+chenyme@users.noreply.github.com> Date: Fri, 13 Feb 2026 13:36:16 +0800 Subject: [PATCH 17/27] refactor: update import paths for ModelService and remove unused model files --- app/api/v1/admin.py | 2 +- app/api/v1/chat.py | 2 +- app/api/v1/image.py | 2 +- app/api/v1/models.py | 2 +- app/services/grok/models/__init__.py | 0 app/services/grok/services/chat.py | 2 +- app/services/grok/{models => services}/model.py | 0 app/services/grok/services/video.py | 2 +- app/services/grok/utils/stream.py | 2 +- 9 files changed, 7 insertions(+), 7 deletions(-) delete mode 100644 app/services/grok/models/__init__.py rename app/services/grok/{models => services}/model.py (100%) diff --git a/app/api/v1/admin.py b/app/api/v1/admin.py index 83279aa4..f76f959c 100644 --- a/app/api/v1/admin.py +++ b/app/api/v1/admin.py @@ -30,7 +30,7 @@ from app.api.v1.image import resolve_aspect_ratio from app.services.grok.services.voice import VoiceService from app.services.grok.services.image import ImageGenerationService -from app.services.grok.models.model import ModelService +from app.services.grok.services.model import ModelService TEMPLATE_DIR = Path(__file__).parent.parent.parent / "static" diff --git a/app/api/v1/chat.py b/app/api/v1/chat.py index ba5a820f..b1a70354 100644 --- a/app/api/v1/chat.py +++ b/app/api/v1/chat.py @@ -9,7 +9,7 @@ from pydantic import BaseModel, Field, field_validator from app.services.grok.services.chat import ChatService -from app.services.grok.models.model import ModelService +from app.services.grok.services.model import ModelService from app.core.exceptions import ValidationException diff --git a/app/api/v1/image.py b/app/api/v1/image.py index 6e8ab670..1d7a3354 100644 --- a/app/api/v1/image.py +++ b/app/api/v1/image.py @@ -13,7 +13,7 @@ from app.services.grok.services.image import ImageGenerationService from app.services.grok.services.image_edit import ImageEditService -from app.services.grok.models.model import ModelService +from app.services.grok.services.model import ModelService from app.services.token import get_token_manager from app.core.exceptions import ValidationException, AppException, ErrorType from app.core.config import get_config diff --git a/app/api/v1/models.py b/app/api/v1/models.py index 13971669..babf35eb 100644 --- a/app/api/v1/models.py +++ b/app/api/v1/models.py @@ -4,7 +4,7 @@ from fastapi import APIRouter -from app.services.grok.models.model import ModelService +from app.services.grok.services.model import ModelService router = APIRouter(tags=["Models"]) diff --git a/app/services/grok/models/__init__.py b/app/services/grok/models/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/app/services/grok/services/chat.py b/app/services/grok/services/chat.py index 0dd27ce5..8fe1e1aa 100644 --- a/app/services/grok/services/chat.py +++ b/app/services/grok/services/chat.py @@ -15,7 +15,7 @@ ErrorType, UpstreamException, ) -from app.services.grok.models.model import ModelService +from app.services.grok.services.model import ModelService from app.services.grok.utils.upload import UploadService from app.services.grok.processors import StreamProcessor, CollectProcessor from app.services.reverse.app_chat import AppChatReverse diff --git a/app/services/grok/models/model.py b/app/services/grok/services/model.py similarity index 100% rename from app/services/grok/models/model.py rename to app/services/grok/services/model.py diff --git a/app/services/grok/services/video.py b/app/services/grok/services/video.py index 56ff5928..b848045a 100644 --- a/app/services/grok/services/video.py +++ b/app/services/grok/services/video.py @@ -14,7 +14,7 @@ ValidationException, ErrorType, ) -from app.services.grok.models.model import ModelService +from app.services.grok.services.model import ModelService from app.services.token import get_token_manager, EffortType from app.services.grok.processors import VideoStreamProcessor, VideoCollectProcessor from app.services.grok.utils.stream import wrap_stream_with_usage diff --git a/app/services/grok/utils/stream.py b/app/services/grok/utils/stream.py index c9e64dd2..053c18d9 100644 --- a/app/services/grok/utils/stream.py +++ b/app/services/grok/utils/stream.py @@ -5,7 +5,7 @@ from typing import AsyncGenerator from app.core.logger import logger -from app.services.grok.models.model import ModelService +from app.services.grok.services.model import ModelService from app.services.token import EffortType From e5df9ca3236faf8115d0337ab8d772e92074b7d9 Mon Sep 17 00:00:00 2001 From: Chenyme <118253778+chenyme@users.noreply.github.com> Date: Fri, 13 Feb 2026 14:01:40 +0800 Subject: [PATCH 18/27] refactor: reorganize batch processing utilities and update import paths across services --- app/api/v1/admin.py | 2 +- app/core/{batch_tasks.py => batch.py} | 99 +++++++++++++++++++++- app/services/grok/batch_services/assets.py | 6 +- app/services/grok/batch_services/nsfw.py | 4 +- app/services/grok/batch_services/usage.py | 4 +- app/services/grok/utils/batch.py | 86 ------------------- 6 files changed, 103 insertions(+), 98 deletions(-) rename app/core/{batch_tasks.py => batch.py} (58%) delete mode 100644 app/services/grok/utils/batch.py diff --git a/app/api/v1/admin.py b/app/api/v1/admin.py index f76f959c..6daa3d95 100644 --- a/app/api/v1/admin.py +++ b/app/api/v1/admin.py @@ -12,7 +12,7 @@ from pydantic import BaseModel from app.core.auth import verify_api_key, verify_app_key, get_admin_api_key from app.core.config import config, get_config -from app.core.batch_tasks import create_task, get_task, expire_task +from app.core.batch import create_task, get_task, expire_task from app.core.storage import get_storage, LocalStorage, RedisStorage, SQLStorage from app.core.exceptions import AppException from app.services.token.manager import get_token_manager diff --git a/app/core/batch_tasks.py b/app/core/batch.py similarity index 58% rename from app/core/batch_tasks.py rename to app/core/batch.py index ff564ffb..4857250d 100644 --- a/app/core/batch_tasks.py +++ b/app/core/batch.py @@ -1,13 +1,94 @@ """ -Batch task manager for admin batch operations (SSE progress). -""" +Batch utilities. -from __future__ import annotations +- run_batch: generic batch concurrency runner +- BatchTask: SSE task manager for admin batch operations +""" import asyncio import time import uuid -from typing import Any, Dict, List, Optional +from typing import Any, Awaitable, Callable, Dict, List, Optional, TypeVar + +from app.core.logger import logger + +T = TypeVar("T") + + +async def run_batch( + items: List[str], + worker: Callable[[str], Awaitable[T]], + *, + max_concurrent: int = 10, + batch_size: int = 50, + task: Optional["BatchTask"] = None, + on_item: Optional[Callable[[str, Dict[str, Any]], Awaitable[None]]] = None, + should_cancel: Optional[Callable[[], bool]] = None, +) -> Dict[str, Dict[str, Any]]: + """ + 分批并发执行,单项失败不影响整体 + + Args: + items: 待处理项列表 + worker: 异步处理函数 + max_concurrent: 最大并发数 + batch_size: 每批大小 + + Returns: + {item: {"ok": bool, "data": ..., "error": ...}} + """ + try: + max_concurrent = int(max_concurrent) + except Exception: + max_concurrent = 10 + try: + batch_size = int(batch_size) + except Exception: + batch_size = 50 + + max_concurrent = max(1, max_concurrent) + batch_size = max(1, batch_size) + + sem = asyncio.Semaphore(max_concurrent) + + async def _one(item: str) -> tuple[str, dict]: + if (should_cancel and should_cancel()) or (task and task.cancelled): + return item, {"ok": False, "error": "cancelled", "cancelled": True} + async with sem: + try: + data = await worker(item) + result = {"ok": True, "data": data} + if task: + task.record(True) + if on_item: + try: + await on_item(item, result) + except Exception: + pass + return item, result + except Exception as e: + logger.warning(f"Batch item failed: {item[:16]}... - {e}") + result = {"ok": False, "error": str(e)} + if task: + task.record(False, error=str(e)) + if on_item: + try: + await on_item(item, result) + except Exception: + pass + return item, result + + results: Dict[str, dict] = {} + + # 分批执行,避免一次性创建所有 task + for i in range(0, len(items), batch_size): + if (should_cancel and should_cancel()) or (task and task.cancelled): + break + chunk = items[i : i + batch_size] + pairs = await asyncio.gather(*(_one(x) for x in chunk)) + results.update(dict(pairs)) + + return results class BatchTask: @@ -150,3 +231,13 @@ def delete_task(task_id: str) -> None: async def expire_task(task_id: str, delay: int = 300) -> None: await asyncio.sleep(delay) delete_task(task_id) + + +__all__ = [ + "run_batch", + "BatchTask", + "create_task", + "get_task", + "delete_task", + "expire_task", +] diff --git a/app/services/grok/batch_services/assets.py b/app/services/grok/batch_services/assets.py index 5ba9ab35..ca87b3c9 100644 --- a/app/services/grok/batch_services/assets.py +++ b/app/services/grok/batch_services/assets.py @@ -11,7 +11,7 @@ from app.core.logger import logger from app.services.reverse.assets_list import AssetsListReverse from app.services.reverse.assets_delete import AssetsDeleteReverse -from app.services.grok.utils.batch import run_in_batches +from app.core.batch import run_batch class BaseAssetsService: @@ -149,7 +149,7 @@ async def _fetch_detail(token: str): return {"detail": detail, "count": 0} try: - return await run_in_batches( + return await run_batch( tokens, _fetch_detail, max_concurrent=max_concurrent, @@ -219,7 +219,7 @@ async def _clear_one(token: str): return {"status": "error", "error": str(e)} try: - return await run_in_batches( + return await run_batch( tokens, _clear_one, max_concurrent=max_concurrent, diff --git a/app/services/grok/batch_services/nsfw.py b/app/services/grok/batch_services/nsfw.py index f53128b3..facfbe5c 100644 --- a/app/services/grok/batch_services/nsfw.py +++ b/app/services/grok/batch_services/nsfw.py @@ -12,7 +12,7 @@ from app.services.reverse.accept_tos import AcceptTosReverse from app.services.reverse.nsfw_mgmt import NsfwMgmtReverse from app.services.reverse.set_birth import SetBirthReverse -from app.services.grok.utils.batch import run_in_batches +from app.core.batch import run_batch class NSFWService: @@ -85,7 +85,7 @@ async def _record_fail(err: UpstreamException, reason: str): logger.error(f"NSFW enable failed: {e}") return {"success": False, "http_status": 0, "error": str(e)[:100]} - return await run_in_batches( + return await run_batch( tokens, _enable, max_concurrent=max_concurrent, diff --git a/app/services/grok/batch_services/usage.py b/app/services/grok/batch_services/usage.py index ddefb742..9d10c98a 100644 --- a/app/services/grok/batch_services/usage.py +++ b/app/services/grok/batch_services/usage.py @@ -10,7 +10,7 @@ from app.core.logger import logger from app.core.config import get_config from app.services.reverse.rate_limits import RateLimitsReverse -from app.services.grok.utils.batch import run_in_batches +from app.core.batch import run_batch _USAGE_SEMAPHORE = asyncio.Semaphore(25) _USAGE_SEM_VALUE = 25 @@ -71,7 +71,7 @@ async def batch( async def _refresh_one(t: str): return await mgr.sync_usage(t, consume_on_fail=False, is_usage=False) - return await run_in_batches( + return await run_batch( tokens, _refresh_one, max_concurrent=max_concurrent, diff --git a/app/services/grok/utils/batch.py b/app/services/grok/utils/batch.py deleted file mode 100644 index adb64ea2..00000000 --- a/app/services/grok/utils/batch.py +++ /dev/null @@ -1,86 +0,0 @@ -""" -批量执行工具 - -提供分批并发、单项失败隔离的通用批量处理能力。 -""" - -import asyncio -from typing import Any, Awaitable, Callable, Dict, List, Optional, TypeVar - -from app.core.logger import logger - -T = TypeVar("T") - - -async def run_in_batches( - items: List[str], - worker: Callable[[str], Awaitable[T]], - *, - max_concurrent: int = 10, - batch_size: int = 50, - on_item: Optional[Callable[[str, Dict[str, Any]], Awaitable[None]]] = None, - should_cancel: Optional[Callable[[], bool]] = None, -) -> Dict[str, Dict[str, Any]]: - """ - 分批并发执行,单项失败不影响整体 - - Args: - items: 待处理项列表 - worker: 异步处理函数 - max_concurrent: 最大并发数 - batch_size: 每批大小 - - Returns: - {item: {"ok": bool, "data": ..., "error": ...}} - """ - try: - max_concurrent = int(max_concurrent) - except Exception: - max_concurrent = 10 - try: - batch_size = int(batch_size) - except Exception: - batch_size = 50 - - max_concurrent = max(1, max_concurrent) - batch_size = max(1, batch_size) - - sem = asyncio.Semaphore(max_concurrent) - - async def _one(item: str) -> tuple[str, dict]: - if should_cancel and should_cancel(): - return item, {"ok": False, "error": "cancelled", "cancelled": True} - async with sem: - try: - data = await worker(item) - result = {"ok": True, "data": data} - if on_item: - try: - await on_item(item, result) - except Exception: - pass - return item, result - except Exception as e: - logger.warning(f"Batch item failed: {item[:16]}... - {e}") - result = {"ok": False, "error": str(e)} - if on_item: - try: - await on_item(item, result) - except Exception: - pass - return item, result - - results: Dict[str, dict] = {} - - # 分批执行,避免一次性创建所有 task - for i in range(0, len(items), batch_size): - if should_cancel and should_cancel(): - break - chunk = items[i : i + batch_size] - pairs = await asyncio.gather(*(_one(x) for x in chunk)) - results.update(dict(pairs)) - - return results - - -__all__ = ["run_in_batches"] From 6257b68527f4e80975fa396bae8fa8b190ba988f Mon Sep 17 00:00:00 2001 From: Chenyme <118253778+chenyme@users.noreply.github.com> Date: Fri, 13 Feb 2026 14:21:11 +0800 Subject: [PATCH 19/27] refactor: remove unused files and streamline batch processing configuration in services --- app/__init__.py | 1 - app/api/v1/admin.py | 47 ----------------- app/core/batch.py | 54 ++++++++------------ app/services/__init__.py | 0 app/services/grok/__init__.py | 0 app/services/grok/batch_services/__init__.py | 5 -- app/services/grok/batch_services/assets.py | 8 +-- app/services/grok/batch_services/nsfw.py | 27 +++++++--- app/services/grok/batch_services/usage.py | 29 +++++------ app/services/reverse/rate_limits.py | 2 - 10 files changed, 58 insertions(+), 115 deletions(-) delete mode 100644 app/__init__.py delete mode 100644 app/services/__init__.py delete mode 100644 app/services/grok/__init__.py delete mode 100644 app/services/grok/batch_services/__init__.py diff --git a/app/__init__.py b/app/__init__.py deleted file mode 100644 index 9f7a4239..00000000 --- a/app/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""App Package""" diff --git a/app/api/v1/admin.py b/app/api/v1/admin.py index 6daa3d95..a05db4a3 100644 --- a/app/api/v1/admin.py +++ b/app/api/v1/admin.py @@ -816,15 +816,9 @@ async def refresh_tokens_api(data: dict): # 去重 unique_tokens = _dedupe_tokens(tokens) - # 批量执行配置 - max_concurrent = get_config("usage.concurrent") - batch_size = get_config("usage.batch_size") - raw_results = await UsageService.batch( unique_tokens, mgr, - max_concurrent=max_concurrent, - batch_size=batch_size, ) results = {} @@ -854,9 +848,6 @@ async def refresh_tokens_api_async(data: dict): # 去重 unique_tokens = _dedupe_tokens(tokens) - max_concurrent = get_config("usage.concurrent") - batch_size = get_config("usage.batch_size") - task = create_task(len(unique_tokens)) async def _run(): @@ -868,8 +859,6 @@ async def _on_item(item: str, res: dict): raw_results = await UsageService.batch( unique_tokens, mgr, - max_concurrent=max_concurrent, - batch_size=batch_size, on_item=_on_item, should_cancel=lambda: task.cancelled, ) @@ -939,15 +928,9 @@ async def enable_nsfw_api(data: dict): # 去重 unique_tokens = _dedupe_tokens(tokens) - # 批量执行配置 - max_concurrent = get_config("nsfw.concurrent") - batch_size = get_config("nsfw.batch_size") - raw_results = await NSFWService.batch( unique_tokens, mgr, - max_concurrent=max_concurrent, - batch_size=batch_size, ) # 构造返回结果(mask token) @@ -1005,9 +988,6 @@ async def enable_nsfw_api_async(data: dict): # 去重 unique_tokens = _dedupe_tokens(tokens) - max_concurrent = get_config("nsfw.concurrent") - batch_size = get_config("nsfw.batch_size") - task = create_task(len(unique_tokens)) async def _run(): @@ -1020,8 +1000,6 @@ async def _on_item(item: str, res: dict): raw_results = await NSFWService.batch( unique_tokens, mgr, - max_concurrent=max_concurrent, - batch_size=batch_size, on_item=_on_item, should_cancel=lambda: task.cancelled, ) @@ -1123,15 +1101,11 @@ async def get_cache_stats_api(request: Request): } online_details = [] account_map = {a["token"]: a for a in accounts} - batch_size = max(1, int(get_config("asset.list_batch_size"))) - max_concurrent = batch_size if selected_tokens: total = 0 raw_results = await ListService.fetch_assets_details( selected_tokens, account_map, - max_concurrent=max_concurrent, - batch_size=batch_size, ) for token, res in raw_results.items(): if res.get("ok"): @@ -1164,8 +1138,6 @@ async def get_cache_stats_api(request: Request): raw_results = await ListService.fetch_assets_details( tokens, account_map, - max_concurrent=max_concurrent, - batch_size=batch_size, ) for token, res in raw_results.items(): if res.get("ok"): @@ -1197,8 +1169,6 @@ async def get_cache_stats_api(request: Request): raw_results = await ListService.fetch_assets_details( [token], account_map, - max_concurrent=1, - batch_size=1, ) res = raw_results.get(token, {}) data = res.get("data", {}) @@ -1289,9 +1259,6 @@ async def load_online_cache_api_async(data: dict): else: raise HTTPException(status_code=400, detail="No tokens provided") - batch_size = get_config("asset.list_batch_size") - max_concurrent = batch_size - task = create_task(len(selected_tokens)) async def _run(): @@ -1307,8 +1274,6 @@ async def _on_item(item: str, res: dict): raw_results = await ListService.fetch_assets_details( selected_tokens, account_map, - max_concurrent=max_concurrent, - batch_size=batch_size, include_ok=True, on_item=_on_item, should_cancel=lambda: task.cancelled, @@ -1426,14 +1391,9 @@ async def clear_online_cache_api(data: dict): token_list = list(dict.fromkeys(token_list)) results = {} - batch_size = max(1, int(get_config("asset.delete_batch_size"))) - max_concurrent = batch_size - raw_results = await DeleteService.clear_assets( token_list, mgr, - max_concurrent=max_concurrent, - batch_size=batch_size, ) for token, res in raw_results.items(): if res.get("ok"): @@ -1452,8 +1412,6 @@ async def clear_online_cache_api(data: dict): raw_results = await DeleteService.clear_assets( [token], mgr, - max_concurrent=1, - batch_size=1, ) res = raw_results.get(token, {}) data = res.get("data", {}) @@ -1480,9 +1438,6 @@ async def clear_online_cache_api_async(data: dict): if not token_list: raise HTTPException(status_code=400, detail="No tokens provided") - batch_size = get_config("asset.delete_batch_size") - max_concurrent = batch_size - task = create_task(len(token_list)) async def _run(): @@ -1494,8 +1449,6 @@ async def _on_item(item: str, res: dict): raw_results = await DeleteService.clear_assets( token_list, mgr, - max_concurrent=max_concurrent, - batch_size=batch_size, include_ok=True, on_item=_on_item, should_cancel=lambda: task.cancelled, diff --git a/app/core/batch.py b/app/core/batch.py index 4857250d..7c62c015 100644 --- a/app/core/batch.py +++ b/app/core/batch.py @@ -19,7 +19,6 @@ async def run_batch( items: List[str], worker: Callable[[str], Awaitable[T]], *, - max_concurrent: int = 10, batch_size: int = 50, task: Optional["BatchTask"] = None, on_item: Optional[Callable[[str, Dict[str, Any]], Awaitable[None]]] = None, @@ -31,52 +30,43 @@ async def run_batch( Args: items: 待处理项列表 worker: 异步处理函数 - max_concurrent: 最大并发数 batch_size: 每批大小 Returns: {item: {"ok": bool, "data": ..., "error": ...}} """ - try: - max_concurrent = int(max_concurrent) - except Exception: - max_concurrent = 10 try: batch_size = int(batch_size) except Exception: batch_size = 50 - max_concurrent = max(1, max_concurrent) batch_size = max(1, batch_size) - sem = asyncio.Semaphore(max_concurrent) - async def _one(item: str) -> tuple[str, dict]: if (should_cancel and should_cancel()) or (task and task.cancelled): return item, {"ok": False, "error": "cancelled", "cancelled": True} - async with sem: - try: - data = await worker(item) - result = {"ok": True, "data": data} - if task: - task.record(True) - if on_item: - try: - await on_item(item, result) - except Exception: - pass - return item, result - except Exception as e: - logger.warning(f"Batch item failed: {item[:16]}... - {e}") - result = {"ok": False, "error": str(e)} - if task: - task.record(False, error=str(e)) - if on_item: - try: - await on_item(item, result) - except Exception: - pass - return item, result + try: + data = await worker(item) + result = {"ok": True, "data": data} + if task: + task.record(True) + if on_item: + try: + await on_item(item, result) + except Exception: + pass + return item, result + except Exception as e: + logger.warning(f"Batch item failed: {item[:16]}... - {e}") + result = {"ok": False, "error": str(e)} + if task: + task.record(False, error=str(e)) + if on_item: + try: + await on_item(item, result) + except Exception: + pass + return item, result results: Dict[str, dict] = {} diff --git a/app/services/__init__.py b/app/services/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/app/services/grok/__init__.py b/app/services/grok/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/app/services/grok/batch_services/__init__.py b/app/services/grok/batch_services/__init__.py deleted file mode 100644 index d71f450e..00000000 --- a/app/services/grok/batch_services/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Batch services.""" - -from .usage import UsageService - -__all__ = ["UsageService"] diff --git a/app/services/grok/batch_services/assets.py b/app/services/grok/batch_services/assets.py index ca87b3c9..7c3c31e3 100644 --- a/app/services/grok/batch_services/assets.py +++ b/app/services/grok/batch_services/assets.py @@ -106,8 +106,6 @@ async def fetch_assets_details( tokens: List[str], account_map: dict, *, - max_concurrent: int, - batch_size: int, include_ok: bool = False, on_item=None, should_cancel=None, @@ -115,6 +113,7 @@ async def fetch_assets_details( """Batch fetch assets details for tokens.""" account_map = account_map or {} shared_service = ListService() + batch_size = max(1, int(get_config("asset.list_batch_size"))) async def _fetch_detail(token: str): account = account_map.get(token) @@ -152,7 +151,6 @@ async def _fetch_detail(token: str): return await run_batch( tokens, _fetch_detail, - max_concurrent=max_concurrent, batch_size=batch_size, on_item=on_item, should_cancel=should_cancel, @@ -194,8 +192,6 @@ async def clear_assets( tokens: List[str], mgr, *, - max_concurrent: int, - batch_size: int, include_ok: bool = False, on_item=None, should_cancel=None, @@ -203,6 +199,7 @@ async def clear_assets( """Batch clear assets for tokens.""" delete_service = DeleteService() list_service = ListService() + batch_size = max(1, int(get_config("asset.delete_batch_size"))) async def _clear_one(token: str): try: @@ -222,7 +219,6 @@ async def _clear_one(token: str): return await run_batch( tokens, _clear_one, - max_concurrent=max_concurrent, batch_size=batch_size, on_item=on_item, should_cancel=should_cancel, diff --git a/app/services/grok/batch_services/nsfw.py b/app/services/grok/batch_services/nsfw.py index facfbe5c..e7761e2e 100644 --- a/app/services/grok/batch_services/nsfw.py +++ b/app/services/grok/batch_services/nsfw.py @@ -2,6 +2,7 @@ Batch NSFW service. """ +import asyncio from typing import Callable, Awaitable, Dict, Any, Optional from curl_cffi.requests import AsyncSession @@ -15,6 +16,19 @@ from app.core.batch import run_batch +_NSFW_SEMAPHORE = None +_NSFW_SEM_VALUE = None + + +def _get_nsfw_semaphore() -> asyncio.Semaphore: + value = max(1, int(get_config("nsfw.concurrent"))) + global _NSFW_SEMAPHORE, _NSFW_SEM_VALUE + if _NSFW_SEMAPHORE is None or value != _NSFW_SEM_VALUE: + _NSFW_SEM_VALUE = value + _NSFW_SEMAPHORE = asyncio.Semaphore(value) + return _NSFW_SEMAPHORE + + class NSFWService: """NSFW 模式服务""" @staticmethod @@ -22,12 +36,11 @@ async def batch( tokens: list[str], mgr, *, - max_concurrent: int, - batch_size: int, on_item: Optional[Callable[[str, Dict[str, Any]], Awaitable[None]]] = None, should_cancel: Optional[Callable[[], bool]] = None, ) -> Dict[str, Dict[str, Any]]: """Batch enable NSFW.""" + batch_size = get_config("nsfw.batch_size") async def _enable(token: str): try: browser = get_config("security.browser") @@ -43,7 +56,8 @@ async def _record_fail(err: UpstreamException, reason: str): return status or 0 try: - await AcceptTosReverse.request(session, token) + async with _get_nsfw_semaphore(): + await AcceptTosReverse.request(session, token) except UpstreamException as e: status = await _record_fail(e, "tos_auth_failed") return { @@ -53,7 +67,8 @@ async def _record_fail(err: UpstreamException, reason: str): } try: - await SetBirthReverse.request(session, token) + async with _get_nsfw_semaphore(): + await SetBirthReverse.request(session, token) except UpstreamException as e: status = await _record_fail(e, "set_birth_auth_failed") return { @@ -63,7 +78,8 @@ async def _record_fail(err: UpstreamException, reason: str): } try: - grpc_status = await NsfwMgmtReverse.request(session, token) + async with _get_nsfw_semaphore(): + grpc_status = await NsfwMgmtReverse.request(session, token) success = grpc_status.code in (-1, 0) except UpstreamException as e: status = await _record_fail(e, "nsfw_mgmt_auth_failed") @@ -88,7 +104,6 @@ async def _record_fail(err: UpstreamException, reason: str): return await run_batch( tokens, _enable, - max_concurrent=max_concurrent, batch_size=batch_size, on_item=on_item, should_cancel=should_cancel, diff --git a/app/services/grok/batch_services/usage.py b/app/services/grok/batch_services/usage.py index 9d10c98a..66aab105 100644 --- a/app/services/grok/batch_services/usage.py +++ b/app/services/grok/batch_services/usage.py @@ -12,8 +12,17 @@ from app.services.reverse.rate_limits import RateLimitsReverse from app.core.batch import run_batch -_USAGE_SEMAPHORE = asyncio.Semaphore(25) -_USAGE_SEM_VALUE = 25 +_USAGE_SEMAPHORE = None +_USAGE_SEM_VALUE = None + + +def _get_usage_semaphore() -> asyncio.Semaphore: + value = max(1, int(get_config("usage.concurrent"))) + global _USAGE_SEMAPHORE, _USAGE_SEM_VALUE + if _USAGE_SEMAPHORE is None or value != _USAGE_SEM_VALUE: + _USAGE_SEM_VALUE = value + _USAGE_SEMAPHORE = asyncio.Semaphore(value) + return _USAGE_SEMAPHORE class UsageService: @@ -32,17 +41,7 @@ async def get(self, token: str) -> Dict: Raises: UpstreamException: 当获取失败且重试耗尽时 """ - value = get_config("usage.concurrent") - try: - value = int(value) - except Exception: - value = 25 - value = max(1, value) - global _USAGE_SEMAPHORE, _USAGE_SEM_VALUE - if value != _USAGE_SEM_VALUE: - _USAGE_SEM_VALUE = value - _USAGE_SEMAPHORE = asyncio.Semaphore(value) - async with _USAGE_SEMAPHORE: + async with _get_usage_semaphore(): try: async with AsyncSession() as session: response = await RateLimitsReverse.request(session, token) @@ -63,18 +62,16 @@ async def batch( tokens: List[str], mgr, *, - max_concurrent: int, - batch_size: int, on_item: Optional[Callable[[str, Dict[str, Any]], Awaitable[None]]] = None, should_cancel: Optional[Callable[[], bool]] = None, ) -> Dict[str, Dict[str, Any]]: + batch_size = get_config("usage.batch_size") async def _refresh_one(t: str): return await mgr.sync_usage(t, consume_on_fail=False, is_usage=False) return await run_batch( tokens, _refresh_one, - max_concurrent=max_concurrent, batch_size=batch_size, on_item=on_item, should_cancel=should_cancel, diff --git a/app/services/reverse/rate_limits.py b/app/services/reverse/rate_limits.py index 48b8c220..198164c4 100644 --- a/app/services/reverse/rate_limits.py +++ b/app/services/reverse/rate_limits.py @@ -50,8 +50,6 @@ async def request(session: AsyncSession, token: str) -> Any: # Curl Config timeout = get_config("usage.timeout") - if timeout is None: - timeout = get_config("network.timeout") browser = get_config("security.browser") async def _do_request(): From d04bc4ee4ca4f8366ef690187b81a3f7a5a0eb77 Mon Sep 17 00:00:00 2001 From: Chenyme <118253778+chenyme@users.noreply.github.com> Date: Fri, 13 Feb 2026 20:22:52 +0800 Subject: [PATCH 20/27] refactor: streamline chat and video processing services, enhance validation, and remove unused components --- app/api/v1/chat.py | 334 ++++++---- app/core/exceptions.py | 9 + app/services/grok/batch_services/nsfw.py | 2 +- app/services/grok/processors/__init__.py | 6 - app/services/grok/processors/chat.py | 348 ----------- app/services/grok/processors/image.py | 23 +- app/services/grok/processors/video.py | 24 +- app/services/grok/services/chat.py | 582 ++++++++++++++---- app/services/grok/services/video.py | 29 +- .../{processors/base.py => utils/process.py} | 22 +- app/services/reverse/app_chat.py | 10 +- app/services/token/manager.py | 4 +- app/services/token/models.py | 4 +- 13 files changed, 757 insertions(+), 640 deletions(-) delete mode 100644 app/services/grok/processors/chat.py rename app/services/grok/{processors/base.py => utils/process.py} (87%) diff --git a/app/api/v1/chat.py b/app/api/v1/chat.py index b1a70354..f7b72ecd 100644 --- a/app/api/v1/chat.py +++ b/app/api/v1/chat.py @@ -3,106 +3,35 @@ """ from typing import Any, Dict, List, Optional, Union +import base64 +import binascii from fastapi import APIRouter from fastapi.responses import StreamingResponse, JSONResponse -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field from app.services.grok.services.chat import ChatService from app.services.grok.services.model import ModelService from app.core.exceptions import ValidationException -router = APIRouter(tags=["Chat"]) - - -VALID_ROLES = ["developer", "system", "user", "assistant", "tool"] -# 角色别名映射 (OpenAI 兼容: function -> tool) -ROLE_ALIASES = {"function": "tool"} -USER_CONTENT_TYPES = ["text", "image_url", "input_audio", "file"] - - class MessageItem(BaseModel): """消息项""" role: str content: Union[str, List[Dict[str, Any]]] - tool_call_id: Optional[str] = None # tool 角色需要的字段 - name: Optional[str] = None # function 角色的函数名 - @field_validator("role") - @classmethod - def validate_role(cls, v): - # 大小写归一化 - v_lower = v.lower() if isinstance(v, str) else v - # 别名映射 - v_normalized = ROLE_ALIASES.get(v_lower, v_lower) - if v_normalized not in VALID_ROLES: - raise ValueError(f"role must be one of {VALID_ROLES}") - return v_normalized + model_config = {"extra": "ignore"} class VideoConfig(BaseModel): """视频生成配置""" - aspect_ratio: Optional[str] = Field( - "3:2", description="视频比例: 3:2, 16:9, 1:1 等" - ) + aspect_ratio: Optional[str] = Field("3:2", description="视频比例: 1280x720(16:9), 720x1280(9:16), 1792x1024(3:2), 1024x1792(2:3), 1024x1024(1:1)") video_length: Optional[int] = Field(6, description="视频时长(秒): 6 / 10 / 15") resolution_name: Optional[str] = Field("480p", description="视频分辨率: 480p, 720p") preset: Optional[str] = Field("custom", description="风格预设: fun, normal, spicy") - @field_validator("aspect_ratio") - @classmethod - def validate_aspect_ratio(cls, v): - allowed = ["2:3", "3:2", "1:1", "9:16", "16:9"] - if v and v not in allowed: - raise ValidationException( - message=f"aspect_ratio must be one of {allowed}", - param="video_config.aspect_ratio", - code="invalid_aspect_ratio", - ) - return v - - @field_validator("video_length") - @classmethod - def validate_video_length(cls, v): - if v is not None: - if v not in (6, 10, 15): - raise ValidationException( - message="video_length must be 6, 10, or 15 seconds", - param="video_config.video_length", - code="invalid_video_length", - ) - return v - - @field_validator("resolution_name") - @classmethod - def validate_resolution(cls, v): - allowed = ["480p", "720p"] - if v and v not in allowed: - raise ValidationException( - message=f"resolution_name must be one of {allowed}", - param="video_config.resolution_name", - code="invalid_resolution", - ) - return v - - @field_validator("preset") - @classmethod - def validate_preset(cls, v): - # 允许为空,默认 custom - if not v: - return "custom" - allowed = ["fun", "normal", "spicy", "custom"] - if v not in allowed: - raise ValidationException( - message=f"preset must be one of {allowed}", - param="video_config.preset", - code="invalid_preset", - ) - return v - class ChatCompletionRequest(BaseModel): """Chat Completions 请求""" @@ -110,34 +39,163 @@ class ChatCompletionRequest(BaseModel): model: str = Field(..., description="模型名称") messages: List[MessageItem] = Field(..., description="消息数组") stream: Optional[bool] = Field(None, description="是否流式输出") - thinking: Optional[str] = Field(None, description="思考模式: enabled/disabled/None") - + reasoning_effort: Optional[str] = Field(None, description="推理强度: none/minimal/low/medium/high/xhigh") + temperature: Optional[float] = Field(0.8, description="采样温度: 0-2") + top_p: Optional[float] = Field(0.95, description="nucleus 采样: 0-1") # 视频生成配置 video_config: Optional[VideoConfig] = Field(None, description="视频生成参数") + model_config = {"extra": "ignore"} - @field_validator("stream", mode="before") - @classmethod - def validate_stream(cls, v): - """确保 stream 参数被正确解析为布尔值""" - if v is None: - return None - if isinstance(v, bool): - return v - if isinstance(v, str): - if v.lower() in ("true", "1", "yes"): - return True - if v.lower() in ("false", "0", "no"): - return False - # 未识别的字符串值抛出错误 - raise ValueError( - f"Invalid stream value '{v}'. Must be a boolean or one of: true, false, 1, 0, yes, no" + +VALID_ROLES = {"developer", "system", "user", "assistant"} +USER_CONTENT_TYPES = {"text", "image_url", "input_audio", "file"} + + +def _validate_media_input(value: str, field_name: str, param: str): + if not isinstance(value, str) or not value.strip(): + raise ValidationException( + message=f"{field_name} cannot be empty", + param=param, + code="empty_media", + ) + value = value.strip() + if value.startswith("data:"): + return + if value.startswith("http://") or value.startswith("https://"): + return + candidate = "".join(value.split()) + if len(candidate) >= 32 and len(candidate) % 4 == 0: + try: + base64.b64decode(candidate, validate=True) + raise ValidationException( + message=f"{field_name} base64 must be provided as a data URI (data:;base64,...)", + param=param, + code="invalid_media", ) - # 非布尔非字符串类型抛出错误 - raise ValueError( - f"Invalid stream value type '{type(v).__name__}'. Must be a boolean or string." + except binascii.Error: + pass + raise ValidationException( + message=f"{field_name} must be a URL or data URI", + param=param, + code="invalid_media", + ) + + +def _normalize_stream(value: Any) -> Optional[bool]: + if value is None: + return None + if isinstance(value, bool): + return value + if isinstance(value, str): + if value.lower() in ("true", "1", "yes"): + return True + if value.lower() in ("false", "0", "no"): + return False + raise ValidationException( + message="stream must be a boolean", + param="stream", + code="invalid_stream", + ) + + +def _validate_reasoning_effort(value: Any) -> Optional[str]: + allowed = {"none", "minimal", "low", "medium", "high", "xhigh"} + if value is None: + return None + if not isinstance(value, str) or value not in allowed: + raise ValidationException( + message=f"reasoning_effort must be one of {sorted(allowed)}", + param="reasoning_effort", + code="invalid_reasoning_effort", ) + return value - model_config = {"extra": "ignore"} + +def _validate_temperature(value: Any) -> float: + if value is None: + return 0.8 + try: + val = float(value) + except Exception: + raise ValidationException( + message="temperature must be a float", + param="temperature", + code="invalid_temperature", + ) + if not (0 <= val <= 2): + raise ValidationException( + message="temperature must be between 0 and 2", + param="temperature", + code="invalid_temperature", + ) + return val + + +def _validate_top_p(value: Any) -> float: + if value is None: + return 0.95 + try: + val = float(value) + except Exception: + raise ValidationException( + message="top_p must be a float", + param="top_p", + code="invalid_top_p", + ) + if not (0 <= val <= 1): + raise ValidationException( + message="top_p must be between 0 and 1", + param="top_p", + code="invalid_top_p", + ) + return val + + +def _normalize_video_config(config: Optional[VideoConfig]) -> VideoConfig: + if config is None: + config = VideoConfig() + + ratio_map = { + "1280x720": "16:9", + "720x1280": "9:16", + "1792x1024": "3:2", + "1024x1792": "2:3", + "1024x1024": "1:1", + "16:9": "16:9", + "9:16": "9:16", + "3:2": "3:2", + "2:3": "2:3", + "1:1": "1:1", + } + if config.aspect_ratio is None: + config.aspect_ratio = "3:2" + if config.aspect_ratio not in ratio_map: + raise ValidationException( + message=f"aspect_ratio must be one of {list(ratio_map.keys())}", + param="video_config.aspect_ratio", + code="invalid_aspect_ratio", + ) + config.aspect_ratio = ratio_map[config.aspect_ratio] + + if config.video_length not in (6, 10, 15): + raise ValidationException( + message="video_length must be 6, 10, or 15 seconds", + param="video_config.video_length", + code="invalid_video_length", + ) + if config.resolution_name not in ("480p", "720p"): + raise ValidationException( + message="resolution_name must be one of ['480p', '720p']", + param="video_config.resolution_name", + code="invalid_resolution", + ) + if config.preset not in ("fun", "normal", "spicy", "custom"): + raise ValidationException( + message="preset must be one of ['fun', 'normal', 'spicy', 'custom']", + param="video_config.preset", + code="invalid_preset", + ) + return config def validate_request(request: ChatCompletionRequest): @@ -152,6 +210,12 @@ def validate_request(request: ChatCompletionRequest): # 验证消息 for idx, msg in enumerate(request.messages): + if not isinstance(msg.role, str) or msg.role not in VALID_ROLES: + raise ValidationException( + message=f"role must be one of {sorted(VALID_ROLES)}", + param=f"messages.{idx}.role", + code="invalid_role", + ) content = msg.content # 字符串内容 @@ -174,6 +238,12 @@ def validate_request(request: ChatCompletionRequest): for block_idx, block in enumerate(content): # 检查空对象 + if not isinstance(block, dict): + raise ValidationException( + message="Content block must be an object", + param=f"messages.{idx}.content.{block_idx}", + code="invalid_block", + ) if not block: raise ValidationException( message="Content block cannot be empty", @@ -211,20 +281,13 @@ def validate_request(request: ChatCompletionRequest): param=f"messages.{idx}.content.{block_idx}.type", code="invalid_type", ) - elif msg.role in ("tool", "function"): - # tool/function 角色只支持 text 类型,但内容可以是 JSON 字符串 + else: if block_type != "text": raise ValidationException( message=f"The `{msg.role}` role only supports 'text' type, got '{block_type}'", param=f"messages.{idx}.content.{block_idx}.type", code="invalid_type", ) - elif block_type != "text": - raise ValidationException( - message=f"The `{msg.role}` role only supports 'text' type, got '{block_type}'", - param=f"messages.{idx}.content.{block_idx}.type", - code="invalid_type", - ) # 验证字段是否存在 & 非空 if block_type == "text": @@ -237,14 +300,61 @@ def validate_request(request: ChatCompletionRequest): ) elif block_type == "image_url": image_url = block.get("image_url") - if not image_url or not ( - isinstance(image_url, dict) and image_url.get("url") - ): + if not image_url or not isinstance(image_url, dict): raise ValidationException( message="image_url must have a 'url' field", param=f"messages.{idx}.content.{block_idx}.image_url", code="missing_url", ) + _validate_media_input( + image_url.get("url", ""), + "image_url.url", + f"messages.{idx}.content.{block_idx}.image_url.url", + ) + elif block_type == "input_audio": + audio = block.get("input_audio") + if not audio or not isinstance(audio, dict): + raise ValidationException( + message="input_audio must have a 'data' field", + param=f"messages.{idx}.content.{block_idx}.input_audio", + code="missing_audio", + ) + _validate_media_input( + audio.get("data", ""), + "input_audio.data", + f"messages.{idx}.content.{block_idx}.input_audio.data", + ) + elif block_type == "file": + file_data = block.get("file") + if not file_data or not isinstance(file_data, dict): + raise ValidationException( + message="file must have a 'file_data' field", + param=f"messages.{idx}.content.{block_idx}.file", + code="missing_file", + ) + _validate_media_input( + file_data.get("file_data", ""), + "file.file_data", + f"messages.{idx}.content.{block_idx}.file.file_data", + ) + else: + raise ValidationException( + message="Message content must be a string or array", + param=f"messages.{idx}.content", + code="invalid_content", + ) + + request.stream = _normalize_stream(request.stream) + request.reasoning_effort = _validate_reasoning_effort(request.reasoning_effort) + request.temperature = _validate_temperature(request.temperature) + request.top_p = _validate_top_p(request.top_p) + + model_info = ModelService.get(request.model) + if model_info and model_info.is_video: + request.video_config = _normalize_video_config(request.video_config) + + +router = APIRouter(tags=["Chat"]) @router.post("/chat/completions") @@ -269,7 +379,7 @@ async def chat_completions(request: ChatCompletionRequest): model=request.model, messages=[msg.model_dump() for msg in request.messages], stream=request.stream, - thinking=request.thinking, + reasoning_effort=request.reasoning_effort, aspect_ratio=v_conf.aspect_ratio, video_length=v_conf.video_length, resolution=v_conf.resolution_name, @@ -280,7 +390,9 @@ async def chat_completions(request: ChatCompletionRequest): model=request.model, messages=[msg.model_dump() for msg in request.messages], stream=request.stream, - thinking=request.thinking, + reasoning_effort=request.reasoning_effort, + temperature=request.temperature, + top_p=request.top_p, ) if isinstance(result, dict): diff --git a/app/core/exceptions.py b/app/core/exceptions.py index 9ceff005..4ae9092c 100644 --- a/app/core/exceptions.py +++ b/app/core/exceptions.py @@ -101,6 +101,14 @@ def __init__(self, message: str, details: Any = None): self.details = details +class StreamIdleTimeoutError(Exception): + """流空闲超时错误""" + + def __init__(self, idle_seconds: float): + self.idle_seconds = idle_seconds + super().__init__(f"Stream idle timeout after {idle_seconds}s") + + # ============= 异常处理器 ============= @@ -219,6 +227,7 @@ def register_exception_handlers(app): "ValidationException", "AuthenticationException", "UpstreamException", + "StreamIdleTimeoutError", "error_response", "register_exception_handlers", ] diff --git a/app/services/grok/batch_services/nsfw.py b/app/services/grok/batch_services/nsfw.py index e7761e2e..6fbdb9c1 100644 --- a/app/services/grok/batch_services/nsfw.py +++ b/app/services/grok/batch_services/nsfw.py @@ -51,7 +51,7 @@ async def _record_fail(err: UpstreamException, reason: str): status = err.details["status"] else: status = getattr(err, "status_code", None) - if status in (401, 403): + if status == 401: await mgr.record_fail(token, status, reason) return status or 0 diff --git a/app/services/grok/processors/__init__.py b/app/services/grok/processors/__init__.py index 72fb3d58..3cf7d613 100644 --- a/app/services/grok/processors/__init__.py +++ b/app/services/grok/processors/__init__.py @@ -2,8 +2,6 @@ OpenAI 响应格式处理器 """ -from .base import BaseProcessor, StreamIdleTimeoutError -from .chat import StreamProcessor, CollectProcessor from .video import VideoStreamProcessor, VideoCollectProcessor from .image import ( ImageStreamProcessor, @@ -13,10 +11,6 @@ ) __all__ = [ - "BaseProcessor", - "StreamIdleTimeoutError", - "StreamProcessor", - "CollectProcessor", "VideoStreamProcessor", "VideoCollectProcessor", "ImageStreamProcessor", diff --git a/app/services/grok/processors/chat.py b/app/services/grok/processors/chat.py deleted file mode 100644 index e5e7bd61..00000000 --- a/app/services/grok/processors/chat.py +++ /dev/null @@ -1,348 +0,0 @@ -""" -Chat response processors. -""" - -import asyncio -import uuid -import re -from typing import Any, AsyncGenerator, AsyncIterable - -import orjson -from curl_cffi.requests.errors import RequestsError - -from app.core.config import get_config -from app.core.logger import logger -from app.core.exceptions import UpstreamException -from .base import ( - BaseProcessor, - StreamIdleTimeoutError, - _with_idle_timeout, - _normalize_stream_line, - _collect_image_urls, - _is_http2_stream_error, -) - - -class StreamProcessor(BaseProcessor): - """Stream response processor.""" - - def __init__(self, model: str, token: str = "", think: bool = None): - super().__init__(model, token) - self.response_id: str = None - self.fingerprint: str = "" - self.think_opened: bool = False - self.role_sent: bool = False - self.filter_tags = get_config("chat.filter_tags") - self._tag_buffer: str = "" - self._in_filter_tag: bool = False - - if think is None: - self.show_think = get_config("chat.thinking") - else: - self.show_think = think - - def _filter_token(self, token: str) -> str: - """Filter special tags (supports cross-token tag filtering).""" - if not self.filter_tags: - return token - - result = [] - i = 0 - while i < len(token): - char = token[i] - - if self._in_filter_tag: - self._tag_buffer += char - if char == ">": - if "/>" in self._tag_buffer: - self._in_filter_tag = False - self._tag_buffer = "" - else: - for tag in self.filter_tags: - if f"" in self._tag_buffer: - self._in_filter_tag = False - self._tag_buffer = "" - break - i += 1 - continue - - if char == "<": - remaining = token[i:] - tag_started = False - for tag in self.filter_tags: - if remaining.startswith(f"<{tag}"): - tag_started = True - break - if len(remaining) < len(tag) + 1: - for j in range(1, len(remaining) + 1): - if f"<{tag}".startswith(remaining[:j]): - tag_started = True - break - - if tag_started: - self._in_filter_tag = True - self._tag_buffer = char - i += 1 - continue - - result.append(char) - i += 1 - - return "".join(result) - - def _sse(self, content: str = "", role: str = None, finish: str = None) -> str: - """Build SSE response.""" - delta = {} - if role: - delta["role"] = role - delta["content"] = "" - elif content: - delta["content"] = content - - chunk = { - "id": self.response_id or f"chatcmpl-{uuid.uuid4().hex[:24]}", - "object": "chat.completion.chunk", - "created": self.created, - "model": self.model, - "system_fingerprint": self.fingerprint, - "choices": [ - {"index": 0, "delta": delta, "logprobs": None, "finish_reason": finish} - ], - } - return f"data: {orjson.dumps(chunk).decode()}\n\n" - - async def process( - self, response: AsyncIterable[bytes] - ) -> AsyncGenerator[str, None]: - """Process stream response.""" - idle_timeout = get_config("timeout.stream_idle_timeout") - - try: - async for line in _with_idle_timeout(response, idle_timeout, self.model): - line = _normalize_stream_line(line) - if not line: - continue - try: - data = orjson.loads(line) - except orjson.JSONDecodeError: - continue - - resp = data.get("result", {}).get("response", {}) - - if (llm := resp.get("llmInfo")) and not self.fingerprint: - self.fingerprint = llm.get("modelHash", "") - if rid := resp.get("responseId"): - self.response_id = rid - - if not self.role_sent: - yield self._sse(role="assistant") - self.role_sent = True - - # Image generation progress - if img := resp.get("streamingImageGenerationResponse"): - if self.show_think: - if not self.think_opened: - yield self._sse("\n") - self.think_opened = True - idx = img.get("imageIndex", 0) + 1 - progress = img.get("progress", 0) - yield self._sse( - f"正在生成第{idx}张图片中,当前进度{progress}%\n" - ) - continue - - # modelResponse - if mr := resp.get("modelResponse"): - if self.think_opened and self.show_think: - if msg := mr.get("message"): - yield self._sse(msg + "\n") - yield self._sse("\n") - self.think_opened = False - - # Handle generated images - for url in _collect_image_urls(mr): - parts = url.split("/") - img_id = parts[-2] if len(parts) >= 2 else "image" - dl_service = self._get_dl() - rendered = await dl_service.render_image( - url, self.token, img_id - ) - yield self._sse(f"{rendered}\n") - - if ( - (meta := mr.get("metadata", {})) - .get("llm_info", {}) - .get("modelHash") - ): - self.fingerprint = meta["llm_info"]["modelHash"] - continue - - # Normal token - if (token := resp.get("token")) is not None: - if token: - filtered = self._filter_token(token) - if filtered: - yield self._sse(filtered) - - if self.think_opened: - yield self._sse("\n") - yield self._sse(finish="stop") - yield "data: [DONE]\n\n" - except asyncio.CancelledError: - logger.debug("Stream cancelled by client", extra={"model": self.model}) - except StreamIdleTimeoutError as e: - raise UpstreamException( - message=f"Stream idle timeout after {e.idle_seconds}s", - status_code=504, - details={ - "error": str(e), - "type": "stream_idle_timeout", - "idle_seconds": e.idle_seconds, - }, - ) - except RequestsError as e: - if _is_http2_stream_error(e): - logger.warning(f"HTTP/2 stream error: {e}", extra={"model": self.model}) - raise UpstreamException( - message="Upstream connection closed unexpectedly", - status_code=502, - details={"error": str(e), "type": "http2_stream_error"}, - ) - logger.error(f"Stream request error: {e}", extra={"model": self.model}) - raise UpstreamException( - message=f"Upstream request failed: {e}", - status_code=502, - details={"error": str(e)}, - ) - except Exception as e: - logger.error( - f"Stream processing error: {e}", - extra={"model": self.model, "error_type": type(e).__name__}, - ) - raise - finally: - await self.close() - - -class CollectProcessor(BaseProcessor): - """Non-stream response processor.""" - - def __init__(self, model: str, token: str = ""): - super().__init__(model, token) - self.filter_tags = get_config("chat.filter_tags") - - def _filter_content(self, content: str) -> str: - """Filter special tags in content.""" - if not content or not self.filter_tags: - return content - - result = content - for tag in self.filter_tags: - pattern = rf"<{re.escape(tag)}[^>]*>.*?|<{re.escape(tag)}[^>]*/>" - result = re.sub(pattern, "", result, flags=re.DOTALL) - - return result - - async def process(self, response: AsyncIterable[bytes]) -> dict[str, Any]: - """Process and collect full response.""" - response_id = "" - fingerprint = "" - content = "" - idle_timeout = get_config("timeout.stream_idle_timeout") - - try: - async for line in _with_idle_timeout(response, idle_timeout, self.model): - line = _normalize_stream_line(line) - if not line: - continue - try: - data = orjson.loads(line) - except orjson.JSONDecodeError: - continue - - resp = data.get("result", {}).get("response", {}) - - if (llm := resp.get("llmInfo")) and not fingerprint: - fingerprint = llm.get("modelHash", "") - - if mr := resp.get("modelResponse"): - response_id = mr.get("responseId", "") - content = mr.get("message", "") - - if urls := _collect_image_urls(mr): - content += "\n" - for url in urls: - parts = url.split("/") - img_id = parts[-2] if len(parts) >= 2 else "image" - dl_service = self._get_dl() - rendered = await dl_service.render_image( - url, self.token, img_id - ) - content += f"{rendered}\n" - - if ( - (meta := mr.get("metadata", {})) - .get("llm_info", {}) - .get("modelHash") - ): - fingerprint = meta["llm_info"]["modelHash"] - - except asyncio.CancelledError: - logger.debug("Collect cancelled by client", extra={"model": self.model}) - except StreamIdleTimeoutError as e: - logger.warning(f"Collect idle timeout: {e}", extra={"model": self.model}) - except RequestsError as e: - if _is_http2_stream_error(e): - logger.warning( - f"HTTP/2 stream error in collect: {e}", extra={"model": self.model} - ) - else: - logger.error(f"Collect request error: {e}", extra={"model": self.model}) - except Exception as e: - logger.error( - f"Collect processing error: {e}", - extra={"model": self.model, "error_type": type(e).__name__}, - ) - finally: - await self.close() - - content = self._filter_content(content) - - return { - "id": response_id, - "object": "chat.completion", - "created": self.created, - "model": self.model, - "system_fingerprint": fingerprint, - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": content, - "refusal": None, - "annotations": [], - }, - "finish_reason": "stop", - } - ], - "usage": { - "prompt_tokens": 0, - "completion_tokens": 0, - "total_tokens": 0, - "prompt_tokens_details": { - "cached_tokens": 0, - "text_tokens": 0, - "audio_tokens": 0, - "image_tokens": 0, - }, - "completion_tokens_details": { - "text_tokens": 0, - "audio_tokens": 0, - "reasoning_tokens": 0, - }, - }, - } - - -__all__ = ["StreamProcessor", "CollectProcessor"] diff --git a/app/services/grok/processors/image.py b/app/services/grok/processors/image.py index f8f1f29e..bbfdb3bd 100644 --- a/app/services/grok/processors/image.py +++ b/app/services/grok/processors/image.py @@ -15,14 +15,13 @@ from app.core.config import get_config from app.core.logger import logger from app.core.storage import DATA_DIR -from app.core.exceptions import UpstreamException -from .base import ( +from app.core.exceptions import UpstreamException, StreamIdleTimeoutError +from app.services.grok.utils.process import ( BaseProcessor, - StreamIdleTimeoutError, _with_idle_timeout, - _normalize_stream_line, - _collect_image_urls, - _is_http2_stream_error, + _normalize_line, + _collect_images, + _is_http2_error, ) @@ -57,7 +56,7 @@ async def process( try: async for line in _with_idle_timeout(response, idle_timeout, self.model): - line = _normalize_stream_line(line) + line = _normalize_line(line) if not line: continue try: @@ -90,7 +89,7 @@ async def process( # modelResponse if mr := resp.get("modelResponse"): - if urls := _collect_image_urls(mr): + if urls := _collect_images(mr): for url in urls: if self.response_format == "url": processed = await self.process_url(url, "image") @@ -155,7 +154,7 @@ async def process( }, ) except RequestsError as e: - if _is_http2_stream_error(e): + if _is_http2_error(e): logger.warning(f"HTTP/2 stream error in image: {e}") raise UpstreamException( message="Upstream connection closed unexpectedly", @@ -192,7 +191,7 @@ async def process(self, response: AsyncIterable[bytes]) -> List[str]: try: async for line in _with_idle_timeout(response, idle_timeout, self.model): - line = _normalize_stream_line(line) + line = _normalize_line(line) if not line: continue try: @@ -203,7 +202,7 @@ async def process(self, response: AsyncIterable[bytes]) -> List[str]: resp = data.get("result", {}).get("response", {}) if mr := resp.get("modelResponse"): - if urls := _collect_image_urls(mr): + if urls := _collect_images(mr): for url in urls: if self.response_format == "url": processed = await self.process_url(url, "image") @@ -234,7 +233,7 @@ async def process(self, response: AsyncIterable[bytes]) -> List[str]: except StreamIdleTimeoutError as e: logger.warning(f"Image collect idle timeout: {e}") except RequestsError as e: - if _is_http2_stream_error(e): + if _is_http2_error(e): logger.warning(f"HTTP/2 stream error in image collect: {e}") else: logger.error(f"Image collect request error: {e}") diff --git a/app/services/grok/processors/video.py b/app/services/grok/processors/video.py index 8ba68b31..fb501cac 100644 --- a/app/services/grok/processors/video.py +++ b/app/services/grok/processors/video.py @@ -11,29 +11,25 @@ from app.core.config import get_config from app.core.logger import logger -from app.core.exceptions import UpstreamException -from .base import ( +from app.core.exceptions import UpstreamException, StreamIdleTimeoutError +from app.services.grok.utils.process import ( BaseProcessor, - StreamIdleTimeoutError, _with_idle_timeout, - _normalize_stream_line, - _is_http2_stream_error, + _normalize_line, + _is_http2_error, ) class VideoStreamProcessor(BaseProcessor): """Video stream response processor.""" - def __init__(self, model: str, token: str = "", think: bool = None): + def __init__(self, model: str, token: str = "", show_think: bool = None): super().__init__(model, token) self.response_id: Optional[str] = None self.think_opened: bool = False self.role_sent: bool = False - if think is None: - self.show_think = get_config("chat.thinking") - else: - self.show_think = think + self.show_think = bool(show_think) def _sse(self, content: str = "", role: str = None, finish: str = None) -> str: """Build SSE response.""" @@ -63,7 +59,7 @@ async def process( try: async for line in _with_idle_timeout(response, idle_timeout, self.model): - line = _normalize_stream_line(line) + line = _normalize_line(line) if not line: continue try: @@ -127,7 +123,7 @@ async def process( }, ) except RequestsError as e: - if _is_http2_stream_error(e): + if _is_http2_error(e): logger.warning( f"HTTP/2 stream error in video: {e}", extra={"model": self.model} ) @@ -167,7 +163,7 @@ async def process(self, response: AsyncIterable[bytes]) -> dict[str, Any]: try: async for line in _with_idle_timeout(response, idle_timeout, self.model): - line = _normalize_stream_line(line) + line = _normalize_line(line) if not line: continue try: @@ -199,7 +195,7 @@ async def process(self, response: AsyncIterable[bytes]) -> dict[str, Any]: f"Video collect idle timeout: {e}", extra={"model": self.model} ) except RequestsError as e: - if _is_http2_stream_error(e): + if _is_http2_error(e): logger.warning( f"HTTP/2 stream error in video collect: {e}", extra={"model": self.model}, diff --git a/app/services/grok/services/chat.py b/app/services/grok/services/chat.py index 8fe1e1aa..e2a452b9 100644 --- a/app/services/grok/services/chat.py +++ b/app/services/grok/services/chat.py @@ -2,10 +2,14 @@ Grok Chat 服务 """ -from typing import Dict, List, Any -from dataclasses import dataclass +import asyncio +import re +import uuid +from typing import Dict, List, Any, AsyncGenerator, AsyncIterable +import orjson from curl_cffi.requests import AsyncSession +from curl_cffi.requests.errors import RequestsError from app.core.logger import logger from app.core.config import get_config @@ -14,39 +18,29 @@ ValidationException, ErrorType, UpstreamException, + StreamIdleTimeoutError, ) from app.services.grok.services.model import ModelService from app.services.grok.utils.upload import UploadService -from app.services.grok.processors import StreamProcessor, CollectProcessor +from app.services.grok.utils import process as proc_base from app.services.reverse.app_chat import AppChatReverse from app.services.grok.utils.stream import wrap_stream_with_usage from app.services.token import get_token_manager, EffortType -@dataclass -class ChatRequest: - """聊天请求数据""" - - model: str - messages: List[Dict[str, Any]] - stream: bool = None - think: bool = None - - class MessageExtractor: """消息内容提取器""" @staticmethod - def extract( - messages: List[Dict[str, Any]], is_video: bool = False - ) -> tuple[str, List[tuple[str, str]]]: - """从 OpenAI 消息格式提取内容,返回 (text, attachments)""" + def extract(messages: List[Dict[str, Any]]) -> tuple[str, List[str], List[str]]: + """从 OpenAI 消息格式提取内容,返回 (text, file_attachments, image_attachments)""" texts = [] - attachments = [] + file_attachments: List[str] = [] + image_attachments: List[str] = [] extracted = [] for msg in messages: - role = msg.get("role", "") + role = msg.get("role", "") or "user" content = msg.get("content", "") parts = [] @@ -63,35 +57,21 @@ def extract( elif item_type == "image_url": image_data = item.get("image_url", {}) - url = ( - image_data.get("url", "") - if isinstance(image_data, dict) - else str(image_data) - ) + url = image_data.get("url", "") if url: - attachments.append(("image", url)) + image_attachments.append(url) elif item_type == "input_audio": - if is_video: - raise ValueError("视频模型不支持 input_audio 类型") audio_data = item.get("input_audio", {}) - data = ( - audio_data.get("data", "") - if isinstance(audio_data, dict) - else str(audio_data) - ) + data = audio_data.get("data", "") if data: - attachments.append(("audio", data)) + file_attachments.append(data) elif item_type == "file": - if is_video: - raise ValueError("视频模型不支持 file 类型") file_data = item.get("file", {}) - url = file_data.get("url", "") or file_data.get("data", "") - if isinstance(file_data, str): - url = file_data - if url: - attachments.append(("file", url)) + raw = file_data.get("file_data", "") + if raw: + file_attachments.append(raw) if parts: extracted.append({"role": role, "text": "\n".join(parts)}) @@ -111,36 +91,12 @@ def extract( text = item["text"] texts.append(text if i == last_user_index else f"{role}: {text}") - return "\n\n".join(texts), attachments - - -class ChatRequestBuilder: - """请求构造器""" - - @staticmethod - def build_payload( - message: str, - model: str, - mode: str = None, - file_attachments: List[str] = None, - image_attachments: List[str] = None, - ) -> Dict[str, Any]: - """构造请求体""" - return AppChatReverse.build_payload( - message=message, - model=model, - mode=mode, - file_attachments=file_attachments, - image_attachments=image_attachments, - ) + return "\n\n".join(texts), file_attachments, image_attachments class GrokChatService: """Grok API 调用服务""" - def __init__(self): - pass - async def chat( self, token: str, @@ -149,7 +105,6 @@ async def chat( mode: str = None, stream: bool = None, file_attachments: List[str] = None, - image_attachments: List[str] = None, tool_overrides: Dict[str, Any] = None, model_config_override: Dict[str, Any] = None, ): @@ -171,7 +126,6 @@ async def chat( model=model, mode=mode, file_attachments=file_attachments, - image_attachments=image_attachments, tool_overrides=tool_overrides, model_config_override=model_config_override, ) @@ -182,44 +136,58 @@ async def chat( return stream_response - async def chat_openai(self, token: str, request: ChatRequest): + async def chat_openai( + self, + token: str, + model: str, + messages: List[Dict[str, Any]], + stream: bool = None, + reasoning_effort: str | None = None, + temperature: float = 0.8, + top_p: float = 0.95, + ): """OpenAI 兼容接口""" - model_info = ModelService.get(request.model) + model_info = ModelService.get(model) if not model_info: - raise ValidationException(f"Unknown model: {request.model}") + raise ValidationException(f"Unknown model: {model}") grok_model = model_info.grok_model mode = model_info.model_mode - is_video = model_info.is_video - # 提取消息和附件 - try: - message, attachments = MessageExtractor.extract( - request.messages, is_video=is_video - ) - logger.debug( - f"Extracted message length={len(message)}, attachments={len(attachments)}" - ) - except ValueError as e: - raise ValidationException(str(e)) + message, file_attachments, image_attachments = MessageExtractor.extract(messages) + logger.debug( + "Extracted message length=%s, files=%s, images=%s", + len(message), + len(file_attachments), + len(image_attachments), + ) # 上传附件 - file_ids = [] - if attachments: + file_ids: List[str] = [] + image_ids: List[str] = [] + if file_attachments or image_attachments: upload_service = UploadService() try: - for attach_type, attach_data in attachments: + for attach_data in file_attachments: file_id, _ = await upload_service.upload_file(attach_data, token) file_ids.append(file_id) - logger.debug( - f"Attachment uploaded: type={attach_type}, file_id={file_id}" - ) + logger.debug(f"Attachment uploaded: type=file, file_id={file_id}") + for attach_data in image_attachments: + file_id, _ = await upload_service.upload_file(attach_data, token) + image_ids.append(file_id) + logger.debug(f"Attachment uploaded: type=image, file_id={file_id}") finally: await upload_service.close() - stream = ( - request.stream if request.stream is not None else get_config("chat.stream") - ) + all_attachments = file_ids + image_ids + stream = stream if stream is not None else get_config("chat.stream") + + model_config_override = { + "temperature": temperature, + "topP": top_p, + } + if reasoning_effort is not None: + model_config_override["reasoningEffort"] = reasoning_effort response = await self.chat( token, @@ -227,11 +195,11 @@ async def chat_openai(self, token: str, request: ChatRequest): grok_model, mode, stream, - file_attachments=file_ids, - image_attachments=[], + file_attachments=all_attachments, + model_config_override=model_config_override, ) - return response, stream, request.model + return response, stream, model class ChatService: @@ -242,29 +210,29 @@ async def completions( model: str, messages: List[Dict[str, Any]], stream: bool = None, - thinking: str = None, + reasoning_effort: str | None = None, + temperature: float = 0.8, + top_p: float = 0.95, ): """Chat Completions 入口""" # 获取 token token_mgr = await get_token_manager() await token_mgr.reload_if_stale() - # 解析参数(只需解析一次) - think = {"enabled": True, "disabled": False}.get(thinking) + # 解析参数 + if reasoning_effort is None: + show_think = get_config("chat.thinking") + else: + show_think = reasoning_effort != "none" is_stream = stream if stream is not None else get_config("chat.stream") - # 构造请求(只需构造一次) - chat_request = ChatRequest( - model=model, messages=messages, stream=is_stream, think=think - ) - # 跨 Token 重试循环 tried_tokens = set() max_token_retries = int(get_config("retry.max_retry")) last_error = None for attempt in range(max_token_retries): - # 选择 token(排除已失败的) + # 选择 token token = None for pool_name in ModelService.pool_candidates_for_model(model): token = token_mgr.get_token(pool_name, exclude=tried_tokens) @@ -296,12 +264,20 @@ async def completions( try: # 请求 Grok service = GrokChatService() - response, _, model_name = await service.chat_openai(token, chat_request) + response, _, model_name = await service.chat_openai( + token, + model, + messages, + stream=is_stream, + reasoning_effort=reasoning_effort, + temperature=temperature, + top_p=top_p, + ) # 处理响应 if is_stream: logger.debug(f"Processing stream response: model={model}") - processor = StreamProcessor(model_name, token, think) + processor = StreamProcessor(model_name, token, show_think) return wrap_stream_with_usage( processor.process(response), token_mgr, token, model ) @@ -349,10 +325,404 @@ async def completions( ) +class StreamProcessor(proc_base.BaseProcessor): + """Stream response processor.""" + + def __init__(self, model: str, token: str = "", show_think: bool = None): + super().__init__(model, token) + self.response_id: str = None + self.fingerprint: str = "" + self.think_opened: bool = False + self.role_sent: bool = False + self.filter_tags = get_config("chat.filter_tags") + self._tag_buffer: str = "" + self._in_filter_tag: bool = False + + self.show_think = bool(show_think) + + def _filter_token(self, token: str) -> str: + """Filter special tags (supports cross-token tag filtering).""" + if not self.filter_tags: + return token + + result = [] + i = 0 + while i < len(token): + char = token[i] + + if self._in_filter_tag: + self._tag_buffer += char + if char == ">": + if "/>" in self._tag_buffer: + self._in_filter_tag = False + self._tag_buffer = "" + else: + for tag in self.filter_tags: + if f"" in self._tag_buffer: + self._in_filter_tag = False + self._tag_buffer = "" + break + i += 1 + continue + + if char == "<": + remaining = token[i:] + tag_started = False + for tag in self.filter_tags: + if remaining.startswith(f"<{tag}"): + tag_started = True + break + if len(remaining) < len(tag) + 1: + for j in range(1, len(remaining) + 1): + if f"<{tag}".startswith(remaining[:j]): + tag_started = True + break + + if tag_started: + self._in_filter_tag = True + self._tag_buffer = char + i += 1 + continue + + result.append(char) + i += 1 + + return "".join(result) + + def _sse(self, content: str = "", role: str = None, finish: str = None) -> str: + """Build SSE response.""" + delta = {} + if role: + delta["role"] = role + delta["content"] = "" + elif content: + delta["content"] = content + + chunk = { + "id": self.response_id or f"chatcmpl-{uuid.uuid4().hex[:24]}", + "object": "chat.completion.chunk", + "created": self.created, + "model": self.model, + "system_fingerprint": self.fingerprint, + "choices": [ + {"index": 0, "delta": delta, "logprobs": None, "finish_reason": finish} + ], + } + return f"data: {orjson.dumps(chunk).decode()}\n\n" + + async def process(self, response: AsyncIterable[bytes]) -> AsyncGenerator[str, None]: + """Process stream response. + + Args: + response: AsyncIterable[bytes], async iterable of bytes + + Returns: + AsyncGenerator[str, None], async generator of strings + """ + idle_timeout = get_config("timeout.stream_idle_timeout") + + try: + async for line in proc_base._with_idle_timeout( + response, idle_timeout, self.model + ): + line = proc_base._normalize_line(line) + if not line: + continue + try: + data = orjson.loads(line) + except orjson.JSONDecodeError: + continue + + resp = data.get("result", {}).get("response", {}) + is_thinking = bool(resp.get("isThinking")) + # isThinking controls tagging + # when absent, treat as False + + if (llm := resp.get("llmInfo")) and not self.fingerprint: + self.fingerprint = llm.get("modelHash", "") + if rid := resp.get("responseId"): + self.response_id = rid + + if not self.role_sent: + yield self._sse(role="assistant") + self.role_sent = True + + if img := resp.get("streamingImageGenerationResponse"): + if not self.show_think: + continue + if is_thinking and not self.think_opened: + yield self._sse("\n") + self.think_opened = True + if (not is_thinking) and self.think_opened: + yield self._sse("\n\n") + self.think_opened = False + idx = img.get("imageIndex", 0) + 1 + progress = img.get("progress", 0) + yield self._sse( + f"正在生成第{idx}张图片中,当前进度{progress}%\n" + ) + continue + + if mr := resp.get("modelResponse"): + for url in proc_base._collect_images(mr): + parts = url.split("/") + img_id = parts[-2] if len(parts) >= 2 else "image" + dl_service = self._get_dl() + rendered = await dl_service.render_image( + url, self.token, img_id + ) + yield self._sse(f"{rendered}\n") + + if ( + (meta := mr.get("metadata", {})) + .get("llm_info", {}) + .get("modelHash") + ): + self.fingerprint = meta["llm_info"]["modelHash"] + continue + + if card := resp.get("cardAttachment"): + json_data = card.get("jsonData") + if isinstance(json_data, str) and json_data.strip(): + try: + card_data = orjson.loads(json_data) + except orjson.JSONDecodeError: + card_data = None + if isinstance(card_data, dict): + image = card_data.get("image") or {} + original = image.get("original") + title = image.get("title") or "" + if original: + title_safe = title.replace("\n", " ").strip() + if title_safe: + yield self._sse(f"![{title_safe}]({original})\n") + else: + yield self._sse(f"![image]({original})\n") + continue + + if (token := resp.get("token")) is not None: + if not token: + continue + filtered = self._filter_token(token) + if not filtered: + continue + if is_thinking: + if not self.show_think: + continue + if not self.think_opened: + yield self._sse("\n") + self.think_opened = True + else: + if self.think_opened: + yield self._sse("\n\n") + self.think_opened = False + yield self._sse(filtered) + + if self.think_opened: + yield self._sse("\n") + yield self._sse(finish="stop") + yield "data: [DONE]\n\n" + except asyncio.CancelledError: + logger.debug("Stream cancelled by client", extra={"model": self.model}) + except StreamIdleTimeoutError as e: + raise UpstreamException( + message=f"Stream idle timeout after {e.idle_seconds}s", + status_code=504, + details={ + "error": str(e), + "type": "stream_idle_timeout", + "idle_seconds": e.idle_seconds, + }, + ) + except RequestsError as e: + if proc_base._is_http2_error(e): + logger.warning(f"HTTP/2 stream error: {e}", extra={"model": self.model}) + raise UpstreamException( + message="Upstream connection closed unexpectedly", + status_code=502, + details={"error": str(e), "type": "http2_stream_error"}, + ) + logger.error(f"Stream request error: {e}", extra={"model": self.model}) + raise UpstreamException( + message=f"Upstream request failed: {e}", + status_code=502, + details={"error": str(e)}, + ) + except Exception as e: + logger.error( + f"Stream processing error: {e}", + extra={"model": self.model, "error_type": type(e).__name__}, + ) + raise + finally: + await self.close() + + +class CollectProcessor(proc_base.BaseProcessor): + """Non-stream response processor.""" + + def __init__(self, model: str, token: str = ""): + super().__init__(model, token) + self.filter_tags = get_config("chat.filter_tags") + + def _filter_content(self, content: str) -> str: + """Filter special tags in content.""" + if not content or not self.filter_tags: + return content + + result = content + for tag in self.filter_tags: + pattern = rf"<{re.escape(tag)}[^>]*>.*?|<{re.escape(tag)}[^>]*/>" + result = re.sub(pattern, "", result, flags=re.DOTALL) + + return result + + async def process(self, response: AsyncIterable[bytes]) -> dict[str, Any]: + """Process and collect full response.""" + response_id = "" + fingerprint = "" + content = "" + idle_timeout = get_config("timeout.stream_idle_timeout") + + try: + async for line in proc_base._with_idle_timeout( + response, idle_timeout, self.model + ): + line = proc_base._normalize_line(line) + if not line: + continue + try: + data = orjson.loads(line) + except orjson.JSONDecodeError: + continue + + resp = data.get("result", {}).get("response", {}) + + if (llm := resp.get("llmInfo")) and not fingerprint: + fingerprint = llm.get("modelHash", "") + + if mr := resp.get("modelResponse"): + response_id = mr.get("responseId", "") + content = mr.get("message", "") + + card_map: dict[str, tuple[str, str]] = {} + for raw in mr.get("cardAttachmentsJson") or []: + if not isinstance(raw, str) or not raw.strip(): + continue + try: + card_data = orjson.loads(raw) + except orjson.JSONDecodeError: + continue + if not isinstance(card_data, dict): + continue + card_id = card_data.get("id") + image = card_data.get("image") or {} + original = image.get("original") + if not card_id or not original: + continue + title = image.get("title") or "" + card_map[card_id] = (title, original) + + if content and card_map: + def _render_card(match: re.Match) -> str: + card_id = match.group(1) + item = card_map.get(card_id) + if not item: + return "" + title, original = item + title_safe = title.replace("\n", " ").strip() or "image" + prefix = "" + if match.start() > 0: + prev = content[match.start() - 1] + if prev not in ("\n", "\r"): + prefix = "\n" + return f"{prefix}![{title_safe}]({original})" + + content = re.sub( + r']*card_id="([^"]+)"[^>]*>.*?', + _render_card, + content, + flags=re.DOTALL, + ) + + if urls := proc_base._collect_images(mr): + content += "\n" + for url in urls: + parts = url.split("/") + img_id = parts[-2] if len(parts) >= 2 else "image" + dl_service = self._get_dl() + rendered = await dl_service.render_image( + url, self.token, img_id + ) + content += f"{rendered}\n" + + if ( + (meta := mr.get("metadata", {})) + .get("llm_info", {}) + .get("modelHash") + ): + fingerprint = meta["llm_info"]["modelHash"] + + except asyncio.CancelledError: + logger.debug("Collect cancelled by client", extra={"model": self.model}) + except StreamIdleTimeoutError as e: + logger.warning(f"Collect idle timeout: {e}", extra={"model": self.model}) + except RequestsError as e: + if proc_base._is_http2_error(e): + logger.warning( + f"HTTP/2 stream error in collect: {e}", extra={"model": self.model} + ) + else: + logger.error(f"Collect request error: {e}", extra={"model": self.model}) + except Exception as e: + logger.error( + f"Collect processing error: {e}", + extra={"model": self.model, "error_type": type(e).__name__}, + ) + finally: + await self.close() + + content = self._filter_content(content) + + return { + "id": response_id, + "object": "chat.completion", + "created": self.created, + "model": self.model, + "system_fingerprint": fingerprint, + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": content, + "refusal": None, + "annotations": [], + }, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + "prompt_tokens_details": { + "cached_tokens": 0, + "text_tokens": 0, + "audio_tokens": 0, + "image_tokens": 0, + }, + "completion_tokens_details": { + "text_tokens": 0, + "audio_tokens": 0, + "reasoning_tokens": 0, + }, + }, + } + + __all__ = [ "GrokChatService", - "ChatRequest", - "ChatRequestBuilder", "MessageExtractor", "ChatService", ] diff --git a/app/services/grok/services/video.py b/app/services/grok/services/video.py index b848045a..32c9f6e5 100644 --- a/app/services/grok/services/video.py +++ b/app/services/grok/services/video.py @@ -224,7 +224,7 @@ async def completions( model: str, messages: list, stream: bool = None, - thinking: str = None, + reasoning_effort: str | None = None, aspect_ratio: str = "3:2", video_length: int = 6, resolution: str = "480p", @@ -256,29 +256,30 @@ async def completions( if token.startswith("sso="): token = token[4:] - think = {"enabled": True, "disabled": False}.get(thinking) + if reasoning_effort is None: + show_think = get_config("chat.thinking") + else: + show_think = reasoning_effort != "none" is_stream = stream if stream is not None else get_config("chat.stream") # Extract content. from app.services.grok.services.chat import MessageExtractor from app.services.grok.utils.upload import UploadService - try: - prompt, attachments = MessageExtractor.extract(messages, is_video=True) - except ValueError as e: - raise ValidationException(str(e)) + prompt, file_attachments, image_attachments = MessageExtractor.extract( + messages, is_video=True + ) # Handle image attachments. image_url = None - if attachments: + if image_attachments: upload_service = UploadService() try: - for attach_type, attach_data in attachments: - if attach_type == "image": - _, file_uri = await upload_service.upload_file(attach_data, token) - image_url = f"https://assets.grok.com/{file_uri}" - logger.info(f"Image uploaded for video: {image_url}") - break + for attach_data in image_attachments: + _, file_uri = await upload_service.upload_file(attach_data, token) + image_url = f"https://assets.grok.com/{file_uri}" + logger.info(f"Image uploaded for video: {image_url}") + break finally: await upload_service.close() @@ -295,7 +296,7 @@ async def completions( # Process response. if is_stream: - processor = VideoStreamProcessor(model, token, think) + processor = VideoStreamProcessor(model, token, show_think) return wrap_stream_with_usage( processor.process(response), token_mgr, token, model ) diff --git a/app/services/grok/processors/base.py b/app/services/grok/utils/process.py similarity index 87% rename from app/services/grok/processors/base.py rename to app/services/grok/utils/process.py index cab0631a..12249491 100644 --- a/app/services/grok/processors/base.py +++ b/app/services/grok/utils/process.py @@ -8,19 +8,20 @@ from app.core.config import get_config from app.core.logger import logger +from app.core.exceptions import StreamIdleTimeoutError from app.services.grok.utils.download import DownloadService T = TypeVar("T") -def _is_http2_stream_error(e: Exception) -> bool: +def _is_http2_error(e: Exception) -> bool: """检查是否为 HTTP/2 流错误""" err_str = str(e).lower() return "http/2" in err_str or "curl: (92)" in err_str or "stream" in err_str -def _normalize_stream_line(line: Any) -> Optional[str]: +def _normalize_line(line: Any) -> Optional[str]: """规范化流式响应行,兼容 SSE data 前缀与空行""" if line is None: return None @@ -38,7 +39,7 @@ def _normalize_stream_line(line: Any) -> Optional[str]: return text -def _collect_image_urls(obj: Any) -> List[str]: +def _collect_images(obj: Any) -> List[str]: """递归收集响应中的图片 URL""" urls: List[str] = [] seen = set() @@ -69,14 +70,6 @@ def walk(value: Any): return urls -class StreamIdleTimeoutError(Exception): - """流空闲超时错误""" - - def __init__(self, idle_seconds: float): - self.idle_seconds = idle_seconds - super().__init__(f"Stream idle timeout after {idle_seconds}s") - - async def _with_idle_timeout( iterable: AsyncIterable[T], idle_timeout: float, model: str = "" ) -> AsyncGenerator[T, None]: @@ -138,9 +131,8 @@ async def process_url(self, path: str, media_type: str = "image") -> str: __all__ = [ "BaseProcessor", - "StreamIdleTimeoutError", "_with_idle_timeout", - "_normalize_stream_line", - "_collect_image_urls", - "_is_http2_stream_error", + "_normalize_line", + "_collect_images", + "_is_http2_error", ] diff --git a/app/services/reverse/app_chat.py b/app/services/reverse/app_chat.py index 77ce6856..6ddeb07e 100644 --- a/app/services/reverse/app_chat.py +++ b/app/services/reverse/app_chat.py @@ -25,17 +25,12 @@ def build_payload( model: str, mode: str = None, file_attachments: List[str] = None, - image_attachments: List[str] = None, tool_overrides: Dict[str, Any] = None, model_config_override: Dict[str, Any] = None, ) -> Dict[str, Any]: """Build chat payload for Grok app-chat API.""" - attachments = [] - if file_attachments: - attachments.extend(file_attachments) - if image_attachments: - attachments.extend(image_attachments) + attachments = file_attachments or [] payload = { "deviceEnvInfo": { @@ -86,7 +81,6 @@ async def request( model: str, mode: str = None, file_attachments: List[str] = None, - image_attachments: List[str] = None, tool_overrides: Dict[str, Any] = None, model_config_override: Dict[str, Any] = None, ) -> Any: @@ -99,7 +93,6 @@ async def request( model: str, the model to use. mode: str, the mode to use. file_attachments: List[str], the file attachments to send. - image_attachments: List[str], the image attachments to send. tool_overrides: Dict[str, Any], the tool overrides to use. model_config_override: Dict[str, Any], the model config override to use. @@ -125,7 +118,6 @@ async def request( model=model, mode=mode, file_attachments=file_attachments, - image_attachments=image_attachments, tool_overrides=tool_overrides, model_config_override=model_config_override, ) diff --git a/app/services/token/manager.py b/app/services/token/manager.py index 68bfd847..c543d814 100644 --- a/app/services/token/manager.py +++ b/app/services/token/manager.py @@ -389,7 +389,7 @@ async def sync_usage( status = e.details["status"] else: status = getattr(e, "status_code", None) - if status in (401, 403): + if status == 401: await self.record_fail(token_str, status, "rate_limits_auth_failed") logger.warning( f"Token {raw_token[:10]}...: API sync failed, fallback to local ({e})" @@ -424,7 +424,7 @@ async def record_fail( for pool in self.pools.values(): token = pool.get(raw_token) if token: - if status_code in (401, 403): + if status_code == 401: token.record_fail(status_code, reason) logger.warning( f"Token {raw_token[:10]}...: recorded {status_code} failure " diff --git a/app/services/token/models.py b/app/services/token/models.py index 0701ab7b..d2853d0a 100644 --- a/app/services/token/models.py +++ b/app/services/token/models.py @@ -130,8 +130,8 @@ def reset(self, default_quota: Optional[int] = None): def record_fail(self, status_code: int = 401, reason: str = ""): """记录失败,达到阈值后自动标记为 expired""" - # 401/403 错误计入失败(都表示认证/授权失败) - if status_code not in (401, 403): + # 仅 401 计入失败 + if status_code != 401: return self.fail_count += 1 From 7f0db522ae997d8519177c0bd3c77994a558a013 Mon Sep 17 00:00:00 2001 From: Chenyme <118253778+chenyme@users.noreply.github.com> Date: Sun, 15 Feb 2026 00:39:30 +0800 Subject: [PATCH 21/27] refactor: enhance configuration management by adding public API key support, reorganizing chat and image settings, and removing deprecated components --- app/api/v1/admin.py | 1494 ---------------------- app/api/v1/admin/__init__.py | 15 + app/api/v1/admin/cache.py | 445 +++++++ app/api/v1/admin/config.py | 53 + app/api/v1/admin/token.py | 395 ++++++ app/api/v1/chat.py | 517 ++++++-- app/api/v1/image.py | 90 +- app/api/v1/models.py | 2 +- app/api/v1/pages.py | 94 ++ app/api/v1/public/__init__.py | 13 + app/api/v1/public/imagine.py | 449 +++++++ app/api/v1/public/voice.py | 80 ++ app/core/auth.py | 63 +- app/core/config.py | 67 +- app/core/exceptions.py | 1 - app/services/grok/batch_services/nsfw.py | 2 +- app/services/grok/defaults.py | 48 +- app/services/grok/processors/__init__.py | 20 - app/services/grok/processors/image.py | 505 -------- app/services/grok/processors/video.py | 235 ---- app/services/grok/services/chat.py | 94 +- app/services/grok/services/image.py | 566 ++++++-- app/services/grok/services/image_edit.py | 420 +++++- app/services/grok/services/model.py | 59 +- app/services/grok/services/video.py | 632 ++++++--- app/services/grok/services/voice.py | 2 +- app/services/grok/utils/download.py | 104 +- app/services/grok/utils/process.py | 14 + app/services/grok/utils/retry.py | 45 + app/services/grok/utils/upload.py | 2 +- app/services/reverse/accept_tos.py | 4 +- app/services/reverse/app_chat.py | 14 +- app/services/reverse/assets_delete.py | 6 +- app/services/reverse/assets_download.py | 6 +- app/services/reverse/assets_list.py | 6 +- app/services/reverse/assets_upload.py | 6 +- app/services/reverse/media_post.py | 13 +- app/services/reverse/nsfw_mgmt.py | 4 +- app/services/reverse/rate_limits.py | 4 +- app/services/reverse/set_birth.py | 4 +- app/services/reverse/utils/headers.py | 6 +- app/services/reverse/utils/statsig.py | 2 +- app/services/reverse/utils/websocket.py | 10 +- app/services/reverse/ws_imagine.py | 43 +- app/services/reverse/ws_livekit.py | 10 +- app/services/token/manager.py | 12 +- app/services/token/models.py | 10 +- app/services/token/scheduler.py | 2 +- app/static/cache/cache.html | 2 +- app/static/cache/cache.js | 16 +- app/static/common/admin-auth.js | 101 +- app/static/common/batch-sse.js | 4 +- app/static/common/header.html | 10 +- app/static/common/header.js | 2 +- app/static/common/public-header.html | 26 + app/static/common/public-header.js | 25 + app/static/config/config.html | 2 +- app/static/config/config.js | 131 +- app/static/imagine/imagine.html | 2 +- app/static/imagine/imagine.js | 63 +- app/static/login/login.html | 2 +- app/static/login/login.js | 14 +- app/static/public/login.html | 68 + app/static/public/login.js | 51 + app/static/token/token.html | 2 +- app/static/token/token.js | 12 +- app/static/voice/voice.html | 2 +- app/static/voice/voice.js | 11 +- config.defaults.toml | 114 +- data/config.toml | 52 +- docs/README.en.md | 347 ++--- main.py | 8 +- readme.md | 234 ++-- tests/test_model.py | 463 ------- 74 files changed, 4532 insertions(+), 3920 deletions(-) delete mode 100644 app/api/v1/admin.py create mode 100644 app/api/v1/admin/__init__.py create mode 100644 app/api/v1/admin/cache.py create mode 100644 app/api/v1/admin/config.py create mode 100644 app/api/v1/admin/token.py create mode 100644 app/api/v1/pages.py create mode 100644 app/api/v1/public/__init__.py create mode 100644 app/api/v1/public/imagine.py create mode 100644 app/api/v1/public/voice.py delete mode 100644 app/services/grok/processors/__init__.py delete mode 100644 app/services/grok/processors/image.py delete mode 100644 app/services/grok/processors/video.py create mode 100644 app/services/grok/utils/retry.py create mode 100644 app/static/common/public-header.html create mode 100644 app/static/common/public-header.js create mode 100644 app/static/public/login.html create mode 100644 app/static/public/login.js delete mode 100644 tests/test_model.py diff --git a/app/api/v1/admin.py b/app/api/v1/admin.py deleted file mode 100644 index a05db4a3..00000000 --- a/app/api/v1/admin.py +++ /dev/null @@ -1,1494 +0,0 @@ -from fastapi import ( - APIRouter, - Depends, - HTTPException, - Request, - Query, - WebSocket, - WebSocketDisconnect, -) -from fastapi.responses import HTMLResponse, StreamingResponse, RedirectResponse -from typing import Optional, List, Tuple -from pydantic import BaseModel -from app.core.auth import verify_api_key, verify_app_key, get_admin_api_key -from app.core.config import config, get_config -from app.core.batch import create_task, get_task, expire_task -from app.core.storage import get_storage, LocalStorage, RedisStorage, SQLStorage -from app.core.exceptions import AppException -from app.services.token.manager import get_token_manager -from app.services.grok.batch_services.usage import UsageService -from app.services.grok.batch_services.nsfw import NSFWService -from app.services.grok.batch_services.assets import ListService, DeleteService -import os -import time -import uuid -from pathlib import Path -import aiofiles -import asyncio -import orjson -from app.core.logger import logger -from app.api.v1.image import resolve_aspect_ratio -from app.services.grok.services.voice import VoiceService -from app.services.grok.services.image import ImageGenerationService -from app.services.grok.services.model import ModelService - -TEMPLATE_DIR = Path(__file__).parent.parent.parent / "static" - - -router = APIRouter() - -IMAGINE_SESSION_TTL = 600 -_IMAGINE_SESSIONS: dict[str, dict] = {} -_IMAGINE_SESSIONS_LOCK = asyncio.Lock() - - -async def _cleanup_imagine_sessions(now: float) -> None: - expired = [ - key - for key, info in _IMAGINE_SESSIONS.items() - if now - float(info.get("created_at") or 0) > IMAGINE_SESSION_TTL - ] - for key in expired: - _IMAGINE_SESSIONS.pop(key, None) - - -async def _create_imagine_session(prompt: str, aspect_ratio: str) -> str: - task_id = uuid.uuid4().hex - now = time.time() - async with _IMAGINE_SESSIONS_LOCK: - await _cleanup_imagine_sessions(now) - _IMAGINE_SESSIONS[task_id] = { - "prompt": prompt, - "aspect_ratio": aspect_ratio, - "created_at": now, - } - return task_id - - -async def _get_imagine_session(task_id: str) -> Optional[dict]: - if not task_id: - return None - now = time.time() - async with _IMAGINE_SESSIONS_LOCK: - await _cleanup_imagine_sessions(now) - info = _IMAGINE_SESSIONS.get(task_id) - if not info: - return None - created_at = float(info.get("created_at") or 0) - if now - created_at > IMAGINE_SESSION_TTL: - _IMAGINE_SESSIONS.pop(task_id, None) - return None - return dict(info) - - -async def _delete_imagine_session(task_id: str) -> None: - if not task_id: - return - async with _IMAGINE_SESSIONS_LOCK: - _IMAGINE_SESSIONS.pop(task_id, None) - - -async def _delete_imagine_sessions(task_ids: List[str]) -> int: - if not task_ids: - return 0 - removed = 0 - async with _IMAGINE_SESSIONS_LOCK: - for task_id in task_ids: - if task_id and task_id in _IMAGINE_SESSIONS: - _IMAGINE_SESSIONS.pop(task_id, None) - removed += 1 - return removed - - -def _collect_tokens(data: dict) -> List[str]: - """从请求数据中收集 token 列表""" - tokens = [] - if isinstance(data.get("token"), str) and data["token"].strip(): - tokens.append(data["token"].strip()) - if isinstance(data.get("tokens"), list): - tokens.extend([str(t).strip() for t in data["tokens"] if str(t).strip()]) - return tokens - - -def _dedupe_tokens(tokens: List[str]) -> List[str]: - """去重 token 列表(保持原顺序)""" - return list(dict.fromkeys(tokens)) - - - - -def _mask_token(token: str) -> str: - """掩码 token 显示""" - return f"{token[:8]}...{token[-8:]}" if len(token) > 20 else token - - -async def render_template(filename: str): - """渲染指定模板""" - template_path = TEMPLATE_DIR / filename - if not template_path.exists(): - return HTMLResponse(f"Template {filename} not found.", status_code=404) - - async with aiofiles.open(template_path, "r", encoding="utf-8") as f: - content = await f.read() - return HTMLResponse(content) - - -def _sse_event(payload: dict) -> str: - return f"data: {orjson.dumps(payload).decode()}\n\n" - - -def _verify_stream_api_key(request: Request) -> None: - api_key = get_admin_api_key() - if not api_key: - return - key = request.query_params.get("api_key") - if key != api_key: - raise HTTPException(status_code=401, detail="Invalid authentication token") - - -@router.get("/api/v1/admin/batch/{task_id}/stream") -async def stream_batch(task_id: str, request: Request): - _verify_stream_api_key(request) - task = get_task(task_id) - if not task: - raise HTTPException(status_code=404, detail="Task not found") - - async def event_stream(): - queue = task.attach() - try: - yield _sse_event({"type": "snapshot", **task.snapshot()}) - - final = task.final_event() - if final: - yield _sse_event(final) - return - - while True: - try: - event = await asyncio.wait_for(queue.get(), timeout=15) - except asyncio.TimeoutError: - yield ": ping\n\n" - final = task.final_event() - if final: - yield _sse_event(final) - return - continue - - yield _sse_event(event) - if event.get("type") in ("done", "error", "cancelled"): - return - finally: - task.detach(queue) - - return StreamingResponse(event_stream(), media_type="text/event-stream") - - -@router.post( - "/api/v1/admin/batch/{task_id}/cancel", dependencies=[Depends(verify_api_key)] -) -async def cancel_batch(task_id: str): - task = get_task(task_id) - if not task: - raise HTTPException(status_code=404, detail="Task not found") - task.cancel() - return {"status": "success"} - - -@router.get("/admin", response_class=HTMLResponse, include_in_schema=False) -async def admin_login_page(): - """管理后台登录页""" - return await render_template("login/login.html") - - -@router.get("/", include_in_schema=False) -async def root_redirect(): - return RedirectResponse(url="/admin") - - -@router.get("/admin/config", response_class=HTMLResponse, include_in_schema=False) -async def admin_config_page(): - """配置管理页""" - return await render_template("config/config.html") - - -@router.get("/admin/token", response_class=HTMLResponse, include_in_schema=False) -async def admin_token_page(): - """Token 管理页""" - return await render_template("token/token.html") - - -@router.get("/admin/voice", response_class=HTMLResponse, include_in_schema=False) -async def admin_voice_page(): - """Voice Live 调试页""" - return await render_template("voice/voice.html") - - -@router.get("/admin/imagine", response_class=HTMLResponse, include_in_schema=False) -async def admin_imagine_page(): - """Imagine 图片瀑布流""" - return await render_template("imagine/imagine.html") - - -class VoiceTokenResponse(BaseModel): - token: str - url: str - participant_name: str = "" - room_name: str = "" - - -@router.get( - "/api/v1/admin/voice/token", - dependencies=[Depends(verify_api_key)], - response_model=VoiceTokenResponse, -) -async def admin_voice_token( - voice: str = "ara", - personality: str = "assistant", - speed: float = 1.0, -): - """获取 Grok Voice Mode (LiveKit) Token""" - token_mgr = await get_token_manager() - sso_token = None - for pool_name in ("ssoBasic", "ssoSuper"): - sso_token = token_mgr.get_token(pool_name) - if sso_token: - break - - if not sso_token: - raise AppException( - "No available tokens for voice mode", - code="no_token", - status_code=503, - ) - - service = VoiceService() - try: - data = await service.get_token( - token=sso_token, - voice=voice, - personality=personality, - speed=speed, - ) - token = data.get("token") - if not token: - raise AppException( - "Upstream returned no voice token", - code="upstream_error", - status_code=502, - ) - - return VoiceTokenResponse( - token=token, - url="wss://livekit.grok.com", - participant_name="", - room_name="", - ) - - except Exception as e: - if isinstance(e, AppException): - raise - raise AppException( - f"Voice token error: {str(e)}", - code="voice_error", - status_code=500, - ) - - -async def _verify_imagine_ws_auth(websocket: WebSocket) -> tuple[bool, Optional[str]]: - task_id = websocket.query_params.get("task_id") - if task_id: - info = await _get_imagine_session(task_id) - if info: - return True, task_id - - api_key = get_admin_api_key() - if not api_key: - return True, None - key = websocket.query_params.get("api_key") - return key == api_key, None - - -@router.websocket("/api/v1/admin/imagine/ws") -async def admin_imagine_ws(websocket: WebSocket): - ok, session_id = await _verify_imagine_ws_auth(websocket) - if not ok: - await websocket.close(code=1008) - return - - await websocket.accept() - stop_event = asyncio.Event() - run_task: Optional[asyncio.Task] = None - - async def _send(payload: dict) -> bool: - try: - await websocket.send_text(orjson.dumps(payload).decode()) - return True - except Exception: - return False - - async def _stop_run(): - nonlocal run_task - stop_event.set() - if run_task and not run_task.done(): - run_task.cancel() - try: - await run_task - except Exception: - pass - run_task = None - stop_event.clear() - - async def _run(prompt: str, aspect_ratio: str): - model_id = "grok-imagine-1.0" - model_info = ModelService.get(model_id) - if not model_info or not model_info.is_image: - await _send( - { - "type": "error", - "message": "Image model is not available.", - "code": "model_not_supported", - } - ) - return - - token_mgr = await get_token_manager() - sequence = 0 - run_id = uuid.uuid4().hex - - await _send( - { - "type": "status", - "status": "running", - "prompt": prompt, - "aspect_ratio": aspect_ratio, - "run_id": run_id, - } - ) - - while not stop_event.is_set(): - try: - await token_mgr.reload_if_stale() - token = None - for pool_name in ModelService.pool_candidates_for_model( - model_info.model_id - ): - token = token_mgr.get_token(pool_name) - if token: - break - - if not token: - await _send( - { - "type": "error", - "message": "No available tokens. Please try again later.", - "code": "rate_limit_exceeded", - } - ) - await asyncio.sleep(2) - continue - - start_at = time.time() - result = await ImageGenerationService().generate( - token_mgr=token_mgr, - token=token, - model_info=model_info, - prompt=prompt, - n=6, - response_format="b64_json", - size="1024x1024", - aspect_ratio=aspect_ratio, - stream=False, - use_ws=True, - ) - elapsed_ms = int((time.time() - start_at) * 1000) - - images = [img for img in result.data if img and img != "error"] - if images: - # 一次发送所有 6 张图片 - for img_b64 in images: - sequence += 1 - await _send( - { - "type": "image", - "b64_json": img_b64, - "sequence": sequence, - "created_at": int(time.time() * 1000), - "elapsed_ms": elapsed_ms, - "aspect_ratio": aspect_ratio, - "run_id": run_id, - } - ) - else: - await _send( - { - "type": "error", - "message": "Image generation returned empty data.", - "code": "empty_image", - } - ) - - except asyncio.CancelledError: - break - except Exception as e: - logger.warning(f"Imagine stream error: {e}") - await _send( - { - "type": "error", - "message": str(e), - "code": "internal_error", - } - ) - await asyncio.sleep(1.5) - - await _send({"type": "status", "status": "stopped", "run_id": run_id}) - - try: - while True: - try: - raw = await websocket.receive_text() - except (RuntimeError, WebSocketDisconnect): - # WebSocket already closed or disconnected - break - - try: - payload = orjson.loads(raw) - except Exception: - await _send( - { - "type": "error", - "message": "Invalid message format.", - "code": "invalid_payload", - } - ) - continue - - msg_type = payload.get("type") - if msg_type == "start": - prompt = str(payload.get("prompt") or "").strip() - if not prompt: - await _send( - { - "type": "error", - "message": "Prompt cannot be empty.", - "code": "empty_prompt", - } - ) - continue - ratio = str(payload.get("aspect_ratio") or "2:3").strip() - if not ratio: - ratio = "2:3" - ratio = resolve_aspect_ratio(ratio) - await _stop_run() - stop_event.clear() - run_task = asyncio.create_task(_run(prompt, ratio)) - elif msg_type == "stop": - await _stop_run() - elif msg_type == "ping": - await _send({"type": "pong"}) - else: - await _send( - { - "type": "error", - "message": "Unknown command.", - "code": "unknown_command", - } - ) - except WebSocketDisconnect: - logger.debug("WebSocket disconnected by client") - except Exception as e: - logger.warning(f"WebSocket error: {e}") - finally: - await _stop_run() - - try: - from starlette.websockets import WebSocketState - if websocket.client_state == WebSocketState.CONNECTED: - await websocket.close(code=1000, reason="Server closing connection") - except Exception as e: - logger.debug(f"WebSocket close ignored: {e}") - if session_id: - await _delete_imagine_session(session_id) - - -class ImagineStartRequest(BaseModel): - prompt: str - aspect_ratio: Optional[str] = "2:3" - - -@router.post("/api/v1/admin/imagine/start", dependencies=[Depends(verify_api_key)]) -async def admin_imagine_start(data: ImagineStartRequest): - prompt = (data.prompt or "").strip() - if not prompt: - raise HTTPException(status_code=400, detail="Prompt cannot be empty") - ratio = resolve_aspect_ratio(str(data.aspect_ratio or "2:3").strip() or "2:3") - task_id = await _create_imagine_session(prompt, ratio) - return {"task_id": task_id, "aspect_ratio": ratio} - - -class ImagineStopRequest(BaseModel): - task_ids: List[str] - - -@router.post("/api/v1/admin/imagine/stop", dependencies=[Depends(verify_api_key)]) -async def admin_imagine_stop(data: ImagineStopRequest): - removed = await _delete_imagine_sessions(data.task_ids or []) - return {"status": "success", "removed": removed} - - -@router.get("/api/v1/admin/imagine/sse") -async def admin_imagine_sse( - request: Request, - task_id: str = Query(""), - prompt: str = Query(""), - aspect_ratio: str = Query("2:3"), -): - """Imagine 图片瀑布流(SSE 兜底)""" - session = None - if task_id: - session = await _get_imagine_session(task_id) - if not session: - raise HTTPException(status_code=404, detail="Task not found") - else: - _verify_stream_api_key(request) - - if session: - prompt = str(session.get("prompt") or "").strip() - ratio = str(session.get("aspect_ratio") or "2:3").strip() or "2:3" - else: - prompt = (prompt or "").strip() - if not prompt: - raise HTTPException(status_code=400, detail="Prompt cannot be empty") - ratio = str(aspect_ratio or "2:3").strip() or "2:3" - ratio = resolve_aspect_ratio(ratio) - - async def event_stream(): - try: - model_id = "grok-imagine-1.0" - model_info = ModelService.get(model_id) - if not model_info or not model_info.is_image: - yield _sse_event( - { - "type": "error", - "message": "Image model is not available.", - "code": "model_not_supported", - } - ) - return - - token_mgr = await get_token_manager() - sequence = 0 - run_id = uuid.uuid4().hex - - yield _sse_event( - { - "type": "status", - "status": "running", - "prompt": prompt, - "aspect_ratio": ratio, - "run_id": run_id, - } - ) - - while True: - if await request.is_disconnected(): - break - if task_id: - session_alive = await _get_imagine_session(task_id) - if not session_alive: - break - - try: - await token_mgr.reload_if_stale() - token = None - for pool_name in ModelService.pool_candidates_for_model( - model_info.model_id - ): - token = token_mgr.get_token(pool_name) - if token: - break - - if not token: - yield _sse_event( - { - "type": "error", - "message": "No available tokens. Please try again later.", - "code": "rate_limit_exceeded", - } - ) - await asyncio.sleep(2) - continue - - start_at = time.time() - result = await ImageGenerationService().generate( - token_mgr=token_mgr, - token=token, - model_info=model_info, - prompt=prompt, - n=6, - response_format="b64_json", - size="1024x1024", - aspect_ratio=ratio, - stream=False, - use_ws=True, - ) - elapsed_ms = int((time.time() - start_at) * 1000) - - images = [img for img in result.data if img and img != "error"] - if images: - for img_b64 in images: - sequence += 1 - yield _sse_event( - { - "type": "image", - "b64_json": img_b64, - "sequence": sequence, - "created_at": int(time.time() * 1000), - "elapsed_ms": elapsed_ms, - "aspect_ratio": ratio, - "run_id": run_id, - } - ) - else: - yield _sse_event( - { - "type": "error", - "message": "Image generation returned empty data.", - "code": "empty_image", - } - ) - except asyncio.CancelledError: - break - except Exception as e: - logger.warning(f"Imagine SSE error: {e}") - yield _sse_event( - {"type": "error", "message": str(e), "code": "internal_error"} - ) - await asyncio.sleep(1.5) - - yield _sse_event({"type": "status", "status": "stopped", "run_id": run_id}) - finally: - if task_id: - await _delete_imagine_session(task_id) - - return StreamingResponse( - event_stream(), - media_type="text/event-stream", - headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, - ) - - -@router.post("/api/v1/admin/login", dependencies=[Depends(verify_app_key)]) -async def admin_login_api(): - """管理后台登录验证(使用 app_key)""" - return {"status": "success", "api_key": get_admin_api_key()} - - -@router.get("/api/v1/admin/config", dependencies=[Depends(verify_api_key)]) -async def get_config_api(): - """获取当前配置""" - # 暴露原始配置字典 - return config._config - - -@router.post("/api/v1/admin/config", dependencies=[Depends(verify_api_key)]) -async def update_config_api(data: dict): - """更新配置""" - try: - await config.update(data) - return {"status": "success", "message": "配置已更新"} - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -@router.get("/api/v1/admin/storage", dependencies=[Depends(verify_api_key)]) -async def get_storage_info(): - """获取当前存储模式""" - storage_type = os.getenv("SERVER_STORAGE_TYPE", "").lower() - if not storage_type: - storage_type = str(get_config("storage.type")).lower() - if not storage_type: - storage = get_storage() - if isinstance(storage, LocalStorage): - storage_type = "local" - elif isinstance(storage, RedisStorage): - storage_type = "redis" - elif isinstance(storage, SQLStorage): - storage_type = { - "mysql": "mysql", - "mariadb": "mysql", - "postgres": "pgsql", - "postgresql": "pgsql", - "pgsql": "pgsql", - }.get(storage.dialect, storage.dialect) - return {"type": storage_type or "local"} - - -@router.get("/api/v1/admin/tokens", dependencies=[Depends(verify_api_key)]) -async def get_tokens_api(): - """获取所有 Token""" - storage = get_storage() - tokens = await storage.load_tokens() - return tokens or {} - - -@router.post("/api/v1/admin/tokens", dependencies=[Depends(verify_api_key)]) -async def update_tokens_api(data: dict): - """更新 Token 信息""" - storage = get_storage() - try: - from app.services.token.manager import get_token_manager - from app.services.token.models import TokenInfo - - async with storage.acquire_lock("tokens_save", timeout=10): - existing = await storage.load_tokens() or {} - normalized = {} - allowed_fields = set(TokenInfo.model_fields.keys()) - existing_map = {} - for pool_name, tokens in existing.items(): - if not isinstance(tokens, list): - continue - pool_map = {} - for item in tokens: - if isinstance(item, str): - token_data = {"token": item} - elif isinstance(item, dict): - token_data = dict(item) - else: - continue - raw_token = token_data.get("token") - if isinstance(raw_token, str) and raw_token.startswith("sso="): - token_data["token"] = raw_token[4:] - token_key = token_data.get("token") - if isinstance(token_key, str): - pool_map[token_key] = token_data - existing_map[pool_name] = pool_map - for pool_name, tokens in (data or {}).items(): - if not isinstance(tokens, list): - continue - pool_list = [] - for item in tokens: - if isinstance(item, str): - token_data = {"token": item} - elif isinstance(item, dict): - token_data = dict(item) - else: - continue - - raw_token = token_data.get("token") - if isinstance(raw_token, str) and raw_token.startswith("sso="): - token_data["token"] = raw_token[4:] - - base = existing_map.get(pool_name, {}).get( - token_data.get("token"), {} - ) - merged = dict(base) - merged.update(token_data) - if merged.get("tags") is None: - merged["tags"] = [] - - filtered = {k: v for k, v in merged.items() if k in allowed_fields} - try: - info = TokenInfo(**filtered) - pool_list.append(info.model_dump()) - except Exception as e: - logger.warning(f"Skip invalid token in pool '{pool_name}': {e}") - continue - normalized[pool_name] = pool_list - - await storage.save_tokens(normalized) - mgr = await get_token_manager() - await mgr.reload() - return {"status": "success", "message": "Token 已更新"} - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -@router.post("/api/v1/admin/tokens/refresh", dependencies=[Depends(verify_api_key)]) -async def refresh_tokens_api(data: dict): - """刷新 Token 状态""" - try: - mgr = await get_token_manager() - tokens = _collect_tokens(data) - - if not tokens: - raise HTTPException(status_code=400, detail="No tokens provided") - - # 去重 - unique_tokens = _dedupe_tokens(tokens) - - raw_results = await UsageService.batch( - unique_tokens, - mgr, - ) - - results = {} - for token, res in raw_results.items(): - if res.get("ok"): - results[token] = res.get("data", False) - else: - results[token] = False - - response = {"status": "success", "results": results} - return response - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -@router.post( - "/api/v1/admin/tokens/refresh/async", dependencies=[Depends(verify_api_key)] -) -async def refresh_tokens_api_async(data: dict): - """刷新 Token 状态(异步批量 + SSE 进度)""" - mgr = await get_token_manager() - tokens = _collect_tokens(data) - - if not tokens: - raise HTTPException(status_code=400, detail="No tokens provided") - - # 去重 - unique_tokens = _dedupe_tokens(tokens) - - task = create_task(len(unique_tokens)) - - async def _run(): - try: - - async def _on_item(item: str, res: dict): - task.record(bool(res.get("ok"))) - - raw_results = await UsageService.batch( - unique_tokens, - mgr, - on_item=_on_item, - should_cancel=lambda: task.cancelled, - ) - - if task.cancelled: - task.finish_cancelled() - return - - results: dict[str, bool] = {} - ok_count = 0 - fail_count = 0 - for token, res in raw_results.items(): - if res.get("ok") and res.get("data") is True: - ok_count += 1 - results[token] = True - else: - fail_count += 1 - results[token] = False - - await mgr._save() - - result = { - "status": "success", - "summary": { - "total": len(unique_tokens), - "ok": ok_count, - "fail": fail_count, - }, - "results": results, - } - task.finish(result) - except Exception as e: - task.fail_task(str(e)) - finally: - asyncio.create_task(expire_task(task.id, 300)) - - asyncio.create_task(_run()) - - return { - "status": "success", - "task_id": task.id, - "total": len(unique_tokens), - } - - -@router.post("/api/v1/admin/tokens/nsfw/enable", dependencies=[Depends(verify_api_key)]) -async def enable_nsfw_api(data: dict): - """批量开启 NSFW (Unhinged) 模式""" - try: - mgr = await get_token_manager() - - # 收集 token 列表 - tokens = _collect_tokens(data) - - # 若未指定,则使用所有 pool 中的 token - if not tokens: - for pool_name, pool in mgr.pools.items(): - for info in pool.list(): - raw = ( - info.token[4:] if info.token.startswith("sso=") else info.token - ) - tokens.append(raw) - - if not tokens: - raise HTTPException(status_code=400, detail="No tokens available") - - # 去重 - unique_tokens = _dedupe_tokens(tokens) - - raw_results = await NSFWService.batch( - unique_tokens, - mgr, - ) - - # 构造返回结果(mask token) - results = {} - ok_count = 0 - fail_count = 0 - - for token, res in raw_results.items(): - masked = _mask_token(token) - if res.get("ok") and res.get("data", {}).get("success"): - ok_count += 1 - results[masked] = res.get("data", {}) - else: - fail_count += 1 - results[masked] = res.get("data") or {"error": res.get("error")} - - response = { - "status": "success", - "summary": { - "total": len(unique_tokens), - "ok": ok_count, - "fail": fail_count, - }, - "results": results, - } - - # 添加截断提示 - return response - - except HTTPException: - raise - except Exception as e: - logger.error(f"Enable NSFW failed: {e}") - raise HTTPException(status_code=500, detail=str(e)) - - -@router.post( - "/api/v1/admin/tokens/nsfw/enable/async", dependencies=[Depends(verify_api_key)] -) -async def enable_nsfw_api_async(data: dict): - """批量开启 NSFW (Unhinged) 模式(异步批量 + SSE 进度)""" - mgr = await get_token_manager() - - tokens = _collect_tokens(data) - - if not tokens: - for pool_name, pool in mgr.pools.items(): - for info in pool.list(): - raw = info.token[4:] if info.token.startswith("sso=") else info.token - tokens.append(raw) - - if not tokens: - raise HTTPException(status_code=400, detail="No tokens available") - - # 去重 - unique_tokens = _dedupe_tokens(tokens) - - task = create_task(len(unique_tokens)) - - async def _run(): - try: - - async def _on_item(item: str, res: dict): - ok = bool(res.get("ok") and res.get("data", {}).get("success")) - task.record(ok) - - raw_results = await NSFWService.batch( - unique_tokens, - mgr, - on_item=_on_item, - should_cancel=lambda: task.cancelled, - ) - - if task.cancelled: - task.finish_cancelled() - return - - results = {} - ok_count = 0 - fail_count = 0 - for token, res in raw_results.items(): - masked = f"{token[:8]}...{token[-8:]}" if len(token) > 20 else token - if res.get("ok") and res.get("data", {}).get("success"): - ok_count += 1 - results[masked] = res.get("data", {}) - else: - fail_count += 1 - results[masked] = res.get("data") or {"error": res.get("error")} - - await mgr._save() - - result = { - "status": "success", - "summary": { - "total": len(unique_tokens), - "ok": ok_count, - "fail": fail_count, - }, - "results": results, - } - task.finish(result) - except Exception as e: - task.fail_task(str(e)) - finally: - asyncio.create_task(expire_task(task.id, 300)) - - asyncio.create_task(_run()) - - return { - "status": "success", - "task_id": task.id, - "total": len(unique_tokens), - } - - -@router.get("/admin/cache", response_class=HTMLResponse, include_in_schema=False) -async def admin_cache_page(): - """缓存管理页""" - return await render_template("cache/cache.html") - - -@router.get("/api/v1/admin/cache", dependencies=[Depends(verify_api_key)]) -async def get_cache_stats_api(request: Request): - """获取缓存统计""" - from app.services.grok.utils.cache import CacheService - from app.services.token.manager import get_token_manager - - try: - cache_service = CacheService() - image_stats = cache_service.get_stats("image") - video_stats = cache_service.get_stats("video") - - mgr = await get_token_manager() - pools = mgr.pools - accounts = [] - for pool_name, pool in pools.items(): - for info in pool.list(): - raw_token = ( - info.token[4:] if info.token.startswith("sso=") else info.token - ) - masked = ( - f"{raw_token[:8]}...{raw_token[-16:]}" - if len(raw_token) > 24 - else raw_token - ) - accounts.append( - { - "token": raw_token, - "token_masked": masked, - "pool": pool_name, - "status": info.status, - "last_asset_clear_at": info.last_asset_clear_at, - } - ) - - scope = request.query_params.get("scope") - selected_token = request.query_params.get("token") - tokens_param = request.query_params.get("tokens") - selected_tokens = [] - if tokens_param: - selected_tokens = [t.strip() for t in tokens_param.split(",") if t.strip()] - - online_stats = { - "count": 0, - "status": "unknown", - "token": None, - "last_asset_clear_at": None, - } - online_details = [] - account_map = {a["token"]: a for a in accounts} - if selected_tokens: - total = 0 - raw_results = await ListService.fetch_assets_details( - selected_tokens, - account_map, - ) - for token, res in raw_results.items(): - if res.get("ok"): - data = res.get("data", {}) - detail = data.get("detail") - total += data.get("count", 0) - else: - account = account_map.get(token) - detail = { - "token": token, - "token_masked": account["token_masked"] if account else token, - "count": 0, - "status": f"error: {res.get('error')}", - "last_asset_clear_at": account["last_asset_clear_at"] - if account - else None, - } - if detail: - online_details.append(detail) - online_stats = { - "count": total, - "status": "ok" if selected_tokens else "no_token", - "token": None, - "last_asset_clear_at": None, - } - scope = "selected" - elif scope == "all": - total = 0 - tokens = list(dict.fromkeys([account["token"] for account in accounts])) - raw_results = await ListService.fetch_assets_details( - tokens, - account_map, - ) - for token, res in raw_results.items(): - if res.get("ok"): - data = res.get("data", {}) - detail = data.get("detail") - total += data.get("count", 0) - else: - account = account_map.get(token) - detail = { - "token": token, - "token_masked": account["token_masked"] if account else token, - "count": 0, - "status": f"error: {res.get('error')}", - "last_asset_clear_at": account["last_asset_clear_at"] - if account - else None, - } - if detail: - online_details.append(detail) - online_stats = { - "count": total, - "status": "ok" if accounts else "no_token", - "token": None, - "last_asset_clear_at": None, - } - else: - token = selected_token - if token: - raw_results = await ListService.fetch_assets_details( - [token], - account_map, - ) - res = raw_results.get(token, {}) - data = res.get("data", {}) - detail = data.get("detail") if res.get("ok") else None - if detail: - online_stats = { - "count": data.get("count", 0), - "status": detail.get("status", "ok"), - "token": detail.get("token"), - "token_masked": detail.get("token_masked"), - "last_asset_clear_at": detail.get("last_asset_clear_at"), - } - else: - match = next((a for a in accounts if a["token"] == token), None) - online_stats = { - "count": 0, - "status": f"error: {res.get('error')}", - "token": token, - "token_masked": match["token_masked"] if match else token, - "last_asset_clear_at": match["last_asset_clear_at"] - if match - else None, - } - else: - online_stats = { - "count": 0, - "status": "not_loaded", - "token": None, - "last_asset_clear_at": None, - } - - response = { - "local_image": image_stats, - "local_video": video_stats, - "online": online_stats, - "online_accounts": accounts, - "online_scope": scope or "none", - "online_details": online_details, - } - return response - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -@router.post( - "/api/v1/admin/cache/online/load/async", dependencies=[Depends(verify_api_key)] -) -async def load_online_cache_api_async(data: dict): - """在线资产统计(异步批量 + SSE 进度)""" - from app.services.grok.utils.cache import CacheService - from app.services.token.manager import get_token_manager - - mgr = await get_token_manager() - - # 账号列表 - accounts = [] - for pool_name, pool in mgr.pools.items(): - for info in pool.list(): - raw_token = info.token[4:] if info.token.startswith("sso=") else info.token - masked = ( - f"{raw_token[:8]}...{raw_token[-16:]}" - if len(raw_token) > 24 - else raw_token - ) - accounts.append( - { - "token": raw_token, - "token_masked": masked, - "pool": pool_name, - "status": info.status, - "last_asset_clear_at": info.last_asset_clear_at, - } - ) - - account_map = {a["token"]: a for a in accounts} - - tokens = data.get("tokens") - scope = data.get("scope") - selected_tokens: List[str] = [] - if isinstance(tokens, list): - selected_tokens = [str(t).strip() for t in tokens if str(t).strip()] - - if not selected_tokens and scope == "all": - selected_tokens = [account["token"] for account in accounts] - scope = "all" - elif selected_tokens: - scope = "selected" - else: - raise HTTPException(status_code=400, detail="No tokens provided") - - task = create_task(len(selected_tokens)) - - async def _run(): - try: - cache_service = CacheService() - image_stats = cache_service.get_stats("image") - video_stats = cache_service.get_stats("video") - - async def _on_item(item: str, res: dict): - ok = bool(res.get("data", {}).get("ok")) - task.record(ok) - - raw_results = await ListService.fetch_assets_details( - selected_tokens, - account_map, - include_ok=True, - on_item=_on_item, - should_cancel=lambda: task.cancelled, - ) - - if task.cancelled: - task.finish_cancelled() - return - - online_details = [] - total = 0 - for token, res in raw_results.items(): - data = res.get("data", {}) - detail = data.get("detail") - if detail: - online_details.append(detail) - total += data.get("count", 0) - - online_stats = { - "count": total, - "status": "ok" if selected_tokens else "no_token", - "token": None, - "last_asset_clear_at": None, - } - - result = { - "local_image": image_stats, - "local_video": video_stats, - "online": online_stats, - "online_accounts": accounts, - "online_scope": scope or "none", - "online_details": online_details, - } - task.finish(result) - except Exception as e: - task.fail_task(str(e)) - finally: - asyncio.create_task(expire_task(task.id, 300)) - - asyncio.create_task(_run()) - - return { - "status": "success", - "task_id": task.id, - "total": len(selected_tokens), - } - - -@router.post("/api/v1/admin/cache/clear", dependencies=[Depends(verify_api_key)]) -async def clear_local_cache_api(data: dict): - """清理本地缓存""" - from app.services.grok.utils.cache import CacheService - - cache_type = data.get("type", "image") - - try: - cache_service = CacheService() - result = cache_service.clear(cache_type) - return {"status": "success", "result": result} - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -@router.get("/api/v1/admin/cache/list", dependencies=[Depends(verify_api_key)]) -async def list_local_cache_api( - cache_type: str = "image", - type_: str = Query(default=None, alias="type"), - page: int = 1, - page_size: int = 1000, -): - """列出本地缓存文件""" - from app.services.grok.utils.cache import CacheService - - try: - if type_: - cache_type = type_ - cache_service = CacheService() - result = cache_service.list_files(cache_type, page, page_size) - return {"status": "success", **result} - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -@router.post("/api/v1/admin/cache/item/delete", dependencies=[Depends(verify_api_key)]) -async def delete_local_cache_item_api(data: dict): - """删除单个本地缓存文件""" - from app.services.grok.utils.cache import CacheService - - cache_type = data.get("type", "image") - name = data.get("name") - if not name: - raise HTTPException(status_code=400, detail="Missing file name") - try: - cache_service = CacheService() - result = cache_service.delete_file(cache_type, name) - return {"status": "success", "result": result} - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -@router.post("/api/v1/admin/cache/online/clear", dependencies=[Depends(verify_api_key)]) -async def clear_online_cache_api(data: dict): - """清理在线缓存""" - from app.services.token.manager import get_token_manager - try: - mgr = await get_token_manager() - tokens = data.get("tokens") - - if isinstance(tokens, list): - token_list = [t.strip() for t in tokens if isinstance(t, str) and t.strip()] - if not token_list: - raise HTTPException(status_code=400, detail="No tokens provided") - - # 去重并保持顺序 - token_list = list(dict.fromkeys(token_list)) - - results = {} - raw_results = await DeleteService.clear_assets( - token_list, - mgr, - ) - for token, res in raw_results.items(): - if res.get("ok"): - results[token] = res.get("data", {}) - else: - results[token] = {"status": "error", "error": res.get("error")} - - return {"status": "success", "results": results} - - token = data.get("token") or mgr.get_token() - if not token: - raise HTTPException( - status_code=400, detail="No available token to perform cleanup" - ) - - raw_results = await DeleteService.clear_assets( - [token], - mgr, - ) - res = raw_results.get(token, {}) - data = res.get("data", {}) - if res.get("ok") and data.get("status") == "success": - return {"status": "success", "result": data.get("result")} - return {"status": "error", "error": data.get("error") or res.get("error")} - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -@router.post( - "/api/v1/admin/cache/online/clear/async", dependencies=[Depends(verify_api_key)] -) -async def clear_online_cache_api_async(data: dict): - """清理在线缓存(异步批量 + SSE 进度)""" - from app.services.token.manager import get_token_manager - - mgr = await get_token_manager() - tokens = data.get("tokens") - if not isinstance(tokens, list): - raise HTTPException(status_code=400, detail="No tokens provided") - - token_list = [t.strip() for t in tokens if isinstance(t, str) and t.strip()] - if not token_list: - raise HTTPException(status_code=400, detail="No tokens provided") - - task = create_task(len(token_list)) - - async def _run(): - try: - async def _on_item(item: str, res: dict): - ok = bool(res.get("data", {}).get("ok")) - task.record(ok) - - raw_results = await DeleteService.clear_assets( - token_list, - mgr, - include_ok=True, - on_item=_on_item, - should_cancel=lambda: task.cancelled, - ) - - if task.cancelled: - task.finish_cancelled() - return - - results = {} - ok_count = 0 - fail_count = 0 - for token, res in raw_results.items(): - data = res.get("data", {}) - if data.get("ok"): - ok_count += 1 - results[token] = {"status": "success", "result": data.get("result")} - else: - fail_count += 1 - results[token] = {"status": "error", "error": data.get("error")} - - result = { - "status": "success", - "summary": { - "total": len(token_list), - "ok": ok_count, - "fail": fail_count, - }, - "results": results, - } - task.finish(result) - except Exception as e: - task.fail_task(str(e)) - finally: - asyncio.create_task(expire_task(task.id, 300)) - - asyncio.create_task(_run()) - - return { - "status": "success", - "task_id": task.id, - "total": len(token_list), - } diff --git a/app/api/v1/admin/__init__.py b/app/api/v1/admin/__init__.py new file mode 100644 index 00000000..63db93d7 --- /dev/null +++ b/app/api/v1/admin/__init__.py @@ -0,0 +1,15 @@ +"""Admin API router (app_key protected).""" + +from fastapi import APIRouter + +from app.api.v1.admin.cache import router as cache_router +from app.api.v1.admin.config import router as config_router +from app.api.v1.admin.token import router as tokens_router + +router = APIRouter() + +router.include_router(config_router) +router.include_router(tokens_router) +router.include_router(cache_router) + +__all__ = ["router"] diff --git a/app/api/v1/admin/cache.py b/app/api/v1/admin/cache.py new file mode 100644 index 00000000..0dc902a7 --- /dev/null +++ b/app/api/v1/admin/cache.py @@ -0,0 +1,445 @@ +from typing import List + +from fastapi import APIRouter, Depends, HTTPException, Query, Request + +from app.core.auth import verify_app_key +from app.core.batch import create_task, expire_task +from app.services.grok.batch_services.assets import ListService, DeleteService +from app.services.token.manager import get_token_manager +router = APIRouter() + + +@router.get("/cache", dependencies=[Depends(verify_app_key)]) +async def cache_stats(request: Request): + """获取缓存统计""" + from app.services.grok.utils.cache import CacheService + + try: + cache_service = CacheService() + image_stats = cache_service.get_stats("image") + video_stats = cache_service.get_stats("video") + + mgr = await get_token_manager() + pools = mgr.pools + accounts = [] + for pool_name, pool in pools.items(): + for info in pool.list(): + raw_token = ( + info.token[4:] if info.token.startswith("sso=") else info.token + ) + masked = ( + f"{raw_token[:8]}...{raw_token[-16:]}" + if len(raw_token) > 24 + else raw_token + ) + accounts.append( + { + "token": raw_token, + "token_masked": masked, + "pool": pool_name, + "status": info.status, + "last_asset_clear_at": info.last_asset_clear_at, + } + ) + + scope = request.query_params.get("scope") + selected_token = request.query_params.get("token") + tokens_param = request.query_params.get("tokens") + selected_tokens = [] + if tokens_param: + selected_tokens = [t.strip() for t in tokens_param.split(",") if t.strip()] + + online_stats = { + "count": 0, + "status": "unknown", + "token": None, + "last_asset_clear_at": None, + } + online_details = [] + account_map = {a["token"]: a for a in accounts} + if selected_tokens: + total = 0 + raw_results = await ListService.fetch_assets_details( + selected_tokens, + account_map, + ) + for token, res in raw_results.items(): + if res.get("ok"): + data = res.get("data", {}) + detail = data.get("detail") + total += data.get("count", 0) + else: + account = account_map.get(token) + detail = { + "token": token, + "token_masked": account["token_masked"] if account else token, + "count": 0, + "status": f"error: {res.get('error')}", + "last_asset_clear_at": account["last_asset_clear_at"] + if account + else None, + } + if detail: + online_details.append(detail) + online_stats = { + "count": total, + "status": "ok" if selected_tokens else "no_token", + "token": None, + "last_asset_clear_at": None, + } + scope = "selected" + elif scope == "all": + total = 0 + tokens = list(dict.fromkeys([account["token"] for account in accounts])) + raw_results = await ListService.fetch_assets_details( + tokens, + account_map, + ) + for token, res in raw_results.items(): + if res.get("ok"): + data = res.get("data", {}) + detail = data.get("detail") + total += data.get("count", 0) + else: + account = account_map.get(token) + detail = { + "token": token, + "token_masked": account["token_masked"] if account else token, + "count": 0, + "status": f"error: {res.get('error')}", + "last_asset_clear_at": account["last_asset_clear_at"] + if account + else None, + } + if detail: + online_details.append(detail) + online_stats = { + "count": total, + "status": "ok" if accounts else "no_token", + "token": None, + "last_asset_clear_at": None, + } + else: + token = selected_token + if token: + raw_results = await ListService.fetch_assets_details( + [token], + account_map, + ) + res = raw_results.get(token, {}) + data = res.get("data", {}) + detail = data.get("detail") if res.get("ok") else None + if detail: + online_stats = { + "count": data.get("count", 0), + "status": detail.get("status", "ok"), + "token": detail.get("token"), + "token_masked": detail.get("token_masked"), + "last_asset_clear_at": detail.get("last_asset_clear_at"), + } + else: + match = next((a for a in accounts if a["token"] == token), None) + online_stats = { + "count": 0, + "status": f"error: {res.get('error')}", + "token": token, + "token_masked": match["token_masked"] if match else token, + "last_asset_clear_at": match["last_asset_clear_at"] + if match + else None, + } + else: + online_stats = { + "count": 0, + "status": "not_loaded", + "token": None, + "last_asset_clear_at": None, + } + + response = { + "local_image": image_stats, + "local_video": video_stats, + "online": online_stats, + "online_accounts": accounts, + "online_scope": scope or "none", + "online_details": online_details, + } + return response + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/cache/list", dependencies=[Depends(verify_app_key)]) +async def list_local( + cache_type: str = "image", + type_: str = Query(default=None, alias="type"), + page: int = 1, + page_size: int = 1000, +): + """列出本地缓存文件""" + from app.services.grok.utils.cache import CacheService + + try: + if type_: + cache_type = type_ + cache_service = CacheService() + result = cache_service.list_files(cache_type, page, page_size) + return {"status": "success", **result} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/cache/clear", dependencies=[Depends(verify_app_key)]) +async def clear_local(data: dict): + """清理本地缓存""" + from app.services.grok.utils.cache import CacheService + + cache_type = data.get("type", "image") + + try: + cache_service = CacheService() + result = cache_service.clear(cache_type) + return {"status": "success", "result": result} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/cache/item/delete", dependencies=[Depends(verify_app_key)]) +async def delete_local_item(data: dict): + """删除单个本地缓存文件""" + from app.services.grok.utils.cache import CacheService + + cache_type = data.get("type", "image") + name = data.get("name") + if not name: + raise HTTPException(status_code=400, detail="Missing file name") + try: + cache_service = CacheService() + result = cache_service.delete_file(cache_type, name) + return {"status": "success", "result": result} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/cache/online/clear", dependencies=[Depends(verify_app_key)]) +async def clear_online(data: dict): + """清理在线缓存""" + try: + mgr = await get_token_manager() + tokens = data.get("tokens") + + if isinstance(tokens, list): + token_list = [t.strip() for t in tokens if isinstance(t, str) and t.strip()] + if not token_list: + raise HTTPException(status_code=400, detail="No tokens provided") + + token_list = list(dict.fromkeys(token_list)) + + results = {} + raw_results = await DeleteService.clear_assets( + token_list, + mgr, + ) + for token, res in raw_results.items(): + if res.get("ok"): + results[token] = res.get("data", {}) + else: + results[token] = {"status": "error", "error": res.get("error")} + + return {"status": "success", "results": results} + + token = data.get("token") or mgr.get_token() + if not token: + raise HTTPException( + status_code=400, detail="No available token to perform cleanup" + ) + + raw_results = await DeleteService.clear_assets( + [token], + mgr, + ) + res = raw_results.get(token, {}) + data = res.get("data", {}) + if res.get("ok") and data.get("status") == "success": + return {"status": "success", "result": data.get("result")} + return {"status": "error", "error": data.get("error") or res.get("error")} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/cache/online/clear/async", dependencies=[Depends(verify_app_key)]) +async def clear_online_async(data: dict): + """清理在线缓存(异步批量 + SSE 进度)""" + mgr = await get_token_manager() + tokens = data.get("tokens") + if not isinstance(tokens, list): + raise HTTPException(status_code=400, detail="No tokens provided") + + token_list = [t.strip() for t in tokens if isinstance(t, str) and t.strip()] + if not token_list: + raise HTTPException(status_code=400, detail="No tokens provided") + + task = create_task(len(token_list)) + + async def _run(): + try: + async def _on_item(item: str, res: dict): + ok = bool(res.get("data", {}).get("ok")) + task.record(ok) + + raw_results = await DeleteService.clear_assets( + token_list, + mgr, + include_ok=True, + on_item=_on_item, + should_cancel=lambda: task.cancelled, + ) + + if task.cancelled: + task.finish_cancelled() + return + + results = {} + ok_count = 0 + fail_count = 0 + for token, res in raw_results.items(): + data = res.get("data", {}) + if data.get("ok"): + ok_count += 1 + results[token] = {"status": "success", "result": data.get("result")} + else: + fail_count += 1 + results[token] = {"status": "error", "error": data.get("error")} + + result = { + "status": "success", + "summary": { + "total": len(token_list), + "ok": ok_count, + "fail": fail_count, + }, + "results": results, + } + task.finish(result) + except Exception as e: + task.fail_task(str(e)) + finally: + import asyncio + asyncio.create_task(expire_task(task.id, 300)) + + import asyncio + asyncio.create_task(_run()) + + return { + "status": "success", + "task_id": task.id, + "total": len(token_list), + } + + +@router.post("/cache/online/load/async", dependencies=[Depends(verify_app_key)]) +async def load_cache_async(data: dict): + """在线资产统计(异步批量 + SSE 进度)""" + from app.services.grok.utils.cache import CacheService + + mgr = await get_token_manager() + + accounts = [] + for pool_name, pool in mgr.pools.items(): + for info in pool.list(): + raw_token = info.token[4:] if info.token.startswith("sso=") else info.token + masked = ( + f"{raw_token[:8]}...{raw_token[-16:]}" + if len(raw_token) > 24 + else raw_token + ) + accounts.append( + { + "token": raw_token, + "token_masked": masked, + "pool": pool_name, + "status": info.status, + "last_asset_clear_at": info.last_asset_clear_at, + } + ) + + account_map = {a["token"]: a for a in accounts} + + tokens = data.get("tokens") + scope = data.get("scope") + selected_tokens: List[str] = [] + if isinstance(tokens, list): + selected_tokens = [str(t).strip() for t in tokens if str(t).strip()] + + if not selected_tokens and scope == "all": + selected_tokens = [account["token"] for account in accounts] + scope = "all" + elif selected_tokens: + scope = "selected" + else: + raise HTTPException(status_code=400, detail="No tokens provided") + + task = create_task(len(selected_tokens)) + + async def _run(): + try: + cache_service = CacheService() + image_stats = cache_service.get_stats("image") + video_stats = cache_service.get_stats("video") + + async def _on_item(item: str, res: dict): + ok = bool(res.get("data", {}).get("ok")) + task.record(ok) + + raw_results = await ListService.fetch_assets_details( + selected_tokens, + account_map, + include_ok=True, + on_item=_on_item, + should_cancel=lambda: task.cancelled, + ) + + if task.cancelled: + task.finish_cancelled() + return + + online_details = [] + total = 0 + for token, res in raw_results.items(): + data = res.get("data", {}) + detail = data.get("detail") + if detail: + online_details.append(detail) + total += data.get("count", 0) + + online_stats = { + "count": total, + "status": "ok" if selected_tokens else "no_token", + "token": None, + "last_asset_clear_at": None, + } + + result = { + "local_image": image_stats, + "local_video": video_stats, + "online": online_stats, + "online_accounts": accounts, + "online_scope": scope or "none", + "online_details": online_details, + } + task.finish(result) + except Exception as e: + task.fail_task(str(e)) + finally: + import asyncio + asyncio.create_task(expire_task(task.id, 300)) + + import asyncio + asyncio.create_task(_run()) + + return { + "status": "success", + "task_id": task.id, + "total": len(selected_tokens), + } + diff --git a/app/api/v1/admin/config.py b/app/api/v1/admin/config.py new file mode 100644 index 00000000..f843b76b --- /dev/null +++ b/app/api/v1/admin/config.py @@ -0,0 +1,53 @@ +import os + +from fastapi import APIRouter, Depends, HTTPException + +from app.core.auth import verify_app_key +from app.core.config import config +from app.core.storage import get_storage, LocalStorage, RedisStorage, SQLStorage + +router = APIRouter() + + +@router.get("/verify", dependencies=[Depends(verify_app_key)]) +async def admin_verify(): + """验证后台访问密钥(app_key)""" + return {"status": "success"} + + +@router.get("/config", dependencies=[Depends(verify_app_key)]) +async def get_config(): + """获取当前配置""" + # 暴露原始配置字典 + return config._config + + +@router.post("/config", dependencies=[Depends(verify_app_key)]) +async def update_config(data: dict): + """更新配置""" + try: + await config.update(data) + return {"status": "success", "message": "配置已更新"} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/storage", dependencies=[Depends(verify_app_key)]) +async def get_storage(): + """获取当前存储模式""" + storage_type = os.getenv("SERVER_STORAGE_TYPE", "").lower() + if not storage_type: + storage = get_storage() + if isinstance(storage, LocalStorage): + storage_type = "local" + elif isinstance(storage, RedisStorage): + storage_type = "redis" + elif isinstance(storage, SQLStorage): + storage_type = { + "mysql": "mysql", + "mariadb": "mysql", + "postgres": "pgsql", + "postgresql": "pgsql", + "pgsql": "pgsql", + }.get(storage.dialect, storage.dialect) + return {"type": storage_type or "local"} diff --git a/app/api/v1/admin/token.py b/app/api/v1/admin/token.py new file mode 100644 index 00000000..81b9fef2 --- /dev/null +++ b/app/api/v1/admin/token.py @@ -0,0 +1,395 @@ +import asyncio + +import orjson +from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi.responses import StreamingResponse + +from app.core.auth import get_app_key, verify_app_key +from app.core.batch import create_task, expire_task, get_task +from app.core.logger import logger +from app.core.storage import get_storage +from app.services.grok.batch_services.usage import UsageService +from app.services.grok.batch_services.nsfw import NSFWService +from app.services.token.manager import get_token_manager + +router = APIRouter() + + +@router.get("/tokens", dependencies=[Depends(verify_app_key)]) +async def get_tokens(): + """获取所有 Token""" + storage = get_storage() + tokens = await storage.load_tokens() + return tokens or {} + + +@router.post("/tokens", dependencies=[Depends(verify_app_key)]) +async def update_tokens(data: dict): + """更新 Token 信息""" + storage = get_storage() + try: + from app.services.token.models import TokenInfo + + async with storage.acquire_lock("tokens_save", timeout=10): + existing = await storage.load_tokens() or {} + normalized = {} + allowed_fields = set(TokenInfo.model_fields.keys()) + existing_map = {} + for pool_name, tokens in existing.items(): + if not isinstance(tokens, list): + continue + pool_map = {} + for item in tokens: + if isinstance(item, str): + token_data = {"token": item} + elif isinstance(item, dict): + token_data = dict(item) + else: + continue + raw_token = token_data.get("token") + if isinstance(raw_token, str) and raw_token.startswith("sso="): + token_data["token"] = raw_token[4:] + token_key = token_data.get("token") + if isinstance(token_key, str): + pool_map[token_key] = token_data + existing_map[pool_name] = pool_map + for pool_name, tokens in (data or {}).items(): + if not isinstance(tokens, list): + continue + pool_list = [] + for item in tokens: + if isinstance(item, str): + token_data = {"token": item} + elif isinstance(item, dict): + token_data = dict(item) + else: + continue + + raw_token = token_data.get("token") + if isinstance(raw_token, str) and raw_token.startswith("sso="): + token_data["token"] = raw_token[4:] + + base = existing_map.get(pool_name, {}).get( + token_data.get("token"), {} + ) + merged = dict(base) + merged.update(token_data) + if merged.get("tags") is None: + merged["tags"] = [] + + filtered = {k: v for k, v in merged.items() if k in allowed_fields} + try: + info = TokenInfo(**filtered) + pool_list.append(info.model_dump()) + except Exception as e: + logger.warning(f"Skip invalid token in pool '{pool_name}': {e}") + continue + normalized[pool_name] = pool_list + + await storage.save_tokens(normalized) + mgr = await get_token_manager() + await mgr.reload() + return {"status": "success", "message": "Token 已更新"} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/tokens/refresh", dependencies=[Depends(verify_app_key)]) +async def refresh_tokens(data: dict): + """刷新 Token 状态""" + try: + mgr = await get_token_manager() + tokens = [] + if isinstance(data.get("token"), str) and data["token"].strip(): + tokens.append(data["token"].strip()) + if isinstance(data.get("tokens"), list): + tokens.extend([str(t).strip() for t in data["tokens"] if str(t).strip()]) + + if not tokens: + raise HTTPException(status_code=400, detail="No tokens provided") + + unique_tokens = list(dict.fromkeys(tokens)) + + raw_results = await UsageService.batch( + unique_tokens, + mgr, + ) + + results = {} + for token, res in raw_results.items(): + if res.get("ok"): + results[token] = res.get("data", False) + else: + results[token] = False + + response = {"status": "success", "results": results} + return response + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/tokens/refresh/async", dependencies=[Depends(verify_app_key)]) +async def refresh_tokens_async(data: dict): + """刷新 Token 状态(异步批量 + SSE 进度)""" + mgr = await get_token_manager() + tokens = [] + if isinstance(data.get("token"), str) and data["token"].strip(): + tokens.append(data["token"].strip()) + if isinstance(data.get("tokens"), list): + tokens.extend([str(t).strip() for t in data["tokens"] if str(t).strip()]) + + if not tokens: + raise HTTPException(status_code=400, detail="No tokens provided") + + unique_tokens = list(dict.fromkeys(tokens)) + + task = create_task(len(unique_tokens)) + + async def _run(): + try: + + async def _on_item(item: str, res: dict): + task.record(bool(res.get("ok"))) + + raw_results = await UsageService.batch( + unique_tokens, + mgr, + on_item=_on_item, + should_cancel=lambda: task.cancelled, + ) + + if task.cancelled: + task.finish_cancelled() + return + + results: dict[str, bool] = {} + ok_count = 0 + fail_count = 0 + for token, res in raw_results.items(): + if res.get("ok") and res.get("data") is True: + ok_count += 1 + results[token] = True + else: + fail_count += 1 + results[token] = False + + await mgr._save() + + result = { + "status": "success", + "summary": { + "total": len(unique_tokens), + "ok": ok_count, + "fail": fail_count, + }, + "results": results, + } + task.finish(result) + except Exception as e: + task.fail_task(str(e)) + finally: + import asyncio + asyncio.create_task(expire_task(task.id, 300)) + + import asyncio + asyncio.create_task(_run()) + + return { + "status": "success", + "task_id": task.id, + "total": len(unique_tokens), + } + + +@router.get("/batch/{task_id}/stream") +async def batch_stream(task_id: str, request: Request): + app_key = get_app_key() + if app_key: + key = request.query_params.get("app_key") + if key != app_key: + raise HTTPException(status_code=401, detail="Invalid authentication token") + task = get_task(task_id) + if not task: + raise HTTPException(status_code=404, detail="Task not found") + + async def event_stream(): + queue = task.attach() + try: + yield f"data: {orjson.dumps({'type': 'snapshot', **task.snapshot()}).decode()}\n\n" + + final = task.final_event() + if final: + yield f"data: {orjson.dumps(final).decode()}\n\n" + return + + while True: + try: + event = await asyncio.wait_for(queue.get(), timeout=15) + except asyncio.TimeoutError: + yield ": ping\n\n" + final = task.final_event() + if final: + yield f"data: {orjson.dumps(final).decode()}\n\n" + return + continue + + yield f"data: {orjson.dumps(event).decode()}\n\n" + if event.get("type") in ("done", "error", "cancelled"): + return + finally: + task.detach(queue) + + return StreamingResponse(event_stream(), media_type="text/event-stream") + + +@router.post("/batch/{task_id}/cancel", dependencies=[Depends(verify_app_key)]) +async def batch_cancel(task_id: str): + task = get_task(task_id) + if not task: + raise HTTPException(status_code=404, detail="Task not found") + task.cancel() + return {"status": "success"} + + +@router.post("/tokens/nsfw/enable", dependencies=[Depends(verify_app_key)]) +async def enable_nsfw(data: dict): + """批量开启 NSFW (Unhinged) 模式""" + try: + mgr = await get_token_manager() + + tokens = [] + if isinstance(data.get("token"), str) and data["token"].strip(): + tokens.append(data["token"].strip()) + if isinstance(data.get("tokens"), list): + tokens.extend([str(t).strip() for t in data["tokens"] if str(t).strip()]) + + if not tokens: + for pool_name, pool in mgr.pools.items(): + for info in pool.list(): + raw = ( + info.token[4:] if info.token.startswith("sso=") else info.token + ) + tokens.append(raw) + + if not tokens: + raise HTTPException(status_code=400, detail="No tokens available") + + unique_tokens = list(dict.fromkeys(tokens)) + + raw_results = await NSFWService.batch( + unique_tokens, + mgr, + ) + + results = {} + ok_count = 0 + fail_count = 0 + + for token, res in raw_results.items(): + masked = f"{token[:8]}...{token[-8:]}" if len(token) > 20 else token + if res.get("ok") and res.get("data", {}).get("success"): + ok_count += 1 + results[masked] = res.get("data", {}) + else: + fail_count += 1 + results[masked] = res.get("data") or {"error": res.get("error")} + + response = { + "status": "success", + "summary": { + "total": len(unique_tokens), + "ok": ok_count, + "fail": fail_count, + }, + "results": results, + } + + return response + + except HTTPException: + raise + except Exception as e: + logger.error(f"Enable NSFW failed: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/tokens/nsfw/enable/async", dependencies=[Depends(verify_app_key)]) +async def enable_nsfw_async(data: dict): + """批量开启 NSFW (Unhinged) 模式(异步批量 + SSE 进度)""" + mgr = await get_token_manager() + + tokens = [] + if isinstance(data.get("token"), str) and data["token"].strip(): + tokens.append(data["token"].strip()) + if isinstance(data.get("tokens"), list): + tokens.extend([str(t).strip() for t in data["tokens"] if str(t).strip()]) + + if not tokens: + for pool_name, pool in mgr.pools.items(): + for info in pool.list(): + raw = info.token[4:] if info.token.startswith("sso=") else info.token + tokens.append(raw) + + if not tokens: + raise HTTPException(status_code=400, detail="No tokens available") + + unique_tokens = list(dict.fromkeys(tokens)) + + task = create_task(len(unique_tokens)) + + async def _run(): + try: + + async def _on_item(item: str, res: dict): + ok = bool(res.get("ok") and res.get("data", {}).get("success")) + task.record(ok) + + raw_results = await NSFWService.batch( + unique_tokens, + mgr, + on_item=_on_item, + should_cancel=lambda: task.cancelled, + ) + + if task.cancelled: + task.finish_cancelled() + return + + results = {} + ok_count = 0 + fail_count = 0 + for token, res in raw_results.items(): + masked = f"{token[:8]}...{token[-8:]}" if len(token) > 20 else token + if res.get("ok") and res.get("data", {}).get("success"): + ok_count += 1 + results[masked] = res.get("data", {}) + else: + fail_count += 1 + results[masked] = res.get("data") or {"error": res.get("error")} + + await mgr._save() + + result = { + "status": "success", + "summary": { + "total": len(unique_tokens), + "ok": ok_count, + "fail": fail_count, + }, + "results": results, + } + task.finish(result) + except Exception as e: + task.fail_task(str(e)) + finally: + import asyncio + asyncio.create_task(expire_task(task.id, 300)) + + import asyncio + asyncio.create_task(_run()) + + return { + "status": "success", + "task_id": task.id, + "total": len(unique_tokens), + } diff --git a/app/api/v1/chat.py b/app/api/v1/chat.py index f7b72ecd..66f07a31 100644 --- a/app/api/v1/chat.py +++ b/app/api/v1/chat.py @@ -5,14 +5,20 @@ from typing import Any, Dict, List, Optional, Union import base64 import binascii +import time from fastapi import APIRouter from fastapi.responses import StreamingResponse, JSONResponse from pydantic import BaseModel, Field from app.services.grok.services.chat import ChatService +from app.services.grok.services.image import ImageGenerationService +from app.services.grok.services.image_edit import ImageEditService from app.services.grok.services.model import ModelService -from app.core.exceptions import ValidationException +from app.services.grok.services.video import VideoService +from app.services.token import get_token_manager +from app.core.config import get_config +from app.core.exceptions import ValidationException, AppException, ErrorType class MessageItem(BaseModel): @@ -21,8 +27,6 @@ class MessageItem(BaseModel): role: str content: Union[str, List[Dict[str, Any]]] - model_config = {"extra": "ignore"} - class VideoConfig(BaseModel): """视频生成配置""" @@ -32,6 +36,13 @@ class VideoConfig(BaseModel): resolution_name: Optional[str] = Field("480p", description="视频分辨率: 480p, 720p") preset: Optional[str] = Field("custom", description="风格预设: fun, normal, spicy") +class ImageConfig(BaseModel): + """图片生成配置""" + + n: Optional[int] = Field(1, ge=1, le=10, description="生成数量 (1-10)") + size: Optional[str] = Field("1024x1024", description="图片尺寸") + response_format: Optional[str] = Field(None, description="响应格式") + class ChatCompletionRequest(BaseModel): """Chat Completions 请求""" @@ -44,14 +55,23 @@ class ChatCompletionRequest(BaseModel): top_p: Optional[float] = Field(0.95, description="nucleus 采样: 0-1") # 视频生成配置 video_config: Optional[VideoConfig] = Field(None, description="视频生成参数") - model_config = {"extra": "ignore"} + # 图片生成配置 + image_config: Optional[ImageConfig] = Field(None, description="图片生成参数") VALID_ROLES = {"developer", "system", "user", "assistant"} USER_CONTENT_TYPES = {"text", "image_url", "input_audio", "file"} +ALLOWED_IMAGE_SIZES = { + "1280x720", + "720x1280", + "1792x1024", + "1024x1792", + "1024x1024", +} def _validate_media_input(value: str, field_name: str, param: str): + """Verify media input is a valid URL or data URI""" if not isinstance(value, str) or not value.strip(): raise ValidationException( message=f"{field_name} cannot be empty", @@ -81,123 +101,86 @@ def _validate_media_input(value: str, field_name: str, param: str): ) -def _normalize_stream(value: Any) -> Optional[bool]: - if value is None: - return None - if isinstance(value, bool): - return value - if isinstance(value, str): - if value.lower() in ("true", "1", "yes"): - return True - if value.lower() in ("false", "0", "no"): - return False +def _extract_prompt_images(messages: List[MessageItem]) -> tuple[str, List[str]]: + """Extract prompt text and image URLs from messages""" + last_text = "" + image_urls: List[str] = [] + + for msg in messages: + role = msg.role or "user" + content = msg.content + if isinstance(content, str): + text = content.strip() + if text: + last_text = text + continue + if not isinstance(content, list): + continue + for block in content: + if not isinstance(block, dict): + continue + block_type = block.get("type") + if block_type == "text": + text = block.get("text", "") + if isinstance(text, str) and text.strip(): + last_text = text.strip() + elif block_type == "image_url" and role == "user": + image = block.get("image_url") or {} + url = image.get("url", "") + if isinstance(url, str) and url.strip(): + image_urls.append(url.strip()) + + return last_text, image_urls + + +def _resolve_image_format(value: Optional[str]) -> str: + fmt = value or get_config("app.image_format") or "url" + if isinstance(fmt, str): + fmt = fmt.lower() + if fmt == "base64": + return "b64_json" + if fmt in ("b64_json", "url"): + return fmt raise ValidationException( - message="stream must be a boolean", - param="stream", - code="invalid_stream", + message="image_format must be one of url, base64, b64_json", + param="image_format", + code="invalid_image_format", ) -def _validate_reasoning_effort(value: Any) -> Optional[str]: - allowed = {"none", "minimal", "low", "medium", "high", "xhigh"} - if value is None: - return None - if not isinstance(value, str) or value not in allowed: - raise ValidationException( - message=f"reasoning_effort must be one of {sorted(allowed)}", - param="reasoning_effort", - code="invalid_reasoning_effort", - ) - return value - - -def _validate_temperature(value: Any) -> float: - if value is None: - return 0.8 - try: - val = float(value) - except Exception: - raise ValidationException( - message="temperature must be a float", - param="temperature", - code="invalid_temperature", - ) - if not (0 <= val <= 2): - raise ValidationException( - message="temperature must be between 0 and 2", - param="temperature", - code="invalid_temperature", - ) - return val - - -def _validate_top_p(value: Any) -> float: - if value is None: - return 0.95 - try: - val = float(value) - except Exception: - raise ValidationException( - message="top_p must be a float", - param="top_p", - code="invalid_top_p", - ) - if not (0 <= val <= 1): - raise ValidationException( - message="top_p must be between 0 and 1", - param="top_p", - code="invalid_top_p", - ) - return val - - -def _normalize_video_config(config: Optional[VideoConfig]) -> VideoConfig: - if config is None: - config = VideoConfig() - - ratio_map = { - "1280x720": "16:9", - "720x1280": "9:16", - "1792x1024": "3:2", - "1024x1792": "2:3", - "1024x1024": "1:1", - "16:9": "16:9", - "9:16": "9:16", - "3:2": "3:2", - "2:3": "2:3", - "1:1": "1:1", - } - if config.aspect_ratio is None: - config.aspect_ratio = "3:2" - if config.aspect_ratio not in ratio_map: - raise ValidationException( - message=f"aspect_ratio must be one of {list(ratio_map.keys())}", - param="video_config.aspect_ratio", - code="invalid_aspect_ratio", - ) - config.aspect_ratio = ratio_map[config.aspect_ratio] +def _image_field(response_format: str) -> str: + if response_format == "url": + return "url" + return "b64_json" - if config.video_length not in (6, 10, 15): +def _validate_image_config(image_conf: ImageConfig, *, stream: bool): + n = image_conf.n or 1 + if n < 1 or n > 10: raise ValidationException( - message="video_length must be 6, 10, or 15 seconds", - param="video_config.video_length", - code="invalid_video_length", + message="n must be between 1 and 10", + param="image_config.n", + code="invalid_n", ) - if config.resolution_name not in ("480p", "720p"): + if stream and n not in (1, 2): raise ValidationException( - message="resolution_name must be one of ['480p', '720p']", - param="video_config.resolution_name", - code="invalid_resolution", + message="Streaming is only supported when n=1 or n=2", + param="image_config.n", + code="invalid_stream_n", ) - if config.preset not in ("fun", "normal", "spicy", "custom"): + if image_conf.response_format: + allowed_formats = {"b64_json", "base64", "url"} + if image_conf.response_format not in allowed_formats: + raise ValidationException( + message="response_format must be one of b64_json, base64, url", + param="image_config.response_format", + code="invalid_response_format", + ) + if image_conf.size and image_conf.size not in ALLOWED_IMAGE_SIZES: raise ValidationException( - message="preset must be one of ['fun', 'normal', 'spicy', 'custom']", - param="video_config.preset", - code="invalid_preset", + message=f"size must be one of {sorted(ALLOWED_IMAGE_SIZES)}", + param="image_config.size", + code="invalid_size", ) - return config - - def validate_request(request: ChatCompletionRequest): """验证请求参数""" # 验证模型 @@ -344,14 +327,174 @@ def validate_request(request: ChatCompletionRequest): code="invalid_content", ) - request.stream = _normalize_stream(request.stream) - request.reasoning_effort = _validate_reasoning_effort(request.reasoning_effort) - request.temperature = _validate_temperature(request.temperature) - request.top_p = _validate_top_p(request.top_p) + # 默认验证 + if request.stream is not None: + if isinstance(request.stream, bool): + pass + elif isinstance(request.stream, str): + if request.stream.lower() in ("true", "1", "yes"): + request.stream = True + elif request.stream.lower() in ("false", "0", "no"): + request.stream = False + else: + raise ValidationException( + message="stream must be a boolean", + param="stream", + code="invalid_stream", + ) + else: + raise ValidationException( + message="stream must be a boolean", + param="stream", + code="invalid_stream", + ) + + allowed_efforts = {"none", "minimal", "low", "medium", "high", "xhigh"} + if request.reasoning_effort is not None: + if not isinstance(request.reasoning_effort, str) or ( + request.reasoning_effort not in allowed_efforts + ): + raise ValidationException( + message=f"reasoning_effort must be one of {sorted(allowed_efforts)}", + param="reasoning_effort", + code="invalid_reasoning_effort", + ) + + if request.temperature is None: + request.temperature = 0.8 + else: + try: + request.temperature = float(request.temperature) + except Exception: + raise ValidationException( + message="temperature must be a float", + param="temperature", + code="invalid_temperature", + ) + if not (0 <= request.temperature <= 2): + raise ValidationException( + message="temperature must be between 0 and 2", + param="temperature", + code="invalid_temperature", + ) + + if request.top_p is None: + request.top_p = 0.95 + else: + try: + request.top_p = float(request.top_p) + except Exception: + raise ValidationException( + message="top_p must be a float", + param="top_p", + code="invalid_top_p", + ) + if not (0 <= request.top_p <= 1): + raise ValidationException( + message="top_p must be between 0 and 1", + param="top_p", + code="invalid_top_p", + ) model_info = ModelService.get(request.model) + # image 验证 + if model_info and (model_info.is_image or model_info.is_image_edit): + prompt, image_urls = _extract_prompt_images(request.messages) + if not prompt: + raise ValidationException( + message="Prompt cannot be empty", + param="messages", + code="empty_prompt", + ) + image_conf = request.image_config or ImageConfig() + n = image_conf.n or 1 + if not (1 <= n <= 10): + raise ValidationException( + message="n must be between 1 and 10", + param="image_config.n", + code="invalid_n", + ) + if request.stream and n not in (1, 2): + raise ValidationException( + message="Streaming is only supported when n=1 or n=2", + param="stream", + code="invalid_stream_n", + ) + + response_format = _resolve_image_format(image_conf.response_format) + image_conf.n = n + image_conf.response_format = response_format + if not image_conf.size: + image_conf.size = "1024x1024" + allowed_sizes = { + "1280x720", + "720x1280", + "1792x1024", + "1024x1792", + "1024x1024", + } + if image_conf.size not in allowed_sizes: + raise ValidationException( + message=f"size must be one of {sorted(allowed_sizes)}", + param="image_config.size", + code="invalid_size", + ) + request.image_config = image_conf + + # image edit 验证 + if model_info and model_info.is_image_edit: + _, image_urls = _extract_prompt_images(request.messages) + if not image_urls: + raise ValidationException( + message="image_url is required for image edits", + param="messages", + code="missing_image", + ) + + # video 验证 if model_info and model_info.is_video: - request.video_config = _normalize_video_config(request.video_config) + config = request.video_config or VideoConfig() + ratio_map = { + "1280x720": "16:9", + "720x1280": "9:16", + "1792x1024": "3:2", + "1024x1792": "2:3", + "1024x1024": "1:1", + "16:9": "16:9", + "9:16": "9:16", + "3:2": "3:2", + "2:3": "2:3", + "1:1": "1:1", + } + if config.aspect_ratio is None: + config.aspect_ratio = "3:2" + if config.aspect_ratio not in ratio_map: + raise ValidationException( + message=f"aspect_ratio must be one of {list(ratio_map.keys())}", + param="video_config.aspect_ratio", + code="invalid_aspect_ratio", + ) + config.aspect_ratio = ratio_map[config.aspect_ratio] + + if config.video_length not in (6, 10, 15): + raise ValidationException( + message="video_length must be 6, 10, or 15 seconds", + param="video_config.video_length", + code="invalid_video_length", + ) + if config.resolution_name not in ("480p", "720p"): + raise ValidationException( + message="resolution_name must be one of ['480p', '720p']", + param="video_config.resolution_name", + code="invalid_resolution", + ) + if config.preset not in ("fun", "normal", "spicy", "custom"): + raise ValidationException( + message="preset must be one of ['fun', 'normal', 'spicy', 'custom']", + param="video_config.preset", + code="invalid_preset", + ) + request.video_config = config router = APIRouter(tags=["Chat"]) @@ -367,11 +510,149 @@ async def chat_completions(request: ChatCompletionRequest): logger.debug(f"Chat request: model={request.model}, stream={request.stream}") - # 检测视频模型 + # 检测模型类型 model_info = ModelService.get(request.model) - if model_info and model_info.is_video: - from app.services.grok.services.video import VideoService + if model_info and model_info.is_image_edit: + prompt, image_urls = _extract_prompt_images(request.messages) + if not image_urls: + raise ValidationException( + message="Image is required", + param="image", + code="missing_image", + ) + image_url = image_urls[-1] + + is_stream = ( + request.stream if request.stream is not None else get_config("app.stream") + ) + image_conf = request.image_config or ImageConfig() + _validate_image_config(image_conf, stream=bool(is_stream)) + response_format = _resolve_image_format(image_conf.response_format) + response_field = _image_field(response_format) + n = image_conf.n or 1 + + token_mgr = await get_token_manager() + await token_mgr.reload_if_stale() + + token = None + for pool_name in ModelService.pool_candidates_for_model(request.model): + token = token_mgr.get_token(pool_name) + if token: + break + + if not token: + raise AppException( + message="No available tokens. Please try again later.", + error_type=ErrorType.RATE_LIMIT.value, + code="rate_limit_exceeded", + status_code=429, + ) + + result = await ImageEditService().edit( + token_mgr=token_mgr, + token=token, + model_info=model_info, + prompt=prompt, + images=[image_url], + n=n, + response_format=response_format, + stream=bool(is_stream), + ) + + if result.stream: + return StreamingResponse( + result.data, + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, + ) + + data = [{response_field: img} for img in result.data] + return JSONResponse( + content={ + "created": int(time.time()), + "data": data, + "usage": { + "total_tokens": 0, + "input_tokens": 0, + "output_tokens": 0, + "input_tokens_details": {"text_tokens": 0, "image_tokens": 0}, + }, + } + ) + if model_info and model_info.is_image: + prompt, _ = _extract_prompt_images(request.messages) + + is_stream = ( + request.stream if request.stream is not None else get_config("app.stream") + ) + image_conf = request.image_config or ImageConfig() + _validate_image_config(image_conf, stream=bool(is_stream)) + response_format = _resolve_image_format(image_conf.response_format) + response_field = _image_field(response_format) + n = image_conf.n or 1 + size = image_conf.size or "1024x1024" + aspect_ratio_map = { + "1280x720": "16:9", + "720x1280": "9:16", + "1792x1024": "3:2", + "1024x1792": "2:3", + "1024x1024": "1:1", + } + aspect_ratio = aspect_ratio_map.get(size, "2:3") + + token_mgr = await get_token_manager() + await token_mgr.reload_if_stale() + + token = None + for pool_name in ModelService.pool_candidates_for_model(request.model): + token = token_mgr.get_token(pool_name) + if token: + break + + if not token: + raise AppException( + message="No available tokens. Please try again later.", + error_type=ErrorType.RATE_LIMIT.value, + code="rate_limit_exceeded", + status_code=429, + ) + + result = await ImageGenerationService().generate( + token_mgr=token_mgr, + token=token, + model_info=model_info, + prompt=prompt, + n=n, + response_format=response_format, + size=size, + aspect_ratio=aspect_ratio, + stream=bool(is_stream), + ) + + if result.stream: + return StreamingResponse( + result.data, + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, + ) + + data = [{response_field: img} for img in result.data] + usage = result.usage_override or { + "total_tokens": 0, + "input_tokens": 0, + "output_tokens": 0, + "input_tokens_details": {"text_tokens": 0, "image_tokens": 0}, + } + return JSONResponse( + content={ + "created": int(time.time()), + "data": data, + "usage": usage, + } + ) + + if model_info and model_info.is_video: # 提取视频配置 (默认值在 Pydantic 模型中处理) v_conf = request.video_config or VideoConfig() diff --git a/app/api/v1/image.py b/app/api/v1/image.py index 1d7a3354..a7dcc9c1 100644 --- a/app/api/v1/image.py +++ b/app/api/v1/image.py @@ -21,6 +21,22 @@ router = APIRouter(tags=["Images"]) +ALLOWED_IMAGE_SIZES = { + "1280x720", + "720x1280", + "1792x1024", + "1024x1792", + "1024x1024", +} + +SIZE_TO_ASPECT = { + "1280x720": "16:9", + "720x1280": "9:16", + "1792x1024": "3:2", + "1024x1792": "2:3", + "1024x1024": "1:1", +} + class ImageGenerationRequest(BaseModel): """图片生成请求 - OpenAI 兼容""" @@ -28,7 +44,10 @@ class ImageGenerationRequest(BaseModel): prompt: str = Field(..., description="图片描述") model: Optional[str] = Field("grok-imagine-1.0", description="模型名称") n: Optional[int] = Field(1, ge=1, le=10, description="生成数量 (1-10)") - size: Optional[str] = Field("1024x1024", description="图片尺寸 (暂不支持)") + size: Optional[str] = Field( + "1024x1024", + description="图片尺寸: 1280x720, 720x1280, 1792x1024, 1024x1792, 1024x1024", + ) quality: Optional[str] = Field("standard", description="图片质量 (暂不支持)") response_format: Optional[str] = Field(None, description="响应格式") style: Optional[str] = Field(None, description="风格 (暂不支持)") @@ -42,7 +61,10 @@ class ImageEditRequest(BaseModel): model: Optional[str] = Field("grok-imagine-1.0-edit", description="模型名称") image: Optional[Union[str, List[str]]] = Field(None, description="待编辑图片文件") n: Optional[int] = Field(1, ge=1, le=10, description="生成数量 (1-10)") - size: Optional[str] = Field("1024x1024", description="图片尺寸 (暂不支持)") + size: Optional[str] = Field( + "1024x1024", + description="图片尺寸: 1280x720, 720x1280, 1792x1024, 1024x1792, 1024x1024", + ) quality: Optional[str] = Field("standard", description="图片质量 (暂不支持)") response_format: Optional[str] = Field(None, description="响应格式") style: Optional[str] = Field(None, description="风格 (暂不支持)") @@ -76,18 +98,14 @@ def _validate_common_request( ) if allow_ws_stream: - # WS 流式仅支持 b64_json (base64 视为同义) - if ( - request.stream - and get_config("image.image_ws") - and request.response_format - and request.response_format not in {"b64_json", "base64"} - ): - raise ValidationException( - message="Streaming with image_ws only supports response_format=b64_json/base64", - param="response_format", - code="invalid_response_format", - ) + if request.stream and request.response_format: + allowed_stream_formats = {"b64_json", "base64", "url"} + if request.response_format not in allowed_stream_formats: + raise ValidationException( + message="Streaming only supports response_format=b64_json/base64/url", + param="response_format", + code="invalid_response_format", + ) if request.response_format: allowed_formats = {"b64_json", "base64", "url"} @@ -98,6 +116,13 @@ def _validate_common_request( code="invalid_response_format", ) + if request.size and request.size not in ALLOWED_IMAGE_SIZES: + raise ValidationException( + message=f"size must be one of {sorted(ALLOWED_IMAGE_SIZES)}", + param="size", + code="invalid_size", + ) + def validate_generation_request(request: ImageGenerationRequest): """验证图片生成请求参数""" @@ -144,26 +169,8 @@ def response_field_name(response_format: str) -> str: def resolve_aspect_ratio(size: str) -> str: """Map OpenAI size to Grok Imagine aspect ratio.""" - size = (size or "").lower() - if size in {"16:9", "9:16", "1:1", "2:3", "3:2"}: - return size - mapping = { - "1024x1024": "1:1", - "512x512": "1:1", - "1024x576": "16:9", - "1280x720": "16:9", - "1536x864": "16:9", - "576x1024": "9:16", - "720x1280": "9:16", - "864x1536": "9:16", - "1024x1536": "2:3", - "512x768": "2:3", - "768x1024": "2:3", - "1536x1024": "3:2", - "768x512": "3:2", - "1024x768": "3:2", - } - return mapping.get(size) or "2:3" + size = (size or "").strip() + return SIZE_TO_ASPECT.get(size) or "2:3" def validate_edit_request(request: ImageEditRequest, images: List[UploadFile]): @@ -174,6 +181,17 @@ def validate_edit_request(request: ImageEditRequest, images: List[UploadFile]): param="model", code="model_not_supported", ) + model_info = ModelService.get(request.model) + if not model_info or not model_info.is_image_edit: + edit_models = [m.model_id for m in ModelService.MODELS if m.is_image_edit] + raise ValidationException( + message=( + f"The model `{request.model}` is not supported for image edits. " + f"Supported: {edit_models}" + ), + param="model", + code="model_not_supported", + ) _validate_common_request(request, allow_ws_stream=False) if not images: raise ValidationException( @@ -243,7 +261,6 @@ async def create_image(request: ImageGenerationRequest): # 获取 token 和模型信息 token_mgr, token = await _get_token(request.model) model_info = ModelService.get(request.model) - use_ws = bool(get_config("image.image_ws")) aspect_ratio = resolve_aspect_ratio(request.size) result = await ImageGenerationService().generate( @@ -256,7 +273,6 @@ async def create_image(request: ImageGenerationRequest): size=request.size, aspect_ratio=aspect_ratio, stream=bool(request.stream), - use_ws=use_ws, ) if result.stream: @@ -332,6 +348,8 @@ async def edit_image( edit_request.stream = False response_format = resolve_response_format(edit_request.response_format) + if response_format == "base64": + response_format = "b64_json" edit_request.response_format = response_format response_field = response_field_name(response_format) diff --git a/app/api/v1/models.py b/app/api/v1/models.py index babf35eb..fe5bdc0e 100644 --- a/app/api/v1/models.py +++ b/app/api/v1/models.py @@ -18,7 +18,7 @@ async def list_models(): "id": m.model_id, "object": "model", "created": 0, - "owned_by": "grok2api", + "owned_by": "grok2api@chenyme", } for m in ModelService.list() ] diff --git a/app/api/v1/pages.py b/app/api/v1/pages.py new file mode 100644 index 00000000..cd9472c3 --- /dev/null +++ b/app/api/v1/pages.py @@ -0,0 +1,94 @@ +from pathlib import Path + +import aiofiles +from fastapi import APIRouter, HTTPException +from fastapi.responses import HTMLResponse, RedirectResponse + +from app.core.auth import is_public_enabled + +router = APIRouter() +TEMPLATE_DIR = Path(__file__).resolve().parents[2] / "static" + + +async def render_template(filename: str) -> HTMLResponse: + """渲染指定模板""" + template_path = TEMPLATE_DIR / filename + if not template_path.exists(): + return HTMLResponse(f"Template {filename} not found.", status_code=404) + + async with aiofiles.open(template_path, "r", encoding="utf-8") as f: + content = await f.read() + return HTMLResponse(content) + +@router.get("/", include_in_schema=False) +async def root_redirect(): + if is_public_enabled(): + return RedirectResponse(url="/login") + return RedirectResponse(url="/admin/login") + + +@router.get("/login", response_class=HTMLResponse, include_in_schema=False) +async def public_login_page(): + """Public 登录页""" + if not is_public_enabled(): + raise HTTPException(status_code=404, detail="Not Found") + return await render_template("public/login.html") + + +@router.get("/imagine", response_class=HTMLResponse, include_in_schema=False) +async def public_imagine_page(): + """Imagine 图片瀑布流""" + if not is_public_enabled(): + raise HTTPException(status_code=404, detail="Not Found") + return await render_template("imagine/imagine.html") + + +@router.get("/voice", response_class=HTMLResponse, include_in_schema=False) +async def public_voice_page(): + """Voice Live 调试页""" + if not is_public_enabled(): + raise HTTPException(status_code=404, detail="Not Found") + return await render_template("voice/voice.html") + + +@router.get("/admin", include_in_schema=False) +async def admin_root_redirect(): + return RedirectResponse(url="/admin/login") + + +@router.get("/admin/login", response_class=HTMLResponse, include_in_schema=False) +async def admin_login_page(): + """管理后台登录页""" + return await render_template("login/login.html") + + +@router.get("/admin/config", response_class=HTMLResponse, include_in_schema=False) +async def admin_config_page(): + """配置管理页""" + return await render_template("config/config.html") + + +@router.get("/admin/token", response_class=HTMLResponse, include_in_schema=False) +async def admin_token_page(): + """Token 管理页""" + return await render_template("token/token.html") + + +@router.get("/admin/voice", include_in_schema=False) +async def admin_voice_redirect(): + if not is_public_enabled(): + raise HTTPException(status_code=404, detail="Not Found") + return RedirectResponse(url="/voice") + + +@router.get("/admin/imagine", include_in_schema=False) +async def admin_imagine_redirect(): + if not is_public_enabled(): + raise HTTPException(status_code=404, detail="Not Found") + return RedirectResponse(url="/imagine") + + +@router.get("/admin/cache", response_class=HTMLResponse, include_in_schema=False) +async def admin_cache_page(): + """缓存管理页""" + return await render_template("cache/cache.html") diff --git a/app/api/v1/public/__init__.py b/app/api/v1/public/__init__.py new file mode 100644 index 00000000..0d4ab694 --- /dev/null +++ b/app/api/v1/public/__init__.py @@ -0,0 +1,13 @@ +"""Public API router (public_key protected).""" + +from fastapi import APIRouter + +from app.api.v1.public.imagine import router as imagine_router +from app.api.v1.public.voice import router as voice_router + +router = APIRouter() + +router.include_router(imagine_router) +router.include_router(voice_router) + +__all__ = ["router"] diff --git a/app/api/v1/public/imagine.py b/app/api/v1/public/imagine.py new file mode 100644 index 00000000..5e71436e --- /dev/null +++ b/app/api/v1/public/imagine.py @@ -0,0 +1,449 @@ +import asyncio +import time +import uuid +from typing import Optional, List + +import orjson +from fastapi import APIRouter, Depends, HTTPException, Query, Request, WebSocket, WebSocketDisconnect +from fastapi.responses import StreamingResponse +from pydantic import BaseModel + +from app.core.auth import verify_public_key, get_public_api_key, is_public_enabled +from app.core.logger import logger +from app.api.v1.image import resolve_aspect_ratio +from app.services.grok.services.image import ImageGenerationService +from app.services.grok.services.model import ModelService +from app.services.token.manager import get_token_manager + +router = APIRouter() + +IMAGINE_SESSION_TTL = 600 +_IMAGINE_SESSIONS: dict[str, dict] = {} +_IMAGINE_SESSIONS_LOCK = asyncio.Lock() + + +async def _clean_sessions(now: float) -> None: + expired = [ + key + for key, info in _IMAGINE_SESSIONS.items() + if now - float(info.get("created_at") or 0) > IMAGINE_SESSION_TTL + ] + for key in expired: + _IMAGINE_SESSIONS.pop(key, None) + + +async def _new_session(prompt: str, aspect_ratio: str) -> str: + task_id = uuid.uuid4().hex + now = time.time() + async with _IMAGINE_SESSIONS_LOCK: + await _clean_sessions(now) + _IMAGINE_SESSIONS[task_id] = { + "prompt": prompt, + "aspect_ratio": aspect_ratio, + "created_at": now, + } + return task_id + + +async def _get_session(task_id: str) -> Optional[dict]: + if not task_id: + return None + now = time.time() + async with _IMAGINE_SESSIONS_LOCK: + await _clean_sessions(now) + info = _IMAGINE_SESSIONS.get(task_id) + if not info: + return None + created_at = float(info.get("created_at") or 0) + if now - created_at > IMAGINE_SESSION_TTL: + _IMAGINE_SESSIONS.pop(task_id, None) + return None + return dict(info) + + +async def _drop_session(task_id: str) -> None: + if not task_id: + return + async with _IMAGINE_SESSIONS_LOCK: + _IMAGINE_SESSIONS.pop(task_id, None) + + +async def _drop_sessions(task_ids: List[str]) -> int: + if not task_ids: + return 0 + removed = 0 + async with _IMAGINE_SESSIONS_LOCK: + for task_id in task_ids: + if task_id and task_id in _IMAGINE_SESSIONS: + _IMAGINE_SESSIONS.pop(task_id, None) + removed += 1 + return removed + + +@router.websocket("/imagine/ws") +async def public_imagine_ws(websocket: WebSocket): + session_id = None + task_id = websocket.query_params.get("task_id") + if task_id: + info = await _get_session(task_id) + if info: + session_id = task_id + + ok = True + if session_id is None: + public_key = get_public_api_key() + public_enabled = is_public_enabled() + if not public_key: + ok = public_enabled + else: + key = websocket.query_params.get("public_key") + ok = key == public_key + + if not ok: + await websocket.close(code=1008) + return + + await websocket.accept() + stop_event = asyncio.Event() + run_task: Optional[asyncio.Task] = None + + async def _send(payload: dict) -> bool: + try: + await websocket.send_text(orjson.dumps(payload).decode()) + return True + except Exception: + return False + + async def _stop_run(): + nonlocal run_task + stop_event.set() + if run_task and not run_task.done(): + run_task.cancel() + try: + await run_task + except Exception: + pass + run_task = None + stop_event.clear() + + async def _run(prompt: str, aspect_ratio: str): + model_id = "grok-imagine-1.0" + model_info = ModelService.get(model_id) + if not model_info or not model_info.is_image: + await _send( + { + "type": "error", + "message": "Image model is not available.", + "code": "model_not_supported", + } + ) + return + + token_mgr = await get_token_manager() + sequence = 0 + run_id = uuid.uuid4().hex + + await _send( + { + "type": "status", + "status": "running", + "prompt": prompt, + "aspect_ratio": aspect_ratio, + "run_id": run_id, + } + ) + + while not stop_event.is_set(): + try: + await token_mgr.reload_if_stale() + token = None + for pool_name in ModelService.pool_candidates_for_model( + model_info.model_id + ): + token = token_mgr.get_token(pool_name) + if token: + break + + if not token: + await _send( + { + "type": "error", + "message": "No available tokens. Please try again later.", + "code": "rate_limit_exceeded", + } + ) + await asyncio.sleep(2) + continue + + start_at = time.time() + result = await ImageGenerationService().generate( + token_mgr=token_mgr, + token=token, + model_info=model_info, + prompt=prompt, + n=6, + response_format="b64_json", + size="1024x1024", + aspect_ratio=aspect_ratio, + stream=False, + ) + elapsed_ms = int((time.time() - start_at) * 1000) + + images = [img for img in result.data if img and img != "error"] + if images: + for img_b64 in images: + sequence += 1 + await _send( + { + "type": "image", + "b64_json": img_b64, + "sequence": sequence, + "created_at": int(time.time() * 1000), + "elapsed_ms": elapsed_ms, + "aspect_ratio": aspect_ratio, + "run_id": run_id, + } + ) + else: + await _send( + { + "type": "error", + "message": "Image generation returned empty data.", + "code": "empty_image", + } + ) + + except asyncio.CancelledError: + break + except Exception as e: + logger.warning(f"Imagine stream error: {e}") + await _send( + { + "type": "error", + "message": str(e), + "code": "internal_error", + } + ) + await asyncio.sleep(1.5) + + await _send({"type": "status", "status": "stopped", "run_id": run_id}) + + try: + while True: + try: + raw = await websocket.receive_text() + except (RuntimeError, WebSocketDisconnect): + break + + try: + payload = orjson.loads(raw) + except Exception: + await _send( + { + "type": "error", + "message": "Invalid message format.", + "code": "invalid_payload", + } + ) + continue + + action = payload.get("type") + if action == "start": + prompt = str(payload.get("prompt") or "").strip() + if not prompt: + await _send( + { + "type": "error", + "message": "Prompt cannot be empty.", + "code": "invalid_prompt", + } + ) + continue + aspect_ratio = resolve_aspect_ratio( + str(payload.get("aspect_ratio") or "2:3").strip() or "2:3" + ) + await _stop_run() + run_task = asyncio.create_task(_run(prompt, aspect_ratio)) + elif action == "stop": + await _stop_run() + else: + await _send( + { + "type": "error", + "message": "Unknown action.", + "code": "invalid_action", + } + ) + + except WebSocketDisconnect: + logger.debug("WebSocket disconnected by client") + except Exception as e: + logger.warning(f"WebSocket error: {e}") + finally: + await _stop_run() + + try: + from starlette.websockets import WebSocketState + if websocket.client_state == WebSocketState.CONNECTED: + await websocket.close(code=1000, reason="Server closing connection") + except Exception as e: + logger.debug(f"WebSocket close ignored: {e}") + if session_id: + await _drop_session(session_id) + + +@router.get("/imagine/sse") +async def public_imagine_sse( + request: Request, + task_id: str = Query(""), + prompt: str = Query(""), + aspect_ratio: str = Query("2:3"), +): + """Imagine 图片瀑布流(SSE 兜底)""" + session = None + if task_id: + session = await _get_session(task_id) + if not session: + raise HTTPException(status_code=404, detail="Task not found") + else: + public_key = get_public_api_key() + public_enabled = is_public_enabled() + if not public_key: + if not public_enabled: + raise HTTPException(status_code=401, detail="Public access is disabled") + else: + key = request.query_params.get("public_key") + if key != public_key: + raise HTTPException(status_code=401, detail="Invalid authentication token") + + if session: + prompt = str(session.get("prompt") or "").strip() + ratio = str(session.get("aspect_ratio") or "2:3").strip() or "2:3" + else: + prompt = (prompt or "").strip() + if not prompt: + raise HTTPException(status_code=400, detail="Prompt cannot be empty") + ratio = str(aspect_ratio or "2:3").strip() or "2:3" + ratio = resolve_aspect_ratio(ratio) + + async def event_stream(): + try: + model_id = "grok-imagine-1.0" + model_info = ModelService.get(model_id) + if not model_info or not model_info.is_image: + yield ( + f"data: {orjson.dumps({'type': 'error', 'message': 'Image model is not available.', 'code': 'model_not_supported'}).decode()}\n\n" + ) + return + + token_mgr = await get_token_manager() + sequence = 0 + run_id = uuid.uuid4().hex + + yield ( + f"data: {orjson.dumps({'type': 'status', 'status': 'running', 'prompt': prompt, 'aspect_ratio': ratio, 'run_id': run_id}).decode()}\n\n" + ) + + while True: + if await request.is_disconnected(): + break + if task_id: + session_alive = await _get_session(task_id) + if not session_alive: + break + + try: + await token_mgr.reload_if_stale() + token = None + for pool_name in ModelService.pool_candidates_for_model( + model_info.model_id + ): + token = token_mgr.get_token(pool_name) + if token: + break + + if not token: + yield ( + f"data: {orjson.dumps({'type': 'error', 'message': 'No available tokens. Please try again later.', 'code': 'rate_limit_exceeded'}).decode()}\n\n" + ) + await asyncio.sleep(2) + continue + + start_at = time.time() + result = await ImageGenerationService().generate( + token_mgr=token_mgr, + token=token, + model_info=model_info, + prompt=prompt, + n=6, + response_format="b64_json", + size="1024x1024", + aspect_ratio=ratio, + stream=False, + ) + elapsed_ms = int((time.time() - start_at) * 1000) + + images = [img for img in result.data if img and img != "error"] + if images: + for img_b64 in images: + sequence += 1 + payload = { + "type": "image", + "b64_json": img_b64, + "sequence": sequence, + "created_at": int(time.time() * 1000), + "elapsed_ms": elapsed_ms, + "aspect_ratio": ratio, + "run_id": run_id, + } + yield f"data: {orjson.dumps(payload).decode()}\n\n" + else: + yield ( + f"data: {orjson.dumps({'type': 'error', 'message': 'Image generation returned empty data.', 'code': 'empty_image'}).decode()}\n\n" + ) + except asyncio.CancelledError: + break + except Exception as e: + logger.warning(f"Imagine SSE error: {e}") + yield ( + f"data: {orjson.dumps({'type': 'error', 'message': str(e), 'code': 'internal_error'}).decode()}\n\n" + ) + await asyncio.sleep(1.5) + + yield ( + f"data: {orjson.dumps({'type': 'status', 'status': 'stopped', 'run_id': run_id}).decode()}\n\n" + ) + finally: + if task_id: + await _drop_session(task_id) + + return StreamingResponse( + event_stream(), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, + ) + + +class ImagineStartRequest(BaseModel): + prompt: str + aspect_ratio: Optional[str] = "2:3" + + +@router.post("/imagine/start", dependencies=[Depends(verify_public_key)]) +async def public_imagine_start(data: ImagineStartRequest): + prompt = (data.prompt or "").strip() + if not prompt: + raise HTTPException(status_code=400, detail="Prompt cannot be empty") + ratio = resolve_aspect_ratio(str(data.aspect_ratio or "2:3").strip() or "2:3") + task_id = await _new_session(prompt, ratio) + return {"task_id": task_id, "aspect_ratio": ratio} + + +class ImagineStopRequest(BaseModel): + task_ids: List[str] + + +@router.post("/imagine/stop", dependencies=[Depends(verify_public_key)]) +async def public_imagine_stop(data: ImagineStopRequest): + removed = await _drop_sessions(data.task_ids or []) + return {"status": "success", "removed": removed} diff --git a/app/api/v1/public/voice.py b/app/api/v1/public/voice.py new file mode 100644 index 00000000..12612f09 --- /dev/null +++ b/app/api/v1/public/voice.py @@ -0,0 +1,80 @@ +from fastapi import APIRouter, Depends +from pydantic import BaseModel + +from app.core.auth import verify_public_key +from app.core.exceptions import AppException +from app.services.grok.services.voice import VoiceService +from app.services.token.manager import get_token_manager + +router = APIRouter() + + +class VoiceTokenResponse(BaseModel): + token: str + url: str + participant_name: str = "" + room_name: str = "" + + +@router.get( + "/voice/token", + dependencies=[Depends(verify_public_key)], + response_model=VoiceTokenResponse, +) +async def public_voice_token( + voice: str = "ara", + personality: str = "assistant", + speed: float = 1.0, +): + """获取 Grok Voice Mode (LiveKit) Token""" + token_mgr = await get_token_manager() + sso_token = None + for pool_name in ("ssoBasic", "ssoSuper"): + sso_token = token_mgr.get_token(pool_name) + if sso_token: + break + + if not sso_token: + raise AppException( + "No available tokens for voice mode", + code="no_token", + status_code=503, + ) + + service = VoiceService() + try: + data = await service.get_token( + token=sso_token, + voice=voice, + personality=personality, + speed=speed, + ) + token = data.get("token") + if not token: + raise AppException( + "Upstream returned no voice token", + code="upstream_error", + status_code=502, + ) + + return VoiceTokenResponse( + token=token, + url="wss://livekit.grok.com", + participant_name="", + room_name="", + ) + + except Exception as e: + if isinstance(e, AppException): + raise + raise AppException( + f"Voice token error: {str(e)}", + code="voice_error", + status_code=500, + ) + + +@router.get("/verify", dependencies=[Depends(verify_public_key)]) +async def public_verify_api(): + """验证 Public Key""" + return {"status": "success"} diff --git a/app/core/auth.py b/app/core/auth.py index e6bb3c37..2cb45820 100644 --- a/app/core/auth.py +++ b/app/core/auth.py @@ -10,6 +10,8 @@ DEFAULT_API_KEY = "" DEFAULT_APP_KEY = "grok2api" +DEFAULT_PUBLIC_KEY = "" +DEFAULT_PUBLIC_ENABLED = False # 定义 Bearer Scheme security = HTTPBearer( @@ -28,6 +30,28 @@ def get_admin_api_key() -> str: api_key = get_config("app.api_key", DEFAULT_API_KEY) return api_key or "" +def get_app_key() -> str: + """ + 获取 App Key(后台管理密码)。 + """ + app_key = get_config("app.app_key", DEFAULT_APP_KEY) + return app_key or "" + +def get_public_api_key() -> str: + """ + 获取 Public API Key。 + + 为空时表示不启用 public 接口认证。 + """ + public_key = get_config("app.public_key", DEFAULT_PUBLIC_KEY) + return public_key or "" + +def is_public_enabled() -> bool: + """ + 是否开启 public 功能入口。 + """ + return bool(get_config("app.public_enabled", DEFAULT_PUBLIC_ENABLED)) + async def verify_api_key( auth: Optional[HTTPAuthorizationCredentials] = Security(security), @@ -66,7 +90,7 @@ async def verify_app_key( app_key 必须配置,否则拒绝登录。 """ - app_key = get_config("app.app_key", DEFAULT_APP_KEY) + app_key = get_app_key() if not app_key: raise HTTPException( @@ -90,3 +114,40 @@ async def verify_app_key( ) return auth.credentials + + +async def verify_public_key( + auth: Optional[HTTPAuthorizationCredentials] = Security(security), +) -> Optional[str]: + """ + 验证 Public Key(public 接口使用)。 + + 默认不公开,需配置 public_key 才能访问;若开启 public_enabled 且未配置 public_key,则放开访问。 + """ + public_key = get_public_api_key() + public_enabled = is_public_enabled() + + if not public_key: + if public_enabled: + return None + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Public access is disabled", + headers={"WWW-Authenticate": "Bearer"}, + ) + + if not auth: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Missing authentication token", + headers={"WWW-Authenticate": "Bearer"}, + ) + + if auth.credentials != public_key: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid authentication token", + headers={"WWW-Authenticate": "Bearer"}, + ) + + return auth.credentials diff --git a/app/core/config.py b/app/core/config.py index bec87149..efb0ec27 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -44,31 +44,34 @@ def _migrate_deprecated_config( # 配置映射规则:旧配置 -> 新配置 MIGRATION_MAP = { # grok.* -> 对应的新配置节 - "grok.temporary": "chat.temporary", - "grok.disable_memory": "chat.disable_memory", - "grok.stream": "chat.stream", - "grok.thinking": "chat.thinking", - "grok.dynamic_statsig": "chat.dynamic_statsig", - "grok.filter_tags": "chat.filter_tags", - "grok.timeout": "network.timeout", - "grok.base_proxy_url": "network.base_proxy_url", - "grok.asset_proxy_url": "network.asset_proxy_url", - "grok.cf_clearance": "security.cf_clearance", - "grok.browser": "security.browser", - "grok.user_agent": "security.user_agent", + "grok.temporary": "app.temporary", + "grok.disable_memory": "app.disable_memory", + "grok.stream": "app.stream", + "grok.thinking": "app.thinking", + "grok.dynamic_statsig": "app.dynamic_statsig", + "grok.filter_tags": "app.filter_tags", + "grok.timeout": "voice.timeout", + "grok.base_proxy_url": "proxy.base_proxy_url", + "grok.asset_proxy_url": "proxy.asset_proxy_url", + "network.base_proxy_url": "proxy.base_proxy_url", + "network.asset_proxy_url": "proxy.asset_proxy_url", + "grok.cf_clearance": "proxy.cf_clearance", + "grok.browser": "proxy.browser", + "grok.user_agent": "proxy.user_agent", + "security.cf_clearance": "proxy.cf_clearance", + "security.browser": "proxy.browser", + "security.user_agent": "proxy.user_agent", "grok.max_retry": "retry.max_retry", "grok.retry_status_codes": "retry.retry_status_codes", "grok.retry_backoff_base": "retry.retry_backoff_base", "grok.retry_backoff_factor": "retry.retry_backoff_factor", "grok.retry_backoff_max": "retry.retry_backoff_max", "grok.retry_budget": "retry.retry_budget", - "grok.stream_idle_timeout": "timeout.stream_idle_timeout", - "grok.video_idle_timeout": "timeout.video_idle_timeout", - "grok.image_ws": "image.image_ws", - "grok.image_ws_nsfw": "image.image_ws_nsfw", - "grok.image_ws_blocked_seconds": "image.image_ws_blocked_seconds", - "grok.image_ws_final_min_bytes": "image.image_ws_final_min_bytes", - "grok.image_ws_medium_min_bytes": "image.image_ws_medium_min_bytes", + "grok.video_idle_timeout": "video.stream_timeout", + "grok.image_ws_nsfw": "image.nsfw", + "grok.image_ws_blocked_seconds": "image.final_timeout", + "grok.image_ws_final_min_bytes": "image.final_min_bytes", + "grok.image_ws_medium_min_bytes": "image.medium_min_bytes", } deprecated_sections = set(config.keys()) - valid_sections @@ -98,8 +101,32 @@ def _migrate_deprecated_config( migrated_count += 1 logger.debug(f"Migrated config: {old_path} -> {new_path} = {old_value}") + # 兼容旧 chat.* 配置键迁移到 app.* + legacy_chat_map = { + "temporary": "temporary", + "disable_memory": "disable_memory", + "stream": "stream", + "thinking": "thinking", + "dynamic_statsig": "dynamic_statsig", + "filter_tags": "filter_tags", + } + chat_section = config.get("chat") + if isinstance(chat_section, dict): + app_section = result.setdefault("app", {}) + for old_key, new_key in legacy_chat_map.items(): + if old_key in chat_section and new_key not in app_section: + app_section[new_key] = chat_section[old_key] + if isinstance(result.get("chat"), dict): + result["chat"].pop(old_key, None) + migrated_count += 1 + logger.debug( + f"Migrated config: chat.{old_key} -> app.{new_key} = {chat_section[old_key]}" + ) + if migrated_count > 0: - logger.info(f"Migrated {migrated_count} config items from deprecated sections") + logger.info( + f"Migrated {migrated_count} config items from deprecated/legacy sections" + ) return result, deprecated_sections diff --git a/app/core/exceptions.py b/app/core/exceptions.py index 4ae9092c..24aa281b 100644 --- a/app/core/exceptions.py +++ b/app/core/exceptions.py @@ -218,7 +218,6 @@ def register_exception_handlers(app): app.add_exception_handler(HTTPException, http_exception_handler) app.add_exception_handler(RequestValidationError, validation_exception_handler) app.add_exception_handler(Exception, generic_exception_handler) - app.add_exception_handler(Exception, generic_exception_handler) __all__ = [ diff --git a/app/services/grok/batch_services/nsfw.py b/app/services/grok/batch_services/nsfw.py index 6fbdb9c1..1c8faa0c 100644 --- a/app/services/grok/batch_services/nsfw.py +++ b/app/services/grok/batch_services/nsfw.py @@ -43,7 +43,7 @@ async def batch( batch_size = get_config("nsfw.batch_size") async def _enable(token: str): try: - browser = get_config("security.browser") + browser = get_config("proxy.browser") async with AsyncSession(impersonate=browser) as session: async def _record_fail(err: UpstreamException, reason: str): status = None diff --git a/app/services/grok/defaults.py b/app/services/grok/defaults.py index 3a66dd2d..06a41584 100644 --- a/app/services/grok/defaults.py +++ b/app/services/grok/defaults.py @@ -10,26 +10,36 @@ "app_url": "", "app_key": "grok2api", "api_key": "", + "public_key": "", + "public_enabled": False, "image_format": "url", "video_format": "html", + "temporary": True, + "disable_memory": True, + "stream": True, + "thinking": True, + "dynamic_statsig": True, + "filter_tags": ["xaiartifact", "xai:tool_usage_card", "grok:render"], }, - "network": { - "timeout": 120, + "proxy": { "base_proxy_url": "", "asset_proxy_url": "", - }, - "security": { "cf_clearance": "", "browser": "chrome136", "user_agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36", }, + "voice": { + "timeout": 120, + }, "chat": { - "temporary": True, - "disable_memory": True, - "stream": True, - "thinking": False, - "dynamic_statsig": True, - "filter_tags": ["grok:render", "xaiartifact", "xai:tool_usage_card"], + "concurrent": 10, + "timeout": 60, + "stream_timeout": 60, + }, + "video": { + "concurrent": 10, + "timeout": 60, + "stream_timeout": 60, }, "retry": { "max_retry": 3, @@ -39,16 +49,13 @@ "retry_backoff_max": 30.0, "retry_budget": 90.0, }, - "timeout": { - "stream_idle_timeout": 45.0, - "video_idle_timeout": 90.0, - }, "image": { - "image_ws": True, - "image_ws_nsfw": True, - "image_ws_blocked_seconds": 15, - "image_ws_final_min_bytes": 100000, - "image_ws_medium_min_bytes": 30000, + "timeout": 120, + "stream_timeout": 120, + "final_timeout": 15, + "nsfw": True, + "medium_min_bytes": 30000, + "final_min_bytes": 100000, }, "token": { "auto_refresh": True, @@ -84,9 +91,6 @@ "batch_size": 50, "timeout": 60, }, - "performance": { - "media_max_concurrent": 50, - }, } diff --git a/app/services/grok/processors/__init__.py b/app/services/grok/processors/__init__.py deleted file mode 100644 index 3cf7d613..00000000 --- a/app/services/grok/processors/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -""" -OpenAI 响应格式处理器 -""" - -from .video import VideoStreamProcessor, VideoCollectProcessor -from .image import ( - ImageStreamProcessor, - ImageCollectProcessor, - ImageWSStreamProcessor, - ImageWSCollectProcessor, -) - -__all__ = [ - "VideoStreamProcessor", - "VideoCollectProcessor", - "ImageStreamProcessor", - "ImageCollectProcessor", - "ImageWSStreamProcessor", - "ImageWSCollectProcessor", -] diff --git a/app/services/grok/processors/image.py b/app/services/grok/processors/image.py deleted file mode 100644 index bbfdb3bd..00000000 --- a/app/services/grok/processors/image.py +++ /dev/null @@ -1,505 +0,0 @@ -""" -Image response processors (HTTP + WebSocket). -""" - -import asyncio -import base64 -import random -import time -from pathlib import Path -from typing import AsyncGenerator, AsyncIterable, List, Dict, Optional - -import orjson -from curl_cffi.requests.errors import RequestsError - -from app.core.config import get_config -from app.core.logger import logger -from app.core.storage import DATA_DIR -from app.core.exceptions import UpstreamException, StreamIdleTimeoutError -from app.services.grok.utils.process import ( - BaseProcessor, - _with_idle_timeout, - _normalize_line, - _collect_images, - _is_http2_error, -) - - -class ImageStreamProcessor(BaseProcessor): - """HTTP image stream processor.""" - - def __init__( - self, model: str, token: str = "", n: int = 1, response_format: str = "b64_json" - ): - super().__init__(model, token) - self.partial_index = 0 - self.n = n - self.target_index = random.randint(0, 1) if n == 1 else None - self.response_format = response_format - if response_format == "url": - self.response_field = "url" - elif response_format == "base64": - self.response_field = "base64" - else: - self.response_field = "b64_json" - - def _sse(self, event: str, data: dict) -> str: - """Build SSE response.""" - return f"event: {event}\ndata: {orjson.dumps(data).decode()}\n\n" - - async def process( - self, response: AsyncIterable[bytes] - ) -> AsyncGenerator[str, None]: - """Process stream response.""" - final_images = [] - idle_timeout = get_config("timeout.stream_idle_timeout") - - try: - async for line in _with_idle_timeout(response, idle_timeout, self.model): - line = _normalize_line(line) - if not line: - continue - try: - data = orjson.loads(line) - except orjson.JSONDecodeError: - continue - - resp = data.get("result", {}).get("response", {}) - - # Image generation progress - if img := resp.get("streamingImageGenerationResponse"): - image_index = img.get("imageIndex", 0) - progress = img.get("progress", 0) - - if self.n == 1 and image_index != self.target_index: - continue - - out_index = 0 if self.n == 1 else image_index - - yield self._sse( - "image_generation.partial_image", - { - "type": "image_generation.partial_image", - self.response_field: "", - "index": out_index, - "progress": progress, - }, - ) - continue - - # modelResponse - if mr := resp.get("modelResponse"): - if urls := _collect_images(mr): - for url in urls: - if self.response_format == "url": - processed = await self.process_url(url, "image") - if processed: - final_images.append(processed) - continue - try: - dl_service = self._get_dl() - base64_data = await dl_service.parse_b64( - url, self.token, "image" - ) - if base64_data: - if "," in base64_data: - b64 = base64_data.split(",", 1)[1] - else: - b64 = base64_data - final_images.append(b64) - except Exception as e: - logger.warning( - f"Failed to convert image to base64, falling back to URL: {e}" - ) - processed = await self.process_url(url, "image") - if processed: - final_images.append(processed) - continue - - for index, b64 in enumerate(final_images): - if self.n == 1: - if index != self.target_index: - continue - out_index = 0 - else: - out_index = index - - yield self._sse( - "image_generation.completed", - { - "type": "image_generation.completed", - self.response_field: b64, - "index": out_index, - "usage": { - "total_tokens": 0, - "input_tokens": 0, - "output_tokens": 0, - "input_tokens_details": { - "text_tokens": 0, - "image_tokens": 0, - }, - }, - }, - ) - except asyncio.CancelledError: - logger.debug("Image stream cancelled by client") - except StreamIdleTimeoutError as e: - raise UpstreamException( - message=f"Image stream idle timeout after {e.idle_seconds}s", - status_code=504, - details={ - "error": str(e), - "type": "stream_idle_timeout", - "idle_seconds": e.idle_seconds, - }, - ) - except RequestsError as e: - if _is_http2_error(e): - logger.warning(f"HTTP/2 stream error in image: {e}") - raise UpstreamException( - message="Upstream connection closed unexpectedly", - status_code=502, - details={"error": str(e), "type": "http2_stream_error"}, - ) - logger.error(f"Image stream request error: {e}") - raise UpstreamException( - message=f"Upstream request failed: {e}", - status_code=502, - details={"error": str(e)}, - ) - except Exception as e: - logger.error( - f"Image stream processing error: {e}", - extra={"error_type": type(e).__name__}, - ) - raise - finally: - await self.close() - - -class ImageCollectProcessor(BaseProcessor): - """HTTP image non-stream processor.""" - - def __init__(self, model: str, token: str = "", response_format: str = "b64_json"): - super().__init__(model, token) - self.response_format = response_format - - async def process(self, response: AsyncIterable[bytes]) -> List[str]: - """Process and collect images.""" - images = [] - idle_timeout = get_config("timeout.stream_idle_timeout") - - try: - async for line in _with_idle_timeout(response, idle_timeout, self.model): - line = _normalize_line(line) - if not line: - continue - try: - data = orjson.loads(line) - except orjson.JSONDecodeError: - continue - - resp = data.get("result", {}).get("response", {}) - - if mr := resp.get("modelResponse"): - if urls := _collect_images(mr): - for url in urls: - if self.response_format == "url": - processed = await self.process_url(url, "image") - if processed: - images.append(processed) - continue - try: - dl_service = self._get_dl() - base64_data = await dl_service.parse_b64( - url, self.token, "image" - ) - if base64_data: - if "," in base64_data: - b64 = base64_data.split(",", 1)[1] - else: - b64 = base64_data - images.append(b64) - except Exception as e: - logger.warning( - f"Failed to convert image to base64, falling back to URL: {e}" - ) - processed = await self.process_url(url, "image") - if processed: - images.append(processed) - - except asyncio.CancelledError: - logger.debug("Image collect cancelled by client") - except StreamIdleTimeoutError as e: - logger.warning(f"Image collect idle timeout: {e}") - except RequestsError as e: - if _is_http2_error(e): - logger.warning(f"HTTP/2 stream error in image collect: {e}") - else: - logger.error(f"Image collect request error: {e}") - except Exception as e: - logger.error( - f"Image collect processing error: {e}", - extra={"error_type": type(e).__name__}, - ) - finally: - await self.close() - - return images - - -class ImageWSBaseProcessor(BaseProcessor): - """WebSocket image processor base.""" - - def __init__(self, model: str, token: str = "", response_format: str = "b64_json"): - super().__init__(model, token) - self.response_format = response_format - if response_format == "url": - self.response_field = "url" - elif response_format == "base64": - self.response_field = "base64" - else: - self.response_field = "b64_json" - self._image_dir: Optional[Path] = None - - def _ensure_image_dir(self) -> Path: - if self._image_dir is None: - base_dir = DATA_DIR / "tmp" / "image" - base_dir.mkdir(parents=True, exist_ok=True) - self._image_dir = base_dir - return self._image_dir - - def _strip_base64(self, blob: str) -> str: - if not blob: - return "" - if "," in blob and "base64" in blob.split(",", 1)[0]: - return blob.split(",", 1)[1] - return blob - - def _filename(self, image_id: str, is_final: bool) -> str: - ext = "jpg" if is_final else "png" - return f"{image_id}.{ext}" - - def _build_file_url(self, filename: str) -> str: - app_url = get_config("app.app_url") - if app_url: - return f"{app_url.rstrip('/')}/v1/files/image/{filename}" - return f"/v1/files/image/{filename}" - - def _save_blob(self, image_id: str, blob: str, is_final: bool) -> str: - data = self._strip_base64(blob) - if not data: - return "" - image_dir = self._ensure_image_dir() - filename = self._filename(image_id, is_final) - filepath = image_dir / filename - with open(filepath, "wb") as f: - f.write(base64.b64decode(data)) - return self._build_file_url(filename) - - def _pick_best(self, existing: Optional[Dict], incoming: Dict) -> Dict: - if not existing: - return incoming - if incoming.get("is_final") and not existing.get("is_final"): - return incoming - if existing.get("is_final") and not incoming.get("is_final"): - return existing - if incoming.get("blob_size", 0) > existing.get("blob_size", 0): - return incoming - return existing - - def _to_output(self, image_id: str, item: Dict) -> str: - try: - if self.response_format == "url": - return self._save_blob( - image_id, item.get("blob", ""), item.get("is_final", False) - ) - return self._strip_base64(item.get("blob", "")) - except Exception as e: - logger.warning(f"Image output failed: {e}") - return "" - - -class ImageWSStreamProcessor(ImageWSBaseProcessor): - """WebSocket image stream processor.""" - - def __init__( - self, - model: str, - token: str = "", - n: int = 1, - response_format: str = "b64_json", - size: str = "1024x1024", - ): - super().__init__(model, token, "b64_json") - self.n = n - self.size = size - self._target_id: Optional[str] = None - self._index_map: Dict[str, int] = {} - self._partial_map: Dict[str, int] = {} - - def _assign_index(self, image_id: str) -> Optional[int]: - if image_id in self._index_map: - return self._index_map[image_id] - if len(self._index_map) >= self.n: - return None - self._index_map[image_id] = len(self._index_map) - return self._index_map[image_id] - - def _sse(self, event: str, data: dict) -> str: - return f"event: {event}\ndata: {orjson.dumps(data).decode()}\n\n" - - async def process(self, response: AsyncIterable[dict]) -> AsyncGenerator[str, None]: - images: Dict[str, Dict] = {} - - async for item in response: - if item.get("type") == "error": - message = item.get("error") or "Upstream error" - code = item.get("error_code") or "upstream_error" - yield self._sse( - "error", - { - "error": { - "message": message, - "type": "server_error", - "code": code, - } - }, - ) - return - if item.get("type") != "image": - continue - - image_id = item.get("image_id") - if not image_id: - continue - - if self.n == 1: - if self._target_id is None: - self._target_id = image_id - index = 0 if image_id == self._target_id else None - else: - index = self._assign_index(image_id) - - images[image_id] = self._pick_best(images.get(image_id), item) - - if index is None: - continue - - if item.get("stage") != "final": - partial_b64 = self._strip_base64(item.get("blob", "")) - if not partial_b64: - continue - partial_index = self._partial_map.get(image_id, 0) - if item.get("stage") == "medium": - partial_index = max(partial_index, 1) - self._partial_map[image_id] = partial_index - yield self._sse( - "image_generation.partial_image", - { - "type": "image_generation.partial_image", - "b64_json": partial_b64, - "created_at": int(time.time()), - "size": self.size, - "index": index, - "partial_image_index": partial_index, - }, - ) - - if self.n == 1: - if self._target_id and self._target_id in images: - selected = [(self._target_id, images[self._target_id])] - else: - selected = ( - [ - max( - images.items(), - key=lambda x: ( - x[1].get("is_final", False), - x[1].get("blob_size", 0), - ), - ) - ] - if images - else [] - ) - else: - selected = [ - (image_id, images[image_id]) - for image_id in self._index_map - if image_id in images - ] - - for image_id, item in selected: - output = self._strip_base64(item.get("blob", "")) - if not output: - continue - - if self.n == 1: - index = 0 - else: - index = self._index_map.get(image_id, 0) - yield self._sse( - "image_generation.completed", - { - "type": "image_generation.completed", - "b64_json": output, - "created_at": int(time.time()), - "size": self.size, - "index": index, - "usage": { - "total_tokens": 0, - "input_tokens": 0, - "output_tokens": 0, - "input_tokens_details": {"text_tokens": 0, "image_tokens": 0}, - }, - }, - ) - - -class ImageWSCollectProcessor(ImageWSBaseProcessor): - """WebSocket image non-stream processor.""" - - def __init__( - self, model: str, token: str = "", n: int = 1, response_format: str = "b64_json" - ): - super().__init__(model, token, response_format) - self.n = n - - async def process(self, response: AsyncIterable[dict]) -> List[str]: - images: Dict[str, Dict] = {} - - async for item in response: - if item.get("type") == "error": - message = item.get("error") or "Upstream error" - raise UpstreamException(message, details=item) - if item.get("type") != "image": - continue - image_id = item.get("image_id") - if not image_id: - continue - images[image_id] = self._pick_best(images.get(image_id), item) - - selected = sorted( - images.values(), - key=lambda x: (x.get("is_final", False), x.get("blob_size", 0)), - reverse=True, - ) - if self.n: - selected = selected[: self.n] - - results: List[str] = [] - for item in selected: - output = self._to_output(item.get("image_id", ""), item) - if output: - results.append(output) - - return results - - -__all__ = [ - "ImageStreamProcessor", - "ImageCollectProcessor", - "ImageWSStreamProcessor", - "ImageWSCollectProcessor", -] diff --git a/app/services/grok/processors/video.py b/app/services/grok/processors/video.py deleted file mode 100644 index fb501cac..00000000 --- a/app/services/grok/processors/video.py +++ /dev/null @@ -1,235 +0,0 @@ -""" -Video response processors. -""" - -import asyncio -import uuid -from typing import Any, AsyncGenerator, AsyncIterable, Optional - -import orjson -from curl_cffi.requests.errors import RequestsError - -from app.core.config import get_config -from app.core.logger import logger -from app.core.exceptions import UpstreamException, StreamIdleTimeoutError -from app.services.grok.utils.process import ( - BaseProcessor, - _with_idle_timeout, - _normalize_line, - _is_http2_error, -) - - -class VideoStreamProcessor(BaseProcessor): - """Video stream response processor.""" - - def __init__(self, model: str, token: str = "", show_think: bool = None): - super().__init__(model, token) - self.response_id: Optional[str] = None - self.think_opened: bool = False - self.role_sent: bool = False - - self.show_think = bool(show_think) - - def _sse(self, content: str = "", role: str = None, finish: str = None) -> str: - """Build SSE response.""" - delta = {} - if role: - delta["role"] = role - delta["content"] = "" - elif content: - delta["content"] = content - - chunk = { - "id": self.response_id or f"chatcmpl-{uuid.uuid4().hex[:24]}", - "object": "chat.completion.chunk", - "created": self.created, - "model": self.model, - "choices": [ - {"index": 0, "delta": delta, "logprobs": None, "finish_reason": finish} - ], - } - return f"data: {orjson.dumps(chunk).decode()}\n\n" - - async def process( - self, response: AsyncIterable[bytes] - ) -> AsyncGenerator[str, None]: - """Process video stream response.""" - idle_timeout = get_config("timeout.video_idle_timeout") - - try: - async for line in _with_idle_timeout(response, idle_timeout, self.model): - line = _normalize_line(line) - if not line: - continue - try: - data = orjson.loads(line) - except orjson.JSONDecodeError: - continue - - resp = data.get("result", {}).get("response", {}) - - if rid := resp.get("responseId"): - self.response_id = rid - - if not self.role_sent: - yield self._sse(role="assistant") - self.role_sent = True - - # Video generation progress - if video_resp := resp.get("streamingVideoGenerationResponse"): - progress = video_resp.get("progress", 0) - - if self.show_think: - if not self.think_opened: - yield self._sse("\n") - self.think_opened = True - yield self._sse(f"正在生成视频中,当前进度{progress}%\n") - - if progress == 100: - video_url = video_resp.get("videoUrl", "") - thumbnail_url = video_resp.get("thumbnailImageUrl", "") - - if self.think_opened and self.show_think: - yield self._sse("\n") - self.think_opened = False - - if video_url: - dl_service = self._get_dl() - rendered = await dl_service.render_video( - video_url, self.token, thumbnail_url - ) - yield self._sse(rendered) - - logger.info(f"Video generated: {video_url}") - continue - - if self.think_opened: - yield self._sse("\n") - yield self._sse(finish="stop") - yield "data: [DONE]\n\n" - except asyncio.CancelledError: - logger.debug( - "Video stream cancelled by client", extra={"model": self.model} - ) - except StreamIdleTimeoutError as e: - raise UpstreamException( - message=f"Video stream idle timeout after {e.idle_seconds}s", - status_code=504, - details={ - "error": str(e), - "type": "stream_idle_timeout", - "idle_seconds": e.idle_seconds, - }, - ) - except RequestsError as e: - if _is_http2_error(e): - logger.warning( - f"HTTP/2 stream error in video: {e}", extra={"model": self.model} - ) - raise UpstreamException( - message="Upstream connection closed unexpectedly", - status_code=502, - details={"error": str(e), "type": "http2_stream_error"}, - ) - logger.error( - f"Video stream request error: {e}", extra={"model": self.model} - ) - raise UpstreamException( - message=f"Upstream request failed: {e}", - status_code=502, - details={"error": str(e)}, - ) - except Exception as e: - logger.error( - f"Video stream processing error: {e}", - extra={"model": self.model, "error_type": type(e).__name__}, - ) - finally: - await self.close() - - -class VideoCollectProcessor(BaseProcessor): - """Video non-stream response processor.""" - - def __init__(self, model: str, token: str = ""): - super().__init__(model, token) - - async def process(self, response: AsyncIterable[bytes]) -> dict[str, Any]: - """Process and collect video response.""" - response_id = "" - content = "" - idle_timeout = get_config("timeout.video_idle_timeout") - - try: - async for line in _with_idle_timeout(response, idle_timeout, self.model): - line = _normalize_line(line) - if not line: - continue - try: - data = orjson.loads(line) - except orjson.JSONDecodeError: - continue - - resp = data.get("result", {}).get("response", {}) - - if video_resp := resp.get("streamingVideoGenerationResponse"): - if video_resp.get("progress") == 100: - response_id = resp.get("responseId", "") - video_url = video_resp.get("videoUrl", "") - thumbnail_url = video_resp.get("thumbnailImageUrl", "") - - if video_url: - dl_service = self._get_dl() - content = await dl_service.render_video( - video_url, self.token, thumbnail_url - ) - logger.info(f"Video generated: {video_url}") - - except asyncio.CancelledError: - logger.debug( - "Video collect cancelled by client", extra={"model": self.model} - ) - except StreamIdleTimeoutError as e: - logger.warning( - f"Video collect idle timeout: {e}", extra={"model": self.model} - ) - except RequestsError as e: - if _is_http2_error(e): - logger.warning( - f"HTTP/2 stream error in video collect: {e}", - extra={"model": self.model}, - ) - else: - logger.error( - f"Video collect request error: {e}", extra={"model": self.model} - ) - except Exception as e: - logger.error( - f"Video collect processing error: {e}", - extra={"model": self.model, "error_type": type(e).__name__}, - ) - finally: - await self.close() - - return { - "id": response_id, - "object": "chat.completion", - "created": self.created, - "model": self.model, - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": content, - "refusal": None, - }, - "finish_reason": "stop", - } - ], - "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, - } - - -__all__ = ["VideoStreamProcessor", "VideoCollectProcessor"] diff --git a/app/services/grok/services/chat.py b/app/services/grok/services/chat.py index e2a452b9..0d203276 100644 --- a/app/services/grok/services/chat.py +++ b/app/services/grok/services/chat.py @@ -23,11 +23,25 @@ from app.services.grok.services.model import ModelService from app.services.grok.utils.upload import UploadService from app.services.grok.utils import process as proc_base +from app.services.grok.utils.retry import pick_token, rate_limited from app.services.reverse.app_chat import AppChatReverse from app.services.grok.utils.stream import wrap_stream_with_usage from app.services.token import get_token_manager, EffortType +_CHAT_SEMAPHORE = None +_CHAT_SEM_VALUE = None + + +def _get_chat_semaphore() -> asyncio.Semaphore: + global _CHAT_SEMAPHORE, _CHAT_SEM_VALUE + value = max(1, int(get_config("chat.concurrent"))) + if value != _CHAT_SEM_VALUE: + _CHAT_SEM_VALUE = value + _CHAT_SEMAPHORE = asyncio.Semaphore(value) + return _CHAT_SEMAPHORE + + class MessageExtractor: """消息内容提取器""" @@ -110,31 +124,39 @@ async def chat( ): """发送聊天请求""" if stream is None: - stream = get_config("chat.stream") + stream = get_config("app.stream") logger.debug( f"Chat request: model={model}, mode={mode}, stream={stream}, attachments={len(file_attachments or [])}" ) - browser = get_config("security.browser") - session = AsyncSession(impersonate=browser) - try: - stream_response = await AppChatReverse.request( - session, - token, - message=message, - model=model, - mode=mode, - file_attachments=file_attachments, - tool_overrides=tool_overrides, - model_config_override=model_config_override, - ) - logger.info(f"Chat connected: model={model}, stream={stream}") - except Exception: - await session.close() - raise + browser = get_config("proxy.browser") - return stream_response + async def _stream(): + session = AsyncSession(impersonate=browser) + try: + async with _get_chat_semaphore(): + stream_response = await AppChatReverse.request( + session, + token, + message=message, + model=model, + mode=mode, + file_attachments=file_attachments, + tool_overrides=tool_overrides, + model_config_override=model_config_override, + ) + logger.info(f"Chat connected: model={model}, stream={stream}") + async for line in stream_response: + yield line + except Exception: + try: + await session.close() + except Exception: + pass + raise + + return _stream() async def chat_openai( self, @@ -180,7 +202,7 @@ async def chat_openai( await upload_service.close() all_attachments = file_ids + image_ids - stream = stream if stream is not None else get_config("chat.stream") + stream = stream if stream is not None else get_config("app.stream") model_config_override = { "temperature": temperature, @@ -221,10 +243,10 @@ async def completions( # 解析参数 if reasoning_effort is None: - show_think = get_config("chat.thinking") + show_think = get_config("app.thinking") else: show_think = reasoning_effort != "none" - is_stream = stream if stream is not None else get_config("chat.stream") + is_stream = stream if stream is not None else get_config("app.stream") # 跨 Token 重试循环 tried_tokens = set() @@ -233,22 +255,7 @@ async def completions( for attempt in range(max_token_retries): # 选择 token - token = None - for pool_name in ModelService.pool_candidates_for_model(model): - token = token_mgr.get_token(pool_name, exclude=tried_tokens) - if token: - break - - if not token and not tried_tokens: - # 首次就无 token,尝试刷新 - logger.info("No available tokens, attempting to refresh cooling tokens...") - result = await token_mgr.refresh_cooling_tokens() - if result.get("recovered", 0) > 0: - for pool_name in ModelService.pool_candidates_for_model(model): - token = token_mgr.get_token(pool_name) - if token: - break - + token = await pick_token(token_mgr, model, tried_tokens) if not token: if last_error: raise last_error @@ -299,10 +306,9 @@ async def completions( return result except UpstreamException as e: - status_code = e.details.get("status") if e.details else None last_error = e - if status_code == 429: + if rate_limited(e): # 配额不足,标记 token 为 cooling 并换 token 重试 await token_mgr.mark_rate_limited(token) logger.warning( @@ -334,7 +340,7 @@ def __init__(self, model: str, token: str = "", show_think: bool = None): self.fingerprint: str = "" self.think_opened: bool = False self.role_sent: bool = False - self.filter_tags = get_config("chat.filter_tags") + self.filter_tags = get_config("app.filter_tags") self._tag_buffer: str = "" self._in_filter_tag: bool = False @@ -419,7 +425,7 @@ async def process(self, response: AsyncIterable[bytes]) -> AsyncGenerator[str, N Returns: AsyncGenerator[str, None], async generator of strings """ - idle_timeout = get_config("timeout.stream_idle_timeout") + idle_timeout = get_config("chat.stream_timeout") try: async for line in proc_base._with_idle_timeout( @@ -563,7 +569,7 @@ class CollectProcessor(proc_base.BaseProcessor): def __init__(self, model: str, token: str = ""): super().__init__(model, token) - self.filter_tags = get_config("chat.filter_tags") + self.filter_tags = get_config("app.filter_tags") def _filter_content(self, content: str) -> str: """Filter special tags in content.""" @@ -582,7 +588,7 @@ async def process(self, response: AsyncIterable[bytes]) -> dict[str, Any]: response_id = "" fingerprint = "" content = "" - idle_timeout = get_config("timeout.stream_idle_timeout") + idle_timeout = get_config("chat.stream_timeout") try: async for line in proc_base._with_idle_timeout( diff --git a/app/services/grok/services/image.py b/app/services/grok/services/image.py index 4bba8d9f..7db7bf98 100644 --- a/app/services/grok/services/image.py +++ b/app/services/grok/services/image.py @@ -3,26 +3,26 @@ """ import asyncio +import base64 import math -import random +import time from dataclasses import dataclass -from typing import Any, AsyncGenerator, List, Optional, Union +from pathlib import Path +from typing import Any, AsyncGenerator, AsyncIterable, Dict, List, Optional, Union + +import orjson from app.core.config import get_config from app.core.logger import logger -from app.services.grok.processors import ( - ImageStreamProcessor, - ImageCollectProcessor, - ImageWSStreamProcessor, - ImageWSCollectProcessor, -) -from app.services.grok.services.chat import GrokChatService +from app.core.storage import DATA_DIR +from app.core.exceptions import AppException, ErrorType, UpstreamException +from app.services.grok.utils.process import BaseProcessor +from app.services.grok.utils.retry import pick_token, rate_limited from app.services.grok.utils.stream import wrap_stream_with_usage from app.services.token import EffortType from app.services.reverse.ws_imagine import ImagineWebSocketReverse -ImageService = ImagineWebSocketReverse image_service = ImagineWebSocketReverse() @@ -48,47 +48,114 @@ async def generate( size: str, aspect_ratio: str, stream: bool, - use_ws: bool, ) -> ImageGenerationResult: + max_token_retries = int(get_config("retry.max_retry")) + tried_tokens: set[str] = set() + last_error: Optional[Exception] = None + if stream: - if use_ws: - return await self._stream_ws( + async def _stream_retry() -> AsyncGenerator[str, None]: + nonlocal last_error + for attempt in range(max_token_retries): + preferred = token if attempt == 0 else None + current_token = await pick_token( + token_mgr, model_info.model_id, tried_tokens, preferred=preferred + ) + if not current_token: + if last_error: + raise last_error + raise AppException( + message="No available tokens. Please try again later.", + error_type=ErrorType.RATE_LIMIT.value, + code="rate_limit_exceeded", + status_code=429, + ) + + tried_tokens.add(current_token) + yielded = False + try: + result = await self._stream_ws( + token_mgr=token_mgr, + token=current_token, + model_info=model_info, + prompt=prompt, + n=n, + response_format=response_format, + size=size, + aspect_ratio=aspect_ratio, + ) + async for chunk in result.data: + yielded = True + yield chunk + return + except UpstreamException as e: + last_error = e + if rate_limited(e): + if yielded: + raise + await token_mgr.mark_rate_limited(current_token) + logger.warning( + f"Token {current_token[:10]}... rate limited (429), " + f"trying next token (attempt {attempt + 1}/{max_token_retries})" + ) + continue + raise + + if last_error: + raise last_error + raise AppException( + message="No available tokens. Please try again later.", + error_type=ErrorType.RATE_LIMIT.value, + code="rate_limit_exceeded", + status_code=429, + ) + + return ImageGenerationResult(stream=True, data=_stream_retry()) + + for attempt in range(max_token_retries): + preferred = token if attempt == 0 else None + current_token = await pick_token( + token_mgr, model_info.model_id, tried_tokens, preferred=preferred + ) + if not current_token: + if last_error: + raise last_error + raise AppException( + message="No available tokens. Please try again later.", + error_type=ErrorType.RATE_LIMIT.value, + code="rate_limit_exceeded", + status_code=429, + ) + + tried_tokens.add(current_token) + try: + return await self._collect_ws( token_mgr=token_mgr, - token=token, + token=current_token, model_info=model_info, prompt=prompt, n=n, response_format=response_format, - size=size, aspect_ratio=aspect_ratio, ) - return await self._stream_http( - token_mgr=token_mgr, - token=token, - model_info=model_info, - prompt=prompt, - n=n, - response_format=response_format, - ) + except UpstreamException as e: + last_error = e + if rate_limited(e): + await token_mgr.mark_rate_limited(current_token) + logger.warning( + f"Token {current_token[:10]}... rate limited (429), " + f"trying next token (attempt {attempt + 1}/{max_token_retries})" + ) + continue + raise - if use_ws: - return await self._collect_ws( - token_mgr=token_mgr, - token=token, - model_info=model_info, - prompt=prompt, - n=n, - response_format=response_format, - aspect_ratio=aspect_ratio, - ) - - return await self._collect_http( - token_mgr=token_mgr, - token=token, - model_info=model_info, - prompt=prompt, - n=n, - response_format=response_format, + if last_error: + raise last_error + raise AppException( + message="No available tokens. Please try again later.", + error_type=ErrorType.RATE_LIMIT.value, + code="rate_limit_exceeded", + status_code=429, ) async def _stream_ws( @@ -103,7 +170,7 @@ async def _stream_ws( size: str, aspect_ratio: str, ) -> ImageGenerationResult: - enable_nsfw = bool(get_config("image.image_ws_nsfw")) + enable_nsfw = bool(get_config("image.nsfw")) upstream = image_service.stream( token=token, prompt=prompt, @@ -126,37 +193,6 @@ async def _stream_ws( ) return ImageGenerationResult(stream=True, data=stream) - async def _stream_http( - self, - *, - token_mgr: Any, - token: str, - model_info: Any, - prompt: str, - n: int, - response_format: str, - ) -> ImageGenerationResult: - response = await GrokChatService().chat( - token=token, - message=f"Image Generation: {prompt}", - model=model_info.grok_model, - mode=model_info.model_mode, - stream=True, - ) - processor = ImageStreamProcessor( - model_info.model_id, - token, - n=n, - response_format=response_format, - ) - stream = wrap_stream_with_usage( - processor.process(response), - token_mgr, - token, - model_info.model_id, - ) - return ImageGenerationResult(stream=True, data=stream) - async def _collect_ws( self, *, @@ -168,7 +204,7 @@ async def _collect_ws( response_format: str, aspect_ratio: str, ) -> ImageGenerationResult: - enable_nsfw = bool(get_config("image.image_ws_nsfw")) + enable_nsfw = bool(get_config("image.nsfw")) all_images: List[str] = [] seen = set() expected_per_call = 6 @@ -227,60 +263,6 @@ async def _fetch_batch(call_target: int): stream=False, data=selected, usage_override=usage_override ) - async def _collect_http( - self, - *, - token_mgr: Any, - token: str, - model_info: Any, - prompt: str, - n: int, - response_format: str, - ) -> ImageGenerationResult: - calls_needed = (n + 1) // 2 - - async def _call_grok(): - success = False - try: - response = await GrokChatService().chat( - token=token, - message=f"Image Generation: {prompt}", - model=model_info.grok_model, - mode=model_info.model_mode, - stream=True, - ) - processor = ImageCollectProcessor( - model_info.model_id, token, response_format=response_format - ) - images = await processor.process(response) - success = True - return images - except Exception as e: - logger.error(f"Grok image call failed: {e}") - return [] - finally: - if success: - try: - await token_mgr.consume(token, self._get_effort(model_info)) - except Exception as e: - logger.warning(f"Failed to consume token: {e}") - - if calls_needed == 1: - all_images = await _call_grok() - else: - tasks = [_call_grok() for _ in range(calls_needed)] - results = await asyncio.gather(*tasks, return_exceptions=True) - - all_images: List[str] = [] - for result in results: - if isinstance(result, Exception): - logger.error(f"Concurrent call failed: {result}") - elif isinstance(result, list): - all_images.extend(result) - - selected = self._select_images(all_images, n) - return ImageGenerationResult(stream=False, data=selected) - @staticmethod def _get_effort(model_info: Any) -> EffortType: return ( @@ -292,16 +274,330 @@ def _get_effort(model_info: Any) -> EffortType: @staticmethod def _select_images(images: List[str], n: int) -> List[str]: if len(images) >= n: - return random.sample(images, n) + return images[:n] selected = images.copy() while len(selected) < n: selected.append("error") return selected -__all__ = [ - "image_service", - "ImageService", - "ImageGenerationService", - "ImageGenerationResult", -] +class ImageWSBaseProcessor(BaseProcessor): + """WebSocket image processor base.""" + + def __init__(self, model: str, token: str = "", response_format: str = "b64_json"): + if response_format == "base64": + response_format = "b64_json" + super().__init__(model, token) + self.response_format = response_format + if response_format == "url": + self.response_field = "url" + elif response_format == "base64": + self.response_field = "base64" + else: + self.response_field = "b64_json" + self._image_dir: Optional[Path] = None + + def _ensure_image_dir(self) -> Path: + if self._image_dir is None: + base_dir = DATA_DIR / "tmp" / "image" + base_dir.mkdir(parents=True, exist_ok=True) + self._image_dir = base_dir + return self._image_dir + + def _strip_base64(self, blob: str) -> str: + if not blob: + return "" + if "," in blob and "base64" in blob.split(",", 1)[0]: + return blob.split(",", 1)[1] + return blob + + def _guess_ext(self, blob: str) -> Optional[str]: + if not blob: + return None + header = "" + data = blob + if "," in blob and "base64" in blob.split(",", 1)[0]: + header, data = blob.split(",", 1) + header = header.lower() + if "image/png" in header: + return "png" + if "image/jpeg" in header or "image/jpg" in header: + return "jpg" + if data.startswith("iVBORw0KGgo"): + return "png" + if data.startswith("/9j/"): + return "jpg" + return None + + def _filename(self, image_id: str, is_final: bool, ext: Optional[str] = None) -> str: + if ext: + ext = ext.lower() + if ext == "jpeg": + ext = "jpg" + if not ext: + ext = "jpg" if is_final else "png" + return f"{image_id}.{ext}" + + def _build_file_url(self, filename: str) -> str: + app_url = get_config("app.app_url") + if app_url: + return f"{app_url.rstrip('/')}/v1/files/image/{filename}" + return f"/v1/files/image/{filename}" + + async def _save_blob( + self, image_id: str, blob: str, is_final: bool, ext: Optional[str] = None + ) -> str: + data = self._strip_base64(blob) + if not data: + return "" + image_dir = self._ensure_image_dir() + ext = ext or self._guess_ext(blob) + filename = self._filename(image_id, is_final, ext=ext) + filepath = image_dir / filename + + def _write_file(): + with open(filepath, "wb") as f: + f.write(base64.b64decode(data)) + + await asyncio.to_thread(_write_file) + return self._build_file_url(filename) + + def _pick_best(self, existing: Optional[Dict], incoming: Dict) -> Dict: + if not existing: + return incoming + if incoming.get("is_final") and not existing.get("is_final"): + return incoming + if existing.get("is_final") and not incoming.get("is_final"): + return existing + if incoming.get("blob_size", 0) > existing.get("blob_size", 0): + return incoming + return existing + + async def _to_output(self, image_id: str, item: Dict) -> str: + try: + if self.response_format == "url": + return await self._save_blob( + image_id, + item.get("blob", ""), + item.get("is_final", False), + ext=item.get("ext"), + ) + return self._strip_base64(item.get("blob", "")) + except Exception as e: + logger.warning(f"Image output failed: {e}") + return "" + + +class ImageWSStreamProcessor(ImageWSBaseProcessor): + """WebSocket image stream processor.""" + + def __init__( + self, + model: str, + token: str = "", + n: int = 1, + response_format: str = "b64_json", + size: str = "1024x1024", + ): + super().__init__(model, token, response_format) + self.n = n + self.size = size + self._target_id: Optional[str] = None + self._index_map: Dict[str, int] = {} + self._partial_map: Dict[str, int] = {} + self._initial_sent: set[str] = set() + + def _assign_index(self, image_id: str) -> Optional[int]: + if image_id in self._index_map: + return self._index_map[image_id] + if len(self._index_map) >= self.n: + return None + self._index_map[image_id] = len(self._index_map) + return self._index_map[image_id] + + def _sse(self, event: str, data: dict) -> str: + return f"event: {event}\ndata: {orjson.dumps(data).decode()}\n\n" + + async def process(self, response: AsyncIterable[dict]) -> AsyncGenerator[str, None]: + images: Dict[str, Dict] = {} + + async for item in response: + if item.get("type") == "error": + message = item.get("error") or "Upstream error" + code = item.get("error_code") or "upstream_error" + status = item.get("status") + if code == "rate_limit_exceeded" or status == 429: + raise UpstreamException(message, details=item) + yield self._sse( + "error", + { + "error": { + "message": message, + "type": "server_error", + "code": code, + } + }, + ) + return + if item.get("type") != "image": + continue + + image_id = item.get("image_id") + if not image_id: + continue + + if self.n == 1: + if self._target_id is None: + self._target_id = image_id + index = 0 if image_id == self._target_id else None + else: + index = self._assign_index(image_id) + + images[image_id] = self._pick_best(images.get(image_id), item) + + if index is None: + continue + + if item.get("stage") != "final": + if image_id not in self._initial_sent: + self._initial_sent.add(image_id) + stage = item.get("stage") or "preview" + if stage == "medium": + partial_index = 1 + self._partial_map[image_id] = 1 + else: + partial_index = 0 + self._partial_map[image_id] = 0 + else: + stage = item.get("stage") or "partial" + if stage == "preview": + continue + partial_index = self._partial_map.get(image_id, 0) + if stage == "medium": + partial_index = max(partial_index, 1) + self._partial_map[image_id] = partial_index + + if self.response_format == "url": + partial_id = f"{image_id}-{stage}-{partial_index}" + partial_out = await self._save_blob( + partial_id, + item.get("blob", ""), + False, + ext=item.get("ext"), + ) + else: + partial_out = self._strip_base64(item.get("blob", "")) + if not partial_out: + continue + yield self._sse( + "image_generation.partial_image", + { + "type": "image_generation.partial_image", + self.response_field: partial_out, + "created_at": int(time.time()), + "size": self.size, + "index": index, + "partial_image_index": partial_index, + }, + ) + + if self.n == 1: + if self._target_id and self._target_id in images: + selected = [(self._target_id, images[self._target_id])] + else: + selected = ( + [ + max( + images.items(), + key=lambda x: ( + x[1].get("is_final", False), + x[1].get("blob_size", 0), + ), + ) + ] + if images + else [] + ) + else: + selected = [ + (image_id, images[image_id]) + for image_id in self._index_map + if image_id in images + ] + + for image_id, item in selected: + if self.response_format == "url": + output = await self._save_blob( + f"{image_id}-final", + item.get("blob", ""), + item.get("is_final", False), + ext=item.get("ext"), + ) + else: + output = await self._to_output(image_id, item) + if not output: + continue + + if self.n == 1: + index = 0 + else: + index = self._index_map.get(image_id, 0) + yield self._sse( + "image_generation.completed", + { + "type": "image_generation.completed", + self.response_field: output, + "created_at": int(time.time()), + "size": self.size, + "index": index, + "usage": { + "total_tokens": 0, + "input_tokens": 0, + "output_tokens": 0, + "input_tokens_details": {"text_tokens": 0, "image_tokens": 0}, + }, + }, + ) + + +class ImageWSCollectProcessor(ImageWSBaseProcessor): + """WebSocket image non-stream processor.""" + + def __init__( + self, model: str, token: str = "", n: int = 1, response_format: str = "b64_json" + ): + super().__init__(model, token, response_format) + self.n = n + + async def process(self, response: AsyncIterable[dict]) -> List[str]: + images: Dict[str, Dict] = {} + + async for item in response: + if item.get("type") == "error": + message = item.get("error") or "Upstream error" + raise UpstreamException(message, details=item) + if item.get("type") != "image": + continue + image_id = item.get("image_id") + if not image_id: + continue + images[image_id] = self._pick_best(images.get(image_id), item) + + selected = sorted( + images.values(), + key=lambda x: (x.get("is_final", False), x.get("blob_size", 0)), + reverse=True, + ) + if self.n: + selected = selected[: self.n] + + results: List[str] = [] + for item in selected: + output = await self._to_output(item.get("image_id", ""), item) + if output: + results.append(output) + + return results + + +__all__ = ["ImageGenerationService"] diff --git a/app/services/grok/services/image_edit.py b/app/services/grok/services/image_edit.py index 0eb81777..eba6f1f3 100644 --- a/app/services/grok/services/image_edit.py +++ b/app/services/grok/services/image_edit.py @@ -6,15 +6,32 @@ import random import re from dataclasses import dataclass -from typing import AsyncGenerator, List, Union, Any +from typing import AsyncGenerator, AsyncIterable, List, Union, Any -from app.core.exceptions import AppException, ErrorType +import orjson +from curl_cffi.requests.errors import RequestsError + +from app.core.config import get_config +from app.core.exceptions import ( + AppException, + ErrorType, + UpstreamException, + StreamIdleTimeoutError, +) from app.core.logger import logger -from app.services.grok.processors import ImageCollectProcessor, ImageStreamProcessor +from app.services.grok.utils.process import ( + BaseProcessor, + _with_idle_timeout, + _normalize_line, + _collect_images, + _is_http2_error, +) from app.services.grok.utils.upload import UploadService +from app.services.grok.utils.retry import pick_token, rate_limited from app.services.grok.services.chat import GrokChatService from app.services.grok.services.video import VideoService from app.services.grok.utils.stream import wrap_stream_with_usage +from app.services.token import EffortType @dataclass @@ -38,60 +55,115 @@ async def edit( response_format: str, stream: bool, ) -> ImageEditResult: - image_urls = await self._upload_images(images, token) - parent_post_id = await self._get_parent_post_id(token, image_urls) - - model_config_override = { - "modelMap": { - "imageEditModel": "imagine", - "imageEditModelConfig": { - "imageReferences": image_urls, - }, - } - } - if parent_post_id: - model_config_override["modelMap"]["imageEditModelConfig"][ - "parentPostId" - ] = parent_post_id - - tool_overrides = {"imageGen": True} + max_token_retries = int(get_config("retry.max_retry")) + tried_tokens: set[str] = set() + last_error: Exception | None = None - if stream: - response = await GrokChatService().chat( - token=token, - message=prompt, - model=model_info.grok_model, - mode=None, - stream=True, - tool_overrides=tool_overrides, - model_config_override=model_config_override, - ) - processor = ImageStreamProcessor( - model_info.model_id, - token, - n=n, - response_format=response_format, - ) - return ImageEditResult( - stream=True, - data=wrap_stream_with_usage( - processor.process(response), - token_mgr, - token, - model_info.model_id, - ), + for attempt in range(max_token_retries): + preferred = token if attempt == 0 else None + current_token = await pick_token( + token_mgr, model_info.model_id, tried_tokens, preferred=preferred ) + if not current_token: + if last_error: + raise last_error + raise AppException( + message="No available tokens. Please try again later.", + error_type=ErrorType.RATE_LIMIT.value, + code="rate_limit_exceeded", + status_code=429, + ) - images_out = await self._collect_images( - token=token, - prompt=prompt, - model_info=model_info, - n=n, - response_format=response_format, - tool_overrides=tool_overrides, - model_config_override=model_config_override, + tried_tokens.add(current_token) + try: + image_urls = await self._upload_images(images, current_token) + parent_post_id = await self._get_parent_post_id( + current_token, image_urls + ) + + model_config_override = { + "modelMap": { + "imageEditModel": "imagine", + "imageEditModelConfig": { + "imageReferences": image_urls, + }, + } + } + if parent_post_id: + model_config_override["modelMap"]["imageEditModelConfig"][ + "parentPostId" + ] = parent_post_id + + tool_overrides = {"imageGen": True} + + if stream: + response = await GrokChatService().chat( + token=current_token, + message=prompt, + model=model_info.grok_model, + mode=None, + stream=True, + tool_overrides=tool_overrides, + model_config_override=model_config_override, + ) + processor = ImageStreamProcessor( + model_info.model_id, + current_token, + n=n, + response_format=response_format, + ) + return ImageEditResult( + stream=True, + data=wrap_stream_with_usage( + processor.process(response), + token_mgr, + current_token, + model_info.model_id, + ), + ) + + images_out = await self._collect_images( + token=current_token, + prompt=prompt, + model_info=model_info, + n=n, + response_format=response_format, + tool_overrides=tool_overrides, + model_config_override=model_config_override, + ) + try: + effort = ( + EffortType.HIGH + if (model_info and model_info.cost.value == "high") + else EffortType.LOW + ) + await token_mgr.consume(current_token, effort) + logger.debug( + f"Image edit completed, recorded usage (effort={effort.value})" + ) + except Exception as e: + logger.warning(f"Failed to record image edit usage: {e}") + return ImageEditResult(stream=False, data=images_out) + + except UpstreamException as e: + last_error = e + if rate_limited(e): + await token_mgr.mark_rate_limited(current_token) + logger.warning( + f"Token {current_token[:10]}... rate limited (429), " + f"trying next token (attempt {attempt + 1}/{max_token_retries})" + ) + continue + raise + + if last_error: + raise last_error + raise AppException( + message="No available tokens. Please try again later.", + error_type=ErrorType.RATE_LIMIT.value, + code="rate_limit_exceeded", + status_code=429, ) - return ImageEditResult(stream=False, data=images_out) async def _upload_images(self, images: List[str], token: str) -> List[str]: image_urls: List[str] = [] @@ -172,6 +244,9 @@ async def _call_edit(): ) return await processor.process(response) + last_error: Exception | None = None + rate_limit_error: Exception | None = None + if calls_needed == 1: all_images = await _call_edit() else: @@ -182,11 +257,23 @@ async def _call_edit(): for result in results: if isinstance(result, Exception): logger.error(f"Concurrent call failed: {result}") + last_error = result + if rate_limited(result): + rate_limit_error = result elif isinstance(result, list): all_images.extend(result) + if not all_images: + if rate_limit_error: + raise rate_limit_error + if last_error: + raise last_error + raise UpstreamException( + "Image edit returned no results", details={"error": "empty_result"} + ) + if len(all_images) >= n: - return random.sample(all_images, n) + return all_images[:n] selected_images = all_images.copy() while len(selected_images) < n: @@ -194,4 +281,229 @@ async def _call_edit(): return selected_images +class ImageStreamProcessor(BaseProcessor): + """HTTP image stream processor.""" + + def __init__( + self, model: str, token: str = "", n: int = 1, response_format: str = "b64_json" + ): + super().__init__(model, token) + self.partial_index = 0 + self.n = n + self.target_index = 0 if n == 1 else None + self.response_format = response_format + if response_format == "url": + self.response_field = "url" + elif response_format == "base64": + self.response_field = "base64" + else: + self.response_field = "b64_json" + + def _sse(self, event: str, data: dict) -> str: + """Build SSE response.""" + return f"event: {event}\ndata: {orjson.dumps(data).decode()}\n\n" + + async def process( + self, response: AsyncIterable[bytes] + ) -> AsyncGenerator[str, None]: + """Process stream response.""" + final_images = [] + idle_timeout = get_config("image.stream_timeout") + + try: + async for line in _with_idle_timeout(response, idle_timeout, self.model): + line = _normalize_line(line) + if not line: + continue + try: + data = orjson.loads(line) + except orjson.JSONDecodeError: + continue + + resp = data.get("result", {}).get("response", {}) + + # Image generation progress + if img := resp.get("streamingImageGenerationResponse"): + image_index = img.get("imageIndex", 0) + progress = img.get("progress", 0) + + if self.n == 1 and image_index != self.target_index: + continue + + out_index = 0 if self.n == 1 else image_index + + yield self._sse( + "image_generation.partial_image", + { + "type": "image_generation.partial_image", + self.response_field: "", + "index": out_index, + "progress": progress, + }, + ) + continue + + # modelResponse + if mr := resp.get("modelResponse"): + if urls := _collect_images(mr): + for url in urls: + if self.response_format == "url": + processed = await self.process_url(url, "image") + if processed: + final_images.append(processed) + continue + try: + dl_service = self._get_dl() + base64_data = await dl_service.parse_b64( + url, self.token, "image" + ) + if base64_data: + if "," in base64_data: + b64 = base64_data.split(",", 1)[1] + else: + b64 = base64_data + final_images.append(b64) + except Exception as e: + logger.warning( + f"Failed to convert image to base64, falling back to URL: {e}" + ) + processed = await self.process_url(url, "image") + if processed: + final_images.append(processed) + continue + + for index, b64 in enumerate(final_images): + if self.n == 1: + if index != self.target_index: + continue + out_index = 0 + else: + out_index = index + + yield self._sse( + "image_generation.completed", + { + "type": "image_generation.completed", + self.response_field: b64, + "index": out_index, + "usage": { + "total_tokens": 0, + "input_tokens": 0, + "output_tokens": 0, + "input_tokens_details": { + "text_tokens": 0, + "image_tokens": 0, + }, + }, + }, + ) + except asyncio.CancelledError: + logger.debug("Image stream cancelled by client") + except StreamIdleTimeoutError as e: + raise UpstreamException( + message=f"Image stream idle timeout after {e.idle_seconds}s", + status_code=504, + details={ + "error": str(e), + "type": "stream_idle_timeout", + "idle_seconds": e.idle_seconds, + }, + ) + except RequestsError as e: + if _is_http2_error(e): + logger.warning(f"HTTP/2 stream error in image: {e}") + raise UpstreamException( + message="Upstream connection closed unexpectedly", + status_code=502, + details={"error": str(e), "type": "http2_stream_error"}, + ) + logger.error(f"Image stream request error: {e}") + raise UpstreamException( + message=f"Upstream request failed: {e}", + status_code=502, + details={"error": str(e)}, + ) + except Exception as e: + logger.error( + f"Image stream processing error: {e}", + extra={"error_type": type(e).__name__}, + ) + raise + finally: + await self.close() + + +class ImageCollectProcessor(BaseProcessor): + """HTTP image non-stream processor.""" + + def __init__(self, model: str, token: str = "", response_format: str = "b64_json"): + if response_format == "base64": + response_format = "b64_json" + super().__init__(model, token) + self.response_format = response_format + + async def process(self, response: AsyncIterable[bytes]) -> List[str]: + """Process and collect images.""" + images = [] + idle_timeout = get_config("image.stream_timeout") + + try: + async for line in _with_idle_timeout(response, idle_timeout, self.model): + line = _normalize_line(line) + if not line: + continue + try: + data = orjson.loads(line) + except orjson.JSONDecodeError: + continue + + resp = data.get("result", {}).get("response", {}) + + if mr := resp.get("modelResponse"): + if urls := _collect_images(mr): + for url in urls: + if self.response_format == "url": + processed = await self.process_url(url, "image") + if processed: + images.append(processed) + continue + try: + dl_service = self._get_dl() + base64_data = await dl_service.parse_b64( + url, self.token, "image" + ) + if base64_data: + if "," in base64_data: + b64 = base64_data.split(",", 1)[1] + else: + b64 = base64_data + images.append(b64) + except Exception as e: + logger.warning( + f"Failed to convert image to base64, falling back to URL: {e}" + ) + processed = await self.process_url(url, "image") + if processed: + images.append(processed) + + except asyncio.CancelledError: + logger.debug("Image collect cancelled by client") + except StreamIdleTimeoutError as e: + logger.warning(f"Image collect idle timeout: {e}") + except RequestsError as e: + if _is_http2_error(e): + logger.warning(f"HTTP/2 stream error in image collect: {e}") + else: + logger.error(f"Image collect request error: {e}") + except Exception as e: + logger.error( + f"Image collect processing error: {e}", + extra={"error_type": type(e).__name__}, + ) + finally: + await self.close() + + return images + + __all__ = ["ImageEditService", "ImageEditResult"] diff --git a/app/services/grok/services/model.py b/app/services/grok/services/model.py index e7cd8d61..f5c1e257 100644 --- a/app/services/grok/services/model.py +++ b/app/services/grok/services/model.py @@ -33,8 +33,9 @@ class ModelInfo(BaseModel): cost: Cost = Field(default=Cost.LOW) display_name: str description: str = "" - is_video: bool = False is_image: bool = False + is_image_edit: bool = False + is_video: bool = False class ModelService: @@ -45,105 +46,157 @@ class ModelService: model_id="grok-3", grok_model="grok-3", model_mode="MODEL_MODE_GROK_3", + tier=Tier.BASIC, cost=Cost.LOW, display_name="GROK-3", + is_image=False, + is_image_edit=False, + is_video=False, ), ModelInfo( model_id="grok-3-mini", grok_model="grok-3", model_mode="MODEL_MODE_GROK_3_MINI_THINKING", + tier=Tier.BASIC, cost=Cost.LOW, display_name="GROK-3-MINI", + is_image=False, + is_image_edit=False, + is_video=False, ), ModelInfo( model_id="grok-3-thinking", grok_model="grok-3", model_mode="MODEL_MODE_GROK_3_THINKING", + tier=Tier.BASIC, cost=Cost.LOW, display_name="GROK-3-THINKING", + is_image=False, + is_image_edit=False, + is_video=False, ), ModelInfo( model_id="grok-4", grok_model="grok-4", model_mode="MODEL_MODE_GROK_4", + tier=Tier.BASIC, cost=Cost.LOW, display_name="GROK-4", + is_image=False, + is_image_edit=False, + is_video=False, ), ModelInfo( model_id="grok-4-mini", grok_model="grok-4-mini", model_mode="MODEL_MODE_GROK_4_MINI_THINKING", + tier=Tier.BASIC, cost=Cost.LOW, display_name="GROK-4-MINI", + is_image=False, + is_image_edit=False, + is_video=False, ), ModelInfo( model_id="grok-4-thinking", grok_model="grok-4", model_mode="MODEL_MODE_GROK_4_THINKING", + tier=Tier.BASIC, cost=Cost.LOW, display_name="GROK-4-THINKING", + is_image=False, + is_image_edit=False, + is_video=False, ), ModelInfo( model_id="grok-4-heavy", grok_model="grok-4", model_mode="MODEL_MODE_HEAVY", - cost=Cost.HIGH, tier=Tier.SUPER, + cost=Cost.HIGH, display_name="GROK-4-HEAVY", + is_image=False, + is_image_edit=False, + is_video=False, ), ModelInfo( model_id="grok-4.1-mini", grok_model="grok-4-1-thinking-1129", model_mode="MODEL_MODE_GROK_4_1_MINI_THINKING", + tier=Tier.BASIC, cost=Cost.LOW, display_name="GROK-4.1-MINI", + is_image=False, + is_image_edit=False, + is_video=False, ), ModelInfo( model_id="grok-4.1-fast", grok_model="grok-4-1-thinking-1129", model_mode="MODEL_MODE_FAST", + tier=Tier.BASIC, cost=Cost.LOW, display_name="GROK-4.1-FAST", + is_image=False, + is_image_edit=False, + is_video=False, ), ModelInfo( model_id="grok-4.1-expert", grok_model="grok-4-1-thinking-1129", model_mode="MODEL_MODE_EXPERT", + tier=Tier.SUPER, cost=Cost.HIGH, display_name="GROK-4.1-EXPERT", + is_image=False, + is_image_edit=False, + is_video=False, ), ModelInfo( model_id="grok-4.1-thinking", grok_model="grok-4-1-thinking-1129", model_mode="MODEL_MODE_GROK_4_1_THINKING", + tier=Tier.SUPER, cost=Cost.HIGH, display_name="GROK-4.1-THINKING", + is_image=False, + is_image_edit=False, + is_video=False, ), ModelInfo( model_id="grok-imagine-1.0", grok_model="grok-3", model_mode="MODEL_MODE_FAST", + tier=Tier.BASIC, cost=Cost.HIGH, display_name="Grok Image", description="Image generation model", is_image=True, + is_image_edit=False, + is_video=False, ), ModelInfo( model_id="grok-imagine-1.0-edit", grok_model="imagine-image-edit", model_mode="MODEL_MODE_FAST", + tier=Tier.BASIC, cost=Cost.HIGH, display_name="Grok Image Edit", description="Image edit model", - is_image=True, + is_image=False, + is_image_edit=True, + is_video=False, ), ModelInfo( model_id="grok-imagine-1.0-video", grok_model="grok-3", model_mode="MODEL_MODE_FAST", + tier=Tier.BASIC, cost=Cost.HIGH, display_name="Grok Video", description="Video generation model", + is_image=False, + is_image_edit=False, is_video=True, ), ] diff --git a/app/services/grok/services/video.py b/app/services/grok/services/video.py index 32c9f6e5..5ba65048 100644 --- a/app/services/grok/services/video.py +++ b/app/services/grok/services/video.py @@ -3,8 +3,12 @@ """ import asyncio -from typing import AsyncGenerator +import uuid +from typing import Any, AsyncGenerator, AsyncIterable, Optional + +import orjson from curl_cffi.requests import AsyncSession +from curl_cffi.requests.errors import RequestsError from app.core.logger import logger from app.core.config import get_config @@ -13,33 +17,39 @@ AppException, ValidationException, ErrorType, + StreamIdleTimeoutError, ) from app.services.grok.services.model import ModelService from app.services.token import get_token_manager, EffortType -from app.services.grok.processors import VideoStreamProcessor, VideoCollectProcessor from app.services.grok.utils.stream import wrap_stream_with_usage +from app.services.grok.utils.process import ( + BaseProcessor, + _with_idle_timeout, + _normalize_line, + _is_http2_error, +) +from app.services.grok.utils.retry import rate_limited from app.services.reverse.app_chat import AppChatReverse from app.services.reverse.media_post import MediaPostReverse -_MEDIA_SEMAPHORE = None -_MEDIA_SEM_VALUE = 0 +_VIDEO_SEMAPHORE = None +_VIDEO_SEM_VALUE = 0 - -def _get_semaphore() -> asyncio.Semaphore: - """Get or refresh the semaphore.""" - global _MEDIA_SEMAPHORE, _MEDIA_SEM_VALUE - value = max(1, int(get_config("performance.media_max_concurrent"))) - if value != _MEDIA_SEM_VALUE: - _MEDIA_SEM_VALUE = value - _MEDIA_SEMAPHORE = asyncio.Semaphore(value) - return _MEDIA_SEMAPHORE +def _get_video_semaphore() -> asyncio.Semaphore: + """Reverse 接口并发控制(video 服务)。""" + global _VIDEO_SEMAPHORE, _VIDEO_SEM_VALUE + value = max(1, int(get_config("video.concurrent"))) + if value != _VIDEO_SEM_VALUE: + _VIDEO_SEM_VALUE = value + _VIDEO_SEMAPHORE = asyncio.Semaphore(value) + return _VIDEO_SEMAPHORE class VideoService: """Video generation service.""" def __init__(self): - self.timeout = get_config("network.timeout") + self.timeout = None async def create_post( self, @@ -54,12 +64,13 @@ async def create_post( raise ValidationException("media_url is required for image posts") async with AsyncSession() as session: - response = await MediaPostReverse.request( - session, - token, - media_type, - media_url or "", - ) + async with _get_video_semaphore(): + response = await MediaPostReverse.request( + session, + token, + media_type, + media_url or "", + ) post_id = response.json().get("post", {}).get("id", "") if not post_id: @@ -80,124 +91,64 @@ async def create_image_post(self, token: str, image_url: str) -> str: token, prompt="", media_type="MEDIA_POST_TYPE_IMAGE", media_url=image_url ) - def _build_payload( + async def generate( self, + token: str, prompt: str, - post_id: str, aspect_ratio: str = "3:2", video_length: int = 6, resolution_name: str = "480p", preset: str = "normal", - ) -> dict: - """Build video generation payload.""" + ) -> AsyncGenerator[bytes, None]: + """Generate video.""" + logger.info( + f"Video generation: prompt='{prompt[:50]}...', ratio={aspect_ratio}, length={video_length}s, preset={preset}" + ) + post_id = await self.create_post(token, prompt) mode_map = { "fun": "--mode=extremely-crazy", "normal": "--mode=normal", "spicy": "--mode=extremely-spicy-or-crazy", } mode_flag = mode_map.get(preset, "--mode=custom") - - payload = { - "temporary": True, - "modelName": "grok-3", - "message": f"{prompt} {mode_flag}", - "toolOverrides": {"videoGen": True}, - "enableSideBySide": True, - "deviceEnvInfo": { - "darkModeEnabled": False, - "devicePixelRatio": 2, - "screenWidth": 1920, - "screenHeight": 1080, - "viewportWidth": 1920, - "viewportHeight": 1080, - }, - "responseMetadata": { - "experiments": [], - "modelConfigOverride": { - "modelMap": { - "videoGenModelConfig": { - "aspectRatio": aspect_ratio, - "parentPostId": post_id, - "resolutionName": resolution_name, - "videoLength": video_length, - } - } - }, - }, + message = f"{prompt} {mode_flag}" + model_config_override = { + "modelMap": { + "videoGenModelConfig": { + "aspectRatio": aspect_ratio, + "parentPostId": post_id, + "resolutionName": resolution_name, + "videoLength": video_length, + } + } } - logger.debug(f"Video generation payload: {payload}") - - return payload - - async def _generate_internal( - self, - token: str, - post_id: str, - prompt: str, - aspect_ratio: str, - video_length: int, - resolution_name: str, - preset: str, - ) -> AsyncGenerator[bytes, None]: - """Internal generation logic.""" - session = None - try: - payload = self._build_payload( - prompt, post_id, aspect_ratio, video_length, resolution_name, preset - ) - + async def _stream(): session = AsyncSession() - stream_response = await AppChatReverse.request( - session, - token, - message=payload.get("message"), - model=payload.get("modelName"), - tool_overrides=payload.get("toolOverrides"), - model_config_override=( - (payload.get("responseMetadata") or {}).get("modelConfigOverride") - ), - ) - - logger.info(f"Video generation started: post_id={post_id}") - - return stream_response - - except Exception as e: - if session: + try: + async with _get_video_semaphore(): + stream_response = await AppChatReverse.request( + session, + token, + message=message, + model="grok-3", + tool_overrides={"videoGen": True}, + model_config_override=model_config_override, + ) + logger.info(f"Video generation started: post_id={post_id}") + async for line in stream_response: + yield line + except Exception as e: try: await session.close() except Exception: pass - logger.error(f"Video generation error: {e}") - if isinstance(e, AppException): - raise - raise UpstreamException(f"Video generation error: {str(e)}") + logger.error(f"Video generation error: {e}") + if isinstance(e, AppException): + raise + raise UpstreamException(f"Video generation error: {str(e)}") - async def generate( - self, - token: str, - prompt: str, - aspect_ratio: str = "3:2", - video_length: int = 6, - resolution_name: str = "480p", - preset: str = "normal", - ) -> AsyncGenerator[bytes, None]: - """Generate video.""" - logger.info( - f"Video generation: prompt='{prompt[:50]}...', ratio={aspect_ratio}, length={video_length}s, preset={preset}" - ) - async with _get_semaphore(): - post_id = await self.create_post(token, prompt) - return await self._generate_internal( - token, - post_id, - prompt, - aspect_ratio, - video_length, - resolution_name, - preset, - ) + return _stream() async def generate_from_image( self, @@ -213,11 +164,51 @@ async def generate_from_image( logger.info( f"Image to video: prompt='{prompt[:50]}...', image={image_url[:80]}" ) - async with _get_semaphore(): - post_id = await self.create_image_post(token, image_url) - return await self._generate_internal( - token, post_id, prompt, aspect_ratio, video_length, resolution, preset - ) + post_id = await self.create_image_post(token, image_url) + mode_map = { + "fun": "--mode=extremely-crazy", + "normal": "--mode=normal", + "spicy": "--mode=extremely-spicy-or-crazy", + } + mode_flag = mode_map.get(preset, "--mode=custom") + message = f"{prompt} {mode_flag}" + model_config_override = { + "modelMap": { + "videoGenModelConfig": { + "aspectRatio": aspect_ratio, + "parentPostId": post_id, + "resolutionName": resolution, + "videoLength": video_length, + } + } + } + + async def _stream(): + session = AsyncSession() + try: + async with _get_video_semaphore(): + stream_response = await AppChatReverse.request( + session, + token, + message=message, + model="grok-3", + tool_overrides={"videoGen": True}, + model_config_override=model_config_override, + ) + logger.info(f"Video generation started: post_id={post_id}") + async for line in stream_response: + yield line + except Exception as e: + try: + await session.close() + except Exception: + pass + logger.error(f"Video generation error: {e}") + if isinstance(e, AppException): + raise + raise UpstreamException(f"Video generation error: {str(e)}") + + return _stream() @staticmethod async def completions( @@ -235,85 +226,358 @@ async def completions( token_mgr = await get_token_manager() await token_mgr.reload_if_stale() - # Select token based on video requirements and pool candidates. - pool_candidates = ModelService.pool_candidates_for_model(model) - token_info = token_mgr.get_token_for_video( - resolution=resolution, - video_length=video_length, - pool_candidates=pool_candidates, - ) - - if not token_info: - raise AppException( - message="No available tokens. Please try again later.", - error_type=ErrorType.RATE_LIMIT.value, - code="rate_limit_exceeded", - status_code=429, - ) - - # Extract token string from TokenInfo. - token = token_info.token - if token.startswith("sso="): - token = token[4:] + max_token_retries = int(get_config("retry.max_retry")) + last_error: Exception | None = None if reasoning_effort is None: - show_think = get_config("chat.thinking") + show_think = get_config("app.thinking") else: show_think = reasoning_effort != "none" - is_stream = stream if stream is not None else get_config("chat.stream") + is_stream = stream if stream is not None else get_config("app.stream") # Extract content. from app.services.grok.services.chat import MessageExtractor from app.services.grok.utils.upload import UploadService - prompt, file_attachments, image_attachments = MessageExtractor.extract( - messages, is_video=True - ) + prompt, file_attachments, image_attachments = MessageExtractor.extract(messages) + + for attempt in range(max_token_retries): + # Select token based on video requirements and pool candidates. + pool_candidates = ModelService.pool_candidates_for_model(model) + token_info = token_mgr.get_token_for_video( + resolution=resolution, + video_length=video_length, + pool_candidates=pool_candidates, + ) + + if not token_info: + if last_error: + raise last_error + raise AppException( + message="No available tokens. Please try again later.", + error_type=ErrorType.RATE_LIMIT.value, + code="rate_limit_exceeded", + status_code=429, + ) + + # Extract token string from TokenInfo. + token = token_info.token + if token.startswith("sso="): + token = token[4:] - # Handle image attachments. - image_url = None - if image_attachments: - upload_service = UploadService() try: - for attach_data in image_attachments: - _, file_uri = await upload_service.upload_file(attach_data, token) - image_url = f"https://assets.grok.com/{file_uri}" - logger.info(f"Image uploaded for video: {image_url}") - break - finally: - await upload_service.close() - - # Generate video. - service = VideoService() - if image_url: - response = await service.generate_from_image( - token, prompt, image_url, aspect_ratio, video_length, resolution, preset + # Handle image attachments. + image_url = None + if image_attachments: + upload_service = UploadService() + try: + for attach_data in image_attachments: + _, file_uri = await upload_service.upload_file( + attach_data, token + ) + image_url = f"https://assets.grok.com/{file_uri}" + logger.info(f"Image uploaded for video: {image_url}") + break + finally: + await upload_service.close() + + # Generate video. + service = VideoService() + if image_url: + response = await service.generate_from_image( + token, + prompt, + image_url, + aspect_ratio, + video_length, + resolution, + preset, + ) + else: + response = await service.generate( + token, + prompt, + aspect_ratio, + video_length, + resolution, + preset, + ) + + # Process response. + if is_stream: + processor = VideoStreamProcessor(model, token, show_think) + return wrap_stream_with_usage( + processor.process(response), token_mgr, token, model + ) + + result = await VideoCollectProcessor(model, token).process(response) + try: + model_info = ModelService.get(model) + effort = ( + EffortType.HIGH + if (model_info and model_info.cost.value == "high") + else EffortType.LOW + ) + await token_mgr.consume(token, effort) + logger.debug( + f"Video completed, recorded usage (effort={effort.value})" + ) + except Exception as e: + logger.warning(f"Failed to record video usage: {e}") + return result + + except UpstreamException as e: + last_error = e + if rate_limited(e): + await token_mgr.mark_rate_limited(token) + logger.warning( + f"Token {token[:10]}... rate limited (429), " + f"trying next token (attempt {attempt + 1}/{max_token_retries})" + ) + continue + raise + + if last_error: + raise last_error + raise AppException( + message="No available tokens. Please try again later.", + error_type=ErrorType.RATE_LIMIT.value, + code="rate_limit_exceeded", + status_code=429, + ) + + +class VideoStreamProcessor(BaseProcessor): + """Video stream response processor.""" + + def __init__(self, model: str, token: str = "", show_think: bool = None): + super().__init__(model, token) + self.response_id: Optional[str] = None + self.think_opened: bool = False + self.role_sent: bool = False + + self.show_think = bool(show_think) + + def _sse(self, content: str = "", role: str = None, finish: str = None) -> str: + """Build SSE response.""" + delta = {} + if role: + delta["role"] = role + delta["content"] = "" + elif content: + delta["content"] = content + + chunk = { + "id": self.response_id or f"chatcmpl-{uuid.uuid4().hex[:24]}", + "object": "chat.completion.chunk", + "created": self.created, + "model": self.model, + "choices": [ + {"index": 0, "delta": delta, "logprobs": None, "finish_reason": finish} + ], + } + return f"data: {orjson.dumps(chunk).decode()}\n\n" + + async def process( + self, response: AsyncIterable[bytes] + ) -> AsyncGenerator[str, None]: + """Process video stream response.""" + idle_timeout = get_config("video.stream_timeout") + + try: + async for line in _with_idle_timeout(response, idle_timeout, self.model): + line = _normalize_line(line) + if not line: + continue + try: + data = orjson.loads(line) + except orjson.JSONDecodeError: + continue + + resp = data.get("result", {}).get("response", {}) + is_thinking = bool(resp.get("isThinking")) + + if rid := resp.get("responseId"): + self.response_id = rid + + if not self.role_sent: + yield self._sse(role="assistant") + self.role_sent = True + + if token := resp.get("token"): + if is_thinking: + if not self.show_think: + continue + if not self.think_opened: + yield self._sse("\n") + self.think_opened = True + else: + if self.think_opened: + yield self._sse("\n\n") + self.think_opened = False + yield self._sse(token) + continue + + if video_resp := resp.get("streamingVideoGenerationResponse"): + progress = video_resp.get("progress", 0) + + if is_thinking: + if not self.show_think: + continue + if not self.think_opened: + yield self._sse("\n") + self.think_opened = True + else: + if self.think_opened: + yield self._sse("\n\n") + self.think_opened = False + if self.show_think: + yield self._sse(f"正在生成视频中,当前进度{progress}%\n") + + if progress == 100: + video_url = video_resp.get("videoUrl", "") + thumbnail_url = video_resp.get("thumbnailImageUrl", "") + + if self.think_opened: + yield self._sse("\n\n") + self.think_opened = False + + if video_url: + dl_service = self._get_dl() + rendered = await dl_service.render_video( + video_url, self.token, thumbnail_url + ) + yield self._sse(rendered) + + logger.info(f"Video generated: {video_url}") + continue + + if self.think_opened: + yield self._sse("\n") + yield self._sse(finish="stop") + yield "data: [DONE]\n\n" + except asyncio.CancelledError: + logger.debug( + "Video stream cancelled by client", extra={"model": self.model} ) - else: - response = await service.generate( - token, prompt, aspect_ratio, video_length, resolution, preset + except StreamIdleTimeoutError as e: + raise UpstreamException( + message=f"Video stream idle timeout after {e.idle_seconds}s", + status_code=504, + details={ + "error": str(e), + "type": "stream_idle_timeout", + "idle_seconds": e.idle_seconds, + }, ) - - # Process response. - if is_stream: - processor = VideoStreamProcessor(model, token, show_think) - return wrap_stream_with_usage( - processor.process(response), token_mgr, token, model + except RequestsError as e: + if _is_http2_error(e): + logger.warning( + f"HTTP/2 stream error in video: {e}", extra={"model": self.model} + ) + raise UpstreamException( + message="Upstream connection closed unexpectedly", + status_code=502, + details={"error": str(e), "type": "http2_stream_error"}, + ) + logger.error( + f"Video stream request error: {e}", extra={"model": self.model} + ) + raise UpstreamException( + message=f"Upstream request failed: {e}", + status_code=502, + details={"error": str(e)}, ) + except Exception as e: + logger.error( + f"Video stream processing error: {e}", + extra={"model": self.model, "error_type": type(e).__name__}, + ) + finally: + await self.close() + + +class VideoCollectProcessor(BaseProcessor): + """Video non-stream response processor.""" + + def __init__(self, model: str, token: str = ""): + super().__init__(model, token) + + async def process(self, response: AsyncIterable[bytes]) -> dict[str, Any]: + """Process and collect video response.""" + response_id = "" + content = "" + idle_timeout = get_config("video.stream_timeout") - result = await VideoCollectProcessor(model, token).process(response) try: - model_info = ModelService.get(model) - effort = ( - EffortType.HIGH - if (model_info and model_info.cost.value == "high") - else EffortType.LOW + async for line in _with_idle_timeout(response, idle_timeout, self.model): + line = _normalize_line(line) + if not line: + continue + try: + data = orjson.loads(line) + except orjson.JSONDecodeError: + continue + + resp = data.get("result", {}).get("response", {}) + + if video_resp := resp.get("streamingVideoGenerationResponse"): + if video_resp.get("progress") == 100: + response_id = resp.get("responseId", "") + video_url = video_resp.get("videoUrl", "") + thumbnail_url = video_resp.get("thumbnailImageUrl", "") + + if video_url: + dl_service = self._get_dl() + content = await dl_service.render_video( + video_url, self.token, thumbnail_url + ) + logger.info(f"Video generated: {video_url}") + + except asyncio.CancelledError: + logger.debug( + "Video collect cancelled by client", extra={"model": self.model} + ) + except StreamIdleTimeoutError as e: + logger.warning( + f"Video collect idle timeout: {e}", extra={"model": self.model} ) - await token_mgr.consume(token, effort) - logger.debug(f"Video completed, recorded usage (effort={effort.value})") + except RequestsError as e: + if _is_http2_error(e): + logger.warning( + f"HTTP/2 stream error in video collect: {e}", + extra={"model": self.model}, + ) + else: + logger.error( + f"Video collect request error: {e}", extra={"model": self.model} + ) except Exception as e: - logger.warning(f"Failed to record video usage: {e}") - return result + logger.error( + f"Video collect processing error: {e}", + extra={"model": self.model, "error_type": type(e).__name__}, + ) + finally: + await self.close() + + return { + "id": response_id, + "object": "chat.completion", + "created": self.created, + "model": self.model, + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": content, + "refusal": None, + }, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + } __all__ = ["VideoService"] diff --git a/app/services/grok/services/voice.py b/app/services/grok/services/voice.py index a6dc5ed6..b72fce3e 100644 --- a/app/services/grok/services/voice.py +++ b/app/services/grok/services/voice.py @@ -20,7 +20,7 @@ async def get_token( personality: str = "assistant", speed: float = 1.0, ) -> Dict[str, Any]: - browser = get_config("security.browser") + browser = get_config("proxy.browser") async with AsyncSession(impersonate=browser) as session: response = await LivekitTokenReverse.request( session, diff --git a/app/services/grok/utils/download.py b/app/services/grok/utils/download.py index a13c7179..edfb279e 100644 --- a/app/services/grok/utils/download.py +++ b/app/services/grok/utils/download.py @@ -110,6 +110,17 @@ async def render_video( ''' + @staticmethod + def _is_url(value: str) -> bool: + """Check if the value is a URL.""" + try: + parsed = urlparse(value) + return bool( + parsed.scheme and parsed.netloc and parsed.scheme in ["http", "https"] + ) + except Exception: + return False + async def parse_b64(self, file_path: str, token: str, media_type: str = "image") -> str: """Download and return data URI.""" try: @@ -229,51 +240,54 @@ async def _check_limit(self): self._cleanup_running = True try: - async with _file_lock("cache_cleanup", timeout=5): - limit_mb = get_config("cache.limit_mb") - total_size = 0 - all_files: List[Tuple[Path, float, int]] = [] - - for d in [self.image_dir, self.video_dir]: - if d.exists(): - for f in d.glob("*"): - if f.is_file(): - try: - stat = f.stat() - total_size += stat.st_size - all_files.append( - (f, stat.st_mtime, stat.st_size) - ) - except Exception: - pass - current_mb = total_size / 1024 / 1024 - - if current_mb <= limit_mb: - return - - logger.info( - f"Cache limit exceeded ({current_mb:.2f}MB > {limit_mb}MB), cleaning..." - ) - all_files.sort(key=lambda x: x[1]) - - deleted_count = 0 - deleted_size = 0 - target_mb = limit_mb * 0.8 - - for f, _, size in all_files: - try: - f.unlink() - deleted_count += 1 - deleted_size += size - total_size -= size - if (total_size / 1024 / 1024) <= target_mb: - break - except Exception: - pass - - logger.info( - f"Cache cleanup: {deleted_count} files ({deleted_size / 1024 / 1024:.2f}MB)" - ) + try: + async with _file_lock("cache_cleanup", timeout=5): + limit_mb = get_config("cache.limit_mb") + total_size = 0 + all_files: List[Tuple[Path, float, int]] = [] + + for d in [self.image_dir, self.video_dir]: + if d.exists(): + for f in d.glob("*"): + if f.is_file(): + try: + stat = f.stat() + total_size += stat.st_size + all_files.append( + (f, stat.st_mtime, stat.st_size) + ) + except Exception: + pass + current_mb = total_size / 1024 / 1024 + + if current_mb <= limit_mb: + return + + logger.info( + f"Cache limit exceeded ({current_mb:.2f}MB > {limit_mb}MB), cleaning..." + ) + all_files.sort(key=lambda x: x[1]) + + deleted_count = 0 + deleted_size = 0 + target_mb = limit_mb * 0.8 + + for f, _, size in all_files: + try: + f.unlink() + deleted_count += 1 + deleted_size += size + total_size -= size + if (total_size / 1024 / 1024) <= target_mb: + break + except Exception: + pass + + logger.info( + f"Cache cleanup: {deleted_count} files ({deleted_size / 1024 / 1024:.2f}MB)" + ) + except Exception as e: + logger.warning(f"Cache cleanup failed: {e}") finally: self._cleanup_running = False diff --git a/app/services/grok/utils/process.py b/app/services/grok/utils/process.py index 12249491..69353c65 100644 --- a/app/services/grok/utils/process.py +++ b/app/services/grok/utils/process.py @@ -87,6 +87,16 @@ async def _with_idle_timeout( return iterator = iterable.__aiter__() + + async def _maybe_aclose(it): + aclose = getattr(it, "aclose", None) + if not aclose: + return + try: + await aclose() + except Exception: + pass + while True: try: item = await asyncio.wait_for(iterator.__anext__(), timeout=idle_timeout) @@ -96,7 +106,11 @@ async def _with_idle_timeout( f"Stream idle timeout after {idle_timeout}s", extra={"model": model, "idle_timeout": idle_timeout}, ) + await _maybe_aclose(iterator) raise StreamIdleTimeoutError(idle_timeout) + except asyncio.CancelledError: + await _maybe_aclose(iterator) + raise except StopAsyncIteration: break diff --git a/app/services/grok/utils/retry.py b/app/services/grok/utils/retry.py new file mode 100644 index 00000000..e0b1edb5 --- /dev/null +++ b/app/services/grok/utils/retry.py @@ -0,0 +1,45 @@ +""" +Retry helpers for token switching. +""" + +from typing import Optional, Set + +from app.core.exceptions import UpstreamException +from app.services.grok.services.model import ModelService + + +async def pick_token( + token_mgr, + model_id: str, + tried: Set[str], + preferred: Optional[str] = None, +) -> Optional[str]: + if preferred and preferred not in tried: + return preferred + + token = None + for pool_name in ModelService.pool_candidates_for_model(model_id): + token = token_mgr.get_token(pool_name, exclude=tried) + if token: + break + + if not token and not tried: + result = await token_mgr.refresh_cooling_tokens() + if result.get("recovered", 0) > 0: + for pool_name in ModelService.pool_candidates_for_model(model_id): + token = token_mgr.get_token(pool_name) + if token: + break + + return token + + +def rate_limited(error: Exception) -> bool: + if not isinstance(error, UpstreamException): + return False + status = error.details.get("status") if error.details else None + code = error.details.get("error_code") if error.details else None + return status == 429 or code == "rate_limit_exceeded" + + +__all__ = ["pick_token", "rate_limited"] diff --git a/app/services/grok/utils/upload.py b/app/services/grok/utils/upload.py index 96707923..0861a4ed 100644 --- a/app/services/grok/utils/upload.py +++ b/app/services/grok/utils/upload.py @@ -136,7 +136,7 @@ async def parse_b64(self, url: str) -> Tuple[str, str, str]: lock_name = f"ul_url_{hashlib.sha1(url.encode()).hexdigest()[:16]}" timeout = float(get_config("asset.upload_timeout")) - proxy_url = get_config("network.base_proxy_url") + proxy_url = get_config("proxy.base_proxy_url") proxies = {"http": proxy_url, "https": proxy_url} if proxy_url else None lock_timeout = max(1, int(get_config("asset.upload_timeout"))) diff --git a/app/services/reverse/accept_tos.py b/app/services/reverse/accept_tos.py index 203e1f62..8459be46 100644 --- a/app/services/reverse/accept_tos.py +++ b/app/services/reverse/accept_tos.py @@ -30,7 +30,7 @@ async def request(session: AsyncSession, token: str) -> GrpcStatus: """ try: # Get proxies - base_proxy = get_config("network.base_proxy_url") + base_proxy = get_config("proxy.base_proxy_url") proxies = {"http": base_proxy, "https": base_proxy} if base_proxy else None # Build headers @@ -52,7 +52,7 @@ async def request(session: AsyncSession, token: str) -> GrpcStatus: # Curl Config timeout = get_config("nsfw.timeout") - browser = get_config("security.browser") + browser = get_config("proxy.browser") async def _do_request(): response = await session.post( diff --git a/app/services/reverse/app_chat.py b/app/services/reverse/app_chat.py index 6ddeb07e..d7075a05 100644 --- a/app/services/reverse/app_chat.py +++ b/app/services/reverse/app_chat.py @@ -41,7 +41,7 @@ def build_payload( "viewportWidth": 2056, "viewportHeight": 1083, }, - "disableMemory": get_config("chat.disable_memory"), + "disableMemory": get_config("app.disable_memory"), "disableSearch": False, "disableSelfHarmShortCircuit": False, "disableTextFollowUps": False, @@ -64,7 +64,7 @@ def build_payload( "returnImageBytes": False, "returnRawGrokInXaiRequest": False, "sendFinalMetadata": True, - "temporary": get_config("chat.temporary"), + "temporary": get_config("app.temporary"), "toolOverrides": tool_overrides or {}, } @@ -101,7 +101,7 @@ async def request( """ try: # Get proxies - base_proxy = get_config("network.base_proxy_url") + base_proxy = get_config("proxy.base_proxy_url") proxies = {"http": base_proxy, "https": base_proxy} if base_proxy else None # Build headers @@ -123,8 +123,12 @@ async def request( ) # Curl Config - timeout = get_config("network.timeout") - browser = get_config("security.browser") + timeout = max( + float(get_config("chat.timeout") or 0), + float(get_config("video.timeout") or 0), + float(get_config("image.timeout") or 0), + ) + browser = get_config("proxy.browser") async def _do_request(): response = await session.post( diff --git a/app/services/reverse/assets_delete.py b/app/services/reverse/assets_delete.py index a982f4b9..79423107 100644 --- a/app/services/reverse/assets_delete.py +++ b/app/services/reverse/assets_delete.py @@ -32,8 +32,8 @@ async def request(session: AsyncSession, token: str, asset_id: str) -> Any: """ try: # Get proxies - base_proxy = get_config("network.base_proxy_url") - assert_proxy = get_config("network.asset_proxy_url") + base_proxy = get_config("proxy.base_proxy_url") + assert_proxy = get_config("proxy.asset_proxy_url") if assert_proxy: proxies = {"http": assert_proxy, "https": assert_proxy} else: @@ -49,7 +49,7 @@ async def request(session: AsyncSession, token: str, asset_id: str) -> Any: # Curl Config timeout = get_config("asset.delete_timeout") - browser = get_config("security.browser") + browser = get_config("proxy.browser") async def _do_request(): response = await session.delete( diff --git a/app/services/reverse/assets_download.py b/app/services/reverse/assets_download.py index e2491a04..ec03794d 100644 --- a/app/services/reverse/assets_download.py +++ b/app/services/reverse/assets_download.py @@ -48,8 +48,8 @@ async def request(session: AsyncSession, token: str, file_path: str) -> Any: url = f"{DOWNLOAD_API}{file_path}" # Get proxies - base_proxy = get_config("network.base_proxy_url") - assert_proxy = get_config("network.asset_proxy_url") + base_proxy = get_config("proxy.base_proxy_url") + assert_proxy = get_config("proxy.asset_proxy_url") if assert_proxy: proxies = {"http": assert_proxy, "https": assert_proxy} else: @@ -75,7 +75,7 @@ async def request(session: AsyncSession, token: str, file_path: str) -> Any: # Curl Config timeout = get_config("asset.download_timeout") - browser = get_config("security.browser") + browser = get_config("proxy.browser") async def _do_request(): response = await session.get( diff --git a/app/services/reverse/assets_list.py b/app/services/reverse/assets_list.py index 9b7762ff..5c84fe99 100644 --- a/app/services/reverse/assets_list.py +++ b/app/services/reverse/assets_list.py @@ -32,8 +32,8 @@ async def request(session: AsyncSession, token: str, params: Dict[str, Any]) -> """ try: # Get proxies - base_proxy = get_config("network.base_proxy_url") - assert_proxy = get_config("network.asset_proxy_url") + base_proxy = get_config("proxy.base_proxy_url") + assert_proxy = get_config("proxy.asset_proxy_url") if assert_proxy: proxies = {"http": assert_proxy, "https": assert_proxy} else: @@ -49,7 +49,7 @@ async def request(session: AsyncSession, token: str, params: Dict[str, Any]) -> # Curl Config timeout = get_config("asset.list_timeout") - browser = get_config("security.browser") + browser = get_config("proxy.browser") async def _do_request(): response = await session.get( diff --git a/app/services/reverse/assets_upload.py b/app/services/reverse/assets_upload.py index 517e7598..b9d96731 100644 --- a/app/services/reverse/assets_upload.py +++ b/app/services/reverse/assets_upload.py @@ -34,8 +34,8 @@ async def request(session: AsyncSession, token: str, fileName: str, fileMimeType """ try: # Get proxies - base_proxy = get_config("network.base_proxy_url") - assert_proxy = get_config("network.asset_proxy_url") + base_proxy = get_config("proxy.base_proxy_url") + assert_proxy = get_config("proxy.asset_proxy_url") if assert_proxy: proxies = {"http": assert_proxy, "https": assert_proxy} else: @@ -58,7 +58,7 @@ async def request(session: AsyncSession, token: str, fileName: str, fileMimeType # Curl Config timeout = get_config("asset.upload_timeout") - browser = get_config("security.browser") + browser = get_config("proxy.browser") async def _do_request(): response = await session.post( diff --git a/app/services/reverse/media_post.py b/app/services/reverse/media_post.py index f7d358ea..044e9189 100644 --- a/app/services/reverse/media_post.py +++ b/app/services/reverse/media_post.py @@ -20,7 +20,12 @@ class MediaPostReverse: """/rest/media/post/create reverse interface.""" @staticmethod - async def request(session: AsyncSession, token: str, mediaType: str, mediaUrl: str) -> Any: + async def request( + session: AsyncSession, + token: str, + mediaType: str, + mediaUrl: str, + ) -> Any: """Create media post in Grok. Args: @@ -34,7 +39,7 @@ async def request(session: AsyncSession, token: str, mediaType: str, mediaUrl: s """ try: # Get proxies - base_proxy = get_config("network.base_proxy_url") + base_proxy = get_config("proxy.base_proxy_url") proxies = {"http": base_proxy, "https": base_proxy} if base_proxy else None # Build headers @@ -52,8 +57,8 @@ async def request(session: AsyncSession, token: str, mediaType: str, mediaUrl: s } # Curl Config - timeout = get_config("network.timeout") - browser = get_config("security.browser") + timeout = get_config("video.timeout") + browser = get_config("proxy.browser") async def _do_request(): response = await session.post( diff --git a/app/services/reverse/nsfw_mgmt.py b/app/services/reverse/nsfw_mgmt.py index 8056b231..ca5afc46 100644 --- a/app/services/reverse/nsfw_mgmt.py +++ b/app/services/reverse/nsfw_mgmt.py @@ -30,7 +30,7 @@ async def request(session: AsyncSession, token: str) -> GrpcStatus: """ try: # Get proxies - base_proxy = get_config("network.base_proxy_url") + base_proxy = get_config("proxy.base_proxy_url") proxies = {"http": base_proxy, "https": base_proxy} if base_proxy else None # Build headers @@ -55,7 +55,7 @@ async def request(session: AsyncSession, token: str) -> GrpcStatus: # Curl Config timeout = get_config("nsfw.timeout") - browser = get_config("security.browser") + browser = get_config("proxy.browser") async def _do_request(): response = await session.post( diff --git a/app/services/reverse/rate_limits.py b/app/services/reverse/rate_limits.py index 198164c4..10e6d71f 100644 --- a/app/services/reverse/rate_limits.py +++ b/app/services/reverse/rate_limits.py @@ -31,7 +31,7 @@ async def request(session: AsyncSession, token: str) -> Any: """ try: # Get proxies - base_proxy = get_config("network.base_proxy_url") + base_proxy = get_config("proxy.base_proxy_url") proxies = {"http": base_proxy, "https": base_proxy} if base_proxy else None # Build headers @@ -50,7 +50,7 @@ async def request(session: AsyncSession, token: str) -> Any: # Curl Config timeout = get_config("usage.timeout") - browser = get_config("security.browser") + browser = get_config("proxy.browser") async def _do_request(): response = await session.post( diff --git a/app/services/reverse/set_birth.py b/app/services/reverse/set_birth.py index 556a331c..d76c4c60 100644 --- a/app/services/reverse/set_birth.py +++ b/app/services/reverse/set_birth.py @@ -32,7 +32,7 @@ async def request(session: AsyncSession, token: str) -> Any: """ try: # Get proxies - base_proxy = get_config("network.base_proxy_url") + base_proxy = get_config("proxy.base_proxy_url") proxies = {"http": base_proxy, "https": base_proxy} if base_proxy else None # Build headers @@ -59,7 +59,7 @@ async def request(session: AsyncSession, token: str) -> Any: # Curl Config timeout = get_config("nsfw.timeout") - browser = get_config("security.browser") + browser = get_config("proxy.browser") async def _do_request(): response = await session.post( diff --git a/app/services/reverse/utils/headers.py b/app/services/reverse/utils/headers.py index 03a8f253..e0c534e6 100644 --- a/app/services/reverse/utils/headers.py +++ b/app/services/reverse/utils/headers.py @@ -27,7 +27,7 @@ def build_sso_cookie(sso_token: str) -> str: cookie = f"sso={sso_token}; sso-rw={sso_token}" # CF Clearance - cf_clearance = get_config("security.cf_clearance") + cf_clearance = get_config("proxy.cf_clearance") if cf_clearance: cookie += f";cf_clearance={cf_clearance}" @@ -48,7 +48,7 @@ def build_ws_headers(token: Optional[str] = None, origin: Optional[str] = None, """ headers = { "Origin": origin or "https://grok.com", - "User-Agent": get_config("security.user_agent"), + "User-Agent": get_config("proxy.user_agent"), "Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8", "Cache-Control": "no-cache", "Pragma": "no-cache", @@ -90,7 +90,7 @@ def build_headers(cookie_token: str, content_type: Optional[str] = None, origin: "Sec-Ch-Ua-Model": "", "Sec-Ch-Ua-Platform": '"macOS"', "Sec-Fetch-Mode": "cors", - "User-Agent": get_config("security.user_agent"), + "User-Agent": get_config("proxy.user_agent"), } # Cookie diff --git a/app/services/reverse/utils/statsig.py b/app/services/reverse/utils/statsig.py index 69e81968..485885f1 100644 --- a/app/services/reverse/utils/statsig.py +++ b/app/services/reverse/utils/statsig.py @@ -31,7 +31,7 @@ def gen_id() -> str: Returns: Base64 encoded ID. """ - dynamic = get_config("chat.dynamic_statsig") + dynamic = get_config("app.dynamic_statsig") # Dynamic Statsig ID if dynamic: diff --git a/app/services/reverse/utils/websocket.py b/app/services/reverse/utils/websocket.py index 67f15164..f13586a1 100644 --- a/app/services/reverse/utils/websocket.py +++ b/app/services/reverse/utils/websocket.py @@ -91,13 +91,14 @@ class WebSocketClient: """WebSocket client with proxy support.""" def __init__(self, proxy: Optional[str] = None) -> None: - self.proxy = proxy or get_config("network.base_proxy_url") + self.proxy = proxy or get_config("proxy.base_proxy_url") self._ssl_context = _default_ssl_context() async def connect( self, url: str, headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, ws_kwargs: Optional[Mapping[str, object]] = None, ) -> WebSocketConnection: """Connect to the WebSocket. @@ -114,7 +115,12 @@ async def connect( connector, proxy = resolve_proxy(self.proxy, self._ssl_context) # Build client timeout - client_timeout = aiohttp.ClientTimeout(total=get_config("network.timeout")) + total_timeout = ( + float(timeout) + if timeout is not None + else float(get_config("voice.timeout") or 120) + ) + client_timeout = aiohttp.ClientTimeout(total=total_timeout) # Create session session = aiohttp.ClientSession(connector=connector, timeout=client_timeout) diff --git a/app/services/reverse/ws_imagine.py b/app/services/reverse/ws_imagine.py index d4c3c7cf..e9e648af 100644 --- a/app/services/reverse/ws_imagine.py +++ b/app/services/reverse/ws_imagine.py @@ -30,18 +30,24 @@ def __init__(self) -> None: self._url_pattern = re.compile(r"/images/([a-f0-9-]+)\.(png|jpg|jpeg)") self._client = WebSocketClient() - def _extract_image_id(self, url: str) -> Optional[str]: + def _parse_image_url(self, url: str) -> tuple[Optional[str], Optional[str]]: match = self._url_pattern.search(url or "") - return match.group(1) if match else None + if not match: + return None, None + return match.group(1), match.group(2).lower() def _is_final_image(self, url: str, blob_size: int, final_min_bytes: int) -> bool: - return (url or "").lower().endswith((".jpg", ".jpeg")) and blob_size > final_min_bytes + url_lower = (url or "").lower() + if url_lower.endswith((".jpg", ".jpeg")): + return True + return blob_size > final_min_bytes def _classify_image(self, url: str, blob: str, final_min_bytes: int, medium_min_bytes: int) -> Optional[Dict[str, object]]: if not url or not blob: return None - image_id = self._extract_image_id(url) or uuid.uuid4().hex + image_id, ext = self._parse_image_url(url) + image_id = image_id or uuid.uuid4().hex blob_size = len(blob) is_final = self._is_final_image(url, blob_size, final_min_bytes) @@ -54,6 +60,7 @@ def _classify_image(self, url: str, blob: str, final_min_bytes: int, medium_min_ return { "type": "image", "image_id": image_id, + "ext": ext, "stage": stage, "blob": blob, "blob_size": blob_size, @@ -120,6 +127,11 @@ async def stream( logger.warning(f"WebSocket blocked, retry {attempt + 1}/{retries}") except Exception as e: logger.error(f"WebSocket stream failed: {e}") + yield { + "type": "error", + "error_code": "ws_stream_failed", + "error": str(e), + } return async def _stream_once( @@ -132,26 +144,33 @@ async def _stream_once( ) -> AsyncGenerator[Dict[str, object], None]: request_id = str(uuid.uuid4()) headers = build_ws_headers(token=token) - timeout = float(get_config("network.timeout")) - blocked_seconds = float(get_config("image.image_ws_blocked_seconds")) - blocked_grace = min(10.0, blocked_seconds) - final_min_bytes = int(get_config("image.image_ws_final_min_bytes")) - medium_min_bytes = int(get_config("image.image_ws_medium_min_bytes")) + timeout = float(get_config("image.timeout")) + stream_timeout = float(get_config("image.stream_timeout")) + final_timeout = float(get_config("image.final_timeout")) + blocked_grace = min(10.0, final_timeout) + final_min_bytes = int(get_config("image.final_min_bytes")) + medium_min_bytes = int(get_config("image.medium_min_bytes")) try: conn = await self._client.connect( WS_IMAGINE_URL, headers=headers, + timeout=timeout, ws_kwargs={ "heartbeat": 20, - "receive_timeout": timeout, + "receive_timeout": stream_timeout, }, ) except Exception as e: + status = getattr(e, "status", None) + error_code = ( + "rate_limit_exceeded" if status == 429 else "connection_failed" + ) logger.error(f"WebSocket connect failed: {e}") yield { "type": "error", - "error_code": "connection_failed", + "error_code": error_code, + "status": status, "error": str(e), } return @@ -238,7 +257,7 @@ async def _stream_once( if ( medium_received_time and completed == 0 - and time.monotonic() - medium_received_time > blocked_seconds + and time.monotonic() - medium_received_time > final_timeout ): raise _BlockedError() diff --git a/app/services/reverse/ws_livekit.py b/app/services/reverse/ws_livekit.py index 095e31b2..bf3d92ae 100644 --- a/app/services/reverse/ws_livekit.py +++ b/app/services/reverse/ws_livekit.py @@ -44,7 +44,7 @@ async def request( """ try: # Get proxies - base_proxy = get_config("network.base_proxy_url") + base_proxy = get_config("proxy.base_proxy_url") proxies = {"http": base_proxy, "https": base_proxy} if base_proxy else None # Build headers @@ -72,8 +72,8 @@ async def request( } # Curl Config - timeout = get_config("network.timeout") - browser = get_config("security.browser") + timeout = get_config("voice.timeout") + browser = get_config("proxy.browser") async def _do_request(): response = await session.post( @@ -164,7 +164,9 @@ async def connect(self, token: str) -> WebSocketConnection: ws_headers = build_ws_headers() try: - return await self._client.connect(url, headers=ws_headers) + return await self._client.connect( + url, headers=ws_headers, timeout=get_config("voice.timeout") + ) except Exception as e: logger.error(f"LivekitWebSocketReverse: Connect failed, {e}") raise UpstreamException( diff --git a/app/services/token/manager.py b/app/services/token/manager.py index c543d814..4955909b 100644 --- a/app/services/token/manager.py +++ b/app/services/token/manager.py @@ -425,10 +425,18 @@ async def record_fail( token = pool.get(raw_token) if token: if status_code == 401: - token.record_fail(status_code, reason) + threshold = get_config("token.fail_threshold", FAIL_THRESHOLD) + try: + threshold = int(threshold) + except (TypeError, ValueError): + threshold = FAIL_THRESHOLD + if threshold < 1: + threshold = 1 + + token.record_fail(status_code, reason, threshold=threshold) logger.warning( f"Token {raw_token[:10]}...: recorded {status_code} failure " - f"({token.fail_count}/{FAIL_THRESHOLD}) - {reason}" + f"({token.fail_count}/{threshold}) - {reason}" ) else: logger.info( diff --git a/app/services/token/models.py b/app/services/token/models.py index d2853d0a..86300d90 100644 --- a/app/services/token/models.py +++ b/app/services/token/models.py @@ -128,7 +128,12 @@ def reset(self, default_quota: Optional[int] = None): self.fail_count = 0 self.last_fail_reason = None - def record_fail(self, status_code: int = 401, reason: str = ""): + def record_fail( + self, + status_code: int = 401, + reason: str = "", + threshold: Optional[int] = None, + ): """记录失败,达到阈值后自动标记为 expired""" # 仅 401 计入失败 if status_code != 401: @@ -138,7 +143,8 @@ def record_fail(self, status_code: int = 401, reason: str = ""): self.last_fail_at = int(datetime.now().timestamp() * 1000) self.last_fail_reason = reason - if self.fail_count >= FAIL_THRESHOLD: + limit = FAIL_THRESHOLD if threshold is None else threshold + if self.fail_count >= limit: self.status = TokenStatus.EXPIRED def record_success(self, is_usage: bool = True): diff --git a/app/services/token/scheduler.py b/app/services/token/scheduler.py index 14277132..5ec8cafb 100644 --- a/app/services/token/scheduler.py +++ b/app/services/token/scheduler.py @@ -36,7 +36,7 @@ async def _refresh_loop(self): lock_acquired = await lock.acquire(blocking=False) else: try: - async with storage.acquire_lock("token_refresh", timeout=0): + async with storage.acquire_lock("token_refresh", timeout=1): lock_acquired = True except StorageError: lock_acquired = False diff --git a/app/static/cache/cache.html b/app/static/cache/cache.html index 128ab971..59051ade 100644 --- a/app/static/cache/cache.html +++ b/app/static/cache/cache.html @@ -198,7 +198,7 @@

缓存管理

- + diff --git a/app/static/cache/cache.js b/app/static/cache/cache.js index 40b99d64..261cccd3 100644 --- a/app/static/cache/cache.js +++ b/app/static/cache/cache.js @@ -84,7 +84,7 @@ function createIconButton(title, svg, onClick) { } async function init() { - apiKey = await ensureApiKey(); + apiKey = await ensureAdminKey(); if (apiKey === null) return; cacheUI(); setupCacheCards(); @@ -233,7 +233,7 @@ async function loadStats(options = {}) { } else { currentScope = 'none'; } - const url = `/api/v1/admin/cache${params.toString() ? `?${params.toString()}` : ''}`; + const url = `/v1/admin/cache${params.toString() ? `?${params.toString()}` : ''}`; const res = await fetch(url, { headers: buildAuthHeaders(apiKey) }); @@ -446,7 +446,7 @@ async function clearCache(type) { if (!ok) return; try { - const res = await fetch('/api/v1/admin/cache/clear', { + const res = await fetch('/v1/admin/cache/clear', { method: 'POST', headers: { 'Content-Type': 'application/json', @@ -770,7 +770,7 @@ async function loadLocalCacheList(type) { body.innerHTML = `加载中...`; try { const params = new URLSearchParams({ type, page: '1', page_size: '1000' }); - const res = await fetch(`/api/v1/admin/cache/list?${params.toString()}`, { + const res = await fetch(`/v1/admin/cache/list?${params.toString()}`, { headers: buildAuthHeaders(apiKey) }); if (!res.ok) { @@ -897,7 +897,7 @@ async function deleteLocalFile(type, name) { async function requestDeleteLocalFile(type, name) { try { - const res = await fetch('/api/v1/admin/cache/item/delete', { + const res = await fetch('/v1/admin/cache/item/delete', { method: 'POST', headers: { 'Content-Type': 'application/json', @@ -1104,7 +1104,7 @@ async function startBatchLoad(tokens) { refreshBatchUI(); try { - const res = await fetch('/api/v1/admin/cache/online/load/async', { + const res = await fetch('/v1/admin/cache/online/load/async', { method: 'POST', headers: { 'Content-Type': 'application/json', @@ -1242,7 +1242,7 @@ async function startBatchDelete(tokens) { updateDeleteButton(); refreshBatchUI(); try { - const res = await fetch('/api/v1/admin/cache/online/clear/async', { + const res = await fetch('/v1/admin/cache/online/clear/async', { method: 'POST', headers: { 'Content-Type': 'application/json', @@ -1337,7 +1337,7 @@ async function clearOnlineCache(targetToken = '', skipConfirm = false) { showToast('正在清理在线资产,请稍候...', 'info'); try { - const res = await fetch('/api/v1/admin/cache/online/clear', { + const res = await fetch('/v1/admin/cache/online/clear', { method: 'POST', headers: { 'Content-Type': 'application/json', diff --git a/app/static/common/admin-auth.js b/app/static/common/admin-auth.js index e450ca03..36898d01 100644 --- a/app/static/common/admin-auth.js +++ b/app/static/common/admin-auth.js @@ -1,8 +1,10 @@ const APP_KEY_STORAGE = 'grok2api_app_key'; +const PUBLIC_KEY_STORAGE = 'grok2api_public_key'; const APP_KEY_ENC_PREFIX = 'enc:v1:'; const APP_KEY_XOR_PREFIX = 'enc:xor:'; const APP_KEY_SECRET = 'grok2api-admin-key'; -let cachedApiKey = null; +let cachedAdminKey = null; +let cachedPublicKey = null; const textEncoder = new TextEncoder(); const textDecoder = new TextDecoder(); @@ -112,6 +114,17 @@ async function getStoredAppKey() { } } +async function getStoredPublicKey() { + const stored = localStorage.getItem(PUBLIC_KEY_STORAGE) || ''; + if (!stored) return ''; + try { + return await decryptAppKey(stored); + } catch (e) { + clearStoredPublicKey(); + return ''; + } +} + async function storeAppKey(appKey) { if (!appKey) { clearStoredAppKey(); @@ -121,34 +134,78 @@ async function storeAppKey(appKey) { localStorage.setItem(APP_KEY_STORAGE, encrypted || ''); } +async function storePublicKey(publicKey) { + if (!publicKey) { + clearStoredPublicKey(); + return; + } + const encrypted = await encryptAppKey(publicKey); + localStorage.setItem(PUBLIC_KEY_STORAGE, encrypted || ''); +} + function clearStoredAppKey() { localStorage.removeItem(APP_KEY_STORAGE); - cachedApiKey = null; + cachedAdminKey = null; } -async function requestApiKey(appKey) { - const headers = appKey ? { 'Authorization': `Bearer ${appKey}` } : {}; - const res = await fetch('/api/v1/admin/login', { method: 'POST', headers }); - if (!res.ok) { - throw new Error('Unauthorized'); - } - const data = await res.json(); - const rawApiKey = data.api_key || ''; - cachedApiKey = rawApiKey ? `Bearer ${rawApiKey}` : ''; - return cachedApiKey; +function clearStoredPublicKey() { + localStorage.removeItem(PUBLIC_KEY_STORAGE); + cachedPublicKey = null; +} + +async function verifyKey(url, key) { + const headers = key ? { 'Authorization': `Bearer ${key}` } : {}; + const res = await fetch(url, { method: 'GET', headers }); + return res.ok; } -async function ensureApiKey() { +async function ensureAdminKey() { + if (cachedAdminKey) return cachedAdminKey; const appKey = await getStoredAppKey(); if (!appKey) { - window.location.href = '/admin'; + window.location.href = '/admin/login'; return null; } try { - return await requestApiKey(appKey); + const ok = await verifyKey('/v1/admin/verify', appKey); + if (!ok) throw new Error('Unauthorized'); + cachedAdminKey = `Bearer ${appKey}`; + return cachedAdminKey; } catch (e) { clearStoredAppKey(); - window.location.href = '/admin'; + window.location.href = '/admin/login'; + return null; + } +} + +async function ensurePublicKey() { + if (cachedPublicKey !== null) return cachedPublicKey; + + const key = await getStoredPublicKey(); + if (!key) { + try { + const ok = await verifyKey('/v1/public/verify', ''); + if (ok) { + cachedPublicKey = ''; + return cachedPublicKey; + } + } catch (e) { + // ignore + } + return null; + } + + if (!key) { + return null; + } + + try { + const ok = await verifyKey('/v1/public/verify', key); + if (!ok) throw new Error('Unauthorized'); + cachedPublicKey = `Bearer ${key}`; + return cachedPublicKey; + } catch (e) { + clearStoredPublicKey(); return null; } } @@ -159,14 +216,20 @@ function buildAuthHeaders(apiKey) { function logout() { clearStoredAppKey(); - window.location.href = '/admin'; + clearStoredPublicKey(); + window.location.href = '/admin/login'; +} + +function publicLogout() { + clearStoredPublicKey(); + window.location.href = '/login'; } async function fetchStorageType() { - const apiKey = await ensureApiKey(); + const apiKey = await ensureAdminKey(); if (apiKey === null) return null; try { - const res = await fetch('/api/v1/admin/storage', { + const res = await fetch('/v1/admin/storage', { headers: buildAuthHeaders(apiKey) }); if (!res.ok) return null; diff --git a/app/static/common/batch-sse.js b/app/static/common/batch-sse.js index 5f202999..d1c2f0b0 100644 --- a/app/static/common/batch-sse.js +++ b/app/static/common/batch-sse.js @@ -9,7 +9,7 @@ if (!taskId) return null; // Query param expects raw key const rawKey = normalizeApiKey(apiKey); - const url = `/api/v1/admin/batch/${taskId}/stream?api_key=${encodeURIComponent(rawKey || '')}`; + const url = `/v1/admin/batch/${taskId}/stream?app_key=${encodeURIComponent(rawKey || '')}`; const es = new EventSource(url); es.onmessage = (e) => { @@ -38,7 +38,7 @@ if (!taskId) return; try { const rawKey = normalizeApiKey(apiKey); - await fetch(`/api/v1/admin/batch/${taskId}/cancel`, { + await fetch(`/v1/admin/batch/${taskId}/cancel`, { method: 'POST', headers: rawKey ? { Authorization: `Bearer ${rawKey}` } : undefined }); diff --git a/app/static/common/header.html b/app/static/common/header.html index 7303639d..f994e474 100644 --- a/app/static/common/header.html +++ b/app/static/common/header.html @@ -23,14 +23,8 @@ 缓存管理 - +
+ Public diff --git a/app/static/login/login.js b/app/static/login/login.js index b9bd6cb6..48dfe956 100644 --- a/app/static/login/login.js +++ b/app/static/login/login.js @@ -1,13 +1,19 @@ const apiKeyInput = document.getElementById('api-key-input'); +const publicKeyInput = document.getElementById('public-key-input'); if (apiKeyInput) { apiKeyInput.addEventListener('keypress', (e) => { if (e.key === 'Enter') login(); }); } +if (publicKeyInput) { + publicKeyInput.addEventListener('keypress', (e) => { + if (e.key === 'Enter') login(); + }); +} async function requestLogin(key) { - const res = await fetch('/api/v1/admin/login', { - method: 'POST', + const res = await fetch('/v1/admin/verify', { + method: 'GET', headers: { 'Authorization': `Bearer ${key}` } }); return res.ok; @@ -15,12 +21,16 @@ async function requestLogin(key) { async function login() { const input = (apiKeyInput ? apiKeyInput.value : '').trim(); + const publicKey = (publicKeyInput ? publicKeyInput.value : '').trim(); if (!input) return; try { const ok = await requestLogin(input); if (ok) { await storeAppKey(input); + if (publicKey) { + await storePublicKey(publicKey); + } window.location.href = '/admin/token'; } else { showToast('密钥无效', 'error'); diff --git a/app/static/public/login.html b/app/static/public/login.html new file mode 100644 index 00000000..bcd0a516 --- /dev/null +++ b/app/static/public/login.html @@ -0,0 +1,68 @@ + + + + + + + Grok2API - Public + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/app/static/public/login.js b/app/static/public/login.js new file mode 100644 index 00000000..63a47ba9 --- /dev/null +++ b/app/static/public/login.js @@ -0,0 +1,51 @@ +const publicKeyInput = document.getElementById('public-key-input'); +if (publicKeyInput) { + publicKeyInput.addEventListener('keypress', (e) => { + if (e.key === 'Enter') login(); + }); +} + +async function requestPublicLogin(key) { + const headers = key ? { 'Authorization': `Bearer ${key}` } : {}; + const res = await fetch('/v1/public/verify', { + method: 'GET', + headers + }); + return res.ok; +} + +async function login() { + const input = (publicKeyInput ? publicKeyInput.value : '').trim(); + try { + const ok = await requestPublicLogin(input); + if (ok) { + await storePublicKey(input); + window.location.href = '/imagine'; + } else { + showToast('密钥无效', 'error'); + } + } catch (e) { + showToast('连接失败', 'error'); + } +} + +(async () => { + try { + const stored = await getStoredPublicKey(); + if (stored) { + const ok = await requestPublicLogin(stored); + if (ok) { + window.location.href = '/imagine'; + return; + } + clearStoredPublicKey(); + } + + const ok = await requestPublicLogin(''); + if (ok) { + window.location.href = '/imagine'; + } + } catch (e) { + return; + } +})(); diff --git a/app/static/token/token.html b/app/static/token/token.html index f43aec9c..ce6ca07f 100644 --- a/app/static/token/token.html +++ b/app/static/token/token.html @@ -294,7 +294,7 @@ - + diff --git a/app/static/token/token.js b/app/static/token/token.js index 1ed1d0af..8c41e78b 100644 --- a/app/static/token/token.js +++ b/app/static/token/token.js @@ -107,7 +107,7 @@ function getPaginationData() { } async function init() { - apiKey = await ensureApiKey(); + apiKey = await ensureAdminKey(); if (apiKey === null) return; setupEditPoolDefaults(); setupConfirmDialog(); @@ -116,7 +116,7 @@ async function init() { async function loadData() { try { - const res = await fetch('/api/v1/admin/tokens', { + const res = await fetch('/v1/admin/tokens', { headers: buildAuthHeaders(apiKey) }); if (res.ok) { @@ -536,7 +536,7 @@ async function syncToServer() { }); try { - const res = await fetch('/api/v1/admin/tokens', { + const res = await fetch('/v1/admin/tokens', { method: 'POST', headers: { 'Content-Type': 'application/json', @@ -622,7 +622,7 @@ async function refreshStatus(token) { btn.innerHTML = ``; } - const res = await fetch('/api/v1/admin/tokens/refresh', { + const res = await fetch('/v1/admin/tokens/refresh', { method: 'POST', headers: { 'Content-Type': 'application/json', @@ -673,7 +673,7 @@ async function startBatchRefresh() { setActionButtonsState(); try { - const res = await fetch('/api/v1/admin/tokens/refresh/async', { + const res = await fetch('/v1/admin/tokens/refresh/async', { method: 'POST', headers: { 'Content-Type': 'application/json', @@ -1039,7 +1039,7 @@ async function batchEnableNSFW() { try { const tokens = selected.length > 0 ? selected.map(t => t.token) : null; - const res = await fetch('/api/v1/admin/tokens/nsfw/enable/async', { + const res = await fetch('/v1/admin/tokens/nsfw/enable/async', { method: 'POST', headers: { 'Content-Type': 'application/json', diff --git a/app/static/voice/voice.html b/app/static/voice/voice.html index e9b0e928..a11b1480 100644 --- a/app/static/voice/voice.html +++ b/app/static/voice/voice.html @@ -148,7 +148,7 @@

Voice Live

- + diff --git a/app/static/voice/voice.js b/app/static/voice/voice.js index 2c14315a..3da0a2a1 100644 --- a/app/static/voice/voice.js +++ b/app/static/voice/voice.js @@ -128,9 +128,10 @@ } try { - const apiKey = await ensureApiKey(); - if (apiKey === null) { - toast('请先登录后台', 'error'); + const authHeader = await ensurePublicKey(); + if (authHeader === null) { + toast('请先配置 Public Key', 'error'); + window.location.href = '/login'; return; } @@ -145,9 +146,9 @@ speed: speedRange.value }); - const headers = buildAuthHeaders(apiKey); + const headers = buildAuthHeaders(authHeader); - const response = await fetch(`/api/v1/admin/voice/token?${params.toString()}`, { + const response = await fetch(`/v1/public/voice/token?${params.toString()}`, { headers }); diff --git a/config.defaults.toml b/config.defaults.toml index 959df3ee..bbe3898c 100644 --- a/config.defaults.toml +++ b/config.defaults.toml @@ -6,24 +6,34 @@ app_url = "http://127.0.0.1:8000" app_key = "grok2api" # API 调用密钥(可选) api_key = "" +# Public 调用密钥(可选) +public_key = "" +# 是否公开功能玩法(public 入口) +public_enabled = false # 生成图片的格式(url 或 base64) image_format = "url" # 生成视频的格式(html 或 url) video_format = "html" +# 是否启用临时对话模式 +temporary = true +# 是否禁用 Grok 记忆功能 +disable_memory = true +# 是否默认启用流式响应 +stream = true +# 是否默认启用思维链输出 +thinking = true +# 是否动态生成 Statsig 指纹 +dynamic_statsig = true +# 过滤的特殊标签列表 +filter_tags = ["xaiartifact","xai:tool_usage_card","grok:render"] -# ==================== 网络配置 ==================== -[network] -# 请求超时时间(秒) -timeout = 120 +# ==================== 代理配置 ==================== +[proxy] # 基础代理地址(代理到 Grok 官网) base_proxy_url = "" # 资源代理地址(代理静态资源如图片/视频) asset_proxy_url = "" - - -# ==================== 反爬虫验证 ==================== -[security] # Cloudflare Clearance Cookie cf_clearance = "" # curl_cffi 浏览器指纹 @@ -32,22 +42,6 @@ browser = "chrome136" user_agent = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36" -# ==================== 对话配置 ==================== -[chat] -# 是否启用临时对话模式 -temporary = true -# 是否禁用 Grok 记忆功能 -disable_memory = true -# 是否默认启用流式响应 -stream = true -# 是否启用思维链输出 -thinking = true -# 是否动态生成 Statsig 指纹 -dynamic_statsig = true -# 过滤的特殊标签列表 -filter_tags = ["xaiartifact","xai:tool_usage_card","grok:render"] - - # ==================== 重试策略 ==================== [retry] # 最大重试次数 @@ -64,28 +58,6 @@ retry_backoff_max = 30.0 retry_budget = 90.0 -# ==================== 超时配置 ==================== -[timeout] -# 流式响应空闲超时(秒) -stream_idle_timeout = 120.0 -# 视频生成空闲超时(秒) -video_idle_timeout = 90.0 - - -# ==================== 图片生成 ==================== -[image] -# 是否启用 WebSocket 直连生成图片 -image_ws = true -# WebSocket 生成时是否启用 NSFW -image_ws_nsfw = true -# 收到中等质量图后等待最终图的超时秒数 -image_ws_blocked_seconds = 15 -# 判定为最终图的最小字节数 -image_ws_final_min_bytes = 100000 -# 判定为中等质量图的最小字节数 -image_ws_medium_min_bytes = 30000 - - # ==================== Token 池管理 ==================== [token] # 是否启用 Token 自动刷新 @@ -101,7 +73,6 @@ save_delay_ms = 500 # 多 worker 状态同步间隔(秒) reload_interval_sec = 30 - # ==================== 缓存管理 ==================== [cache] # 是否启用自动清理 @@ -109,7 +80,45 @@ enable_auto_clean = true # 缓存大小上限(MB) limit_mb = 1024 -# ==================== Asset ==================== +# ==================== 对话配置 ==================== +[chat] +# Reverse 接口并发上限 +concurrent = 10 +# Reverse 接口超时时间(秒) +timeout = 60 +# 流式空闲超时时间(秒) +stream_timeout = 60 + +# ==================== 图像配置 ==================== +[image] +# WebSocket 请求超时时间(秒) +timeout = 120 +# WebSocket 流式空闲超时时间(秒) +stream_timeout = 120 +# 中等图后等待最终图的超时秒数 +final_timeout = 15 +# 是否启用 NSFW +nsfw = true +# 判定为中等质量图的最小字节数 +medium_min_bytes = 30000 +# 判定为最终图的最小字节数 +final_min_bytes = 100000 + +# ==================== 视频配置 ==================== +[video] +# Reverse 接口并发上限 +concurrent = 10 +# Reverse 接口超时时间(秒) +timeout = 60 +# 流式空闲超时时间(秒) +stream_timeout = 60 + +# ==================== 语音配置 ==================== +[voice] +# Voice 请求超时时间(秒) +timeout = 120 + +# ==================== 资产配置 ==================== [asset] # 上传并发数 upload_concurrent = 30 @@ -141,7 +150,7 @@ batch_size = 50 # NSFW 请求超时时间(秒) timeout = 60 -# ==================== Usage ==================== +# ==================== 用量配置 ==================== [usage] # Usage 批量开启并发上限 concurrent = 10 @@ -149,10 +158,3 @@ concurrent = 10 batch_size = 50 # Usage 请求超时时间(秒) timeout = 60 - - -# ==================== 并发性能 ==================== -[performance] - -# Media 生成并发上限 -media_max_concurrent = 50 diff --git a/data/config.toml b/data/config.toml index 92187997..d43d9efe 100644 --- a/data/config.toml +++ b/data/config.toml @@ -2,26 +2,36 @@ app_url = "http://127.0.0.1:8000" app_key = "grok2api" api_key = "" +public_key = "" +public_enabled = false image_format = "url" video_format = "html" +temporary = true +disable_memory = true +stream = true +thinking = true +dynamic_statsig = true +filter_tags = ["xaiartifact","xai:tool_usage_card","grok:render"] -[network] -timeout = 120 +[proxy] base_proxy_url = "" asset_proxy_url = "" - -[security] cf_clearance = "" browser = "chrome136" user_agent = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36" +[voice] +timeout = 120 + [chat] -temporary = true -disable_memory = true -stream = true -thinking = true -dynamic_statsig = true -filter_tags = ["xaiartifact","xai:tool_usage_card","grok:render"] +concurrent = 10 +timeout = 60 +stream_timeout = 60 + +[video] +concurrent = 10 +timeout = 60 +stream_timeout = 60 [retry] max_retry = 3 @@ -31,16 +41,13 @@ retry_backoff_factor = 2.0 retry_backoff_max = 30.0 retry_budget = 90.0 -[timeout] -stream_idle_timeout = 120.0 -video_idle_timeout = 90.0 - [image] -image_ws = true -image_ws_nsfw = true -image_ws_blocked_seconds = 15 -image_ws_final_min_bytes = 100000 -image_ws_medium_min_bytes = 30000 +timeout = 120 +stream_timeout = 120 +final_timeout = 15 +nsfw = true +medium_min_bytes = 30000 +final_min_bytes = 100000 [token] auto_refresh = true @@ -75,10 +82,3 @@ timeout = 60 concurrent = 10 batch_size = 50 timeout = 60 - -[performance] -media_max_concurrent = 50 -assets_max_concurrent = 25 -assets_delete_batch_size = 10 -assets_batch_size = 10 -assets_max_tokens = 1000 diff --git a/docs/README.en.md b/docs/README.en.md index d8e31cbc..ed433999 100644 --- a/docs/README.en.md +++ b/docs/README.en.md @@ -1,11 +1,13 @@ # Grok2API -[中文](../README.md) | **English** +[中文](../readme.md) | **English** > [!NOTE] -> This project is for learning and research only. You must comply with Grok's Terms of Use and applicable laws. Do not use it for illegal purposes. +> This project is for learning and research only. You must comply with Grok **Terms of Use** and **local laws and regulations**. Do not use for illegal purposes. -Grok2API rebuilt with **FastAPI**, fully aligned with the latest web call format. Supports streaming and non-streaming chat, image generation/editing, deep thinking, token pool concurrency, and automatic load balancing. +Grok2API rebuilt with **FastAPI**, fully aligned with the latest web call format. Supports streaming/non-streaming chat, image generation/editing, deep reasoning, token pool concurrency, and automatic load balancing. + +### NOTE: The project is no longer accepting PRs and feature updates; this is the last structure optimization. image @@ -13,7 +15,7 @@ Grok2API rebuilt with **FastAPI**, fully aligned with the latest web call format ## Usage -### How to start +### How to Start - Local development @@ -23,91 +25,93 @@ uv sync uv run main.py ``` -- Deployment +### How to Deploy +#### docker compose ``` git clone https://github.com/chenyme/grok2api docker compose up -d ``` -### One-click deploy (Render) - -[![Deploy to Render](https://render.com/images/deploy-to-render-button.svg)](https://render.com/deploy?repo=https://github.com/chenyme/grok2api) - -> Render free instances spin down after 15 minutes of inactivity; data is lost on resume/restart/redeploy. -> -> For persistence, use MySQL / Redis / PostgreSQL, on Render set: SERVER_STORAGE_TYPE (mysql/redis/pgsql) and SERVER_STORAGE_URL. - -#### Vercel Deployment +#### Vercel [![Deploy with Vercel](https://vercel.com/button)](https://vercel.com/new/clone?repository-url=https://github.com/chenyme/grok2api&env=LOG_LEVEL,LOG_FILE_ENABLED,DATA_DIR,SERVER_STORAGE_TYPE,SERVER_STORAGE_URL&envDefaults=%7B%22DATA_DIR%22%3A%22/tmp/data%22%2C%22LOG_FILE_ENABLED%22%3A%22false%22%2C%22LOG_LEVEL%22%3A%22INFO%22%2C%22SERVER_STORAGE_TYPE%22%3A%22local%22%2C%22SERVER_STORAGE_URL%22%3A%22%22%7D) -> Make sure to set DATA_DIR=/tmp/data and disable file logging (LOG_FILE_ENABLED=false). +> Make sure to set `DATA_DIR=/tmp/data` and disable file logging with `LOG_FILE_ENABLED=false`. > -> For persistence, use MySQL / Redis / PostgreSQL. On Vercel set: SERVER_STORAGE_TYPE (mysql/redis/pgsql) and SERVER_STORAGE_URL. +> For persistence, use MySQL / Redis / PostgreSQL and set `SERVER_STORAGE_TYPE` (mysql/redis/pgsql) and `SERVER_STORAGE_URL` in Vercel env vars. -#### Render Deployment +#### Render [![Deploy to Render](https://render.com/images/deploy-to-render-button.svg)](https://render.com/deploy?repo=https://github.com/chenyme/grok2api) -> Render free instances spin down after 15 minutes of inactivity; data is lost on resume/restart/redeploy. +> Render free instances sleep after 15 minutes of inactivity; restart/redeploy will lose data. > -> For persistence, use MySQL / Redis / PostgreSQL. On Render set: SERVER_STORAGE_TYPE (mysql/redis/pgsql) and SERVER_STORAGE_URL. +> For persistence, use MySQL / Redis / PostgreSQL and set `SERVER_STORAGE_TYPE` (mysql/redis/pgsql) and `SERVER_STORAGE_URL` in Render env vars. + +### Admin Panel + +Access: `http://:8000/admin` +Default password: `grok2api` (config `app.app_key`, recommended to change). -### Admin panel +**Features**: -URL: `http://:8000/admin` -Default password: `grok2api` (config key `app.app_key`, change it in production). +- **Token Management**: import/add/delete tokens, view status and quota +- **Status Filter**: filter by status (active/limited/expired) or NSFW status +- **Batch Ops**: batch refresh/export/delete/enable NSFW +- **NSFW Enable**: one-click Unhinged for tokens (proxy or cf_clearance required) +- **Config Management**: update system config online +- **Cache Management**: view and clear media cache -### Environment variables +### Environment Variables -| Variable | Description | Default | Example | -| :--- | :--- | :--- | :--- | -| `LOG_LEVEL` | Log level | `INFO` | `DEBUG` | -| `LOG_FILE_ENABLED` | Enable file logging | `true` | `false` | -| `DATA_DIR` | Data directory (config/tokens/locks) | `./data` | `/data` | -| `SERVER_HOST` | Bind address | `0.0.0.0` | `0.0.0.0` | -| `SERVER_PORT` | Service port | `8000` | `8000` | -| `SERVER_WORKERS` | Uvicorn worker count | `1` | `2` | -| `SERVER_STORAGE_TYPE` | Storage type (`local`/`redis`/`mysql`/`pgsql`) | `local` | `pgsql` | -| `SERVER_STORAGE_URL` | Storage URL (empty for local) | `""` | `postgresql+asyncpg://user:password@host:5432/db` | +> Configure `.env` -> MySQL example: `mysql+aiomysql://user:password@host:3306/db` (if you set `mysql://`, it will be normalized to `mysql+aiomysql://`) +| Name | Description | Default | Example | +| :--------------------- | :-------------------------------------------------- | :---------- | :---------------------------------------------------- | +| `LOG_LEVEL` | Log level | `INFO` | `DEBUG` | +| `LOG_FILE_ENABLED` | Enable file logging | `true` | `false` | +| `DATA_DIR` | Data dir (config/tokens/locks) | `./data` | `/data` | +| `SERVER_HOST` | Bind address | `0.0.0.0` | `0.0.0.0` | +| `SERVER_PORT` | Server port | `8000` | `8000` | +| `SERVER_WORKERS` | Uvicorn worker count | `1` | `2` | +| `SERVER_STORAGE_TYPE` | Storage type (`local`/`redis`/`mysql`/`pgsql`) | `local` | `pgsql` | +| `SERVER_STORAGE_URL` | Storage DSN (optional for local) | `""` | `postgresql+asyncpg://user:password@host:5432/db` | -### Usage limits +> MySQL example: `mysql+aiomysql://user:password@host:3306/db` (if you provide `mysql://`, it will be converted to `mysql+aiomysql://`) + +### Quotas - Basic account: 80 requests / 20h - Super account: 140 requests / 2h ### Models -| Model | Cost | Account | Chat | Image | Video | -| :--- | :---: | :--- | :---: | :---: | :---: | -| `grok-3` | 1 | Basic/Super | Yes | Yes | - | -| `grok-3-fast` | 1 | Basic/Super | Yes | Yes | - | -| `grok-4` | 1 | Basic/Super | Yes | Yes | - | -| `grok-4-mini` | 1 | Basic/Super | Yes | Yes | - | -| `grok-4-fast` | 1 | Basic/Super | Yes | Yes | - | -| `grok-4-heavy` | 4 | Super | Yes | Yes | - | -| `grok-4.1` | 1 | Basic/Super | Yes | Yes | - | -| `grok-4.1-thinking` | 4 | Basic/Super | Yes | Yes | - | -| `grok-imagine-1.0` | 4 | Basic/Super | - | Yes | - | -| `grok-imagine-1.0-edit` | 4 | Basic/Super | - | Yes | - | -| `grok-imagine-1.0-video` | - | Basic/Super | - | - | Yes | +| Model | Cost | Account | Chat | Image | Video | +| :---------------------- | :--: | :---------- | :--: | :---: | :---: | +| `grok-3` | 1 | Basic/Super | Yes | Yes | - | +| `grok-3-fast` | 1 | Basic/Super | Yes | Yes | - | +| `grok-4` | 1 | Basic/Super | Yes | Yes | - | +| `grok-4-mini` | 1 | Basic/Super | Yes | Yes | - | +| `grok-4-fast` | 1 | Basic/Super | Yes | Yes | - | +| `grok-4-heavy` | 4 | Super | Yes | Yes | - | +| `grok-4.1` | 1 | Basic/Super | Yes | Yes | - | +| `grok-4.1-thinking` | 4 | Basic/Super | Yes | Yes | - | +| `grok-imagine-1.0` | 4 | Basic/Super | - | Yes | - | +| `grok-imagine-1.0-edit` | 4 | Basic/Super | - | Yes | - | +| `grok-imagine-1.0-video`| - | Basic/Super | - | - | Yes |
## API ### `POST /v1/chat/completions` + > Generic endpoint: chat, image generation, image editing, video generation, video upscaling ```bash -curl http://localhost:8000/v1/chat/completions \ - -H "Content-Type: application/json" \ - -H "Authorization: Bearer $GROK2API_API_KEY" \ - -d '{ +curl http://localhost:8000/v1/chat/completions -H "Content-Type: application/json" -H "Authorization: Bearer $GROK2API_API_KEY" -d '{ "model": "grok-4", "messages": [{"role":"user","content":"Hello"}] }' @@ -118,19 +122,45 @@ curl http://localhost:8000/v1/chat/completions \
-| Field | Type | Description | Allowed values | -| :--- | :--- | :--- | :--- | -| `model` | string | Model ID | - | -| `messages` | array | Message list | `developer`, `system`, `user`, `assistant` | -| `stream` | boolean | Enable streaming | `true`, `false` | -| `thinking` | string | Thinking mode | `enabled`, `disabled`, `null` | -| `video_config` | object | **Video model only** | - | -| └─ `aspect_ratio` | string | Video aspect ratio | `16:9`, `9:16`, `1:1`, `2:3`, `3:2` | -| └─ `video_length` | integer | Video length (seconds) | `6`, `10`, `15` | -| └─ `resolution_name` | string | Resolution | `480p`, `720p` | -| └─ `preset` | string | Style preset | `fun`, `normal`, `spicy` | - -Note: any other parameters will be discarded and ignored. +| Field | Type | Description | Allowed values | +| :--------------------- | :------ | :-------------------------- | :--------------------------------------------------------------------------------------------------------------- | +| `model` | string | Model ID | See model list above | +| `messages` | array | Message list | See message format below | +| `stream` | boolean | Enable streaming | `true`, `false` | +| `reasoning_effort` | string | Reasoning effort | `none`, `minimal`, `low`, `medium`, `high`, `xhigh` | +| `temperature` | number | Sampling temperature | `0` ~ `2` | +| `top_p` | number | Nucleus sampling | `0` ~ `1` | +| `video_config` | object | **Video model only** | Supported: `grok-imagine-1.0-video` | +| └─ `aspect_ratio` | string | Video aspect ratio | `16:9`, `9:16`, `1:1`, `2:3`, `3:2`, `1280x720`, `720x1280`, `1792x1024`, `1024x1792`, `1024x1024` | +| └─ `video_length` | integer | Video length (seconds) | `6`, `10`, `15` | +| └─ `resolution_name` | string | Resolution | `480p`, `720p` | +| └─ `preset` | string | Style preset | `fun`, `normal`, `spicy`, `custom` | +| `image_config` | object | **Image models only** | Supported: `grok-imagine-1.0` / `grok-imagine-1.0-edit` | +| └─ `n` | integer | Number of images | `1` ~ `10` | +| └─ `size` | string | Image size | `1280x720`, `720x1280`, `1792x1024`, `1024x1792`, `1024x1024` | +| └─ `response_format` | string | Response format | `url`, `b64_json`, `base64` | + +**Message format (messages)**: + +| Field | Type | Description | +| :-------- | :----------- | :-------------------------------------------------- | +| `role` | string | `developer`, `system`, `user`, `assistant` | +| `content` | string/array | Message content (plain text or multimodal array) | + +**Multimodal content block types (content array)**: + +| type | Description | Example | +| :------------ | :---------- | :----------------------------------------------------------------------- | +| `text` | Text | `{"type": "text", "text": "Describe this image"}` | +| `image_url` | Image URL | `{"type": "image_url", "image_url": {"url": "https://..."}}` | +| `input_audio` | Audio | `{"type": "input_audio", "input_audio": {"data": "https://..."}}` | +| `file` | File | `{"type": "file", "file": {"file_data": "https://..."}}` | + +**Notes**: +- `image_url/input_audio/file` only supports URL or Data URI (`data:;base64,...`); raw base64 will be rejected. +- `reasoning_effort`: `none` disables thinking output; any other value enables it. +- `grok-imagine-1.0-edit` requires an image; if multiple are provided, the last image and last text are used. +- Any other parameters will be discarded and ignored.
@@ -139,13 +169,11 @@ Note: any other parameters will be discarded and ignored.
### `POST /v1/images/generations` + > Image generation endpoint ```bash -curl http://localhost:8000/v1/images/generations \ - -H "Content-Type: application/json" \ - -H "Authorization: Bearer $GROK2API_API_KEY" \ - -d '{ +curl http://localhost:8000/v1/images/generations -H "Content-Type: application/json" -H "Authorization: Bearer $GROK2API_API_KEY" -d '{ "model": "grok-imagine-1.0", "prompt": "A cat floating in space", "n": 1 @@ -157,19 +185,19 @@ curl http://localhost:8000/v1/images/generations \
-| Field | Type | Description | Allowed values | -| :--- | :--- | :--- | :--- | -| `model` | string | Image model ID | `grok-imagine-1.0` | -| `prompt` | string | Prompt | - | -| `n` | integer | Number of images | `1` - `10` (streaming: `1` or `2` only) | -| `stream` | boolean | Enable streaming | `true`, `false` | -| `size` | string | Image size | `1024x1024` (WS mode maps to aspect ratio) | -| `quality` | string | Image quality | `standard` (not customizable yet) | -| `response_format` | string | Response format | `url`, `b64_json` | -| `style` | string | Style | - (not supported yet) | +| Field | Type | Description | Allowed values | +| :----------------- | :------ | :--------------- | :----------------------------------------------------------------- | +| `model` | string | Image model ID | `grok-imagine-1.0` | +| `prompt` | string | Prompt | - | +| `n` | integer | Number of images | `1` - `10` (streaming: `1` or `2` only) | +| `stream` | boolean | Enable streaming | `true`, `false` | +| `size` | string | Image size | `1280x720`, `720x1280`, `1792x1024`, `1024x1792`, `1024x1024` | +| `quality` | string | Image quality | - (not supported) | +| `response_format` | string | Response format | `url`, `b64_json`, `base64` | +| `style` | string | Style | - | -Note: when `grok.image_ws=true`, `size` is mapped to aspect ratio (only 5 supported: `16:9`, `9:16`, `1:1`, `2:3`, `3:2`); you can also pass those ratio strings directly: -`1024x576/1280x720/1536x864 -> 16:9`, `576x1024/720x1280/864x1536 -> 9:16`, `1024x1024/512x512 -> 1:1`, `1024x1536/512x768/768x1024 -> 2:3`, `1536x1024/768x512/1024x768 -> 3:2`, otherwise defaults to `2:3`. Other parameters are ignored. +**Notes**: +- `quality` and `style` are OpenAI compatibility placeholders and are not customizable yet.
@@ -178,15 +206,11 @@ Note: when `grok.image_ws=true`, `size` is mapped to aspect ratio (only 5 suppor
### `POST /v1/images/edits` + > Image edit endpoint (multipart/form-data) ```bash -curl http://localhost:8000/v1/images/edits \ - -H "Authorization: Bearer $GROK2API_API_KEY" \ - -F "model=grok-imagine-1.0-edit" \ - -F "prompt=Make it sharper" \ - -F "image=@/path/to/image.png" \ - -F "n=1" +curl http://localhost:8000/v1/images/edits -H "Authorization: Bearer $GROK2API_API_KEY" -F "model=grok-imagine-1.0-edit" -F "prompt=Make it sharper" -F "image=@/path/to/image.png" -F "n=1" ```
@@ -194,19 +218,20 @@ curl http://localhost:8000/v1/images/edits \
-| Field | Type | Description | Allowed values | -| :--- | :--- | :--- | :--- | -| `model` | string | Image model ID | `grok-imagine-1.0-edit` | -| `prompt` | string | Edit prompt | - | -| `image` | file | Image file | `png`, `jpg`, `webp` | -| `n` | integer | Number of images | `1` - `10` (streaming: `1` or `2` only) | -| `stream` | boolean | Enable streaming | `true`, `false` | -| `size` | string | Image size | `1024x1024` (not customizable yet) | -| `quality` | string | Image quality | `standard` (not customizable yet) | -| `response_format` | string | Response format | `url`, `b64_json` | -| `style` | string | Style | - (not supported yet) | - -Note: `size`, `quality`, `style` are OpenAI compatibility placeholders and are not customizable yet. +| Field | Type | Description | Allowed values | +| :----------------- | :------ | :--------------- | :----------------------------------------------------------------- | +| `model` | string | Image model ID | `grok-imagine-1.0-edit` | +| `prompt` | string | Edit prompt | - | +| `image` | file | Image file | `png`, `jpg`, `webp` | +| `n` | integer | Number of images | `1` - `10` (streaming: `1` or `2` only) | +| `stream` | boolean | Enable streaming | `true`, `false` | +| `size` | string | Image size | `1280x720`, `720x1280`, `1792x1024`, `1024x1792`, `1024x1024` | +| `quality` | string | Image quality | - (not supported) | +| `response_format` | string | Response format | `url`, `b64_json`, `base64` | +| `style` | string | Style | - (not supported) | + +**Notes**: +- `quality` and `style` are OpenAI compatibility placeholders and are not customizable yet.
@@ -219,64 +244,74 @@ Note: `size`, `quality`, `style` are OpenAI compatibility placeholders and are n Config file: `data/config.toml` > [!NOTE] -> In production or behind a reverse proxy, make sure `app.app_url` is set to the public URL. -> Otherwise file links may be incorrect or return 403. +> In production or reverse proxy environments, set `app.app_url` to a publicly accessible URL, +> otherwise file links may be incorrect or return 403. > [!TIP] -> **v2.0 Config Upgrade**: Existing users will have their config **auto-migrated** to the new structure upon update. -> Custom values from the old `[grok]` section will be automatically mapped to the corresponding new sections. - -| Module | Field | Key | Description | Default | -| :--- | :--- | :--- | :--- | :--- | -| **app** | `app_url` | App URL | External access URL for Grok2API (used for file links). | `http://127.0.0.1:8000` | -| | `app_key` | Admin password | Password for the Grok2API admin panel (required). | `grok2api` | -| | `api_key` | API key | Token for calling Grok2API (optional). | `""` | -| | `image_format` | Image format | Output image format (`url` or `base64`). | `url` | -| | `video_format` | Video format | Output video format (html tag or processed url). | `html` | -| **network** | `timeout` | Request timeout | Timeout for Grok requests (seconds). | `120` | -| | `base_proxy_url` | Base proxy URL | Base service address proxying Grok official site. | `""` | -| | `asset_proxy_url` | Asset proxy URL | Proxy URL for Grok static assets (images/videos). | `""` | -| **security** | `cf_clearance` | CF Clearance | Cloudflare clearance cookie for bypassing anti-bot. | `""` | -| | `browser` | Browser fingerprint | curl_cffi browser fingerprint (e.g. chrome136). | `chrome136` | -| | `user_agent` | User-Agent | HTTP User-Agent string. | `Mozilla/5.0 (Macintosh; ...)` | -| **chat** | `temporary` | Temporary chat | Enable temporary conversation mode. | `true` | -| | `disable_memory` | Disable memory | Disable Grok memory to prevent irrelevant context. | `true` | -| | `stream` | Streaming | Enable streaming by default. | `true` | -| | `thinking` | Thinking chain | Enable model thinking output. | `true` | -| | `dynamic_statsig` | Dynamic fingerprint | Enable dynamic Statsig value generation. | `true` | -| | `filter_tags` | Filter tags | Auto-filter special tags in Grok responses. | `["xaiartifact", "xai:tool_usage_card", "grok:render"]` | -| **retry** | `max_retry` | Max retries | Max retries on Grok request failure. | `3` | -| | `retry_status_codes` | Retry status codes | HTTP status codes that trigger retry. | `[401, 429, 403]` | -| | `retry_backoff_base` | Backoff base | Base delay for retry backoff (seconds). | `0.5` | -| | `retry_backoff_factor` | Backoff factor | Exponential multiplier for retry backoff. | `2.0` | -| | `retry_backoff_max` | Backoff max | Max wait per retry (seconds). | `30.0` | -| | `retry_budget` | Backoff budget | Max total retry time per request (seconds). | `90.0` | -| **timeout** | `stream_idle_timeout` | Stream idle timeout | Idle timeout for streaming responses (seconds). | `120.0` | -| | `video_idle_timeout` | Video idle timeout | Idle timeout for video generation (seconds). | `90.0` | -| **image** | `image_ws` | WebSocket generation | When enabled, `/v1/images/generations` uses WebSocket. | `true` | -| | `image_ws_nsfw` | NSFW mode | Enable NSFW in WebSocket requests. | `true` | -| | `image_ws_blocked_seconds` | Blocked threshold | Mark blocked if no final image after this many seconds post-medium. | `15` | -| | `image_ws_final_min_bytes` | Final min bytes | Minimum bytes to treat an image as final (JPG usually > 100KB). | `100000` | -| | `image_ws_medium_min_bytes` | Medium min bytes | Minimum bytes for medium quality image. | `30000` | -| **token** | `auto_refresh` | Auto refresh | Enable automatic token refresh. | `true` | -| | `refresh_interval_hours` | Refresh interval | Regular token refresh interval (hours). | `8` | -| | `super_refresh_interval_hours` | Super refresh interval | Super token refresh interval (hours). | `2` | -| | `fail_threshold` | Failure threshold | Consecutive failures before a token is disabled. | `5` | -| | `save_delay_ms` | Save delay | Debounced save delay for token changes (ms). | `500` | -| | `reload_interval_sec` | Sync interval | Token state refresh interval in multi-worker setups (sec). | `30` | -| **cache** | `enable_auto_clean` | Auto clean | Enable cache auto clean; cleanup when exceeding limit. | `true` | -| | `limit_mb` | Cleanup threshold | Cache size threshold (MB) that triggers cleanup. | `1024` | -| **performance** | `media_max_concurrent` | Media concurrency | Concurrency cap for video/media generation. Recommended 50. | `50` | -| | `assets_max_concurrent` | Assets concurrency | Concurrency cap for batch asset find/delete. Recommended 25. | `25` | -| | `assets_batch_size` | Assets batch size | Batch size for asset find/delete. Recommended 10. | `10` | -| | `assets_max_tokens` | Assets max tokens | Max tokens per asset find/delete batch. Recommended 1000. | `1000` | -| | `assets_delete_batch_size` | Assets delete batch | Batch concurrency for single-account asset deletion. Recommended 10. | `10` | -| | `usage_max_concurrent` | Token refresh concurrency | Concurrency cap for batch usage refresh. Recommended 25. | `25` | -| | `usage_batch_size` | Token refresh batch size | Batch size for usage refresh. Recommended 50. | `50` | -| | `usage_max_tokens` | Token refresh max tokens | Max tokens per usage refresh batch. Recommended 1000. | `1000` | -| | `nsfw_max_concurrent` | NSFW enable concurrency | Concurrency cap for enabling NSFW in batch. Recommended 10. | `10` | -| | `nsfw_batch_size` | NSFW enable batch size | Batch size for enabling NSFW. Recommended 50. | `50` | -| | `nsfw_max_tokens` | NSFW enable max tokens | Max tokens per NSFW batch to avoid mistakes. Recommended 1000. | `1000` | +> **v2.0 config structure upgrade**: legacy config will be **automatically migrated** to the new structure. +> Custom values under the old `[grok]` section are mapped to the new sections. + +| Module | Field | Name | Description | Default | +| :------------------- | :----------------------------- | :--------------------- | :----------------------------------------------------------------- | :---------------------------------------------------------- | +| **app** | `app_url` | App URL | External access URL for Grok2API (used for file links). | `http://127.0.0.1:8000` | +| | `app_key` | Admin password | Password for Grok2API admin panel (required). | `grok2api` | +| | `api_key` | API key | Token for calling Grok2API (optional). | `""` | +| | `image_format` | Image format | Output image format (url or base64). | `url` | +| | `video_format` | Video format | Output video format (html or url, url is processed). | `html` | +| | `temporary` | Temporary chat | Enable temporary conversation mode. | `true` | +| | `disable_memory` | Disable memory | Disable Grok memory to prevent irrelevant context. | `true` | +| | `stream` | Streaming | Enable streaming by default. | `true` | +| | `thinking` | Thinking chain | Enable model thinking output. | `true` | +| | `dynamic_statsig` | Dynamic fingerprint | Enable dynamic Statsig generation. | `true` | +| | `filter_tags` | Filter tags | Auto-filter special tags in Grok responses. | `["xaiartifact", "xai:tool_usage_card", "grok:render"]` | +| **proxy** | `base_proxy_url` | Base proxy URL | Base service address proxying Grok official site. | `""` | +| | `asset_proxy_url` | Asset proxy URL | Proxy URL for Grok static assets (images/videos). | `""` | +| | `cf_clearance` | CF Clearance | Cloudflare clearance cookie for anti-bot. | `""` | +| | `browser` | Browser fingerprint | curl_cffi browser fingerprint (e.g. chrome136). | `chrome136` | +| | `user_agent` | User-Agent | HTTP User-Agent string. | `Mozilla/5.0 (Macintosh; ...)` | +| **voice** | `timeout` | Request timeout | Voice request timeout (seconds). | `120` | +| **chat** | `concurrent` | Concurrency | Reverse interface concurrency limit. | `10` | +| | `timeout` | Request timeout | Reverse interface timeout (seconds). | `60` | +| | `stream_timeout` | Stream idle timeout | Stream idle timeout (seconds). | `60` | +| **video** | `concurrent` | Concurrency | Reverse interface concurrency limit. | `10` | +| | `timeout` | Request timeout | Reverse interface timeout (seconds). | `60` | +| | `stream_timeout` | Stream idle timeout | Stream idle timeout (seconds). | `60` | +| **retry** | `max_retry` | Max retries | Max retries on Grok request failure. | `3` | +| | `retry_status_codes` | Retry status codes | HTTP status codes that trigger retry. | `[401, 429, 403]` | +| | `retry_backoff_base` | Backoff base | Base delay for retry backoff (seconds). | `0.5` | +| | `retry_backoff_factor` | Backoff factor | Exponential multiplier for retry backoff. | `2.0` | +| | `retry_backoff_max` | Backoff max | Max wait per retry (seconds). | `30.0` | +| | `retry_budget` | Backoff budget | Max total retry time per request (seconds). | `90.0` | +| **image** | `timeout` | Request timeout | WebSocket request timeout (seconds). | `120` | +| | `stream_timeout` | Stream idle timeout | WebSocket stream idle timeout (seconds). | `120` | +| | `final_timeout` | Final image timeout | Timeout after medium image before final (seconds). | `15` | +| | `nsfw` | NSFW mode | Enable NSFW in WebSocket requests. | `true` | +| | `medium_min_bytes` | Medium min bytes | Minimum bytes for medium quality image. | `30000` | +| | `final_min_bytes` | Final min bytes | Minimum bytes to treat an image as final (JPG usually > 100KB). | `100000` | +| **token** | `auto_refresh` | Auto refresh | Enable automatic token refresh. | `true` | +| | `refresh_interval_hours` | Refresh interval | Regular token refresh interval (hours). | `8` | +| | `super_refresh_interval_hours` | Super refresh interval | Super token refresh interval (hours). | `2` | +| | `fail_threshold` | Failure threshold | Consecutive failures before a token is disabled. | `5` | +| | `save_delay_ms` | Save delay | Debounced save delay for token changes (ms). | `500` | +| | `reload_interval_sec` | Sync interval | Token state refresh interval in multi-worker setups (sec). | `30` | +| **cache** | `enable_auto_clean` | Auto clean | Enable cache auto clean; cleanup when exceeding limit. | `true` | +| | `limit_mb` | Cleanup threshold | Cache size threshold (MB) that triggers cleanup. | `1024` | +| **asset** | `upload_concurrent` | Upload concurrency | Max concurrency for upload. Recommended 30. | `30` | +| | `upload_timeout` | Upload timeout | Upload timeout (seconds). Recommended 60. | `60` | +| | `download_concurrent` | Download concurrency | Max concurrency for download. Recommended 30. | `30` | +| | `download_timeout` | Download timeout | Download timeout (seconds). Recommended 60. | `60` | +| | `list_concurrent` | List concurrency | Max concurrency for asset listing. Recommended 10. | `10` | +| | `list_timeout` | List timeout | List timeout (seconds). Recommended 60. | `60` | +| | `list_batch_size` | List batch size | Batch size per list request. Recommended 10. | `10` | +| | `delete_concurrent` | Delete concurrency | Max concurrency for asset delete. Recommended 10. | `10` | +| | `delete_timeout` | Delete timeout | Delete timeout (seconds). Recommended 60. | `60` | +| | `delete_batch_size` | Delete batch size | Batch size per delete request. Recommended 10. | `10` | +| **nsfw** | `concurrent` | Concurrency | Max concurrency for enabling NSFW. Recommended 10. | `10` | +| | `batch_size` | Batch size | Batch size for enabling NSFW. Recommended 50. | `50` | +| | `timeout` | Request timeout | NSFW enable request timeout (seconds). Recommended 60. | `60` | +| **usage** | `concurrent` | Concurrency | Max concurrency for usage refresh. Recommended 10. | `10` | +| | `batch_size` | Batch size | Batch size for usage refresh. Recommended 50. | `50` | +| | `timeout` | Request timeout | Usage query timeout (seconds). Recommended 60. | `60` |
diff --git a/main.py b/main.py index 8a0e537d..d94218ec 100644 --- a/main.py +++ b/main.py @@ -121,10 +121,14 @@ def create_app() -> FastAPI: if static_dir.exists(): app.mount("/static", StaticFiles(directory=static_dir), name="static") - # 注册管理路由 + # 注册管理与公共路由 from app.api.v1.admin import router as admin_router + from app.api.v1.public import router as public_router + from app.api.v1.pages import router as pages_router - app.include_router(admin_router) + app.include_router(admin_router, prefix="/v1/admin") + app.include_router(public_router, prefix="/v1/public") + app.include_router(pages_router) return app diff --git a/readme.md b/readme.md index 880908ab..3850d655 100644 --- a/readme.md +++ b/readme.md @@ -69,16 +69,16 @@ docker compose up -d > 配置 `.env` 文件 -| 变量名 | 说明 | 默认值 | 示例 | +| 变量名 | 说明 | 默认值 | 示例 | | :---------------------- | :-------------------------------------------------- | :---------- | :-------------------------------------------------- | -| `LOG_LEVEL` | 日志级别 | `INFO` | `DEBUG` | -| `LOG_FILE_ENABLED` | 是否启用文件日志 | `true` | `false` | -| `DATA_DIR` | 数据目录(配置/Token/锁) | `./data` | `/data` | -| `SERVER_HOST` | 服务监听地址 | `0.0.0.0` | `0.0.0.0` | -| `SERVER_PORT` | 服务端口 | `8000` | `8000` | -| `SERVER_WORKERS` | Uvicorn worker 数量 | `1` | `2` | -| `SERVER_STORAGE_TYPE` | 存储类型(`local`/`redis`/`mysql`/`pgsql`) | `local` | `pgsql` | -| `SERVER_STORAGE_URL` | 存储连接串(local 时可为空) | `""` | `postgresql+asyncpg://user:password@host:5432/db` | +| `LOG_LEVEL` | 日志级别 | `INFO` | `DEBUG` | +| `LOG_FILE_ENABLED` | 是否启用文件日志 | `true` | `false` | +| `DATA_DIR` | 数据目录(配置/Token/锁) | `./data` | `/data` | +| `SERVER_HOST` | 服务监听地址 | `0.0.0.0` | `0.0.0.0` | +| `SERVER_PORT` | 服务端口 | `8000` | `8000` | +| `SERVER_WORKERS` | Uvicorn worker 数量 | `1` | `2` | +| `SERVER_STORAGE_TYPE` | 存储类型(`local`/`redis`/`mysql`/`pgsql`) | `local` | `pgsql` | +| `SERVER_STORAGE_URL` | 存储连接串(local 时可为空) | `""` | `postgresql+asyncpg://user:password@host:5432/db` | > MySQL 示例:`mysql+aiomysql://user:password@host:3306/db`(若填 `mysql://` 会自动转为 `mysql+aiomysql://`) @@ -126,34 +126,46 @@ curl http://localhost:8000/v1/chat/completions \
-| 字段 | 类型 | 说明 | 可用参数 | -| :---------------------- | :------ | :----------------------------- | :-------------------------------------------- | -| `model` | string | 模型名称 | 见上方模型列表 | -| `messages` | array | 消息列表 | 见下方消息格式 | -| `stream` | boolean | 是否开启流式输出 | `true`, `false` | -| `thinking` | string | 思维链模式 | `enabled`, `disabled`, `null` | -| `video_config` | object | **视频模型专用配置对象** | - | -| └─`aspect_ratio` | string | 视频宽高比 | `16:9`, `9:16`, `1:1`, `2:3`, `3:2` | -| └─`video_length` | integer | 视频时长 (秒) | `6`, `10`, `15` | -| └─`resolution_name` | string | 分辨率 | `480p`, `720p` | -| └─`preset` | string | 风格预设 | `fun`, `normal`, `spicy`, `custom` | +| 字段 | 类型 | 说明 | 可用参数 | +| :---------------------- | :------ | :----------------------------- | :------------------------------------------------------------------------------------------------- | +| `model` | string | 模型名称 | 见上方模型列表 | +| `messages` | array | 消息列表 | 见下方消息格式 | +| `stream` | boolean | 是否开启流式输出 | `true`, `false` | +| `reasoning_effort` | string | 推理强度 | `none`, `minimal`, `low`, `medium`, `high`, `xhigh` | +| `temperature` | number | 采样温度 | `0` ~ `2` | +| `top_p` | number | nucleus 采样 | `0` ~ `1` | +| `video_config` | object | **视频模型专用配置对象** | 支持:`grok-imagine-1.0-video` | +| └─`aspect_ratio` | string | 视频宽高比 | `16:9`, `9:16`, `1:1`, `2:3`, `3:2`, `1280x720`, `720x1280`, `1792x1024`, `1024x1792`, `1024x1024` | +| └─`video_length` | integer | 视频时长 (秒) | `6`, `10`, `15` | +| └─`resolution_name` | string | 分辨率 | `480p`, `720p` | +| └─`preset` | string | 风格预设 | `fun`, `normal`, `spicy`, `custom` | +| `image_config` | object | **图片模型专用配置对象** | 支持:`grok-imagine-1.0` / `grok-imagine-1.0-edit` | +| └─`n` | integer | 生成数量 | `1` ~ `10` | +| └─`size` | string | 图片尺寸 | `1280x720`, `720x1280`, `1792x1024`, `1024x1792`, `1024x1024` | +| └─`response_format` | string | 响应格式 | `url`, `b64_json`, `base64` | **消息格式 (messages)**: -| 字段 | 类型 | 说明 | +| 字段 | 类型 | 说明 | | :---------- | :----------- | :------------------------------------------------------- | -| `role` | string | 角色:`developer`, `system`, `user`, `assistant` | -| `content` | string/array | 消息内容,支持纯文本或多模态数组 | +| `role` | string | 角色:`developer`, `system`, `user`, `assistant` | +| `content` | string/array | 消息内容,支持纯文本或多模态数组 | **多模态内容块类型 (content array)**: -| type | 说明 | 示例 | -| :------------ | :------- | :------------------------------------------------------------- | -| `text` | 文本内容 | `{"type": "text", "text": "描述这张图片"}` | -| `image_url` | 图片 URL | `{"type": "image_url", "image_url": {"url": "https://..."}}` | -| `file` | 文件 | `{"type": "file", "file": {"url": "https://..."}}` | +| type | 说明 | 示例 | +| :------------ | :------- | :---------------------------------------------------------------- | +| `text` | 文本内容 | `{"type": "text", "text": "描述这张图片"}` | +| `image_url` | 图片 URL | `{"type": "image_url", "image_url": {"url": "https://..."}}` | +| `input_audio` | 音频 | `{"type": "input_audio", "input_audio": {"data": "https://..."}}` | +| `file` | 文件 | `{"type": "file", "file": {"file_data": "https://..."}}` | -注:除上述外的其他参数将自动丢弃并忽略 +**注意事项**: + +- `image_url/input_audio/file` 仅支持 URL 或 Data URI(`data:;base64,...`),裸 base64 会报错。 +- `reasoning_effort`:`none` 表示不输出思考,其他值都会输出思考内容。 +- `grok-imagine-1.0-edit` 必须提供图片,多图默认取最后一张与最后一个文本。 +- 除上述外的其他参数将自动丢弃并忽略。
@@ -181,20 +193,19 @@ curl http://localhost:8000/v1/images/generations \
-| 字段 | 类型 | 说明 | 可用参数 | -| :------------------ | :------ | :--------------- | :------------------------------------------- | -| `model` | string | 图像模型名 | `grok-imagine-1.0` | -| `prompt` | string | 图像描述提示词 | - | -| `n` | integer | 生成数量 | `1` - `10` (流式模式仅限 `1` 或 `2`) | -| `stream` | boolean | 是否开启流式输出 | `true`, `false` | -| `size` | string | 图片尺寸 | `1024x1024` (WS 模式支持映射到比例) | -| `quality` | string | 图片质量 | `standard` (暂不支持自定义) | -| `response_format` | string | 响应格式 | `url`, `b64_json` | -| `style` | string | 风格 | - (暂不支持) | - -注:`quality`、`style` 参数为 OpenAI 兼容保留,当前版本暂不支持自定义。 -当开启 `grok.image_ws=true` 时,`size` 将映射为宽高比(仅支持 5 种:`16:9`、`9:16`、`1:1`、`2:3`、`3:2`),也可以直接传以上比例字符串: -`1024x576/1280x720/1536x864 -> 16:9`,`576x1024/720x1280/864x1536 -> 9:16`,`1024x1024/512x512 -> 1:1`,`1024x1536/512x768/768x1024 -> 2:3`,`1536x1024/768x512/1024x768 -> 3:2`,其他值默认 `2:3`。 +| 字段 | 类型 | 说明 | 可用参数 | +| :------------------ | :------ | :--------------- | :------------------------------------------------------------ | +| `model` | string | 图像模型名 | `grok-imagine-1.0` | +| `prompt` | string | 图像描述提示词 | - | +| `n` | integer | 生成数量 | `1` - `10` (流式模式仅限 `1` 或 `2`) | +| `stream` | boolean | 是否开启流式输出 | `true`, `false` | +| `size` | string | 图片尺寸 | `1280x720`, `720x1280`, `1792x1024`, `1024x1792`, `1024x1024` | +| `quality` | string | 图片质量 | - (暂不支持) | +| `response_format` | string | 响应格式 | `url`, `b64_json`, `base64` | +| `style` | string | 风格 | - | + +**注意事项**: +- `quality`、`style` 参数为 OpenAI 兼容保留,当前版本暂不支持自定义。
@@ -220,19 +231,20 @@ curl http://localhost:8000/v1/images/edits \
-| 字段 | 类型 | 说明 | 可用参数 | -| :------------------ | :------ | :--------------- | :------------------------------------------- | -| `model` | string | 图像模型名 | `grok-imagine-1.0-edit` | -| `prompt` | string | 编辑描述 | - | -| `image` | file | 待编辑图片 | `png`, `jpg`, `webp` | -| `n` | integer | 生成数量 | `1` - `10` (流式模式仅限 `1` 或 `2`) | -| `stream` | boolean | 是否开启流式输出 | `true`, `false` | -| `size` | string | 图片尺寸 | `1024x1024` (暂不支持自定义) | -| `quality` | string | 图片质量 | `standard` (暂不支持自定义) | -| `response_format` | string | 响应格式 | `url`, `b64_json` | -| `style` | string | 风格 | - (暂不支持) | - -注:`size`、`quality`、`style` 参数为 OpenAI 兼容保留,当前版本暂不支持自定义 +| 字段 | 类型 | 说明 | 可用参数 | +| :------------------ | :------ | :--------------- | :------------------------------------------------------------ | +| `model` | string | 图像模型名 | `grok-imagine-1.0-edit` | +| `prompt` | string | 编辑描述 | - | +| `image` | file | 待编辑图片 | `png`, `jpg`, `webp` | +| `n` | integer | 生成数量 | `1` - `10` (流式模式仅限 `1` 或 `2`) | +| `stream` | boolean | 是否开启流式输出 | `true`, `false` | +| `size` | string | 图片尺寸 | `1280x720`, `720x1280`, `1792x1024`, `1024x1792`, `1024x1024` | +| `quality` | string | 图片质量 | - (暂不支持) | +| `response_format` | string | 响应格式 | `url`, `b64_json`, `base64` | +| `style` | string | 风格 | - (暂不支持) | + +**注意事项**: +- `quality`、`style` 参数为 OpenAI 兼容保留,当前版本暂不支持自定义。
@@ -252,57 +264,67 @@ curl http://localhost:8000/v1/images/edits \ > **v2.0 配置结构升级**:旧版本用户更新后,配置会**自动迁移**到新结构,无需手动修改。 > 旧的 `[grok]` 配置节中的自定义值会自动映射到对应的新配置节。 -| 模块 | 字段 | 配置名 | 说明 | 默认值 | +| 模块 | 字段 | 配置名 | 说明 | 默认值 | | :-------------------- | :------------------------------- | :----------------- | :---------------------------------------------------- | :-------------------------------------------------------- | -| **app** | `app_url` | 应用地址 | 当前 Grok2API 服务的外部访问 URL,用于文件链接访问。 | `http://127.0.0.1:8000` | -| | `app_key` | 后台密码 | 登录 Grok2API 管理后台的密码(必填)。 | `grok2api` | -| | `api_key` | API 密钥 | 调用 Grok2API 服务的 Token(可选)。 | `""` | -| | `image_format` | 图片格式 | 生成的图片格式(url 或 base64)。 | `url` | -| | `video_format` | 视频格式 | 生成的视频格式(html 或 url,url 为处理后的链接)。 | `html` | -| **network** | `timeout` | 请求超时 | 请求 Grok 服务的超时时间(秒)。 | `120` | -| | `base_proxy_url` | 基础代理 URL | 代理请求到 Grok 官网的基础服务地址。 | `""` | -| | `asset_proxy_url` | 资源代理 URL | 代理请求到 Grok 官网的静态资源(图片/视频)地址。 | `""` | -| **security** | `cf_clearance` | CF Clearance | Cloudflare 验证 Cookie,用于绕过反爬虫验证。 | `""` | -| | `browser` | 浏览器指纹 | curl_cffi 浏览器指纹标识(如 chrome136)。 | `chrome136` | -| | `user_agent` | User-Agent | HTTP 请求的 User-Agent 字符串。 | `Mozilla/5.0 (Macintosh; ...)` | -| **chat** | `temporary` | 临时对话 | 是否启用临时对话模式。 | `true` | -| | `disable_memory` | 禁用记忆 | 禁用 Grok 记忆功能,防止响应中出现不相关上下文。 | `true` | -| | `stream` | 流式响应 | 是否默认启用流式输出。 | `true` | -| | `thinking` | 思维链 | 是否启用模型思维链输出。 | `true` | -| | `dynamic_statsig` | 动态指纹 | 是否启用动态生成 Statsig 值。 | `true` | -| | `filter_tags` | 过滤标签 | 自动过滤 Grok 响应中的特殊标签。 | `["xaiartifact", "xai:tool_usage_card", "grok:render"]` | -| **retry** | `max_retry` | 最大重试 | 请求 Grok 服务失败时的最大重试次数。 | `3` | -| | `retry_status_codes` | 重试状态码 | 触发重试的 HTTP 状态码列表。 | `[401, 429, 403]` | -| | `retry_backoff_base` | 退避基数 | 重试退避的基础延迟(秒)。 | `0.5` | -| | `retry_backoff_factor` | 退避倍率 | 重试退避的指数放大系数。 | `2.0` | -| | `retry_backoff_max` | 退避上限 | 单次重试等待的最大延迟(秒)。 | `30.0` | -| | `retry_budget` | 退避预算 | 单次请求的最大重试总耗时(秒)。 | `90.0` | -| **timeout** | `stream_idle_timeout` | 流空闲超时 | 流式响应空闲超时(秒),超过将断开。 | `120.0` | -| | `video_idle_timeout` | 视频空闲超时 | 视频生成空闲超时(秒),超过将断开。 | `90.0` | -| **image** | `image_ws` | WebSocket 生成 | 启用后 `/v1/images/generations` 走 WebSocket 直连。 | `true` | -| | `image_ws_nsfw` | NSFW 模式 | WebSocket 请求是否启用 NSFW。 | `true` | -| | `image_ws_blocked_seconds` | Blocked 阈值 | 收到中等图后超过该秒数仍无最终图则判定 blocked。 | `15` | -| | `image_ws_final_min_bytes` | 最终图最小字节 | 判定最终图的最小字节数(通常 JPG > 100KB)。 | `100000` | -| | `image_ws_medium_min_bytes` | 中等图最小字节 | 判定中等质量图的最小字节数。 | `30000` | -| **token** | `auto_refresh` | 自动刷新 | 是否开启 Token 自动刷新机制。 | `true` | -| | `refresh_interval_hours` | 刷新间隔 | 普通 Token 刷新的时间间隔(小时)。 | `8` | -| | `super_refresh_interval_hours` | Super 刷新间隔 | Super Token 刷新的时间间隔(小时)。 | `2` | -| | `fail_threshold` | 失败阈值 | 单个 Token 连续失败多少次后被标记为不可用。 | `5` | -| | `save_delay_ms` | 保存延迟 | Token 变更合并写入的延迟(毫秒)。 | `500` | -| | `reload_interval_sec` | 同步间隔 | 多 worker 场景下 Token 状态刷新间隔(秒)。 | `30` | -| **cache** | `enable_auto_clean` | 自动清理 | 是否启用缓存自动清理,开启后按上限自动回收。 | `true` | -| | `limit_mb` | 清理阈值 | 缓存大小阈值(MB),超过阈值会触发清理。 | `1024` | -| **performance** | `media_max_concurrent` | Media 并发上限 | 视频/媒体生成请求的并发上限。推荐 50。 | `50` | -| | `assets_max_concurrent` | Assets 并发上限 | 批量查找/删除资产时的并发请求上限。推荐 25。 | `25` | -| | `assets_batch_size` | Assets 批次大小 | 批量查找/删除资产时的单批处理数量。推荐 10。 | `10` | -| | `assets_max_tokens` | Assets 最大数量 | 单次批量查找/删除资产时的处理数量上限。推荐 1000。 | `1000` | -| | `assets_delete_batch_size` | Assets 删除批量 | 单账号批量删除资产时的单批并发数量。推荐 10。 | `10` | -| | `usage_max_concurrent` | Token 刷新并发上限 | 批量刷新 Token 用量时的并发请求上限。推荐 25。 | `25` | -| | `usage_batch_size` | Token 刷新批次大小 | 批量刷新 Token 用量的单批处理数量。推荐 50。 | `50` | -| | `usage_max_tokens` | Token 刷新最大数量 | 单次批量刷新 Token 用量时的处理数量上限。推荐 1000。 | `1000` | -| | `nsfw_max_concurrent` | NSFW 开启并发上限 | 批量开启 NSFW 模式时的并发请求上限。推荐 10。 | `10` | -| | `nsfw_batch_size` | NSFW 开启批次大小 | 批量开启 NSFW 模式的单批处理数量。推荐 50。 | `50` | -| | `nsfw_max_tokens` | NSFW 开启最大数量 | 单次批量开启 NSFW 的 Token 数量上限。推荐 1000。 | `1000` | +| **app** | `app_url` | 应用地址 | 当前 Grok2API 服务的外部访问 URL,用于文件链接访问。 | `http://127.0.0.1:8000` | +| | `app_key` | 后台密码 | 登录 Grok2API 管理后台的密码(必填)。 | `grok2api` | +| | `api_key` | API 密钥 | 调用 Grok2API 服务的 Token(可选)。 | `""` | +| | `image_format` | 图片格式 | 生成的图片格式(url 或 base64)。 | `url` | +| | `video_format` | 视频格式 | 生成的视频格式(html 或 url,url 为处理后的链接)。 | `html` | +| | `temporary` | 临时对话 | 是否启用临时对话模式。 | `true` | +| | `disable_memory` | 禁用记忆 | 禁用 Grok 记忆功能,防止响应中出现不相关上下文。 | `true` | +| | `stream` | 流式响应 | 是否默认启用流式输出。 | `true` | +| | `thinking` | 思维链 | 是否启用模型思维链输出。 | `true` | +| | `dynamic_statsig` | 动态指纹 | 是否启用动态生成 Statsig 值。 | `true` | +| | `filter_tags` | 过滤标签 | 自动过滤 Grok 响应中的特殊标签。 | `["xaiartifact", "xai:tool_usage_card", "grok:render"]` | +| **proxy** | `base_proxy_url` | 基础代理 URL | 代理请求到 Grok 官网的基础服务地址。 | `""` | +| | `asset_proxy_url` | 资源代理 URL | 代理请求到 Grok 官网的静态资源(图片/视频)地址。 | `""` | +| | `cf_clearance` | CF Clearance | Cloudflare 验证 Cookie,用于绕过反爬虫验证。 | `""` | +| | `browser` | 浏览器指纹 | curl_cffi 浏览器指纹标识(如 chrome136)。 | `chrome136` | +| | `user_agent` | User-Agent | HTTP 请求的 User-Agent 字符串。 | `Mozilla/5.0 (Macintosh; ...)` | +| **voice** | `timeout` | 请求超时 | Voice 请求超时时间(秒)。 | `120` | +| **chat** | `concurrent` | 并发上限 | Reverse 接口并发上限。 | `10` | +| | `timeout` | 请求超时 | Reverse 接口超时时间(秒)。 | `60` | +| | `stream_timeout` | 流空闲超时 | 流式空闲超时时间(秒)。 | `60` | +| **video** | `concurrent` | 并发上限 | Reverse 接口并发上限。 | `10` | +| | `timeout` | 请求超时 | Reverse 接口超时时间(秒)。 | `60` | +| | `stream_timeout` | 流空闲超时 | 流式空闲超时时间(秒)。 | `60` | +| **retry** | `max_retry` | 最大重试 | 请求 Grok 服务失败时的最大重试次数。 | `3` | +| | `retry_status_codes` | 重试状态码 | 触发重试的 HTTP 状态码列表。 | `[401, 429, 403]` | +| | `retry_backoff_base` | 退避基数 | 重试退避的基础延迟(秒)。 | `0.5` | +| | `retry_backoff_factor` | 退避倍率 | 重试退避的指数放大系数。 | `2.0` | +| | `retry_backoff_max` | 退避上限 | 单次重试等待的最大延迟(秒)。 | `30.0` | +| | `retry_budget` | 退避预算 | 单次请求的最大重试总耗时(秒)。 | `90.0` | +| **image** | `timeout` | 请求超时 | WebSocket 请求超时时间(秒)。 | `120` | +| | `stream_timeout` | 流空闲超时 | WebSocket 流式空闲超时时间(秒)。 | `120` | +| | `final_timeout` | 最终图超时 | 收到中等图后等待最终图的超时秒数。 | `15` | +| | `nsfw` | NSFW 模式 | WebSocket 请求是否启用 NSFW。 | `true` | +| | `medium_min_bytes` | 中等图最小字节 | 判定中等质量图的最小字节数。 | `30000` | +| | `final_min_bytes` | 最终图最小字节 | 判定最终图的最小字节数(通常 JPG > 100KB)。 | `100000` | +| **token** | `auto_refresh` | 自动刷新 | 是否开启 Token 自动刷新机制。 | `true` | +| | `refresh_interval_hours` | 刷新间隔 | 普通 Token 刷新的时间间隔(小时)。 | `8` | +| | `super_refresh_interval_hours` | Super 刷新间隔 | Super Token 刷新的时间间隔(小时)。 | `2` | +| | `fail_threshold` | 失败阈值 | 单个 Token 连续失败多少次后被标记为不可用。 | `5` | +| | `save_delay_ms` | 保存延迟 | Token 变更合并写入的延迟(毫秒)。 | `500` | +| | `reload_interval_sec` | 同步间隔 | 多 worker 场景下 Token 状态刷新间隔(秒)。 | `30` | +| **cache** | `enable_auto_clean` | 自动清理 | 是否启用缓存自动清理,开启后按上限自动回收。 | `true` | +| | `limit_mb` | 清理阈值 | 缓存大小阈值(MB),超过阈值会触发清理。 | `1024` | +| **asset** | `upload_concurrent` | 上传并发 | 上传接口的最大并发数。推荐 30。 | `30` | +| | `upload_timeout` | 上传超时 | 上传接口超时时间(秒)。推荐 60。 | `60` | +| | `download_concurrent` | 下载并发 | 下载接口的最大并发数。推荐 30。 | `30` | +| | `download_timeout` | 下载超时 | 下载接口超时时间(秒)。推荐 60。 | `60` | +| | `list_concurrent` | 查询并发 | 资产查询接口的最大并发数。推荐 10。 | `10` | +| | `list_timeout` | 查询超时 | 资产查询接口超时时间(秒)。推荐 60。 | `60` | +| | `list_batch_size` | 查询批次大小 | 单次查询可处理的 Token 数量。推荐 10。 | `10` | +| | `delete_concurrent` | 删除并发 | 资产删除接口的最大并发数。推荐 10。 | `10` | +| | `delete_timeout` | 删除超时 | 资产删除接口超时时间(秒)。推荐 60。 | `60` | +| | `delete_batch_size` | 删除批次大小 | 单次删除可处理的 Token 数量。推荐 10。 | `10` | +| **nsfw** | `concurrent` | 并发上限 | 批量开启 NSFW 模式时的并发请求上限。推荐 10。 | `10` | +| | `batch_size` | 批次大小 | 批量开启 NSFW 模式的单批处理数量。推荐 50。 | `50` | +| | `timeout` | 请求超时 | NSFW 开启相关请求的超时时间(秒)。推荐 60。 | `60` | +| **usage** | `concurrent` | 并发上限 | 批量刷新用量时的并发请求上限。推荐 10。 | `10` | +| | `batch_size` | 批次大小 | 批量刷新用量的单批处理数量。推荐 50。 | `50` | +| | `timeout` | 请求超时 | 用量查询接口的超时时间(秒)。推荐 60。 | `60` |
diff --git a/tests/test_model.py b/tests/test_model.py deleted file mode 100644 index b4e96e78..00000000 --- a/tests/test_model.py +++ /dev/null @@ -1,463 +0,0 @@ -import argparse -import asyncio -import json -import os -import sys -from datetime import datetime, timezone -from pathlib import Path -from typing import Any, Iterable - -from dotenv import load_dotenv - -ROOT = Path(__file__).resolve().parents[1] -if str(ROOT) not in sys.path: - sys.path.insert(0, str(ROOT)) - -load_dotenv(ROOT / ".env") - -from app.core.config import config -from app.core.logger import setup_logging -from app.services.grok.services.chat import GrokChatService -from app.services.grok.processors import CollectProcessor -from app.services.grok.services.usage import UsageService - - -async def _fetch_usage(token: str, timeout: float) -> int | None: - usage = UsageService() - try: - result = await asyncio.wait_for(usage.get(token), timeout=timeout) - except Exception as exc: - print(f"Usage fetch failed: {exc}") - return None - try: - return int(result.get("remainingTokens")) - except Exception: - return None - - -async def _run_once( - model: str, - mode: str, - token: str, - message: str, - timeout: float, - lock: asyncio.Lock | None = None, -) -> tuple[bool, int | None, int | None]: - async def _execute() -> tuple[bool, int | None, int | None]: - service = GrokChatService() - before = await _fetch_usage(token, timeout) - http_ok = False - try: - print(f"Requesting model={model} mode={mode} ...") - response = await asyncio.wait_for( - service.chat( - token=token, - message=message, - model=model, - mode=mode, - think=False, - stream=False, - ), - timeout=timeout, - ) - http_ok = True - except Exception as exc: - print(f"Request failed: {exc}") - after = await _fetch_usage(token, timeout) - return False, before, after - - processor = CollectProcessor(model, token) - try: - print("Collecting response ...") - result = await asyncio.wait_for(processor.process(response), timeout=timeout) - except Exception as exc: - print(f"Collect failed: {exc}") - after = await _fetch_usage(token, timeout) - return False, before, after - - content = ( - result.get("choices", [{}])[0] - .get("message", {}) - .get("content", "") - .strip() - ) - after = await _fetch_usage(token, timeout) - ok = http_ok and bool(content) - return ok, before, after - - if lock is None: - return await _execute() - async with lock: - return await _execute() - - -async def _run_test( - grok_model: str, - model_mode: str, - basic_token: str, - super_token: str, - message: str, - out_path: str, - timeout: float, - model_id: str | None = None, - lock_map: dict[str, asyncio.Lock] | None = None, - load_config: bool = True, -) -> tuple[dict, bool]: - if load_config: - await config.load() - - basic_lock = lock_map.get(basic_token) if lock_map else None - super_lock = lock_map.get(super_token) if lock_map else None - - print("Testing basic token ...") - basic_task = asyncio.create_task( - _run_once(grok_model, model_mode, basic_token, message, timeout, basic_lock) - ) - print("Testing super token ...") - super_task = asyncio.create_task( - _run_once(grok_model, model_mode, super_token, message, timeout, super_lock) - ) - basic_ok, basic_before, basic_after = await basic_task - super_ok, super_before, super_after = await super_task - - basic_delta = ( - (basic_before - basic_after) - if (basic_before is not None and basic_after is not None) - else None - ) - super_delta = ( - (super_before - super_after) - if (super_before is not None and super_after is not None) - else None - ) - - cost_guess = _guess_cost(basic_delta, super_delta) - - payload = { - "model_id": model_id or grok_model, - "grok_model": grok_model, - "model_mode": model_mode, - "basic": { - "ok": bool(basic_ok), - "before": basic_before, - "after": basic_after, - "delta": basic_delta, - }, - "super": { - "ok": bool(super_ok), - "before": super_before, - "after": super_after, - "delta": super_delta, - }, - "cost_guess": cost_guess, - "timestamp": datetime.now(timezone.utc).isoformat(), - } - - ok = bool(basic_ok and super_ok) - if out_path: - _append_result(out_path, payload) - print(f"Appended results to {out_path}") - return payload, ok - - -def _guess_cost(basic_delta: int | None, super_delta: int | None) -> str | None: - for delta in (basic_delta, super_delta): - if delta is None: - continue - return "high" if delta >= 4 else "low" - return None - - -def _load_tokens_file(path: str) -> dict: - if not path: - return {} - file_path = Path(path) - if not file_path.exists(): - return {} - try: - with file_path.open("r", encoding="utf-8") as f: - data = json.load(f) - if isinstance(data, dict): - return data - except Exception as exc: - print(f"Failed to read tokens file: {exc}") - return {} - - -def _prompt_if_missing(value: str, label: str) -> str: - if value: - return value - return input(f"{label}: ").strip() - - -def _append_result(out_path: str, payload: dict) -> None: - out_file = Path(out_path) - data = [] - if out_file.exists(): - try: - with out_file.open("r", encoding="utf-8") as f: - existing = json.load(f) - if isinstance(existing, list): - data = existing - elif isinstance(existing, dict): - data = [existing] - except Exception as exc: - print(f"Failed to read existing results, overwrite: {exc}") - data = [] - if isinstance(payload, list): - data.extend(payload) - else: - data.append(payload) - with out_file.open("w", encoding="utf-8") as f: - json.dump(data, f, ensure_ascii=True, indent=2) - - -def _normalize_matrix_item(item: Any) -> dict[str, Any]: - if isinstance(item, (list, tuple)): - if len(item) < 3: - return {} - return { - "grok_model": str(item[0]).strip(), - "model_mode": str(item[1]).strip(), - "model_id": str(item[2]).strip(), - } - if isinstance(item, dict): - grok_model = item.get("grok_model") or item.get("model") or item.get("grok") - model_mode = item.get("model_mode") or item.get("mode") - model_id = item.get("model_id") or item.get("id") or item.get("name") - tier = item.get("tier") - if not (grok_model and model_mode and model_id): - return {} - normalized = { - "grok_model": str(grok_model).strip(), - "model_mode": str(model_mode).strip(), - "model_id": str(model_id).strip(), - } - if tier: - normalized["tier"] = str(tier).strip() - return normalized - return {} - - -def _parse_matrix_text(text: str) -> list[dict[str, Any]]: - items: list[dict[str, Any]] = [] - for line in text.splitlines(): - line = line.strip() - if not line or line.startswith("#"): - continue - parts = line.split() - if len(parts) < 3: - continue - entry = { - "grok_model": parts[0], - "model_mode": parts[1], - "model_id": parts[2], - } - if len(parts) >= 4: - entry["tier"] = parts[3] - items.append(entry) - return items - - -def _load_matrix(matrix_file: str | None, matrix_inline: str | None) -> list[dict]: - if matrix_inline: - try: - data = json.loads(matrix_inline) - if isinstance(data, list): - items = [_normalize_matrix_item(x) for x in data] - return [x for x in items if x] - except Exception: - return _parse_matrix_text(matrix_inline) - - if matrix_file: - file_path = Path(matrix_file) - if file_path.exists(): - text = file_path.read_text(encoding="utf-8").strip() - if not text: - return [] - try: - data = json.loads(text) - if isinstance(data, list): - items = [_normalize_matrix_item(x) for x in data] - return [x for x in items if x] - except Exception: - return _parse_matrix_text(text) - return [] - - -def _build_token_locks(tokens: Iterable[str]) -> dict[str, asyncio.Lock]: - locks: dict[str, asyncio.Lock] = {} - for token in tokens: - if token and token not in locks: - locks[token] = asyncio.Lock() - return locks - - -def _format_model_list(results: Iterable[dict]) -> str: - lines = [] - for item in results: - model_id = item.get("model_id") or item.get("grok_model") or "" - grok_model = item.get("grok_model") or "" - model_mode = item.get("model_mode") or "" - tier = item.get("tier") - cost_guess = item.get("cost_guess") - cost = "Cost.HIGH" if cost_guess == "high" else "Cost.LOW" - display_name = model_id.upper() if model_id else "" - - lines.append(" ModelInfo(") - lines.append(f' model_id="{model_id}",') - lines.append(f' grok_model="{grok_model}",') - lines.append(f' model_mode="{model_mode}",') - if tier and str(tier).upper() == "SUPER": - lines.append(" tier=Tier.SUPER,") - lines.append(f" cost={cost},") - lines.append(f' display_name="{display_name}",') - lines.append(" ),") - lines.append("") - - return "\n".join(lines).rstrip() - - -async def _run_matrix( - matrix: list[dict], - basic_token: str, - super_token: str, - message: str, - out_path: str, - timeout: float, - max_concurrent: int, -) -> tuple[list[dict], bool]: - await config.load() - locks = _build_token_locks([basic_token, super_token]) - sem = asyncio.Semaphore(max(1, int(max_concurrent))) - - async def _one(entry: dict) -> tuple[dict, bool]: - async with sem: - payload, ok = await _run_test( - entry["grok_model"], - entry["model_mode"], - basic_token, - super_token, - message, - out_path="", - timeout=timeout, - model_id=entry.get("model_id"), - lock_map=locks, - load_config=False, - ) - if entry.get("tier"): - payload["tier"] = entry["tier"] - return payload, ok - - tasks = [_one(entry) for entry in matrix] - pairs = await asyncio.gather(*tasks) - results = [payload for payload, _ in pairs] - all_ok = all(ok for _, ok in pairs) - - if out_path: - _append_result(out_path, results) - print(f"Appended results to {out_path}") - return results, all_ok - - -def main() -> int: - parser = argparse.ArgumentParser( - description="Test Grok model by grok_model and model_mode using basic/super tokens." - ) - parser.add_argument("grok_model", nargs="?", help="e.g. grok-4-1-thinking-1129") - parser.add_argument("model_mode", nargs="?", help="e.g. MODEL_MODE_GROK_4_1_THINKING") - parser.add_argument("--model-id", dest="model_id", help="model id for output") - parser.add_argument("--basic-token", dest="basic_token", help="basic account token") - parser.add_argument("--super-token", dest="super_token", help="super account token") - parser.add_argument( - "--tokens-file", - default="data/model_tokens.json", - help="path to tokens json file", - ) - parser.add_argument("--matrix", help="inline JSON or line-based model list") - parser.add_argument("--matrix-file", help="path to model list (json or text)") - parser.add_argument( - "--max-concurrent", - type=int, - default=2, - help="max concurrent model tests", - ) - parser.add_argument( - "--emit-model-list", - action="store_true", - help="print ModelInfo list snippet", - ) - parser.add_argument("--emit-model-list-out", help="write ModelInfo list snippet") - parser.add_argument("--message", default="Ping", help="test prompt") - parser.add_argument("--out", default="model.json", help="output json path") - parser.add_argument("--timeout", type=float, default=120, help="timeout seconds") - parser.add_argument("--log-level", help="log level (overrides LOG_LEVEL)") - args = parser.parse_args() - - tokens_file_data = _load_tokens_file(args.tokens_file) - basic_token = _prompt_if_missing( - args.basic_token - or tokens_file_data.get("basic_token", "") - or os.getenv("BASIC_TOKEN", ""), - "basic_token", - ) - super_token = _prompt_if_missing( - args.super_token - or tokens_file_data.get("super_token", "") - or os.getenv("SUPER_TOKEN", ""), - "super_token", - ) - if not basic_token or not super_token: - print("basic_token and super_token are required.") - return 2 - - log_level = args.log_level or os.getenv("LOG_LEVEL", "INFO") - setup_logging(level=log_level, json_console=False, file_logging=False) - matrix = _load_matrix(args.matrix_file, args.matrix) - if matrix: - results, ok = asyncio.run( - _run_matrix( - matrix, - basic_token, - super_token, - args.message, - args.out, - args.timeout, - args.max_concurrent, - ) - ) - if args.emit_model_list or args.emit_model_list_out: - snippet = _format_model_list(results) - if args.emit_model_list_out: - Path(args.emit_model_list_out).write_text( - snippet + "\n", encoding="utf-8" - ) - print(f"Model list written to {args.emit_model_list_out}") - else: - print(snippet) - return 0 if ok else 1 - - grok_model = args.grok_model or os.getenv("GROK_MODEL", "") - model_mode = args.model_mode or os.getenv("MODEL_MODE", "") - if not grok_model: - print("grok_model is required.") - return 2 - - _payload, ok = asyncio.run( - _run_test( - grok_model, - model_mode, - basic_token, - super_token, - args.message, - args.out, - args.timeout, - model_id=args.model_id, - lock_map=_build_token_locks([basic_token, super_token]), - ) - ) - return 0 if ok else 1 - - -if __name__ == "__main__": - raise SystemExit(main()) From 928faabbac422fb7d715d3969ea239460619ddef Mon Sep 17 00:00:00 2001 From: Chenyme <118253778+chenyme@users.noreply.github.com> Date: Sun, 15 Feb 2026 17:43:57 +0800 Subject: [PATCH 22/27] refactor: update configuration settings, enhance public API integration, and streamline service imports while removing deprecated files --- app/api/pages/__init__.py | 13 + app/api/pages/admin.py | 32 + app/api/pages/public.py | 44 ++ app/api/v1/pages.py | 94 --- app/api/v1/public/__init__.py | 2 + app/api/v1/public/imagine.py | 154 +++-- app/api/v1/public/video.py | 274 ++++++++ app/core/response_middleware.py | 14 + app/services/grok/defaults.py | 110 +-- app/services/grok/services/chat.py | 52 +- app/services/grok/services/image.py | 15 +- app/services/grok/services/video.py | 105 ++- app/services/reverse/__init__.py | 2 + app/services/reverse/media_post.py | 17 +- app/services/reverse/video_upscale.py | 109 +++ app/services/token/manager.py | 8 + app/static/{cache => admin/css}/cache.css | 0 app/static/{config => admin/css}/config.css | 12 + app/static/{token => admin/css}/token.css | 0 app/static/{cache => admin/js}/cache.js | 0 app/static/{config => admin/js}/config.js | 62 +- app/static/{login => admin/js}/login.js | 0 app/static/{token => admin/js}/token.js | 0 app/static/{cache => admin/pages}/cache.html | 22 +- .../{config => admin/pages}/config.html | 18 +- app/static/{login => admin/pages}/login.html | 16 +- app/static/{token => admin/pages}/token.html | 22 +- app/static/common/{ => css}/common.css | 0 app/static/{login => common/css}/login.css | 0 app/static/common/{ => css}/toast.css | 0 app/static/common/{ => html}/footer.html | 0 app/static/common/{ => html}/header.html | 14 +- .../common/{ => html}/public-header.html | 4 +- .../{ => common/img}/favicon/favicon.ico | Bin app/static/common/{ => js}/admin-auth.js | 0 app/static/common/{ => js}/batch-sse.js | 0 app/static/common/{ => js}/draggable.js | 0 app/static/common/{ => js}/footer.js | 2 +- app/static/common/{ => js}/header.js | 2 +- app/static/common/{ => js}/public-header.js | 14 +- app/static/common/{ => js}/toast.js | 0 .../{imagine => public/css}/imagine.css | 189 +++++- app/static/public/css/video.css | 482 +++++++++++++ app/static/{voice => public/css}/voice.css | 0 app/static/{imagine => public/js}/imagine.js | 235 ++++++- app/static/public/{ => js}/login.js | 0 app/static/public/js/video.js | 640 ++++++++++++++++++ app/static/{voice => public/js}/voice.js | 0 .../{imagine => public/pages}/imagine.html | 125 ++-- app/static/public/{ => pages}/login.html | 16 +- app/static/public/pages/video.html | 168 +++++ app/static/{voice => public/pages}/voice.html | 18 +- config.defaults.toml | 38 +- data/config.toml | 50 +- main.py | 2 +- 55 files changed, 2695 insertions(+), 501 deletions(-) create mode 100644 app/api/pages/__init__.py create mode 100644 app/api/pages/admin.py create mode 100644 app/api/pages/public.py delete mode 100644 app/api/v1/pages.py create mode 100644 app/api/v1/public/video.py create mode 100644 app/services/reverse/video_upscale.py rename app/static/{cache => admin/css}/cache.css (100%) rename app/static/{config => admin/css}/config.css (80%) rename app/static/{token => admin/css}/token.css (100%) rename app/static/{cache => admin/js}/cache.js (100%) rename app/static/{config => admin/js}/config.js (85%) rename app/static/{login => admin/js}/login.js (100%) rename app/static/{token => admin/js}/token.js (100%) rename app/static/{cache => admin/pages}/cache.html (93%) rename app/static/{config => admin/pages}/config.html (75%) rename app/static/{login => admin/pages}/login.html (82%) rename app/static/{token => admin/pages}/token.html (95%) rename app/static/common/{ => css}/common.css (100%) rename app/static/{login => common/css}/login.css (100%) rename app/static/common/{ => css}/toast.css (100%) rename app/static/common/{ => html}/footer.html (100%) rename app/static/common/{ => html}/header.html (73%) rename app/static/common/{ => html}/public-header.html (89%) rename app/static/{ => common/img}/favicon/favicon.ico (100%) rename app/static/common/{ => js}/admin-auth.js (100%) rename app/static/common/{ => js}/batch-sse.js (100%) rename app/static/common/{ => js}/draggable.js (100%) rename app/static/common/{ => js}/footer.js (85%) rename app/static/common/{ => js}/header.js (93%) rename app/static/common/{ => js}/public-header.js (59%) rename app/static/common/{ => js}/toast.js (100%) rename app/static/{imagine => public/css}/imagine.css (80%) create mode 100644 app/static/public/css/video.css rename app/static/{voice => public/css}/voice.css (100%) rename app/static/{imagine => public/js}/imagine.js (79%) rename app/static/public/{ => js}/login.js (100%) create mode 100644 app/static/public/js/video.js rename app/static/{voice => public/js}/voice.js (100%) rename app/static/{imagine => public/pages}/imagine.html (63%) rename app/static/public/{ => pages}/login.html (83%) create mode 100644 app/static/public/pages/video.html rename app/static/{voice => public/pages}/voice.html (92%) diff --git a/app/api/pages/__init__.py b/app/api/pages/__init__.py new file mode 100644 index 00000000..6aada6e5 --- /dev/null +++ b/app/api/pages/__init__.py @@ -0,0 +1,13 @@ +"""UI pages router.""" + +from fastapi import APIRouter + +from app.api.pages.admin import router as admin_router +from app.api.pages.public import router as public_router + +router = APIRouter() + +router.include_router(public_router) +router.include_router(admin_router) + +__all__ = ["router"] diff --git a/app/api/pages/admin.py b/app/api/pages/admin.py new file mode 100644 index 00000000..bb581e89 --- /dev/null +++ b/app/api/pages/admin.py @@ -0,0 +1,32 @@ +from pathlib import Path + +from fastapi import APIRouter +from fastapi.responses import FileResponse, RedirectResponse + +router = APIRouter() +STATIC_DIR = Path(__file__).resolve().parents[2] / "static" + + +@router.get("/admin", include_in_schema=False) +async def admin_root(): + return RedirectResponse(url="/admin/login") + + +@router.get("/admin/login", include_in_schema=False) +async def admin_login(): + return FileResponse(STATIC_DIR / "admin/pages/login.html") + + +@router.get("/admin/config", include_in_schema=False) +async def admin_config(): + return FileResponse(STATIC_DIR / "admin/pages/config.html") + + +@router.get("/admin/cache", include_in_schema=False) +async def admin_cache(): + return FileResponse(STATIC_DIR / "admin/pages/cache.html") + + +@router.get("/admin/token", include_in_schema=False) +async def admin_token(): + return FileResponse(STATIC_DIR / "admin/pages/token.html") diff --git a/app/api/pages/public.py b/app/api/pages/public.py new file mode 100644 index 00000000..0792df99 --- /dev/null +++ b/app/api/pages/public.py @@ -0,0 +1,44 @@ +from pathlib import Path + +from fastapi import APIRouter, HTTPException +from fastapi.responses import FileResponse, RedirectResponse + +from app.core.auth import is_public_enabled + +router = APIRouter() +STATIC_DIR = Path(__file__).resolve().parents[2] / "static" + + +@router.get("/", include_in_schema=False) +async def root(): + if is_public_enabled(): + return RedirectResponse(url="/login") + return RedirectResponse(url="/admin/login") + + +@router.get("/login", include_in_schema=False) +async def public_login(): + if not is_public_enabled(): + raise HTTPException(status_code=404, detail="Not Found") + return FileResponse(STATIC_DIR / "public/pages/login.html") + + +@router.get("/imagine", include_in_schema=False) +async def public_imagine(): + if not is_public_enabled(): + raise HTTPException(status_code=404, detail="Not Found") + return FileResponse(STATIC_DIR / "public/pages/imagine.html") + + +@router.get("/voice", include_in_schema=False) +async def public_voice(): + if not is_public_enabled(): + raise HTTPException(status_code=404, detail="Not Found") + return FileResponse(STATIC_DIR / "public/pages/voice.html") + + +@router.get("/video", include_in_schema=False) +async def public_video(): + if not is_public_enabled(): + raise HTTPException(status_code=404, detail="Not Found") + return FileResponse(STATIC_DIR / "public/pages/video.html") diff --git a/app/api/v1/pages.py b/app/api/v1/pages.py deleted file mode 100644 index cd9472c3..00000000 --- a/app/api/v1/pages.py +++ /dev/null @@ -1,94 +0,0 @@ -from pathlib import Path - -import aiofiles -from fastapi import APIRouter, HTTPException -from fastapi.responses import HTMLResponse, RedirectResponse - -from app.core.auth import is_public_enabled - -router = APIRouter() -TEMPLATE_DIR = Path(__file__).resolve().parents[2] / "static" - - -async def render_template(filename: str) -> HTMLResponse: - """渲染指定模板""" - template_path = TEMPLATE_DIR / filename - if not template_path.exists(): - return HTMLResponse(f"Template {filename} not found.", status_code=404) - - async with aiofiles.open(template_path, "r", encoding="utf-8") as f: - content = await f.read() - return HTMLResponse(content) - -@router.get("/", include_in_schema=False) -async def root_redirect(): - if is_public_enabled(): - return RedirectResponse(url="/login") - return RedirectResponse(url="/admin/login") - - -@router.get("/login", response_class=HTMLResponse, include_in_schema=False) -async def public_login_page(): - """Public 登录页""" - if not is_public_enabled(): - raise HTTPException(status_code=404, detail="Not Found") - return await render_template("public/login.html") - - -@router.get("/imagine", response_class=HTMLResponse, include_in_schema=False) -async def public_imagine_page(): - """Imagine 图片瀑布流""" - if not is_public_enabled(): - raise HTTPException(status_code=404, detail="Not Found") - return await render_template("imagine/imagine.html") - - -@router.get("/voice", response_class=HTMLResponse, include_in_schema=False) -async def public_voice_page(): - """Voice Live 调试页""" - if not is_public_enabled(): - raise HTTPException(status_code=404, detail="Not Found") - return await render_template("voice/voice.html") - - -@router.get("/admin", include_in_schema=False) -async def admin_root_redirect(): - return RedirectResponse(url="/admin/login") - - -@router.get("/admin/login", response_class=HTMLResponse, include_in_schema=False) -async def admin_login_page(): - """管理后台登录页""" - return await render_template("login/login.html") - - -@router.get("/admin/config", response_class=HTMLResponse, include_in_schema=False) -async def admin_config_page(): - """配置管理页""" - return await render_template("config/config.html") - - -@router.get("/admin/token", response_class=HTMLResponse, include_in_schema=False) -async def admin_token_page(): - """Token 管理页""" - return await render_template("token/token.html") - - -@router.get("/admin/voice", include_in_schema=False) -async def admin_voice_redirect(): - if not is_public_enabled(): - raise HTTPException(status_code=404, detail="Not Found") - return RedirectResponse(url="/voice") - - -@router.get("/admin/imagine", include_in_schema=False) -async def admin_imagine_redirect(): - if not is_public_enabled(): - raise HTTPException(status_code=404, detail="Not Found") - return RedirectResponse(url="/imagine") - - -@router.get("/admin/cache", response_class=HTMLResponse, include_in_schema=False) -async def admin_cache_page(): - """缓存管理页""" - return await render_template("cache/cache.html") diff --git a/app/api/v1/public/__init__.py b/app/api/v1/public/__init__.py index 0d4ab694..984bf0d3 100644 --- a/app/api/v1/public/__init__.py +++ b/app/api/v1/public/__init__.py @@ -3,11 +3,13 @@ from fastapi import APIRouter from app.api.v1.public.imagine import router as imagine_router +from app.api.v1.public.video import router as video_router from app.api.v1.public.voice import router as voice_router router = APIRouter() router.include_router(imagine_router) +router.include_router(video_router) router.include_router(voice_router) __all__ = ["router"] diff --git a/app/api/v1/public/imagine.py b/app/api/v1/public/imagine.py index 5e71436e..83f59c34 100644 --- a/app/api/v1/public/imagine.py +++ b/app/api/v1/public/imagine.py @@ -1,7 +1,7 @@ import asyncio import time import uuid -from typing import Optional, List +from typing import Optional, List, Dict, Any import orjson from fastapi import APIRouter, Depends, HTTPException, Query, Request, WebSocket, WebSocketDisconnect @@ -9,6 +9,7 @@ from pydantic import BaseModel from app.core.auth import verify_public_key, get_public_api_key, is_public_enabled +from app.core.config import get_config from app.core.logger import logger from app.api.v1.image import resolve_aspect_ratio from app.services.grok.services.image import ImageGenerationService @@ -32,7 +33,35 @@ async def _clean_sessions(now: float) -> None: _IMAGINE_SESSIONS.pop(key, None) -async def _new_session(prompt: str, aspect_ratio: str) -> str: +def _parse_sse_chunk(chunk: str) -> Optional[Dict[str, Any]]: + if not chunk: + return None + event = None + data_lines: List[str] = [] + for raw in str(chunk).splitlines(): + line = raw.strip() + if not line: + continue + if line.startswith("event:"): + event = line[6:].strip() + continue + if line.startswith("data:"): + data_lines.append(line[5:].strip()) + if not data_lines: + return None + data_str = "\n".join(data_lines) + if data_str == "[DONE]": + return None + try: + payload = orjson.loads(data_str) + except orjson.JSONDecodeError: + return None + if event and isinstance(payload, dict) and "type" not in payload: + payload["type"] = event + return payload + + +async def _new_session(prompt: str, aspect_ratio: str, nsfw: Optional[bool]) -> str: task_id = uuid.uuid4().hex now = time.time() async with _IMAGINE_SESSIONS_LOCK: @@ -40,6 +69,7 @@ async def _new_session(prompt: str, aspect_ratio: str) -> str: _IMAGINE_SESSIONS[task_id] = { "prompt": prompt, "aspect_ratio": aspect_ratio, + "nsfw": nsfw, "created_at": now, } return task_id @@ -126,7 +156,7 @@ async def _stop_run(): run_task = None stop_event.clear() - async def _run(prompt: str, aspect_ratio: str): + async def _run(prompt: str, aspect_ratio: str, nsfw: Optional[bool]): model_id = "grok-imagine-1.0" model_info = ModelService.get(model_id) if not model_info or not model_info.is_image: @@ -140,7 +170,6 @@ async def _run(prompt: str, aspect_ratio: str): return token_mgr = await get_token_manager() - sequence = 0 run_id = uuid.uuid4().hex await _send( @@ -175,7 +204,6 @@ async def _run(prompt: str, aspect_ratio: str): await asyncio.sleep(2) continue - start_at = time.time() result = await ImageGenerationService().generate( token_mgr=token_mgr, token=token, @@ -185,33 +213,38 @@ async def _run(prompt: str, aspect_ratio: str): response_format="b64_json", size="1024x1024", aspect_ratio=aspect_ratio, - stream=False, + stream=True, + enable_nsfw=nsfw, ) - elapsed_ms = int((time.time() - start_at) * 1000) - - images = [img for img in result.data if img and img != "error"] - if images: - for img_b64 in images: - sequence += 1 + if result.stream: + async for chunk in result.data: + payload = _parse_sse_chunk(chunk) + if not payload: + continue + if isinstance(payload, dict): + payload.setdefault("run_id", run_id) + await _send(payload) + else: + images = [img for img in result.data if img and img != "error"] + if images: + for img_b64 in images: + await _send( + { + "type": "image", + "b64_json": img_b64, + "created_at": int(time.time() * 1000), + "aspect_ratio": aspect_ratio, + "run_id": run_id, + } + ) + else: await _send( { - "type": "image", - "b64_json": img_b64, - "sequence": sequence, - "created_at": int(time.time() * 1000), - "elapsed_ms": elapsed_ms, - "aspect_ratio": aspect_ratio, - "run_id": run_id, + "type": "error", + "message": "Image generation returned empty data.", + "code": "empty_image", } ) - else: - await _send( - { - "type": "error", - "message": "Image generation returned empty data.", - "code": "empty_image", - } - ) except asyncio.CancelledError: break @@ -262,8 +295,11 @@ async def _run(prompt: str, aspect_ratio: str): aspect_ratio = resolve_aspect_ratio( str(payload.get("aspect_ratio") or "2:3").strip() or "2:3" ) + nsfw = payload.get("nsfw") + if nsfw is not None: + nsfw = bool(nsfw) await _stop_run() - run_task = asyncio.create_task(_run(prompt, aspect_ratio)) + run_task = asyncio.create_task(_run(prompt, aspect_ratio, nsfw)) elif action == "stop": await _stop_run() else: @@ -319,12 +355,16 @@ async def public_imagine_sse( if session: prompt = str(session.get("prompt") or "").strip() ratio = str(session.get("aspect_ratio") or "2:3").strip() or "2:3" + nsfw = session.get("nsfw") else: prompt = (prompt or "").strip() if not prompt: raise HTTPException(status_code=400, detail="Prompt cannot be empty") ratio = str(aspect_ratio or "2:3").strip() or "2:3" ratio = resolve_aspect_ratio(ratio) + nsfw = request.query_params.get("nsfw") + if nsfw is not None: + nsfw = str(nsfw).lower() in ("1", "true", "yes", "on") async def event_stream(): try: @@ -369,7 +409,6 @@ async def event_stream(): await asyncio.sleep(2) continue - start_at = time.time() result = await ImageGenerationService().generate( token_mgr=token_mgr, token=token, @@ -379,28 +418,35 @@ async def event_stream(): response_format="b64_json", size="1024x1024", aspect_ratio=ratio, - stream=False, + stream=True, + enable_nsfw=nsfw, ) - elapsed_ms = int((time.time() - start_at) * 1000) - - images = [img for img in result.data if img and img != "error"] - if images: - for img_b64 in images: - sequence += 1 - payload = { - "type": "image", - "b64_json": img_b64, - "sequence": sequence, - "created_at": int(time.time() * 1000), - "elapsed_ms": elapsed_ms, - "aspect_ratio": ratio, - "run_id": run_id, - } + if result.stream: + async for chunk in result.data: + payload = _parse_sse_chunk(chunk) + if not payload: + continue + if isinstance(payload, dict): + payload.setdefault("run_id", run_id) yield f"data: {orjson.dumps(payload).decode()}\n\n" else: - yield ( - f"data: {orjson.dumps({'type': 'error', 'message': 'Image generation returned empty data.', 'code': 'empty_image'}).decode()}\n\n" - ) + images = [img for img in result.data if img and img != "error"] + if images: + for img_b64 in images: + sequence += 1 + payload = { + "type": "image", + "b64_json": img_b64, + "sequence": sequence, + "created_at": int(time.time() * 1000), + "aspect_ratio": ratio, + "run_id": run_id, + } + yield f"data: {orjson.dumps(payload).decode()}\n\n" + else: + yield ( + f"data: {orjson.dumps({'type': 'error', 'message': 'Image generation returned empty data.', 'code': 'empty_image'}).decode()}\n\n" + ) except asyncio.CancelledError: break except Exception as e: @@ -424,9 +470,19 @@ async def event_stream(): ) +@router.get("/imagine/config") +async def public_imagine_config(): + return { + "final_min_bytes": int(get_config("image.final_min_bytes") or 0), + "medium_min_bytes": int(get_config("image.medium_min_bytes") or 0), + "nsfw": bool(get_config("image.nsfw")), + } + + class ImagineStartRequest(BaseModel): prompt: str aspect_ratio: Optional[str] = "2:3" + nsfw: Optional[bool] = None @router.post("/imagine/start", dependencies=[Depends(verify_public_key)]) @@ -435,7 +491,7 @@ async def public_imagine_start(data: ImagineStartRequest): if not prompt: raise HTTPException(status_code=400, detail="Prompt cannot be empty") ratio = resolve_aspect_ratio(str(data.aspect_ratio or "2:3").strip() or "2:3") - task_id = await _new_session(prompt, ratio) + task_id = await _new_session(prompt, ratio, data.nsfw) return {"task_id": task_id, "aspect_ratio": ratio} diff --git a/app/api/v1/public/video.py b/app/api/v1/public/video.py new file mode 100644 index 00000000..c88182c8 --- /dev/null +++ b/app/api/v1/public/video.py @@ -0,0 +1,274 @@ +import asyncio +import time +import uuid +from typing import Optional, List, Dict, Any + +import orjson +from fastapi import APIRouter, Depends, HTTPException, Query, Request +from fastapi.responses import StreamingResponse +from pydantic import BaseModel + +from app.core.auth import verify_public_key +from app.core.logger import logger +from app.services.grok.services.video import VideoService +from app.services.grok.services.model import ModelService + +router = APIRouter() + +VIDEO_SESSION_TTL = 600 +_VIDEO_SESSIONS: dict[str, dict] = {} +_VIDEO_SESSIONS_LOCK = asyncio.Lock() + +_VIDEO_RATIO_MAP = { + "1280x720": "16:9", + "720x1280": "9:16", + "1792x1024": "3:2", + "1024x1792": "2:3", + "1024x1024": "1:1", + "16:9": "16:9", + "9:16": "9:16", + "3:2": "3:2", + "2:3": "2:3", + "1:1": "1:1", +} + + +async def _clean_sessions(now: float) -> None: + expired = [ + key + for key, info in _VIDEO_SESSIONS.items() + if now - float(info.get("created_at") or 0) > VIDEO_SESSION_TTL + ] + for key in expired: + _VIDEO_SESSIONS.pop(key, None) + + +async def _new_session( + prompt: str, + aspect_ratio: str, + video_length: int, + resolution_name: str, + preset: str, + image_url: Optional[str], + reasoning_effort: Optional[str], +) -> str: + task_id = uuid.uuid4().hex + now = time.time() + async with _VIDEO_SESSIONS_LOCK: + await _clean_sessions(now) + _VIDEO_SESSIONS[task_id] = { + "prompt": prompt, + "aspect_ratio": aspect_ratio, + "video_length": video_length, + "resolution_name": resolution_name, + "preset": preset, + "image_url": image_url, + "reasoning_effort": reasoning_effort, + "created_at": now, + } + return task_id + + +async def _get_session(task_id: str) -> Optional[dict]: + if not task_id: + return None + now = time.time() + async with _VIDEO_SESSIONS_LOCK: + await _clean_sessions(now) + info = _VIDEO_SESSIONS.get(task_id) + if not info: + return None + created_at = float(info.get("created_at") or 0) + if now - created_at > VIDEO_SESSION_TTL: + _VIDEO_SESSIONS.pop(task_id, None) + return None + return dict(info) + + +async def _drop_session(task_id: str) -> None: + if not task_id: + return + async with _VIDEO_SESSIONS_LOCK: + _VIDEO_SESSIONS.pop(task_id, None) + + +async def _drop_sessions(task_ids: List[str]) -> int: + if not task_ids: + return 0 + removed = 0 + async with _VIDEO_SESSIONS_LOCK: + for task_id in task_ids: + if task_id and task_id in _VIDEO_SESSIONS: + _VIDEO_SESSIONS.pop(task_id, None) + removed += 1 + return removed + + +def _normalize_ratio(value: Optional[str]) -> str: + raw = (value or "").strip() + return _VIDEO_RATIO_MAP.get(raw, "") + + +def _validate_image_url(image_url: str) -> None: + value = (image_url or "").strip() + if not value: + return + if value.startswith("data:"): + return + if value.startswith("http://") or value.startswith("https://"): + return + raise HTTPException( + status_code=400, + detail="image_url must be a URL or data URI (data:;base64,...)", + ) + + +class VideoStartRequest(BaseModel): + prompt: str + aspect_ratio: Optional[str] = "3:2" + video_length: Optional[int] = 6 + resolution_name: Optional[str] = "480p" + preset: Optional[str] = "normal" + image_url: Optional[str] = None + reasoning_effort: Optional[str] = None + + +@router.post("/video/start", dependencies=[Depends(verify_public_key)]) +async def public_video_start(data: VideoStartRequest): + prompt = (data.prompt or "").strip() + if not prompt: + raise HTTPException(status_code=400, detail="Prompt cannot be empty") + + aspect_ratio = _normalize_ratio(data.aspect_ratio) + if not aspect_ratio: + raise HTTPException( + status_code=400, + detail="aspect_ratio must be one of ['16:9','9:16','3:2','2:3','1:1']", + ) + + video_length = int(data.video_length or 6) + if video_length not in (6, 10, 15): + raise HTTPException( + status_code=400, detail="video_length must be 6, 10, or 15 seconds" + ) + + resolution_name = str(data.resolution_name or "480p") + if resolution_name not in ("480p", "720p"): + raise HTTPException( + status_code=400, + detail="resolution_name must be one of ['480p','720p']", + ) + + preset = str(data.preset or "normal") + if preset not in ("fun", "normal", "spicy", "custom"): + raise HTTPException( + status_code=400, + detail="preset must be one of ['fun','normal','spicy','custom']", + ) + + image_url = (data.image_url or "").strip() or None + if image_url: + _validate_image_url(image_url) + + reasoning_effort = (data.reasoning_effort or "").strip() or None + if reasoning_effort: + allowed = {"none", "minimal", "low", "medium", "high", "xhigh"} + if reasoning_effort not in allowed: + raise HTTPException( + status_code=400, + detail=f"reasoning_effort must be one of {sorted(allowed)}", + ) + + task_id = await _new_session( + prompt, + aspect_ratio, + video_length, + resolution_name, + preset, + image_url, + reasoning_effort, + ) + return {"task_id": task_id, "aspect_ratio": aspect_ratio} + + +@router.get("/video/sse") +async def public_video_sse(request: Request, task_id: str = Query("")): + session = await _get_session(task_id) + if not session: + raise HTTPException(status_code=404, detail="Task not found") + + prompt = str(session.get("prompt") or "").strip() + aspect_ratio = str(session.get("aspect_ratio") or "3:2") + video_length = int(session.get("video_length") or 6) + resolution_name = str(session.get("resolution_name") or "480p") + preset = str(session.get("preset") or "normal") + image_url = session.get("image_url") + reasoning_effort = session.get("reasoning_effort") + + async def event_stream(): + try: + model_id = "grok-imagine-1.0-video" + model_info = ModelService.get(model_id) + if not model_info or not model_info.is_video: + payload = { + "error": "Video model is not available.", + "code": "model_not_supported", + } + yield f"data: {orjson.dumps(payload).decode()}\n\n" + yield "data: [DONE]\n\n" + return + + if image_url: + messages: List[Dict[str, Any]] = [ + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + {"type": "image_url", "image_url": {"url": image_url}}, + ], + } + ] + else: + messages = [{"role": "user", "content": prompt}] + + stream = await VideoService.completions( + model_id, + messages, + stream=True, + reasoning_effort=reasoning_effort, + aspect_ratio=aspect_ratio, + video_length=video_length, + resolution=resolution_name, + preset=preset, + ) + + async for chunk in stream: + if await request.is_disconnected(): + break + yield chunk + except Exception as e: + logger.warning(f"Public video SSE error: {e}") + payload = {"error": str(e), "code": "internal_error"} + yield f"data: {orjson.dumps(payload).decode()}\n\n" + yield "data: [DONE]\n\n" + finally: + await _drop_session(task_id) + + return StreamingResponse( + event_stream(), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, + ) + + +class VideoStopRequest(BaseModel): + task_ids: List[str] + + +@router.post("/video/stop", dependencies=[Depends(verify_public_key)]) +async def public_video_stop(data: VideoStopRequest): + removed = await _drop_sessions(data.task_ids or []) + return {"status": "success", "removed": removed} + + +__all__ = ["router"] diff --git a/app/core/response_middleware.py b/app/core/response_middleware.py index 2cfa8b3e..4c0a07ec 100644 --- a/app/core/response_middleware.py +++ b/app/core/response_middleware.py @@ -25,6 +25,20 @@ async def dispatch(self, request: Request, call_next): request.state.trace_id = trace_id start_time = time.time() + path = request.url.path + + if path.startswith("/static/") or path in ( + "/", + "/login", + "/imagine", + "/voice", + "/admin", + "/admin/login", + "/admin/config", + "/admin/cache", + "/admin/token", + ): + return await call_next(request) # 记录请求信息 logger.info( diff --git a/app/services/grok/defaults.py b/app/services/grok/defaults.py index 06a41584..d7af7eb7 100644 --- a/app/services/grok/defaults.py +++ b/app/services/grok/defaults.py @@ -1,101 +1,33 @@ """ Grok 服务默认配置 -此文件定义所有 Grok 相关服务的默认值,会在应用启动时注册到配置系统中。 +此文件读取 config.defaults.toml,作为 Grok 服务的默认值来源。 """ -# Grok 服务默认配置 -GROK_DEFAULTS = { - "app": { - "app_url": "", - "app_key": "grok2api", - "api_key": "", - "public_key": "", - "public_enabled": False, - "image_format": "url", - "video_format": "html", - "temporary": True, - "disable_memory": True, - "stream": True, - "thinking": True, - "dynamic_statsig": True, - "filter_tags": ["xaiartifact", "xai:tool_usage_card", "grok:render"], - }, - "proxy": { - "base_proxy_url": "", - "asset_proxy_url": "", - "cf_clearance": "", - "browser": "chrome136", - "user_agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36", - }, - "voice": { - "timeout": 120, - }, - "chat": { - "concurrent": 10, - "timeout": 60, - "stream_timeout": 60, - }, - "video": { - "concurrent": 10, - "timeout": 60, - "stream_timeout": 60, - }, - "retry": { - "max_retry": 3, - "retry_status_codes": [401, 429, 403], - "retry_backoff_base": 0.5, - "retry_backoff_factor": 2.0, - "retry_backoff_max": 30.0, - "retry_budget": 90.0, - }, - "image": { - "timeout": 120, - "stream_timeout": 120, - "final_timeout": 15, - "nsfw": True, - "medium_min_bytes": 30000, - "final_min_bytes": 100000, - }, - "token": { - "auto_refresh": True, - "refresh_interval_hours": 8, - "super_refresh_interval_hours": 2, - "fail_threshold": 5, - "save_delay_ms": 500, - "reload_interval_sec": 30, - }, - "cache": { - "enable_auto_clean": True, - "limit_mb": 1024, - }, - "asset": { - "upload_concurrent": 30, - "upload_timeout": 60, - "download_concurrent": 30, - "download_timeout": 60, - "list_concurrent": 10, - "list_timeout": 60, - "list_batch_size": 10, - "delete_concurrent": 10, - "delete_timeout": 60, - "delete_batch_size": 10, - }, - "nsfw": { - "concurrent": 10, - "batch_size": 50, - "timeout": 60, - }, - "usage": { - "concurrent": 10, - "batch_size": 50, - "timeout": 60, - }, -} +from pathlib import Path +import tomllib + +from app.core.logger import logger + +DEFAULTS_FILE = Path(__file__).resolve().parent.parent.parent.parent / "config.defaults.toml" + +# Grok 服务默认配置(运行时从 config.defaults.toml 读取并缓存) +GROK_DEFAULTS: dict = {} def get_grok_defaults(): """获取 Grok 默认配置""" + global GROK_DEFAULTS + if GROK_DEFAULTS: + return GROK_DEFAULTS + if not DEFAULTS_FILE.exists(): + logger.warning(f"Defaults file not found: {DEFAULTS_FILE}") + return GROK_DEFAULTS + try: + with DEFAULTS_FILE.open("rb") as f: + GROK_DEFAULTS = tomllib.load(f) + except Exception as e: + logger.warning(f"Failed to load defaults from {DEFAULTS_FILE}: {e}") return GROK_DEFAULTS diff --git a/app/services/grok/services/chat.py b/app/services/grok/services/chat.py index 0d203276..260bb438 100644 --- a/app/services/grok/services/chat.py +++ b/app/services/grok/services/chat.py @@ -341,59 +341,19 @@ def __init__(self, model: str, token: str = "", show_think: bool = None): self.think_opened: bool = False self.role_sent: bool = False self.filter_tags = get_config("app.filter_tags") - self._tag_buffer: str = "" - self._in_filter_tag: bool = False self.show_think = bool(show_think) def _filter_token(self, token: str) -> str: - """Filter special tags (supports cross-token tag filtering).""" - if not self.filter_tags: + """Filter special tags in current token only.""" + if not self.filter_tags or not token: return token - result = [] - i = 0 - while i < len(token): - char = token[i] - - if self._in_filter_tag: - self._tag_buffer += char - if char == ">": - if "/>" in self._tag_buffer: - self._in_filter_tag = False - self._tag_buffer = "" - else: - for tag in self.filter_tags: - if f"" in self._tag_buffer: - self._in_filter_tag = False - self._tag_buffer = "" - break - i += 1 - continue - - if char == "<": - remaining = token[i:] - tag_started = False - for tag in self.filter_tags: - if remaining.startswith(f"<{tag}"): - tag_started = True - break - if len(remaining) < len(tag) + 1: - for j in range(1, len(remaining) + 1): - if f"<{tag}".startswith(remaining[:j]): - tag_started = True - break - - if tag_started: - self._in_filter_tag = True - self._tag_buffer = char - i += 1 - continue - - result.append(char) - i += 1 + for tag in self.filter_tags: + if f"<{tag}" in token or f" str: """Build SSE response.""" diff --git a/app/services/grok/services/image.py b/app/services/grok/services/image.py index 7db7bf98..75e26987 100644 --- a/app/services/grok/services/image.py +++ b/app/services/grok/services/image.py @@ -48,6 +48,7 @@ async def generate( size: str, aspect_ratio: str, stream: bool, + enable_nsfw: Optional[bool] = None, ) -> ImageGenerationResult: max_token_retries = int(get_config("retry.max_retry")) tried_tokens: set[str] = set() @@ -83,6 +84,7 @@ async def _stream_retry() -> AsyncGenerator[str, None]: response_format=response_format, size=size, aspect_ratio=aspect_ratio, + enable_nsfw=enable_nsfw, ) async for chunk in result.data: yielded = True @@ -137,6 +139,7 @@ async def _stream_retry() -> AsyncGenerator[str, None]: n=n, response_format=response_format, aspect_ratio=aspect_ratio, + enable_nsfw=enable_nsfw, ) except UpstreamException as e: last_error = e @@ -169,8 +172,10 @@ async def _stream_ws( response_format: str, size: str, aspect_ratio: str, + enable_nsfw: Optional[bool] = None, ) -> ImageGenerationResult: - enable_nsfw = bool(get_config("image.nsfw")) + if enable_nsfw is None: + enable_nsfw = bool(get_config("image.nsfw")) upstream = image_service.stream( token=token, prompt=prompt, @@ -203,8 +208,10 @@ async def _collect_ws( n: int, response_format: str, aspect_ratio: str, + enable_nsfw: Optional[bool] = None, ) -> ImageGenerationResult: - enable_nsfw = bool(get_config("image.nsfw")) + if enable_nsfw is None: + enable_nsfw = bool(get_config("image.nsfw")) all_images: List[str] = [] seen = set() expected_per_call = 6 @@ -498,6 +505,8 @@ async def process(self, response: AsyncIterable[dict]) -> AsyncGenerator[str, No "size": self.size, "index": index, "partial_image_index": partial_index, + "image_id": image_id, + "stage": stage, }, ) @@ -550,6 +559,8 @@ async def process(self, response: AsyncIterable[dict]) -> AsyncGenerator[str, No "created_at": int(time.time()), "size": self.size, "index": index, + "image_id": image_id, + "stage": "final", "usage": { "total_tokens": 0, "input_tokens": 0, diff --git a/app/services/grok/services/video.py b/app/services/grok/services/video.py index 5ba65048..70f477e3 100644 --- a/app/services/grok/services/video.py +++ b/app/services/grok/services/video.py @@ -4,6 +4,7 @@ import asyncio import uuid +import re from typing import Any, AsyncGenerator, AsyncIterable, Optional import orjson @@ -31,6 +32,8 @@ from app.services.grok.utils.retry import rate_limited from app.services.reverse.app_chat import AppChatReverse from app.services.reverse.media_post import MediaPostReverse +from app.services.reverse.video_upscale import VideoUpscaleReverse +from app.services.token.manager import BASIC_POOL_NAME _VIDEO_SEMAPHORE = None _VIDEO_SEM_VALUE = 0 @@ -63,13 +66,17 @@ async def create_post( if media_type == "MEDIA_POST_TYPE_IMAGE" and not media_url: raise ValidationException("media_url is required for image posts") + prompt_value = prompt if media_type == "MEDIA_POST_TYPE_VIDEO" else "" + media_value = media_url or "" + async with AsyncSession() as session: async with _get_video_semaphore(): response = await MediaPostReverse.request( session, token, media_type, - media_url or "", + media_value, + prompt=prompt_value, ) post_id = response.json().get("post", {}).get("id", "") @@ -264,6 +271,8 @@ async def completions( token = token_info.token if token.startswith("sso="): token = token[4:] + pool_name = token_mgr.get_pool_name_for_token(token) + should_upscale = resolution == "720p" and pool_name == BASIC_POOL_NAME try: # Handle image attachments. @@ -305,12 +314,19 @@ async def completions( # Process response. if is_stream: - processor = VideoStreamProcessor(model, token, show_think) + processor = VideoStreamProcessor( + model, + token, + show_think, + upscale_on_finish=should_upscale, + ) return wrap_stream_with_usage( processor.process(response), token_mgr, token, model ) - result = await VideoCollectProcessor(model, token).process(response) + result = await VideoCollectProcessor( + model, token, upscale_on_finish=should_upscale + ).process(response) try: model_info = ModelService.get(model) effort = ( @@ -350,13 +366,53 @@ async def completions( class VideoStreamProcessor(BaseProcessor): """Video stream response processor.""" - def __init__(self, model: str, token: str = "", show_think: bool = None): + def __init__( + self, + model: str, + token: str = "", + show_think: bool = None, + upscale_on_finish: bool = False, + ): super().__init__(model, token) self.response_id: Optional[str] = None self.think_opened: bool = False self.role_sent: bool = False self.show_think = bool(show_think) + self.upscale_on_finish = bool(upscale_on_finish) + + @staticmethod + def _extract_video_id(video_url: str) -> str: + if not video_url: + return "" + match = re.search(r"/generated/([0-9a-fA-F-]{32,36})/", video_url) + if match: + return match.group(1) + match = re.search(r"/([0-9a-fA-F-]{32,36})/generated_video", video_url) + if match: + return match.group(1) + return "" + + async def _upscale_video_url(self, video_url: str) -> str: + if not video_url or not self.upscale_on_finish: + return video_url + video_id = self._extract_video_id(video_url) + if not video_id: + logger.warning("Video upscale skipped: unable to extract video id") + return video_url + try: + async with AsyncSession() as session: + response = await VideoUpscaleReverse.request( + session, self.token, video_id + ) + payload = response.json() if response is not None else {} + hd_url = payload.get("hdMediaUrl") if isinstance(payload, dict) else None + if hd_url: + logger.info(f"Video upscale completed: {hd_url}") + return hd_url + except Exception as e: + logger.warning(f"Video upscale failed: {e}") + return video_url def _sse(self, content: str = "", role: str = None, finish: str = None) -> str: """Build SSE response.""" @@ -443,6 +499,9 @@ async def process( self.think_opened = False if video_url: + if self.upscale_on_finish: + yield self._sse("正在对视频进行超分辨率\n") + video_url = await self._upscale_video_url(video_url) dl_service = self._get_dl() rendered = await dl_service.render_video( video_url, self.token, thumbnail_url @@ -500,8 +559,42 @@ async def process( class VideoCollectProcessor(BaseProcessor): """Video non-stream response processor.""" - def __init__(self, model: str, token: str = ""): + def __init__(self, model: str, token: str = "", upscale_on_finish: bool = False): super().__init__(model, token) + self.upscale_on_finish = bool(upscale_on_finish) + + @staticmethod + def _extract_video_id(video_url: str) -> str: + if not video_url: + return "" + match = re.search(r"/generated/([0-9a-fA-F-]{32,36})/", video_url) + if match: + return match.group(1) + match = re.search(r"/([0-9a-fA-F-]{32,36})/generated_video", video_url) + if match: + return match.group(1) + return "" + + async def _upscale_video_url(self, video_url: str) -> str: + if not video_url or not self.upscale_on_finish: + return video_url + video_id = self._extract_video_id(video_url) + if not video_id: + logger.warning("Video upscale skipped: unable to extract video id") + return video_url + try: + async with AsyncSession() as session: + response = await VideoUpscaleReverse.request( + session, self.token, video_id + ) + payload = response.json() if response is not None else {} + hd_url = payload.get("hdMediaUrl") if isinstance(payload, dict) else None + if hd_url: + logger.info(f"Video upscale completed: {hd_url}") + return hd_url + except Exception as e: + logger.warning(f"Video upscale failed: {e}") + return video_url async def process(self, response: AsyncIterable[bytes]) -> dict[str, Any]: """Process and collect video response.""" @@ -528,6 +621,8 @@ async def process(self, response: AsyncIterable[bytes]) -> dict[str, Any]: thumbnail_url = video_resp.get("thumbnailImageUrl", "") if video_url: + if self.upscale_on_finish: + video_url = await self._upscale_video_url(video_url) dl_service = self._get_dl() content = await dl_service.render_video( video_url, self.token, thumbnail_url diff --git a/app/services/reverse/__init__.py b/app/services/reverse/__init__.py index 08734a8e..6e8aebfa 100644 --- a/app/services/reverse/__init__.py +++ b/app/services/reverse/__init__.py @@ -9,6 +9,7 @@ from .nsfw_mgmt import NsfwMgmtReverse from .rate_limits import RateLimitsReverse from .set_birth import SetBirthReverse +from .video_upscale import VideoUpscaleReverse from .ws_livekit import LivekitTokenReverse, LivekitWebSocketReverse from .ws_imagine import ImagineWebSocketReverse from .utils.headers import build_headers @@ -24,6 +25,7 @@ "NsfwMgmtReverse", "RateLimitsReverse", "SetBirthReverse", + "VideoUpscaleReverse", "LivekitTokenReverse", "LivekitWebSocketReverse", "ImagineWebSocketReverse", diff --git a/app/services/reverse/media_post.py b/app/services/reverse/media_post.py index 044e9189..6e70d539 100644 --- a/app/services/reverse/media_post.py +++ b/app/services/reverse/media_post.py @@ -25,6 +25,7 @@ async def request( token: str, mediaType: str, mediaUrl: str, + prompt: str = "", ) -> Any: """Create media post in Grok. @@ -51,10 +52,11 @@ async def request( ) # Build payload - payload = { - "mediaType": mediaType, - "mediaUrl": mediaUrl, - } + payload = {"mediaType": mediaType} + if mediaUrl: + payload["mediaUrl"] = mediaUrl + if prompt: + payload["prompt"] = prompt # Curl Config timeout = get_config("video.timeout") @@ -71,13 +73,18 @@ async def _do_request(): ) if response.status_code != 200: + content = "" + try: + content = await response.text() + except Exception: + pass logger.error( f"MediaPostReverse: Media post create failed, {response.status_code}", extra={"error_type": "UpstreamException"}, ) raise UpstreamException( message=f"MediaPostReverse: Media post create failed, {response.status_code}", - details={"status": response.status_code}, + details={"status": response.status_code, "body": content}, ) return response diff --git a/app/services/reverse/video_upscale.py b/app/services/reverse/video_upscale.py new file mode 100644 index 00000000..f6c70e17 --- /dev/null +++ b/app/services/reverse/video_upscale.py @@ -0,0 +1,109 @@ +""" +Reverse interface: video upscale. +""" + +import orjson +from typing import Any +from curl_cffi.requests import AsyncSession + +from app.core.logger import logger +from app.core.config import get_config +from app.core.exceptions import UpstreamException +from app.services.token.service import TokenService +from app.services.reverse.utils.headers import build_headers +from app.services.reverse.utils.retry import retry_on_status + +VIDEO_UPSCALE_API = "https://grok.com/rest/media/video/upscale" + + +class VideoUpscaleReverse: + """/rest/media/video/upscale reverse interface.""" + + @staticmethod + async def request(session: AsyncSession, token: str, video_id: str) -> Any: + """Upscale video (image upscaling endpoint) in Grok. + + Args: + session: AsyncSession, the session to use for the request. + token: str, the SSO token. + video_id: str, the video id. + + Returns: + Any: The response from the request. + """ + try: + # Get proxies + base_proxy = get_config("proxy.base_proxy_url") + proxies = {"http": base_proxy, "https": base_proxy} if base_proxy else None + + # Build headers + headers = build_headers( + cookie_token=token, + content_type="application/json", + origin="https://grok.com", + referer="https://grok.com", + ) + + # Build payload + payload = {"videoId": video_id} + + # Curl Config + timeout = get_config("video.timeout") + browser = get_config("proxy.browser") + + async def _do_request(): + response = await session.post( + VIDEO_UPSCALE_API, + headers=headers, + data=orjson.dumps(payload), + timeout=timeout, + proxies=proxies, + impersonate=browser, + ) + + if response.status_code != 200: + content = "" + try: + content = await response.text() + except Exception: + pass + logger.error( + f"VideoUpscaleReverse: Upscale failed, {response.status_code}", + extra={"error_type": "UpstreamException"}, + ) + raise UpstreamException( + message=f"VideoUpscaleReverse: Upscale failed, {response.status_code}", + details={"status": response.status_code, "body": content}, + ) + + return response + + return await retry_on_status(_do_request) + + except Exception as e: + # Handle upstream exception + if isinstance(e, UpstreamException): + status = None + if e.details and "status" in e.details: + status = e.details["status"] + else: + status = getattr(e, "status_code", None) + if status == 401: + try: + await TokenService.record_fail(token, status, "video_upscale_auth_failed") + except Exception: + pass + raise + + # Handle other non-upstream exceptions + logger.error( + f"VideoUpscaleReverse: Upscale failed, {str(e)}", + extra={"error_type": type(e).__name__}, + ) + raise UpstreamException( + message=f"VideoUpscaleReverse: Upscale failed, {str(e)}", + details={"status": 502, "error": str(e)}, + ) + + +__all__ = ["VideoUpscaleReverse"] diff --git a/app/services/token/manager.py b/app/services/token/manager.py index 4955909b..1530cbea 100644 --- a/app/services/token/manager.py +++ b/app/services/token/manager.py @@ -299,6 +299,14 @@ def get_token_for_video( ) return None + def get_pool_name_for_token(self, token_str: str) -> Optional[str]: + """Return pool name for the given token string.""" + raw_token = token_str.replace("sso=", "") + for pool_name, pool in self.pools.items(): + if pool.get(raw_token): + return pool_name + return None + async def consume( self, token_str: str, effort: EffortType = EffortType.LOW ) -> bool: diff --git a/app/static/cache/cache.css b/app/static/admin/css/cache.css similarity index 100% rename from app/static/cache/cache.css rename to app/static/admin/css/cache.css diff --git a/app/static/config/config.css b/app/static/admin/css/config.css similarity index 80% rename from app/static/config/config.css rename to app/static/admin/css/config.css index b887da93..8fa851be 100644 --- a/app/static/config/config.css +++ b/app/static/admin/css/config.css @@ -28,6 +28,7 @@ .config-field { padding-top: 2px; + position: relative; } .config-field-title { @@ -45,3 +46,14 @@ .config-field-input { margin-top: 6px; } + +.config-field.has-action { + padding-right: 44px; +} + +.config-field-action { + position: absolute; + right: 0; + top: 50%; + transform: translateY(-50%); +} diff --git a/app/static/token/token.css b/app/static/admin/css/token.css similarity index 100% rename from app/static/token/token.css rename to app/static/admin/css/token.css diff --git a/app/static/cache/cache.js b/app/static/admin/js/cache.js similarity index 100% rename from app/static/cache/cache.js rename to app/static/admin/js/cache.js diff --git a/app/static/config/config.js b/app/static/admin/js/config.js similarity index 85% rename from app/static/config/config.js rename to app/static/admin/js/config.js index 5c5f0cd0..e2978a66 100644 --- a/app/static/config/config.js +++ b/app/static/admin/js/config.js @@ -37,17 +37,17 @@ const LOCALE_MAP = { "label": "应用设置", "api_key": { title: "API 密钥", desc: "调用 Grok2API 服务的 Token(可选)。" }, "app_key": { title: "后台密码", desc: "登录 Grok2API 管理后台的密码(必填)。" }, - "public_key": { title: "Public Key", desc: "Public 接口访问密钥(可选)。" }, - "public_enabled": { title: "公开功能玩法", desc: "开启后 public 入口可访问(public_key 为空时默认放开)。" }, + "public_enabled": { title: "启用功能玩法", desc: "是否启用功能玩法入口(关闭则功能玩法页面不可访问)。" }, + "public_key": { title: "Public 密码", desc: "功能玩法页面的访问密码(可选)。" }, "app_url": { title: "应用地址", desc: "当前 Grok2API 服务的外部访问 URL,用于文件链接访问。" }, - "image_format": { title: "图片格式", desc: "生成的图片格式(url 或 base64)。" }, - "video_format": { title: "视频格式", desc: "生成的视频格式(html 或 url,url 为处理后的链接)。" }, - "temporary": { title: "临时对话", desc: "是否启用临时对话模式。" }, - "disable_memory": { title: "禁用记忆", desc: "禁用 Grok 记忆功能,以防止响应中出现不相关上下文。" }, + "image_format": { title: "图片格式", desc: "默认生成的图片格式(url 或 base64)。" }, + "video_format": { title: "视频格式", desc: "默认生成的视频格式(html 或 url,url 为处理后的链接)。" }, + "temporary": { title: "临时对话", desc: "是否默认启用临时对话模式。" }, + "disable_memory": { title: "禁用记忆", desc: "是否默认禁用 Grok 记忆功能。" }, "stream": { title: "流式响应", desc: "是否默认启用流式输出。" }, - "thinking": { title: "思维链", desc: "是否启用模型思维链输出。" }, - "dynamic_statsig": { title: "动态指纹", desc: "是否启用动态生成 Statsig 值。" }, - "filter_tags": { title: "过滤标签", desc: "自动过滤 Grok 响应中的特殊标签。" } + "thinking": { title: "思维链", desc: "是否默认启用思维链输出。" }, + "dynamic_statsig": { title: "动态指纹", desc: "是否默认启用动态生成 Statsig 指纹。" }, + "filter_tags": { title: "过滤标签", desc: "设置自动过滤 Grok 响应中的特殊标签。" } }, @@ -258,6 +258,15 @@ function buildSecretInput(section, key, val) { const wrapper = document.createElement('div'); wrapper.className = 'flex items-center gap-2'; + const genBtn = document.createElement('button'); + genBtn.className = 'flex-none w-[32px] h-[32px] flex items-center justify-center bg-black text-white rounded-md hover:opacity-80 transition-opacity'; + genBtn.type = 'button'; + genBtn.title = '生成'; + genBtn.innerHTML = ``; + genBtn.onclick = () => { + input.value = randomKey(16); + }; + const copyBtn = document.createElement('button'); copyBtn.className = 'flex-none w-[32px] h-[32px] flex items-center justify-center bg-black text-white rounded-md hover:opacity-80 transition-opacity'; copyBtn.type = 'button'; @@ -265,11 +274,29 @@ function buildSecretInput(section, key, val) { copyBtn.onclick = () => copyToClipboard(input.value, copyBtn); wrapper.appendChild(input); + wrapper.appendChild(genBtn); wrapper.appendChild(copyBtn); return { input, node: wrapper }; } +function randomKey(len) { + const chars = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'; + const out = []; + if (window.crypto && window.crypto.getRandomValues) { + const buf = new Uint8Array(len); + window.crypto.getRandomValues(buf); + for (let i = 0; i < len; i++) { + out.push(chars[buf[i] % chars.length]); + } + return out.join(''); + } + for (let i = 0; i < len; i++) { + out.push(chars[Math.floor(Math.random() * chars.length)]); + } + return out.join(''); +} + async function init() { apiKey = await ensureAdminKey(); if (apiKey === null) return; @@ -399,6 +426,23 @@ function buildFieldCard(section, key, val) { } fieldCard.appendChild(inputWrapper); + if (section === 'app' && key === 'public_enabled') { + fieldCard.classList.add('has-action'); + const link = document.createElement('a'); + link.href = '/login'; + link.className = 'config-field-action flex-none w-[32px] h-[32px] flex items-center justify-center bg-black text-white rounded-md hover:opacity-80 transition-opacity'; + link.title = '功能玩法'; + link.setAttribute('aria-label', '功能玩法'); + link.innerHTML = ``; + link.style.display = val ? 'inline-flex' : 'none'; + fieldCard.appendChild(link); + if (built && built.input) { + built.input.addEventListener('change', () => { + link.style.display = built.input.checked ? 'inline-flex' : 'none'; + }); + } + } + return fieldCard; } diff --git a/app/static/login/login.js b/app/static/admin/js/login.js similarity index 100% rename from app/static/login/login.js rename to app/static/admin/js/login.js diff --git a/app/static/token/token.js b/app/static/admin/js/token.js similarity index 100% rename from app/static/token/token.js rename to app/static/admin/js/token.js diff --git a/app/static/cache/cache.html b/app/static/admin/pages/cache.html similarity index 93% rename from app/static/cache/cache.html rename to app/static/admin/pages/cache.html index 59051ade..5acb03e2 100644 --- a/app/static/cache/cache.html +++ b/app/static/admin/pages/cache.html @@ -5,13 +5,13 @@ Grok2API - 缓存管理 - + - - - + + + @@ -196,13 +196,13 @@

缓存管理

- - - - - - - + + + + + + + diff --git a/app/static/config/config.html b/app/static/admin/pages/config.html similarity index 75% rename from app/static/config/config.html rename to app/static/admin/pages/config.html index 95a244e2..2225fff2 100644 --- a/app/static/config/config.html +++ b/app/static/admin/pages/config.html @@ -5,13 +5,13 @@ Grok2API - 配置管理 - + - - - + + + @@ -46,11 +46,11 @@

配置管理

- - - - - + + + + + diff --git a/app/static/login/login.html b/app/static/admin/pages/login.html similarity index 82% rename from app/static/login/login.html rename to app/static/admin/pages/login.html index 3ee9a91a..9e94cd5f 100644 --- a/app/static/login/login.html +++ b/app/static/admin/pages/login.html @@ -5,7 +5,7 @@ Grok2API - 登录 - + @@ -24,9 +24,9 @@ } } - - - + + + @@ -55,10 +55,10 @@ - - - - + + + + diff --git a/app/static/token/token.html b/app/static/admin/pages/token.html similarity index 95% rename from app/static/token/token.html rename to app/static/admin/pages/token.html index ce6ca07f..6302d958 100644 --- a/app/static/token/token.html +++ b/app/static/admin/pages/token.html @@ -5,13 +5,13 @@ Grok2API - Token 管理 - + - - - + + + @@ -291,13 +291,13 @@ - - - - - - - + + + + + + + diff --git a/app/static/common/common.css b/app/static/common/css/common.css similarity index 100% rename from app/static/common/common.css rename to app/static/common/css/common.css diff --git a/app/static/login/login.css b/app/static/common/css/login.css similarity index 100% rename from app/static/login/login.css rename to app/static/common/css/login.css diff --git a/app/static/common/toast.css b/app/static/common/css/toast.css similarity index 100% rename from app/static/common/toast.css rename to app/static/common/css/toast.css diff --git a/app/static/common/footer.html b/app/static/common/html/footer.html similarity index 100% rename from app/static/common/footer.html rename to app/static/common/html/footer.html diff --git a/app/static/common/header.html b/app/static/common/html/header.html similarity index 73% rename from app/static/common/header.html rename to app/static/common/html/header.html index f994e474..2330a996 100644 --- a/app/static/common/header.html +++ b/app/static/common/html/header.html @@ -14,17 +14,9 @@ class="text-xs text-[var(--accents-4)] hover:text-black">@chenyme
- -
- Public + Token管理 + 配置管理 + 缓存管理
Imagine + Video Voice Live - Public Key diff --git a/app/static/favicon/favicon.ico b/app/static/common/img/favicon/favicon.ico similarity index 100% rename from app/static/favicon/favicon.ico rename to app/static/common/img/favicon/favicon.ico diff --git a/app/static/common/admin-auth.js b/app/static/common/js/admin-auth.js similarity index 100% rename from app/static/common/admin-auth.js rename to app/static/common/js/admin-auth.js diff --git a/app/static/common/batch-sse.js b/app/static/common/js/batch-sse.js similarity index 100% rename from app/static/common/batch-sse.js rename to app/static/common/js/batch-sse.js diff --git a/app/static/common/draggable.js b/app/static/common/js/draggable.js similarity index 100% rename from app/static/common/draggable.js rename to app/static/common/js/draggable.js diff --git a/app/static/common/footer.js b/app/static/common/js/footer.js similarity index 85% rename from app/static/common/footer.js rename to app/static/common/js/footer.js index 6256b2d3..e32efdcd 100644 --- a/app/static/common/footer.js +++ b/app/static/common/js/footer.js @@ -2,7 +2,7 @@ async function loadAdminFooter() { const container = document.getElementById('app-footer'); if (!container) return; try { - const res = await fetch('/static/common/footer.html?v=1'); + const res = await fetch('/static/common/html/footer.html?v=1'); if (!res.ok) return; container.innerHTML = await res.text(); } catch (e) { diff --git a/app/static/common/header.js b/app/static/common/js/header.js similarity index 93% rename from app/static/common/header.js rename to app/static/common/js/header.js index 0433458a..a7e0a259 100644 --- a/app/static/common/header.js +++ b/app/static/common/js/header.js @@ -2,7 +2,7 @@ async function loadAdminHeader() { const container = document.getElementById('app-header'); if (!container) return; try { - const res = await fetch('/static/common/header.html?v=5'); + const res = await fetch('/static/common/html/header.html?v=1'); if (!res.ok) return; container.innerHTML = await res.text(); const path = window.location.pathname; diff --git a/app/static/common/public-header.js b/app/static/common/js/public-header.js similarity index 59% rename from app/static/common/public-header.js rename to app/static/common/js/public-header.js index 7cd79159..82797ec2 100644 --- a/app/static/common/public-header.js +++ b/app/static/common/js/public-header.js @@ -2,9 +2,21 @@ async function loadPublicHeader() { const container = document.getElementById('app-header'); if (!container) return; try { - const res = await fetch('/static/common/public-header.html?v=1'); + const res = await fetch('/static/common/html/public-header.html?v=1'); if (!res.ok) return; container.innerHTML = await res.text(); + const logoutBtn = container.querySelector('#public-logout-btn'); + if (logoutBtn) { + logoutBtn.classList.add('hidden'); + try { + const verify = await fetch('/v1/public/verify', { method: 'GET' }); + if (verify.status === 401) { + logoutBtn.classList.remove('hidden'); + } + } catch (e) { + // Ignore verification errors and keep it hidden + } + } const path = window.location.pathname; const links = container.querySelectorAll('a[data-nav]'); links.forEach((link) => { diff --git a/app/static/common/toast.js b/app/static/common/js/toast.js similarity index 100% rename from app/static/common/toast.js rename to app/static/common/js/toast.js diff --git a/app/static/imagine/imagine.css b/app/static/public/css/imagine.css similarity index 80% rename from app/static/imagine/imagine.css rename to app/static/public/css/imagine.css index 99c462ba..5a15513b 100644 --- a/app/static/imagine/imagine.css +++ b/app/static/public/css/imagine.css @@ -23,6 +23,10 @@ overflow: visible; } +.imagine-card.settings-card { + padding: 16px; +} + .imagine-card-collapsible { transition: all 0.3s ease; } @@ -138,65 +142,143 @@ } .imagine-textarea { - min-height: 96px; + height: auto; + min-height: 0; resize: vertical; line-height: 1.5; + flex: 1; } -.settings-layout { +.settings-grid { + --row-gap: 10px; + --row-h: 52px; + display: grid; + grid-template-columns: 2.2fr 1fr; + grid-template-rows: calc(var(--row-h) * 2 + var(--row-gap)) var(--row-h) var(--row-h); + column-gap: 16px; + row-gap: var(--row-gap); + align-items: stretch; +} + +.settings-block { display: flex; flex-direction: column; - gap: 16px; + gap: 6px; + height: 100%; +} + +.settings-block-prompt { + grid-column: 1; + grid-row: 1; } -.settings-top { +.settings-block-duo { + grid-column: 2; + grid-row: 1; + height: 100%; display: grid; - grid-template-columns: 2fr 1fr; - gap: 16px; - align-items: start; - padding-bottom: 8px; + grid-template-rows: 1fr 1fr; + gap: 10px; +} + +.settings-block-row2 { + grid-column: 1; + grid-row: 2; +} + +.settings-block-row2b { + grid-column: 2; + grid-row: 2; +} + +.settings-block-row3 { + grid-column: 1; + grid-row: 3; +} + +.settings-block-row3b { + grid-column: 2; + grid-row: 3; +} + +.settings-block-single { + height: 100%; + align-self: stretch; + justify-content: stretch; } .settings-prompt { display: flex; flex-direction: column; + height: 100%; +} + +.settings-row { + display: grid; + grid-template-columns: 1fr 1fr; + gap: 8px; + align-items: stretch; + height: 100%; } -.settings-options { +.settings-field { display: flex; flex-direction: column; - gap: 16px; + height: 100%; + justify-content: space-between; } -.settings-bottom-row { - display: grid; - grid-template-columns: 2fr 1fr; - gap: 16px; - align-items: start; +.settings-field .field-label { + margin-bottom: 0; } -.settings-toggles { - display: grid; - grid-template-columns: 1fr 1fr; - gap: 16px; +.toggle-row { + display: flex; + align-items: center; + justify-content: space-between; + gap: 10px; + height: 100%; + padding: 0 10px; + border-radius: 10px; + background: #f7f7f8; + border: 1px solid var(--accents-1); } -.settings-toggle { +.toggle-text { display: flex; flex-direction: column; + gap: 2px; + min-width: 0; } -.settings-toggle .field-label { - margin-bottom: 6px; +.toggle-title { + font-size: 11px; + font-weight: 600; + color: var(--accents-6); + line-height: 1.2; +} + +.toggle-desc { + font-size: 10px; + color: var(--accents-4); + line-height: 1.2; +} + +.toggle-row-fill { + height: 100%; } .settings-folder { display: flex; flex-direction: column; + align-items: flex-start; + height: 100%; + justify-content: space-between; } .settings-folder .field-label { - margin-bottom: 6px; + margin-bottom: 0; + line-height: 1; } .folder-select-btn { @@ -217,16 +299,29 @@ } @media (max-width: 768px) { - .settings-top { + .settings-grid { grid-template-columns: 1fr; + grid-template-rows: auto; } - - .settings-bottom-row { + + .settings-row { grid-template-columns: 1fr; } - - .settings-toggles { - grid-template-columns: 1fr; + + .settings-block-prompt, + .settings-block-duo, + .settings-block-row2, + .settings-block-row2b, + .settings-block-row3, + .settings-block-row3b { + grid-column: 1; + grid-row: auto; + } + + .settings-block-duo { + height: auto; + display: flex; + flex-direction: column; } } @@ -374,14 +469,14 @@ } .meta-item { - padding: 9px 12px; + padding: 12px; border-radius: 10px; background: #f7f7f8; display: flex; align-items: center; justify-content: space-between; gap: 12px; - min-height: 38px; + min-height: 42px; } .meta-label { @@ -449,9 +544,39 @@ .imagine-empty { text-align: center; - padding: 28px 0 8px; + padding: 32px 24px; font-size: 12px; color: var(--accents-4); + display: flex; + flex-direction: column; + align-items: center; + justify-content: center; + gap: 6px; + min-height: 220px; + border: 1px dashed var(--border); + border-radius: 14px; + background: linear-gradient(135deg, #fafafa 0%, #f3f4f6 100%); +} + +.empty-title { + font-size: 13px; + font-weight: 600; + color: var(--accents-6); +} + +.empty-subtitle { + font-size: 11px; + color: var(--accents-4); +} + +.empty-hint { + margin-top: 4px; + font-size: 10px; + color: var(--accents-3); + padding: 2px 8px; + border-radius: 999px; + border: 1px solid var(--accents-1); + background: #fff; } .waterfall { diff --git a/app/static/public/css/video.css b/app/static/public/css/video.css new file mode 100644 index 00000000..a634b7a9 --- /dev/null +++ b/app/static/public/css/video.css @@ -0,0 +1,482 @@ +:root { + --video-surface: #ffffff; + --video-muted: #f2f4f7; + --video-bg: #f6f7fb; + --video-outline: #e6e6e6; + --video-ink: #0f172a; + --video-accent: #0f172a; + --video-glow: rgba(15, 23, 42, 0.08); +} + +body { + background-color: var(--video-bg); + background-image: none; +} + +.video-header-row { + display: flex; + align-items: center; + justify-content: space-between; + gap: 12px; + flex-wrap: wrap; +} + +.video-hero { + padding: 16px 20px; + border-radius: 12px; + border: none; + background: #fff; + box-shadow: none; +} + +.hero-kicker { + font-size: 11px; + color: var(--accents-4); + text-transform: uppercase; + letter-spacing: 0.18em; + margin-bottom: 6px; +} + +.hero-note { + display: flex; + gap: 16px; + flex-wrap: wrap; + margin-top: 12px; + font-size: 11px; + color: var(--accents-5); +} + +.hero-note-item { + display: flex; + align-items: center; + gap: 8px; + background: #f5f5f5; + border: none; + padding: 6px 10px; + border-radius: 999px; +} + +.hero-dot { + width: 6px; + height: 6px; + border-radius: 50%; + background: var(--video-accent); + display: inline-block; +} + +.video-actions { + display: flex; + gap: 8px; + flex-wrap: wrap; +} + +.video-top-grid { + display: grid; + grid-template-columns: minmax(0, 2fr) minmax(0, 1fr); + gap: 24px; + align-items: stretch; + margin-bottom: 24px; +} + +@media (max-width: 1024px) { + .video-top-grid { + grid-template-columns: 1fr; + } +} + +.video-card { + background: var(--video-surface); + border: none; + border-radius: 12px; + padding: 20px; + box-shadow: none; + position: relative; + overflow: hidden; + --field-height: 32px; + --field-label: 11px; + --field-label-gap: 6px; + --field-block: calc(var(--field-height) + var(--field-label) + var(--field-label-gap)); + --field-row-gap: 12px; +} + +.video-card::after { + content: none; +} + +.video-card-glow::before { + content: none; +} + +.video-card-contrast { + background: #fff; + color: var(--video-ink); + border: none; +} + +.video-card-contrast .card-title, +.video-card-contrast .meta-label, +.video-card-contrast .meta-value, +.video-card-contrast .progress-meta, +.video-card-contrast .video-meta { + color: #e2e8f0; +} + +.card-title { + font-size: 13px; + font-weight: 600; + color: var(--video-ink); + margin-bottom: 12px; +} + +.field-label { + display: block; + font-size: 11px; + color: var(--accents-4); + margin-bottom: 6px; + line-height: 1; +} + +.video-card input.geist-input, +.video-card select.geist-input { + height: var(--field-height); +} + +.settings-grid { + display: grid; + grid-template-columns: minmax(0, 2fr) minmax(0, 1fr) minmax(0, 1fr); + grid-template-rows: var(--field-block) var(--field-block) auto; + column-gap: 16px; + row-gap: var(--field-row-gap); + align-items: start; +} + +.settings-block { + display: flex; + flex-direction: column; +} + +.prompt-block { + grid-column: 1; + grid-row: 1 / span 2; +} + +.ref-block { + grid-column: 1; + grid-row: 3; +} + +.ratio-block { + grid-column: 2; + grid-row: 1; +} + +.length-block { + grid-column: 3; + grid-row: 1; +} + +.resolution-block { + grid-column: 2; + grid-row: 2; +} + +.preset-block { + grid-column: 3; + grid-row: 2; +} + +.upload-block { + grid-column: 2; + grid-row: 3; +} + +.clear-block { + grid-column: 3; + grid-row: 3; +} + +.upload-block .geist-button-outline, +.clear-block .geist-button-outline { + height: 32px; + width: 100%; +} + +.ref-controls { + display: grid; + grid-template-columns: 1fr; + gap: 8px; +} + +.ref-controls .geist-input { + min-width: 0; +} + +.ref-file-input { + display: none; +} + +.ref-meta { + margin-top: 6px; +} + +.ref-name { + font-size: 11px; + color: var(--accents-4); + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; + max-width: 100%; + display: inline-block; +} + +@media (max-width: 640px) { + .ref-controls { + grid-template-columns: 1fr; + } +} + +@media (max-width: 900px) { + .settings-grid { + grid-template-columns: 1fr; + grid-template-rows: none; + } + + .settings-block { + grid-column: auto; + grid-row: auto; + } +} + +.video-textarea { + min-height: calc(var(--field-block) * 2 + var(--field-row-gap) - var(--field-label) - var(--field-label-gap)); + height: calc(var(--field-block) * 2 + var(--field-row-gap) - var(--field-label) - var(--field-label-gap)); + resize: vertical; +} + +.status-header { + display: flex; + align-items: center; + justify-content: space-between; + gap: 8px; + margin-bottom: 12px; +} + +.status-text { + font-size: 11px; + color: var(--accents-4); +} + +.status-text.connected { + color: #059669; +} + +.status-text.connecting { + color: #d97706; +} + +.status-text.error { + color: #dc2626; +} + +.progress-wrap { + margin-bottom: 12px; +} + +.progress-bar { + width: 100%; + height: 8px; + border-radius: 999px; + background: #f0f0f0; + overflow: hidden; + position: relative; +} + +.progress-fill { + height: 100%; + width: 0%; + border-radius: 999px; + background: #111; + transition: width 0.3s ease; + position: absolute; + left: 0; + top: 0; +} + +.progress-bar.indeterminate .progress-fill { + width: 40%; + animation: progress-indeterminate 1.2s ease-in-out infinite; +} + +@keyframes progress-indeterminate { + 0% { + transform: translateX(-120%); + } + 100% { + transform: translateX(220%); + } +} + +.progress-meta { + display: flex; + align-items: center; + justify-content: space-between; + margin-top: 8px; + font-size: 11px; + color: var(--accents-4); +} + +.meta-grid { + display: grid; + grid-template-columns: repeat(2, minmax(0, 1fr)); + gap: 12px; +} + +.meta-item { + padding: 10px 12px; + border-radius: 10px; + background: #f5f5f5; + border: 1px solid transparent; + display: flex; + align-items: center; + justify-content: space-between; + gap: 12px; + min-width: 0; +} + +.meta-label { + font-size: 10px; + color: var(--accents-4); +} + +.meta-value { + font-size: 12px; + font-weight: 600; + color: var(--accents-7); + text-align: right; + overflow-wrap: anywhere; +} + +.video-preview-header { + display: flex; + align-items: center; + justify-content: space-between; + gap: 12px; + flex-wrap: wrap; + margin-bottom: 12px; +} + +.preview-actions { + display: flex; + align-items: center; + gap: 8px; + flex-wrap: wrap; +} + +.video-empty { + text-align: center; + color: var(--accents-4); + font-size: 12px; + padding: 42px 12px; + background: #f5f5f5; + border-radius: 12px; + border: 1px dashed var(--video-outline); +} + +.video-stage { + min-height: 240px; + padding: 12px; + border-radius: 12px; + border: 1px solid var(--video-outline); + background: #111; + display: flex; + flex-direction: column; + gap: 12px; + align-items: stretch; +} + +.video-item { + background: #000; + border-radius: 10px; + padding: 10px; + display: flex; + flex-direction: column; + gap: 10px; +} + +.video-item-bar { + display: flex; + align-items: center; + justify-content: space-between; + gap: 12px; + color: #e2e8f0; + font-size: 11px; +} + +.video-item-title { + font-size: 11px; + color: #e2e8f0; +} + +.video-item-actions { + display: flex; + align-items: center; + gap: 8px; +} + +.video-item-actions .geist-button-outline { + height: 26px; + padding: 0 10px; + border-color: rgba(148, 163, 184, 0.35); + color: #e2e8f0; +} + +.video-item-body { + border-radius: 8px; + background: #000; + overflow: hidden; + min-height: 140px; + display: flex; + align-items: center; + justify-content: center; +} + +.video-item video { + width: 100%; + border-radius: 8px; + background: #000; + max-height: 380px; +} + +.video-item-link { + display: none; + font-size: 11px; + color: rgba(226, 232, 240, 0.7); + word-break: break-all; +} + +.video-item-link.has-url { + display: block; +} + +.video-item-placeholder { + font-size: 12px; + color: rgba(226, 232, 240, 0.7); +} + +.video-item.is-pending .video-item-actions .video-open { + display: none; +} + +.video-stage.hidden, +.video-empty.hidden { + display: none; +} + +.preview-actions .geist-button-outline { + height: 30px; +} + +.video-card-contrast .geist-button-outline { + border-color: var(--video-outline); + color: var(--accents-7); +} diff --git a/app/static/voice/voice.css b/app/static/public/css/voice.css similarity index 100% rename from app/static/voice/voice.css rename to app/static/public/css/voice.css diff --git a/app/static/imagine/imagine.js b/app/static/public/js/imagine.js similarity index 79% rename from app/static/imagine/imagine.js rename to app/static/public/js/imagine.js index 306c75bb..cb9234ca 100644 --- a/app/static/imagine/imagine.js +++ b/app/static/public/js/imagine.js @@ -7,6 +7,9 @@ const concurrentSelect = document.getElementById('concurrentSelect'); const autoScrollToggle = document.getElementById('autoScrollToggle'); const autoDownloadToggle = document.getElementById('autoDownloadToggle'); + const reverseInsertToggle = document.getElementById('reverseInsertToggle'); + const autoFilterToggle = document.getElementById('autoFilterToggle'); + const nsfwSelect = document.getElementById('nsfwSelect'); const selectFolderBtn = document.getElementById('selectFolderBtn'); const folderPath = document.getElementById('folderPath'); const statusText = document.getElementById('statusText'); @@ -36,6 +39,9 @@ let useFileSystemAPI = false; let isSelectionMode = false; let selectedImages = new Set(); + let streamSequence = 0; + const streamImageMap = new Map(); + let finalMinBytesDefault = 100000; function toast(message, type) { if (typeof showToast === 'function') { @@ -103,6 +109,23 @@ function updateModeValue() {} + async function loadFilterDefaults() { + try { + const res = await fetch('/v1/public/imagine/config', { cache: 'no-store' }); + if (!res.ok) return; + const data = await res.json(); + const value = parseInt(data && data.final_min_bytes, 10); + if (Number.isFinite(value) && value >= 0) { + finalMinBytesDefault = value; + } + if (nsfwSelect && typeof data.nsfw === 'boolean') { + nsfwSelect.value = data.nsfw ? 'true' : 'false'; + } + } catch (e) { + // ignore + } + } + function updateLatency(value) { if (value) { @@ -121,6 +144,15 @@ function updateError(value) {} + function isLikelyBase64(raw) { + if (!raw) return false; + if (raw.startsWith('data:')) return true; + if (raw.startsWith('http://') || raw.startsWith('https://')) return false; + const head = raw.slice(0, 16); + if (head.startsWith('/9j/') || head.startsWith('iVBOR') || head.startsWith('R0lGOD')) return true; + return /^[A-Za-z0-9+/=\s]+$/.test(raw); + } + function inferMime(base64) { if (!base64) return 'image/jpeg'; if (base64.startsWith('iVBOR')) return 'image/png'; @@ -129,6 +161,31 @@ return 'image/jpeg'; } + function estimateBase64Bytes(raw) { + if (!raw) return null; + if (raw.startsWith('http://') || raw.startsWith('https://')) { + return null; + } + if (raw.startsWith('/') && !isLikelyBase64(raw)) { + return null; + } + let base64 = raw; + if (raw.startsWith('data:')) { + const comma = raw.indexOf(','); + base64 = comma >= 0 ? raw.slice(comma + 1) : ''; + } + base64 = base64.replace(/\s/g, ''); + if (!base64) return 0; + let padding = 0; + if (base64.endsWith('==')) padding = 2; + else if (base64.endsWith('=')) padding = 1; + return Math.max(0, Math.floor((base64.length * 3) / 4) - padding); + } + + function getFinalMinBytes() { + return Number.isFinite(finalMinBytesDefault) && finalMinBytesDefault >= 0 ? finalMinBytesDefault : 100000; + } + function dataUrlToBlob(dataUrl) { const parts = (dataUrl || '').split(','); if (parts.length < 2) return null; @@ -149,14 +206,14 @@ } } - async function createImagineTask(prompt, ratio, authHeader) { + async function createImagineTask(prompt, ratio, authHeader, nsfwEnabled) { const res = await fetch('/v1/public/imagine/start', { method: 'POST', headers: { ...buildAuthHeaders(authHeader), 'Content-Type': 'application/json' }, - body: JSON.stringify({ prompt, aspect_ratio: ratio }) + body: JSON.stringify({ prompt, aspect_ratio: ratio, nsfw: nsfwEnabled }) }); if (!res.ok) { const text = await res.text(); @@ -166,10 +223,10 @@ return data && data.task_id ? String(data.task_id) : ''; } - async function createImagineTasks(prompt, ratio, concurrent, authHeader) { + async function createImagineTasks(prompt, ratio, concurrent, authHeader, nsfwEnabled) { const tasks = []; for (let i = 0; i < concurrent; i++) { - const taskId = await createImagineTask(prompt, ratio, authHeader); + const taskId = await createImagineTask(prompt, ratio, authHeader, nsfwEnabled); if (!taskId) { throw new Error('Missing task id'); } @@ -239,6 +296,13 @@ function appendImage(base64, meta) { if (!waterfall) return; + if (autoFilterToggle && autoFilterToggle.checked) { + const bytes = estimateBase64Bytes(base64 || ''); + const minBytes = getFinalMinBytes(); + if (bytes !== null && bytes < minBytes) { + return; + } + } if (emptyState) { emptyState.style.display = 'none'; } @@ -282,10 +346,18 @@ item.classList.add('selection-mode'); } - waterfall.appendChild(item); + if (reverseInsertToggle && reverseInsertToggle.checked) { + waterfall.prepend(item); + } else { + waterfall.appendChild(item); + } if (autoScrollToggle && autoScrollToggle.checked) { - window.scrollTo({ top: document.body.scrollHeight, behavior: 'smooth' }); + if (reverseInsertToggle && reverseInsertToggle.checked) { + window.scrollTo({ top: 0, behavior: 'smooth' }); + } else { + window.scrollTo({ top: document.body.scrollHeight, behavior: 'smooth' }); + } } if (autoDownloadToggle && autoDownloadToggle.checked) { @@ -304,6 +376,132 @@ } } + function upsertStreamImage(raw, meta, imageId, isFinal) { + if (!waterfall || !raw) return; + if (emptyState) { + emptyState.style.display = 'none'; + } + + if (isFinal && autoFilterToggle && autoFilterToggle.checked) { + const bytes = estimateBase64Bytes(raw); + const minBytes = getFinalMinBytes(); + if (bytes !== null && bytes < minBytes) { + const existing = imageId ? streamImageMap.get(imageId) : null; + if (existing) { + if (selectedImages.has(existing)) { + selectedImages.delete(existing); + updateSelectedCount(); + } + existing.remove(); + streamImageMap.delete(imageId); + if (imageCount > 0) { + imageCount -= 1; + updateCount(imageCount); + } + } + return; + } + } + + const isDataUrl = typeof raw === 'string' && raw.startsWith('data:'); + const looksLikeBase64 = typeof raw === 'string' && isLikelyBase64(raw); + const isHttpUrl = typeof raw === 'string' && (raw.startsWith('http://') || raw.startsWith('https://') || (raw.startsWith('/') && !looksLikeBase64)); + const mime = isDataUrl || isHttpUrl ? '' : inferMime(raw); + const dataUrl = isDataUrl || isHttpUrl ? raw : `data:${mime};base64,${raw}`; + + let item = imageId ? streamImageMap.get(imageId) : null; + let isNew = false; + if (!item) { + isNew = true; + streamSequence += 1; + const sequence = streamSequence; + + item = document.createElement('div'); + item.className = 'waterfall-item'; + + const checkbox = document.createElement('div'); + checkbox.className = 'image-checkbox'; + + const img = document.createElement('img'); + img.loading = 'lazy'; + img.decoding = 'async'; + img.alt = imageId ? `image-${imageId}` : 'image'; + img.src = dataUrl; + + const metaBar = document.createElement('div'); + metaBar.className = 'waterfall-meta'; + const left = document.createElement('div'); + left.textContent = `#${sequence}`; + const right = document.createElement('span'); + right.textContent = ''; + if (meta && meta.elapsed_ms) { + right.textContent = `${meta.elapsed_ms}ms`; + } + + metaBar.appendChild(left); + metaBar.appendChild(right); + + item.appendChild(checkbox); + item.appendChild(img); + item.appendChild(metaBar); + + const prompt = (meta && meta.prompt) ? String(meta.prompt) : (promptInput ? promptInput.value.trim() : ''); + item.dataset.imageUrl = dataUrl; + item.dataset.prompt = prompt || 'image'; + + if (isSelectionMode) { + item.classList.add('selection-mode'); + } + + if (reverseInsertToggle && reverseInsertToggle.checked) { + waterfall.prepend(item); + } else { + waterfall.appendChild(item); + } + + if (imageId) { + streamImageMap.set(imageId, item); + } + + imageCount += 1; + updateCount(imageCount); + } else { + const img = item.querySelector('img'); + if (img) { + img.src = dataUrl; + } + item.dataset.imageUrl = dataUrl; + const right = item.querySelector('.waterfall-meta span'); + if (right && meta && meta.elapsed_ms) { + right.textContent = `${meta.elapsed_ms}ms`; + } + } + + updateError(''); + + if (isNew && autoScrollToggle && autoScrollToggle.checked) { + if (reverseInsertToggle && reverseInsertToggle.checked) { + window.scrollTo({ top: 0, behavior: 'smooth' }); + } else { + window.scrollTo({ top: document.body.scrollHeight, behavior: 'smooth' }); + } + } + + if (isFinal && autoDownloadToggle && autoDownloadToggle.checked) { + const timestamp = Date.now(); + const ext = mime === 'image/png' ? 'png' : 'jpg'; + const filename = `imagine_${timestamp}_${imageId || streamSequence}.${ext}`; + + if (useFileSystemAPI && directoryHandle) { + saveToFileSystem(raw, filename).catch(() => { + downloadImage(raw, filename); + }); + } else { + downloadImage(raw, filename); + } + } + } + function handleMessage(raw) { let data = null; try { @@ -313,7 +511,15 @@ } if (!data || typeof data !== 'object') return; - if (data.type === 'image') { + if (data.type === 'image_generation.partial_image' || data.type === 'image_generation.completed') { + const imageId = data.image_id || data.imageId; + const payload = data.b64_json || data.url || data.image; + if (!payload || !imageId) { + return; + } + const isFinal = data.type === 'image_generation.completed' || data.stage === 'final'; + upsertStreamImage(payload, data, imageId, isFinal); + } else if (data.type === 'image') { imageCount += 1; updateCount(imageCount); updateLatency(data.elapsed_ms); @@ -329,8 +535,8 @@ } setStatus('', '已停止'); } - } else if (data.type === 'error') { - const message = data.message || '生成失败'; + } else if (data.type === 'error' || data.error) { + const message = data.message || (data.error && data.error.message) || '生成失败'; updateError(message); toast(message, 'error'); } @@ -442,6 +648,7 @@ const concurrent = concurrentSelect ? parseInt(concurrentSelect.value, 10) : 1; const ratio = ratioSelect ? ratioSelect.value : '2:3'; + const nsfwEnabled = nsfwSelect ? nsfwSelect.value === 'true' : true; if (isRunning) { toast('已在运行中', 'warning'); @@ -459,7 +666,7 @@ let taskIds = []; try { - taskIds = await createImagineTasks(prompt, ratio, concurrent, authHeader); + taskIds = await createImagineTasks(prompt, ratio, concurrent, authHeader, nsfwEnabled); } catch (e) { setStatus('error', '创建任务失败'); startBtn.disabled = false; @@ -557,10 +764,12 @@ if (!ws || ws.readyState !== WebSocket.OPEN) return; const prompt = promptOverride || (promptInput ? promptInput.value.trim() : ''); const ratio = ratioSelect ? ratioSelect.value : '2:3'; + const nsfwEnabled = nsfwSelect ? nsfwSelect.value === 'true' : true; const payload = { type: 'start', prompt, - aspect_ratio: ratio + aspect_ratio: ratio, + nsfw: nsfwEnabled }; ws.send(JSON.stringify(payload)); updateError(''); @@ -590,6 +799,8 @@ if (waterfall) { waterfall.innerHTML = ''; } + streamImageMap.clear(); + streamSequence = 0; imageCount = 0; totalLatency = 0; latencyCount = 0; @@ -624,6 +835,8 @@ }); } + loadFilterDefaults(); + if (ratioSelect) { ratioSelect.addEventListener('change', () => { if (isRunning) { diff --git a/app/static/public/login.js b/app/static/public/js/login.js similarity index 100% rename from app/static/public/login.js rename to app/static/public/js/login.js diff --git a/app/static/public/js/video.js b/app/static/public/js/video.js new file mode 100644 index 00000000..076483ce --- /dev/null +++ b/app/static/public/js/video.js @@ -0,0 +1,640 @@ +(() => { + const startBtn = document.getElementById('startBtn'); + const stopBtn = document.getElementById('stopBtn'); + const clearBtn = document.getElementById('clearBtn'); + const promptInput = document.getElementById('promptInput'); + const imageUrlInput = document.getElementById('imageUrlInput'); + const imageFileInput = document.getElementById('imageFileInput'); + const imageFileName = document.getElementById('imageFileName'); + const clearImageFileBtn = document.getElementById('clearImageFileBtn'); + const selectImageFileBtn = document.getElementById('selectImageFileBtn'); + const ratioSelect = document.getElementById('ratioSelect'); + const lengthSelect = document.getElementById('lengthSelect'); + const resolutionSelect = document.getElementById('resolutionSelect'); + const presetSelect = document.getElementById('presetSelect'); + const statusText = document.getElementById('statusText'); + const progressBar = document.getElementById('progressBar'); + const progressFill = document.getElementById('progressFill'); + const progressText = document.getElementById('progressText'); + const durationValue = document.getElementById('durationValue'); + const aspectValue = document.getElementById('aspectValue'); + const lengthValue = document.getElementById('lengthValue'); + const resolutionValue = document.getElementById('resolutionValue'); + const presetValue = document.getElementById('presetValue'); + const videoEmpty = document.getElementById('videoEmpty'); + const videoStage = document.getElementById('videoStage'); + + let currentSource = null; + let currentTaskId = ''; + let isRunning = false; + let progressBuffer = ''; + let contentBuffer = ''; + let collectingContent = false; + let startAt = 0; + let fileDataUrl = ''; + let elapsedTimer = null; + let lastProgress = 0; + let currentPreviewItem = null; + let previewCount = 0; + const DEFAULT_REASONING_EFFORT = 'low'; + + function toast(message, type) { + if (typeof showToast === 'function') { + showToast(message, type); + } + } + + function setStatus(state, text) { + if (!statusText) return; + statusText.textContent = text; + statusText.classList.remove('connected', 'connecting', 'error'); + if (state) { + statusText.classList.add(state); + } + } + + function setButtons(running) { + if (!startBtn || !stopBtn) return; + if (running) { + startBtn.classList.add('hidden'); + stopBtn.classList.remove('hidden'); + } else { + startBtn.classList.remove('hidden'); + stopBtn.classList.add('hidden'); + startBtn.disabled = false; + } + } + + function updateProgress(value) { + const safe = Math.max(0, Math.min(100, Number(value) || 0)); + lastProgress = safe; + if (progressFill) { + progressFill.style.width = `${safe}%`; + } + if (progressText) { + progressText.textContent = `${safe}%`; + } + } + + function updateMeta() { + if (aspectValue && ratioSelect) { + aspectValue.textContent = ratioSelect.value; + } + if (lengthValue && lengthSelect) { + lengthValue.textContent = `${lengthSelect.value}s`; + } + if (resolutionValue && resolutionSelect) { + resolutionValue.textContent = resolutionSelect.value; + } + if (presetValue && presetSelect) { + presetValue.textContent = presetSelect.value; + } + } + + function resetOutput(keepPreview) { + progressBuffer = ''; + contentBuffer = ''; + collectingContent = false; + lastProgress = 0; + currentPreviewItem = null; + updateProgress(0); + setIndeterminate(false); + if (!keepPreview) { + if (videoStage) { + videoStage.innerHTML = ''; + videoStage.classList.add('hidden'); + } + if (videoEmpty) { + videoEmpty.classList.remove('hidden'); + } + previewCount = 0; + } + if (durationValue) { + durationValue.textContent = '耗时 -'; + } + } + + function initPreviewSlot() { + if (!videoStage) return; + previewCount += 1; + currentPreviewItem = document.createElement('div'); + currentPreviewItem.className = 'video-item'; + currentPreviewItem.dataset.index = String(previewCount); + currentPreviewItem.classList.add('is-pending'); + + const header = document.createElement('div'); + header.className = 'video-item-bar'; + + const title = document.createElement('div'); + title.className = 'video-item-title'; + title.textContent = `视频 ${previewCount}`; + + const actions = document.createElement('div'); + actions.className = 'video-item-actions'; + + const openBtn = document.createElement('a'); + openBtn.className = 'geist-button-outline text-xs px-3 video-open hidden'; + openBtn.target = '_blank'; + openBtn.rel = 'noopener'; + openBtn.textContent = '打开'; + + const downloadBtn = document.createElement('button'); + downloadBtn.className = 'geist-button-outline text-xs px-3 video-download'; + downloadBtn.type = 'button'; + downloadBtn.textContent = '下载'; + downloadBtn.disabled = true; + + actions.appendChild(openBtn); + actions.appendChild(downloadBtn); + header.appendChild(title); + header.appendChild(actions); + + const body = document.createElement('div'); + body.className = 'video-item-body'; + body.innerHTML = '
生成中…
'; + + const link = document.createElement('div'); + link.className = 'video-item-link'; + + currentPreviewItem.appendChild(header); + currentPreviewItem.appendChild(body); + currentPreviewItem.appendChild(link); + videoStage.appendChild(currentPreviewItem); + videoStage.classList.remove('hidden'); + if (videoEmpty) { + videoEmpty.classList.add('hidden'); + } + } + + function ensurePreviewSlot() { + if (!currentPreviewItem) { + initPreviewSlot(); + } + return currentPreviewItem; + } + + function updateItemLinks(item, url) { + if (!item) return; + const openBtn = item.querySelector('.video-open'); + const downloadBtn = item.querySelector('.video-download'); + const link = item.querySelector('.video-item-link'); + const safeUrl = url || ''; + item.dataset.url = safeUrl; + if (link) { + link.textContent = safeUrl; + link.classList.toggle('has-url', Boolean(safeUrl)); + } + if (openBtn) { + if (safeUrl) { + openBtn.href = safeUrl; + openBtn.classList.remove('hidden'); + } else { + openBtn.classList.add('hidden'); + openBtn.removeAttribute('href'); + } + } + if (downloadBtn) { + downloadBtn.dataset.url = safeUrl; + downloadBtn.disabled = !safeUrl; + } + if (safeUrl) { + item.classList.remove('is-pending'); + } + } + + function setIndeterminate(active) { + if (!progressBar) return; + if (active) { + progressBar.classList.add('indeterminate'); + } else { + progressBar.classList.remove('indeterminate'); + } + } + + function startElapsedTimer() { + stopElapsedTimer(); + if (!durationValue) return; + elapsedTimer = setInterval(() => { + if (!startAt) return; + const seconds = Math.max(0, Math.round((Date.now() - startAt) / 1000)); + durationValue.textContent = `耗时 ${seconds}s`; + }, 1000); + } + + function stopElapsedTimer() { + if (elapsedTimer) { + clearInterval(elapsedTimer); + elapsedTimer = null; + } + } + + function clearFileSelection() { + fileDataUrl = ''; + if (imageFileInput) { + imageFileInput.value = ''; + } + if (imageFileName) { + imageFileName.textContent = '未选择文件'; + } + } + + function normalizeAuthHeader(authHeader) { + if (!authHeader) return ''; + if (authHeader.startsWith('Bearer ')) { + return authHeader.slice(7).trim(); + } + return authHeader; + } + + function buildSseUrl(taskId, rawPublicKey) { + const httpProtocol = window.location.protocol === 'https:' ? 'https' : 'http'; + const base = `${httpProtocol}://${window.location.host}/v1/public/video/sse`; + const params = new URLSearchParams(); + params.set('task_id', taskId); + params.set('t', String(Date.now())); + if (rawPublicKey) { + params.set('public_key', rawPublicKey); + } + return `${base}?${params.toString()}`; + } + + async function createVideoTask(authHeader) { + const prompt = promptInput ? promptInput.value.trim() : ''; + const rawUrl = imageUrlInput ? imageUrlInput.value.trim() : ''; + if (fileDataUrl && rawUrl) { + toast('参考图只能选择其一:URL/Base64 或 本地上传', 'error'); + throw new Error('invalid_reference'); + } + const imageUrl = fileDataUrl || rawUrl; + const res = await fetch('/v1/public/video/start', { + method: 'POST', + headers: { + ...buildAuthHeaders(authHeader), + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + prompt, + image_url: imageUrl || null, + reasoning_effort: DEFAULT_REASONING_EFFORT, + aspect_ratio: ratioSelect ? ratioSelect.value : '3:2', + video_length: lengthSelect ? parseInt(lengthSelect.value, 10) : 6, + resolution_name: resolutionSelect ? resolutionSelect.value : '480p', + preset: presetSelect ? presetSelect.value : 'normal' + }) + }); + if (!res.ok) { + const text = await res.text(); + throw new Error(text || 'Failed to create task'); + } + const data = await res.json(); + return data && data.task_id ? String(data.task_id) : ''; + } + + async function stopVideoTask(taskId, authHeader) { + if (!taskId) return; + try { + await fetch('/v1/public/video/stop', { + method: 'POST', + headers: { + ...buildAuthHeaders(authHeader), + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ task_ids: [taskId] }) + }); + } catch (e) { + // ignore + } + } + + function extractVideoInfo(buffer) { + if (!buffer) return null; + if (buffer.includes('/gi); + if (matches && matches.length) { + return { html: matches[matches.length - 1] }; + } + } + const mdMatches = buffer.match(/\[video\]\(([^)]+)\)/g); + if (mdMatches && mdMatches.length) { + const last = mdMatches[mdMatches.length - 1]; + const urlMatch = last.match(/\[video\]\(([^)]+)\)/); + if (urlMatch) { + return { url: urlMatch[1] }; + } + } + const urlMatches = buffer.match(/https?:\/\/[^\s<)]+/g); + if (urlMatches && urlMatches.length) { + return { url: urlMatches[urlMatches.length - 1] }; + } + return null; + } + + function renderVideoFromHtml(html) { + const container = ensurePreviewSlot(); + if (!container) return; + const body = container.querySelector('.video-item-body'); + if (!body) return; + body.innerHTML = html; + const videoEl = body.querySelector('video'); + let videoUrl = ''; + if (videoEl) { + videoEl.controls = true; + videoEl.preload = 'metadata'; + const source = videoEl.querySelector('source'); + if (source && source.getAttribute('src')) { + videoUrl = source.getAttribute('src'); + } else if (videoEl.getAttribute('src')) { + videoUrl = videoEl.getAttribute('src'); + } + } + updateItemLinks(container, videoUrl); + } + + function renderVideoFromUrl(url) { + const container = ensurePreviewSlot(); + if (!container) return; + const safeUrl = url || ''; + const body = container.querySelector('.video-item-body'); + if (!body) return; + body.innerHTML = `\n \n `; + updateItemLinks(container, safeUrl); + } + + function handleDelta(text) { + if (!text) return; + if (text.includes('') || text.includes('')) { + return; + } + if (text.includes('超分辨率')) { + setStatus('connecting', '超分辨率中'); + setIndeterminate(true); + if (progressText) { + progressText.textContent = '超分辨率中'; + } + return; + } + + if (!collectingContent) { + const maybeVideo = text.includes(' { + setStatus('connected', '生成中'); + }; + + es.onmessage = (event) => { + if (!event || !event.data) return; + if (event.data === '[DONE]') { + finishRun(); + return; + } + let payload = null; + try { + payload = JSON.parse(event.data); + } catch (e) { + return; + } + if (payload && payload.error) { + toast(payload.error, 'error'); + setStatus('error', '生成失败'); + finishRun(true); + return; + } + const choice = payload.choices && payload.choices[0]; + const delta = choice && choice.delta ? choice.delta : null; + if (delta && delta.content) { + handleDelta(delta.content); + } + if (choice && choice.finish_reason === 'stop') { + finishRun(); + } + }; + + es.onerror = () => { + if (!isRunning) return; + setStatus('error', '连接错误'); + finishRun(true); + }; + } + + async function stopConnection() { + const authHeader = await ensurePublicKey(); + if (authHeader !== null) { + await stopVideoTask(currentTaskId, authHeader); + } + closeSource(); + isRunning = false; + currentTaskId = ''; + stopElapsedTimer(); + setButtons(false); + setStatus('', '未连接'); + } + + function finishRun(hasError) { + if (!isRunning) return; + closeSource(); + isRunning = false; + setButtons(false); + stopElapsedTimer(); + if (!hasError) { + setStatus('connected', '完成'); + setIndeterminate(false); + updateProgress(100); + } + if (durationValue && startAt) { + const seconds = Math.max(0, Math.round((Date.now() - startAt) / 1000)); + durationValue.textContent = `耗时 ${seconds}s`; + } + } + + if (startBtn) { + startBtn.addEventListener('click', () => startConnection()); + } + + if (stopBtn) { + stopBtn.addEventListener('click', () => stopConnection()); + } + + if (clearBtn) { + clearBtn.addEventListener('click', () => resetOutput()); + } + + if (videoStage) { + videoStage.addEventListener('click', async (event) => { + const target = event.target; + if (!(target instanceof HTMLElement)) return; + if (!target.classList.contains('video-download')) return; + event.preventDefault(); + const item = target.closest('.video-item'); + if (!item) return; + const url = item.dataset.url || target.dataset.url || ''; + const index = item.dataset.index || ''; + if (!url) return; + try { + const response = await fetch(url, { mode: 'cors' }); + if (!response.ok) { + throw new Error('download_failed'); + } + const blob = await response.blob(); + const blobUrl = URL.createObjectURL(blob); + const anchor = document.createElement('a'); + anchor.href = blobUrl; + anchor.download = index ? `grok_video_${index}.mp4` : 'grok_video.mp4'; + document.body.appendChild(anchor); + anchor.click(); + anchor.remove(); + URL.revokeObjectURL(blobUrl); + } catch (e) { + toast('下载失败,请检查视频链接是否可访问', 'error'); + } + }); + } + + if (imageFileInput) { + imageFileInput.addEventListener('change', () => { + const file = imageFileInput.files && imageFileInput.files[0]; + if (!file) { + clearFileSelection(); + return; + } + if (imageUrlInput && imageUrlInput.value.trim()) { + imageUrlInput.value = ''; + } + if (imageFileName) { + imageFileName.textContent = file.name; + } + const reader = new FileReader(); + reader.onload = () => { + if (typeof reader.result === 'string') { + fileDataUrl = reader.result; + } else { + fileDataUrl = ''; + toast('文件读取失败', 'error'); + } + }; + reader.onerror = () => { + fileDataUrl = ''; + toast('文件读取失败', 'error'); + }; + reader.readAsDataURL(file); + }); + } + + if (selectImageFileBtn && imageFileInput) { + selectImageFileBtn.addEventListener('click', () => { + imageFileInput.click(); + }); + } + + if (clearImageFileBtn) { + clearImageFileBtn.addEventListener('click', () => { + clearFileSelection(); + }); + } + + if (imageUrlInput) { + imageUrlInput.addEventListener('input', () => { + if (imageUrlInput.value.trim() && fileDataUrl) { + clearFileSelection(); + } + }); + } + + if (promptInput) { + promptInput.addEventListener('keydown', (event) => { + if ((event.metaKey || event.ctrlKey) && event.key === 'Enter') { + event.preventDefault(); + startConnection(); + } + }); + } + + updateMeta(); +})(); diff --git a/app/static/voice/voice.js b/app/static/public/js/voice.js similarity index 100% rename from app/static/voice/voice.js rename to app/static/public/js/voice.js diff --git a/app/static/imagine/imagine.html b/app/static/public/pages/imagine.html similarity index 63% rename from app/static/imagine/imagine.html rename to app/static/public/pages/imagine.html index bcd5c352..515057f0 100644 --- a/app/static/imagine/imagine.html +++ b/app/static/public/pages/imagine.html @@ -5,13 +5,13 @@ Grok2API - Imagine 瀑布流 - + - - - + + + @@ -28,17 +28,41 @@

Imagine 瀑布流

-
-
生成设置
+
-
-
+
生成设置
+
+
-
-
+
+
+
+
+
自动跟随
+
滚动到最新
+
+ +
+
+
+
自动保存
+
自动保存图片
+
+ +
+
+
+
+
-
+
- - 滚动到最新 - +
+
+
+
自动过滤
+
过滤不达标图片
-
- - + +
+
+
+
+
+ + +
+
+ +
-
- - +
+
+
+
+
反向新增
+
最新显示在上方
+
+
@@ -133,7 +171,10 @@

Imagine 瀑布流

-
等待连接中...
+
+
等待连接中
+
启动任务后,生成的图片会自动汇聚到这里
+
@@ -196,12 +237,12 @@

Imagine 瀑布流

- - - - + + + + - + diff --git a/app/static/public/login.html b/app/static/public/pages/login.html similarity index 83% rename from app/static/public/login.html rename to app/static/public/pages/login.html index bcd0a516..3cdc099d 100644 --- a/app/static/public/login.html +++ b/app/static/public/pages/login.html @@ -5,7 +5,7 @@ Grok2API - Public - + @@ -24,9 +24,9 @@ } } - - - + + + @@ -59,10 +59,10 @@ - - - - + + + + diff --git a/app/static/public/pages/video.html b/app/static/public/pages/video.html new file mode 100644 index 00000000..a4f45cbb --- /dev/null +++ b/app/static/public/pages/video.html @@ -0,0 +1,168 @@ + + + + + + + Grok2API - Video 生成 + + + + + + + + + + +
+
+ +
+
+
+
+

Video 生成

+

生成短视频,支持参考图与多种预设风格。

+
+
+ + +
+
+ +
+ +
+
+
生成设置
+
+
+ + +
+
+ +
+ +
+
+ 未选择文件 +
+
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + + +
+
+ + +
+
+
+ +
+
+
运行状态
+ 未连接 +
+
+
+
+
+
+ 0% + 耗时 - +
+
+
+
+
比例
+
-
+
+
+
时长
+
-
+
+
+
分辨率
+
-
+
+
+
预设
+
-
+
+
+
+
+ +
+
+
视频预览
+
+ +
+
+
等待生成视频
+ +
+
+
+ + + + + + + + + + + diff --git a/app/static/voice/voice.html b/app/static/public/pages/voice.html similarity index 92% rename from app/static/voice/voice.html rename to app/static/public/pages/voice.html index a11b1480..32994187 100644 --- a/app/static/voice/voice.html +++ b/app/static/public/pages/voice.html @@ -5,13 +5,13 @@ Grok2API - Voice Live - + - - - + + + @@ -147,11 +147,11 @@

Voice Live

- - - - - + + + + + diff --git a/config.defaults.toml b/config.defaults.toml index bbe3898c..24bce1f0 100644 --- a/config.defaults.toml +++ b/config.defaults.toml @@ -6,10 +6,10 @@ app_url = "http://127.0.0.1:8000" app_key = "grok2api" # API 调用密钥(可选) api_key = "" +# 是否启用 public 功能玩法 +public_enabled = false # Public 调用密钥(可选) public_key = "" -# 是否公开功能玩法(public 入口) -public_enabled = false # 生成图片的格式(url 或 base64) image_format = "url" # 生成视频的格式(html 或 url) @@ -53,9 +53,9 @@ retry_backoff_base = 0.5 # 退避倍率 retry_backoff_factor = 2.0 # 单次重试最大延迟(秒) -retry_backoff_max = 30.0 +retry_backoff_max = 20.0 # 总重试预算时间(秒) -retry_budget = 90.0 +retry_budget = 60.0 # ==================== Token 池管理 ==================== @@ -78,12 +78,12 @@ reload_interval_sec = 30 # 是否启用自动清理 enable_auto_clean = true # 缓存大小上限(MB) -limit_mb = 1024 +limit_mb = 512 # ==================== 对话配置 ==================== [chat] # Reverse 接口并发上限 -concurrent = 10 +concurrent = 50 # Reverse 接口超时时间(秒) timeout = 60 # 流式空闲超时时间(秒) @@ -92,9 +92,9 @@ stream_timeout = 60 # ==================== 图像配置 ==================== [image] # WebSocket 请求超时时间(秒) -timeout = 120 +timeout = 60 # WebSocket 流式空闲超时时间(秒) -stream_timeout = 120 +stream_timeout = 60 # 中等图后等待最终图的超时秒数 final_timeout = 15 # 是否启用 NSFW @@ -107,7 +107,7 @@ final_min_bytes = 100000 # ==================== 视频配置 ==================== [video] # Reverse 接口并发上限 -concurrent = 10 +concurrent = 100 # Reverse 接口超时时间(秒) timeout = 60 # 流式空闲超时时间(秒) @@ -116,44 +116,44 @@ stream_timeout = 60 # ==================== 语音配置 ==================== [voice] # Voice 请求超时时间(秒) -timeout = 120 +timeout = 60 # ==================== 资产配置 ==================== [asset] # 上传并发数 -upload_concurrent = 30 +upload_concurrent = 100 # 上传超时时间(秒) upload_timeout = 60 # 下载并发数 -download_concurrent = 30 +download_concurrent = 100 # 下载超时时间(秒) download_timeout = 60 # 资产查询并发数 -list_concurrent = 10 +list_concurrent = 100 # 资产查询超时时间(秒) list_timeout = 60 # 资产查询批次大小(Token 维度) -list_batch_size = 10 +list_batch_size = 50 # 资产删除并发数 -delete_concurrent = 10 +delete_concurrent = 100 # 资产删除超时时间(秒) delete_timeout = 60 # 资产删除批次大小(Token 维度) -delete_batch_size = 10 +delete_batch_size = 50 # ==================== NSFW ==================== [nsfw] # NSFW 批量开启并发上限 -concurrent = 10 +concurrent = 60 # NSFW 批量开启批次大小 -batch_size = 50 +batch_size = 30 # NSFW 请求超时时间(秒) timeout = 60 # ==================== 用量配置 ==================== [usage] # Usage 批量开启并发上限 -concurrent = 10 +concurrent = 100 # Usage 批量开启批次大小 batch_size = 50 # Usage 请求超时时间(秒) diff --git a/data/config.toml b/data/config.toml index d43d9efe..7ad6fadc 100644 --- a/data/config.toml +++ b/data/config.toml @@ -2,8 +2,8 @@ app_url = "http://127.0.0.1:8000" app_key = "grok2api" api_key = "" +public_enabled = true public_key = "" -public_enabled = false image_format = "url" video_format = "html" temporary = true @@ -20,34 +20,13 @@ cf_clearance = "" browser = "chrome136" user_agent = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36" -[voice] -timeout = 120 - -[chat] -concurrent = 10 -timeout = 60 -stream_timeout = 60 - -[video] -concurrent = 10 -timeout = 60 -stream_timeout = 60 - [retry] max_retry = 3 retry_status_codes = [401,429,403] retry_backoff_base = 0.5 -retry_backoff_factor = 2.0 -retry_backoff_max = 30.0 -retry_budget = 90.0 - -[image] -timeout = 120 -stream_timeout = 120 -final_timeout = 15 -nsfw = true -medium_min_bytes = 30000 -final_min_bytes = 100000 +retry_backoff_factor = 2 +retry_backoff_max = 30 +retry_budget = 90 [token] auto_refresh = true @@ -61,6 +40,27 @@ reload_interval_sec = 30 enable_auto_clean = true limit_mb = 1024 +[chat] +concurrent = 10 +timeout = 60 +stream_timeout = 60 + +[image] +timeout = 120 +stream_timeout = 120 +final_timeout = 15 +nsfw = true +medium_min_bytes = 30000 +final_min_bytes = 100000 + +[video] +concurrent = 10 +timeout = 60 +stream_timeout = 60 + +[voice] +timeout = 120 + [asset] upload_concurrent = 30 upload_timeout = 60 diff --git a/main.py b/main.py index d94218ec..f618d872 100644 --- a/main.py +++ b/main.py @@ -124,7 +124,7 @@ def create_app() -> FastAPI: # 注册管理与公共路由 from app.api.v1.admin import router as admin_router from app.api.v1.public import router as public_router - from app.api.v1.pages import router as pages_router + from app.api.pages import router as pages_router app.include_router(admin_router, prefix="/v1/admin") app.include_router(public_router, prefix="/v1/public") From 7b093288dc13dd4163378da7d403ba53d5a9dd8d Mon Sep 17 00:00:00 2001 From: Chenyme <118253778+chenyme@users.noreply.github.com> Date: Sun, 15 Feb 2026 18:01:34 +0800 Subject: [PATCH 23/27] docs: update README and documentation to reflect new features, including video generation and improved deployment instructions --- docs/README.en.md | 354 ++++++++++++++++++++++++---------------------- readme.md | 326 +++++++++++++++++++++--------------------- 2 files changed, 352 insertions(+), 328 deletions(-) diff --git a/docs/README.en.md b/docs/README.en.md index ed433999..e92283a8 100644 --- a/docs/README.en.md +++ b/docs/README.en.md @@ -5,102 +5,108 @@ > [!NOTE] > This project is for learning and research only. You must comply with Grok **Terms of Use** and **local laws and regulations**. Do not use for illegal purposes. -Grok2API rebuilt with **FastAPI**, fully aligned with the latest web call format. Supports streaming/non-streaming chat, image generation/editing, deep reasoning, token pool concurrency, and automatic load balancing. +Grok2API rebuilt with **FastAPI**, fully aligned with the latest web call format. Supports streaming/non-streaming chat, image generation/editing, video generation/upscale, deep reasoning, token pool concurrency, and automatic load balancing. -### NOTE: The project is no longer accepting PRs and feature updates; this is the last structure optimization. +> [!IMPORTANT] +> The project is no longer accepting PRs or new features; this is the last structure optimization. -image +image
-## Usage +## Quick Start -### How to Start +### Local -- Local development - -``` +```bash uv sync - uv run main.py ``` -### How to Deploy +### Docker Compose -#### docker compose -``` +```bash git clone https://github.com/chenyme/grok2api +cd grok2api docker compose up -d ``` -#### Vercel +### Vercel [![Deploy with Vercel](https://vercel.com/button)](https://vercel.com/new/clone?repository-url=https://github.com/chenyme/grok2api&env=LOG_LEVEL,LOG_FILE_ENABLED,DATA_DIR,SERVER_STORAGE_TYPE,SERVER_STORAGE_URL&envDefaults=%7B%22DATA_DIR%22%3A%22/tmp/data%22%2C%22LOG_FILE_ENABLED%22%3A%22false%22%2C%22LOG_LEVEL%22%3A%22INFO%22%2C%22SERVER_STORAGE_TYPE%22%3A%22local%22%2C%22SERVER_STORAGE_URL%22%3A%22%22%7D) -> Make sure to set `DATA_DIR=/tmp/data` and disable file logging with `LOG_FILE_ENABLED=false`. +> Set `DATA_DIR=/tmp/data` and disable file logs with `LOG_FILE_ENABLED=false`. > -> For persistence, use MySQL / Redis / PostgreSQL and set `SERVER_STORAGE_TYPE` (mysql/redis/pgsql) and `SERVER_STORAGE_URL` in Vercel env vars. +> For persistence, use MySQL / Redis / PostgreSQL and set `SERVER_STORAGE_TYPE` and `SERVER_STORAGE_URL`. -#### Render +### Render [![Deploy to Render](https://render.com/images/deploy-to-render-button.svg)](https://render.com/deploy?repo=https://github.com/chenyme/grok2api) -> Render free instances sleep after 15 minutes of inactivity; restart/redeploy will lose data. +> Render free instances sleep after 15 minutes of inactivity; redeploy/restart will lose data. > -> For persistence, use MySQL / Redis / PostgreSQL and set `SERVER_STORAGE_TYPE` (mysql/redis/pgsql) and `SERVER_STORAGE_URL` in Render env vars. +> For persistence, use MySQL / Redis / PostgreSQL and set `SERVER_STORAGE_TYPE` and `SERVER_STORAGE_URL`. + +
-### Admin Panel +## Admin Panel -Access: `http://:8000/admin` -Default password: `grok2api` (config `app.app_key`, recommended to change). +- Access: `http://:8000/admin` +- Default password: `grok2api` (config `app.app_key`, recommended to change) **Features**: - **Token Management**: import/add/delete tokens, view status and quota - **Status Filter**: filter by status (active/limited/expired) or NSFW status - **Batch Ops**: batch refresh/export/delete/enable NSFW -- **NSFW Enable**: one-click Unhinged for tokens (proxy or cf_clearance required) +- **NSFW Enable**: one-click Unhinged for tokens (proxy or `cf_clearance` required) - **Config Management**: update system config online - **Cache Management**: view and clear media cache -### Environment Variables +
+ +## Environment Variables > Configure `.env` -| Name | Description | Default | Example | -| :--------------------- | :-------------------------------------------------- | :---------- | :---------------------------------------------------- | -| `LOG_LEVEL` | Log level | `INFO` | `DEBUG` | -| `LOG_FILE_ENABLED` | Enable file logging | `true` | `false` | -| `DATA_DIR` | Data dir (config/tokens/locks) | `./data` | `/data` | -| `SERVER_HOST` | Bind address | `0.0.0.0` | `0.0.0.0` | -| `SERVER_PORT` | Server port | `8000` | `8000` | -| `SERVER_WORKERS` | Uvicorn worker count | `1` | `2` | -| `SERVER_STORAGE_TYPE` | Storage type (`local`/`redis`/`mysql`/`pgsql`) | `local` | `pgsql` | -| `SERVER_STORAGE_URL` | Storage DSN (optional for local) | `""` | `postgresql+asyncpg://user:password@host:5432/db` | +| Name | Description | Default | Example | +| :-- | :-- | :-- | :-- | +| `LOG_LEVEL` | Log level | `INFO` | `DEBUG` | +| `LOG_FILE_ENABLED` | Enable file logging | `true` | `false` | +| `DATA_DIR` | Data dir (config/tokens/locks) | `./data` | `/data` | +| `SERVER_HOST` | Bind address | `0.0.0.0` | `0.0.0.0` | +| `SERVER_PORT` | Server port | `8000` | `8000` | +| `SERVER_WORKERS` | Uvicorn worker count | `1` | `2` | +| `SERVER_STORAGE_TYPE` | Storage type (`local`/`redis`/`mysql`/`pgsql`) | `local` | `pgsql` | +| `SERVER_STORAGE_URL` | Storage DSN (optional for local) | `""` | `postgresql+asyncpg://user:password@host:5432/db` | + +> MySQL example: `mysql+aiomysql://user:password@host:3306/db` (if you provide `mysql://`, it will be converted to `mysql+aiomysql://`). -> MySQL example: `mysql+aiomysql://user:password@host:3306/db` (if you provide `mysql://`, it will be converted to `mysql+aiomysql://`) +
-### Quotas +## Quotas - Basic account: 80 requests / 20h - Super account: 140 requests / 2h -### Models - -| Model | Cost | Account | Chat | Image | Video | -| :---------------------- | :--: | :---------- | :--: | :---: | :---: | -| `grok-3` | 1 | Basic/Super | Yes | Yes | - | -| `grok-3-fast` | 1 | Basic/Super | Yes | Yes | - | -| `grok-4` | 1 | Basic/Super | Yes | Yes | - | -| `grok-4-mini` | 1 | Basic/Super | Yes | Yes | - | -| `grok-4-fast` | 1 | Basic/Super | Yes | Yes | - | -| `grok-4-heavy` | 4 | Super | Yes | Yes | - | -| `grok-4.1` | 1 | Basic/Super | Yes | Yes | - | -| `grok-4.1-thinking` | 4 | Basic/Super | Yes | Yes | - | -| `grok-imagine-1.0` | 4 | Basic/Super | - | Yes | - | -| `grok-imagine-1.0-edit` | 4 | Basic/Super | - | Yes | - | -| `grok-imagine-1.0-video`| - | Basic/Super | - | - | Yes | +
+ +## Models + +| Model | Cost | Account | Chat | Image | Video | +| :-- | :--: | :-- | :--: | :--: | :--: | +| `grok-3` | 1 | Basic/Super | Yes | Yes | - | +| `grok-3-fast` | 1 | Basic/Super | Yes | Yes | - | +| `grok-4` | 1 | Basic/Super | Yes | Yes | - | +| `grok-4-mini` | 1 | Basic/Super | Yes | Yes | - | +| `grok-4-fast` | 1 | Basic/Super | Yes | Yes | - | +| `grok-4-heavy` | 4 | Super | Yes | Yes | - | +| `grok-4.1` | 1 | Basic/Super | Yes | Yes | - | +| `grok-4.1-thinking` | 4 | Basic/Super | Yes | Yes | - | +| `grok-imagine-1.0` | 4 | Basic/Super | - | Yes | - | +| `grok-imagine-1.0-edit` | 4 | Basic/Super | - | Yes | - | +| `grok-imagine-1.0-video` | - | Basic/Super | - | - | Yes |
@@ -111,7 +117,10 @@ Default password: `grok2api` (config `app.app_key`, recommended to change). > Generic endpoint: chat, image generation, image editing, video generation, video upscaling ```bash -curl http://localhost:8000/v1/chat/completions -H "Content-Type: application/json" -H "Authorization: Bearer $GROK2API_API_KEY" -d '{ +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $GROK2API_API_KEY" \ + -d '{ "model": "grok-4", "messages": [{"role":"user","content":"Hello"}] }' @@ -122,41 +131,42 @@ curl http://localhost:8000/v1/chat/completions -H "Content-Type: application/j
-| Field | Type | Description | Allowed values | -| :--------------------- | :------ | :-------------------------- | :--------------------------------------------------------------------------------------------------------------- | -| `model` | string | Model ID | See model list above | -| `messages` | array | Message list | See message format below | -| `stream` | boolean | Enable streaming | `true`, `false` | -| `reasoning_effort` | string | Reasoning effort | `none`, `minimal`, `low`, `medium`, `high`, `xhigh` | -| `temperature` | number | Sampling temperature | `0` ~ `2` | -| `top_p` | number | Nucleus sampling | `0` ~ `1` | -| `video_config` | object | **Video model only** | Supported: `grok-imagine-1.0-video` | -| └─ `aspect_ratio` | string | Video aspect ratio | `16:9`, `9:16`, `1:1`, `2:3`, `3:2`, `1280x720`, `720x1280`, `1792x1024`, `1024x1792`, `1024x1024` | -| └─ `video_length` | integer | Video length (seconds) | `6`, `10`, `15` | -| └─ `resolution_name` | string | Resolution | `480p`, `720p` | -| └─ `preset` | string | Style preset | `fun`, `normal`, `spicy`, `custom` | -| `image_config` | object | **Image models only** | Supported: `grok-imagine-1.0` / `grok-imagine-1.0-edit` | -| └─ `n` | integer | Number of images | `1` ~ `10` | -| └─ `size` | string | Image size | `1280x720`, `720x1280`, `1792x1024`, `1024x1792`, `1024x1024` | -| └─ `response_format` | string | Response format | `url`, `b64_json`, `base64` | +| Field | Type | Description | Allowed values | +| :-- | :-- | :-- | :-- | +| `model` | string | Model ID | See model list above | +| `messages` | array | Message list | See message format below | +| `stream` | boolean | Enable streaming | `true`, `false` | +| `reasoning_effort` | string | Reasoning effort | `none`, `minimal`, `low`, `medium`, `high`, `xhigh` | +| `temperature` | number | Sampling temperature | `0` ~ `2` | +| `top_p` | number | Nucleus sampling | `0` ~ `1` | +| `video_config` | object | **Video model only** | Supported: `grok-imagine-1.0-video` | +| └─ `aspect_ratio` | string | Video aspect ratio | `16:9`, `9:16`, `1:1`, `2:3`, `3:2`, `1280x720`, `720x1280`, `1792x1024`, `1024x1792`, `1024x1024` | +| └─ `video_length` | integer | Video length (seconds) | `6`, `10`, `15` | +| └─ `resolution_name` | string | Resolution | `480p`, `720p` | +| └─ `preset` | string | Style preset | `fun`, `normal`, `spicy`, `custom` | +| `image_config` | object | **Image models only** | Supported: `grok-imagine-1.0` / `grok-imagine-1.0-edit` | +| └─ `n` | integer | Number of images | `1` ~ `10` | +| └─ `size` | string | Image size | `1280x720`, `720x1280`, `1792x1024`, `1024x1792`, `1024x1024` | +| └─ `response_format` | string | Response format | `url`, `b64_json`, `base64` | **Message format (messages)**: -| Field | Type | Description | -| :-------- | :----------- | :-------------------------------------------------- | -| `role` | string | `developer`, `system`, `user`, `assistant` | -| `content` | string/array | Message content (plain text or multimodal array) | +| Field | Type | Description | +| :-- | :-- | :-- | +| `role` | string | `developer`, `system`, `user`, `assistant` | +| `content` | string/array | Plain text or multimodal array | **Multimodal content block types (content array)**: -| type | Description | Example | -| :------------ | :---------- | :----------------------------------------------------------------------- | -| `text` | Text | `{"type": "text", "text": "Describe this image"}` | -| `image_url` | Image URL | `{"type": "image_url", "image_url": {"url": "https://..."}}` | -| `input_audio` | Audio | `{"type": "input_audio", "input_audio": {"data": "https://..."}}` | -| `file` | File | `{"type": "file", "file": {"file_data": "https://..."}}` | +| type | Description | Example | +| :-- | :-- | :-- | +| `text` | Text | `{"type": "text", "text": "Describe this image"}` | +| `image_url` | Image URL | `{"type": "image_url", "image_url": {"url": "https://..."}}` | +| `input_audio` | Audio | `{"type": "input_audio", "input_audio": {"data": "https://..."}}` | +| `file` | File | `{"type": "file", "file": {"file_data": "https://..."}}` | **Notes**: + - `image_url/input_audio/file` only supports URL or Data URI (`data:;base64,...`); raw base64 will be rejected. - `reasoning_effort`: `none` disables thinking output; any other value enables it. - `grok-imagine-1.0-edit` requires an image; if multiple are provided, the last image and last text are used. @@ -173,7 +183,10 @@ curl http://localhost:8000/v1/chat/completions -H "Content-Type: application/j > Image generation endpoint ```bash -curl http://localhost:8000/v1/images/generations -H "Content-Type: application/json" -H "Authorization: Bearer $GROK2API_API_KEY" -d '{ +curl http://localhost:8000/v1/images/generations \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $GROK2API_API_KEY" \ + -d '{ "model": "grok-imagine-1.0", "prompt": "A cat floating in space", "n": 1 @@ -185,18 +198,19 @@ curl http://localhost:8000/v1/images/generations -H "Content-Type: application
-| Field | Type | Description | Allowed values | -| :----------------- | :------ | :--------------- | :----------------------------------------------------------------- | -| `model` | string | Image model ID | `grok-imagine-1.0` | -| `prompt` | string | Prompt | - | -| `n` | integer | Number of images | `1` - `10` (streaming: `1` or `2` only) | -| `stream` | boolean | Enable streaming | `true`, `false` | -| `size` | string | Image size | `1280x720`, `720x1280`, `1792x1024`, `1024x1792`, `1024x1024` | -| `quality` | string | Image quality | - (not supported) | -| `response_format` | string | Response format | `url`, `b64_json`, `base64` | -| `style` | string | Style | - | +| Field | Type | Description | Allowed values | +| :-- | :-- | :-- | :-- | +| `model` | string | Image model ID | `grok-imagine-1.0` | +| `prompt` | string | Prompt | - | +| `n` | integer | Number of images | `1` - `10` (streaming: `1` or `2` only) | +| `stream` | boolean | Enable streaming | `true`, `false` | +| `size` | string | Image size | `1280x720`, `720x1280`, `1792x1024`, `1024x1792`, `1024x1024` | +| `quality` | string | Image quality | - (not supported) | +| `response_format` | string | Response format | `url`, `b64_json`, `base64` | +| `style` | string | Style | - (not supported) | **Notes**: + - `quality` and `style` are OpenAI compatibility placeholders and are not customizable yet.
@@ -210,7 +224,12 @@ curl http://localhost:8000/v1/images/generations -H "Content-Type: application > Image edit endpoint (multipart/form-data) ```bash -curl http://localhost:8000/v1/images/edits -H "Authorization: Bearer $GROK2API_API_KEY" -F "model=grok-imagine-1.0-edit" -F "prompt=Make it sharper" -F "image=@/path/to/image.png" -F "n=1" +curl http://localhost:8000/v1/images/edits \ + -H "Authorization: Bearer $GROK2API_API_KEY" \ + -F "model=grok-imagine-1.0-edit" \ + -F "prompt=Make the image clearer" \ + -F "image=@/path/to/image.png" \ + -F "n=1" ```
@@ -218,19 +237,20 @@ curl http://localhost:8000/v1/images/edits -H "Authorization: Bearer $GROK2API
-| Field | Type | Description | Allowed values | -| :----------------- | :------ | :--------------- | :----------------------------------------------------------------- | -| `model` | string | Image model ID | `grok-imagine-1.0-edit` | -| `prompt` | string | Edit prompt | - | -| `image` | file | Image file | `png`, `jpg`, `webp` | -| `n` | integer | Number of images | `1` - `10` (streaming: `1` or `2` only) | -| `stream` | boolean | Enable streaming | `true`, `false` | -| `size` | string | Image size | `1280x720`, `720x1280`, `1792x1024`, `1024x1792`, `1024x1024` | -| `quality` | string | Image quality | - (not supported) | -| `response_format` | string | Response format | `url`, `b64_json`, `base64` | -| `style` | string | Style | - (not supported) | +| Field | Type | Description | Allowed values | +| :-- | :-- | :-- | :-- | +| `model` | string | Image model ID | `grok-imagine-1.0-edit` | +| `prompt` | string | Edit prompt | - | +| `image` | file | Source image | `png`, `jpg`, `webp` | +| `n` | integer | Number of images | `1` - `10` (streaming: `1` or `2` only) | +| `stream` | boolean | Enable streaming | `true`, `false` | +| `size` | string | Image size | `1280x720`, `720x1280`, `1792x1024`, `1024x1792`, `1024x1024` | +| `quality` | string | Image quality | - (not supported) | +| `response_format` | string | Response format | `url`, `b64_json`, `base64` | +| `style` | string | Style | - (not supported) | **Notes**: + - `quality` and `style` are OpenAI compatibility placeholders and are not customizable yet.
@@ -244,74 +264,74 @@ curl http://localhost:8000/v1/images/edits -H "Authorization: Bearer $GROK2API Config file: `data/config.toml` > [!NOTE] -> In production or reverse proxy environments, set `app.app_url` to a publicly accessible URL, -> otherwise file links may be incorrect or return 403. +> In production or behind a reverse proxy, make sure `app.app_url` is a publicly accessible URL, +> otherwise file links may be incorrect or result in 403. > [!TIP] -> **v2.0 config structure upgrade**: legacy config will be **automatically migrated** to the new structure. -> Custom values under the old `[grok]` section are mapped to the new sections. - -| Module | Field | Name | Description | Default | -| :------------------- | :----------------------------- | :--------------------- | :----------------------------------------------------------------- | :---------------------------------------------------------- | -| **app** | `app_url` | App URL | External access URL for Grok2API (used for file links). | `http://127.0.0.1:8000` | -| | `app_key` | Admin password | Password for Grok2API admin panel (required). | `grok2api` | -| | `api_key` | API key | Token for calling Grok2API (optional). | `""` | -| | `image_format` | Image format | Output image format (url or base64). | `url` | -| | `video_format` | Video format | Output video format (html or url, url is processed). | `html` | -| | `temporary` | Temporary chat | Enable temporary conversation mode. | `true` | -| | `disable_memory` | Disable memory | Disable Grok memory to prevent irrelevant context. | `true` | -| | `stream` | Streaming | Enable streaming by default. | `true` | -| | `thinking` | Thinking chain | Enable model thinking output. | `true` | -| | `dynamic_statsig` | Dynamic fingerprint | Enable dynamic Statsig generation. | `true` | -| | `filter_tags` | Filter tags | Auto-filter special tags in Grok responses. | `["xaiartifact", "xai:tool_usage_card", "grok:render"]` | -| **proxy** | `base_proxy_url` | Base proxy URL | Base service address proxying Grok official site. | `""` | -| | `asset_proxy_url` | Asset proxy URL | Proxy URL for Grok static assets (images/videos). | `""` | -| | `cf_clearance` | CF Clearance | Cloudflare clearance cookie for anti-bot. | `""` | -| | `browser` | Browser fingerprint | curl_cffi browser fingerprint (e.g. chrome136). | `chrome136` | -| | `user_agent` | User-Agent | HTTP User-Agent string. | `Mozilla/5.0 (Macintosh; ...)` | -| **voice** | `timeout` | Request timeout | Voice request timeout (seconds). | `120` | -| **chat** | `concurrent` | Concurrency | Reverse interface concurrency limit. | `10` | -| | `timeout` | Request timeout | Reverse interface timeout (seconds). | `60` | -| | `stream_timeout` | Stream idle timeout | Stream idle timeout (seconds). | `60` | -| **video** | `concurrent` | Concurrency | Reverse interface concurrency limit. | `10` | -| | `timeout` | Request timeout | Reverse interface timeout (seconds). | `60` | -| | `stream_timeout` | Stream idle timeout | Stream idle timeout (seconds). | `60` | -| **retry** | `max_retry` | Max retries | Max retries on Grok request failure. | `3` | -| | `retry_status_codes` | Retry status codes | HTTP status codes that trigger retry. | `[401, 429, 403]` | -| | `retry_backoff_base` | Backoff base | Base delay for retry backoff (seconds). | `0.5` | -| | `retry_backoff_factor` | Backoff factor | Exponential multiplier for retry backoff. | `2.0` | -| | `retry_backoff_max` | Backoff max | Max wait per retry (seconds). | `30.0` | -| | `retry_budget` | Backoff budget | Max total retry time per request (seconds). | `90.0` | -| **image** | `timeout` | Request timeout | WebSocket request timeout (seconds). | `120` | -| | `stream_timeout` | Stream idle timeout | WebSocket stream idle timeout (seconds). | `120` | -| | `final_timeout` | Final image timeout | Timeout after medium image before final (seconds). | `15` | -| | `nsfw` | NSFW mode | Enable NSFW in WebSocket requests. | `true` | -| | `medium_min_bytes` | Medium min bytes | Minimum bytes for medium quality image. | `30000` | -| | `final_min_bytes` | Final min bytes | Minimum bytes to treat an image as final (JPG usually > 100KB). | `100000` | -| **token** | `auto_refresh` | Auto refresh | Enable automatic token refresh. | `true` | -| | `refresh_interval_hours` | Refresh interval | Regular token refresh interval (hours). | `8` | -| | `super_refresh_interval_hours` | Super refresh interval | Super token refresh interval (hours). | `2` | -| | `fail_threshold` | Failure threshold | Consecutive failures before a token is disabled. | `5` | -| | `save_delay_ms` | Save delay | Debounced save delay for token changes (ms). | `500` | -| | `reload_interval_sec` | Sync interval | Token state refresh interval in multi-worker setups (sec). | `30` | -| **cache** | `enable_auto_clean` | Auto clean | Enable cache auto clean; cleanup when exceeding limit. | `true` | -| | `limit_mb` | Cleanup threshold | Cache size threshold (MB) that triggers cleanup. | `1024` | -| **asset** | `upload_concurrent` | Upload concurrency | Max concurrency for upload. Recommended 30. | `30` | -| | `upload_timeout` | Upload timeout | Upload timeout (seconds). Recommended 60. | `60` | -| | `download_concurrent` | Download concurrency | Max concurrency for download. Recommended 30. | `30` | -| | `download_timeout` | Download timeout | Download timeout (seconds). Recommended 60. | `60` | -| | `list_concurrent` | List concurrency | Max concurrency for asset listing. Recommended 10. | `10` | -| | `list_timeout` | List timeout | List timeout (seconds). Recommended 60. | `60` | -| | `list_batch_size` | List batch size | Batch size per list request. Recommended 10. | `10` | -| | `delete_concurrent` | Delete concurrency | Max concurrency for asset delete. Recommended 10. | `10` | -| | `delete_timeout` | Delete timeout | Delete timeout (seconds). Recommended 60. | `60` | -| | `delete_batch_size` | Delete batch size | Batch size per delete request. Recommended 10. | `10` | -| **nsfw** | `concurrent` | Concurrency | Max concurrency for enabling NSFW. Recommended 10. | `10` | -| | `batch_size` | Batch size | Batch size for enabling NSFW. Recommended 50. | `50` | -| | `timeout` | Request timeout | NSFW enable request timeout (seconds). Recommended 60. | `60` | -| **usage** | `concurrent` | Concurrency | Max concurrency for usage refresh. Recommended 10. | `10` | -| | `batch_size` | Batch size | Batch size for usage refresh. Recommended 50. | `50` | -| | `timeout` | Request timeout | Usage query timeout (seconds). Recommended 60. | `60` | +> **v2.0 config migration**: old configs are automatically migrated. The old `[grok]` section +> is mapped into the new config structure. + +| Module | Field | Key | Description | Default | +| :-- | :-- | :-- | :-- | :-- | +| **app** | `app_url` | App URL | External base URL used for file links. | `http://127.0.0.1:8000` | +| | `app_key` | Admin password | Login password for admin panel. | `grok2api` | +| | `api_key` | API key | Optional API key for access. | `""` | +| | `image_format` | Image format | `url` or `base64`. | `url` | +| | `video_format` | Video format | `html` or `url` (processed link). | `html` | +| | `temporary` | Temporary chat | Enable temporary chat mode. | `true` | +| | `disable_memory` | Disable memory | Disable Grok memory. | `true` | +| | `stream` | Stream | Enable streaming by default. | `true` | +| | `thinking` | Thinking | Enable reasoning output. | `true` | +| | `dynamic_statsig` | Dynamic statsig | Generate dynamic Statsig values. | `true` | +| | `filter_tags` | Filter tags | Filter special tags in responses. | `["xaiartifact", "xai:tool_usage_card", "grok:render"]` | +| **proxy** | `base_proxy_url` | Base proxy URL | Proxy to Grok web. | `""` | +| | `asset_proxy_url` | Asset proxy URL | Proxy to Grok assets (img/video). | `""` | +| | `cf_clearance` | CF Clearance | Cloudflare clearance cookie. | `""` | +| | `browser` | Browser fingerprint | curl_cffi fingerprint (e.g. chrome136). | `chrome136` | +| | `user_agent` | User-Agent | HTTP User-Agent string. | `Mozilla/5.0 (Macintosh; ...)` | +| **voice** | `timeout` | Timeout | Voice request timeout (seconds). | `120` | +| **chat** | `concurrent` | Concurrency | Reverse interface concurrency limit. | `10` | +| | `timeout` | Timeout | Reverse request timeout (seconds). | `60` | +| | `stream_timeout` | Stream idle timeout | Stream idle timeout (seconds). | `60` | +| **video** | `concurrent` | Concurrency | Reverse interface concurrency limit. | `10` | +| | `timeout` | Timeout | Reverse request timeout (seconds). | `60` | +| | `stream_timeout` | Stream idle timeout | Stream idle timeout (seconds). | `60` | +| **retry** | `max_retry` | Max retry | Max retries for upstream failures. | `3` | +| | `retry_status_codes` | Retry codes | HTTP status codes that trigger retry. | `[401, 429, 403]` | +| | `retry_backoff_base` | Backoff base | Retry backoff base seconds. | `0.5` | +| | `retry_backoff_factor` | Backoff factor | Exponential backoff factor. | `2.0` | +| | `retry_backoff_max` | Backoff max | Max delay per retry (seconds). | `30.0` | +| | `retry_budget` | Retry budget | Max total retry time (seconds). | `90.0` | +| **image** | `timeout` | Timeout | WebSocket timeout (seconds). | `120` | +| | `stream_timeout` | Stream idle timeout | WS stream idle timeout (seconds). | `120` | +| | `final_timeout` | Final timeout | Wait time after medium image (seconds). | `15` | +| | `nsfw` | NSFW | Enable NSFW. | `true` | +| | `medium_min_bytes` | Medium min bytes | Minimum size for medium image. | `30000` | +| | `final_min_bytes` | Final min bytes | Minimum size for final image (JPG > 100KB typical). | `100000` | +| **token** | `auto_refresh` | Auto refresh | Enable token auto refresh. | `true` | +| | `refresh_interval_hours` | Refresh interval | Basic token refresh interval (hours). | `8` | +| | `super_refresh_interval_hours` | Super refresh interval | Super token refresh interval (hours). | `2` | +| | `fail_threshold` | Fail threshold | Consecutive failures to disable. | `5` | +| | `save_delay_ms` | Save delay | Merge write delay (ms). | `500` | +| | `reload_interval_sec` | Reload interval | Multi-worker token reload interval (seconds). | `30` | +| **cache** | `enable_auto_clean` | Auto clean | Enable cache auto cleanup. | `true` | +| | `limit_mb` | Size limit | Cleanup threshold (MB). | `1024` | +| **asset** | `upload_concurrent` | Upload concurrency | Max upload concurrency (recommended 30). | `30` | +| | `upload_timeout` | Upload timeout | Upload timeout (seconds). | `60` | +| | `download_concurrent` | Download concurrency | Max download concurrency (recommended 30). | `30` | +| | `download_timeout` | Download timeout | Download timeout (seconds). | `60` | +| | `list_concurrent` | List concurrency | Max list concurrency (recommended 10). | `10` | +| | `list_timeout` | List timeout | List timeout (seconds). | `60` | +| | `list_batch_size` | List batch size | Tokens per list batch (recommended 10). | `10` | +| | `delete_concurrent` | Delete concurrency | Max delete concurrency (recommended 10). | `10` | +| | `delete_timeout` | Delete timeout | Delete timeout (seconds). | `60` | +| | `delete_batch_size` | Delete batch size | Tokens per delete batch (recommended 10). | `10` | +| **nsfw** | `concurrent` | Concurrency | NSFW batch enable concurrency (recommended 10). | `10` | +| | `batch_size` | Batch size | NSFW batch size (recommended 50). | `50` | +| | `timeout` | Timeout | NSFW request timeout (seconds). | `60` | +| **usage** | `concurrent` | Concurrency | Usage refresh concurrency (recommended 10). | `10` | +| | `batch_size` | Batch size | Usage batch size (recommended 50). | `50` | +| | `timeout` | Timeout | Usage request timeout (seconds). | `60` |
diff --git a/readme.md b/readme.md index 3850d655..8fce25c1 100644 --- a/readme.md +++ b/readme.md @@ -5,103 +5,105 @@ > [!NOTE] > 本项目仅供学习与研究,使用者必须在遵循 Grok 的 **使用条款** 以及 **法律法规** 的情况下使用,不得用于非法用途。 -基于 **FastAPI** 重构的 Grok2API,全面适配最新 Web 调用格式,支持流/非流式对话、图像生成/编辑、深度思考,号池并发与自动负载均衡一体化。 +基于 **FastAPI** 重构的 Grok2API,全面适配最新 Web 调用格式,支持流/非流式对话、图像生成/编辑、视频生成/超分、深度思考,号池并发与自动负载均衡一体化。 -### NOTE:项目近期停止接受 PR 和 暂停功能更新,最后优化一次项目结构~ - -image +image
-## 使用说明 - -### 如何启动 +## 快速开始 -- 本地开发 +### 本地开发 -``` +```bash uv sync - uv run main.py ``` -### 如何部署 +### Docker Compose - -#### docker compose 部署 -``` +```bash git clone https://github.com/chenyme/grok2api +cd grok2api docker compose up -d ``` -#### Vercel 部署 +### Vercel 部署 [![Deploy with Vercel](https://vercel.com/button)](https://vercel.com/new/clone?repository-url=https://github.com/chenyme/grok2api&env=LOG_LEVEL,LOG_FILE_ENABLED,DATA_DIR,SERVER_STORAGE_TYPE,SERVER_STORAGE_URL&envDefaults=%7B%22DATA_DIR%22%3A%22/tmp/data%22%2C%22LOG_FILE_ENABLED%22%3A%22false%22%2C%22LOG_LEVEL%22%3A%22INFO%22%2C%22SERVER_STORAGE_TYPE%22%3A%22local%22%2C%22SERVER_STORAGE_URL%22%3A%22%22%7D) -> 请务必设置 DATA_DIR=/tmp/data,并关闭文件日志 LOG_FILE_ENABLED=false。 +> 请务必设置 `DATA_DIR=/tmp/data` 并关闭文件日志 `LOG_FILE_ENABLED=false`。 > -> 持久化请使用 MySQL / Redis / PostgreSQL,在 Vercel 环境变量中设置:SERVER_STORAGE_TYPE(mysql/redis/pgsql)与 SERVER_STORAGE_URL。 +> 持久化请使用 MySQL / Redis / PostgreSQL,并设置:`SERVER_STORAGE_TYPE` 与 `SERVER_STORAGE_URL`。 -#### Render 部署 +### Render 部署 [![Deploy to Render](https://render.com/images/deploy-to-render-button.svg)](https://render.com/deploy?repo=https://github.com/chenyme/grok2api) -> Render 免费实例 15 分钟无访问会休眠,恢复/重启/重新部署会丢失。 +> Render 免费实例 15 分钟无访问会休眠;重启/重新部署会丢失数据。 > -> 持久化请使用 MySQL / Redis / PostgreSQL,在 Render 环境变量中设置:SERVER_STORAGE_TYPE(mysql/redis/pgsql)与 SERVER_STORAGE_URL。 +> 持久化请使用 MySQL / Redis / PostgreSQL,并设置:`SERVER_STORAGE_TYPE` 与 `SERVER_STORAGE_URL`。 + +
-### 管理面板 +## 管理面板 -访问地址:`http://:8000/admin` -默认登录密码:`grok2api`(对应配置项 `app.app_key`,建议修改)。 +- 访问地址:`http://:8000/admin` +- 默认密码:`grok2api`(配置项 `app.app_key`,建议修改) **功能说明**: - **Token 管理**:导入/添加/删除 Token,查看状态和配额 - **状态筛选**:按状态(正常/限流/失效)或 NSFW 状态筛选 - **批量操作**:批量刷新、导出、删除、开启 NSFW -- **NSFW 开启**:一键为 Token 开启 Unhinged 模式(需代理或 cf_clearance) +- **NSFW 开启**:一键为 Token 开启 Unhinged 模式(需代理或 `cf_clearance`) - **配置管理**:在线修改系统配置 - **缓存管理**:查看和清理媒体缓存 -### 环境变量 +
+ +## 环境变量 > 配置 `.env` 文件 -| 变量名 | 说明 | 默认值 | 示例 | -| :---------------------- | :-------------------------------------------------- | :---------- | :-------------------------------------------------- | -| `LOG_LEVEL` | 日志级别 | `INFO` | `DEBUG` | -| `LOG_FILE_ENABLED` | 是否启用文件日志 | `true` | `false` | -| `DATA_DIR` | 数据目录(配置/Token/锁) | `./data` | `/data` | -| `SERVER_HOST` | 服务监听地址 | `0.0.0.0` | `0.0.0.0` | -| `SERVER_PORT` | 服务端口 | `8000` | `8000` | -| `SERVER_WORKERS` | Uvicorn worker 数量 | `1` | `2` | -| `SERVER_STORAGE_TYPE` | 存储类型(`local`/`redis`/`mysql`/`pgsql`) | `local` | `pgsql` | -| `SERVER_STORAGE_URL` | 存储连接串(local 时可为空) | `""` | `postgresql+asyncpg://user:password@host:5432/db` | +| 变量名 | 说明 | 默认值 | 示例 | +| :-- | :-- | :-- | :-- | +| `LOG_LEVEL` | 日志级别 | `INFO` | `DEBUG` | +| `LOG_FILE_ENABLED` | 是否启用文件日志 | `true` | `false` | +| `DATA_DIR` | 数据目录(配置/Token/锁) | `./data` | `/data` | +| `SERVER_HOST` | 服务监听地址 | `0.0.0.0` | `0.0.0.0` | +| `SERVER_PORT` | 服务端口 | `8000` | `8000` | +| `SERVER_WORKERS` | Uvicorn worker 数量 | `1` | `2` | +| `SERVER_STORAGE_TYPE` | 存储类型(`local`/`redis`/`mysql`/`pgsql`) | `local` | `pgsql` | +| `SERVER_STORAGE_URL` | 存储连接串(local 时可为空) | `""` | `postgresql+asyncpg://user:password@host:5432/db` | > MySQL 示例:`mysql+aiomysql://user:password@host:3306/db`(若填 `mysql://` 会自动转为 `mysql+aiomysql://`) -### 可用次数 +
+ +## 可用次数 - Basic 账号:80 次 / 20h - Super 账号:140 次 / 2h -### 可用模型 - -| 模型名 | 计次 | 可用账号 | 对话功能 | 图像功能 | 视频功能 | -| :------------------------- | :--: | :---------- | :------: | :------: | :------: | -| `grok-3` | 1 | Basic/Super | 支持 | 支持 | - | -| `grok-3-fast` | 1 | Basic/Super | 支持 | 支持 | - | -| `grok-4` | 1 | Basic/Super | 支持 | 支持 | - | -| `grok-4-mini` | 1 | Basic/Super | 支持 | 支持 | - | -| `grok-4-fast` | 1 | Basic/Super | 支持 | 支持 | - | -| `grok-4-heavy` | 4 | Super | 支持 | 支持 | - | -| `grok-4.1` | 1 | Basic/Super | 支持 | 支持 | - | -| `grok-4.1-thinking` | 4 | Basic/Super | 支持 | 支持 | - | -| `grok-imagine-1.0` | 4 | Basic/Super | - | 支持 | - | -| `grok-imagine-1.0-edit` | 4 | Basic/Super | - | 支持 | - | -| `grok-imagine-1.0-video` | - | Basic/Super | - | - | 支持 | +
+ +## 可用模型 + +| 模型名 | 计次 | 可用账号 | 对话功能 | 图像功能 | 视频功能 | +| :-- | :--: | :-- | :--: | :--: | :--: | +| `grok-3` | 1 | Basic/Super | 支持 | 支持 | - | +| `grok-3-fast` | 1 | Basic/Super | 支持 | 支持 | - | +| `grok-4` | 1 | Basic/Super | 支持 | 支持 | - | +| `grok-4-mini` | 1 | Basic/Super | 支持 | 支持 | - | +| `grok-4-fast` | 1 | Basic/Super | 支持 | 支持 | - | +| `grok-4-heavy` | 4 | Super | 支持 | 支持 | - | +| `grok-4.1` | 1 | Basic/Super | 支持 | 支持 | - | +| `grok-4.1-thinking` | 4 | Basic/Super | 支持 | 支持 | - | +| `grok-imagine-1.0` | 4 | Basic/Super | - | 支持 | - | +| `grok-imagine-1.0-edit` | 4 | Basic/Super | - | 支持 | - | +| `grok-imagine-1.0-video` | - | Basic/Super | - | - | 支持 |
@@ -126,44 +128,44 @@ curl http://localhost:8000/v1/chat/completions \
-| 字段 | 类型 | 说明 | 可用参数 | -| :---------------------- | :------ | :----------------------------- | :------------------------------------------------------------------------------------------------- | -| `model` | string | 模型名称 | 见上方模型列表 | -| `messages` | array | 消息列表 | 见下方消息格式 | -| `stream` | boolean | 是否开启流式输出 | `true`, `false` | -| `reasoning_effort` | string | 推理强度 | `none`, `minimal`, `low`, `medium`, `high`, `xhigh` | -| `temperature` | number | 采样温度 | `0` ~ `2` | -| `top_p` | number | nucleus 采样 | `0` ~ `1` | -| `video_config` | object | **视频模型专用配置对象** | 支持:`grok-imagine-1.0-video` | -| └─`aspect_ratio` | string | 视频宽高比 | `16:9`, `9:16`, `1:1`, `2:3`, `3:2`, `1280x720`, `720x1280`, `1792x1024`, `1024x1792`, `1024x1024` | -| └─`video_length` | integer | 视频时长 (秒) | `6`, `10`, `15` | -| └─`resolution_name` | string | 分辨率 | `480p`, `720p` | -| └─`preset` | string | 风格预设 | `fun`, `normal`, `spicy`, `custom` | -| `image_config` | object | **图片模型专用配置对象** | 支持:`grok-imagine-1.0` / `grok-imagine-1.0-edit` | -| └─`n` | integer | 生成数量 | `1` ~ `10` | -| └─`size` | string | 图片尺寸 | `1280x720`, `720x1280`, `1792x1024`, `1024x1792`, `1024x1024` | -| └─`response_format` | string | 响应格式 | `url`, `b64_json`, `base64` | +| 字段 | 类型 | 说明 | 可用参数 | +| :-- | :-- | :-- | :-- | +| `model` | string | 模型名称 | 见上方模型列表 | +| `messages` | array | 消息列表 | 见下方消息格式 | +| `stream` | boolean | 是否开启流式输出 | `true`, `false` | +| `reasoning_effort` | string | 推理强度 | `none`, `minimal`, `low`, `medium`, `high`, `xhigh` | +| `temperature` | number | 采样温度 | `0` ~ `2` | +| `top_p` | number | nucleus 采样 | `0` ~ `1` | +| `video_config` | object | **视频模型专用配置对象** | 支持:`grok-imagine-1.0-video` | +| └─`aspect_ratio` | string | 视频宽高比 | `16:9`, `9:16`, `1:1`, `2:3`, `3:2`, `1280x720`, `720x1280`, `1792x1024`, `1024x1792`, `1024x1024` | +| └─`video_length` | integer | 视频时长 (秒) | `6`, `10`, `15` | +| └─`resolution_name` | string | 分辨率 | `480p`, `720p` | +| └─`preset` | string | 风格预设 | `fun`, `normal`, `spicy`, `custom` | +| `image_config` | object | **图片模型专用配置对象** | 支持:`grok-imagine-1.0` / `grok-imagine-1.0-edit` | +| └─`n` | integer | 生成数量 | `1` ~ `10` | +| └─`size` | string | 图片尺寸 | `1280x720`, `720x1280`, `1792x1024`, `1024x1792`, `1024x1024` | +| └─`response_format` | string | 响应格式 | `url`, `b64_json`, `base64` | **消息格式 (messages)**: -| 字段 | 类型 | 说明 | -| :---------- | :----------- | :------------------------------------------------------- | -| `role` | string | 角色:`developer`, `system`, `user`, `assistant` | -| `content` | string/array | 消息内容,支持纯文本或多模态数组 | +| 字段 | 类型 | 说明 | +| :-- | :-- | :-- | +| `role` | string | 角色:`developer`, `system`, `user`, `assistant` | +| `content` | string/array | 消息内容,支持纯文本或多模态数组 | **多模态内容块类型 (content array)**: -| type | 说明 | 示例 | -| :------------ | :------- | :---------------------------------------------------------------- | -| `text` | 文本内容 | `{"type": "text", "text": "描述这张图片"}` | -| `image_url` | 图片 URL | `{"type": "image_url", "image_url": {"url": "https://..."}}` | -| `input_audio` | 音频 | `{"type": "input_audio", "input_audio": {"data": "https://..."}}` | -| `file` | 文件 | `{"type": "file", "file": {"file_data": "https://..."}}` | +| type | 说明 | 示例 | +| :-- | :-- | :-- | +| `text` | 文本内容 | `{"type": "text", "text": "描述这张图片"}` | +| `image_url` | 图片 URL | `{"type": "image_url", "image_url": {"url": "https://..."}}` | +| `input_audio` | 音频 | `{"type": "input_audio", "input_audio": {"data": "https://..."}}` | +| `file` | 文件 | `{"type": "file", "file": {"file_data": "https://..."}}` | **注意事项**: - `image_url/input_audio/file` 仅支持 URL 或 Data URI(`data:;base64,...`),裸 base64 会报错。 -- `reasoning_effort`:`none` 表示不输出思考,其他值都会输出思考内容。 +- `reasoning_effort`:`none` 表示不输出思考,其他值都会输出思考内容。 - `grok-imagine-1.0-edit` 必须提供图片,多图默认取最后一张与最后一个文本。 - 除上述外的其他参数将自动丢弃并忽略。 @@ -193,18 +195,19 @@ curl http://localhost:8000/v1/images/generations \
-| 字段 | 类型 | 说明 | 可用参数 | -| :------------------ | :------ | :--------------- | :------------------------------------------------------------ | -| `model` | string | 图像模型名 | `grok-imagine-1.0` | -| `prompt` | string | 图像描述提示词 | - | -| `n` | integer | 生成数量 | `1` - `10` (流式模式仅限 `1` 或 `2`) | -| `stream` | boolean | 是否开启流式输出 | `true`, `false` | -| `size` | string | 图片尺寸 | `1280x720`, `720x1280`, `1792x1024`, `1024x1792`, `1024x1024` | -| `quality` | string | 图片质量 | - (暂不支持) | -| `response_format` | string | 响应格式 | `url`, `b64_json`, `base64` | -| `style` | string | 风格 | - | +| 字段 | 类型 | 说明 | 可用参数 | +| :-- | :-- | :-- | :-- | +| `model` | string | 图像模型名 | `grok-imagine-1.0` | +| `prompt` | string | 图像描述提示词 | - | +| `n` | integer | 生成数量 | `1` - `10` (流式模式仅限 `1` 或 `2`) | +| `stream` | boolean | 是否开启流式输出 | `true`, `false` | +| `size` | string | 图片尺寸 | `1280x720`, `720x1280`, `1792x1024`, `1024x1792`, `1024x1024` | +| `quality` | string | 图片质量 | - (暂不支持) | +| `response_format` | string | 响应格式 | `url`, `b64_json`, `base64` | +| `style` | string | 风格 | - (暂不支持) | **注意事项**: + - `quality`、`style` 参数为 OpenAI 兼容保留,当前版本暂不支持自定义。
@@ -231,19 +234,20 @@ curl http://localhost:8000/v1/images/edits \
-| 字段 | 类型 | 说明 | 可用参数 | -| :------------------ | :------ | :--------------- | :------------------------------------------------------------ | -| `model` | string | 图像模型名 | `grok-imagine-1.0-edit` | -| `prompt` | string | 编辑描述 | - | -| `image` | file | 待编辑图片 | `png`, `jpg`, `webp` | -| `n` | integer | 生成数量 | `1` - `10` (流式模式仅限 `1` 或 `2`) | -| `stream` | boolean | 是否开启流式输出 | `true`, `false` | -| `size` | string | 图片尺寸 | `1280x720`, `720x1280`, `1792x1024`, `1024x1792`, `1024x1024` | -| `quality` | string | 图片质量 | - (暂不支持) | -| `response_format` | string | 响应格式 | `url`, `b64_json`, `base64` | -| `style` | string | 风格 | - (暂不支持) | +| 字段 | 类型 | 说明 | 可用参数 | +| :-- | :-- | :-- | :-- | +| `model` | string | 图像模型名 | `grok-imagine-1.0-edit` | +| `prompt` | string | 编辑描述 | - | +| `image` | file | 待编辑图片 | `png`, `jpg`, `webp` | +| `n` | integer | 生成数量 | `1` - `10` (流式模式仅限 `1` 或 `2`) | +| `stream` | boolean | 是否开启流式输出 | `true`, `false` | +| `size` | string | 图片尺寸 | `1280x720`, `720x1280`, `1792x1024`, `1024x1792`, `1024x1024` | +| `quality` | string | 图片质量 | - (暂不支持) | +| `response_format` | string | 响应格式 | `url`, `b64_json`, `base64` | +| `style` | string | 风格 | - (暂不支持) | **注意事项**: + - `quality`、`style` 参数为 OpenAI 兼容保留,当前版本暂不支持自定义。
@@ -264,67 +268,67 @@ curl http://localhost:8000/v1/images/edits \ > **v2.0 配置结构升级**:旧版本用户更新后,配置会**自动迁移**到新结构,无需手动修改。 > 旧的 `[grok]` 配置节中的自定义值会自动映射到对应的新配置节。 -| 模块 | 字段 | 配置名 | 说明 | 默认值 | -| :-------------------- | :------------------------------- | :----------------- | :---------------------------------------------------- | :-------------------------------------------------------- | -| **app** | `app_url` | 应用地址 | 当前 Grok2API 服务的外部访问 URL,用于文件链接访问。 | `http://127.0.0.1:8000` | -| | `app_key` | 后台密码 | 登录 Grok2API 管理后台的密码(必填)。 | `grok2api` | -| | `api_key` | API 密钥 | 调用 Grok2API 服务的 Token(可选)。 | `""` | -| | `image_format` | 图片格式 | 生成的图片格式(url 或 base64)。 | `url` | -| | `video_format` | 视频格式 | 生成的视频格式(html 或 url,url 为处理后的链接)。 | `html` | -| | `temporary` | 临时对话 | 是否启用临时对话模式。 | `true` | -| | `disable_memory` | 禁用记忆 | 禁用 Grok 记忆功能,防止响应中出现不相关上下文。 | `true` | -| | `stream` | 流式响应 | 是否默认启用流式输出。 | `true` | -| | `thinking` | 思维链 | 是否启用模型思维链输出。 | `true` | -| | `dynamic_statsig` | 动态指纹 | 是否启用动态生成 Statsig 值。 | `true` | -| | `filter_tags` | 过滤标签 | 自动过滤 Grok 响应中的特殊标签。 | `["xaiartifact", "xai:tool_usage_card", "grok:render"]` | -| **proxy** | `base_proxy_url` | 基础代理 URL | 代理请求到 Grok 官网的基础服务地址。 | `""` | -| | `asset_proxy_url` | 资源代理 URL | 代理请求到 Grok 官网的静态资源(图片/视频)地址。 | `""` | -| | `cf_clearance` | CF Clearance | Cloudflare 验证 Cookie,用于绕过反爬虫验证。 | `""` | -| | `browser` | 浏览器指纹 | curl_cffi 浏览器指纹标识(如 chrome136)。 | `chrome136` | -| | `user_agent` | User-Agent | HTTP 请求的 User-Agent 字符串。 | `Mozilla/5.0 (Macintosh; ...)` | -| **voice** | `timeout` | 请求超时 | Voice 请求超时时间(秒)。 | `120` | -| **chat** | `concurrent` | 并发上限 | Reverse 接口并发上限。 | `10` | -| | `timeout` | 请求超时 | Reverse 接口超时时间(秒)。 | `60` | -| | `stream_timeout` | 流空闲超时 | 流式空闲超时时间(秒)。 | `60` | -| **video** | `concurrent` | 并发上限 | Reverse 接口并发上限。 | `10` | -| | `timeout` | 请求超时 | Reverse 接口超时时间(秒)。 | `60` | -| | `stream_timeout` | 流空闲超时 | 流式空闲超时时间(秒)。 | `60` | -| **retry** | `max_retry` | 最大重试 | 请求 Grok 服务失败时的最大重试次数。 | `3` | -| | `retry_status_codes` | 重试状态码 | 触发重试的 HTTP 状态码列表。 | `[401, 429, 403]` | -| | `retry_backoff_base` | 退避基数 | 重试退避的基础延迟(秒)。 | `0.5` | -| | `retry_backoff_factor` | 退避倍率 | 重试退避的指数放大系数。 | `2.0` | -| | `retry_backoff_max` | 退避上限 | 单次重试等待的最大延迟(秒)。 | `30.0` | -| | `retry_budget` | 退避预算 | 单次请求的最大重试总耗时(秒)。 | `90.0` | -| **image** | `timeout` | 请求超时 | WebSocket 请求超时时间(秒)。 | `120` | -| | `stream_timeout` | 流空闲超时 | WebSocket 流式空闲超时时间(秒)。 | `120` | -| | `final_timeout` | 最终图超时 | 收到中等图后等待最终图的超时秒数。 | `15` | -| | `nsfw` | NSFW 模式 | WebSocket 请求是否启用 NSFW。 | `true` | -| | `medium_min_bytes` | 中等图最小字节 | 判定中等质量图的最小字节数。 | `30000` | -| | `final_min_bytes` | 最终图最小字节 | 判定最终图的最小字节数(通常 JPG > 100KB)。 | `100000` | -| **token** | `auto_refresh` | 自动刷新 | 是否开启 Token 自动刷新机制。 | `true` | -| | `refresh_interval_hours` | 刷新间隔 | 普通 Token 刷新的时间间隔(小时)。 | `8` | -| | `super_refresh_interval_hours` | Super 刷新间隔 | Super Token 刷新的时间间隔(小时)。 | `2` | -| | `fail_threshold` | 失败阈值 | 单个 Token 连续失败多少次后被标记为不可用。 | `5` | -| | `save_delay_ms` | 保存延迟 | Token 变更合并写入的延迟(毫秒)。 | `500` | -| | `reload_interval_sec` | 同步间隔 | 多 worker 场景下 Token 状态刷新间隔(秒)。 | `30` | -| **cache** | `enable_auto_clean` | 自动清理 | 是否启用缓存自动清理,开启后按上限自动回收。 | `true` | -| | `limit_mb` | 清理阈值 | 缓存大小阈值(MB),超过阈值会触发清理。 | `1024` | -| **asset** | `upload_concurrent` | 上传并发 | 上传接口的最大并发数。推荐 30。 | `30` | -| | `upload_timeout` | 上传超时 | 上传接口超时时间(秒)。推荐 60。 | `60` | -| | `download_concurrent` | 下载并发 | 下载接口的最大并发数。推荐 30。 | `30` | -| | `download_timeout` | 下载超时 | 下载接口超时时间(秒)。推荐 60。 | `60` | -| | `list_concurrent` | 查询并发 | 资产查询接口的最大并发数。推荐 10。 | `10` | -| | `list_timeout` | 查询超时 | 资产查询接口超时时间(秒)。推荐 60。 | `60` | -| | `list_batch_size` | 查询批次大小 | 单次查询可处理的 Token 数量。推荐 10。 | `10` | -| | `delete_concurrent` | 删除并发 | 资产删除接口的最大并发数。推荐 10。 | `10` | -| | `delete_timeout` | 删除超时 | 资产删除接口超时时间(秒)。推荐 60。 | `60` | -| | `delete_batch_size` | 删除批次大小 | 单次删除可处理的 Token 数量。推荐 10。 | `10` | -| **nsfw** | `concurrent` | 并发上限 | 批量开启 NSFW 模式时的并发请求上限。推荐 10。 | `10` | -| | `batch_size` | 批次大小 | 批量开启 NSFW 模式的单批处理数量。推荐 50。 | `50` | -| | `timeout` | 请求超时 | NSFW 开启相关请求的超时时间(秒)。推荐 60。 | `60` | -| **usage** | `concurrent` | 并发上限 | 批量刷新用量时的并发请求上限。推荐 10。 | `10` | -| | `batch_size` | 批次大小 | 批量刷新用量的单批处理数量。推荐 50。 | `50` | -| | `timeout` | 请求超时 | 用量查询接口的超时时间(秒)。推荐 60。 | `60` | +| 模块 | 字段 | 配置名 | 说明 | 默认值 | +| :-- | :-- | :-- | :-- | :-- | +| **app** | `app_url` | 应用地址 | 当前 Grok2API 服务的外部访问 URL,用于文件链接访问。 | `http://127.0.0.1:8000` | +| | `app_key` | 后台密码 | 登录 Grok2API 管理后台的密码(必填)。 | `grok2api` | +| | `api_key` | API 密钥 | 调用 Grok2API 服务的 Token(可选)。 | `""` | +| | `image_format` | 图片格式 | 生成的图片格式(url 或 base64)。 | `url` | +| | `video_format` | 视频格式 | 生成的视频格式(html 或 url,url 为处理后的链接)。 | `html` | +| | `temporary` | 临时对话 | 是否启用临时对话模式。 | `true` | +| | `disable_memory` | 禁用记忆 | 禁用 Grok 记忆功能,防止响应中出现不相关上下文。 | `true` | +| | `stream` | 流式响应 | 是否默认启用流式输出。 | `true` | +| | `thinking` | 思维链 | 是否启用模型思维链输出。 | `true` | +| | `dynamic_statsig` | 动态指纹 | 是否启用动态生成 Statsig 值。 | `true` | +| | `filter_tags` | 过滤标签 | 自动过滤 Grok 响应中的特殊标签。 | `["xaiartifact", "xai:tool_usage_card", "grok:render"]` | +| **proxy** | `base_proxy_url` | 基础代理 URL | 代理请求到 Grok 官网的基础服务地址。 | `""` | +| | `asset_proxy_url` | 资源代理 URL | 代理请求到 Grok 官网的静态资源(图片/视频)地址。 | `""` | +| | `cf_clearance` | CF Clearance | Cloudflare 验证 Cookie,用于绕过反爬虫验证。 | `""` | +| | `browser` | 浏览器指纹 | curl_cffi 浏览器指纹标识(如 chrome136)。 | `chrome136` | +| | `user_agent` | User-Agent | HTTP 请求的 User-Agent 字符串。 | `Mozilla/5.0 (Macintosh; ...)` | +| **voice** | `timeout` | 请求超时 | Voice 请求超时时间(秒)。 | `120` | +| **chat** | `concurrent` | 并发上限 | Reverse 接口并发上限。 | `10` | +| | `timeout` | 请求超时 | Reverse 接口超时时间(秒)。 | `60` | +| | `stream_timeout` | 流空闲超时 | 流式空闲超时时间(秒)。 | `60` | +| **video** | `concurrent` | 并发上限 | Reverse 接口并发上限。 | `10` | +| | `timeout` | 请求超时 | Reverse 接口超时时间(秒)。 | `60` | +| | `stream_timeout` | 流空闲超时 | 流式空闲超时时间(秒)。 | `60` | +| **retry** | `max_retry` | 最大重试 | 请求 Grok 服务失败时的最大重试次数。 | `3` | +| | `retry_status_codes` | 重试状态码 | 触发重试的 HTTP 状态码列表。 | `[401, 429, 403]` | +| | `retry_backoff_base` | 退避基数 | 重试退避的基础延迟(秒)。 | `0.5` | +| | `retry_backoff_factor` | 退避倍率 | 重试退避的指数放大系数。 | `2.0` | +| | `retry_backoff_max` | 退避上限 | 单次重试等待的最大延迟(秒)。 | `30.0` | +| | `retry_budget` | 退避预算 | 单次请求的最大重试总耗时(秒)。 | `90.0` | +| **image** | `timeout` | 请求超时 | WebSocket 请求超时时间(秒)。 | `120` | +| | `stream_timeout` | 流空闲超时 | WebSocket 流式空闲超时时间(秒)。 | `120` | +| | `final_timeout` | 最终图超时 | 收到中等图后等待最终图的超时秒数。 | `15` | +| | `nsfw` | NSFW 模式 | WebSocket 请求是否启用 NSFW。 | `true` | +| | `medium_min_bytes` | 中等图最小字节 | 判定中等质量图的最小字节数。 | `30000` | +| | `final_min_bytes` | 最终图最小字节 | 判定最终图的最小字节数(通常 JPG > 100KB)。 | `100000` | +| **token** | `auto_refresh` | 自动刷新 | 是否开启 Token 自动刷新机制。 | `true` | +| | `refresh_interval_hours` | 刷新间隔 | 普通 Token 刷新的时间间隔(小时)。 | `8` | +| | `super_refresh_interval_hours` | Super 刷新间隔 | Super Token 刷新的时间间隔(小时)。 | `2` | +| | `fail_threshold` | 失败阈值 | 单个 Token 连续失败多少次后被标记为不可用。 | `5` | +| | `save_delay_ms` | 保存延迟 | Token 变更合并写入的延迟(毫秒)。 | `500` | +| | `reload_interval_sec` | 同步间隔 | 多 worker 场景下 Token 状态刷新间隔(秒)。 | `30` | +| **cache** | `enable_auto_clean` | 自动清理 | 是否启用缓存自动清理,开启后按上限自动回收。 | `true` | +| | `limit_mb` | 清理阈值 | 缓存大小阈值(MB),超过阈值会触发清理。 | `1024` | +| **asset** | `upload_concurrent` | 上传并发 | 上传接口的最大并发数。推荐 30。 | `30` | +| | `upload_timeout` | 上传超时 | 上传接口超时时间(秒)。推荐 60。 | `60` | +| | `download_concurrent` | 下载并发 | 下载接口的最大并发数。推荐 30。 | `30` | +| | `download_timeout` | 下载超时 | 下载接口超时时间(秒)。推荐 60。 | `60` | +| | `list_concurrent` | 查询并发 | 资产查询接口的最大并发数。推荐 10。 | `10` | +| | `list_timeout` | 查询超时 | 资产查询接口超时时间(秒)。推荐 60。 | `60` | +| | `list_batch_size` | 查询批次大小 | 单次查询可处理的 Token 数量。推荐 10。 | `10` | +| | `delete_concurrent` | 删除并发 | 资产删除接口的最大并发数。推荐 10。 | `10` | +| | `delete_timeout` | 删除超时 | 资产删除接口超时时间(秒)。推荐 60。 | `60` | +| | `delete_batch_size` | 删除批次大小 | 单次删除可处理的 Token 数量。推荐 10。 | `10` | +| **nsfw** | `concurrent` | 并发上限 | 批量开启 NSFW 模式时的并发请求上限。推荐 10。 | `10` | +| | `batch_size` | 批次大小 | 批量开启 NSFW 模式的单批处理数量。推荐 50。 | `50` | +| | `timeout` | 请求超时 | NSFW 开启相关请求的超时时间(秒)。推荐 60。 | `60` | +| **usage** | `concurrent` | 并发上限 | 批量刷新用量时的并发请求上限。推荐 10。 | `10` | +| | `batch_size` | 批次大小 | 批量刷新用量的单批处理数量。推荐 50。 | `50` | +| | `timeout` | 请求超时 | 用量查询接口的超时时间(秒)。推荐 60。 | `60` |
From 2611aad43e4481ca28f9be106c858891c3b7a802 Mon Sep 17 00:00:00 2001 From: Chenyme <118253778+chenyme@users.noreply.github.com> Date: Sun, 15 Feb 2026 18:08:10 +0800 Subject: [PATCH 24/27] chore: update version to 0.3.0, enhance configuration migration logic, and remove outdated README notes --- app/core/config.py | 79 ++++++++++++++++++++++++++++++++++++---------- docs/README.en.md | 3 -- pyproject.toml | 2 +- 3 files changed, 63 insertions(+), 21 deletions(-) diff --git a/app/core/config.py b/app/core/config.py index efb0ec27..d761bd86 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -72,6 +72,41 @@ def _migrate_deprecated_config( "grok.image_ws_blocked_seconds": "image.final_timeout", "grok.image_ws_final_min_bytes": "image.final_min_bytes", "grok.image_ws_medium_min_bytes": "image.medium_min_bytes", + # legacy sections + "network.base_proxy_url": "proxy.base_proxy_url", + "network.asset_proxy_url": "proxy.asset_proxy_url", + "network.timeout": [ + "chat.timeout", + "image.timeout", + "video.timeout", + "voice.timeout", + ], + "security.cf_clearance": "proxy.cf_clearance", + "security.browser": "proxy.browser", + "security.user_agent": "proxy.user_agent", + "timeout.stream_idle_timeout": [ + "chat.stream_timeout", + "image.stream_timeout", + "video.stream_timeout", + ], + "timeout.video_idle_timeout": "video.stream_timeout", + "image.image_ws_nsfw": "image.nsfw", + "image.image_ws_blocked_seconds": "image.final_timeout", + "image.image_ws_final_min_bytes": "image.final_min_bytes", + "image.image_ws_medium_min_bytes": "image.medium_min_bytes", + "performance.assets_max_concurrent": [ + "asset.upload_concurrent", + "asset.download_concurrent", + "asset.list_concurrent", + "asset.delete_concurrent", + ], + "performance.assets_delete_batch_size": "asset.delete_batch_size", + "performance.assets_batch_size": "asset.list_batch_size", + "performance.media_max_concurrent": ["chat.concurrent", "video.concurrent"], + "performance.usage_max_concurrent": "usage.concurrent", + "performance.usage_batch_size": "usage.batch_size", + "performance.nsfw_max_concurrent": "nsfw.concurrent", + "performance.nsfw_batch_size": "nsfw.batch_size", } deprecated_sections = set(config.keys()) - valid_sections @@ -81,25 +116,35 @@ def _migrate_deprecated_config( result = {k: deepcopy(v) for k, v in config.items() if k in valid_sections} migrated_count = 0 - # 处理废弃配置节中的配置项 - for old_section in deprecated_sections: - if old_section not in config or not isinstance(config[old_section], dict): + # 处理废弃配置节或旧配置键 + for old_section, old_values in config.items(): + if not isinstance(old_values, dict): continue - - for old_key, old_value in config[old_section].items(): - # 查找映射规则 + for old_key, old_value in old_values.items(): old_path = f"{old_section}.{old_key}" - new_path = MIGRATION_MAP.get(old_path) - - if new_path: - new_section, new_key = new_path.split(".", 1) - # 确保新配置节存在 - if new_section not in result: - result[new_section] = {} - # 迁移配置项(保留用户的自定义值) - result[new_section][new_key] = old_value - migrated_count += 1 - logger.debug(f"Migrated config: {old_path} -> {new_path} = {old_value}") + new_paths = MIGRATION_MAP.get(old_path) + if not new_paths: + continue + if isinstance(new_paths, str): + new_paths = [new_paths] + for new_path in new_paths: + try: + new_section, new_key = new_path.split(".", 1) + if new_section not in result: + result[new_section] = {} + if new_key not in result[new_section]: + result[new_section][new_key] = old_value + migrated_count += 1 + logger.debug( + f"Migrated config: {old_path} -> {new_path} = {old_value}" + ) + except Exception as e: + logger.warning( + f"Skip config migration for {old_path}: {e}" + ) + continue + if isinstance(result.get(old_section), dict): + result[old_section].pop(old_key, None) # 兼容旧 chat.* 配置键迁移到 app.* legacy_chat_map = { diff --git a/docs/README.en.md b/docs/README.en.md index e92283a8..c6746540 100644 --- a/docs/README.en.md +++ b/docs/README.en.md @@ -7,9 +7,6 @@ Grok2API rebuilt with **FastAPI**, fully aligned with the latest web call format. Supports streaming/non-streaming chat, image generation/editing, video generation/upscale, deep reasoning, token pool concurrency, and automatic load balancing. -> [!IMPORTANT] -> The project is no longer accepting PRs or new features; this is the last structure optimization. - image
diff --git a/pyproject.toml b/pyproject.toml index e2fc17b9..f808a049 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "grok2api" -version = "1.5.0" +version = "0.3.0" description = "Grok2API rebuilt with FastAPI, fully aligned with the latest web call format. Supports streaming and non-streaming chat, image generation/editing, deep thinking, token pool concurrency, and automatic load balancing." readme = "README.md" requires-python = ">=3.13" From 8c275163c44048a908b0c38aa2ee9683986093ef Mon Sep 17 00:00:00 2001 From: Chenyme <118253778+chenyme@users.noreply.github.com> Date: Sun, 15 Feb 2026 18:15:47 +0800 Subject: [PATCH 25/27] chore: update version to 0.3.0 across multiple files, including CSS and JS resources, and enhance page titles for better clarity --- app/static/admin/pages/cache.html | 14 +++++++------- app/static/admin/pages/config.html | 12 ++++++------ app/static/admin/pages/login.html | 10 +++++----- app/static/admin/pages/token.html | 14 +++++++------- app/static/common/html/public-header.html | 6 +++--- app/static/common/js/footer.js | 2 +- app/static/common/js/header.js | 2 +- app/static/common/js/public-header.js | 2 +- app/static/public/pages/imagine.html | 12 ++++++------ app/static/public/pages/login.html | 10 +++++----- app/static/public/pages/video.html | 16 ++++++++-------- app/static/public/pages/voice.html | 16 ++++++++-------- uv.lock | 2 +- 13 files changed, 59 insertions(+), 59 deletions(-) diff --git a/app/static/admin/pages/cache.html b/app/static/admin/pages/cache.html index 5acb03e2..52055d24 100644 --- a/app/static/admin/pages/cache.html +++ b/app/static/admin/pages/cache.html @@ -9,8 +9,8 @@ - - + + @@ -196,12 +196,12 @@

缓存管理

- - - - + + + + - + diff --git a/app/static/admin/pages/config.html b/app/static/admin/pages/config.html index 2225fff2..68786b0d 100644 --- a/app/static/admin/pages/config.html +++ b/app/static/admin/pages/config.html @@ -9,8 +9,8 @@ - - + + @@ -46,11 +46,11 @@

配置管理

- - - + + + - + diff --git a/app/static/admin/pages/login.html b/app/static/admin/pages/login.html index 9e94cd5f..913c33f7 100644 --- a/app/static/admin/pages/login.html +++ b/app/static/admin/pages/login.html @@ -24,8 +24,8 @@ } } - - + + @@ -55,10 +55,10 @@ - + - - + + diff --git a/app/static/admin/pages/token.html b/app/static/admin/pages/token.html index 6302d958..daf49759 100644 --- a/app/static/admin/pages/token.html +++ b/app/static/admin/pages/token.html @@ -9,8 +9,8 @@ - - + + @@ -292,12 +292,12 @@
- - - - + + + + - + diff --git a/app/static/common/html/public-header.html b/app/static/common/html/public-header.html index f94f08fb..313bbedd 100644 --- a/app/static/common/html/public-header.html +++ b/app/static/common/html/public-header.html @@ -14,9 +14,9 @@ class="text-xs text-[var(--accents-4)] hover:text-black">@chenyme
- Imagine - Video - Voice Live + Imagine 瀑布流 + Video 视频生成 + LiveKit 陪聊