diff --git a/app/__init__.py b/app/__init__.py deleted file mode 100644 index 9f7a4239..00000000 --- a/app/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""App Package""" diff --git a/app/api/pages/__init__.py b/app/api/pages/__init__.py new file mode 100644 index 00000000..6aada6e5 --- /dev/null +++ b/app/api/pages/__init__.py @@ -0,0 +1,13 @@ +"""UI pages router.""" + +from fastapi import APIRouter + +from app.api.pages.admin import router as admin_router +from app.api.pages.public import router as public_router + +router = APIRouter() + +router.include_router(public_router) +router.include_router(admin_router) + +__all__ = ["router"] diff --git a/app/api/pages/admin.py b/app/api/pages/admin.py new file mode 100644 index 00000000..bb581e89 --- /dev/null +++ b/app/api/pages/admin.py @@ -0,0 +1,32 @@ +from pathlib import Path + +from fastapi import APIRouter +from fastapi.responses import FileResponse, RedirectResponse + +router = APIRouter() +STATIC_DIR = Path(__file__).resolve().parents[2] / "static" + + +@router.get("/admin", include_in_schema=False) +async def admin_root(): + return RedirectResponse(url="/admin/login") + + +@router.get("/admin/login", include_in_schema=False) +async def admin_login(): + return FileResponse(STATIC_DIR / "admin/pages/login.html") + + +@router.get("/admin/config", include_in_schema=False) +async def admin_config(): + return FileResponse(STATIC_DIR / "admin/pages/config.html") + + +@router.get("/admin/cache", include_in_schema=False) +async def admin_cache(): + return FileResponse(STATIC_DIR / "admin/pages/cache.html") + + +@router.get("/admin/token", include_in_schema=False) +async def admin_token(): + return FileResponse(STATIC_DIR / "admin/pages/token.html") diff --git a/app/api/pages/public.py b/app/api/pages/public.py new file mode 100644 index 00000000..0792df99 --- /dev/null +++ b/app/api/pages/public.py @@ -0,0 +1,44 @@ +from pathlib import Path + +from fastapi import APIRouter, HTTPException +from fastapi.responses import FileResponse, RedirectResponse + +from app.core.auth import is_public_enabled + +router = APIRouter() +STATIC_DIR = Path(__file__).resolve().parents[2] / "static" + + +@router.get("/", include_in_schema=False) +async def root(): + if is_public_enabled(): + return RedirectResponse(url="/login") + return RedirectResponse(url="/admin/login") + + +@router.get("/login", include_in_schema=False) +async def public_login(): + if not is_public_enabled(): + raise HTTPException(status_code=404, detail="Not Found") + return FileResponse(STATIC_DIR / "public/pages/login.html") + + +@router.get("/imagine", include_in_schema=False) +async def public_imagine(): + if not is_public_enabled(): + raise HTTPException(status_code=404, detail="Not Found") + return FileResponse(STATIC_DIR / "public/pages/imagine.html") + + +@router.get("/voice", include_in_schema=False) +async def public_voice(): + if not is_public_enabled(): + raise HTTPException(status_code=404, detail="Not Found") + return FileResponse(STATIC_DIR / "public/pages/voice.html") + + +@router.get("/video", include_in_schema=False) +async def public_video(): + if not is_public_enabled(): + raise HTTPException(status_code=404, detail="Not Found") + return FileResponse(STATIC_DIR / "public/pages/video.html") diff --git a/app/api/v1/admin.py b/app/api/v1/admin.py deleted file mode 100644 index c27103ad..00000000 --- a/app/api/v1/admin.py +++ /dev/null @@ -1,1783 +0,0 @@ -from fastapi import ( - APIRouter, - Depends, - HTTPException, - Request, - Query, - WebSocket, - WebSocketDisconnect, -) -from fastapi.responses import HTMLResponse, StreamingResponse, RedirectResponse -from typing import Optional -from pydantic import BaseModel -from app.core.auth import verify_api_key, verify_app_key, get_admin_api_key -from app.core.config import config, get_config -from app.core.batch_tasks import create_task, get_task, expire_task -from app.core.storage import get_storage, LocalStorage, RedisStorage, SQLStorage -from app.core.exceptions import AppException -from app.services.token.manager import get_token_manager -from app.services.grok.utils.batch import run_in_batches -import os -import time -import uuid -from pathlib import Path -import aiofiles -import asyncio -import orjson -from app.core.logger import logger -from app.api.v1.image import resolve_aspect_ratio -from app.services.grok.services.voice import VoiceService -from app.services.grok.services.image import image_service -from app.services.grok.models.model import ModelService -from app.services.grok.processors.image_ws_processors import ImageWSCollectProcessor -from app.services.token import EffortType - -TEMPLATE_DIR = Path(__file__).parent.parent.parent / "static" - - -router = APIRouter() - -IMAGINE_SESSION_TTL = 600 -_IMAGINE_SESSIONS: dict[str, dict] = {} -_IMAGINE_SESSIONS_LOCK = asyncio.Lock() - - -async def _cleanup_imagine_sessions(now: float) -> None: - expired = [ - key - for key, info in _IMAGINE_SESSIONS.items() - if now - float(info.get("created_at") or 0) > IMAGINE_SESSION_TTL - ] - for key in expired: - _IMAGINE_SESSIONS.pop(key, None) - - -async def _create_imagine_session(prompt: str, aspect_ratio: str) -> str: - task_id = uuid.uuid4().hex - now = time.time() - async with _IMAGINE_SESSIONS_LOCK: - await _cleanup_imagine_sessions(now) - _IMAGINE_SESSIONS[task_id] = { - "prompt": prompt, - "aspect_ratio": aspect_ratio, - "created_at": now, - } - return task_id - - -async def _get_imagine_session(task_id: str) -> Optional[dict]: - if not task_id: - return None - now = time.time() - async with _IMAGINE_SESSIONS_LOCK: - await _cleanup_imagine_sessions(now) - info = _IMAGINE_SESSIONS.get(task_id) - if not info: - return None - created_at = float(info.get("created_at") or 0) - if now - created_at > IMAGINE_SESSION_TTL: - _IMAGINE_SESSIONS.pop(task_id, None) - return None - return dict(info) - - -async def _delete_imagine_session(task_id: str) -> None: - if not task_id: - return - async with _IMAGINE_SESSIONS_LOCK: - _IMAGINE_SESSIONS.pop(task_id, None) - - -async def _delete_imagine_sessions(task_ids: list[str]) -> int: - if not task_ids: - return 0 - removed = 0 - async with _IMAGINE_SESSIONS_LOCK: - for task_id in task_ids: - if task_id and task_id in _IMAGINE_SESSIONS: - _IMAGINE_SESSIONS.pop(task_id, None) - removed += 1 - return removed - - -def _collect_tokens(data: dict) -> list[str]: - """从请求数据中收集 token 列表""" - tokens = [] - if isinstance(data.get("token"), str) and data["token"].strip(): - tokens.append(data["token"].strip()) - if isinstance(data.get("tokens"), list): - tokens.extend([str(t).strip() for t in data["tokens"] if str(t).strip()]) - return tokens - - -def _truncate_tokens( - tokens: list[str], max_tokens: int, operation: str = "operation" -) -> tuple[list[str], bool, int]: - """去重并截断 token 列表,返回 (unique_tokens, truncated, original_count)""" - unique_tokens = list(dict.fromkeys(tokens)) - original_count = len(unique_tokens) - truncated = False - - if len(unique_tokens) > max_tokens: - unique_tokens = unique_tokens[:max_tokens] - truncated = True - logger.warning( - f"{operation}: truncated from {original_count} to {max_tokens} tokens" - ) - - return unique_tokens, truncated, original_count - - -def _mask_token(token: str) -> str: - """掩码 token 显示""" - return f"{token[:8]}...{token[-8:]}" if len(token) > 20 else token - - -async def render_template(filename: str): - """渲染指定模板""" - template_path = TEMPLATE_DIR / filename - if not template_path.exists(): - return HTMLResponse(f"Template {filename} not found.", status_code=404) - - async with aiofiles.open(template_path, "r", encoding="utf-8") as f: - content = await f.read() - return HTMLResponse(content) - - -def _sse_event(payload: dict) -> str: - return f"data: {orjson.dumps(payload).decode()}\n\n" - - -def _verify_stream_api_key(request: Request) -> None: - api_key = get_admin_api_key() - if not api_key: - return - key = request.query_params.get("api_key") - if key != api_key: - raise HTTPException(status_code=401, detail="Invalid authentication token") - - -@router.get("/api/v1/admin/batch/{task_id}/stream") -async def stream_batch(task_id: str, request: Request): - _verify_stream_api_key(request) - task = get_task(task_id) - if not task: - raise HTTPException(status_code=404, detail="Task not found") - - async def event_stream(): - queue = task.attach() - try: - yield _sse_event({"type": "snapshot", **task.snapshot()}) - - final = task.final_event() - if final: - yield _sse_event(final) - return - - while True: - try: - event = await asyncio.wait_for(queue.get(), timeout=15) - except asyncio.TimeoutError: - yield ": ping\n\n" - final = task.final_event() - if final: - yield _sse_event(final) - return - continue - - yield _sse_event(event) - if event.get("type") in ("done", "error", "cancelled"): - return - finally: - task.detach(queue) - - return StreamingResponse(event_stream(), media_type="text/event-stream") - - -@router.post( - "/api/v1/admin/batch/{task_id}/cancel", dependencies=[Depends(verify_api_key)] -) -async def cancel_batch(task_id: str): - task = get_task(task_id) - if not task: - raise HTTPException(status_code=404, detail="Task not found") - task.cancel() - return {"status": "success"} - - -@router.get("/admin", response_class=HTMLResponse, include_in_schema=False) -async def admin_login_page(): - """管理后台登录页""" - return await render_template("login/login.html") - - -@router.get("/", include_in_schema=False) -async def root_redirect(): - return RedirectResponse(url="/admin") - - -@router.get("/admin/config", response_class=HTMLResponse, include_in_schema=False) -async def admin_config_page(): - """配置管理页""" - return await render_template("config/config.html") - - -@router.get("/admin/token", response_class=HTMLResponse, include_in_schema=False) -async def admin_token_page(): - """Token 管理页""" - return await render_template("token/token.html") - - -@router.get("/admin/voice", response_class=HTMLResponse, include_in_schema=False) -async def admin_voice_page(): - """Voice Live 调试页""" - return await render_template("voice/voice.html") - - -@router.get("/admin/imagine", response_class=HTMLResponse, include_in_schema=False) -async def admin_imagine_page(): - """Imagine 图片瀑布流""" - return await render_template("imagine/imagine.html") - - -class VoiceTokenResponse(BaseModel): - token: str - url: str - participant_name: str = "" - room_name: str = "" - - -@router.get( - "/api/v1/admin/voice/token", - dependencies=[Depends(verify_api_key)], - response_model=VoiceTokenResponse, -) -async def admin_voice_token( - voice: str = "ara", - personality: str = "assistant", - speed: float = 1.0, -): - """获取 Grok Voice Mode (LiveKit) Token""" - token_mgr = await get_token_manager() - sso_token = None - for pool_name in ("ssoBasic", "ssoSuper"): - sso_token = token_mgr.get_token(pool_name) - if sso_token: - break - - if not sso_token: - raise AppException( - "No available tokens for voice mode", - code="no_token", - status_code=503, - ) - - service = VoiceService() - try: - data = await service.get_token( - token=sso_token, - voice=voice, - personality=personality, - speed=speed, - ) - token = data.get("token") - if not token: - raise AppException( - "Upstream returned no voice token", - code="upstream_error", - status_code=502, - ) - - return VoiceTokenResponse( - token=token, - url="wss://livekit.grok.com", - participant_name="", - room_name="", - ) - - except Exception as e: - if isinstance(e, AppException): - raise - raise AppException( - f"Voice token error: {str(e)}", - code="voice_error", - status_code=500, - ) - - -async def _verify_imagine_ws_auth(websocket: WebSocket) -> tuple[bool, Optional[str]]: - task_id = websocket.query_params.get("task_id") - if task_id: - info = await _get_imagine_session(task_id) - if info: - return True, task_id - - api_key = get_admin_api_key() - if not api_key: - return True, None - key = websocket.query_params.get("api_key") - return key == api_key, None - - -@router.websocket("/api/v1/admin/imagine/ws") -async def admin_imagine_ws(websocket: WebSocket): - ok, session_id = await _verify_imagine_ws_auth(websocket) - if not ok: - await websocket.close(code=1008) - return - - await websocket.accept() - stop_event = asyncio.Event() - run_task: Optional[asyncio.Task] = None - - async def _send(payload: dict) -> bool: - try: - await websocket.send_text(orjson.dumps(payload).decode()) - return True - except Exception: - return False - - async def _stop_run(): - nonlocal run_task - stop_event.set() - if run_task and not run_task.done(): - run_task.cancel() - try: - await run_task - except Exception: - pass - run_task = None - stop_event.clear() - - async def _run(prompt: str, aspect_ratio: str): - model_id = "grok-imagine-1.0" - model_info = ModelService.get(model_id) - if not model_info or not model_info.is_image: - await _send( - { - "type": "error", - "message": "Image model is not available.", - "code": "model_not_supported", - } - ) - return - - token_mgr = await get_token_manager() - enable_nsfw = bool(get_config("image.image_ws_nsfw", True)) - sequence = 0 - run_id = uuid.uuid4().hex - - await _send( - { - "type": "status", - "status": "running", - "prompt": prompt, - "aspect_ratio": aspect_ratio, - "run_id": run_id, - } - ) - - while not stop_event.is_set(): - try: - await token_mgr.reload_if_stale() - token = None - for pool_name in ModelService.pool_candidates_for_model( - model_info.model_id - ): - token = token_mgr.get_token(pool_name) - if token: - break - - if not token: - await _send( - { - "type": "error", - "message": "No available tokens. Please try again later.", - "code": "rate_limit_exceeded", - } - ) - await asyncio.sleep(2) - continue - - upstream = image_service.stream( - token=token, - prompt=prompt, - aspect_ratio=aspect_ratio, - n=6, - enable_nsfw=enable_nsfw, - ) - - processor = ImageWSCollectProcessor( - model_info.model_id, - token, - n=6, - response_format="b64_json", - ) - - start_at = time.time() - images = await processor.process(upstream) - elapsed_ms = int((time.time() - start_at) * 1000) - - if images and all(img and img != "error" for img in images): - # 一次发送所有 6 张图片 - for img_b64 in images: - sequence += 1 - await _send( - { - "type": "image", - "b64_json": img_b64, - "sequence": sequence, - "created_at": int(time.time() * 1000), - "elapsed_ms": elapsed_ms, - "aspect_ratio": aspect_ratio, - "run_id": run_id, - } - ) - - # 消耗 token(6 张图片按高成本计算) - try: - effort = ( - EffortType.HIGH - if (model_info and model_info.cost.value == "high") - else EffortType.LOW - ) - await token_mgr.consume(token, effort) - except Exception as e: - logger.warning(f"Failed to consume token: {e}") - else: - await _send( - { - "type": "error", - "message": "Image generation returned empty data.", - "code": "empty_image", - } - ) - - except asyncio.CancelledError: - break - except Exception as e: - logger.warning(f"Imagine stream error: {e}") - await _send( - { - "type": "error", - "message": str(e), - "code": "internal_error", - } - ) - await asyncio.sleep(1.5) - - await _send({"type": "status", "status": "stopped", "run_id": run_id}) - - try: - while True: - try: - raw = await websocket.receive_text() - except (RuntimeError, WebSocketDisconnect): - # WebSocket already closed or disconnected - break - - try: - payload = orjson.loads(raw) - except Exception: - await _send( - { - "type": "error", - "message": "Invalid message format.", - "code": "invalid_payload", - } - ) - continue - - msg_type = payload.get("type") - if msg_type == "start": - prompt = str(payload.get("prompt") or "").strip() - if not prompt: - await _send( - { - "type": "error", - "message": "Prompt cannot be empty.", - "code": "empty_prompt", - } - ) - continue - ratio = str(payload.get("aspect_ratio") or "2:3").strip() - if not ratio: - ratio = "2:3" - ratio = resolve_aspect_ratio(ratio) - await _stop_run() - stop_event.clear() - run_task = asyncio.create_task(_run(prompt, ratio)) - elif msg_type == "stop": - await _stop_run() - elif msg_type == "ping": - await _send({"type": "pong"}) - else: - await _send( - { - "type": "error", - "message": "Unknown command.", - "code": "unknown_command", - } - ) - except WebSocketDisconnect: - logger.debug("WebSocket disconnected by client") - except Exception as e: - logger.warning(f"WebSocket error: {e}") - finally: - await _stop_run() - - try: - from starlette.websockets import WebSocketState - if websocket.client_state == WebSocketState.CONNECTED: - await websocket.close(code=1000, reason="Server closing connection") - except Exception as e: - logger.debug(f"WebSocket close ignored: {e}") - if session_id: - await _delete_imagine_session(session_id) - - -class ImagineStartRequest(BaseModel): - prompt: str - aspect_ratio: Optional[str] = "2:3" - - -@router.post("/api/v1/admin/imagine/start", dependencies=[Depends(verify_api_key)]) -async def admin_imagine_start(data: ImagineStartRequest): - prompt = (data.prompt or "").strip() - if not prompt: - raise HTTPException(status_code=400, detail="Prompt cannot be empty") - ratio = resolve_aspect_ratio(str(data.aspect_ratio or "2:3").strip() or "2:3") - task_id = await _create_imagine_session(prompt, ratio) - return {"task_id": task_id, "aspect_ratio": ratio} - - -class ImagineStopRequest(BaseModel): - task_ids: list[str] - - -@router.post("/api/v1/admin/imagine/stop", dependencies=[Depends(verify_api_key)]) -async def admin_imagine_stop(data: ImagineStopRequest): - removed = await _delete_imagine_sessions(data.task_ids or []) - return {"status": "success", "removed": removed} - - -@router.get("/api/v1/admin/imagine/sse") -async def admin_imagine_sse( - request: Request, - task_id: str = Query(""), - prompt: str = Query(""), - aspect_ratio: str = Query("2:3"), -): - """Imagine 图片瀑布流(SSE 兜底)""" - session = None - if task_id: - session = await _get_imagine_session(task_id) - if not session: - raise HTTPException(status_code=404, detail="Task not found") - else: - _verify_stream_api_key(request) - - if session: - prompt = str(session.get("prompt") or "").strip() - ratio = str(session.get("aspect_ratio") or "2:3").strip() or "2:3" - else: - prompt = (prompt or "").strip() - if not prompt: - raise HTTPException(status_code=400, detail="Prompt cannot be empty") - ratio = str(aspect_ratio or "2:3").strip() or "2:3" - ratio = resolve_aspect_ratio(ratio) - - async def event_stream(): - try: - model_id = "grok-imagine-1.0" - model_info = ModelService.get(model_id) - if not model_info or not model_info.is_image: - yield _sse_event( - { - "type": "error", - "message": "Image model is not available.", - "code": "model_not_supported", - } - ) - return - - token_mgr = await get_token_manager() - enable_nsfw = bool(get_config("image.image_ws_nsfw", True)) - sequence = 0 - run_id = uuid.uuid4().hex - - yield _sse_event( - { - "type": "status", - "status": "running", - "prompt": prompt, - "aspect_ratio": ratio, - "run_id": run_id, - } - ) - - while True: - if await request.is_disconnected(): - break - if task_id: - session_alive = await _get_imagine_session(task_id) - if not session_alive: - break - - try: - await token_mgr.reload_if_stale() - token = None - for pool_name in ModelService.pool_candidates_for_model( - model_info.model_id - ): - token = token_mgr.get_token(pool_name) - if token: - break - - if not token: - yield _sse_event( - { - "type": "error", - "message": "No available tokens. Please try again later.", - "code": "rate_limit_exceeded", - } - ) - await asyncio.sleep(2) - continue - - upstream = image_service.stream( - token=token, - prompt=prompt, - aspect_ratio=ratio, - n=6, - enable_nsfw=enable_nsfw, - ) - - processor = ImageWSCollectProcessor( - model_info.model_id, - token, - n=6, - response_format="b64_json", - ) - - start_at = time.time() - images = await processor.process(upstream) - elapsed_ms = int((time.time() - start_at) * 1000) - - if images and all(img and img != "error" for img in images): - for img_b64 in images: - sequence += 1 - yield _sse_event( - { - "type": "image", - "b64_json": img_b64, - "sequence": sequence, - "created_at": int(time.time() * 1000), - "elapsed_ms": elapsed_ms, - "aspect_ratio": ratio, - "run_id": run_id, - } - ) - - try: - effort = ( - EffortType.HIGH - if (model_info and model_info.cost.value == "high") - else EffortType.LOW - ) - await token_mgr.consume(token, effort) - except Exception as e: - logger.warning(f"Failed to consume token: {e}") - else: - yield _sse_event( - { - "type": "error", - "message": "Image generation returned empty data.", - "code": "empty_image", - } - ) - except asyncio.CancelledError: - break - except Exception as e: - logger.warning(f"Imagine SSE error: {e}") - yield _sse_event( - {"type": "error", "message": str(e), "code": "internal_error"} - ) - await asyncio.sleep(1.5) - - yield _sse_event({"type": "status", "status": "stopped", "run_id": run_id}) - finally: - if task_id: - await _delete_imagine_session(task_id) - - return StreamingResponse( - event_stream(), - media_type="text/event-stream", - headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, - ) - - -@router.post("/api/v1/admin/login", dependencies=[Depends(verify_app_key)]) -async def admin_login_api(): - """管理后台登录验证(使用 app_key)""" - return {"status": "success", "api_key": get_admin_api_key()} - - -@router.get("/api/v1/admin/config", dependencies=[Depends(verify_api_key)]) -async def get_config_api(): - """获取当前配置""" - # 暴露原始配置字典 - return config._config - - -@router.post("/api/v1/admin/config", dependencies=[Depends(verify_api_key)]) -async def update_config_api(data: dict): - """更新配置""" - try: - await config.update(data) - return {"status": "success", "message": "配置已更新"} - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -@router.get("/api/v1/admin/storage", dependencies=[Depends(verify_api_key)]) -async def get_storage_info(): - """获取当前存储模式""" - storage_type = os.getenv("SERVER_STORAGE_TYPE", "").lower() - if not storage_type: - storage_type = str(get_config("storage.type")).lower() - if not storage_type: - storage = get_storage() - if isinstance(storage, LocalStorage): - storage_type = "local" - elif isinstance(storage, RedisStorage): - storage_type = "redis" - elif isinstance(storage, SQLStorage): - storage_type = { - "mysql": "mysql", - "mariadb": "mysql", - "postgres": "pgsql", - "postgresql": "pgsql", - "pgsql": "pgsql", - }.get(storage.dialect, storage.dialect) - return {"type": storage_type or "local"} - - -@router.get("/api/v1/admin/tokens", dependencies=[Depends(verify_api_key)]) -async def get_tokens_api(): - """获取所有 Token""" - storage = get_storage() - tokens = await storage.load_tokens() - return tokens or {} - - -@router.post("/api/v1/admin/tokens", dependencies=[Depends(verify_api_key)]) -async def update_tokens_api(data: dict): - """更新 Token 信息""" - storage = get_storage() - try: - from app.services.token.manager import get_token_manager - from app.services.token.models import TokenInfo - - async with storage.acquire_lock("tokens_save", timeout=10): - existing = await storage.load_tokens() or {} - normalized = {} - allowed_fields = set(TokenInfo.model_fields.keys()) - existing_map = {} - for pool_name, tokens in existing.items(): - if not isinstance(tokens, list): - continue - pool_map = {} - for item in tokens: - if isinstance(item, str): - token_data = {"token": item} - elif isinstance(item, dict): - token_data = dict(item) - else: - continue - raw_token = token_data.get("token") - if isinstance(raw_token, str) and raw_token.startswith("sso="): - token_data["token"] = raw_token[4:] - token_key = token_data.get("token") - if isinstance(token_key, str): - pool_map[token_key] = token_data - existing_map[pool_name] = pool_map - for pool_name, tokens in (data or {}).items(): - if not isinstance(tokens, list): - continue - pool_list = [] - for item in tokens: - if isinstance(item, str): - token_data = {"token": item} - elif isinstance(item, dict): - token_data = dict(item) - else: - continue - - raw_token = token_data.get("token") - if isinstance(raw_token, str) and raw_token.startswith("sso="): - token_data["token"] = raw_token[4:] - - base = existing_map.get(pool_name, {}).get( - token_data.get("token"), {} - ) - merged = dict(base) - merged.update(token_data) - if merged.get("tags") is None: - merged["tags"] = [] - - filtered = {k: v for k, v in merged.items() if k in allowed_fields} - try: - info = TokenInfo(**filtered) - pool_list.append(info.model_dump()) - except Exception as e: - logger.warning(f"Skip invalid token in pool '{pool_name}': {e}") - continue - normalized[pool_name] = pool_list - - await storage.save_tokens(normalized) - mgr = await get_token_manager() - await mgr.reload() - return {"status": "success", "message": "Token 已更新"} - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -@router.post("/api/v1/admin/tokens/refresh", dependencies=[Depends(verify_api_key)]) -async def refresh_tokens_api(data: dict): - """刷新 Token 状态""" - try: - mgr = await get_token_manager() - tokens = _collect_tokens(data) - - if not tokens: - raise HTTPException(status_code=400, detail="No tokens provided") - - # 去重并截断 - max_tokens = int(get_config("performance.usage_max_tokens")) - unique_tokens, truncated, original_count = _truncate_tokens( - tokens, max_tokens, "Usage refresh" - ) - - # 批量执行配置 - max_concurrent = get_config("performance.usage_max_concurrent") - batch_size = get_config("performance.usage_batch_size") - - async def _refresh_one(t): - return await mgr.sync_usage( - t, "grok-3", consume_on_fail=False, is_usage=False - ) - - raw_results = await run_in_batches( - unique_tokens, - _refresh_one, - max_concurrent=max_concurrent, - batch_size=batch_size, - ) - - results = {} - for token, res in raw_results.items(): - if res.get("ok"): - results[token] = res.get("data", False) - else: - results[token] = False - - response = {"status": "success", "results": results} - if truncated: - response["warning"] = ( - f"数量超出限制,仅处理前 {max_tokens} 个(共 {original_count} 个)" - ) - return response - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -@router.post( - "/api/v1/admin/tokens/refresh/async", dependencies=[Depends(verify_api_key)] -) -async def refresh_tokens_api_async(data: dict): - """刷新 Token 状态(异步批量 + SSE 进度)""" - mgr = await get_token_manager() - tokens = _collect_tokens(data) - - if not tokens: - raise HTTPException(status_code=400, detail="No tokens provided") - - # 去重并截断 - max_tokens = int(get_config("performance.usage_max_tokens")) - unique_tokens, truncated, original_count = _truncate_tokens( - tokens, max_tokens, "Usage refresh" - ) - - max_concurrent = get_config("performance.usage_max_concurrent") - batch_size = get_config("performance.usage_batch_size") - - task = create_task(len(unique_tokens)) - - async def _run(): - try: - - async def _refresh_one(t: str): - return await mgr.sync_usage( - t, "grok-3", consume_on_fail=False, is_usage=False - ) - - async def _on_item(item: str, res: dict): - task.record(bool(res.get("ok"))) - - raw_results = await run_in_batches( - unique_tokens, - _refresh_one, - max_concurrent=max_concurrent, - batch_size=batch_size, - on_item=_on_item, - should_cancel=lambda: task.cancelled, - ) - - if task.cancelled: - task.finish_cancelled() - return - - results: dict[str, bool] = {} - ok_count = 0 - fail_count = 0 - for token, res in raw_results.items(): - if res.get("ok") and res.get("data") is True: - ok_count += 1 - results[token] = True - else: - fail_count += 1 - results[token] = False - - await mgr._save() - - result = { - "status": "success", - "summary": { - "total": len(unique_tokens), - "ok": ok_count, - "fail": fail_count, - }, - "results": results, - } - warning = None - if truncated: - warning = ( - f"数量超出限制,仅处理前 {max_tokens} 个(共 {original_count} 个)" - ) - task.finish(result, warning=warning) - except Exception as e: - task.fail_task(str(e)) - finally: - asyncio.create_task(expire_task(task.id, 300)) - - asyncio.create_task(_run()) - - return { - "status": "success", - "task_id": task.id, - "total": len(unique_tokens), - } - - -@router.post("/api/v1/admin/tokens/nsfw/enable", dependencies=[Depends(verify_api_key)]) -async def enable_nsfw_api(data: dict): - """批量开启 NSFW (Unhinged) 模式""" - from app.services.grok.services.nsfw import NSFWService - - try: - mgr = await get_token_manager() - nsfw_service = NSFWService() - - # 收集 token 列表 - tokens = _collect_tokens(data) - - # 若未指定,则使用所有 pool 中的 token - if not tokens: - for pool_name, pool in mgr.pools.items(): - for info in pool.list(): - raw = ( - info.token[4:] if info.token.startswith("sso=") else info.token - ) - tokens.append(raw) - - if not tokens: - raise HTTPException(status_code=400, detail="No tokens available") - - # 去重并截断 - max_tokens = int(get_config("performance.nsfw_max_tokens")) - unique_tokens, truncated, original_count = _truncate_tokens( - tokens, max_tokens, "NSFW enable" - ) - - # 批量执行配置 - max_concurrent = get_config("performance.nsfw_max_concurrent") - batch_size = get_config("performance.nsfw_batch_size") - - # 定义 worker - async def _enable(token: str): - result = await nsfw_service.enable(token) - # 成功后添加 nsfw tag - if result.success: - await mgr.add_tag(token, "nsfw") - return { - "success": result.success, - "http_status": result.http_status, - "grpc_status": result.grpc_status, - "grpc_message": result.grpc_message, - "error": result.error, - } - - # 执行批量操作 - raw_results = await run_in_batches( - unique_tokens, _enable, max_concurrent=max_concurrent, batch_size=batch_size - ) - - # 构造返回结果(mask token) - results = {} - ok_count = 0 - fail_count = 0 - - for token, res in raw_results.items(): - masked = _mask_token(token) - if res.get("ok") and res.get("data", {}).get("success"): - ok_count += 1 - results[masked] = res.get("data", {}) - else: - fail_count += 1 - results[masked] = res.get("data") or {"error": res.get("error")} - - response = { - "status": "success", - "summary": { - "total": len(unique_tokens), - "ok": ok_count, - "fail": fail_count, - }, - "results": results, - } - - # 添加截断提示 - if truncated: - response["warning"] = ( - f"数量超出限制,仅处理前 {max_tokens} 个(共 {original_count} 个)" - ) - - return response - - except HTTPException: - raise - except Exception as e: - logger.error(f"Enable NSFW failed: {e}") - raise HTTPException(status_code=500, detail=str(e)) - - -@router.post( - "/api/v1/admin/tokens/nsfw/enable/async", dependencies=[Depends(verify_api_key)] -) -async def enable_nsfw_api_async(data: dict): - """批量开启 NSFW (Unhinged) 模式(异步批量 + SSE 进度)""" - from app.services.grok.services.nsfw import NSFWService - - mgr = await get_token_manager() - nsfw_service = NSFWService() - - tokens = _collect_tokens(data) - - if not tokens: - for pool_name, pool in mgr.pools.items(): - for info in pool.list(): - raw = info.token[4:] if info.token.startswith("sso=") else info.token - tokens.append(raw) - - if not tokens: - raise HTTPException(status_code=400, detail="No tokens available") - - # 去重并截断 - max_tokens = int(get_config("performance.nsfw_max_tokens")) - unique_tokens, truncated, original_count = _truncate_tokens( - tokens, max_tokens, "NSFW enable" - ) - - max_concurrent = get_config("performance.nsfw_max_concurrent") - batch_size = get_config("performance.nsfw_batch_size") - - task = create_task(len(unique_tokens)) - - async def _run(): - try: - - async def _enable(token: str): - result = await nsfw_service.enable(token) - if result.success: - await mgr.add_tag(token, "nsfw") - return { - "success": result.success, - "http_status": result.http_status, - "grpc_status": result.grpc_status, - "grpc_message": result.grpc_message, - "error": result.error, - } - - async def _on_item(item: str, res: dict): - ok = bool(res.get("ok") and res.get("data", {}).get("success")) - task.record(ok) - - raw_results = await run_in_batches( - unique_tokens, - _enable, - max_concurrent=max_concurrent, - batch_size=batch_size, - on_item=_on_item, - should_cancel=lambda: task.cancelled, - ) - - if task.cancelled: - task.finish_cancelled() - return - - results = {} - ok_count = 0 - fail_count = 0 - for token, res in raw_results.items(): - masked = f"{token[:8]}...{token[-8:]}" if len(token) > 20 else token - if res.get("ok") and res.get("data", {}).get("success"): - ok_count += 1 - results[masked] = res.get("data", {}) - else: - fail_count += 1 - results[masked] = res.get("data") or {"error": res.get("error")} - - await mgr._save() - - result = { - "status": "success", - "summary": { - "total": len(unique_tokens), - "ok": ok_count, - "fail": fail_count, - }, - "results": results, - } - warning = None - if truncated: - warning = ( - f"数量超出限制,仅处理前 {max_tokens} 个(共 {original_count} 个)" - ) - task.finish(result, warning=warning) - except Exception as e: - task.fail_task(str(e)) - finally: - asyncio.create_task(expire_task(task.id, 300)) - - asyncio.create_task(_run()) - - return { - "status": "success", - "task_id": task.id, - "total": len(unique_tokens), - } - - -@router.get("/admin/cache", response_class=HTMLResponse, include_in_schema=False) -async def admin_cache_page(): - """缓存管理页""" - return await render_template("cache/cache.html") - - -@router.get("/api/v1/admin/cache", dependencies=[Depends(verify_api_key)]) -async def get_cache_stats_api(request: Request): - """获取缓存统计""" - from app.services.grok.services.assets import DownloadService, ListService - from app.services.token.manager import get_token_manager - from app.services.grok.utils.batch import run_in_batches - - try: - dl_service = DownloadService() - image_stats = dl_service.get_stats("image") - video_stats = dl_service.get_stats("video") - - mgr = await get_token_manager() - pools = mgr.pools - accounts = [] - for pool_name, pool in pools.items(): - for info in pool.list(): - raw_token = ( - info.token[4:] if info.token.startswith("sso=") else info.token - ) - masked = ( - f"{raw_token[:8]}...{raw_token[-16:]}" - if len(raw_token) > 24 - else raw_token - ) - accounts.append( - { - "token": raw_token, - "token_masked": masked, - "pool": pool_name, - "status": info.status, - "last_asset_clear_at": info.last_asset_clear_at, - } - ) - - scope = request.query_params.get("scope") - selected_token = request.query_params.get("token") - tokens_param = request.query_params.get("tokens") - selected_tokens = [] - if tokens_param: - selected_tokens = [t.strip() for t in tokens_param.split(",") if t.strip()] - - online_stats = { - "count": 0, - "status": "unknown", - "token": None, - "last_asset_clear_at": None, - } - online_details = [] - account_map = {a["token"]: a for a in accounts} - max_concurrent = max(1, int(get_config("performance.assets_max_concurrent"))) - batch_size = max(1, int(get_config("performance.assets_batch_size"))) - max_tokens = int(get_config("performance.assets_max_tokens")) - - truncated = False - original_count = 0 - - async def _fetch_assets(token: str): - list_service = ListService() - try: - return await list_service.count(token) - finally: - await list_service.close() - - async def _fetch_detail(token: str): - account = account_map.get(token) - try: - count = await _fetch_assets(token) - return { - "detail": { - "token": token, - "token_masked": account["token_masked"] if account else token, - "count": count, - "status": "ok", - "last_asset_clear_at": account["last_asset_clear_at"] - if account - else None, - }, - "count": count, - } - except Exception as e: - return { - "detail": { - "token": token, - "token_masked": account["token_masked"] if account else token, - "count": 0, - "status": f"error: {str(e)}", - "last_asset_clear_at": account["last_asset_clear_at"] - if account - else None, - }, - "count": 0, - } - - if selected_tokens: - selected_tokens, truncated, original_count = _truncate_tokens( - selected_tokens, max_tokens, "Assets fetch" - ) - total = 0 - raw_results = await run_in_batches( - selected_tokens, - _fetch_detail, - max_concurrent=max_concurrent, - batch_size=batch_size, - ) - for token, res in raw_results.items(): - if res.get("ok"): - data = res.get("data", {}) - detail = data.get("detail") - total += data.get("count", 0) - else: - account = account_map.get(token) - detail = { - "token": token, - "token_masked": account["token_masked"] if account else token, - "count": 0, - "status": f"error: {res.get('error')}", - "last_asset_clear_at": account["last_asset_clear_at"] - if account - else None, - } - if detail: - online_details.append(detail) - online_stats = { - "count": total, - "status": "ok" if selected_tokens else "no_token", - "token": None, - "last_asset_clear_at": None, - } - scope = "selected" - elif scope == "all": - total = 0 - tokens = list(dict.fromkeys([account["token"] for account in accounts])) - original_count = len(tokens) - if len(tokens) > max_tokens: - tokens = tokens[:max_tokens] - truncated = True - raw_results = await run_in_batches( - tokens, - _fetch_detail, - max_concurrent=max_concurrent, - batch_size=batch_size, - ) - for token, res in raw_results.items(): - if res.get("ok"): - data = res.get("data", {}) - detail = data.get("detail") - total += data.get("count", 0) - else: - account = account_map.get(token) - detail = { - "token": token, - "token_masked": account["token_masked"] if account else token, - "count": 0, - "status": f"error: {res.get('error')}", - "last_asset_clear_at": account["last_asset_clear_at"] - if account - else None, - } - if detail: - online_details.append(detail) - online_stats = { - "count": total, - "status": "ok" if accounts else "no_token", - "token": None, - "last_asset_clear_at": None, - } - else: - token = selected_token - if token: - try: - count = await _fetch_assets(token) - match = next((a for a in accounts if a["token"] == token), None) - online_stats = { - "count": count, - "status": "ok", - "token": token, - "token_masked": match["token_masked"] if match else token, - "last_asset_clear_at": match["last_asset_clear_at"] - if match - else None, - } - except Exception as e: - match = next((a for a in accounts if a["token"] == token), None) - online_stats = { - "count": 0, - "status": f"error: {str(e)}", - "token": token, - "token_masked": match["token_masked"] if match else token, - "last_asset_clear_at": match["last_asset_clear_at"] - if match - else None, - } - else: - online_stats = { - "count": 0, - "status": "not_loaded", - "token": None, - "last_asset_clear_at": None, - } - - response = { - "local_image": image_stats, - "local_video": video_stats, - "online": online_stats, - "online_accounts": accounts, - "online_scope": scope or "none", - "online_details": online_details, - } - if truncated: - response["warning"] = ( - f"数量超出限制,仅处理前 {max_tokens} 个(共 {original_count} 个)" - ) - return response - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -@router.post( - "/api/v1/admin/cache/online/load/async", dependencies=[Depends(verify_api_key)] -) -async def load_online_cache_api_async(data: dict): - """在线资产统计(异步批量 + SSE 进度)""" - from app.services.grok.services.assets import DownloadService, ListService - from app.services.token.manager import get_token_manager - from app.services.grok.utils.batch import run_in_batches - - mgr = await get_token_manager() - - # 账号列表 - accounts = [] - for pool_name, pool in mgr.pools.items(): - for info in pool.list(): - raw_token = info.token[4:] if info.token.startswith("sso=") else info.token - masked = ( - f"{raw_token[:8]}...{raw_token[-16:]}" - if len(raw_token) > 24 - else raw_token - ) - accounts.append( - { - "token": raw_token, - "token_masked": masked, - "pool": pool_name, - "status": info.status, - "last_asset_clear_at": info.last_asset_clear_at, - } - ) - - account_map = {a["token"]: a for a in accounts} - - tokens = data.get("tokens") - scope = data.get("scope") - selected_tokens: list[str] = [] - if isinstance(tokens, list): - selected_tokens = [str(t).strip() for t in tokens if str(t).strip()] - - if not selected_tokens and scope == "all": - selected_tokens = [account["token"] for account in accounts] - scope = "all" - elif selected_tokens: - scope = "selected" - else: - raise HTTPException(status_code=400, detail="No tokens provided") - - max_tokens = int(get_config("performance.assets_max_tokens")) - selected_tokens, truncated, original_count = _truncate_tokens( - selected_tokens, max_tokens, "Assets load" - ) - - max_concurrent = get_config("performance.assets_max_concurrent") - batch_size = get_config("performance.assets_batch_size") - - task = create_task(len(selected_tokens)) - - async def _run(): - try: - dl_service = DownloadService() - image_stats = dl_service.get_stats("image") - video_stats = dl_service.get_stats("video") - - async def _fetch_detail(token: str): - account = account_map.get(token) - list_service = ListService() - try: - count = await list_service.count(token) - detail = { - "token": token, - "token_masked": account["token_masked"] if account else token, - "count": count, - "status": "ok", - "last_asset_clear_at": account["last_asset_clear_at"] - if account - else None, - } - return {"ok": True, "detail": detail, "count": count} - except Exception as e: - detail = { - "token": token, - "token_masked": account["token_masked"] if account else token, - "count": 0, - "status": f"error: {str(e)}", - "last_asset_clear_at": account["last_asset_clear_at"] - if account - else None, - } - return {"ok": False, "detail": detail, "count": 0} - finally: - await list_service.close() - - async def _on_item(item: str, res: dict): - ok = bool(res.get("data", {}).get("ok")) - task.record(ok) - - raw_results = await run_in_batches( - selected_tokens, - _fetch_detail, - max_concurrent=max_concurrent, - batch_size=batch_size, - on_item=_on_item, - should_cancel=lambda: task.cancelled, - ) - - if task.cancelled: - task.finish_cancelled() - return - - online_details = [] - total = 0 - for token, res in raw_results.items(): - data = res.get("data", {}) - detail = data.get("detail") - if detail: - online_details.append(detail) - total += data.get("count", 0) - - online_stats = { - "count": total, - "status": "ok" if selected_tokens else "no_token", - "token": None, - "last_asset_clear_at": None, - } - - result = { - "local_image": image_stats, - "local_video": video_stats, - "online": online_stats, - "online_accounts": accounts, - "online_scope": scope or "none", - "online_details": online_details, - } - warning = None - if truncated: - warning = ( - f"数量超出限制,仅处理前 {max_tokens} 个(共 {original_count} 个)" - ) - task.finish(result, warning=warning) - except Exception as e: - task.fail_task(str(e)) - finally: - asyncio.create_task(expire_task(task.id, 300)) - - asyncio.create_task(_run()) - - return { - "status": "success", - "task_id": task.id, - "total": len(selected_tokens), - } - - -@router.post("/api/v1/admin/cache/clear", dependencies=[Depends(verify_api_key)]) -async def clear_local_cache_api(data: dict): - """清理本地缓存""" - from app.services.grok.services.assets import DownloadService - - cache_type = data.get("type", "image") - - try: - dl_service = DownloadService() - result = dl_service.clear(cache_type) - return {"status": "success", "result": result} - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -@router.get("/api/v1/admin/cache/list", dependencies=[Depends(verify_api_key)]) -async def list_local_cache_api( - cache_type: str = "image", - type_: str = Query(default=None, alias="type"), - page: int = 1, - page_size: int = 1000, -): - """列出本地缓存文件""" - from app.services.grok.services.assets import DownloadService - - try: - if type_: - cache_type = type_ - dl_service = DownloadService() - result = dl_service.list_files(cache_type, page, page_size) - return {"status": "success", **result} - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -@router.post("/api/v1/admin/cache/item/delete", dependencies=[Depends(verify_api_key)]) -async def delete_local_cache_item_api(data: dict): - """删除单个本地缓存文件""" - from app.services.grok.services.assets import DownloadService - - cache_type = data.get("type", "image") - name = data.get("name") - if not name: - raise HTTPException(status_code=400, detail="Missing file name") - try: - dl_service = DownloadService() - result = dl_service.delete_file(cache_type, name) - return {"status": "success", "result": result} - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -@router.post("/api/v1/admin/cache/online/clear", dependencies=[Depends(verify_api_key)]) -async def clear_online_cache_api(data: dict): - """清理在线缓存""" - from app.services.grok.services.assets import DeleteService - from app.services.token.manager import get_token_manager - from app.services.grok.utils.batch import run_in_batches - - delete_service = None - try: - mgr = await get_token_manager() - tokens = data.get("tokens") - delete_service = DeleteService() - - if isinstance(tokens, list): - token_list = [t.strip() for t in tokens if isinstance(t, str) and t.strip()] - if not token_list: - raise HTTPException(status_code=400, detail="No tokens provided") - - # 去重并保持顺序 - token_list = list(dict.fromkeys(token_list)) - - # 最大数量限制 - max_tokens = int(get_config("performance.assets_max_tokens")) - token_list, truncated, original_count = _truncate_tokens( - token_list, max_tokens, "Clear online cache" - ) - - results = {} - max_concurrent = max( - 1, int(get_config("performance.assets_max_concurrent")) - ) - batch_size = max(1, int(get_config("performance.assets_batch_size"))) - - async def _clear_one(t: str): - try: - result = await delete_service.delete_all(t) - await mgr.mark_asset_clear(t) - return {"status": "success", "result": result} - except Exception as e: - return {"status": "error", "error": str(e)} - - raw_results = await run_in_batches( - token_list, - _clear_one, - max_concurrent=max_concurrent, - batch_size=batch_size, - ) - for token, res in raw_results.items(): - if res.get("ok"): - results[token] = res.get("data", {}) - else: - results[token] = {"status": "error", "error": res.get("error")} - - response = {"status": "success", "results": results} - if truncated: - response["warning"] = ( - f"数量超出限制,仅处理前 {max_tokens} 个(共 {original_count} 个)" - ) - return response - - token = data.get("token") or mgr.get_token() - if not token: - raise HTTPException( - status_code=400, detail="No available token to perform cleanup" - ) - - result = await delete_service.delete_all(token) - await mgr.mark_asset_clear(token) - return {"status": "success", "result": result} - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - finally: - if delete_service: - await delete_service.close() - - -@router.post( - "/api/v1/admin/cache/online/clear/async", dependencies=[Depends(verify_api_key)] -) -async def clear_online_cache_api_async(data: dict): - """清理在线缓存(异步批量 + SSE 进度)""" - from app.services.grok.services.assets import DeleteService - from app.services.token.manager import get_token_manager - from app.services.grok.utils.batch import run_in_batches - - mgr = await get_token_manager() - tokens = data.get("tokens") - if not isinstance(tokens, list): - raise HTTPException(status_code=400, detail="No tokens provided") - - token_list = [t.strip() for t in tokens if isinstance(t, str) and t.strip()] - if not token_list: - raise HTTPException(status_code=400, detail="No tokens provided") - - max_tokens = int(get_config("performance.assets_max_tokens")) - token_list, truncated, original_count = _truncate_tokens( - token_list, max_tokens, "Clear online cache async" - ) - - max_concurrent = get_config("performance.assets_max_concurrent") - batch_size = get_config("performance.assets_batch_size") - - task = create_task(len(token_list)) - - async def _run(): - delete_service = DeleteService() - try: - - async def _clear_one(t: str): - try: - result = await delete_service.delete_all(t) - await mgr.mark_asset_clear(t) - return {"ok": True, "result": result} - except Exception as e: - return {"ok": False, "error": str(e)} - - async def _on_item(item: str, res: dict): - ok = bool(res.get("data", {}).get("ok")) - task.record(ok) - - raw_results = await run_in_batches( - token_list, - _clear_one, - max_concurrent=max_concurrent, - batch_size=batch_size, - on_item=_on_item, - should_cancel=lambda: task.cancelled, - ) - - if task.cancelled: - task.finish_cancelled() - return - - results = {} - ok_count = 0 - fail_count = 0 - for token, res in raw_results.items(): - data = res.get("data", {}) - if data.get("ok"): - ok_count += 1 - results[token] = {"status": "success", "result": data.get("result")} - else: - fail_count += 1 - results[token] = {"status": "error", "error": data.get("error")} - - result = { - "status": "success", - "summary": { - "total": len(token_list), - "ok": ok_count, - "fail": fail_count, - }, - "results": results, - } - warning = None - if truncated: - warning = ( - f"数量超出限制,仅处理前 {max_tokens} 个(共 {original_count} 个)" - ) - task.finish(result, warning=warning) - except Exception as e: - task.fail_task(str(e)) - finally: - await delete_service.close() - asyncio.create_task(expire_task(task.id, 300)) - - asyncio.create_task(_run()) - - return { - "status": "success", - "task_id": task.id, - "total": len(token_list), - } diff --git a/app/api/v1/admin/__init__.py b/app/api/v1/admin/__init__.py new file mode 100644 index 00000000..63db93d7 --- /dev/null +++ b/app/api/v1/admin/__init__.py @@ -0,0 +1,15 @@ +"""Admin API router (app_key protected).""" + +from fastapi import APIRouter + +from app.api.v1.admin.cache import router as cache_router +from app.api.v1.admin.config import router as config_router +from app.api.v1.admin.token import router as tokens_router + +router = APIRouter() + +router.include_router(config_router) +router.include_router(tokens_router) +router.include_router(cache_router) + +__all__ = ["router"] diff --git a/app/api/v1/admin/cache.py b/app/api/v1/admin/cache.py new file mode 100644 index 00000000..0dc902a7 --- /dev/null +++ b/app/api/v1/admin/cache.py @@ -0,0 +1,445 @@ +from typing import List + +from fastapi import APIRouter, Depends, HTTPException, Query, Request + +from app.core.auth import verify_app_key +from app.core.batch import create_task, expire_task +from app.services.grok.batch_services.assets import ListService, DeleteService +from app.services.token.manager import get_token_manager +router = APIRouter() + + +@router.get("/cache", dependencies=[Depends(verify_app_key)]) +async def cache_stats(request: Request): + """获取缓存统计""" + from app.services.grok.utils.cache import CacheService + + try: + cache_service = CacheService() + image_stats = cache_service.get_stats("image") + video_stats = cache_service.get_stats("video") + + mgr = await get_token_manager() + pools = mgr.pools + accounts = [] + for pool_name, pool in pools.items(): + for info in pool.list(): + raw_token = ( + info.token[4:] if info.token.startswith("sso=") else info.token + ) + masked = ( + f"{raw_token[:8]}...{raw_token[-16:]}" + if len(raw_token) > 24 + else raw_token + ) + accounts.append( + { + "token": raw_token, + "token_masked": masked, + "pool": pool_name, + "status": info.status, + "last_asset_clear_at": info.last_asset_clear_at, + } + ) + + scope = request.query_params.get("scope") + selected_token = request.query_params.get("token") + tokens_param = request.query_params.get("tokens") + selected_tokens = [] + if tokens_param: + selected_tokens = [t.strip() for t in tokens_param.split(",") if t.strip()] + + online_stats = { + "count": 0, + "status": "unknown", + "token": None, + "last_asset_clear_at": None, + } + online_details = [] + account_map = {a["token"]: a for a in accounts} + if selected_tokens: + total = 0 + raw_results = await ListService.fetch_assets_details( + selected_tokens, + account_map, + ) + for token, res in raw_results.items(): + if res.get("ok"): + data = res.get("data", {}) + detail = data.get("detail") + total += data.get("count", 0) + else: + account = account_map.get(token) + detail = { + "token": token, + "token_masked": account["token_masked"] if account else token, + "count": 0, + "status": f"error: {res.get('error')}", + "last_asset_clear_at": account["last_asset_clear_at"] + if account + else None, + } + if detail: + online_details.append(detail) + online_stats = { + "count": total, + "status": "ok" if selected_tokens else "no_token", + "token": None, + "last_asset_clear_at": None, + } + scope = "selected" + elif scope == "all": + total = 0 + tokens = list(dict.fromkeys([account["token"] for account in accounts])) + raw_results = await ListService.fetch_assets_details( + tokens, + account_map, + ) + for token, res in raw_results.items(): + if res.get("ok"): + data = res.get("data", {}) + detail = data.get("detail") + total += data.get("count", 0) + else: + account = account_map.get(token) + detail = { + "token": token, + "token_masked": account["token_masked"] if account else token, + "count": 0, + "status": f"error: {res.get('error')}", + "last_asset_clear_at": account["last_asset_clear_at"] + if account + else None, + } + if detail: + online_details.append(detail) + online_stats = { + "count": total, + "status": "ok" if accounts else "no_token", + "token": None, + "last_asset_clear_at": None, + } + else: + token = selected_token + if token: + raw_results = await ListService.fetch_assets_details( + [token], + account_map, + ) + res = raw_results.get(token, {}) + data = res.get("data", {}) + detail = data.get("detail") if res.get("ok") else None + if detail: + online_stats = { + "count": data.get("count", 0), + "status": detail.get("status", "ok"), + "token": detail.get("token"), + "token_masked": detail.get("token_masked"), + "last_asset_clear_at": detail.get("last_asset_clear_at"), + } + else: + match = next((a for a in accounts if a["token"] == token), None) + online_stats = { + "count": 0, + "status": f"error: {res.get('error')}", + "token": token, + "token_masked": match["token_masked"] if match else token, + "last_asset_clear_at": match["last_asset_clear_at"] + if match + else None, + } + else: + online_stats = { + "count": 0, + "status": "not_loaded", + "token": None, + "last_asset_clear_at": None, + } + + response = { + "local_image": image_stats, + "local_video": video_stats, + "online": online_stats, + "online_accounts": accounts, + "online_scope": scope or "none", + "online_details": online_details, + } + return response + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/cache/list", dependencies=[Depends(verify_app_key)]) +async def list_local( + cache_type: str = "image", + type_: str = Query(default=None, alias="type"), + page: int = 1, + page_size: int = 1000, +): + """列出本地缓存文件""" + from app.services.grok.utils.cache import CacheService + + try: + if type_: + cache_type = type_ + cache_service = CacheService() + result = cache_service.list_files(cache_type, page, page_size) + return {"status": "success", **result} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/cache/clear", dependencies=[Depends(verify_app_key)]) +async def clear_local(data: dict): + """清理本地缓存""" + from app.services.grok.utils.cache import CacheService + + cache_type = data.get("type", "image") + + try: + cache_service = CacheService() + result = cache_service.clear(cache_type) + return {"status": "success", "result": result} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/cache/item/delete", dependencies=[Depends(verify_app_key)]) +async def delete_local_item(data: dict): + """删除单个本地缓存文件""" + from app.services.grok.utils.cache import CacheService + + cache_type = data.get("type", "image") + name = data.get("name") + if not name: + raise HTTPException(status_code=400, detail="Missing file name") + try: + cache_service = CacheService() + result = cache_service.delete_file(cache_type, name) + return {"status": "success", "result": result} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/cache/online/clear", dependencies=[Depends(verify_app_key)]) +async def clear_online(data: dict): + """清理在线缓存""" + try: + mgr = await get_token_manager() + tokens = data.get("tokens") + + if isinstance(tokens, list): + token_list = [t.strip() for t in tokens if isinstance(t, str) and t.strip()] + if not token_list: + raise HTTPException(status_code=400, detail="No tokens provided") + + token_list = list(dict.fromkeys(token_list)) + + results = {} + raw_results = await DeleteService.clear_assets( + token_list, + mgr, + ) + for token, res in raw_results.items(): + if res.get("ok"): + results[token] = res.get("data", {}) + else: + results[token] = {"status": "error", "error": res.get("error")} + + return {"status": "success", "results": results} + + token = data.get("token") or mgr.get_token() + if not token: + raise HTTPException( + status_code=400, detail="No available token to perform cleanup" + ) + + raw_results = await DeleteService.clear_assets( + [token], + mgr, + ) + res = raw_results.get(token, {}) + data = res.get("data", {}) + if res.get("ok") and data.get("status") == "success": + return {"status": "success", "result": data.get("result")} + return {"status": "error", "error": data.get("error") or res.get("error")} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/cache/online/clear/async", dependencies=[Depends(verify_app_key)]) +async def clear_online_async(data: dict): + """清理在线缓存(异步批量 + SSE 进度)""" + mgr = await get_token_manager() + tokens = data.get("tokens") + if not isinstance(tokens, list): + raise HTTPException(status_code=400, detail="No tokens provided") + + token_list = [t.strip() for t in tokens if isinstance(t, str) and t.strip()] + if not token_list: + raise HTTPException(status_code=400, detail="No tokens provided") + + task = create_task(len(token_list)) + + async def _run(): + try: + async def _on_item(item: str, res: dict): + ok = bool(res.get("data", {}).get("ok")) + task.record(ok) + + raw_results = await DeleteService.clear_assets( + token_list, + mgr, + include_ok=True, + on_item=_on_item, + should_cancel=lambda: task.cancelled, + ) + + if task.cancelled: + task.finish_cancelled() + return + + results = {} + ok_count = 0 + fail_count = 0 + for token, res in raw_results.items(): + data = res.get("data", {}) + if data.get("ok"): + ok_count += 1 + results[token] = {"status": "success", "result": data.get("result")} + else: + fail_count += 1 + results[token] = {"status": "error", "error": data.get("error")} + + result = { + "status": "success", + "summary": { + "total": len(token_list), + "ok": ok_count, + "fail": fail_count, + }, + "results": results, + } + task.finish(result) + except Exception as e: + task.fail_task(str(e)) + finally: + import asyncio + asyncio.create_task(expire_task(task.id, 300)) + + import asyncio + asyncio.create_task(_run()) + + return { + "status": "success", + "task_id": task.id, + "total": len(token_list), + } + + +@router.post("/cache/online/load/async", dependencies=[Depends(verify_app_key)]) +async def load_cache_async(data: dict): + """在线资产统计(异步批量 + SSE 进度)""" + from app.services.grok.utils.cache import CacheService + + mgr = await get_token_manager() + + accounts = [] + for pool_name, pool in mgr.pools.items(): + for info in pool.list(): + raw_token = info.token[4:] if info.token.startswith("sso=") else info.token + masked = ( + f"{raw_token[:8]}...{raw_token[-16:]}" + if len(raw_token) > 24 + else raw_token + ) + accounts.append( + { + "token": raw_token, + "token_masked": masked, + "pool": pool_name, + "status": info.status, + "last_asset_clear_at": info.last_asset_clear_at, + } + ) + + account_map = {a["token"]: a for a in accounts} + + tokens = data.get("tokens") + scope = data.get("scope") + selected_tokens: List[str] = [] + if isinstance(tokens, list): + selected_tokens = [str(t).strip() for t in tokens if str(t).strip()] + + if not selected_tokens and scope == "all": + selected_tokens = [account["token"] for account in accounts] + scope = "all" + elif selected_tokens: + scope = "selected" + else: + raise HTTPException(status_code=400, detail="No tokens provided") + + task = create_task(len(selected_tokens)) + + async def _run(): + try: + cache_service = CacheService() + image_stats = cache_service.get_stats("image") + video_stats = cache_service.get_stats("video") + + async def _on_item(item: str, res: dict): + ok = bool(res.get("data", {}).get("ok")) + task.record(ok) + + raw_results = await ListService.fetch_assets_details( + selected_tokens, + account_map, + include_ok=True, + on_item=_on_item, + should_cancel=lambda: task.cancelled, + ) + + if task.cancelled: + task.finish_cancelled() + return + + online_details = [] + total = 0 + for token, res in raw_results.items(): + data = res.get("data", {}) + detail = data.get("detail") + if detail: + online_details.append(detail) + total += data.get("count", 0) + + online_stats = { + "count": total, + "status": "ok" if selected_tokens else "no_token", + "token": None, + "last_asset_clear_at": None, + } + + result = { + "local_image": image_stats, + "local_video": video_stats, + "online": online_stats, + "online_accounts": accounts, + "online_scope": scope or "none", + "online_details": online_details, + } + task.finish(result) + except Exception as e: + task.fail_task(str(e)) + finally: + import asyncio + asyncio.create_task(expire_task(task.id, 300)) + + import asyncio + asyncio.create_task(_run()) + + return { + "status": "success", + "task_id": task.id, + "total": len(selected_tokens), + } + diff --git a/app/api/v1/admin/config.py b/app/api/v1/admin/config.py new file mode 100644 index 00000000..f843b76b --- /dev/null +++ b/app/api/v1/admin/config.py @@ -0,0 +1,53 @@ +import os + +from fastapi import APIRouter, Depends, HTTPException + +from app.core.auth import verify_app_key +from app.core.config import config +from app.core.storage import get_storage, LocalStorage, RedisStorage, SQLStorage + +router = APIRouter() + + +@router.get("/verify", dependencies=[Depends(verify_app_key)]) +async def admin_verify(): + """验证后台访问密钥(app_key)""" + return {"status": "success"} + + +@router.get("/config", dependencies=[Depends(verify_app_key)]) +async def get_config(): + """获取当前配置""" + # 暴露原始配置字典 + return config._config + + +@router.post("/config", dependencies=[Depends(verify_app_key)]) +async def update_config(data: dict): + """更新配置""" + try: + await config.update(data) + return {"status": "success", "message": "配置已更新"} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/storage", dependencies=[Depends(verify_app_key)]) +async def get_storage(): + """获取当前存储模式""" + storage_type = os.getenv("SERVER_STORAGE_TYPE", "").lower() + if not storage_type: + storage = get_storage() + if isinstance(storage, LocalStorage): + storage_type = "local" + elif isinstance(storage, RedisStorage): + storage_type = "redis" + elif isinstance(storage, SQLStorage): + storage_type = { + "mysql": "mysql", + "mariadb": "mysql", + "postgres": "pgsql", + "postgresql": "pgsql", + "pgsql": "pgsql", + }.get(storage.dialect, storage.dialect) + return {"type": storage_type or "local"} diff --git a/app/api/v1/admin/token.py b/app/api/v1/admin/token.py new file mode 100644 index 00000000..81b9fef2 --- /dev/null +++ b/app/api/v1/admin/token.py @@ -0,0 +1,395 @@ +import asyncio + +import orjson +from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi.responses import StreamingResponse + +from app.core.auth import get_app_key, verify_app_key +from app.core.batch import create_task, expire_task, get_task +from app.core.logger import logger +from app.core.storage import get_storage +from app.services.grok.batch_services.usage import UsageService +from app.services.grok.batch_services.nsfw import NSFWService +from app.services.token.manager import get_token_manager + +router = APIRouter() + + +@router.get("/tokens", dependencies=[Depends(verify_app_key)]) +async def get_tokens(): + """获取所有 Token""" + storage = get_storage() + tokens = await storage.load_tokens() + return tokens or {} + + +@router.post("/tokens", dependencies=[Depends(verify_app_key)]) +async def update_tokens(data: dict): + """更新 Token 信息""" + storage = get_storage() + try: + from app.services.token.models import TokenInfo + + async with storage.acquire_lock("tokens_save", timeout=10): + existing = await storage.load_tokens() or {} + normalized = {} + allowed_fields = set(TokenInfo.model_fields.keys()) + existing_map = {} + for pool_name, tokens in existing.items(): + if not isinstance(tokens, list): + continue + pool_map = {} + for item in tokens: + if isinstance(item, str): + token_data = {"token": item} + elif isinstance(item, dict): + token_data = dict(item) + else: + continue + raw_token = token_data.get("token") + if isinstance(raw_token, str) and raw_token.startswith("sso="): + token_data["token"] = raw_token[4:] + token_key = token_data.get("token") + if isinstance(token_key, str): + pool_map[token_key] = token_data + existing_map[pool_name] = pool_map + for pool_name, tokens in (data or {}).items(): + if not isinstance(tokens, list): + continue + pool_list = [] + for item in tokens: + if isinstance(item, str): + token_data = {"token": item} + elif isinstance(item, dict): + token_data = dict(item) + else: + continue + + raw_token = token_data.get("token") + if isinstance(raw_token, str) and raw_token.startswith("sso="): + token_data["token"] = raw_token[4:] + + base = existing_map.get(pool_name, {}).get( + token_data.get("token"), {} + ) + merged = dict(base) + merged.update(token_data) + if merged.get("tags") is None: + merged["tags"] = [] + + filtered = {k: v for k, v in merged.items() if k in allowed_fields} + try: + info = TokenInfo(**filtered) + pool_list.append(info.model_dump()) + except Exception as e: + logger.warning(f"Skip invalid token in pool '{pool_name}': {e}") + continue + normalized[pool_name] = pool_list + + await storage.save_tokens(normalized) + mgr = await get_token_manager() + await mgr.reload() + return {"status": "success", "message": "Token 已更新"} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/tokens/refresh", dependencies=[Depends(verify_app_key)]) +async def refresh_tokens(data: dict): + """刷新 Token 状态""" + try: + mgr = await get_token_manager() + tokens = [] + if isinstance(data.get("token"), str) and data["token"].strip(): + tokens.append(data["token"].strip()) + if isinstance(data.get("tokens"), list): + tokens.extend([str(t).strip() for t in data["tokens"] if str(t).strip()]) + + if not tokens: + raise HTTPException(status_code=400, detail="No tokens provided") + + unique_tokens = list(dict.fromkeys(tokens)) + + raw_results = await UsageService.batch( + unique_tokens, + mgr, + ) + + results = {} + for token, res in raw_results.items(): + if res.get("ok"): + results[token] = res.get("data", False) + else: + results[token] = False + + response = {"status": "success", "results": results} + return response + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/tokens/refresh/async", dependencies=[Depends(verify_app_key)]) +async def refresh_tokens_async(data: dict): + """刷新 Token 状态(异步批量 + SSE 进度)""" + mgr = await get_token_manager() + tokens = [] + if isinstance(data.get("token"), str) and data["token"].strip(): + tokens.append(data["token"].strip()) + if isinstance(data.get("tokens"), list): + tokens.extend([str(t).strip() for t in data["tokens"] if str(t).strip()]) + + if not tokens: + raise HTTPException(status_code=400, detail="No tokens provided") + + unique_tokens = list(dict.fromkeys(tokens)) + + task = create_task(len(unique_tokens)) + + async def _run(): + try: + + async def _on_item(item: str, res: dict): + task.record(bool(res.get("ok"))) + + raw_results = await UsageService.batch( + unique_tokens, + mgr, + on_item=_on_item, + should_cancel=lambda: task.cancelled, + ) + + if task.cancelled: + task.finish_cancelled() + return + + results: dict[str, bool] = {} + ok_count = 0 + fail_count = 0 + for token, res in raw_results.items(): + if res.get("ok") and res.get("data") is True: + ok_count += 1 + results[token] = True + else: + fail_count += 1 + results[token] = False + + await mgr._save() + + result = { + "status": "success", + "summary": { + "total": len(unique_tokens), + "ok": ok_count, + "fail": fail_count, + }, + "results": results, + } + task.finish(result) + except Exception as e: + task.fail_task(str(e)) + finally: + import asyncio + asyncio.create_task(expire_task(task.id, 300)) + + import asyncio + asyncio.create_task(_run()) + + return { + "status": "success", + "task_id": task.id, + "total": len(unique_tokens), + } + + +@router.get("/batch/{task_id}/stream") +async def batch_stream(task_id: str, request: Request): + app_key = get_app_key() + if app_key: + key = request.query_params.get("app_key") + if key != app_key: + raise HTTPException(status_code=401, detail="Invalid authentication token") + task = get_task(task_id) + if not task: + raise HTTPException(status_code=404, detail="Task not found") + + async def event_stream(): + queue = task.attach() + try: + yield f"data: {orjson.dumps({'type': 'snapshot', **task.snapshot()}).decode()}\n\n" + + final = task.final_event() + if final: + yield f"data: {orjson.dumps(final).decode()}\n\n" + return + + while True: + try: + event = await asyncio.wait_for(queue.get(), timeout=15) + except asyncio.TimeoutError: + yield ": ping\n\n" + final = task.final_event() + if final: + yield f"data: {orjson.dumps(final).decode()}\n\n" + return + continue + + yield f"data: {orjson.dumps(event).decode()}\n\n" + if event.get("type") in ("done", "error", "cancelled"): + return + finally: + task.detach(queue) + + return StreamingResponse(event_stream(), media_type="text/event-stream") + + +@router.post("/batch/{task_id}/cancel", dependencies=[Depends(verify_app_key)]) +async def batch_cancel(task_id: str): + task = get_task(task_id) + if not task: + raise HTTPException(status_code=404, detail="Task not found") + task.cancel() + return {"status": "success"} + + +@router.post("/tokens/nsfw/enable", dependencies=[Depends(verify_app_key)]) +async def enable_nsfw(data: dict): + """批量开启 NSFW (Unhinged) 模式""" + try: + mgr = await get_token_manager() + + tokens = [] + if isinstance(data.get("token"), str) and data["token"].strip(): + tokens.append(data["token"].strip()) + if isinstance(data.get("tokens"), list): + tokens.extend([str(t).strip() for t in data["tokens"] if str(t).strip()]) + + if not tokens: + for pool_name, pool in mgr.pools.items(): + for info in pool.list(): + raw = ( + info.token[4:] if info.token.startswith("sso=") else info.token + ) + tokens.append(raw) + + if not tokens: + raise HTTPException(status_code=400, detail="No tokens available") + + unique_tokens = list(dict.fromkeys(tokens)) + + raw_results = await NSFWService.batch( + unique_tokens, + mgr, + ) + + results = {} + ok_count = 0 + fail_count = 0 + + for token, res in raw_results.items(): + masked = f"{token[:8]}...{token[-8:]}" if len(token) > 20 else token + if res.get("ok") and res.get("data", {}).get("success"): + ok_count += 1 + results[masked] = res.get("data", {}) + else: + fail_count += 1 + results[masked] = res.get("data") or {"error": res.get("error")} + + response = { + "status": "success", + "summary": { + "total": len(unique_tokens), + "ok": ok_count, + "fail": fail_count, + }, + "results": results, + } + + return response + + except HTTPException: + raise + except Exception as e: + logger.error(f"Enable NSFW failed: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/tokens/nsfw/enable/async", dependencies=[Depends(verify_app_key)]) +async def enable_nsfw_async(data: dict): + """批量开启 NSFW (Unhinged) 模式(异步批量 + SSE 进度)""" + mgr = await get_token_manager() + + tokens = [] + if isinstance(data.get("token"), str) and data["token"].strip(): + tokens.append(data["token"].strip()) + if isinstance(data.get("tokens"), list): + tokens.extend([str(t).strip() for t in data["tokens"] if str(t).strip()]) + + if not tokens: + for pool_name, pool in mgr.pools.items(): + for info in pool.list(): + raw = info.token[4:] if info.token.startswith("sso=") else info.token + tokens.append(raw) + + if not tokens: + raise HTTPException(status_code=400, detail="No tokens available") + + unique_tokens = list(dict.fromkeys(tokens)) + + task = create_task(len(unique_tokens)) + + async def _run(): + try: + + async def _on_item(item: str, res: dict): + ok = bool(res.get("ok") and res.get("data", {}).get("success")) + task.record(ok) + + raw_results = await NSFWService.batch( + unique_tokens, + mgr, + on_item=_on_item, + should_cancel=lambda: task.cancelled, + ) + + if task.cancelled: + task.finish_cancelled() + return + + results = {} + ok_count = 0 + fail_count = 0 + for token, res in raw_results.items(): + masked = f"{token[:8]}...{token[-8:]}" if len(token) > 20 else token + if res.get("ok") and res.get("data", {}).get("success"): + ok_count += 1 + results[masked] = res.get("data", {}) + else: + fail_count += 1 + results[masked] = res.get("data") or {"error": res.get("error")} + + await mgr._save() + + result = { + "status": "success", + "summary": { + "total": len(unique_tokens), + "ok": ok_count, + "fail": fail_count, + }, + "results": results, + } + task.finish(result) + except Exception as e: + task.fail_task(str(e)) + finally: + import asyncio + asyncio.create_task(expire_task(task.id, 300)) + + import asyncio + asyncio.create_task(_run()) + + return { + "status": "success", + "task_id": task.id, + "total": len(unique_tokens), + } diff --git a/app/api/v1/chat.py b/app/api/v1/chat.py index 6c420567..66f07a31 100644 --- a/app/api/v1/chat.py +++ b/app/api/v1/chat.py @@ -3,23 +3,22 @@ """ from typing import Any, Dict, List, Optional, Union +import base64 +import binascii +import time from fastapi import APIRouter from fastapi.responses import StreamingResponse, JSONResponse -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field from app.services.grok.services.chat import ChatService -from app.services.grok.models.model import ModelService -from app.core.exceptions import ValidationException - - -router = APIRouter(tags=["Chat"]) - - -VALID_ROLES = ["developer", "system", "user", "assistant", "tool"] -# 角色别名映射 (OpenAI 兼容: function -> tool) -ROLE_ALIASES = {"function": "tool"} -USER_CONTENT_TYPES = ["text", "image_url", "input_audio", "file"] +from app.services.grok.services.image import ImageGenerationService +from app.services.grok.services.image_edit import ImageEditService +from app.services.grok.services.model import ModelService +from app.services.grok.services.video import VideoService +from app.services.token import get_token_manager +from app.core.config import get_config +from app.core.exceptions import ValidationException, AppException, ErrorType class MessageItem(BaseModel): @@ -27,81 +26,22 @@ class MessageItem(BaseModel): role: str content: Union[str, List[Dict[str, Any]]] - tool_call_id: Optional[str] = None # tool 角色需要的字段 - name: Optional[str] = None # function 角色的函数名 - - @field_validator("role") - @classmethod - def validate_role(cls, v): - # 大小写归一化 - v_lower = v.lower() if isinstance(v, str) else v - # 别名映射 - v_normalized = ROLE_ALIASES.get(v_lower, v_lower) - if v_normalized not in VALID_ROLES: - raise ValueError(f"role must be one of {VALID_ROLES}") - return v_normalized class VideoConfig(BaseModel): """视频生成配置""" - aspect_ratio: Optional[str] = Field( - "3:2", description="视频比例: 3:2, 16:9, 1:1 等" - ) + aspect_ratio: Optional[str] = Field("3:2", description="视频比例: 1280x720(16:9), 720x1280(9:16), 1792x1024(3:2), 1024x1792(2:3), 1024x1024(1:1)") video_length: Optional[int] = Field(6, description="视频时长(秒): 6 / 10 / 15") resolution_name: Optional[str] = Field("480p", description="视频分辨率: 480p, 720p") preset: Optional[str] = Field("custom", description="风格预设: fun, normal, spicy") - @field_validator("aspect_ratio") - @classmethod - def validate_aspect_ratio(cls, v): - allowed = ["2:3", "3:2", "1:1", "9:16", "16:9"] - if v and v not in allowed: - raise ValidationException( - message=f"aspect_ratio must be one of {allowed}", - param="video_config.aspect_ratio", - code="invalid_aspect_ratio", - ) - return v +class ImageConfig(BaseModel): + """图片生成配置""" - @field_validator("video_length") - @classmethod - def validate_video_length(cls, v): - if v is not None: - if v not in (6, 10, 15): - raise ValidationException( - message="video_length must be 6, 10, or 15 seconds", - param="video_config.video_length", - code="invalid_video_length", - ) - return v - - @field_validator("resolution_name") - @classmethod - def validate_resolution(cls, v): - allowed = ["480p", "720p"] - if v and v not in allowed: - raise ValidationException( - message=f"resolution_name must be one of {allowed}", - param="video_config.resolution_name", - code="invalid_resolution", - ) - return v - - @field_validator("preset") - @classmethod - def validate_preset(cls, v): - # 允许为空,默认 custom - if not v: - return "custom" - allowed = ["fun", "normal", "spicy", "custom"] - if v not in allowed: - raise ValidationException( - message=f"preset must be one of {allowed}", - param="video_config.preset", - code="invalid_preset", - ) - return v + n: Optional[int] = Field(1, ge=1, le=10, description="生成数量 (1-10)") + size: Optional[str] = Field("1024x1024", description="图片尺寸") + response_format: Optional[str] = Field(None, description="响应格式") class ChatCompletionRequest(BaseModel): @@ -110,36 +50,137 @@ class ChatCompletionRequest(BaseModel): model: str = Field(..., description="模型名称") messages: List[MessageItem] = Field(..., description="消息数组") stream: Optional[bool] = Field(None, description="是否流式输出") - thinking: Optional[str] = Field(None, description="思考模式: enabled/disabled/None") - + reasoning_effort: Optional[str] = Field(None, description="推理强度: none/minimal/low/medium/high/xhigh") + temperature: Optional[float] = Field(0.8, description="采样温度: 0-2") + top_p: Optional[float] = Field(0.95, description="nucleus 采样: 0-1") # 视频生成配置 video_config: Optional[VideoConfig] = Field(None, description="视频生成参数") + # 图片生成配置 + image_config: Optional[ImageConfig] = Field(None, description="图片生成参数") - @field_validator("stream", mode="before") - @classmethod - def validate_stream(cls, v): - """确保 stream 参数被正确解析为布尔值""" - if v is None: - return None - if isinstance(v, bool): - return v - if isinstance(v, str): - if v.lower() in ("true", "1", "yes"): - return True - if v.lower() in ("false", "0", "no"): - return False - # 未识别的字符串值抛出错误 - raise ValueError( - f"Invalid stream value '{v}'. Must be a boolean or one of: true, false, 1, 0, yes, no" - ) - # 非布尔非字符串类型抛出错误 - raise ValueError( - f"Invalid stream value type '{type(v).__name__}'. Must be a boolean or string." + +VALID_ROLES = {"developer", "system", "user", "assistant"} +USER_CONTENT_TYPES = {"text", "image_url", "input_audio", "file"} +ALLOWED_IMAGE_SIZES = { + "1280x720", + "720x1280", + "1792x1024", + "1024x1792", + "1024x1024", +} + + +def _validate_media_input(value: str, field_name: str, param: str): + """Verify media input is a valid URL or data URI""" + if not isinstance(value, str) or not value.strip(): + raise ValidationException( + message=f"{field_name} cannot be empty", + param=param, + code="empty_media", ) + value = value.strip() + if value.startswith("data:"): + return + if value.startswith("http://") or value.startswith("https://"): + return + candidate = "".join(value.split()) + if len(candidate) >= 32 and len(candidate) % 4 == 0: + try: + base64.b64decode(candidate, validate=True) + raise ValidationException( + message=f"{field_name} base64 must be provided as a data URI (data:;base64,...)", + param=param, + code="invalid_media", + ) + except binascii.Error: + pass + raise ValidationException( + message=f"{field_name} must be a URL or data URI", + param=param, + code="invalid_media", + ) + + +def _extract_prompt_images(messages: List[MessageItem]) -> tuple[str, List[str]]: + """Extract prompt text and image URLs from messages""" + last_text = "" + image_urls: List[str] = [] + + for msg in messages: + role = msg.role or "user" + content = msg.content + if isinstance(content, str): + text = content.strip() + if text: + last_text = text + continue + if not isinstance(content, list): + continue + for block in content: + if not isinstance(block, dict): + continue + block_type = block.get("type") + if block_type == "text": + text = block.get("text", "") + if isinstance(text, str) and text.strip(): + last_text = text.strip() + elif block_type == "image_url" and role == "user": + image = block.get("image_url") or {} + url = image.get("url", "") + if isinstance(url, str) and url.strip(): + image_urls.append(url.strip()) + + return last_text, image_urls + + +def _resolve_image_format(value: Optional[str]) -> str: + fmt = value or get_config("app.image_format") or "url" + if isinstance(fmt, str): + fmt = fmt.lower() + if fmt == "base64": + return "b64_json" + if fmt in ("b64_json", "url"): + return fmt + raise ValidationException( + message="image_format must be one of url, base64, b64_json", + param="image_format", + code="invalid_image_format", + ) - model_config = {"extra": "ignore"} +def _image_field(response_format: str) -> str: + if response_format == "url": + return "url" + return "b64_json" +def _validate_image_config(image_conf: ImageConfig, *, stream: bool): + n = image_conf.n or 1 + if n < 1 or n > 10: + raise ValidationException( + message="n must be between 1 and 10", + param="image_config.n", + code="invalid_n", + ) + if stream and n not in (1, 2): + raise ValidationException( + message="Streaming is only supported when n=1 or n=2", + param="image_config.n", + code="invalid_stream_n", + ) + if image_conf.response_format: + allowed_formats = {"b64_json", "base64", "url"} + if image_conf.response_format not in allowed_formats: + raise ValidationException( + message="response_format must be one of b64_json, base64, url", + param="image_config.response_format", + code="invalid_response_format", + ) + if image_conf.size and image_conf.size not in ALLOWED_IMAGE_SIZES: + raise ValidationException( + message=f"size must be one of {sorted(ALLOWED_IMAGE_SIZES)}", + param="image_config.size", + code="invalid_size", + ) def validate_request(request: ChatCompletionRequest): """验证请求参数""" # 验证模型 @@ -152,6 +193,12 @@ def validate_request(request: ChatCompletionRequest): # 验证消息 for idx, msg in enumerate(request.messages): + if not isinstance(msg.role, str) or msg.role not in VALID_ROLES: + raise ValidationException( + message=f"role must be one of {sorted(VALID_ROLES)}", + param=f"messages.{idx}.role", + code="invalid_role", + ) content = msg.content # 字符串内容 @@ -174,6 +221,12 @@ def validate_request(request: ChatCompletionRequest): for block_idx, block in enumerate(content): # 检查空对象 + if not isinstance(block, dict): + raise ValidationException( + message="Content block must be an object", + param=f"messages.{idx}.content.{block_idx}", + code="invalid_block", + ) if not block: raise ValidationException( message="Content block cannot be empty", @@ -211,20 +264,13 @@ def validate_request(request: ChatCompletionRequest): param=f"messages.{idx}.content.{block_idx}.type", code="invalid_type", ) - elif msg.role in ("tool", "function"): - # tool/function 角色只支持 text 类型,但内容可以是 JSON 字符串 + else: if block_type != "text": raise ValidationException( message=f"The `{msg.role}` role only supports 'text' type, got '{block_type}'", param=f"messages.{idx}.content.{block_idx}.type", code="invalid_type", ) - elif block_type != "text": - raise ValidationException( - message=f"The `{msg.role}` role only supports 'text' type, got '{block_type}'", - param=f"messages.{idx}.content.{block_idx}.type", - code="invalid_type", - ) # 验证字段是否存在 & 非空 if block_type == "text": @@ -237,14 +283,221 @@ def validate_request(request: ChatCompletionRequest): ) elif block_type == "image_url": image_url = block.get("image_url") - if not image_url or not ( - isinstance(image_url, dict) and image_url.get("url") - ): + if not image_url or not isinstance(image_url, dict): raise ValidationException( message="image_url must have a 'url' field", param=f"messages.{idx}.content.{block_idx}.image_url", code="missing_url", ) + _validate_media_input( + image_url.get("url", ""), + "image_url.url", + f"messages.{idx}.content.{block_idx}.image_url.url", + ) + elif block_type == "input_audio": + audio = block.get("input_audio") + if not audio or not isinstance(audio, dict): + raise ValidationException( + message="input_audio must have a 'data' field", + param=f"messages.{idx}.content.{block_idx}.input_audio", + code="missing_audio", + ) + _validate_media_input( + audio.get("data", ""), + "input_audio.data", + f"messages.{idx}.content.{block_idx}.input_audio.data", + ) + elif block_type == "file": + file_data = block.get("file") + if not file_data or not isinstance(file_data, dict): + raise ValidationException( + message="file must have a 'file_data' field", + param=f"messages.{idx}.content.{block_idx}.file", + code="missing_file", + ) + _validate_media_input( + file_data.get("file_data", ""), + "file.file_data", + f"messages.{idx}.content.{block_idx}.file.file_data", + ) + else: + raise ValidationException( + message="Message content must be a string or array", + param=f"messages.{idx}.content", + code="invalid_content", + ) + + # 默认验证 + if request.stream is not None: + if isinstance(request.stream, bool): + pass + elif isinstance(request.stream, str): + if request.stream.lower() in ("true", "1", "yes"): + request.stream = True + elif request.stream.lower() in ("false", "0", "no"): + request.stream = False + else: + raise ValidationException( + message="stream must be a boolean", + param="stream", + code="invalid_stream", + ) + else: + raise ValidationException( + message="stream must be a boolean", + param="stream", + code="invalid_stream", + ) + + allowed_efforts = {"none", "minimal", "low", "medium", "high", "xhigh"} + if request.reasoning_effort is not None: + if not isinstance(request.reasoning_effort, str) or ( + request.reasoning_effort not in allowed_efforts + ): + raise ValidationException( + message=f"reasoning_effort must be one of {sorted(allowed_efforts)}", + param="reasoning_effort", + code="invalid_reasoning_effort", + ) + + if request.temperature is None: + request.temperature = 0.8 + else: + try: + request.temperature = float(request.temperature) + except Exception: + raise ValidationException( + message="temperature must be a float", + param="temperature", + code="invalid_temperature", + ) + if not (0 <= request.temperature <= 2): + raise ValidationException( + message="temperature must be between 0 and 2", + param="temperature", + code="invalid_temperature", + ) + + if request.top_p is None: + request.top_p = 0.95 + else: + try: + request.top_p = float(request.top_p) + except Exception: + raise ValidationException( + message="top_p must be a float", + param="top_p", + code="invalid_top_p", + ) + if not (0 <= request.top_p <= 1): + raise ValidationException( + message="top_p must be between 0 and 1", + param="top_p", + code="invalid_top_p", + ) + + model_info = ModelService.get(request.model) + # image 验证 + if model_info and (model_info.is_image or model_info.is_image_edit): + prompt, image_urls = _extract_prompt_images(request.messages) + if not prompt: + raise ValidationException( + message="Prompt cannot be empty", + param="messages", + code="empty_prompt", + ) + image_conf = request.image_config or ImageConfig() + n = image_conf.n or 1 + if not (1 <= n <= 10): + raise ValidationException( + message="n must be between 1 and 10", + param="image_config.n", + code="invalid_n", + ) + if request.stream and n not in (1, 2): + raise ValidationException( + message="Streaming is only supported when n=1 or n=2", + param="stream", + code="invalid_stream_n", + ) + + response_format = _resolve_image_format(image_conf.response_format) + image_conf.n = n + image_conf.response_format = response_format + if not image_conf.size: + image_conf.size = "1024x1024" + allowed_sizes = { + "1280x720", + "720x1280", + "1792x1024", + "1024x1792", + "1024x1024", + } + if image_conf.size not in allowed_sizes: + raise ValidationException( + message=f"size must be one of {sorted(allowed_sizes)}", + param="image_config.size", + code="invalid_size", + ) + request.image_config = image_conf + + # image edit 验证 + if model_info and model_info.is_image_edit: + _, image_urls = _extract_prompt_images(request.messages) + if not image_urls: + raise ValidationException( + message="image_url is required for image edits", + param="messages", + code="missing_image", + ) + + # video 验证 + if model_info and model_info.is_video: + config = request.video_config or VideoConfig() + ratio_map = { + "1280x720": "16:9", + "720x1280": "9:16", + "1792x1024": "3:2", + "1024x1792": "2:3", + "1024x1024": "1:1", + "16:9": "16:9", + "9:16": "9:16", + "3:2": "3:2", + "2:3": "2:3", + "1:1": "1:1", + } + if config.aspect_ratio is None: + config.aspect_ratio = "3:2" + if config.aspect_ratio not in ratio_map: + raise ValidationException( + message=f"aspect_ratio must be one of {list(ratio_map.keys())}", + param="video_config.aspect_ratio", + code="invalid_aspect_ratio", + ) + config.aspect_ratio = ratio_map[config.aspect_ratio] + + if config.video_length not in (6, 10, 15): + raise ValidationException( + message="video_length must be 6, 10, or 15 seconds", + param="video_config.video_length", + code="invalid_video_length", + ) + if config.resolution_name not in ("480p", "720p"): + raise ValidationException( + message="resolution_name must be one of ['480p', '720p']", + param="video_config.resolution_name", + code="invalid_resolution", + ) + if config.preset not in ("fun", "normal", "spicy", "custom"): + raise ValidationException( + message="preset must be one of ['fun', 'normal', 'spicy', 'custom']", + param="video_config.preset", + code="invalid_preset", + ) + request.video_config = config + + +router = APIRouter(tags=["Chat"]) @router.post("/chat/completions") @@ -257,11 +510,149 @@ async def chat_completions(request: ChatCompletionRequest): logger.debug(f"Chat request: model={request.model}, stream={request.stream}") - # 检测视频模型 + # 检测模型类型 model_info = ModelService.get(request.model) - if model_info and model_info.is_video: - from app.services.grok.services.media import VideoService + if model_info and model_info.is_image_edit: + prompt, image_urls = _extract_prompt_images(request.messages) + if not image_urls: + raise ValidationException( + message="Image is required", + param="image", + code="missing_image", + ) + image_url = image_urls[-1] + + is_stream = ( + request.stream if request.stream is not None else get_config("app.stream") + ) + image_conf = request.image_config or ImageConfig() + _validate_image_config(image_conf, stream=bool(is_stream)) + response_format = _resolve_image_format(image_conf.response_format) + response_field = _image_field(response_format) + n = image_conf.n or 1 + + token_mgr = await get_token_manager() + await token_mgr.reload_if_stale() + + token = None + for pool_name in ModelService.pool_candidates_for_model(request.model): + token = token_mgr.get_token(pool_name) + if token: + break + + if not token: + raise AppException( + message="No available tokens. Please try again later.", + error_type=ErrorType.RATE_LIMIT.value, + code="rate_limit_exceeded", + status_code=429, + ) + result = await ImageEditService().edit( + token_mgr=token_mgr, + token=token, + model_info=model_info, + prompt=prompt, + images=[image_url], + n=n, + response_format=response_format, + stream=bool(is_stream), + ) + + if result.stream: + return StreamingResponse( + result.data, + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, + ) + + data = [{response_field: img} for img in result.data] + return JSONResponse( + content={ + "created": int(time.time()), + "data": data, + "usage": { + "total_tokens": 0, + "input_tokens": 0, + "output_tokens": 0, + "input_tokens_details": {"text_tokens": 0, "image_tokens": 0}, + }, + } + ) + + if model_info and model_info.is_image: + prompt, _ = _extract_prompt_images(request.messages) + + is_stream = ( + request.stream if request.stream is not None else get_config("app.stream") + ) + image_conf = request.image_config or ImageConfig() + _validate_image_config(image_conf, stream=bool(is_stream)) + response_format = _resolve_image_format(image_conf.response_format) + response_field = _image_field(response_format) + n = image_conf.n or 1 + size = image_conf.size or "1024x1024" + aspect_ratio_map = { + "1280x720": "16:9", + "720x1280": "9:16", + "1792x1024": "3:2", + "1024x1792": "2:3", + "1024x1024": "1:1", + } + aspect_ratio = aspect_ratio_map.get(size, "2:3") + + token_mgr = await get_token_manager() + await token_mgr.reload_if_stale() + + token = None + for pool_name in ModelService.pool_candidates_for_model(request.model): + token = token_mgr.get_token(pool_name) + if token: + break + + if not token: + raise AppException( + message="No available tokens. Please try again later.", + error_type=ErrorType.RATE_LIMIT.value, + code="rate_limit_exceeded", + status_code=429, + ) + + result = await ImageGenerationService().generate( + token_mgr=token_mgr, + token=token, + model_info=model_info, + prompt=prompt, + n=n, + response_format=response_format, + size=size, + aspect_ratio=aspect_ratio, + stream=bool(is_stream), + ) + + if result.stream: + return StreamingResponse( + result.data, + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, + ) + + data = [{response_field: img} for img in result.data] + usage = result.usage_override or { + "total_tokens": 0, + "input_tokens": 0, + "output_tokens": 0, + "input_tokens_details": {"text_tokens": 0, "image_tokens": 0}, + } + return JSONResponse( + content={ + "created": int(time.time()), + "data": data, + "usage": usage, + } + ) + + if model_info and model_info.is_video: # 提取视频配置 (默认值在 Pydantic 模型中处理) v_conf = request.video_config or VideoConfig() @@ -269,7 +660,7 @@ async def chat_completions(request: ChatCompletionRequest): model=request.model, messages=[msg.model_dump() for msg in request.messages], stream=request.stream, - thinking=request.thinking, + reasoning_effort=request.reasoning_effort, aspect_ratio=v_conf.aspect_ratio, video_length=v_conf.video_length, resolution=v_conf.resolution_name, @@ -280,7 +671,9 @@ async def chat_completions(request: ChatCompletionRequest): model=request.model, messages=[msg.model_dump() for msg in request.messages], stream=request.stream, - thinking=request.thinking, + reasoning_effort=request.reasoning_effort, + temperature=request.temperature, + top_p=request.top_p, ) if isinstance(result, dict): diff --git a/app/api/v1/image.py b/app/api/v1/image.py index ec207c18..a7dcc9c1 100644 --- a/app/api/v1/image.py +++ b/app/api/v1/image.py @@ -2,11 +2,7 @@ Image Generation API 路由 """ -import asyncio import base64 -import math -import random -import re import time from pathlib import Path from typing import List, Optional, Union @@ -15,25 +11,32 @@ from fastapi.responses import StreamingResponse, JSONResponse from pydantic import BaseModel, Field, ValidationError -from app.services.grok.services.chat import GrokChatService -from app.services.grok.services.image import image_service -from app.services.grok.services.assets import UploadService -from app.services.grok.services.media import VideoService -from app.services.grok.models.model import ModelService -from app.services.grok.processors import ( - ImageStreamProcessor, - ImageCollectProcessor, - ImageWSStreamProcessor, - ImageWSCollectProcessor, -) -from app.services.token import get_token_manager, EffortType +from app.services.grok.services.image import ImageGenerationService +from app.services.grok.services.image_edit import ImageEditService +from app.services.grok.services.model import ModelService +from app.services.token import get_token_manager from app.core.exceptions import ValidationException, AppException, ErrorType from app.core.config import get_config -from app.core.logger import logger router = APIRouter(tags=["Images"]) +ALLOWED_IMAGE_SIZES = { + "1280x720", + "720x1280", + "1792x1024", + "1024x1792", + "1024x1024", +} + +SIZE_TO_ASPECT = { + "1280x720": "16:9", + "720x1280": "9:16", + "1792x1024": "3:2", + "1024x1792": "2:3", + "1024x1024": "1:1", +} + class ImageGenerationRequest(BaseModel): """图片生成请求 - OpenAI 兼容""" @@ -41,7 +44,10 @@ class ImageGenerationRequest(BaseModel): prompt: str = Field(..., description="图片描述") model: Optional[str] = Field("grok-imagine-1.0", description="模型名称") n: Optional[int] = Field(1, ge=1, le=10, description="生成数量 (1-10)") - size: Optional[str] = Field("1024x1024", description="图片尺寸 (暂不支持)") + size: Optional[str] = Field( + "1024x1024", + description="图片尺寸: 1280x720, 720x1280, 1792x1024, 1024x1792, 1024x1024", + ) quality: Optional[str] = Field("standard", description="图片质量 (暂不支持)") response_format: Optional[str] = Field(None, description="响应格式") style: Optional[str] = Field(None, description="风格 (暂不支持)") @@ -55,7 +61,10 @@ class ImageEditRequest(BaseModel): model: Optional[str] = Field("grok-imagine-1.0-edit", description="模型名称") image: Optional[Union[str, List[str]]] = Field(None, description="待编辑图片文件") n: Optional[int] = Field(1, ge=1, le=10, description="生成数量 (1-10)") - size: Optional[str] = Field("1024x1024", description="图片尺寸 (暂不支持)") + size: Optional[str] = Field( + "1024x1024", + description="图片尺寸: 1280x720, 720x1280, 1792x1024, 1024x1792, 1024x1024", + ) quality: Optional[str] = Field("standard", description="图片质量 (暂不支持)") response_format: Optional[str] = Field(None, description="响应格式") style: Optional[str] = Field(None, description="风格 (暂不支持)") @@ -89,18 +98,14 @@ def _validate_common_request( ) if allow_ws_stream: - # WS 流式仅支持 b64_json (base64 视为同义) - if ( - request.stream - and get_config("image.image_ws") - and request.response_format - and request.response_format not in {"b64_json", "base64"} - ): - raise ValidationException( - message="Streaming with image_ws only supports response_format=b64_json/base64", - param="response_format", - code="invalid_response_format", - ) + if request.stream and request.response_format: + allowed_stream_formats = {"b64_json", "base64", "url"} + if request.response_format not in allowed_stream_formats: + raise ValidationException( + message="Streaming only supports response_format=b64_json/base64/url", + param="response_format", + code="invalid_response_format", + ) if request.response_format: allowed_formats = {"b64_json", "base64", "url"} @@ -111,6 +116,13 @@ def _validate_common_request( code="invalid_response_format", ) + if request.size and request.size not in ALLOWED_IMAGE_SIZES: + raise ValidationException( + message=f"size must be one of {sorted(ALLOWED_IMAGE_SIZES)}", + param="size", + code="invalid_size", + ) + def validate_generation_request(request: ImageGenerationRequest): """验证图片生成请求参数""" @@ -157,26 +169,8 @@ def response_field_name(response_format: str) -> str: def resolve_aspect_ratio(size: str) -> str: """Map OpenAI size to Grok Imagine aspect ratio.""" - size = (size or "").lower() - if size in {"16:9", "9:16", "1:1", "2:3", "3:2"}: - return size - mapping = { - "1024x1024": "1:1", - "512x512": "1:1", - "1024x576": "16:9", - "1280x720": "16:9", - "1536x864": "16:9", - "576x1024": "9:16", - "720x1280": "9:16", - "864x1536": "9:16", - "1024x1536": "2:3", - "512x768": "2:3", - "768x1024": "2:3", - "1536x1024": "3:2", - "768x512": "3:2", - "1024x768": "3:2", - } - return mapping.get(size) or "2:3" + size = (size or "").strip() + return SIZE_TO_ASPECT.get(size) or "2:3" def validate_edit_request(request: ImageEditRequest, images: List[UploadFile]): @@ -187,6 +181,17 @@ def validate_edit_request(request: ImageEditRequest, images: List[UploadFile]): param="model", code="model_not_supported", ) + model_info = ModelService.get(request.model) + if not model_info or not model_info.is_image_edit: + edit_models = [m.model_id for m in ModelService.MODELS if m.is_image_edit] + raise ValidationException( + message=( + f"The model `{request.model}` is not supported for image edits. " + f"Supported: {edit_models}" + ), + param="model", + code="model_not_supported", + ) _validate_common_request(request, allow_ws_stream=False) if not images: raise ValidationException( @@ -202,30 +207,6 @@ def validate_edit_request(request: ImageEditRequest, images: List[UploadFile]): ) -def _get_effort(model_info) -> EffortType: - """获取模型消耗级别""" - return ( - EffortType.HIGH - if (model_info and model_info.cost.value == "high") - else EffortType.LOW - ) - - -async def _wrap_stream_with_usage(stream, token_mgr, token, model_info): - """包装流式响应,成功完成时记录使用""" - success = False - try: - async for chunk in stream: - yield chunk - success = True - finally: - if success: - try: - await token_mgr.consume(token, _get_effort(model_info)) - except Exception as e: - logger.warning(f"Failed to consume token: {e}") - - async def _get_token(model: str): """获取可用 token""" token_mgr = await get_token_manager() @@ -248,46 +229,6 @@ async def _get_token(model: str): return token_mgr, token -async def call_grok( - token_mgr, - token: str, - prompt: str, - model_info, - file_attachments: Optional[List[str]] = None, - response_format: str = "b64_json", -) -> List[str]: - """调用 Grok 获取图片,返回 base64 列表""" - chat_service = GrokChatService() - success = False - - try: - response = await chat_service.chat( - token=token, - message=prompt, - model=model_info.grok_model, - mode=model_info.model_mode, - stream=True, - file_attachments=file_attachments, - ) - - processor = ImageCollectProcessor( - model_info.model_id, token, response_format=response_format - ) - images = await processor.process(response) - success = True - return images - - except Exception as e: - logger.error(f"Grok image call failed: {e}") - return [] - finally: - if success: - try: - await token_mgr.consume(token, _get_effort(model_info)) - except Exception as e: - logger.warning(f"Failed to consume token: {e}") - - @router.post("/images/generations") async def create_image(request: ImageGenerationRequest): """ @@ -320,161 +261,29 @@ async def create_image(request: ImageGenerationRequest): # 获取 token 和模型信息 token_mgr, token = await _get_token(request.model) model_info = ModelService.get(request.model) - use_ws = bool(get_config("image.image_ws")) - - # 流式模式 - if request.stream: - if use_ws: - aspect_ratio = resolve_aspect_ratio(request.size) - enable_nsfw = bool(get_config("image.image_ws_nsfw")) - upstream = image_service.stream( - token=token, - prompt=request.prompt, - aspect_ratio=aspect_ratio, - n=request.n, - enable_nsfw=enable_nsfw, - ) - processor = ImageWSStreamProcessor( - model_info.model_id, - token, - n=request.n, - response_format=response_format, - size=request.size, - ) - - return StreamingResponse( - _wrap_stream_with_usage( - processor.process(upstream), token_mgr, token, model_info - ), - media_type="text/event-stream", - headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, - ) - - chat_service = GrokChatService() - response = await chat_service.chat( - token=token, - message=f"Image Generation: {request.prompt}", - model=model_info.grok_model, - mode=model_info.model_mode, - stream=True, - ) - - processor = ImageStreamProcessor( - model_info.model_id, token, n=request.n, response_format=response_format - ) + aspect_ratio = resolve_aspect_ratio(request.size) + + result = await ImageGenerationService().generate( + token_mgr=token_mgr, + token=token, + model_info=model_info, + prompt=request.prompt, + n=request.n, + response_format=response_format, + size=request.size, + aspect_ratio=aspect_ratio, + stream=bool(request.stream), + ) + if result.stream: return StreamingResponse( - _wrap_stream_with_usage( - processor.process(response), token_mgr, token, model_info - ), + result.data, media_type="text/event-stream", headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, ) - # 非流式模式 - n = request.n - - usage_override = None - if use_ws: - aspect_ratio = resolve_aspect_ratio(request.size) - enable_nsfw = bool(get_config("image.image_ws_nsfw")) - all_images = [] - seen = set() - expected_per_call = 6 - calls_needed = max(1, math.ceil(n / expected_per_call)) - calls_needed = min(calls_needed, n) - - async def _fetch_batch(call_target: int): - upstream = image_service.stream( - token=token, - prompt=request.prompt, - aspect_ratio=aspect_ratio, - n=call_target, - enable_nsfw=enable_nsfw, - ) - processor = ImageWSCollectProcessor( - model_info.model_id, - token, - n=call_target, - response_format=response_format, - ) - return await processor.process(upstream) - - tasks = [] - for i in range(calls_needed): - remaining = n - (i * expected_per_call) - call_target = min(expected_per_call, remaining) - tasks.append(_fetch_batch(call_target)) - - results = await asyncio.gather(*tasks, return_exceptions=True) - for batch in results: - if isinstance(batch, Exception): - logger.warning(f"WS batch failed: {batch}") - continue - for img in batch: - if img not in seen: - seen.add(img) - all_images.append(img) - if len(all_images) >= n: - break - if len(all_images) >= n: - break - try: - await token_mgr.consume(token, _get_effort(model_info)) - except Exception as e: - logger.warning(f"Failed to consume token: {e}") - usage_override = { - "total_tokens": 0, - "input_tokens": 0, - "output_tokens": 0, - "input_tokens_details": {"text_tokens": 0, "image_tokens": 0}, - } - else: - calls_needed = (n + 1) // 2 - - if calls_needed == 1: - # 单次调用 - all_images = await call_grok( - token_mgr, - token, - f"Image Generation: {request.prompt}", - model_info, - response_format=response_format, - ) - else: - # 并发调用 - tasks = [ - call_grok( - token_mgr, - token, - f"Image Generation: {request.prompt}", - model_info, - response_format=response_format, - ) - for _ in range(calls_needed) - ] - results = await asyncio.gather(*tasks, return_exceptions=True) - - # 收集成功的图片 - all_images = [] - for result in results: - if isinstance(result, Exception): - logger.error(f"Concurrent call failed: {result}") - elif isinstance(result, list): - all_images.extend(result) - - # 随机选取 n 张图片 - if len(all_images) >= n: - selected_images = random.sample(all_images, n) - else: - # 全部返回,error 填充缺失 - selected_images = all_images.copy() - while len(selected_images) < n: - selected_images.append("error") - - # 构建响应 - data = [{response_field: img} for img in selected_images] - usage = usage_override or { + data = [{response_field: img} for img in result.data] + usage = result.usage_override or { "total_tokens": 0, "input_tokens": 0, "output_tokens": 0, @@ -539,6 +348,8 @@ async def edit_image( edit_request.stream = False response_format = resolve_response_format(edit_request.response_format) + if response_format == "base64": + response_format = "b64_json" edit_request.response_format = response_format response_field = response_field_name(response_format) @@ -588,150 +399,25 @@ async def edit_image( token_mgr, token = await _get_token(edit_request.model) model_info = ModelService.get(edit_request.model) - # 上传图片 - image_urls: List[str] = [] - upload_service = UploadService() - try: - for image in images: - file_id, file_uri = await upload_service.upload(image, token) - if file_uri: - if file_uri.startswith("http"): - image_urls.append(file_uri) - else: - image_urls.append(f"https://assets.grok.com/{file_uri.lstrip('/')}") - finally: - await upload_service.close() - - if not image_urls: - raise AppException( - message="Image upload failed", - error_type=ErrorType.SERVER.value, - code="upload_failed", - ) - - parent_post_id = None - try: - media_service = VideoService() - parent_post_id = await media_service.create_image_post(token, image_urls[0]) - logger.debug(f"Parent post ID: {parent_post_id}") - except Exception as e: - logger.warning(f"Create image post failed: {e}") - - if not parent_post_id: - for url in image_urls: - match = re.search(r"/generated/([a-f0-9-]+)/", url) - if match: - parent_post_id = match.group(1) - logger.debug(f"Parent post ID: {parent_post_id}") - break - match = re.search(r"/users/[^/]+/([a-f0-9-]+)/content", url) - if match: - parent_post_id = match.group(1) - logger.debug(f"Parent post ID: {parent_post_id}") - break - - model_config_override = { - "modelMap": { - "imageEditModel": "imagine", - "imageEditModelConfig": { - "imageReferences": image_urls, - }, - } - } - - if parent_post_id: - model_config_override["modelMap"]["imageEditModelConfig"]["parentPostId"] = ( - parent_post_id - ) - - raw_payload = { - "temporary": bool(get_config("chat.temporary")), - "modelName": model_info.grok_model, - "message": edit_request.prompt, - "enableImageGeneration": True, - "returnImageBytes": False, - "returnRawGrokInXaiRequest": False, - "enableImageStreaming": True, - "imageGenerationCount": 2, - "forceConcise": False, - "toolOverrides": {"imageGen": True}, - "enableSideBySide": True, - "sendFinalMetadata": True, - "isReasoning": False, - "disableTextFollowUps": True, - "responseMetadata": {"modelConfigOverride": model_config_override}, - "disableMemory": False, - "forceSideBySide": False, - } - - # 流式模式 - if edit_request.stream: - chat_service = GrokChatService() - response = await chat_service.chat( - token=token, - message=edit_request.prompt, - model=model_info.grok_model, - mode=None, - stream=True, - raw_payload=raw_payload, - ) - - processor = ImageStreamProcessor( - model_info.model_id, - token, - n=edit_request.n, - response_format=response_format, - ) + result = await ImageEditService().edit( + token_mgr=token_mgr, + token=token, + model_info=model_info, + prompt=edit_request.prompt, + images=images, + n=edit_request.n, + response_format=response_format, + stream=bool(edit_request.stream), + ) + if result.stream: return StreamingResponse( - _wrap_stream_with_usage( - processor.process(response), token_mgr, token, model_info - ), + result.data, media_type="text/event-stream", headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, ) - # 非流式模式 - n = edit_request.n - calls_needed = (n + 1) // 2 - - async def _call_edit(): - chat_service = GrokChatService() - response = await chat_service.chat( - token=token, - message=edit_request.prompt, - model=model_info.grok_model, - mode=None, - stream=True, - raw_payload=raw_payload, - ) - processor = ImageCollectProcessor( - model_info.model_id, token, response_format=response_format - ) - return await processor.process(response) - - if calls_needed == 1: - all_images = await _call_edit() - else: - tasks = [_call_edit() for _ in range(calls_needed)] - results = await asyncio.gather(*tasks, return_exceptions=True) - - all_images = [] - for result in results: - if isinstance(result, Exception): - logger.error(f"Concurrent call failed: {result}") - elif isinstance(result, list): - all_images.extend(result) - - # 选择图片 - if len(all_images) >= n: - selected_images = random.sample(all_images, n) - else: - selected_images = all_images.copy() - while len(selected_images) < n: - selected_images.append("error") - - data = [{response_field: img} for img in selected_images] + data = [{response_field: img} for img in result.data] return JSONResponse( content={ diff --git a/app/api/v1/models.py b/app/api/v1/models.py index 13971669..fe5bdc0e 100644 --- a/app/api/v1/models.py +++ b/app/api/v1/models.py @@ -4,7 +4,7 @@ from fastapi import APIRouter -from app.services.grok.models.model import ModelService +from app.services.grok.services.model import ModelService router = APIRouter(tags=["Models"]) @@ -18,7 +18,7 @@ async def list_models(): "id": m.model_id, "object": "model", "created": 0, - "owned_by": "grok2api", + "owned_by": "grok2api@chenyme", } for m in ModelService.list() ] diff --git a/app/api/v1/public/__init__.py b/app/api/v1/public/__init__.py new file mode 100644 index 00000000..984bf0d3 --- /dev/null +++ b/app/api/v1/public/__init__.py @@ -0,0 +1,15 @@ +"""Public API router (public_key protected).""" + +from fastapi import APIRouter + +from app.api.v1.public.imagine import router as imagine_router +from app.api.v1.public.video import router as video_router +from app.api.v1.public.voice import router as voice_router + +router = APIRouter() + +router.include_router(imagine_router) +router.include_router(video_router) +router.include_router(voice_router) + +__all__ = ["router"] diff --git a/app/api/v1/public/imagine.py b/app/api/v1/public/imagine.py new file mode 100644 index 00000000..83f59c34 --- /dev/null +++ b/app/api/v1/public/imagine.py @@ -0,0 +1,505 @@ +import asyncio +import time +import uuid +from typing import Optional, List, Dict, Any + +import orjson +from fastapi import APIRouter, Depends, HTTPException, Query, Request, WebSocket, WebSocketDisconnect +from fastapi.responses import StreamingResponse +from pydantic import BaseModel + +from app.core.auth import verify_public_key, get_public_api_key, is_public_enabled +from app.core.config import get_config +from app.core.logger import logger +from app.api.v1.image import resolve_aspect_ratio +from app.services.grok.services.image import ImageGenerationService +from app.services.grok.services.model import ModelService +from app.services.token.manager import get_token_manager + +router = APIRouter() + +IMAGINE_SESSION_TTL = 600 +_IMAGINE_SESSIONS: dict[str, dict] = {} +_IMAGINE_SESSIONS_LOCK = asyncio.Lock() + + +async def _clean_sessions(now: float) -> None: + expired = [ + key + for key, info in _IMAGINE_SESSIONS.items() + if now - float(info.get("created_at") or 0) > IMAGINE_SESSION_TTL + ] + for key in expired: + _IMAGINE_SESSIONS.pop(key, None) + + +def _parse_sse_chunk(chunk: str) -> Optional[Dict[str, Any]]: + if not chunk: + return None + event = None + data_lines: List[str] = [] + for raw in str(chunk).splitlines(): + line = raw.strip() + if not line: + continue + if line.startswith("event:"): + event = line[6:].strip() + continue + if line.startswith("data:"): + data_lines.append(line[5:].strip()) + if not data_lines: + return None + data_str = "\n".join(data_lines) + if data_str == "[DONE]": + return None + try: + payload = orjson.loads(data_str) + except orjson.JSONDecodeError: + return None + if event and isinstance(payload, dict) and "type" not in payload: + payload["type"] = event + return payload + + +async def _new_session(prompt: str, aspect_ratio: str, nsfw: Optional[bool]) -> str: + task_id = uuid.uuid4().hex + now = time.time() + async with _IMAGINE_SESSIONS_LOCK: + await _clean_sessions(now) + _IMAGINE_SESSIONS[task_id] = { + "prompt": prompt, + "aspect_ratio": aspect_ratio, + "nsfw": nsfw, + "created_at": now, + } + return task_id + + +async def _get_session(task_id: str) -> Optional[dict]: + if not task_id: + return None + now = time.time() + async with _IMAGINE_SESSIONS_LOCK: + await _clean_sessions(now) + info = _IMAGINE_SESSIONS.get(task_id) + if not info: + return None + created_at = float(info.get("created_at") or 0) + if now - created_at > IMAGINE_SESSION_TTL: + _IMAGINE_SESSIONS.pop(task_id, None) + return None + return dict(info) + + +async def _drop_session(task_id: str) -> None: + if not task_id: + return + async with _IMAGINE_SESSIONS_LOCK: + _IMAGINE_SESSIONS.pop(task_id, None) + + +async def _drop_sessions(task_ids: List[str]) -> int: + if not task_ids: + return 0 + removed = 0 + async with _IMAGINE_SESSIONS_LOCK: + for task_id in task_ids: + if task_id and task_id in _IMAGINE_SESSIONS: + _IMAGINE_SESSIONS.pop(task_id, None) + removed += 1 + return removed + + +@router.websocket("/imagine/ws") +async def public_imagine_ws(websocket: WebSocket): + session_id = None + task_id = websocket.query_params.get("task_id") + if task_id: + info = await _get_session(task_id) + if info: + session_id = task_id + + ok = True + if session_id is None: + public_key = get_public_api_key() + public_enabled = is_public_enabled() + if not public_key: + ok = public_enabled + else: + key = websocket.query_params.get("public_key") + ok = key == public_key + + if not ok: + await websocket.close(code=1008) + return + + await websocket.accept() + stop_event = asyncio.Event() + run_task: Optional[asyncio.Task] = None + + async def _send(payload: dict) -> bool: + try: + await websocket.send_text(orjson.dumps(payload).decode()) + return True + except Exception: + return False + + async def _stop_run(): + nonlocal run_task + stop_event.set() + if run_task and not run_task.done(): + run_task.cancel() + try: + await run_task + except Exception: + pass + run_task = None + stop_event.clear() + + async def _run(prompt: str, aspect_ratio: str, nsfw: Optional[bool]): + model_id = "grok-imagine-1.0" + model_info = ModelService.get(model_id) + if not model_info or not model_info.is_image: + await _send( + { + "type": "error", + "message": "Image model is not available.", + "code": "model_not_supported", + } + ) + return + + token_mgr = await get_token_manager() + run_id = uuid.uuid4().hex + + await _send( + { + "type": "status", + "status": "running", + "prompt": prompt, + "aspect_ratio": aspect_ratio, + "run_id": run_id, + } + ) + + while not stop_event.is_set(): + try: + await token_mgr.reload_if_stale() + token = None + for pool_name in ModelService.pool_candidates_for_model( + model_info.model_id + ): + token = token_mgr.get_token(pool_name) + if token: + break + + if not token: + await _send( + { + "type": "error", + "message": "No available tokens. Please try again later.", + "code": "rate_limit_exceeded", + } + ) + await asyncio.sleep(2) + continue + + result = await ImageGenerationService().generate( + token_mgr=token_mgr, + token=token, + model_info=model_info, + prompt=prompt, + n=6, + response_format="b64_json", + size="1024x1024", + aspect_ratio=aspect_ratio, + stream=True, + enable_nsfw=nsfw, + ) + if result.stream: + async for chunk in result.data: + payload = _parse_sse_chunk(chunk) + if not payload: + continue + if isinstance(payload, dict): + payload.setdefault("run_id", run_id) + await _send(payload) + else: + images = [img for img in result.data if img and img != "error"] + if images: + for img_b64 in images: + await _send( + { + "type": "image", + "b64_json": img_b64, + "created_at": int(time.time() * 1000), + "aspect_ratio": aspect_ratio, + "run_id": run_id, + } + ) + else: + await _send( + { + "type": "error", + "message": "Image generation returned empty data.", + "code": "empty_image", + } + ) + + except asyncio.CancelledError: + break + except Exception as e: + logger.warning(f"Imagine stream error: {e}") + await _send( + { + "type": "error", + "message": str(e), + "code": "internal_error", + } + ) + await asyncio.sleep(1.5) + + await _send({"type": "status", "status": "stopped", "run_id": run_id}) + + try: + while True: + try: + raw = await websocket.receive_text() + except (RuntimeError, WebSocketDisconnect): + break + + try: + payload = orjson.loads(raw) + except Exception: + await _send( + { + "type": "error", + "message": "Invalid message format.", + "code": "invalid_payload", + } + ) + continue + + action = payload.get("type") + if action == "start": + prompt = str(payload.get("prompt") or "").strip() + if not prompt: + await _send( + { + "type": "error", + "message": "Prompt cannot be empty.", + "code": "invalid_prompt", + } + ) + continue + aspect_ratio = resolve_aspect_ratio( + str(payload.get("aspect_ratio") or "2:3").strip() or "2:3" + ) + nsfw = payload.get("nsfw") + if nsfw is not None: + nsfw = bool(nsfw) + await _stop_run() + run_task = asyncio.create_task(_run(prompt, aspect_ratio, nsfw)) + elif action == "stop": + await _stop_run() + else: + await _send( + { + "type": "error", + "message": "Unknown action.", + "code": "invalid_action", + } + ) + + except WebSocketDisconnect: + logger.debug("WebSocket disconnected by client") + except Exception as e: + logger.warning(f"WebSocket error: {e}") + finally: + await _stop_run() + + try: + from starlette.websockets import WebSocketState + if websocket.client_state == WebSocketState.CONNECTED: + await websocket.close(code=1000, reason="Server closing connection") + except Exception as e: + logger.debug(f"WebSocket close ignored: {e}") + if session_id: + await _drop_session(session_id) + + +@router.get("/imagine/sse") +async def public_imagine_sse( + request: Request, + task_id: str = Query(""), + prompt: str = Query(""), + aspect_ratio: str = Query("2:3"), +): + """Imagine 图片瀑布流(SSE 兜底)""" + session = None + if task_id: + session = await _get_session(task_id) + if not session: + raise HTTPException(status_code=404, detail="Task not found") + else: + public_key = get_public_api_key() + public_enabled = is_public_enabled() + if not public_key: + if not public_enabled: + raise HTTPException(status_code=401, detail="Public access is disabled") + else: + key = request.query_params.get("public_key") + if key != public_key: + raise HTTPException(status_code=401, detail="Invalid authentication token") + + if session: + prompt = str(session.get("prompt") or "").strip() + ratio = str(session.get("aspect_ratio") or "2:3").strip() or "2:3" + nsfw = session.get("nsfw") + else: + prompt = (prompt or "").strip() + if not prompt: + raise HTTPException(status_code=400, detail="Prompt cannot be empty") + ratio = str(aspect_ratio or "2:3").strip() or "2:3" + ratio = resolve_aspect_ratio(ratio) + nsfw = request.query_params.get("nsfw") + if nsfw is not None: + nsfw = str(nsfw).lower() in ("1", "true", "yes", "on") + + async def event_stream(): + try: + model_id = "grok-imagine-1.0" + model_info = ModelService.get(model_id) + if not model_info or not model_info.is_image: + yield ( + f"data: {orjson.dumps({'type': 'error', 'message': 'Image model is not available.', 'code': 'model_not_supported'}).decode()}\n\n" + ) + return + + token_mgr = await get_token_manager() + sequence = 0 + run_id = uuid.uuid4().hex + + yield ( + f"data: {orjson.dumps({'type': 'status', 'status': 'running', 'prompt': prompt, 'aspect_ratio': ratio, 'run_id': run_id}).decode()}\n\n" + ) + + while True: + if await request.is_disconnected(): + break + if task_id: + session_alive = await _get_session(task_id) + if not session_alive: + break + + try: + await token_mgr.reload_if_stale() + token = None + for pool_name in ModelService.pool_candidates_for_model( + model_info.model_id + ): + token = token_mgr.get_token(pool_name) + if token: + break + + if not token: + yield ( + f"data: {orjson.dumps({'type': 'error', 'message': 'No available tokens. Please try again later.', 'code': 'rate_limit_exceeded'}).decode()}\n\n" + ) + await asyncio.sleep(2) + continue + + result = await ImageGenerationService().generate( + token_mgr=token_mgr, + token=token, + model_info=model_info, + prompt=prompt, + n=6, + response_format="b64_json", + size="1024x1024", + aspect_ratio=ratio, + stream=True, + enable_nsfw=nsfw, + ) + if result.stream: + async for chunk in result.data: + payload = _parse_sse_chunk(chunk) + if not payload: + continue + if isinstance(payload, dict): + payload.setdefault("run_id", run_id) + yield f"data: {orjson.dumps(payload).decode()}\n\n" + else: + images = [img for img in result.data if img and img != "error"] + if images: + for img_b64 in images: + sequence += 1 + payload = { + "type": "image", + "b64_json": img_b64, + "sequence": sequence, + "created_at": int(time.time() * 1000), + "aspect_ratio": ratio, + "run_id": run_id, + } + yield f"data: {orjson.dumps(payload).decode()}\n\n" + else: + yield ( + f"data: {orjson.dumps({'type': 'error', 'message': 'Image generation returned empty data.', 'code': 'empty_image'}).decode()}\n\n" + ) + except asyncio.CancelledError: + break + except Exception as e: + logger.warning(f"Imagine SSE error: {e}") + yield ( + f"data: {orjson.dumps({'type': 'error', 'message': str(e), 'code': 'internal_error'}).decode()}\n\n" + ) + await asyncio.sleep(1.5) + + yield ( + f"data: {orjson.dumps({'type': 'status', 'status': 'stopped', 'run_id': run_id}).decode()}\n\n" + ) + finally: + if task_id: + await _drop_session(task_id) + + return StreamingResponse( + event_stream(), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, + ) + + +@router.get("/imagine/config") +async def public_imagine_config(): + return { + "final_min_bytes": int(get_config("image.final_min_bytes") or 0), + "medium_min_bytes": int(get_config("image.medium_min_bytes") or 0), + "nsfw": bool(get_config("image.nsfw")), + } + + +class ImagineStartRequest(BaseModel): + prompt: str + aspect_ratio: Optional[str] = "2:3" + nsfw: Optional[bool] = None + + +@router.post("/imagine/start", dependencies=[Depends(verify_public_key)]) +async def public_imagine_start(data: ImagineStartRequest): + prompt = (data.prompt or "").strip() + if not prompt: + raise HTTPException(status_code=400, detail="Prompt cannot be empty") + ratio = resolve_aspect_ratio(str(data.aspect_ratio or "2:3").strip() or "2:3") + task_id = await _new_session(prompt, ratio, data.nsfw) + return {"task_id": task_id, "aspect_ratio": ratio} + + +class ImagineStopRequest(BaseModel): + task_ids: List[str] + + +@router.post("/imagine/stop", dependencies=[Depends(verify_public_key)]) +async def public_imagine_stop(data: ImagineStopRequest): + removed = await _drop_sessions(data.task_ids or []) + return {"status": "success", "removed": removed} diff --git a/app/api/v1/public/video.py b/app/api/v1/public/video.py new file mode 100644 index 00000000..c88182c8 --- /dev/null +++ b/app/api/v1/public/video.py @@ -0,0 +1,274 @@ +import asyncio +import time +import uuid +from typing import Optional, List, Dict, Any + +import orjson +from fastapi import APIRouter, Depends, HTTPException, Query, Request +from fastapi.responses import StreamingResponse +from pydantic import BaseModel + +from app.core.auth import verify_public_key +from app.core.logger import logger +from app.services.grok.services.video import VideoService +from app.services.grok.services.model import ModelService + +router = APIRouter() + +VIDEO_SESSION_TTL = 600 +_VIDEO_SESSIONS: dict[str, dict] = {} +_VIDEO_SESSIONS_LOCK = asyncio.Lock() + +_VIDEO_RATIO_MAP = { + "1280x720": "16:9", + "720x1280": "9:16", + "1792x1024": "3:2", + "1024x1792": "2:3", + "1024x1024": "1:1", + "16:9": "16:9", + "9:16": "9:16", + "3:2": "3:2", + "2:3": "2:3", + "1:1": "1:1", +} + + +async def _clean_sessions(now: float) -> None: + expired = [ + key + for key, info in _VIDEO_SESSIONS.items() + if now - float(info.get("created_at") or 0) > VIDEO_SESSION_TTL + ] + for key in expired: + _VIDEO_SESSIONS.pop(key, None) + + +async def _new_session( + prompt: str, + aspect_ratio: str, + video_length: int, + resolution_name: str, + preset: str, + image_url: Optional[str], + reasoning_effort: Optional[str], +) -> str: + task_id = uuid.uuid4().hex + now = time.time() + async with _VIDEO_SESSIONS_LOCK: + await _clean_sessions(now) + _VIDEO_SESSIONS[task_id] = { + "prompt": prompt, + "aspect_ratio": aspect_ratio, + "video_length": video_length, + "resolution_name": resolution_name, + "preset": preset, + "image_url": image_url, + "reasoning_effort": reasoning_effort, + "created_at": now, + } + return task_id + + +async def _get_session(task_id: str) -> Optional[dict]: + if not task_id: + return None + now = time.time() + async with _VIDEO_SESSIONS_LOCK: + await _clean_sessions(now) + info = _VIDEO_SESSIONS.get(task_id) + if not info: + return None + created_at = float(info.get("created_at") or 0) + if now - created_at > VIDEO_SESSION_TTL: + _VIDEO_SESSIONS.pop(task_id, None) + return None + return dict(info) + + +async def _drop_session(task_id: str) -> None: + if not task_id: + return + async with _VIDEO_SESSIONS_LOCK: + _VIDEO_SESSIONS.pop(task_id, None) + + +async def _drop_sessions(task_ids: List[str]) -> int: + if not task_ids: + return 0 + removed = 0 + async with _VIDEO_SESSIONS_LOCK: + for task_id in task_ids: + if task_id and task_id in _VIDEO_SESSIONS: + _VIDEO_SESSIONS.pop(task_id, None) + removed += 1 + return removed + + +def _normalize_ratio(value: Optional[str]) -> str: + raw = (value or "").strip() + return _VIDEO_RATIO_MAP.get(raw, "") + + +def _validate_image_url(image_url: str) -> None: + value = (image_url or "").strip() + if not value: + return + if value.startswith("data:"): + return + if value.startswith("http://") or value.startswith("https://"): + return + raise HTTPException( + status_code=400, + detail="image_url must be a URL or data URI (data:;base64,...)", + ) + + +class VideoStartRequest(BaseModel): + prompt: str + aspect_ratio: Optional[str] = "3:2" + video_length: Optional[int] = 6 + resolution_name: Optional[str] = "480p" + preset: Optional[str] = "normal" + image_url: Optional[str] = None + reasoning_effort: Optional[str] = None + + +@router.post("/video/start", dependencies=[Depends(verify_public_key)]) +async def public_video_start(data: VideoStartRequest): + prompt = (data.prompt or "").strip() + if not prompt: + raise HTTPException(status_code=400, detail="Prompt cannot be empty") + + aspect_ratio = _normalize_ratio(data.aspect_ratio) + if not aspect_ratio: + raise HTTPException( + status_code=400, + detail="aspect_ratio must be one of ['16:9','9:16','3:2','2:3','1:1']", + ) + + video_length = int(data.video_length or 6) + if video_length not in (6, 10, 15): + raise HTTPException( + status_code=400, detail="video_length must be 6, 10, or 15 seconds" + ) + + resolution_name = str(data.resolution_name or "480p") + if resolution_name not in ("480p", "720p"): + raise HTTPException( + status_code=400, + detail="resolution_name must be one of ['480p','720p']", + ) + + preset = str(data.preset or "normal") + if preset not in ("fun", "normal", "spicy", "custom"): + raise HTTPException( + status_code=400, + detail="preset must be one of ['fun','normal','spicy','custom']", + ) + + image_url = (data.image_url or "").strip() or None + if image_url: + _validate_image_url(image_url) + + reasoning_effort = (data.reasoning_effort or "").strip() or None + if reasoning_effort: + allowed = {"none", "minimal", "low", "medium", "high", "xhigh"} + if reasoning_effort not in allowed: + raise HTTPException( + status_code=400, + detail=f"reasoning_effort must be one of {sorted(allowed)}", + ) + + task_id = await _new_session( + prompt, + aspect_ratio, + video_length, + resolution_name, + preset, + image_url, + reasoning_effort, + ) + return {"task_id": task_id, "aspect_ratio": aspect_ratio} + + +@router.get("/video/sse") +async def public_video_sse(request: Request, task_id: str = Query("")): + session = await _get_session(task_id) + if not session: + raise HTTPException(status_code=404, detail="Task not found") + + prompt = str(session.get("prompt") or "").strip() + aspect_ratio = str(session.get("aspect_ratio") or "3:2") + video_length = int(session.get("video_length") or 6) + resolution_name = str(session.get("resolution_name") or "480p") + preset = str(session.get("preset") or "normal") + image_url = session.get("image_url") + reasoning_effort = session.get("reasoning_effort") + + async def event_stream(): + try: + model_id = "grok-imagine-1.0-video" + model_info = ModelService.get(model_id) + if not model_info or not model_info.is_video: + payload = { + "error": "Video model is not available.", + "code": "model_not_supported", + } + yield f"data: {orjson.dumps(payload).decode()}\n\n" + yield "data: [DONE]\n\n" + return + + if image_url: + messages: List[Dict[str, Any]] = [ + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + {"type": "image_url", "image_url": {"url": image_url}}, + ], + } + ] + else: + messages = [{"role": "user", "content": prompt}] + + stream = await VideoService.completions( + model_id, + messages, + stream=True, + reasoning_effort=reasoning_effort, + aspect_ratio=aspect_ratio, + video_length=video_length, + resolution=resolution_name, + preset=preset, + ) + + async for chunk in stream: + if await request.is_disconnected(): + break + yield chunk + except Exception as e: + logger.warning(f"Public video SSE error: {e}") + payload = {"error": str(e), "code": "internal_error"} + yield f"data: {orjson.dumps(payload).decode()}\n\n" + yield "data: [DONE]\n\n" + finally: + await _drop_session(task_id) + + return StreamingResponse( + event_stream(), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, + ) + + +class VideoStopRequest(BaseModel): + task_ids: List[str] + + +@router.post("/video/stop", dependencies=[Depends(verify_public_key)]) +async def public_video_stop(data: VideoStopRequest): + removed = await _drop_sessions(data.task_ids or []) + return {"status": "success", "removed": removed} + + +__all__ = ["router"] diff --git a/app/api/v1/public/voice.py b/app/api/v1/public/voice.py new file mode 100644 index 00000000..12612f09 --- /dev/null +++ b/app/api/v1/public/voice.py @@ -0,0 +1,80 @@ +from fastapi import APIRouter, Depends +from pydantic import BaseModel + +from app.core.auth import verify_public_key +from app.core.exceptions import AppException +from app.services.grok.services.voice import VoiceService +from app.services.token.manager import get_token_manager + +router = APIRouter() + + +class VoiceTokenResponse(BaseModel): + token: str + url: str + participant_name: str = "" + room_name: str = "" + + +@router.get( + "/voice/token", + dependencies=[Depends(verify_public_key)], + response_model=VoiceTokenResponse, +) +async def public_voice_token( + voice: str = "ara", + personality: str = "assistant", + speed: float = 1.0, +): + """获取 Grok Voice Mode (LiveKit) Token""" + token_mgr = await get_token_manager() + sso_token = None + for pool_name in ("ssoBasic", "ssoSuper"): + sso_token = token_mgr.get_token(pool_name) + if sso_token: + break + + if not sso_token: + raise AppException( + "No available tokens for voice mode", + code="no_token", + status_code=503, + ) + + service = VoiceService() + try: + data = await service.get_token( + token=sso_token, + voice=voice, + personality=personality, + speed=speed, + ) + token = data.get("token") + if not token: + raise AppException( + "Upstream returned no voice token", + code="upstream_error", + status_code=502, + ) + + return VoiceTokenResponse( + token=token, + url="wss://livekit.grok.com", + participant_name="", + room_name="", + ) + + except Exception as e: + if isinstance(e, AppException): + raise + raise AppException( + f"Voice token error: {str(e)}", + code="voice_error", + status_code=500, + ) + + +@router.get("/verify", dependencies=[Depends(verify_public_key)]) +async def public_verify_api(): + """验证 Public Key""" + return {"status": "success"} diff --git a/app/core/auth.py b/app/core/auth.py index e6bb3c37..2cb45820 100644 --- a/app/core/auth.py +++ b/app/core/auth.py @@ -10,6 +10,8 @@ DEFAULT_API_KEY = "" DEFAULT_APP_KEY = "grok2api" +DEFAULT_PUBLIC_KEY = "" +DEFAULT_PUBLIC_ENABLED = False # 定义 Bearer Scheme security = HTTPBearer( @@ -28,6 +30,28 @@ def get_admin_api_key() -> str: api_key = get_config("app.api_key", DEFAULT_API_KEY) return api_key or "" +def get_app_key() -> str: + """ + 获取 App Key(后台管理密码)。 + """ + app_key = get_config("app.app_key", DEFAULT_APP_KEY) + return app_key or "" + +def get_public_api_key() -> str: + """ + 获取 Public API Key。 + + 为空时表示不启用 public 接口认证。 + """ + public_key = get_config("app.public_key", DEFAULT_PUBLIC_KEY) + return public_key or "" + +def is_public_enabled() -> bool: + """ + 是否开启 public 功能入口。 + """ + return bool(get_config("app.public_enabled", DEFAULT_PUBLIC_ENABLED)) + async def verify_api_key( auth: Optional[HTTPAuthorizationCredentials] = Security(security), @@ -66,7 +90,7 @@ async def verify_app_key( app_key 必须配置,否则拒绝登录。 """ - app_key = get_config("app.app_key", DEFAULT_APP_KEY) + app_key = get_app_key() if not app_key: raise HTTPException( @@ -90,3 +114,40 @@ async def verify_app_key( ) return auth.credentials + + +async def verify_public_key( + auth: Optional[HTTPAuthorizationCredentials] = Security(security), +) -> Optional[str]: + """ + 验证 Public Key(public 接口使用)。 + + 默认不公开,需配置 public_key 才能访问;若开启 public_enabled 且未配置 public_key,则放开访问。 + """ + public_key = get_public_api_key() + public_enabled = is_public_enabled() + + if not public_key: + if public_enabled: + return None + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Public access is disabled", + headers={"WWW-Authenticate": "Bearer"}, + ) + + if not auth: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Missing authentication token", + headers={"WWW-Authenticate": "Bearer"}, + ) + + if auth.credentials != public_key: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid authentication token", + headers={"WWW-Authenticate": "Bearer"}, + ) + + return auth.credentials diff --git a/app/core/batch_tasks.py b/app/core/batch.py similarity index 61% rename from app/core/batch_tasks.py rename to app/core/batch.py index ff564ffb..7c62c015 100644 --- a/app/core/batch_tasks.py +++ b/app/core/batch.py @@ -1,13 +1,84 @@ """ -Batch task manager for admin batch operations (SSE progress). -""" +Batch utilities. -from __future__ import annotations +- run_batch: generic batch concurrency runner +- BatchTask: SSE task manager for admin batch operations +""" import asyncio import time import uuid -from typing import Any, Dict, List, Optional +from typing import Any, Awaitable, Callable, Dict, List, Optional, TypeVar + +from app.core.logger import logger + +T = TypeVar("T") + + +async def run_batch( + items: List[str], + worker: Callable[[str], Awaitable[T]], + *, + batch_size: int = 50, + task: Optional["BatchTask"] = None, + on_item: Optional[Callable[[str, Dict[str, Any]], Awaitable[None]]] = None, + should_cancel: Optional[Callable[[], bool]] = None, +) -> Dict[str, Dict[str, Any]]: + """ + 分批并发执行,单项失败不影响整体 + + Args: + items: 待处理项列表 + worker: 异步处理函数 + batch_size: 每批大小 + + Returns: + {item: {"ok": bool, "data": ..., "error": ...}} + """ + try: + batch_size = int(batch_size) + except Exception: + batch_size = 50 + + batch_size = max(1, batch_size) + + async def _one(item: str) -> tuple[str, dict]: + if (should_cancel and should_cancel()) or (task and task.cancelled): + return item, {"ok": False, "error": "cancelled", "cancelled": True} + try: + data = await worker(item) + result = {"ok": True, "data": data} + if task: + task.record(True) + if on_item: + try: + await on_item(item, result) + except Exception: + pass + return item, result + except Exception as e: + logger.warning(f"Batch item failed: {item[:16]}... - {e}") + result = {"ok": False, "error": str(e)} + if task: + task.record(False, error=str(e)) + if on_item: + try: + await on_item(item, result) + except Exception: + pass + return item, result + + results: Dict[str, dict] = {} + + # 分批执行,避免一次性创建所有 task + for i in range(0, len(items), batch_size): + if (should_cancel and should_cancel()) or (task and task.cancelled): + break + chunk = items[i : i + batch_size] + pairs = await asyncio.gather(*(_one(x) for x in chunk)) + results.update(dict(pairs)) + + return results class BatchTask: @@ -150,3 +221,13 @@ def delete_task(task_id: str) -> None: async def expire_task(task_id: str, delay: int = 300) -> None: await asyncio.sleep(delay) delete_task(task_id) + + +__all__ = [ + "run_batch", + "BatchTask", + "create_task", + "get_task", + "delete_task", + "expire_task", +] diff --git a/app/core/config.py b/app/core/config.py index bec87149..d761bd86 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -44,31 +44,69 @@ def _migrate_deprecated_config( # 配置映射规则:旧配置 -> 新配置 MIGRATION_MAP = { # grok.* -> 对应的新配置节 - "grok.temporary": "chat.temporary", - "grok.disable_memory": "chat.disable_memory", - "grok.stream": "chat.stream", - "grok.thinking": "chat.thinking", - "grok.dynamic_statsig": "chat.dynamic_statsig", - "grok.filter_tags": "chat.filter_tags", - "grok.timeout": "network.timeout", - "grok.base_proxy_url": "network.base_proxy_url", - "grok.asset_proxy_url": "network.asset_proxy_url", - "grok.cf_clearance": "security.cf_clearance", - "grok.browser": "security.browser", - "grok.user_agent": "security.user_agent", + "grok.temporary": "app.temporary", + "grok.disable_memory": "app.disable_memory", + "grok.stream": "app.stream", + "grok.thinking": "app.thinking", + "grok.dynamic_statsig": "app.dynamic_statsig", + "grok.filter_tags": "app.filter_tags", + "grok.timeout": "voice.timeout", + "grok.base_proxy_url": "proxy.base_proxy_url", + "grok.asset_proxy_url": "proxy.asset_proxy_url", + "network.base_proxy_url": "proxy.base_proxy_url", + "network.asset_proxy_url": "proxy.asset_proxy_url", + "grok.cf_clearance": "proxy.cf_clearance", + "grok.browser": "proxy.browser", + "grok.user_agent": "proxy.user_agent", + "security.cf_clearance": "proxy.cf_clearance", + "security.browser": "proxy.browser", + "security.user_agent": "proxy.user_agent", "grok.max_retry": "retry.max_retry", "grok.retry_status_codes": "retry.retry_status_codes", "grok.retry_backoff_base": "retry.retry_backoff_base", "grok.retry_backoff_factor": "retry.retry_backoff_factor", "grok.retry_backoff_max": "retry.retry_backoff_max", "grok.retry_budget": "retry.retry_budget", - "grok.stream_idle_timeout": "timeout.stream_idle_timeout", - "grok.video_idle_timeout": "timeout.video_idle_timeout", - "grok.image_ws": "image.image_ws", - "grok.image_ws_nsfw": "image.image_ws_nsfw", - "grok.image_ws_blocked_seconds": "image.image_ws_blocked_seconds", - "grok.image_ws_final_min_bytes": "image.image_ws_final_min_bytes", - "grok.image_ws_medium_min_bytes": "image.image_ws_medium_min_bytes", + "grok.video_idle_timeout": "video.stream_timeout", + "grok.image_ws_nsfw": "image.nsfw", + "grok.image_ws_blocked_seconds": "image.final_timeout", + "grok.image_ws_final_min_bytes": "image.final_min_bytes", + "grok.image_ws_medium_min_bytes": "image.medium_min_bytes", + # legacy sections + "network.base_proxy_url": "proxy.base_proxy_url", + "network.asset_proxy_url": "proxy.asset_proxy_url", + "network.timeout": [ + "chat.timeout", + "image.timeout", + "video.timeout", + "voice.timeout", + ], + "security.cf_clearance": "proxy.cf_clearance", + "security.browser": "proxy.browser", + "security.user_agent": "proxy.user_agent", + "timeout.stream_idle_timeout": [ + "chat.stream_timeout", + "image.stream_timeout", + "video.stream_timeout", + ], + "timeout.video_idle_timeout": "video.stream_timeout", + "image.image_ws_nsfw": "image.nsfw", + "image.image_ws_blocked_seconds": "image.final_timeout", + "image.image_ws_final_min_bytes": "image.final_min_bytes", + "image.image_ws_medium_min_bytes": "image.medium_min_bytes", + "performance.assets_max_concurrent": [ + "asset.upload_concurrent", + "asset.download_concurrent", + "asset.list_concurrent", + "asset.delete_concurrent", + ], + "performance.assets_delete_batch_size": "asset.delete_batch_size", + "performance.assets_batch_size": "asset.list_batch_size", + "performance.media_max_concurrent": ["chat.concurrent", "video.concurrent"], + "performance.usage_max_concurrent": "usage.concurrent", + "performance.usage_batch_size": "usage.batch_size", + "performance.nsfw_max_concurrent": "nsfw.concurrent", + "performance.nsfw_batch_size": "nsfw.batch_size", } deprecated_sections = set(config.keys()) - valid_sections @@ -78,28 +116,62 @@ def _migrate_deprecated_config( result = {k: deepcopy(v) for k, v in config.items() if k in valid_sections} migrated_count = 0 - # 处理废弃配置节中的配置项 - for old_section in deprecated_sections: - if old_section not in config or not isinstance(config[old_section], dict): + # 处理废弃配置节或旧配置键 + for old_section, old_values in config.items(): + if not isinstance(old_values, dict): continue - - for old_key, old_value in config[old_section].items(): - # 查找映射规则 + for old_key, old_value in old_values.items(): old_path = f"{old_section}.{old_key}" - new_path = MIGRATION_MAP.get(old_path) - - if new_path: - new_section, new_key = new_path.split(".", 1) - # 确保新配置节存在 - if new_section not in result: - result[new_section] = {} - # 迁移配置项(保留用户的自定义值) - result[new_section][new_key] = old_value + new_paths = MIGRATION_MAP.get(old_path) + if not new_paths: + continue + if isinstance(new_paths, str): + new_paths = [new_paths] + for new_path in new_paths: + try: + new_section, new_key = new_path.split(".", 1) + if new_section not in result: + result[new_section] = {} + if new_key not in result[new_section]: + result[new_section][new_key] = old_value + migrated_count += 1 + logger.debug( + f"Migrated config: {old_path} -> {new_path} = {old_value}" + ) + except Exception as e: + logger.warning( + f"Skip config migration for {old_path}: {e}" + ) + continue + if isinstance(result.get(old_section), dict): + result[old_section].pop(old_key, None) + + # 兼容旧 chat.* 配置键迁移到 app.* + legacy_chat_map = { + "temporary": "temporary", + "disable_memory": "disable_memory", + "stream": "stream", + "thinking": "thinking", + "dynamic_statsig": "dynamic_statsig", + "filter_tags": "filter_tags", + } + chat_section = config.get("chat") + if isinstance(chat_section, dict): + app_section = result.setdefault("app", {}) + for old_key, new_key in legacy_chat_map.items(): + if old_key in chat_section and new_key not in app_section: + app_section[new_key] = chat_section[old_key] + if isinstance(result.get("chat"), dict): + result["chat"].pop(old_key, None) migrated_count += 1 - logger.debug(f"Migrated config: {old_path} -> {new_path} = {old_value}") + logger.debug( + f"Migrated config: chat.{old_key} -> app.{new_key} = {chat_section[old_key]}" + ) if migrated_count > 0: - logger.info(f"Migrated {migrated_count} config items from deprecated sections") + logger.info( + f"Migrated {migrated_count} config items from deprecated/legacy sections" + ) return result, deprecated_sections diff --git a/app/core/exceptions.py b/app/core/exceptions.py index 9ceff005..24aa281b 100644 --- a/app/core/exceptions.py +++ b/app/core/exceptions.py @@ -101,6 +101,14 @@ def __init__(self, message: str, details: Any = None): self.details = details +class StreamIdleTimeoutError(Exception): + """流空闲超时错误""" + + def __init__(self, idle_seconds: float): + self.idle_seconds = idle_seconds + super().__init__(f"Stream idle timeout after {idle_seconds}s") + + # ============= 异常处理器 ============= @@ -210,7 +218,6 @@ def register_exception_handlers(app): app.add_exception_handler(HTTPException, http_exception_handler) app.add_exception_handler(RequestValidationError, validation_exception_handler) app.add_exception_handler(Exception, generic_exception_handler) - app.add_exception_handler(Exception, generic_exception_handler) __all__ = [ @@ -219,6 +226,7 @@ def register_exception_handlers(app): "ValidationException", "AuthenticationException", "UpstreamException", + "StreamIdleTimeoutError", "error_response", "register_exception_handlers", ] diff --git a/app/core/logger.py b/app/core/logger.py index a49b219d..0b0290f7 100644 --- a/app/core/logger.py +++ b/app/core/logger.py @@ -9,6 +9,10 @@ from pathlib import Path from loguru import logger +# Provide logging.Logger compatibility for legacy calls +if not hasattr(logger, "isEnabledFor"): + logger.isEnabledFor = lambda _level: True + # 日志目录 DEFAULT_LOG_DIR = Path(__file__).parent.parent.parent / "logs" LOG_DIR = Path(os.getenv("LOG_DIR", str(DEFAULT_LOG_DIR))) diff --git a/app/core/response_middleware.py b/app/core/response_middleware.py index 2cfa8b3e..4c0a07ec 100644 --- a/app/core/response_middleware.py +++ b/app/core/response_middleware.py @@ -25,6 +25,20 @@ async def dispatch(self, request: Request, call_next): request.state.trace_id = trace_id start_time = time.time() + path = request.url.path + + if path.startswith("/static/") or path in ( + "/", + "/login", + "/imagine", + "/voice", + "/admin", + "/admin/login", + "/admin/config", + "/admin/cache", + "/admin/token", + ): + return await call_next(request) # 记录请求信息 logger.info( diff --git a/app/services/__init__.py b/app/services/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/app/services/grok/__init__.py b/app/services/grok/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/app/services/grok/batch_services/assets.py b/app/services/grok/batch_services/assets.py new file mode 100644 index 00000000..7c3c31e3 --- /dev/null +++ b/app/services/grok/batch_services/assets.py @@ -0,0 +1,231 @@ +""" +Batch assets service. +""" + +import asyncio +from typing import Dict, List, Optional + +from curl_cffi.requests import AsyncSession + +from app.core.config import get_config +from app.core.logger import logger +from app.services.reverse.assets_list import AssetsListReverse +from app.services.reverse.assets_delete import AssetsDeleteReverse +from app.core.batch import run_batch + + +class BaseAssetsService: + """Base assets service.""" + + def __init__(self): + self._session: Optional[AsyncSession] = None + + async def _get_session(self) -> AsyncSession: + if self._session is None: + self._session = AsyncSession() + return self._session + + async def close(self): + if self._session: + await self._session.close() + self._session = None + + +_LIST_SEMAPHORE = None +_LIST_SEM_VALUE = None +_DELETE_SEMAPHORE = None +_DELETE_SEM_VALUE = None + + +def _get_list_semaphore() -> asyncio.Semaphore: + value = max(1, int(get_config("asset.list_concurrent"))) + global _LIST_SEMAPHORE, _LIST_SEM_VALUE + if _LIST_SEMAPHORE is None or value != _LIST_SEM_VALUE: + _LIST_SEM_VALUE = value + _LIST_SEMAPHORE = asyncio.Semaphore(value) + return _LIST_SEMAPHORE + + +def _get_delete_semaphore() -> asyncio.Semaphore: + value = max(1, int(get_config("asset.delete_concurrent"))) + global _DELETE_SEMAPHORE, _DELETE_SEM_VALUE + if _DELETE_SEMAPHORE is None or value != _DELETE_SEM_VALUE: + _DELETE_SEM_VALUE = value + _DELETE_SEMAPHORE = asyncio.Semaphore(value) + return _DELETE_SEMAPHORE + + +class ListService(BaseAssetsService): + """Assets list service.""" + + async def list(self, token: str) -> Dict[str, List[str] | int]: + params = { + "pageSize": 50, + "orderBy": "ORDER_BY_LAST_USE_TIME", + "source": "SOURCE_ANY", + "isLatest": "true", + } + page_token = None + seen_tokens = set() + asset_ids: List[str] = [] + session = await self._get_session() + while True: + if page_token: + if page_token in seen_tokens: + logger.warning("Pagination stopped: repeated page token") + break + seen_tokens.add(page_token) + params["pageToken"] = page_token + else: + params.pop("pageToken", None) + + async with _get_list_semaphore(): + response = await AssetsListReverse.request( + session, + token, + params, + ) + + result = response.json() + page_assets = result.get("assets", []) + if page_assets: + for asset in page_assets: + asset_id = asset.get("assetId") + if asset_id: + asset_ids.append(asset_id) + + page_token = result.get("nextPageToken") + if not page_token: + break + + logger.info(f"List success: {len(asset_ids)} files") + return {"asset_ids": asset_ids, "count": len(asset_ids)} + + @staticmethod + async def fetch_assets_details( + tokens: List[str], + account_map: dict, + *, + include_ok: bool = False, + on_item=None, + should_cancel=None, + ) -> dict: + """Batch fetch assets details for tokens.""" + account_map = account_map or {} + shared_service = ListService() + batch_size = max(1, int(get_config("asset.list_batch_size"))) + + async def _fetch_detail(token: str): + account = account_map.get(token) + try: + result = await shared_service.list(token) + asset_ids = result.get("asset_ids", []) + count = result.get("count", len(asset_ids)) + detail = { + "token": token, + "token_masked": account["token_masked"] if account else token, + "count": count, + "status": "ok", + "last_asset_clear_at": account["last_asset_clear_at"] + if account + else None, + } + if include_ok: + return {"ok": True, "detail": detail, "count": count} + return {"detail": detail, "count": count} + except Exception as e: + detail = { + "token": token, + "token_masked": account["token_masked"] if account else token, + "count": 0, + "status": f"error: {str(e)}", + "last_asset_clear_at": account["last_asset_clear_at"] + if account + else None, + } + if include_ok: + return {"ok": False, "detail": detail, "count": 0} + return {"detail": detail, "count": 0} + + try: + return await run_batch( + tokens, + _fetch_detail, + batch_size=batch_size, + on_item=on_item, + should_cancel=should_cancel, + ) + finally: + await shared_service.close() + + +class DeleteService(BaseAssetsService): + """Assets delete service.""" + + async def delete(self, token: str, asset_ids: List[str]) -> Dict[str, int]: + if not asset_ids: + logger.info("No assets to delete") + return {"total": 0, "success": 0, "failed": 0, "skipped": True} + + total = len(asset_ids) + success = 0 + failed = 0 + session = await self._get_session() + + async def _delete_one(asset_id: str): + async with _get_delete_semaphore(): + await AssetsDeleteReverse.request(session, token, asset_id) + + tasks = [_delete_one(asset_id) for asset_id in asset_ids if asset_id] + results = await asyncio.gather(*tasks, return_exceptions=True) + for res in results: + if isinstance(res, Exception): + failed += 1 + else: + success += 1 + + logger.info(f"Delete all: total={total}, success={success}, failed={failed}") + return {"total": total, "success": success, "failed": failed} + + @staticmethod + async def clear_assets( + tokens: List[str], + mgr, + *, + include_ok: bool = False, + on_item=None, + should_cancel=None, + ) -> dict: + """Batch clear assets for tokens.""" + delete_service = DeleteService() + list_service = ListService() + batch_size = max(1, int(get_config("asset.delete_batch_size"))) + + async def _clear_one(token: str): + try: + result = await list_service.list(token) + asset_ids = result.get("asset_ids", []) + result = await delete_service.delete(token, asset_ids) + await mgr.mark_asset_clear(token) + if include_ok: + return {"ok": True, "result": result} + return {"status": "success", "result": result} + except Exception as e: + if include_ok: + return {"ok": False, "error": str(e)} + return {"status": "error", "error": str(e)} + + try: + return await run_batch( + tokens, + _clear_one, + batch_size=batch_size, + on_item=on_item, + should_cancel=should_cancel, + ) + finally: + await delete_service.close() + await list_service.close() + + +__all__ = ["ListService", "DeleteService"] diff --git a/app/services/grok/batch_services/nsfw.py b/app/services/grok/batch_services/nsfw.py new file mode 100644 index 00000000..1c8faa0c --- /dev/null +++ b/app/services/grok/batch_services/nsfw.py @@ -0,0 +1,113 @@ +""" +Batch NSFW service. +""" + +import asyncio +from typing import Callable, Awaitable, Dict, Any, Optional + +from curl_cffi.requests import AsyncSession + +from app.core.logger import logger +from app.core.config import get_config +from app.core.exceptions import UpstreamException +from app.services.reverse.accept_tos import AcceptTosReverse +from app.services.reverse.nsfw_mgmt import NsfwMgmtReverse +from app.services.reverse.set_birth import SetBirthReverse +from app.core.batch import run_batch + + +_NSFW_SEMAPHORE = None +_NSFW_SEM_VALUE = None + + +def _get_nsfw_semaphore() -> asyncio.Semaphore: + value = max(1, int(get_config("nsfw.concurrent"))) + global _NSFW_SEMAPHORE, _NSFW_SEM_VALUE + if _NSFW_SEMAPHORE is None or value != _NSFW_SEM_VALUE: + _NSFW_SEM_VALUE = value + _NSFW_SEMAPHORE = asyncio.Semaphore(value) + return _NSFW_SEMAPHORE + + +class NSFWService: + """NSFW 模式服务""" + @staticmethod + async def batch( + tokens: list[str], + mgr, + *, + on_item: Optional[Callable[[str, Dict[str, Any]], Awaitable[None]]] = None, + should_cancel: Optional[Callable[[], bool]] = None, + ) -> Dict[str, Dict[str, Any]]: + """Batch enable NSFW.""" + batch_size = get_config("nsfw.batch_size") + async def _enable(token: str): + try: + browser = get_config("proxy.browser") + async with AsyncSession(impersonate=browser) as session: + async def _record_fail(err: UpstreamException, reason: str): + status = None + if err.details and "status" in err.details: + status = err.details["status"] + else: + status = getattr(err, "status_code", None) + if status == 401: + await mgr.record_fail(token, status, reason) + return status or 0 + + try: + async with _get_nsfw_semaphore(): + await AcceptTosReverse.request(session, token) + except UpstreamException as e: + status = await _record_fail(e, "tos_auth_failed") + return { + "success": False, + "http_status": status, + "error": f"Accept ToS failed: {str(e)}", + } + + try: + async with _get_nsfw_semaphore(): + await SetBirthReverse.request(session, token) + except UpstreamException as e: + status = await _record_fail(e, "set_birth_auth_failed") + return { + "success": False, + "http_status": status, + "error": f"Set birth date failed: {str(e)}", + } + + try: + async with _get_nsfw_semaphore(): + grpc_status = await NsfwMgmtReverse.request(session, token) + success = grpc_status.code in (-1, 0) + except UpstreamException as e: + status = await _record_fail(e, "nsfw_mgmt_auth_failed") + return { + "success": False, + "http_status": status, + "error": f"NSFW enable failed: {str(e)}", + } + if success: + await mgr.add_tag(token, "nsfw") + return { + "success": success, + "http_status": 200, + "grpc_status": grpc_status.code, + "grpc_message": grpc_status.message or None, + "error": None, + } + except Exception as e: + logger.error(f"NSFW enable failed: {e}") + return {"success": False, "http_status": 0, "error": str(e)[:100]} + + return await run_batch( + tokens, + _enable, + batch_size=batch_size, + on_item=on_item, + should_cancel=should_cancel, + ) + + +__all__ = ["NSFWService"] diff --git a/app/services/grok/batch_services/usage.py b/app/services/grok/batch_services/usage.py new file mode 100644 index 00000000..66aab105 --- /dev/null +++ b/app/services/grok/batch_services/usage.py @@ -0,0 +1,81 @@ +""" +Batch usage service. +""" + +import asyncio +from typing import Callable, Awaitable, Dict, Any, Optional, List + +from curl_cffi.requests import AsyncSession + +from app.core.logger import logger +from app.core.config import get_config +from app.services.reverse.rate_limits import RateLimitsReverse +from app.core.batch import run_batch + +_USAGE_SEMAPHORE = None +_USAGE_SEM_VALUE = None + + +def _get_usage_semaphore() -> asyncio.Semaphore: + value = max(1, int(get_config("usage.concurrent"))) + global _USAGE_SEMAPHORE, _USAGE_SEM_VALUE + if _USAGE_SEMAPHORE is None or value != _USAGE_SEM_VALUE: + _USAGE_SEM_VALUE = value + _USAGE_SEMAPHORE = asyncio.Semaphore(value) + return _USAGE_SEMAPHORE + + +class UsageService: + """用量查询服务""" + + async def get(self, token: str) -> Dict: + """ + 获取速率限制信息 + + Args: + token: 认证 Token + + Returns: + 响应数据 + + Raises: + UpstreamException: 当获取失败且重试耗尽时 + """ + async with _get_usage_semaphore(): + try: + async with AsyncSession() as session: + response = await RateLimitsReverse.request(session, token) + data = response.json() + remaining = data.get("remainingTokens", 0) + logger.info( + f"Usage sync success: remaining={remaining}, token={token[:10]}..." + ) + return data + + except Exception: + # 最后一次失败已经被记录 + raise + + + @staticmethod + async def batch( + tokens: List[str], + mgr, + *, + on_item: Optional[Callable[[str, Dict[str, Any]], Awaitable[None]]] = None, + should_cancel: Optional[Callable[[], bool]] = None, + ) -> Dict[str, Dict[str, Any]]: + batch_size = get_config("usage.batch_size") + async def _refresh_one(t: str): + return await mgr.sync_usage(t, consume_on_fail=False, is_usage=False) + + return await run_batch( + tokens, + _refresh_one, + batch_size=batch_size, + on_item=on_item, + should_cancel=should_cancel, + ) + + +__all__ = ["UsageService"] diff --git a/app/services/grok/defaults.py b/app/services/grok/defaults.py index 03f1c10f..d7af7eb7 100644 --- a/app/services/grok/defaults.py +++ b/app/services/grok/defaults.py @@ -1,85 +1,33 @@ """ Grok 服务默认配置 -此文件定义所有 Grok 相关服务的默认值,会在应用启动时注册到配置系统中。 +此文件读取 config.defaults.toml,作为 Grok 服务的默认值来源。 """ -# Grok 服务默认配置 -GROK_DEFAULTS = { - "app": { - "app_url": "", - "app_key": "grok2api", - "api_key": "", - "image_format": "url", - "video_format": "html", - }, - "network": { - "timeout": 120, - "base_proxy_url": "", - "asset_proxy_url": "", - }, - "security": { - "cf_clearance": "", - "browser": "chrome136", - "user_agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36", - }, - "chat": { - "temporary": True, - "disable_memory": True, - "stream": True, - "thinking": False, - "dynamic_statsig": True, - "filter_tags": ["grok:render", "xaiartifact", "xai:tool_usage_card"], - }, - "retry": { - "max_retry": 3, - "retry_status_codes": [401, 429, 403], - "retry_backoff_base": 0.5, - "retry_backoff_factor": 2.0, - "retry_backoff_max": 30.0, - "retry_budget": 90.0, - }, - "timeout": { - "stream_idle_timeout": 45.0, - "video_idle_timeout": 90.0, - }, - "image": { - "image_ws": True, - "image_ws_nsfw": True, - "image_ws_blocked_seconds": 15, - "image_ws_final_min_bytes": 100000, - "image_ws_medium_min_bytes": 30000, - }, - "token": { - "auto_refresh": True, - "refresh_interval_hours": 8, - "super_refresh_interval_hours": 2, - "fail_threshold": 5, - "save_delay_ms": 500, - "reload_interval_sec": 30, - }, - "cache": { - "enable_auto_clean": True, - "limit_mb": 1024, - }, - "performance": { - "assets_max_concurrent": 25, - "assets_delete_batch_size": 10, - "assets_batch_size": 10, - "assets_max_tokens": 1000, - "media_max_concurrent": 50, - "usage_max_concurrent": 25, - "usage_batch_size": 50, - "usage_max_tokens": 1000, - "nsfw_max_concurrent": 10, - "nsfw_batch_size": 50, - "nsfw_max_tokens": 1000, - }, -} +from pathlib import Path +import tomllib + +from app.core.logger import logger + +DEFAULTS_FILE = Path(__file__).resolve().parent.parent.parent.parent / "config.defaults.toml" + +# Grok 服务默认配置(运行时从 config.defaults.toml 读取并缓存) +GROK_DEFAULTS: dict = {} def get_grok_defaults(): """获取 Grok 默认配置""" + global GROK_DEFAULTS + if GROK_DEFAULTS: + return GROK_DEFAULTS + if not DEFAULTS_FILE.exists(): + logger.warning(f"Defaults file not found: {DEFAULTS_FILE}") + return GROK_DEFAULTS + try: + with DEFAULTS_FILE.open("rb") as f: + GROK_DEFAULTS = tomllib.load(f) + except Exception as e: + logger.warning(f"Failed to load defaults from {DEFAULTS_FILE}: {e}") return GROK_DEFAULTS diff --git a/app/services/grok/models/__init__.py b/app/services/grok/models/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/app/services/grok/processors/__init__.py b/app/services/grok/processors/__init__.py deleted file mode 100644 index 04773f67..00000000 --- a/app/services/grok/processors/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -""" -OpenAI 响应格式处理器 -""" - -from .base import BaseProcessor, StreamIdleTimeoutError -from .chat_processors import StreamProcessor, CollectProcessor -from .video_processors import VideoStreamProcessor, VideoCollectProcessor -from .image_processors import ImageStreamProcessor, ImageCollectProcessor -from .image_ws_processors import ImageWSStreamProcessor, ImageWSCollectProcessor - -__all__ = [ - "BaseProcessor", - "StreamIdleTimeoutError", - "StreamProcessor", - "CollectProcessor", - "VideoStreamProcessor", - "VideoCollectProcessor", - "ImageStreamProcessor", - "ImageCollectProcessor", - "ImageWSStreamProcessor", - "ImageWSCollectProcessor", -] diff --git a/app/services/grok/processors/chat_processors.py b/app/services/grok/processors/chat_processors.py deleted file mode 100644 index 000e09b8..00000000 --- a/app/services/grok/processors/chat_processors.py +++ /dev/null @@ -1,382 +0,0 @@ -""" -聊天响应处理器 -""" - -import asyncio -import uuid -import re -from typing import Any, AsyncGenerator, AsyncIterable - -import orjson -from curl_cffi.requests.errors import RequestsError - -from app.core.config import get_config -from app.core.logger import logger -from app.core.exceptions import UpstreamException -from .base import ( - BaseProcessor, - StreamIdleTimeoutError, - _with_idle_timeout, - _normalize_stream_line, - _collect_image_urls, - _is_http2_stream_error, -) - - -class StreamProcessor(BaseProcessor): - """流式响应处理器""" - - def __init__(self, model: str, token: str = "", think: bool = None): - super().__init__(model, token) - self.response_id: str = None - self.fingerprint: str = "" - self.think_opened: bool = False - self.role_sent: bool = False - self.filter_tags = get_config("chat.filter_tags") - self.image_format = get_config("app.image_format") - self._tag_buffer: str = "" - self._in_filter_tag: bool = False - - if think is None: - self.show_think = get_config("chat.thinking") - else: - self.show_think = think - - def _filter_token(self, token: str) -> str: - """过滤 token 中的特殊标签(如 ...),支持跨 token 的标签过滤""" - if not self.filter_tags: - return token - - result = [] - i = 0 - while i < len(token): - char = token[i] - - if self._in_filter_tag: - self._tag_buffer += char - if char == ">": - if "/>" in self._tag_buffer: - self._in_filter_tag = False - self._tag_buffer = "" - else: - for tag in self.filter_tags: - if f"" in self._tag_buffer: - self._in_filter_tag = False - self._tag_buffer = "" - break - i += 1 - continue - - if char == "<": - remaining = token[i:] - tag_started = False - for tag in self.filter_tags: - if remaining.startswith(f"<{tag}"): - tag_started = True - break - if len(remaining) < len(tag) + 1: - for j in range(1, len(remaining) + 1): - if f"<{tag}".startswith(remaining[:j]): - tag_started = True - break - - if tag_started: - self._in_filter_tag = True - self._tag_buffer = char - i += 1 - continue - - result.append(char) - i += 1 - - return "".join(result) - - def _sse(self, content: str = "", role: str = None, finish: str = None) -> str: - """构建 SSE 响应""" - delta = {} - if role: - delta["role"] = role - delta["content"] = "" - elif content: - delta["content"] = content - - chunk = { - "id": self.response_id or f"chatcmpl-{uuid.uuid4().hex[:24]}", - "object": "chat.completion.chunk", - "created": self.created, - "model": self.model, - "system_fingerprint": self.fingerprint, - "choices": [ - {"index": 0, "delta": delta, "logprobs": None, "finish_reason": finish} - ], - } - return f"data: {orjson.dumps(chunk).decode()}\n\n" - - async def process( - self, response: AsyncIterable[bytes] - ) -> AsyncGenerator[str, None]: - """处理流式响应""" - idle_timeout = get_config("timeout.stream_idle_timeout") - - try: - async for line in _with_idle_timeout(response, idle_timeout, self.model): - line = _normalize_stream_line(line) - if not line: - continue - try: - data = orjson.loads(line) - except orjson.JSONDecodeError: - continue - - resp = data.get("result", {}).get("response", {}) - - if (llm := resp.get("llmInfo")) and not self.fingerprint: - self.fingerprint = llm.get("modelHash", "") - if rid := resp.get("responseId"): - self.response_id = rid - - if not self.role_sent: - yield self._sse(role="assistant") - self.role_sent = True - - # 图像生成进度 - if img := resp.get("streamingImageGenerationResponse"): - if self.show_think: - if not self.think_opened: - yield self._sse("\n") - self.think_opened = True - idx = img.get("imageIndex", 0) + 1 - progress = img.get("progress", 0) - yield self._sse( - f"正在生成第{idx}张图片中,当前进度{progress}%\n" - ) - continue - - # modelResponse - if mr := resp.get("modelResponse"): - if self.think_opened and self.show_think: - if msg := mr.get("message"): - yield self._sse(msg + "\n") - yield self._sse("\n") - self.think_opened = False - - # 处理生成的图片 - for url in _collect_image_urls(mr): - parts = url.split("/") - img_id = parts[-2] if len(parts) >= 2 else "image" - - if self.image_format == "base64": - try: - dl_service = self._get_dl() - base64_data = await dl_service.to_base64( - url, self.token, "image" - ) - if base64_data: - yield self._sse(f"![{img_id}]({base64_data})\n") - else: - final_url = await self.process_url(url, "image") - yield self._sse(f"![{img_id}]({final_url})\n") - except Exception as e: - logger.warning( - f"Failed to convert image to base64, falling back to URL: {e}" - ) - final_url = await self.process_url(url, "image") - yield self._sse(f"![{img_id}]({final_url})\n") - else: - final_url = await self.process_url(url, "image") - yield self._sse(f"![{img_id}]({final_url})\n") - - if ( - (meta := mr.get("metadata", {})) - .get("llm_info", {}) - .get("modelHash") - ): - self.fingerprint = meta["llm_info"]["modelHash"] - continue - - # 普通 token - if (token := resp.get("token")) is not None: - if token: - filtered = self._filter_token(token) - if filtered: - yield self._sse(filtered) - - if self.think_opened: - yield self._sse("\n") - yield self._sse(finish="stop") - yield "data: [DONE]\n\n" - except asyncio.CancelledError: - logger.debug("Stream cancelled by client", extra={"model": self.model}) - except StreamIdleTimeoutError as e: - raise UpstreamException( - message=f"Stream idle timeout after {e.idle_seconds}s", - status_code=504, - details={ - "error": str(e), - "type": "stream_idle_timeout", - "idle_seconds": e.idle_seconds, - }, - ) - except RequestsError as e: - if _is_http2_stream_error(e): - logger.warning(f"HTTP/2 stream error: {e}", extra={"model": self.model}) - raise UpstreamException( - message="Upstream connection closed unexpectedly", - status_code=502, - details={"error": str(e), "type": "http2_stream_error"}, - ) - logger.error(f"Stream request error: {e}", extra={"model": self.model}) - raise UpstreamException( - message=f"Upstream request failed: {e}", - status_code=502, - details={"error": str(e)}, - ) - except Exception as e: - logger.error( - f"Stream processing error: {e}", - extra={"model": self.model, "error_type": type(e).__name__}, - ) - raise - finally: - await self.close() - - -class CollectProcessor(BaseProcessor): - """非流式响应处理器""" - - def __init__(self, model: str, token: str = ""): - super().__init__(model, token) - self.image_format = get_config("app.image_format") - self.filter_tags = get_config("chat.filter_tags") - - def _filter_content(self, content: str) -> str: - """过滤内容中的特殊标签""" - if not content or not self.filter_tags: - return content - - result = content - for tag in self.filter_tags: - pattern = rf"<{re.escape(tag)}[^>]*>.*?|<{re.escape(tag)}[^>]*/>" - result = re.sub(pattern, "", result, flags=re.DOTALL) - - return result - - async def process(self, response: AsyncIterable[bytes]) -> dict[str, Any]: - """处理并收集完整响应""" - response_id = "" - fingerprint = "" - content = "" - idle_timeout = get_config("timeout.stream_idle_timeout") - - try: - async for line in _with_idle_timeout(response, idle_timeout, self.model): - line = _normalize_stream_line(line) - if not line: - continue - try: - data = orjson.loads(line) - except orjson.JSONDecodeError: - continue - - resp = data.get("result", {}).get("response", {}) - - if (llm := resp.get("llmInfo")) and not fingerprint: - fingerprint = llm.get("modelHash", "") - - if mr := resp.get("modelResponse"): - response_id = mr.get("responseId", "") - content = mr.get("message", "") - - if urls := _collect_image_urls(mr): - content += "\n" - for url in urls: - parts = url.split("/") - img_id = parts[-2] if len(parts) >= 2 else "image" - - if self.image_format == "base64": - try: - dl_service = self._get_dl() - base64_data = await dl_service.to_base64( - url, self.token, "image" - ) - if base64_data: - content += f"![{img_id}]({base64_data})\n" - else: - final_url = await self.process_url(url, "image") - content += f"![{img_id}]({final_url})\n" - except Exception as e: - logger.warning( - f"Failed to convert image to base64, falling back to URL: {e}" - ) - final_url = await self.process_url(url, "image") - content += f"![{img_id}]({final_url})\n" - else: - final_url = await self.process_url(url, "image") - content += f"![{img_id}]({final_url})\n" - - if ( - (meta := mr.get("metadata", {})) - .get("llm_info", {}) - .get("modelHash") - ): - fingerprint = meta["llm_info"]["modelHash"] - - except asyncio.CancelledError: - logger.debug("Collect cancelled by client", extra={"model": self.model}) - except StreamIdleTimeoutError as e: - logger.warning(f"Collect idle timeout: {e}", extra={"model": self.model}) - except RequestsError as e: - if _is_http2_stream_error(e): - logger.warning( - f"HTTP/2 stream error in collect: {e}", extra={"model": self.model} - ) - else: - logger.error(f"Collect request error: {e}", extra={"model": self.model}) - except Exception as e: - logger.error( - f"Collect processing error: {e}", - extra={"model": self.model, "error_type": type(e).__name__}, - ) - finally: - await self.close() - - content = self._filter_content(content) - - return { - "id": response_id, - "object": "chat.completion", - "created": self.created, - "model": self.model, - "system_fingerprint": fingerprint, - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": content, - "refusal": None, - "annotations": [], - }, - "finish_reason": "stop", - } - ], - "usage": { - "prompt_tokens": 0, - "completion_tokens": 0, - "total_tokens": 0, - "prompt_tokens_details": { - "cached_tokens": 0, - "text_tokens": 0, - "audio_tokens": 0, - "image_tokens": 0, - }, - "completion_tokens_details": { - "text_tokens": 0, - "audio_tokens": 0, - "reasoning_tokens": 0, - }, - }, - } - - -__all__ = ["StreamProcessor", "CollectProcessor"] diff --git a/app/services/grok/processors/image_processors.py b/app/services/grok/processors/image_processors.py deleted file mode 100644 index 8f78ac3f..00000000 --- a/app/services/grok/processors/image_processors.py +++ /dev/null @@ -1,248 +0,0 @@ -""" -图片生成响应处理器(HTTP) -""" - -import asyncio -import random -from typing import AsyncGenerator, AsyncIterable, List - -import orjson -from curl_cffi.requests.errors import RequestsError - -from app.core.config import get_config -from app.core.logger import logger -from app.core.exceptions import UpstreamException -from .base import ( - BaseProcessor, - StreamIdleTimeoutError, - _with_idle_timeout, - _normalize_stream_line, - _collect_image_urls, - _is_http2_stream_error, -) - - -class ImageStreamProcessor(BaseProcessor): - """图片生成流式响应处理器""" - - def __init__( - self, model: str, token: str = "", n: int = 1, response_format: str = "b64_json" - ): - super().__init__(model, token) - self.partial_index = 0 - self.n = n - self.target_index = random.randint(0, 1) if n == 1 else None - self.response_format = response_format - if response_format == "url": - self.response_field = "url" - elif response_format == "base64": - self.response_field = "base64" - else: - self.response_field = "b64_json" - - def _sse(self, event: str, data: dict) -> str: - """构建 SSE 响应""" - return f"event: {event}\ndata: {orjson.dumps(data).decode()}\n\n" - - async def process( - self, response: AsyncIterable[bytes] - ) -> AsyncGenerator[str, None]: - """处理流式响应""" - final_images = [] - idle_timeout = get_config("timeout.stream_idle_timeout") - - try: - async for line in _with_idle_timeout(response, idle_timeout, self.model): - line = _normalize_stream_line(line) - if not line: - continue - try: - data = orjson.loads(line) - except orjson.JSONDecodeError: - continue - - resp = data.get("result", {}).get("response", {}) - - # 图片生成进度 - if img := resp.get("streamingImageGenerationResponse"): - image_index = img.get("imageIndex", 0) - progress = img.get("progress", 0) - - if self.n == 1 and image_index != self.target_index: - continue - - out_index = 0 if self.n == 1 else image_index - - yield self._sse( - "image_generation.partial_image", - { - "type": "image_generation.partial_image", - self.response_field: "", - "index": out_index, - "progress": progress, - }, - ) - continue - - # modelResponse - if mr := resp.get("modelResponse"): - if urls := _collect_image_urls(mr): - for url in urls: - if self.response_format == "url": - processed = await self.process_url(url, "image") - if processed: - final_images.append(processed) - continue - try: - dl_service = self._get_dl() - base64_data = await dl_service.to_base64( - url, self.token, "image" - ) - if base64_data: - if "," in base64_data: - b64 = base64_data.split(",", 1)[1] - else: - b64 = base64_data - final_images.append(b64) - except Exception as e: - logger.warning( - f"Failed to convert image to base64, falling back to URL: {e}" - ) - processed = await self.process_url(url, "image") - if processed: - final_images.append(processed) - continue - - for index, b64 in enumerate(final_images): - if self.n == 1: - if index != self.target_index: - continue - out_index = 0 - else: - out_index = index - - yield self._sse( - "image_generation.completed", - { - "type": "image_generation.completed", - self.response_field: b64, - "index": out_index, - "usage": { - "total_tokens": 0, - "input_tokens": 0, - "output_tokens": 0, - "input_tokens_details": { - "text_tokens": 0, - "image_tokens": 0, - }, - }, - }, - ) - except asyncio.CancelledError: - logger.debug("Image stream cancelled by client") - except StreamIdleTimeoutError as e: - raise UpstreamException( - message=f"Image stream idle timeout after {e.idle_seconds}s", - status_code=504, - details={ - "error": str(e), - "type": "stream_idle_timeout", - "idle_seconds": e.idle_seconds, - }, - ) - except RequestsError as e: - if _is_http2_stream_error(e): - logger.warning(f"HTTP/2 stream error in image: {e}") - raise UpstreamException( - message="Upstream connection closed unexpectedly", - status_code=502, - details={"error": str(e), "type": "http2_stream_error"}, - ) - logger.error(f"Image stream request error: {e}") - raise UpstreamException( - message=f"Upstream request failed: {e}", - status_code=502, - details={"error": str(e)}, - ) - except Exception as e: - logger.error( - f"Image stream processing error: {e}", - extra={"error_type": type(e).__name__}, - ) - raise - finally: - await self.close() - - -class ImageCollectProcessor(BaseProcessor): - """图片生成非流式响应处理器""" - - def __init__(self, model: str, token: str = "", response_format: str = "b64_json"): - super().__init__(model, token) - self.response_format = response_format - - async def process(self, response: AsyncIterable[bytes]) -> List[str]: - """处理并收集图片""" - images = [] - idle_timeout = get_config("timeout.stream_idle_timeout") - - try: - async for line in _with_idle_timeout(response, idle_timeout, self.model): - line = _normalize_stream_line(line) - if not line: - continue - try: - data = orjson.loads(line) - except orjson.JSONDecodeError: - continue - - resp = data.get("result", {}).get("response", {}) - - if mr := resp.get("modelResponse"): - if urls := _collect_image_urls(mr): - for url in urls: - if self.response_format == "url": - processed = await self.process_url(url, "image") - if processed: - images.append(processed) - continue - try: - dl_service = self._get_dl() - base64_data = await dl_service.to_base64( - url, self.token, "image" - ) - if base64_data: - if "," in base64_data: - b64 = base64_data.split(",", 1)[1] - else: - b64 = base64_data - images.append(b64) - except Exception as e: - logger.warning( - f"Failed to convert image to base64, falling back to URL: {e}" - ) - processed = await self.process_url(url, "image") - if processed: - images.append(processed) - - except asyncio.CancelledError: - logger.debug("Image collect cancelled by client") - except StreamIdleTimeoutError as e: - logger.warning(f"Image collect idle timeout: {e}") - except RequestsError as e: - if _is_http2_stream_error(e): - logger.warning(f"HTTP/2 stream error in image collect: {e}") - else: - logger.error(f"Image collect request error: {e}") - except Exception as e: - logger.error( - f"Image collect processing error: {e}", - extra={"error_type": type(e).__name__}, - ) - finally: - await self.close() - - return images - - -__all__ = ["ImageStreamProcessor", "ImageCollectProcessor"] diff --git a/app/services/grok/processors/image_ws_processors.py b/app/services/grok/processors/image_ws_processors.py deleted file mode 100644 index 788a442b..00000000 --- a/app/services/grok/processors/image_ws_processors.py +++ /dev/null @@ -1,268 +0,0 @@ -""" -图片生成响应处理器(WebSocket) -""" - -import base64 -import time -from pathlib import Path -from typing import AsyncGenerator, AsyncIterable, List, Dict, Optional - -import orjson - -from app.core.config import get_config -from app.core.logger import logger -from app.core.storage import DATA_DIR -from app.core.exceptions import UpstreamException -from .base import BaseProcessor - - -class ImageWSBaseProcessor(BaseProcessor): - """WebSocket 图片处理基类""" - - def __init__(self, model: str, token: str = "", response_format: str = "b64_json"): - super().__init__(model, token) - self.response_format = response_format - if response_format == "url": - self.response_field = "url" - elif response_format == "base64": - self.response_field = "base64" - else: - self.response_field = "b64_json" - self._image_dir: Optional[Path] = None - - def _ensure_image_dir(self) -> Path: - if self._image_dir is None: - base_dir = DATA_DIR / "tmp" / "image" - base_dir.mkdir(parents=True, exist_ok=True) - self._image_dir = base_dir - return self._image_dir - - def _strip_base64(self, blob: str) -> str: - if not blob: - return "" - if "," in blob and "base64" in blob.split(",", 1)[0]: - return blob.split(",", 1)[1] - return blob - - def _filename(self, image_id: str, is_final: bool) -> str: - ext = "jpg" if is_final else "png" - return f"{image_id}.{ext}" - - def _build_file_url(self, filename: str) -> str: - app_url = get_config("app.app_url") - if app_url: - return f"{app_url.rstrip('/')}/v1/files/image/{filename}" - return f"/v1/files/image/{filename}" - - def _save_blob(self, image_id: str, blob: str, is_final: bool) -> str: - data = self._strip_base64(blob) - if not data: - return "" - image_dir = self._ensure_image_dir() - filename = self._filename(image_id, is_final) - filepath = image_dir / filename - with open(filepath, "wb") as f: - f.write(base64.b64decode(data)) - return self._build_file_url(filename) - - def _pick_best(self, existing: Optional[Dict], incoming: Dict) -> Dict: - if not existing: - return incoming - if incoming.get("is_final") and not existing.get("is_final"): - return incoming - if existing.get("is_final") and not incoming.get("is_final"): - return existing - if incoming.get("blob_size", 0) > existing.get("blob_size", 0): - return incoming - return existing - - def _to_output(self, image_id: str, item: Dict) -> str: - try: - if self.response_format == "url": - return self._save_blob( - image_id, item.get("blob", ""), item.get("is_final", False) - ) - return self._strip_base64(item.get("blob", "")) - except Exception as e: - logger.warning(f"Image output failed: {e}") - return "" - - -class ImageWSStreamProcessor(ImageWSBaseProcessor): - """WebSocket 图片流式响应处理器""" - - def __init__( - self, - model: str, - token: str = "", - n: int = 1, - response_format: str = "b64_json", - size: str = "1024x1024", - ): - super().__init__(model, token, "b64_json") - self.n = n - self.size = size - self._target_id: Optional[str] = None - self._index_map: Dict[str, int] = {} - self._partial_map: Dict[str, int] = {} - - def _assign_index(self, image_id: str) -> Optional[int]: - if image_id in self._index_map: - return self._index_map[image_id] - if len(self._index_map) >= self.n: - return None - self._index_map[image_id] = len(self._index_map) - return self._index_map[image_id] - - def _sse(self, event: str, data: dict) -> str: - return f"event: {event}\ndata: {orjson.dumps(data).decode()}\n\n" - - async def process(self, response: AsyncIterable[dict]) -> AsyncGenerator[str, None]: - images: Dict[str, Dict] = {} - - async for item in response: - if item.get("type") == "error": - message = item.get("error") or "Upstream error" - code = item.get("error_code") or "upstream_error" - yield self._sse( - "error", - { - "error": { - "message": message, - "type": "server_error", - "code": code, - } - }, - ) - return - if item.get("type") != "image": - continue - - image_id = item.get("image_id") - if not image_id: - continue - - if self.n == 1: - if self._target_id is None: - self._target_id = image_id - index = 0 if image_id == self._target_id else None - else: - index = self._assign_index(image_id) - - images[image_id] = self._pick_best(images.get(image_id), item) - - if index is None: - continue - - if item.get("stage") != "final": - partial_b64 = self._strip_base64(item.get("blob", "")) - if not partial_b64: - continue - partial_index = self._partial_map.get(image_id, 0) - if item.get("stage") == "medium": - partial_index = max(partial_index, 1) - self._partial_map[image_id] = partial_index - yield self._sse( - "image_generation.partial_image", - { - "type": "image_generation.partial_image", - "b64_json": partial_b64, - "created_at": int(time.time()), - "size": self.size, - "index": index, - "partial_image_index": partial_index, - }, - ) - - if self.n == 1: - if self._target_id and self._target_id in images: - selected = [(self._target_id, images[self._target_id])] - else: - selected = ( - [ - max( - images.items(), - key=lambda x: ( - x[1].get("is_final", False), - x[1].get("blob_size", 0), - ), - ) - ] - if images - else [] - ) - else: - selected = [ - (image_id, images[image_id]) - for image_id in self._index_map - if image_id in images - ] - - for image_id, item in selected: - output = self._strip_base64(item.get("blob", "")) - if not output: - continue - - if self.n == 1: - index = 0 - else: - index = self._index_map.get(image_id, 0) - yield self._sse( - "image_generation.completed", - { - "type": "image_generation.completed", - "b64_json": output, - "created_at": int(time.time()), - "size": self.size, - "index": index, - "usage": { - "total_tokens": 0, - "input_tokens": 0, - "output_tokens": 0, - "input_tokens_details": {"text_tokens": 0, "image_tokens": 0}, - }, - }, - ) - - -class ImageWSCollectProcessor(ImageWSBaseProcessor): - """WebSocket 图片非流式响应处理器""" - - def __init__( - self, model: str, token: str = "", n: int = 1, response_format: str = "b64_json" - ): - super().__init__(model, token, response_format) - self.n = n - - async def process(self, response: AsyncIterable[dict]) -> List[str]: - images: Dict[str, Dict] = {} - - async for item in response: - if item.get("type") == "error": - message = item.get("error") or "Upstream error" - raise UpstreamException(message, details=item) - if item.get("type") != "image": - continue - image_id = item.get("image_id") - if not image_id: - continue - images[image_id] = self._pick_best(images.get(image_id), item) - - selected = sorted( - images.values(), - key=lambda x: (x.get("is_final", False), x.get("blob_size", 0)), - reverse=True, - ) - if self.n: - selected = selected[: self.n] - - results: List[str] = [] - for item in selected: - output = self._to_output(item.get("image_id", ""), item) - if output: - results.append(output) - - return results - - -__all__ = ["ImageWSStreamProcessor", "ImageWSCollectProcessor"] diff --git a/app/services/grok/processors/video_processors.py b/app/services/grok/processors/video_processors.py deleted file mode 100644 index a0ead8c3..00000000 --- a/app/services/grok/processors/video_processors.py +++ /dev/null @@ -1,276 +0,0 @@ -""" -视频响应处理器 -""" - -import asyncio -import uuid -from typing import Any, AsyncGenerator, AsyncIterable, Optional - -import orjson -from curl_cffi.requests.errors import RequestsError - -from app.core.config import get_config -from app.core.logger import logger -from app.core.exceptions import UpstreamException -from .base import ( - BaseProcessor, - StreamIdleTimeoutError, - _with_idle_timeout, - _normalize_stream_line, - _is_http2_stream_error, -) - - -class VideoStreamProcessor(BaseProcessor): - """视频流式响应处理器""" - - def __init__(self, model: str, token: str = "", think: bool = None): - super().__init__(model, token) - self.response_id: Optional[str] = None - self.think_opened: bool = False - self.role_sent: bool = False - self.video_format = str(get_config("app.video_format")).lower() - - if think is None: - self.show_think = get_config("chat.thinking") - else: - self.show_think = think - - def _sse(self, content: str = "", role: str = None, finish: str = None) -> str: - """构建 SSE 响应""" - delta = {} - if role: - delta["role"] = role - delta["content"] = "" - elif content: - delta["content"] = content - - chunk = { - "id": self.response_id or f"chatcmpl-{uuid.uuid4().hex[:24]}", - "object": "chat.completion.chunk", - "created": self.created, - "model": self.model, - "choices": [ - {"index": 0, "delta": delta, "logprobs": None, "finish_reason": finish} - ], - } - return f"data: {orjson.dumps(chunk).decode()}\n\n" - - def _build_video_html(self, video_url: str, thumbnail_url: str = "") -> str: - """构建视频 HTML 标签""" - import html - - safe_video_url = html.escape(video_url) - safe_thumbnail_url = html.escape(thumbnail_url) - poster_attr = f' poster="{safe_thumbnail_url}"' if safe_thumbnail_url else "" - return f'''''' - - async def process( - self, response: AsyncIterable[bytes] - ) -> AsyncGenerator[str, None]: - """处理视频流式响应""" - idle_timeout = get_config("timeout.video_idle_timeout") - - try: - async for line in _with_idle_timeout(response, idle_timeout, self.model): - line = _normalize_stream_line(line) - if not line: - continue - try: - data = orjson.loads(line) - except orjson.JSONDecodeError: - continue - - resp = data.get("result", {}).get("response", {}) - - if rid := resp.get("responseId"): - self.response_id = rid - - if not self.role_sent: - yield self._sse(role="assistant") - self.role_sent = True - - # 视频生成进度 - if video_resp := resp.get("streamingVideoGenerationResponse"): - progress = video_resp.get("progress", 0) - - if self.show_think: - if not self.think_opened: - yield self._sse("\n") - self.think_opened = True - yield self._sse(f"正在生成视频中,当前进度{progress}%\n") - - if progress == 100: - video_url = video_resp.get("videoUrl", "") - thumbnail_url = video_resp.get("thumbnailImageUrl", "") - - if self.think_opened and self.show_think: - yield self._sse("\n") - self.think_opened = False - - if video_url: - final_video_url = await self.process_url(video_url, "video") - final_thumbnail_url = "" - if thumbnail_url: - final_thumbnail_url = await self.process_url( - thumbnail_url, "image" - ) - - if self.video_format == "url": - yield self._sse(final_video_url) - else: - video_html = self._build_video_html( - final_video_url, final_thumbnail_url - ) - yield self._sse(video_html) - - logger.info(f"Video generated: {video_url}") - continue - - if self.think_opened: - yield self._sse("\n") - yield self._sse(finish="stop") - yield "data: [DONE]\n\n" - except asyncio.CancelledError: - logger.debug( - "Video stream cancelled by client", extra={"model": self.model} - ) - except StreamIdleTimeoutError as e: - raise UpstreamException( - message=f"Video stream idle timeout after {e.idle_seconds}s", - status_code=504, - details={ - "error": str(e), - "type": "stream_idle_timeout", - "idle_seconds": e.idle_seconds, - }, - ) - except RequestsError as e: - if _is_http2_stream_error(e): - logger.warning( - f"HTTP/2 stream error in video: {e}", extra={"model": self.model} - ) - raise UpstreamException( - message="Upstream connection closed unexpectedly", - status_code=502, - details={"error": str(e), "type": "http2_stream_error"}, - ) - logger.error( - f"Video stream request error: {e}", extra={"model": self.model} - ) - raise UpstreamException( - message=f"Upstream request failed: {e}", - status_code=502, - details={"error": str(e)}, - ) - except Exception as e: - logger.error( - f"Video stream processing error: {e}", - extra={"model": self.model, "error_type": type(e).__name__}, - ) - finally: - await self.close() - - -class VideoCollectProcessor(BaseProcessor): - """视频非流式响应处理器""" - - def __init__(self, model: str, token: str = ""): - super().__init__(model, token) - self.video_format = str(get_config("app.video_format")).lower() - - def _build_video_html(self, video_url: str, thumbnail_url: str = "") -> str: - poster_attr = f' poster="{thumbnail_url}"' if thumbnail_url else "" - return f'''''' - - async def process(self, response: AsyncIterable[bytes]) -> dict[str, Any]: - """处理并收集视频响应""" - response_id = "" - content = "" - idle_timeout = get_config("timeout.video_idle_timeout") - - try: - async for line in _with_idle_timeout(response, idle_timeout, self.model): - line = _normalize_stream_line(line) - if not line: - continue - try: - data = orjson.loads(line) - except orjson.JSONDecodeError: - continue - - resp = data.get("result", {}).get("response", {}) - - if video_resp := resp.get("streamingVideoGenerationResponse"): - if video_resp.get("progress") == 100: - response_id = resp.get("responseId", "") - video_url = video_resp.get("videoUrl", "") - thumbnail_url = video_resp.get("thumbnailImageUrl", "") - - if video_url: - final_video_url = await self.process_url(video_url, "video") - final_thumbnail_url = "" - if thumbnail_url: - final_thumbnail_url = await self.process_url( - thumbnail_url, "image" - ) - - if self.video_format == "url": - content = final_video_url - else: - content = self._build_video_html( - final_video_url, final_thumbnail_url - ) - logger.info(f"Video generated: {video_url}") - - except asyncio.CancelledError: - logger.debug( - "Video collect cancelled by client", extra={"model": self.model} - ) - except StreamIdleTimeoutError as e: - logger.warning( - f"Video collect idle timeout: {e}", extra={"model": self.model} - ) - except RequestsError as e: - if _is_http2_stream_error(e): - logger.warning( - f"HTTP/2 stream error in video collect: {e}", - extra={"model": self.model}, - ) - else: - logger.error( - f"Video collect request error: {e}", extra={"model": self.model} - ) - except Exception as e: - logger.error( - f"Video collect processing error: {e}", - extra={"model": self.model, "error_type": type(e).__name__}, - ) - finally: - await self.close() - - return { - "id": response_id, - "object": "chat.completion", - "created": self.created, - "model": self.model, - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": content, - "refusal": None, - }, - "finish_reason": "stop", - } - ], - "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, - } - - -__all__ = ["VideoStreamProcessor", "VideoCollectProcessor"] diff --git a/app/services/grok/protocols/__init__.py b/app/services/grok/protocols/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/app/services/grok/protocols/grpc_web.py b/app/services/grok/protocols/grpc_web.py deleted file mode 100644 index 0724727d..00000000 --- a/app/services/grok/protocols/grpc_web.py +++ /dev/null @@ -1,157 +0,0 @@ -""" -gRPC-Web 协议工具 - -提供 framing 编码/解码、trailer 解析等通用功能。 -支持 application/grpc-web+proto 和 application/grpc-web-text (base64) 两种格式。 -""" - -from __future__ import annotations - -import base64 -import re -import struct -from dataclasses import dataclass -from typing import Dict, List, Mapping, Tuple -from urllib.parse import unquote - - -_B64_RE = re.compile(rb"^[A-Za-z0-9+/=\r\n]+$") - - -def encode_grpc_web_payload(data: bytes) -> bytes: - """ - 编码 gRPC-Web data frame - - Frame format: - 1-byte flags + 4-byte big-endian length + message bytes - """ - return b"\x00" + struct.pack(">I", len(data)) + data - - -def _maybe_decode_grpc_web_text(body: bytes, content_type: str | None) -> bytes: - """处理 grpc-web-text 模式的 base64 解码""" - ct = (content_type or "").lower() - if "grpc-web-text" in ct: - compact = b"".join(body.split()) - return base64.b64decode(compact, validate=False) - - # 启发式:body 仅包含 base64 字符才尝试解码 - head = body[: min(len(body), 2048)] - if head and _B64_RE.fullmatch(head): - compact = b"".join(body.split()) - try: - return base64.b64decode(compact, validate=True) - except Exception: - return body - return body - - -def _parse_trailer_block(payload: bytes) -> Dict[str, str]: - """解析 trailer frame 内容""" - text = payload.decode("utf-8", errors="replace") - lines = [ln for ln in re.split(r"\r\n|\n", text) if ln] - - trailers: Dict[str, str] = {} - for ln in lines: - if ":" not in ln: - continue - k, v = ln.split(":", 1) - trailers[k.strip().lower()] = v.strip() - - # grpc-message 可能是 percent-encoding - if "grpc-message" in trailers: - trailers["grpc-message"] = unquote(trailers["grpc-message"]) - - return trailers - - -def parse_grpc_web_response( - body: bytes, - content_type: str | None = None, - headers: Mapping[str, str] | None = None, -) -> Tuple[List[bytes], Dict[str, str]]: - """ - 解析 gRPC-Web 响应 - - Returns: - (messages, trailers): data frames 列表和合并后的 trailers - """ - decoded = _maybe_decode_grpc_web_text(body, content_type) - - messages: List[bytes] = [] - trailers: Dict[str, str] = {} - - i = 0 - n = len(decoded) - while i < n: - if n - i < 5: - break - - flag = decoded[i] - length = int.from_bytes(decoded[i + 1 : i + 5], "big") - i += 5 - - if n - i < length: - break - - payload = decoded[i : i + length] - i += length - - if flag & 0x80: # trailer frame - trailers.update(_parse_trailer_block(payload)) - elif flag & 0x01: # compressed (不支持) - raise ValueError("grpc-web compressed flag not supported") - else: - messages.append(payload) - - # 兼容:grpc-status 可能在 response headers 中 - if headers: - lower = {k.lower(): v for k, v in headers.items()} - if "grpc-status" in lower and "grpc-status" not in trailers: - trailers["grpc-status"] = str(lower["grpc-status"]).strip() - if "grpc-message" in lower and "grpc-message" not in trailers: - trailers["grpc-message"] = unquote(str(lower["grpc-message"]).strip()) - - return messages, trailers - - -@dataclass(frozen=True) -class GrpcStatus: - code: int - message: str = "" - - @property - def ok(self) -> bool: - return self.code == 0 - - @property - def http_equiv(self) -> int: - """映射到类 HTTP 状态码""" - mapping = { - 0: 200, # OK - 16: 401, # UNAUTHENTICATED - 7: 403, # PERMISSION_DENIED - 8: 429, # RESOURCE_EXHAUSTED - 4: 504, # DEADLINE_EXCEEDED - 14: 503, # UNAVAILABLE - } - return mapping.get(self.code, 502) - - -def get_grpc_status(trailers: Mapping[str, str]) -> GrpcStatus: - """从 trailers 提取 gRPC 状态""" - raw = str(trailers.get("grpc-status", "")).strip() - msg = str(trailers.get("grpc-message", "")).strip() - try: - code = int(raw) - except Exception: - code = -1 - return GrpcStatus(code=code, message=msg) - - -__all__ = [ - "encode_grpc_web_payload", - "parse_grpc_web_response", - "get_grpc_status", - "GrpcStatus", -] diff --git a/app/services/grok/services/__init__.py b/app/services/grok/services/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/app/services/grok/services/assets.py b/app/services/grok/services/assets.py deleted file mode 100644 index 9197f853..00000000 --- a/app/services/grok/services/assets.py +++ /dev/null @@ -1,780 +0,0 @@ -""" -Grok 文件资产服务 -""" - -import asyncio -import base64 -import hashlib -import os -import re -import time -from contextlib import asynccontextmanager -from dataclasses import dataclass -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple -from urllib.parse import urlparse - -try: - import fcntl -except ImportError: - fcntl = None - -import aiofiles -from curl_cffi.requests import AsyncSession - -from app.core.config import get_config -from app.core.exceptions import AppException, UpstreamException, ValidationException -from app.core.logger import logger -from app.core.storage import DATA_DIR -from app.services.grok.utils.headers import apply_statsig, build_sso_cookie -from app.services.token.service import TokenService - -# ==================== 常量 ==================== - -UPLOAD_API = "https://grok.com/rest/app-chat/upload-file" -LIST_API = "https://grok.com/rest/assets" -DELETE_API = "https://grok.com/rest/assets-metadata" -DOWNLOAD_API = "https://assets.grok.com" -LOCK_DIR = DATA_DIR / ".locks" - -# 全局信号量(运行时动态初始化) -_ASSETS_SEMAPHORE = None -_ASSETS_SEM_VALUE = None - -# 常用 MIME 类型(业务数据,非配置) -MIME_TYPES = { - ".jpg": "image/jpeg", - ".jpeg": "image/jpeg", - ".png": "image/png", - ".gif": "image/gif", - ".webp": "image/webp", - ".bmp": "image/bmp", - ".pdf": "application/pdf", - ".txt": "text/plain", - ".md": "text/markdown", - ".csv": "text/csv", - ".json": "application/json", - ".xml": "application/xml", - ".py": "text/x-python-script", - ".js": "application/javascript", - ".html": "text/html", - ".css": "text/css", - ".mp4": "video/mp4", - ".webm": "video/webm", -} - -IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"} -VIDEO_EXTS = {".mp4", ".mov", ".m4v", ".webm", ".avi", ".mkv"} - -# ==================== 工具函数 ==================== - - -def _get_assets_semaphore() -> asyncio.Semaphore: - """获取全局并发控制信号量""" - value = max(1, int(get_config("performance.assets_max_concurrent"))) - - global _ASSETS_SEMAPHORE, _ASSETS_SEM_VALUE - if _ASSETS_SEMAPHORE is None or value != _ASSETS_SEM_VALUE: - _ASSETS_SEM_VALUE = value - _ASSETS_SEMAPHORE = asyncio.Semaphore(value) - return _ASSETS_SEMAPHORE - - -@asynccontextmanager -async def _file_lock(name: str, timeout: int = 10): - """文件锁""" - if fcntl is None: - yield - return - - LOCK_DIR.mkdir(parents=True, exist_ok=True) - lock_path = LOCK_DIR / f"{name}.lock" - fd = None - locked = False - start = time.monotonic() - - try: - fd = open(lock_path, "a+") - while True: - try: - fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB) - locked = True - break - except BlockingIOError: - if time.monotonic() - start >= timeout: - break - await asyncio.sleep(0.05) - yield - finally: - if fd: - if locked: - try: - fcntl.flock(fd, fcntl.LOCK_UN) - except Exception: - pass - fd.close() - - -@dataclass -class ServiceConfig: - """服务配置""" - - proxy: str - timeout: int - browser: str - user_agent: str - - @classmethod - def from_settings(cls, proxy: Optional[str] = None): - return cls( - proxy=proxy - or get_config("network.asset_proxy_url") - or get_config("network.base_proxy_url"), - timeout=get_config("network.timeout"), - browser=get_config("security.browser"), - user_agent=get_config("security.user_agent"), - ) - - def get_proxies(self) -> Optional[dict]: - return {"http": self.proxy, "https": self.proxy} if self.proxy else None - - -# ==================== 基础服务 ==================== - - -class BaseService: - """基础服务类""" - - def __init__(self, proxy: Optional[str] = None): - self.config = ServiceConfig.from_settings(proxy) - self._session: Optional[AsyncSession] = None - - def _build_headers( - self, token: str, referer: str = "https://grok.com/", download: bool = False - ) -> dict: - """构建请求头""" - if download: - headers = { - "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7", - "Sec-Fetch-Dest": "document", - "Sec-Fetch-Mode": "navigate", - "Sec-Fetch-Site": "same-site", - "Sec-Fetch-User": "?1", - "Referer": referer, - "User-Agent": self.config.user_agent, - } - else: - headers = { - "Accept": "*/*", - "Content-Type": "application/json", - "Origin": "https://grok.com", - "Referer": referer, - "User-Agent": self.config.user_agent, - } - apply_statsig(headers) - - headers["Cookie"] = build_sso_cookie(token) - return headers - - async def _get_session(self) -> AsyncSession: - """获取复用 Session""" - if self._session is None: - self._session = AsyncSession() - return self._session - - async def close(self): - """关闭 Session""" - if self._session: - await self._session.close() - self._session = None - - @staticmethod - def is_url(s: str) -> bool: - """检查是否为 URL""" - try: - r = urlparse(s) - return bool(r.scheme and r.netloc and r.scheme in ["http", "https"]) - except Exception: - return False - - @staticmethod - async def fetch(url: str) -> Tuple[str, str, str]: - """获取远程资源并转 Base64""" - try: - async with AsyncSession() as session: - response = await session.get(url, timeout=10) - if response.status_code >= 400: - raise UpstreamException( - message=f"Failed to fetch: {response.status_code}", - details={"url": url, "status": response.status_code}, - ) - - filename = url.split("/")[-1].split("?")[0] or "download" - content_type = response.headers.get( - "content-type", "application/octet-stream" - ).split(";")[0] - b64 = base64.b64encode(response.content).decode() - - logger.debug(f"Fetched: {url}") - return filename, b64, content_type - except Exception as e: - if isinstance(e, AppException): - raise - logger.error(f"Fetch failed: {url} - {e}") - raise UpstreamException(f"Fetch failed: {str(e)}", details={"url": url}) - - @staticmethod - def parse_b64(data_uri: str) -> Tuple[str, str, str]: - """解析 Base64 数据""" - if not data_uri.startswith("data:"): - return "file.bin", data_uri, "application/octet-stream" - - try: - header, b64 = data_uri.split(",", 1) - except ValueError: - return "file.bin", data_uri, "application/octet-stream" - - if ";base64" not in header: - return "file.bin", data_uri, "application/octet-stream" - - mime = header[5:].split(";", 1)[0] or "application/octet-stream" - b64 = re.sub(r"\s+", "", b64) - ext = mime.split("/")[-1] if "/" in mime else "bin" - return f"file.{ext}", b64, mime - - @staticmethod - def to_b64(file_path: Path, mime_type: str) -> str: - """文件转 base64 data URI""" - try: - if not file_path.exists(): - logger.warning(f"File not found for base64 conversion: {file_path}") - raise AppException( - f"File not found: {file_path}", code="file_not_found" - ) - - if not file_path.is_file(): - logger.warning(f"Path is not a file: {file_path}") - raise AppException( - f"Invalid file path: {file_path}", code="invalid_file_path" - ) - - b64_data = base64.b64encode(file_path.read_bytes()).decode() - return f"data:{mime_type};base64,{b64_data}" - except AppException: - raise - except Exception as e: - logger.error(f"File to base64 failed: {file_path} - {e}") - raise AppException( - f"Failed to read file: {file_path}", code="file_read_error" - ) - - -# ==================== 上传服务 ==================== - - -class UploadService(BaseService): - """文件上传服务""" - - async def upload(self, file_input: str, token: str) -> Tuple[str, str]: - """ - 上传文件到 Grok - - Returns: - (file_id, file_uri) - """ - async with _get_assets_semaphore(): - # 处理输入 - if self.is_url(file_input): - filename, b64, mime = await self.fetch(file_input) - else: - filename, b64, mime = self.parse_b64(file_input) - - logger.debug( - f"Upload prepare: filename={filename}, type={mime}, size={len(b64)}" - ) - - if not b64: - raise ValidationException("Invalid file input: empty content") - - # 执行上传 - session = await self._get_session() - response = await session.post( - UPLOAD_API, - headers=self._build_headers(token), - json={"fileName": filename, "fileMimeType": mime, "content": b64}, - impersonate=self.config.browser, - timeout=self.config.timeout, - proxies=self.config.get_proxies(), - ) - - # 处理响应 - if response.status_code == 200: - result = response.json() - file_id = result.get("fileMetadataId", "") - file_uri = result.get("fileUri", "") - logger.info(f"Upload success: {filename} -> {file_id}") - return file_id, file_uri - - # 认证失败 - if response.status_code in (401, 403): - logger.warning(f"Upload auth failed: {response.status_code}") - try: - await TokenService.record_fail( - token, response.status_code, "upload_auth_failed" - ) - except Exception as e: - logger.error(f"Failed to record token failure: {e}") - - raise UpstreamException( - message=f"Upload authentication failed: {response.status_code}", - details={"status": response.status_code, "token_invalidated": True}, - ) - - # 其他错误 - logger.error(f"Upload failed: {filename} - {response.status_code}") - raise UpstreamException( - message=f"Upload failed: {response.status_code}", - details={"status": response.status_code}, - ) - - -# ==================== 列表服务 ==================== - - -class ListService(BaseService): - """文件列表查询服务""" - - async def iter_assets(self, token: str): - """分页迭代资产列表""" - headers = self._build_headers(token, referer="https://grok.com/files") - params = { - "pageSize": 50, - "orderBy": "ORDER_BY_LAST_USE_TIME", - "source": "SOURCE_ANY", - "isLatest": "true", - } - page_token = None - seen_tokens = set() - - async with AsyncSession() as session: - while True: - if page_token: - if page_token in seen_tokens: - logger.warning("Pagination stopped: repeated page token") - break - seen_tokens.add(page_token) - params["pageToken"] = page_token - else: - params.pop("pageToken", None) - - response = await session.get( - LIST_API, - headers=headers, - params=params, - impersonate=self.config.browser, - timeout=self.config.timeout, - proxies=self.config.get_proxies(), - ) - - if response.status_code != 200: - raise UpstreamException( - message=f"List failed: {response.status_code}", - details={"status": response.status_code}, - ) - - result = response.json() - page_assets = result.get("assets", []) - yield page_assets - - page_token = result.get("nextPageToken") - if not page_token: - break - - async def list(self, token: str) -> List[Dict]: - """查询文件列表""" - assets = [] - async for page_assets in self.iter_assets(token): - assets.extend(page_assets) - logger.info(f"List success: {len(assets)} files") - return assets - - async def count(self, token: str) -> int: - """统计资产数量""" - total = 0 - async for page_assets in self.iter_assets(token): - total += len(page_assets) - logger.debug(f"Asset count: {total}") - return total - - -# ==================== 删除服务 ==================== - - -class DeleteService(BaseService): - """文件删除服务""" - - async def delete(self, token: str, asset_id: str) -> bool: - """删除单个文件""" - async with _get_assets_semaphore(): - session = await self._get_session() - response = await session.delete( - f"{DELETE_API}/{asset_id}", - headers=self._build_headers(token, referer="https://grok.com/files"), - impersonate=self.config.browser, - timeout=self.config.timeout, - proxies=self.config.get_proxies(), - ) - - if response.status_code == 200: - logger.debug(f"Deleted: {asset_id}") - return True - - logger.error(f"Delete failed: {asset_id} - {response.status_code}") - raise UpstreamException( - message=f"Delete failed: {asset_id}", - details={"status": response.status_code}, - ) - - async def delete_all(self, token: str) -> Dict[str, int]: - """删除所有文件""" - total = success = failed = 0 - list_service = ListService(self.config.proxy) - - try: - async for assets in list_service.iter_assets(token): - if not assets: - continue - - total += len(assets) - batch_result = await self._delete_batch(token, assets) - success += batch_result["success"] - failed += batch_result["failed"] - - if total == 0: - logger.info("No assets to delete") - return {"total": 0, "success": 0, "failed": 0, "skipped": True} - finally: - await list_service.close() - - logger.info(f"Delete all: total={total}, success={success}, failed={failed}") - return {"total": total, "success": success, "failed": failed} - - async def _delete_batch(self, token: str, assets: List[Dict]) -> Dict[str, int]: - """批量删除""" - batch_size = max(1, int(get_config("performance.assets_delete_batch_size"))) - success = failed = 0 - - for i in range(0, len(assets), batch_size): - batch = assets[i : i + batch_size] - results = await asyncio.gather( - *[ - self._delete_one(token, asset, idx) - for idx, asset in enumerate(batch) - ], - return_exceptions=True, - ) - success += sum(1 for r in results if r is True) - failed += sum(1 for r in results if r is not True) - - return {"success": success, "failed": failed} - - async def _delete_one(self, token: str, asset: Dict, index: int) -> bool: - """删除单个资产(带延迟)""" - await asyncio.sleep(0.01 * index) - asset_id = asset.get("assetId", "") - if not asset_id: - return False - try: - return await self.delete(token, asset_id) - except Exception: - return False - - -# ==================== 下载服务 ==================== - - -class DownloadService(BaseService): - """文件下载服务""" - - def __init__(self, proxy: Optional[str] = None): - super().__init__(proxy) - self.base_dir = DATA_DIR / "tmp" - self.image_dir = self.base_dir / "image" - self.video_dir = self.base_dir / "video" - self.image_dir.mkdir(parents=True, exist_ok=True) - self.video_dir.mkdir(parents=True, exist_ok=True) - self._cleanup_running = False - - def _cache_path(self, file_path: str, media_type: str) -> Path: - """获取缓存路径""" - cache_dir = self.image_dir if media_type == "image" else self.video_dir - filename = file_path.lstrip("/").replace("/", "-") - return cache_dir / filename - - def _get_mime(self, cache_path: Path, response=None) -> str: - """获取 MIME 类型""" - if response: - return response.headers.get( - "content-type", "application/octet-stream" - ).split(";")[0] - return MIME_TYPES.get(cache_path.suffix.lower(), "application/octet-stream") - - async def download( - self, file_path: str, token: str, media_type: str = "image" - ) -> Tuple[Optional[Path], str]: - """下载文件到本地""" - async with _get_assets_semaphore(): - cache_path = self._cache_path(file_path, media_type) - - # 检查缓存 - if cache_path.exists(): - logger.debug(f"Cache hit: {cache_path}") - return cache_path, self._get_mime(cache_path) - - # 文件锁防止并发下载 - lock_name = f"dl_{media_type}_{hashlib.sha1(str(cache_path).encode()).hexdigest()[:16]}" - async with _file_lock(lock_name, timeout=10): - # 双重检查 - if cache_path.exists(): - return cache_path, self._get_mime(cache_path) - - # 执行下载 - mime = await self._download_file(file_path, token, cache_path) - logger.info(f"Downloaded: {file_path}") - - # 异步检查缓存限制 - asyncio.create_task(self.check_limit()) - - return cache_path, mime - - async def _download_file(self, file_path: str, token: str, cache_path: Path) -> str: - """执行下载""" - if not file_path.startswith("/"): - file_path = f"/{file_path}" - - url = f"{DOWNLOAD_API}{file_path}" - headers = self._build_headers(token, download=True) - - session = await self._get_session() - response = await session.get( - url, - headers=headers, - proxies=self.config.get_proxies(), - timeout=self.config.timeout, - allow_redirects=True, - impersonate=self.config.browser, - stream=True, - ) - - if response.status_code != 200: - raise UpstreamException( - message=f"Download failed: {response.status_code}", - details={"path": file_path, "status": response.status_code}, - ) - - # 保存文件 - tmp_path = cache_path.with_suffix(cache_path.suffix + ".tmp") - try: - async with aiofiles.open(tmp_path, "wb") as f: - # 尝试流式写入 - if hasattr(response, "aiter_content"): - async for chunk in response.aiter_content(): - if chunk: - await f.write(chunk) - else: - await f.write(response.content) - os.replace(tmp_path, cache_path) - finally: - if tmp_path.exists() and not cache_path.exists(): - try: - tmp_path.unlink() - except Exception: - pass - - return self._get_mime(cache_path, response) - - async def to_base64( - self, file_path: str, token: str, media_type: str = "image" - ) -> str: - """下载并转 base64""" - try: - cache_path, mime = await self.download(file_path, token, media_type) - if not cache_path or not cache_path.exists(): - logger.warning(f"Download failed for {file_path}: invalid path") - raise AppException( - "Download failed: invalid path", code="download_failed" - ) - - data_uri = self.to_b64(cache_path, mime) - - # 删除临时文件 - if data_uri: - try: - cache_path.unlink() - except Exception as e: - logger.debug(f"Failed to cleanup temp file {cache_path}: {e}") - - return data_uri - except Exception as e: - logger.error(f"Failed to convert {file_path} to base64: {e}") - raise - - def get_stats(self, media_type: str = "image") -> Dict[str, Any]: - """获取缓存统计""" - cache_dir = self.image_dir if media_type == "image" else self.video_dir - if not cache_dir.exists(): - return {"count": 0, "size_mb": 0.0} - - allowed = IMAGE_EXTS if media_type == "image" else VIDEO_EXTS - files = [ - f - for f in cache_dir.glob("*") - if f.is_file() and f.suffix.lower() in allowed - ] - total_size = sum(f.stat().st_size for f in files) - return {"count": len(files), "size_mb": round(total_size / 1024 / 1024, 2)} - - def list_files( - self, media_type: str = "image", page: int = 1, page_size: int = 1000 - ) -> Dict[str, Any]: - """列出缓存文件""" - cache_dir = self.image_dir if media_type == "image" else self.video_dir - if not cache_dir.exists(): - return {"total": 0, "page": page, "page_size": page_size, "items": []} - - allowed = IMAGE_EXTS if media_type == "image" else VIDEO_EXTS - files = [ - f - for f in cache_dir.glob("*") - if f.is_file() and f.suffix.lower() in allowed - ] - - # 构建文件列表 - items = [] - for f in files: - try: - stat = f.stat() - items.append( - { - "name": f.name, - "size_bytes": stat.st_size, - "mtime_ms": int(stat.st_mtime * 1000), - } - ) - except Exception: - continue - - items.sort(key=lambda x: x["mtime_ms"], reverse=True) - - # 分页 - total = len(items) - start = max(0, (page - 1) * page_size) - paged = items[start : start + page_size] - - # 添加 URL - for item in paged: - item["view_url"] = f"/v1/files/{media_type}/{item['name']}" - - return {"total": total, "page": page, "page_size": page_size, "items": paged} - - def delete_file(self, media_type: str, name: str) -> Dict[str, Any]: - """删除缓存文件""" - cache_dir = self.image_dir if media_type == "image" else self.video_dir - file_path = cache_dir / name.replace("/", "-") - - if file_path.exists(): - try: - file_path.unlink() - return {"deleted": True} - except Exception: - pass - return {"deleted": False} - - def clear(self, media_type: str = "image") -> Dict[str, Any]: - """清空缓存""" - cache_dir = self.image_dir if media_type == "image" else self.video_dir - if not cache_dir.exists(): - return {"count": 0, "size_mb": 0.0} - - files = list(cache_dir.glob("*")) - total_size = sum(f.stat().st_size for f in files if f.is_file()) - count = 0 - - for f in files: - if f.is_file(): - try: - f.unlink() - count += 1 - except Exception: - pass - - return {"count": count, "size_mb": round(total_size / 1024 / 1024, 2)} - - async def check_limit(self): - """检查并清理缓存""" - if self._cleanup_running or not get_config("cache.enable_auto_clean"): - return - - self._cleanup_running = True - try: - async with _file_lock("cache_cleanup", timeout=5): - limit_mb = get_config("cache.limit_mb") - all_files, total_size = self._collect_files() - current_mb = total_size / 1024 / 1024 - - if current_mb <= limit_mb: - return - - # 清理到 80% - logger.info( - f"Cache limit exceeded ({current_mb:.2f}MB > {limit_mb}MB), cleaning..." - ) - all_files.sort(key=lambda x: x[1]) # 按时间排序 - - deleted_count = 0 - deleted_size = 0 - target_mb = limit_mb * 0.8 - - for f, _, size in all_files: - try: - f.unlink() - deleted_count += 1 - deleted_size += size - total_size -= size - if (total_size / 1024 / 1024) <= target_mb: - break - except Exception: - pass - - logger.info( - f"Cache cleanup: {deleted_count} files ({deleted_size / 1024 / 1024:.2f}MB)" - ) - finally: - self._cleanup_running = False - - def _collect_files(self) -> Tuple[List[Tuple[Path, float, int]], int]: - """收集所有缓存文件""" - total_size = 0 - all_files = [] - - for d in [self.image_dir, self.video_dir]: - if d.exists(): - for f in d.glob("*"): - if f.is_file(): - try: - stat = f.stat() - total_size += stat.st_size - all_files.append((f, stat.st_mtime, stat.st_size)) - except Exception: - pass - - return all_files, total_size - - -__all__ = [ - "BaseService", - "UploadService", - "ListService", - "DeleteService", - "DownloadService", -] diff --git a/app/services/grok/services/chat.py b/app/services/grok/services/chat.py index 0fee80e8..260bb438 100644 --- a/app/services/grok/services/chat.py +++ b/app/services/grok/services/chat.py @@ -2,56 +2,59 @@ Grok Chat 服务 """ -import orjson -from typing import Dict, List, Any -from dataclasses import dataclass +import asyncio +import re +import uuid +from typing import Dict, List, Any, AsyncGenerator, AsyncIterable +import orjson from curl_cffi.requests import AsyncSession +from curl_cffi.requests.errors import RequestsError from app.core.logger import logger from app.core.config import get_config from app.core.exceptions import ( AppException, - UpstreamException, ValidationException, ErrorType, + UpstreamException, + StreamIdleTimeoutError, ) -from app.services.grok.models.model import ModelService -from app.services.grok.services.assets import UploadService -from app.services.grok.processors import StreamProcessor, CollectProcessor -from app.services.grok.utils.retry import retry_on_status -from app.services.grok.utils.headers import apply_statsig, build_sso_cookie +from app.services.grok.services.model import ModelService +from app.services.grok.utils.upload import UploadService +from app.services.grok.utils import process as proc_base +from app.services.grok.utils.retry import pick_token, rate_limited +from app.services.reverse.app_chat import AppChatReverse from app.services.grok.utils.stream import wrap_stream_with_usage from app.services.token import get_token_manager, EffortType -CHAT_API = "https://grok.com/rest/app-chat/conversations/new" +_CHAT_SEMAPHORE = None +_CHAT_SEM_VALUE = None -@dataclass -class ChatRequest: - """聊天请求数据""" - - model: str - messages: List[Dict[str, Any]] - stream: bool = None - think: bool = None +def _get_chat_semaphore() -> asyncio.Semaphore: + global _CHAT_SEMAPHORE, _CHAT_SEM_VALUE + value = max(1, int(get_config("chat.concurrent"))) + if value != _CHAT_SEM_VALUE: + _CHAT_SEM_VALUE = value + _CHAT_SEMAPHORE = asyncio.Semaphore(value) + return _CHAT_SEMAPHORE class MessageExtractor: """消息内容提取器""" @staticmethod - def extract( - messages: List[Dict[str, Any]], is_video: bool = False - ) -> tuple[str, List[tuple[str, str]]]: - """从 OpenAI 消息格式提取内容,返回 (text, attachments)""" + def extract(messages: List[Dict[str, Any]]) -> tuple[str, List[str], List[str]]: + """从 OpenAI 消息格式提取内容,返回 (text, file_attachments, image_attachments)""" texts = [] - attachments = [] + file_attachments: List[str] = [] + image_attachments: List[str] = [] extracted = [] for msg in messages: - role = msg.get("role", "") + role = msg.get("role", "") or "user" content = msg.get("content", "") parts = [] @@ -68,35 +71,21 @@ def extract( elif item_type == "image_url": image_data = item.get("image_url", {}) - url = ( - image_data.get("url", "") - if isinstance(image_data, dict) - else str(image_data) - ) + url = image_data.get("url", "") if url: - attachments.append(("image", url)) + image_attachments.append(url) elif item_type == "input_audio": - if is_video: - raise ValueError("视频模型不支持 input_audio 类型") audio_data = item.get("input_audio", {}) - data = ( - audio_data.get("data", "") - if isinstance(audio_data, dict) - else str(audio_data) - ) + data = audio_data.get("data", "") if data: - attachments.append(("audio", data)) + file_attachments.append(data) elif item_type == "file": - if is_video: - raise ValueError("视频模型不支持 file 类型") file_data = item.get("file", {}) - url = file_data.get("url", "") or file_data.get("data", "") - if isinstance(file_data, str): - url = file_data - if url: - attachments.append(("file", url)) + raw = file_data.get("file_data", "") + if raw: + file_attachments.append(raw) if parts: extracted.append({"role": role, "text": "\n".join(parts)}) @@ -116,101 +105,12 @@ def extract( text = item["text"] texts.append(text if i == last_user_index else f"{role}: {text}") - return "\n\n".join(texts), attachments - - -class ChatRequestBuilder: - """请求构造器""" - - @staticmethod - def build_headers(token: str) -> Dict[str, str]: - """构造请求头""" - user_agent = get_config("security.user_agent") - headers = { - "Accept": "*/*", - "Accept-Encoding": "gzip, deflate, br, zstd", - "Accept-Language": "zh-CN,zh;q=0.9", - "Baggage": "sentry-environment=production,sentry-release=d6add6fb0460641fd482d767a335ef72b9b6abb8,sentry-public_key=b311e0f2690c81f25e2c4cf6d4f7ce1c", - "Cache-Control": "no-cache", - "Content-Type": "application/json", - "Origin": "https://grok.com", - "Pragma": "no-cache", - "Priority": "u=1, i", - "Referer": "https://grok.com/", - "Sec-Ch-Ua": '"Google Chrome";v="136", "Chromium";v="136", "Not(A:Brand";v="24"', - "Sec-Ch-Ua-Arch": "arm", - "Sec-Ch-Ua-Bitness": "64", - "Sec-Ch-Ua-Mobile": "?0", - "Sec-Ch-Ua-Model": "", - "Sec-Ch-Ua-Platform": '"macOS"', - "Sec-Fetch-Dest": "empty", - "Sec-Fetch-Mode": "cors", - "Sec-Fetch-Site": "same-origin", - "User-Agent": user_agent, - } - - apply_statsig(headers) - headers["Cookie"] = build_sso_cookie(token) - - return headers - - @staticmethod - def build_payload( - message: str, - model: str, - mode: str = None, - file_attachments: List[str] = None, - image_attachments: List[str] = None, - ) -> Dict[str, Any]: - """构造请求体""" - merged_attachments = [] - if file_attachments: - merged_attachments.extend(file_attachments) - if image_attachments: - merged_attachments.extend(image_attachments) - - payload = { - "temporary": get_config("chat.temporary"), - "modelName": model, - "message": message, - "fileAttachments": merged_attachments, - "imageAttachments": [], - "disableSearch": False, - "enableImageGeneration": True, - "returnImageBytes": False, - "enableImageStreaming": True, - "imageGenerationCount": 2, - "forceConcise": False, - "toolOverrides": {}, - "enableSideBySide": True, - "sendFinalMetadata": True, - "responseMetadata": { - "modelConfigOverride": {"modelMap": {}}, - "requestModelDetails": {"modelId": model}, - }, - "disableMemory": get_config("chat.disable_memory"), - "deviceEnvInfo": { - "darkModeEnabled": False, - "devicePixelRatio": 2, - "screenWidth": 2056, - "screenHeight": 1329, - "viewportWidth": 2056, - "viewportHeight": 1083, - }, - } - - if mode: - payload["modelMode"] = mode - - return payload + return "\n\n".join(texts), file_attachments, image_attachments class GrokChatService: """Grok API 调用服务""" - def __init__(self, proxy: str = None): - self.proxy = proxy or get_config("network.base_proxy_url") - async def chat( self, token: str, @@ -219,149 +119,97 @@ async def chat( mode: str = None, stream: bool = None, file_attachments: List[str] = None, - image_attachments: List[str] = None, - raw_payload: Dict[str, Any] = None, + tool_overrides: Dict[str, Any] = None, + model_config_override: Dict[str, Any] = None, ): """发送聊天请求""" if stream is None: - stream = get_config("chat.stream") - - headers = ChatRequestBuilder.build_headers(token) - payload = ( - raw_payload - if raw_payload is not None - else ChatRequestBuilder.build_payload( - message, model, mode, file_attachments, image_attachments - ) - ) - proxies = {"http": self.proxy, "https": self.proxy} if self.proxy else None - timeout = get_config("network.timeout") + stream = get_config("app.stream") logger.debug( f"Chat request: model={model}, mode={mode}, stream={stream}, attachments={len(file_attachments or [])}" ) - # 建立连接 - async def establish_connection(): - browser = get_config("security.browser") + browser = get_config("proxy.browser") + + async def _stream(): session = AsyncSession(impersonate=browser) try: - response = await session.post( - CHAT_API, - headers=headers, - data=orjson.dumps(payload), - timeout=timeout, - stream=True, - proxies=proxies, - ) - - if response.status_code != 200: - content = "" - try: - content = await response.text() - except Exception: - pass - - logger.error( - f"Chat failed: status={response.status_code}, token={token[:10]}..." + async with _get_chat_semaphore(): + stream_response = await AppChatReverse.request( + session, + token, + message=message, + model=model, + mode=mode, + file_attachments=file_attachments, + tool_overrides=tool_overrides, + model_config_override=model_config_override, ) - + logger.info(f"Chat connected: model={model}, stream={stream}") + async for line in stream_response: + yield line + except Exception: + try: await session.close() - raise UpstreamException( - message=f"Grok API request failed: {response.status_code}", - details={"status": response.status_code, "body": content}, - ) - - logger.info(f"Chat connected: model={model}, stream={stream}") - return session, response - - except UpstreamException: + except Exception: + pass raise - except Exception as e: - logger.error(f"Chat request error: {e}") - await session.close() - raise UpstreamException( - message=f"Chat connection failed: {str(e)}", - details={"error": str(e)}, - ) - # 重试机制 - def extract_status(e: Exception) -> int | None: - if isinstance(e, UpstreamException) and e.details: - status = e.details.get("status") - # 429 不在内层重试,由外层跨 token 重试处理 - if status == 429: - return None - return status - return None - - session = None - response = None - try: - session, response = await retry_on_status( - establish_connection, extract_status=extract_status - ) - except Exception as e: - status_code = extract_status(e) - if status_code: - token_mgr = await get_token_manager() - reason = str(e) - if isinstance(e, UpstreamException) and e.details: - body = e.details.get("body") - if body: - reason = f"{reason} | body: {body}" - await token_mgr.record_fail(token, status_code, reason) - raise + return _stream() - # 流式传输 - async def stream_response(): - try: - async for line in response.aiter_lines(): - yield line - finally: - if session: - await session.close() - - return stream_response() - - async def chat_openai(self, token: str, request: ChatRequest): + async def chat_openai( + self, + token: str, + model: str, + messages: List[Dict[str, Any]], + stream: bool = None, + reasoning_effort: str | None = None, + temperature: float = 0.8, + top_p: float = 0.95, + ): """OpenAI 兼容接口""" - model_info = ModelService.get(request.model) + model_info = ModelService.get(model) if not model_info: - raise ValidationException(f"Unknown model: {request.model}") + raise ValidationException(f"Unknown model: {model}") grok_model = model_info.grok_model mode = model_info.model_mode - is_video = model_info.is_video - # 提取消息和附件 - try: - message, attachments = MessageExtractor.extract( - request.messages, is_video=is_video - ) - logger.debug( - f"Extracted message length={len(message)}, attachments={len(attachments)}" - ) - except ValueError as e: - raise ValidationException(str(e)) + message, file_attachments, image_attachments = MessageExtractor.extract(messages) + logger.debug( + "Extracted message length=%s, files=%s, images=%s", + len(message), + len(file_attachments), + len(image_attachments), + ) # 上传附件 - file_ids = [] - if attachments: + file_ids: List[str] = [] + image_ids: List[str] = [] + if file_attachments or image_attachments: upload_service = UploadService() try: - for attach_type, attach_data in attachments: - file_id, _ = await upload_service.upload(attach_data, token) + for attach_data in file_attachments: + file_id, _ = await upload_service.upload_file(attach_data, token) file_ids.append(file_id) - logger.debug( - f"Attachment uploaded: type={attach_type}, file_id={file_id}" - ) + logger.debug(f"Attachment uploaded: type=file, file_id={file_id}") + for attach_data in image_attachments: + file_id, _ = await upload_service.upload_file(attach_data, token) + image_ids.append(file_id) + logger.debug(f"Attachment uploaded: type=image, file_id={file_id}") finally: await upload_service.close() - stream = ( - request.stream if request.stream is not None else get_config("chat.stream") - ) + all_attachments = file_ids + image_ids + stream = stream if stream is not None else get_config("app.stream") + + model_config_override = { + "temperature": temperature, + "topP": top_p, + } + if reasoning_effort is not None: + model_config_override["reasoningEffort"] = reasoning_effort response = await self.chat( token, @@ -369,11 +217,11 @@ async def chat_openai(self, token: str, request: ChatRequest): grok_model, mode, stream, - file_attachments=file_ids, - image_attachments=[], + file_attachments=all_attachments, + model_config_override=model_config_override, ) - return response, stream, request.model + return response, stream, model class ChatService: @@ -384,21 +232,21 @@ async def completions( model: str, messages: List[Dict[str, Any]], stream: bool = None, - thinking: str = None, + reasoning_effort: str | None = None, + temperature: float = 0.8, + top_p: float = 0.95, ): """Chat Completions 入口""" # 获取 token token_mgr = await get_token_manager() await token_mgr.reload_if_stale() - # 解析参数(只需解析一次) - think = {"enabled": True, "disabled": False}.get(thinking) - is_stream = stream if stream is not None else get_config("chat.stream") - - # 构造请求(只需构造一次) - chat_request = ChatRequest( - model=model, messages=messages, stream=is_stream, think=think - ) + # 解析参数 + if reasoning_effort is None: + show_think = get_config("app.thinking") + else: + show_think = reasoning_effort != "none" + is_stream = stream if stream is not None else get_config("app.stream") # 跨 Token 重试循环 tried_tokens = set() @@ -406,23 +254,8 @@ async def completions( last_error = None for attempt in range(max_token_retries): - # 选择 token(排除已失败的) - token = None - for pool_name in ModelService.pool_candidates_for_model(model): - token = token_mgr.get_token(pool_name, exclude=tried_tokens) - if token: - break - - if not token and not tried_tokens: - # 首次就无 token,尝试刷新 - logger.info("No available tokens, attempting to refresh cooling tokens...") - result = await token_mgr.refresh_cooling_tokens() - if result.get("recovered", 0) > 0: - for pool_name in ModelService.pool_candidates_for_model(model): - token = token_mgr.get_token(pool_name) - if token: - break - + # 选择 token + token = await pick_token(token_mgr, model, tried_tokens) if not token: if last_error: raise last_error @@ -438,12 +271,20 @@ async def completions( try: # 请求 Grok service = GrokChatService() - response, _, model_name = await service.chat_openai(token, chat_request) + response, _, model_name = await service.chat_openai( + token, + model, + messages, + stream=is_stream, + reasoning_effort=reasoning_effort, + temperature=temperature, + top_p=top_p, + ) # 处理响应 if is_stream: logger.debug(f"Processing stream response: model={model}") - processor = StreamProcessor(model_name, token, think) + processor = StreamProcessor(model_name, token, show_think) return wrap_stream_with_usage( processor.process(response), token_mgr, token, model ) @@ -465,10 +306,9 @@ async def completions( return result except UpstreamException as e: - status_code = e.details.get("status") if e.details else None last_error = e - if status_code == 429: + if rate_limited(e): # 配额不足,标记 token 为 cooling 并换 token 重试 await token_mgr.mark_rate_limited(token) logger.warning( @@ -491,10 +331,364 @@ async def completions( ) +class StreamProcessor(proc_base.BaseProcessor): + """Stream response processor.""" + + def __init__(self, model: str, token: str = "", show_think: bool = None): + super().__init__(model, token) + self.response_id: str = None + self.fingerprint: str = "" + self.think_opened: bool = False + self.role_sent: bool = False + self.filter_tags = get_config("app.filter_tags") + + self.show_think = bool(show_think) + + def _filter_token(self, token: str) -> str: + """Filter special tags in current token only.""" + if not self.filter_tags or not token: + return token + + for tag in self.filter_tags: + if f"<{tag}" in token or f" str: + """Build SSE response.""" + delta = {} + if role: + delta["role"] = role + delta["content"] = "" + elif content: + delta["content"] = content + + chunk = { + "id": self.response_id or f"chatcmpl-{uuid.uuid4().hex[:24]}", + "object": "chat.completion.chunk", + "created": self.created, + "model": self.model, + "system_fingerprint": self.fingerprint, + "choices": [ + {"index": 0, "delta": delta, "logprobs": None, "finish_reason": finish} + ], + } + return f"data: {orjson.dumps(chunk).decode()}\n\n" + + async def process(self, response: AsyncIterable[bytes]) -> AsyncGenerator[str, None]: + """Process stream response. + + Args: + response: AsyncIterable[bytes], async iterable of bytes + + Returns: + AsyncGenerator[str, None], async generator of strings + """ + idle_timeout = get_config("chat.stream_timeout") + + try: + async for line in proc_base._with_idle_timeout( + response, idle_timeout, self.model + ): + line = proc_base._normalize_line(line) + if not line: + continue + try: + data = orjson.loads(line) + except orjson.JSONDecodeError: + continue + + resp = data.get("result", {}).get("response", {}) + is_thinking = bool(resp.get("isThinking")) + # isThinking controls tagging + # when absent, treat as False + + if (llm := resp.get("llmInfo")) and not self.fingerprint: + self.fingerprint = llm.get("modelHash", "") + if rid := resp.get("responseId"): + self.response_id = rid + + if not self.role_sent: + yield self._sse(role="assistant") + self.role_sent = True + + if img := resp.get("streamingImageGenerationResponse"): + if not self.show_think: + continue + if is_thinking and not self.think_opened: + yield self._sse("\n") + self.think_opened = True + if (not is_thinking) and self.think_opened: + yield self._sse("\n\n") + self.think_opened = False + idx = img.get("imageIndex", 0) + 1 + progress = img.get("progress", 0) + yield self._sse( + f"正在生成第{idx}张图片中,当前进度{progress}%\n" + ) + continue + + if mr := resp.get("modelResponse"): + for url in proc_base._collect_images(mr): + parts = url.split("/") + img_id = parts[-2] if len(parts) >= 2 else "image" + dl_service = self._get_dl() + rendered = await dl_service.render_image( + url, self.token, img_id + ) + yield self._sse(f"{rendered}\n") + + if ( + (meta := mr.get("metadata", {})) + .get("llm_info", {}) + .get("modelHash") + ): + self.fingerprint = meta["llm_info"]["modelHash"] + continue + + if card := resp.get("cardAttachment"): + json_data = card.get("jsonData") + if isinstance(json_data, str) and json_data.strip(): + try: + card_data = orjson.loads(json_data) + except orjson.JSONDecodeError: + card_data = None + if isinstance(card_data, dict): + image = card_data.get("image") or {} + original = image.get("original") + title = image.get("title") or "" + if original: + title_safe = title.replace("\n", " ").strip() + if title_safe: + yield self._sse(f"![{title_safe}]({original})\n") + else: + yield self._sse(f"![image]({original})\n") + continue + + if (token := resp.get("token")) is not None: + if not token: + continue + filtered = self._filter_token(token) + if not filtered: + continue + if is_thinking: + if not self.show_think: + continue + if not self.think_opened: + yield self._sse("\n") + self.think_opened = True + else: + if self.think_opened: + yield self._sse("\n\n") + self.think_opened = False + yield self._sse(filtered) + + if self.think_opened: + yield self._sse("\n") + yield self._sse(finish="stop") + yield "data: [DONE]\n\n" + except asyncio.CancelledError: + logger.debug("Stream cancelled by client", extra={"model": self.model}) + except StreamIdleTimeoutError as e: + raise UpstreamException( + message=f"Stream idle timeout after {e.idle_seconds}s", + status_code=504, + details={ + "error": str(e), + "type": "stream_idle_timeout", + "idle_seconds": e.idle_seconds, + }, + ) + except RequestsError as e: + if proc_base._is_http2_error(e): + logger.warning(f"HTTP/2 stream error: {e}", extra={"model": self.model}) + raise UpstreamException( + message="Upstream connection closed unexpectedly", + status_code=502, + details={"error": str(e), "type": "http2_stream_error"}, + ) + logger.error(f"Stream request error: {e}", extra={"model": self.model}) + raise UpstreamException( + message=f"Upstream request failed: {e}", + status_code=502, + details={"error": str(e)}, + ) + except Exception as e: + logger.error( + f"Stream processing error: {e}", + extra={"model": self.model, "error_type": type(e).__name__}, + ) + raise + finally: + await self.close() + + +class CollectProcessor(proc_base.BaseProcessor): + """Non-stream response processor.""" + + def __init__(self, model: str, token: str = ""): + super().__init__(model, token) + self.filter_tags = get_config("app.filter_tags") + + def _filter_content(self, content: str) -> str: + """Filter special tags in content.""" + if not content or not self.filter_tags: + return content + + result = content + for tag in self.filter_tags: + pattern = rf"<{re.escape(tag)}[^>]*>.*?|<{re.escape(tag)}[^>]*/>" + result = re.sub(pattern, "", result, flags=re.DOTALL) + + return result + + async def process(self, response: AsyncIterable[bytes]) -> dict[str, Any]: + """Process and collect full response.""" + response_id = "" + fingerprint = "" + content = "" + idle_timeout = get_config("chat.stream_timeout") + + try: + async for line in proc_base._with_idle_timeout( + response, idle_timeout, self.model + ): + line = proc_base._normalize_line(line) + if not line: + continue + try: + data = orjson.loads(line) + except orjson.JSONDecodeError: + continue + + resp = data.get("result", {}).get("response", {}) + + if (llm := resp.get("llmInfo")) and not fingerprint: + fingerprint = llm.get("modelHash", "") + + if mr := resp.get("modelResponse"): + response_id = mr.get("responseId", "") + content = mr.get("message", "") + + card_map: dict[str, tuple[str, str]] = {} + for raw in mr.get("cardAttachmentsJson") or []: + if not isinstance(raw, str) or not raw.strip(): + continue + try: + card_data = orjson.loads(raw) + except orjson.JSONDecodeError: + continue + if not isinstance(card_data, dict): + continue + card_id = card_data.get("id") + image = card_data.get("image") or {} + original = image.get("original") + if not card_id or not original: + continue + title = image.get("title") or "" + card_map[card_id] = (title, original) + + if content and card_map: + def _render_card(match: re.Match) -> str: + card_id = match.group(1) + item = card_map.get(card_id) + if not item: + return "" + title, original = item + title_safe = title.replace("\n", " ").strip() or "image" + prefix = "" + if match.start() > 0: + prev = content[match.start() - 1] + if prev not in ("\n", "\r"): + prefix = "\n" + return f"{prefix}![{title_safe}]({original})" + + content = re.sub( + r']*card_id="([^"]+)"[^>]*>.*?', + _render_card, + content, + flags=re.DOTALL, + ) + + if urls := proc_base._collect_images(mr): + content += "\n" + for url in urls: + parts = url.split("/") + img_id = parts[-2] if len(parts) >= 2 else "image" + dl_service = self._get_dl() + rendered = await dl_service.render_image( + url, self.token, img_id + ) + content += f"{rendered}\n" + + if ( + (meta := mr.get("metadata", {})) + .get("llm_info", {}) + .get("modelHash") + ): + fingerprint = meta["llm_info"]["modelHash"] + + except asyncio.CancelledError: + logger.debug("Collect cancelled by client", extra={"model": self.model}) + except StreamIdleTimeoutError as e: + logger.warning(f"Collect idle timeout: {e}", extra={"model": self.model}) + except RequestsError as e: + if proc_base._is_http2_error(e): + logger.warning( + f"HTTP/2 stream error in collect: {e}", extra={"model": self.model} + ) + else: + logger.error(f"Collect request error: {e}", extra={"model": self.model}) + except Exception as e: + logger.error( + f"Collect processing error: {e}", + extra={"model": self.model, "error_type": type(e).__name__}, + ) + finally: + await self.close() + + content = self._filter_content(content) + + return { + "id": response_id, + "object": "chat.completion", + "created": self.created, + "model": self.model, + "system_fingerprint": fingerprint, + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": content, + "refusal": None, + "annotations": [], + }, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + "prompt_tokens_details": { + "cached_tokens": 0, + "text_tokens": 0, + "audio_tokens": 0, + "image_tokens": 0, + }, + "completion_tokens_details": { + "text_tokens": 0, + "audio_tokens": 0, + "reasoning_tokens": 0, + }, + }, + } + + __all__ = [ "GrokChatService", - "ChatRequest", - "ChatRequestBuilder", "MessageExtractor", "ChatService", ] diff --git a/app/services/grok/services/image.py b/app/services/grok/services/image.py index 573e1a21..75e26987 100644 --- a/app/services/grok/services/image.py +++ b/app/services/grok/services/image.py @@ -1,289 +1,614 @@ """ -Grok Imagine WebSocket image service. +Grok image services. """ import asyncio -import certifi -import json -import re -import ssl +import base64 +import math import time -import uuid -from typing import AsyncGenerator, Dict, Optional -from urllib.parse import urlparse +from dataclasses import dataclass +from pathlib import Path +from typing import Any, AsyncGenerator, AsyncIterable, Dict, List, Optional, Union -import aiohttp -from aiohttp_socks import ProxyConnector +import orjson from app.core.config import get_config from app.core.logger import logger -from app.services.grok.utils.headers import build_sso_cookie +from app.core.storage import DATA_DIR +from app.core.exceptions import AppException, ErrorType, UpstreamException +from app.services.grok.utils.process import BaseProcessor +from app.services.grok.utils.retry import pick_token, rate_limited +from app.services.grok.utils.stream import wrap_stream_with_usage +from app.services.token import EffortType +from app.services.reverse.ws_imagine import ImagineWebSocketReverse -WS_URL = "wss://grok.com/ws/imagine/listen" +image_service = ImagineWebSocketReverse() -class _BlockedError(Exception): - pass +@dataclass +class ImageGenerationResult: + stream: bool + data: Union[AsyncGenerator[str, None], List[str]] + usage_override: Optional[dict] = None -class ImageService: - """Grok Imagine WebSocket image service.""" - def __init__(self): - self._ssl_context = ssl.create_default_context() - self._ssl_context.load_verify_locations(certifi.where()) - self._url_pattern = re.compile(r"/images/([a-f0-9-]+)\.(png|jpg|jpeg)") +class ImageGenerationService: + """Image generation orchestration service.""" - def _resolve_proxy(self) -> tuple[aiohttp.BaseConnector, Optional[str]]: - proxy_url = get_config("network.base_proxy_url") - if not proxy_url: - return aiohttp.TCPConnector(ssl=self._ssl_context), None - - scheme = urlparse(proxy_url).scheme.lower() - if scheme.startswith("socks"): - logger.info(f"Using SOCKS proxy: {proxy_url}") - return ProxyConnector.from_url(proxy_url, ssl=self._ssl_context), None - - logger.info(f"Using HTTP proxy: {proxy_url}") - return aiohttp.TCPConnector(ssl=self._ssl_context), proxy_url - - def _get_ws_headers(self, token: str) -> Dict[str, str]: - cookie = build_sso_cookie(token, include_rw=True) - user_agent = get_config("security.user_agent") - return { - "Cookie": cookie, - "Origin": "https://grok.com", - "User-Agent": user_agent, - "Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8", - "Cache-Control": "no-cache", - "Pragma": "no-cache", - } - - def _extract_image_id(self, url: str) -> Optional[str]: - match = self._url_pattern.search(url or "") - return match.group(1) if match else None - - def _is_final_image(self, url: str, blob_size: int) -> bool: - return (url or "").lower().endswith( - (".jpg", ".jpeg") - ) and blob_size > get_config("image.image_ws_final_min_bytes") - - def _classify_image(self, url: str, blob: str) -> Optional[Dict[str, object]]: - if not url or not blob: - return None - - image_id = self._extract_image_id(url) or uuid.uuid4().hex - blob_size = len(blob) - is_final = self._is_final_image(url, blob_size) - - stage = ( - "final" - if is_final - else ( - "medium" - if blob_size > get_config("image.image_ws_medium_min_bytes") - else "preview" + async def generate( + self, + *, + token_mgr: Any, + token: str, + model_info: Any, + prompt: str, + n: int, + response_format: str, + size: str, + aspect_ratio: str, + stream: bool, + enable_nsfw: Optional[bool] = None, + ) -> ImageGenerationResult: + max_token_retries = int(get_config("retry.max_retry")) + tried_tokens: set[str] = set() + last_error: Optional[Exception] = None + + if stream: + async def _stream_retry() -> AsyncGenerator[str, None]: + nonlocal last_error + for attempt in range(max_token_retries): + preferred = token if attempt == 0 else None + current_token = await pick_token( + token_mgr, model_info.model_id, tried_tokens, preferred=preferred + ) + if not current_token: + if last_error: + raise last_error + raise AppException( + message="No available tokens. Please try again later.", + error_type=ErrorType.RATE_LIMIT.value, + code="rate_limit_exceeded", + status_code=429, + ) + + tried_tokens.add(current_token) + yielded = False + try: + result = await self._stream_ws( + token_mgr=token_mgr, + token=current_token, + model_info=model_info, + prompt=prompt, + n=n, + response_format=response_format, + size=size, + aspect_ratio=aspect_ratio, + enable_nsfw=enable_nsfw, + ) + async for chunk in result.data: + yielded = True + yield chunk + return + except UpstreamException as e: + last_error = e + if rate_limited(e): + if yielded: + raise + await token_mgr.mark_rate_limited(current_token) + logger.warning( + f"Token {current_token[:10]}... rate limited (429), " + f"trying next token (attempt {attempt + 1}/{max_token_retries})" + ) + continue + raise + + if last_error: + raise last_error + raise AppException( + message="No available tokens. Please try again later.", + error_type=ErrorType.RATE_LIMIT.value, + code="rate_limit_exceeded", + status_code=429, + ) + + return ImageGenerationResult(stream=True, data=_stream_retry()) + + for attempt in range(max_token_retries): + preferred = token if attempt == 0 else None + current_token = await pick_token( + token_mgr, model_info.model_id, tried_tokens, preferred=preferred ) + if not current_token: + if last_error: + raise last_error + raise AppException( + message="No available tokens. Please try again later.", + error_type=ErrorType.RATE_LIMIT.value, + code="rate_limit_exceeded", + status_code=429, + ) + + tried_tokens.add(current_token) + try: + return await self._collect_ws( + token_mgr=token_mgr, + token=current_token, + model_info=model_info, + prompt=prompt, + n=n, + response_format=response_format, + aspect_ratio=aspect_ratio, + enable_nsfw=enable_nsfw, + ) + except UpstreamException as e: + last_error = e + if rate_limited(e): + await token_mgr.mark_rate_limited(current_token) + logger.warning( + f"Token {current_token[:10]}... rate limited (429), " + f"trying next token (attempt {attempt + 1}/{max_token_retries})" + ) + continue + raise + + if last_error: + raise last_error + raise AppException( + message="No available tokens. Please try again later.", + error_type=ErrorType.RATE_LIMIT.value, + code="rate_limit_exceeded", + status_code=429, ) - return { - "type": "image", - "image_id": image_id, - "stage": stage, - "blob": blob, - "blob_size": blob_size, - "url": url, - "is_final": is_final, - } - - async def stream( + async def _stream_ws( self, + *, + token_mgr: Any, token: str, + model_info: Any, prompt: str, - aspect_ratio: str = "2:3", - n: int = 1, - enable_nsfw: bool = True, - max_retries: int = None, - ) -> AsyncGenerator[Dict[str, object], None]: - retries = max(1, max_retries if max_retries is not None else 1) - logger.info( - f"Image generation: prompt='{prompt[:50]}...', n={n}, ratio={aspect_ratio}, nsfw={enable_nsfw}" + n: int, + response_format: str, + size: str, + aspect_ratio: str, + enable_nsfw: Optional[bool] = None, + ) -> ImageGenerationResult: + if enable_nsfw is None: + enable_nsfw = bool(get_config("image.nsfw")) + upstream = image_service.stream( + token=token, + prompt=prompt, + aspect_ratio=aspect_ratio, + n=n, + enable_nsfw=enable_nsfw, ) + processor = ImageWSStreamProcessor( + model_info.model_id, + token, + n=n, + response_format=response_format, + size=size, + ) + stream = wrap_stream_with_usage( + processor.process(upstream), + token_mgr, + token, + model_info.model_id, + ) + return ImageGenerationResult(stream=True, data=stream) - for attempt in range(retries): - try: - yielded_any = False - async for item in self._stream_once( - token, prompt, aspect_ratio, n, enable_nsfw - ): - yielded_any = True - yield item - return - except _BlockedError: - if yielded_any or attempt + 1 >= retries: - if not yielded_any: - yield { - "type": "error", - "error_code": "blocked", - "error": "blocked_no_final_image", - } - return - logger.warning(f"WebSocket blocked, retry {attempt + 1}/{retries}") - except Exception as e: - logger.error(f"WebSocket stream failed: {e}") - return - - async def _stream_once( + async def _collect_ws( self, + *, + token_mgr: Any, token: str, + model_info: Any, prompt: str, - aspect_ratio: str, n: int, - enable_nsfw: bool, - ) -> AsyncGenerator[Dict[str, object], None]: - request_id = str(uuid.uuid4()) - headers = self._get_ws_headers(token) - timeout = float(get_config("network.timeout")) - blocked_seconds = float(get_config("image.image_ws_blocked_seconds")) + response_format: str, + aspect_ratio: str, + enable_nsfw: Optional[bool] = None, + ) -> ImageGenerationResult: + if enable_nsfw is None: + enable_nsfw = bool(get_config("image.nsfw")) + all_images: List[str] = [] + seen = set() + expected_per_call = 6 + calls_needed = max(1, int(math.ceil(n / expected_per_call))) + calls_needed = min(calls_needed, n) + + async def _fetch_batch(call_target: int): + upstream = image_service.stream( + token=token, + prompt=prompt, + aspect_ratio=aspect_ratio, + n=call_target, + enable_nsfw=enable_nsfw, + ) + processor = ImageWSCollectProcessor( + model_info.model_id, + token, + n=call_target, + response_format=response_format, + ) + return await processor.process(upstream) + + tasks = [] + for i in range(calls_needed): + remaining = n - (i * expected_per_call) + call_target = min(expected_per_call, remaining) + tasks.append(_fetch_batch(call_target)) + + results = await asyncio.gather(*tasks, return_exceptions=True) + for batch in results: + if isinstance(batch, Exception): + logger.warning(f"WS batch failed: {batch}") + continue + for img in batch: + if img not in seen: + seen.add(img) + all_images.append(img) + if len(all_images) >= n: + break + if len(all_images) >= n: + break try: - connector, proxy = self._resolve_proxy() + await token_mgr.consume(token, self._get_effort(model_info)) except Exception as e: - logger.error(f"WebSocket proxy setup failed: {e}") - return + logger.warning(f"Failed to consume token: {e}") + + selected = self._select_images(all_images, n) + usage_override = { + "total_tokens": 0, + "input_tokens": 0, + "output_tokens": 0, + "input_tokens_details": {"text_tokens": 0, "image_tokens": 0}, + } + return ImageGenerationResult( + stream=False, data=selected, usage_override=usage_override + ) + @staticmethod + def _get_effort(model_info: Any) -> EffortType: + return ( + EffortType.HIGH + if (model_info and model_info.cost.value == "high") + else EffortType.LOW + ) + + @staticmethod + def _select_images(images: List[str], n: int) -> List[str]: + if len(images) >= n: + return images[:n] + selected = images.copy() + while len(selected) < n: + selected.append("error") + return selected + + +class ImageWSBaseProcessor(BaseProcessor): + """WebSocket image processor base.""" + + def __init__(self, model: str, token: str = "", response_format: str = "b64_json"): + if response_format == "base64": + response_format = "b64_json" + super().__init__(model, token) + self.response_format = response_format + if response_format == "url": + self.response_field = "url" + elif response_format == "base64": + self.response_field = "base64" + else: + self.response_field = "b64_json" + self._image_dir: Optional[Path] = None + + def _ensure_image_dir(self) -> Path: + if self._image_dir is None: + base_dir = DATA_DIR / "tmp" / "image" + base_dir.mkdir(parents=True, exist_ok=True) + self._image_dir = base_dir + return self._image_dir + + def _strip_base64(self, blob: str) -> str: + if not blob: + return "" + if "," in blob and "base64" in blob.split(",", 1)[0]: + return blob.split(",", 1)[1] + return blob + + def _guess_ext(self, blob: str) -> Optional[str]: + if not blob: + return None + header = "" + data = blob + if "," in blob and "base64" in blob.split(",", 1)[0]: + header, data = blob.split(",", 1) + header = header.lower() + if "image/png" in header: + return "png" + if "image/jpeg" in header or "image/jpg" in header: + return "jpg" + if data.startswith("iVBORw0KGgo"): + return "png" + if data.startswith("/9j/"): + return "jpg" + return None + + def _filename(self, image_id: str, is_final: bool, ext: Optional[str] = None) -> str: + if ext: + ext = ext.lower() + if ext == "jpeg": + ext = "jpg" + if not ext: + ext = "jpg" if is_final else "png" + return f"{image_id}.{ext}" + + def _build_file_url(self, filename: str) -> str: + app_url = get_config("app.app_url") + if app_url: + return f"{app_url.rstrip('/')}/v1/files/image/{filename}" + return f"/v1/files/image/{filename}" + + async def _save_blob( + self, image_id: str, blob: str, is_final: bool, ext: Optional[str] = None + ) -> str: + data = self._strip_base64(blob) + if not data: + return "" + image_dir = self._ensure_image_dir() + ext = ext or self._guess_ext(blob) + filename = self._filename(image_id, is_final, ext=ext) + filepath = image_dir / filename + + def _write_file(): + with open(filepath, "wb") as f: + f.write(base64.b64decode(data)) + + await asyncio.to_thread(_write_file) + return self._build_file_url(filename) + + def _pick_best(self, existing: Optional[Dict], incoming: Dict) -> Dict: + if not existing: + return incoming + if incoming.get("is_final") and not existing.get("is_final"): + return incoming + if existing.get("is_final") and not incoming.get("is_final"): + return existing + if incoming.get("blob_size", 0) > existing.get("blob_size", 0): + return incoming + return existing + + async def _to_output(self, image_id: str, item: Dict) -> str: try: - async with aiohttp.ClientSession(connector=connector) as session: - async with session.ws_connect( - WS_URL, - headers=headers, - heartbeat=20, - receive_timeout=timeout, - proxy=proxy, - ) as ws: - message = { - "type": "conversation.item.create", - "timestamp": int(time.time() * 1000), - "item": { - "type": "message", - "content": [ - { - "requestId": request_id, - "text": prompt, - "type": "input_text", - "properties": { - "section_count": 0, - "is_kids_mode": False, - "enable_nsfw": enable_nsfw, - "skip_upsampler": False, - "is_initial": False, - "aspect_ratio": aspect_ratio, - }, - } - ], - }, - } - - await ws.send_json(message) - logger.info(f"WebSocket request sent: {prompt[:80]}...") - - images = {} - completed = 0 - start_time = last_activity = time.time() - medium_received_time = None - - while time.time() - start_time < timeout: - try: - ws_msg = await asyncio.wait_for(ws.receive(), timeout=5.0) - except asyncio.TimeoutError: - if ( - medium_received_time - and completed == 0 - and time.time() - medium_received_time - > min(10, blocked_seconds) - ): - raise _BlockedError() - if completed > 0 and time.time() - last_activity > 10: - logger.info( - f"WebSocket idle timeout, collected {completed} images" - ) - break - continue + if self.response_format == "url": + return await self._save_blob( + image_id, + item.get("blob", ""), + item.get("is_final", False), + ext=item.get("ext"), + ) + return self._strip_base64(item.get("blob", "")) + except Exception as e: + logger.warning(f"Image output failed: {e}") + return "" + + +class ImageWSStreamProcessor(ImageWSBaseProcessor): + """WebSocket image stream processor.""" + + def __init__( + self, + model: str, + token: str = "", + n: int = 1, + response_format: str = "b64_json", + size: str = "1024x1024", + ): + super().__init__(model, token, response_format) + self.n = n + self.size = size + self._target_id: Optional[str] = None + self._index_map: Dict[str, int] = {} + self._partial_map: Dict[str, int] = {} + self._initial_sent: set[str] = set() + + def _assign_index(self, image_id: str) -> Optional[int]: + if image_id in self._index_map: + return self._index_map[image_id] + if len(self._index_map) >= self.n: + return None + self._index_map[image_id] = len(self._index_map) + return self._index_map[image_id] + + def _sse(self, event: str, data: dict) -> str: + return f"event: {event}\ndata: {orjson.dumps(data).decode()}\n\n" + + async def process(self, response: AsyncIterable[dict]) -> AsyncGenerator[str, None]: + images: Dict[str, Dict] = {} + + async for item in response: + if item.get("type") == "error": + message = item.get("error") or "Upstream error" + code = item.get("error_code") or "upstream_error" + status = item.get("status") + if code == "rate_limit_exceeded" or status == 429: + raise UpstreamException(message, details=item) + yield self._sse( + "error", + { + "error": { + "message": message, + "type": "server_error", + "code": code, + } + }, + ) + return + if item.get("type") != "image": + continue + + image_id = item.get("image_id") + if not image_id: + continue + + if self.n == 1: + if self._target_id is None: + self._target_id = image_id + index = 0 if image_id == self._target_id else None + else: + index = self._assign_index(image_id) + + images[image_id] = self._pick_best(images.get(image_id), item) + + if index is None: + continue + + if item.get("stage") != "final": + if image_id not in self._initial_sent: + self._initial_sent.add(image_id) + stage = item.get("stage") or "preview" + if stage == "medium": + partial_index = 1 + self._partial_map[image_id] = 1 + else: + partial_index = 0 + self._partial_map[image_id] = 0 + else: + stage = item.get("stage") or "partial" + if stage == "preview": + continue + partial_index = self._partial_map.get(image_id, 0) + if stage == "medium": + partial_index = max(partial_index, 1) + self._partial_map[image_id] = partial_index + + if self.response_format == "url": + partial_id = f"{image_id}-{stage}-{partial_index}" + partial_out = await self._save_blob( + partial_id, + item.get("blob", ""), + False, + ext=item.get("ext"), + ) + else: + partial_out = self._strip_base64(item.get("blob", "")) + if not partial_out: + continue + yield self._sse( + "image_generation.partial_image", + { + "type": "image_generation.partial_image", + self.response_field: partial_out, + "created_at": int(time.time()), + "size": self.size, + "index": index, + "partial_image_index": partial_index, + "image_id": image_id, + "stage": stage, + }, + ) + + if self.n == 1: + if self._target_id and self._target_id in images: + selected = [(self._target_id, images[self._target_id])] + else: + selected = ( + [ + max( + images.items(), + key=lambda x: ( + x[1].get("is_final", False), + x[1].get("blob_size", 0), + ), + ) + ] + if images + else [] + ) + else: + selected = [ + (image_id, images[image_id]) + for image_id in self._index_map + if image_id in images + ] + + for image_id, item in selected: + if self.response_format == "url": + output = await self._save_blob( + f"{image_id}-final", + item.get("blob", ""), + item.get("is_final", False), + ext=item.get("ext"), + ) + else: + output = await self._to_output(image_id, item) + if not output: + continue + + if self.n == 1: + index = 0 + else: + index = self._index_map.get(image_id, 0) + yield self._sse( + "image_generation.completed", + { + "type": "image_generation.completed", + self.response_field: output, + "created_at": int(time.time()), + "size": self.size, + "index": index, + "image_id": image_id, + "stage": "final", + "usage": { + "total_tokens": 0, + "input_tokens": 0, + "output_tokens": 0, + "input_tokens_details": {"text_tokens": 0, "image_tokens": 0}, + }, + }, + ) + + +class ImageWSCollectProcessor(ImageWSBaseProcessor): + """WebSocket image non-stream processor.""" + + def __init__( + self, model: str, token: str = "", n: int = 1, response_format: str = "b64_json" + ): + super().__init__(model, token, response_format) + self.n = n + + async def process(self, response: AsyncIterable[dict]) -> List[str]: + images: Dict[str, Dict] = {} + + async for item in response: + if item.get("type") == "error": + message = item.get("error") or "Upstream error" + raise UpstreamException(message, details=item) + if item.get("type") != "image": + continue + image_id = item.get("image_id") + if not image_id: + continue + images[image_id] = self._pick_best(images.get(image_id), item) + + selected = sorted( + images.values(), + key=lambda x: (x.get("is_final", False), x.get("blob_size", 0)), + reverse=True, + ) + if self.n: + selected = selected[: self.n] + + results: List[str] = [] + for item in selected: + output = await self._to_output(item.get("image_id", ""), item) + if output: + results.append(output) + + return results + - if ws_msg.type == aiohttp.WSMsgType.TEXT: - last_activity = time.time() - msg = json.loads(ws_msg.data) - msg_type = msg.get("type") - - if msg_type == "image": - info = self._classify_image( - msg.get("url", ""), msg.get("blob", "") - ) - if not info: - continue - - image_id = info["image_id"] - existing = images.get(image_id, {}) - - if ( - info["stage"] == "medium" - and medium_received_time is None - ): - medium_received_time = time.time() - - if info["is_final"] and not existing.get("is_final"): - completed += 1 - logger.debug( - f"Final image received: id={image_id}, size={info['blob_size']}" - ) - - images[image_id] = { - "is_final": info["is_final"] - or existing.get("is_final") - } - yield info - - elif msg_type == "error": - logger.warning( - f"WebSocket error: {msg.get('err_code', '')} - {msg.get('err_msg', '')}" - ) - yield { - "type": "error", - "error_code": msg.get("err_code", ""), - "error": msg.get("err_msg", ""), - } - return - - if completed >= n: - logger.info( - f"WebSocket collected {completed} final images" - ) - break - - if ( - medium_received_time - and completed == 0 - and time.time() - medium_received_time > blocked_seconds - ): - raise _BlockedError() - - elif ws_msg.type in ( - aiohttp.WSMsgType.CLOSED, - aiohttp.WSMsgType.ERROR, - ): - logger.warning(f"WebSocket closed/error: {ws_msg.type}") - yield { - "type": "error", - "error_code": "ws_closed", - "error": f"websocket closed: {ws_msg.type}", - } - break - - except aiohttp.ClientError as e: - logger.error(f"WebSocket connection error: {e}") - yield {"type": "error", "error_code": "connection_failed", "error": str(e)} - - -image_service = ImageService() - -__all__ = ["image_service", "ImageService"] +__all__ = ["ImageGenerationService"] diff --git a/app/services/grok/services/image_edit.py b/app/services/grok/services/image_edit.py new file mode 100644 index 00000000..eba6f1f3 --- /dev/null +++ b/app/services/grok/services/image_edit.py @@ -0,0 +1,509 @@ +""" +Grok image edit service. +""" + +import asyncio +import random +import re +from dataclasses import dataclass +from typing import AsyncGenerator, AsyncIterable, List, Union, Any + +import orjson +from curl_cffi.requests.errors import RequestsError + +from app.core.config import get_config +from app.core.exceptions import ( + AppException, + ErrorType, + UpstreamException, + StreamIdleTimeoutError, +) +from app.core.logger import logger +from app.services.grok.utils.process import ( + BaseProcessor, + _with_idle_timeout, + _normalize_line, + _collect_images, + _is_http2_error, +) +from app.services.grok.utils.upload import UploadService +from app.services.grok.utils.retry import pick_token, rate_limited +from app.services.grok.services.chat import GrokChatService +from app.services.grok.services.video import VideoService +from app.services.grok.utils.stream import wrap_stream_with_usage +from app.services.token import EffortType + + +@dataclass +class ImageEditResult: + stream: bool + data: Union[AsyncGenerator[str, None], List[str]] + + +class ImageEditService: + """Image edit orchestration service.""" + + async def edit( + self, + *, + token_mgr: Any, + token: str, + model_info: Any, + prompt: str, + images: List[str], + n: int, + response_format: str, + stream: bool, + ) -> ImageEditResult: + max_token_retries = int(get_config("retry.max_retry")) + tried_tokens: set[str] = set() + last_error: Exception | None = None + + for attempt in range(max_token_retries): + preferred = token if attempt == 0 else None + current_token = await pick_token( + token_mgr, model_info.model_id, tried_tokens, preferred=preferred + ) + if not current_token: + if last_error: + raise last_error + raise AppException( + message="No available tokens. Please try again later.", + error_type=ErrorType.RATE_LIMIT.value, + code="rate_limit_exceeded", + status_code=429, + ) + + tried_tokens.add(current_token) + try: + image_urls = await self._upload_images(images, current_token) + parent_post_id = await self._get_parent_post_id( + current_token, image_urls + ) + + model_config_override = { + "modelMap": { + "imageEditModel": "imagine", + "imageEditModelConfig": { + "imageReferences": image_urls, + }, + } + } + if parent_post_id: + model_config_override["modelMap"]["imageEditModelConfig"][ + "parentPostId" + ] = parent_post_id + + tool_overrides = {"imageGen": True} + + if stream: + response = await GrokChatService().chat( + token=current_token, + message=prompt, + model=model_info.grok_model, + mode=None, + stream=True, + tool_overrides=tool_overrides, + model_config_override=model_config_override, + ) + processor = ImageStreamProcessor( + model_info.model_id, + current_token, + n=n, + response_format=response_format, + ) + return ImageEditResult( + stream=True, + data=wrap_stream_with_usage( + processor.process(response), + token_mgr, + current_token, + model_info.model_id, + ), + ) + + images_out = await self._collect_images( + token=current_token, + prompt=prompt, + model_info=model_info, + n=n, + response_format=response_format, + tool_overrides=tool_overrides, + model_config_override=model_config_override, + ) + try: + effort = ( + EffortType.HIGH + if (model_info and model_info.cost.value == "high") + else EffortType.LOW + ) + await token_mgr.consume(current_token, effort) + logger.debug( + f"Image edit completed, recorded usage (effort={effort.value})" + ) + except Exception as e: + logger.warning(f"Failed to record image edit usage: {e}") + return ImageEditResult(stream=False, data=images_out) + + except UpstreamException as e: + last_error = e + if rate_limited(e): + await token_mgr.mark_rate_limited(current_token) + logger.warning( + f"Token {current_token[:10]}... rate limited (429), " + f"trying next token (attempt {attempt + 1}/{max_token_retries})" + ) + continue + raise + + if last_error: + raise last_error + raise AppException( + message="No available tokens. Please try again later.", + error_type=ErrorType.RATE_LIMIT.value, + code="rate_limit_exceeded", + status_code=429, + ) + + async def _upload_images(self, images: List[str], token: str) -> List[str]: + image_urls: List[str] = [] + upload_service = UploadService() + try: + for image in images: + _, file_uri = await upload_service.upload_file(image, token) + if file_uri: + if file_uri.startswith("http"): + image_urls.append(file_uri) + else: + image_urls.append( + f"https://assets.grok.com/{file_uri.lstrip('/')}" + ) + finally: + await upload_service.close() + + if not image_urls: + raise AppException( + message="Image upload failed", + error_type=ErrorType.SERVER.value, + code="upload_failed", + ) + + return image_urls + + async def _get_parent_post_id(self, token: str, image_urls: List[str]) -> str: + parent_post_id = None + try: + media_service = VideoService() + parent_post_id = await media_service.create_image_post(token, image_urls[0]) + logger.debug(f"Parent post ID: {parent_post_id}") + except Exception as e: + logger.warning(f"Create image post failed: {e}") + + if parent_post_id: + return parent_post_id + + for url in image_urls: + match = re.search(r"/generated/([a-f0-9-]+)/", url) + if match: + parent_post_id = match.group(1) + logger.debug(f"Parent post ID: {parent_post_id}") + break + match = re.search(r"/users/[^/]+/([a-f0-9-]+)/content", url) + if match: + parent_post_id = match.group(1) + logger.debug(f"Parent post ID: {parent_post_id}") + break + + return parent_post_id or "" + + async def _collect_images( + self, + *, + token: str, + prompt: str, + model_info: Any, + n: int, + response_format: str, + tool_overrides: dict, + model_config_override: dict, + ) -> List[str]: + calls_needed = (n + 1) // 2 + + async def _call_edit(): + response = await GrokChatService().chat( + token=token, + message=prompt, + model=model_info.grok_model, + mode=None, + stream=True, + tool_overrides=tool_overrides, + model_config_override=model_config_override, + ) + processor = ImageCollectProcessor( + model_info.model_id, token, response_format=response_format + ) + return await processor.process(response) + + last_error: Exception | None = None + rate_limit_error: Exception | None = None + + if calls_needed == 1: + all_images = await _call_edit() + else: + tasks = [_call_edit() for _ in range(calls_needed)] + results = await asyncio.gather(*tasks, return_exceptions=True) + + all_images: List[str] = [] + for result in results: + if isinstance(result, Exception): + logger.error(f"Concurrent call failed: {result}") + last_error = result + if rate_limited(result): + rate_limit_error = result + elif isinstance(result, list): + all_images.extend(result) + + if not all_images: + if rate_limit_error: + raise rate_limit_error + if last_error: + raise last_error + raise UpstreamException( + "Image edit returned no results", details={"error": "empty_result"} + ) + + if len(all_images) >= n: + return all_images[:n] + + selected_images = all_images.copy() + while len(selected_images) < n: + selected_images.append("error") + return selected_images + + +class ImageStreamProcessor(BaseProcessor): + """HTTP image stream processor.""" + + def __init__( + self, model: str, token: str = "", n: int = 1, response_format: str = "b64_json" + ): + super().__init__(model, token) + self.partial_index = 0 + self.n = n + self.target_index = 0 if n == 1 else None + self.response_format = response_format + if response_format == "url": + self.response_field = "url" + elif response_format == "base64": + self.response_field = "base64" + else: + self.response_field = "b64_json" + + def _sse(self, event: str, data: dict) -> str: + """Build SSE response.""" + return f"event: {event}\ndata: {orjson.dumps(data).decode()}\n\n" + + async def process( + self, response: AsyncIterable[bytes] + ) -> AsyncGenerator[str, None]: + """Process stream response.""" + final_images = [] + idle_timeout = get_config("image.stream_timeout") + + try: + async for line in _with_idle_timeout(response, idle_timeout, self.model): + line = _normalize_line(line) + if not line: + continue + try: + data = orjson.loads(line) + except orjson.JSONDecodeError: + continue + + resp = data.get("result", {}).get("response", {}) + + # Image generation progress + if img := resp.get("streamingImageGenerationResponse"): + image_index = img.get("imageIndex", 0) + progress = img.get("progress", 0) + + if self.n == 1 and image_index != self.target_index: + continue + + out_index = 0 if self.n == 1 else image_index + + yield self._sse( + "image_generation.partial_image", + { + "type": "image_generation.partial_image", + self.response_field: "", + "index": out_index, + "progress": progress, + }, + ) + continue + + # modelResponse + if mr := resp.get("modelResponse"): + if urls := _collect_images(mr): + for url in urls: + if self.response_format == "url": + processed = await self.process_url(url, "image") + if processed: + final_images.append(processed) + continue + try: + dl_service = self._get_dl() + base64_data = await dl_service.parse_b64( + url, self.token, "image" + ) + if base64_data: + if "," in base64_data: + b64 = base64_data.split(",", 1)[1] + else: + b64 = base64_data + final_images.append(b64) + except Exception as e: + logger.warning( + f"Failed to convert image to base64, falling back to URL: {e}" + ) + processed = await self.process_url(url, "image") + if processed: + final_images.append(processed) + continue + + for index, b64 in enumerate(final_images): + if self.n == 1: + if index != self.target_index: + continue + out_index = 0 + else: + out_index = index + + yield self._sse( + "image_generation.completed", + { + "type": "image_generation.completed", + self.response_field: b64, + "index": out_index, + "usage": { + "total_tokens": 0, + "input_tokens": 0, + "output_tokens": 0, + "input_tokens_details": { + "text_tokens": 0, + "image_tokens": 0, + }, + }, + }, + ) + except asyncio.CancelledError: + logger.debug("Image stream cancelled by client") + except StreamIdleTimeoutError as e: + raise UpstreamException( + message=f"Image stream idle timeout after {e.idle_seconds}s", + status_code=504, + details={ + "error": str(e), + "type": "stream_idle_timeout", + "idle_seconds": e.idle_seconds, + }, + ) + except RequestsError as e: + if _is_http2_error(e): + logger.warning(f"HTTP/2 stream error in image: {e}") + raise UpstreamException( + message="Upstream connection closed unexpectedly", + status_code=502, + details={"error": str(e), "type": "http2_stream_error"}, + ) + logger.error(f"Image stream request error: {e}") + raise UpstreamException( + message=f"Upstream request failed: {e}", + status_code=502, + details={"error": str(e)}, + ) + except Exception as e: + logger.error( + f"Image stream processing error: {e}", + extra={"error_type": type(e).__name__}, + ) + raise + finally: + await self.close() + + +class ImageCollectProcessor(BaseProcessor): + """HTTP image non-stream processor.""" + + def __init__(self, model: str, token: str = "", response_format: str = "b64_json"): + if response_format == "base64": + response_format = "b64_json" + super().__init__(model, token) + self.response_format = response_format + + async def process(self, response: AsyncIterable[bytes]) -> List[str]: + """Process and collect images.""" + images = [] + idle_timeout = get_config("image.stream_timeout") + + try: + async for line in _with_idle_timeout(response, idle_timeout, self.model): + line = _normalize_line(line) + if not line: + continue + try: + data = orjson.loads(line) + except orjson.JSONDecodeError: + continue + + resp = data.get("result", {}).get("response", {}) + + if mr := resp.get("modelResponse"): + if urls := _collect_images(mr): + for url in urls: + if self.response_format == "url": + processed = await self.process_url(url, "image") + if processed: + images.append(processed) + continue + try: + dl_service = self._get_dl() + base64_data = await dl_service.parse_b64( + url, self.token, "image" + ) + if base64_data: + if "," in base64_data: + b64 = base64_data.split(",", 1)[1] + else: + b64 = base64_data + images.append(b64) + except Exception as e: + logger.warning( + f"Failed to convert image to base64, falling back to URL: {e}" + ) + processed = await self.process_url(url, "image") + if processed: + images.append(processed) + + except asyncio.CancelledError: + logger.debug("Image collect cancelled by client") + except StreamIdleTimeoutError as e: + logger.warning(f"Image collect idle timeout: {e}") + except RequestsError as e: + if _is_http2_error(e): + logger.warning(f"HTTP/2 stream error in image collect: {e}") + else: + logger.error(f"Image collect request error: {e}") + except Exception as e: + logger.error( + f"Image collect processing error: {e}", + extra={"error_type": type(e).__name__}, + ) + finally: + await self.close() + + return images + + +__all__ = ["ImageEditService", "ImageEditResult"] diff --git a/app/services/grok/services/media.py b/app/services/grok/services/media.py deleted file mode 100644 index 79d9b677..00000000 --- a/app/services/grok/services/media.py +++ /dev/null @@ -1,388 +0,0 @@ -""" -Grok 视频生成服务 -""" - -import asyncio -from typing import AsyncGenerator, Optional - -import orjson -from curl_cffi.requests import AsyncSession - -from app.core.logger import logger -from app.core.config import get_config -from app.core.exceptions import ( - UpstreamException, - AppException, - ValidationException, - ErrorType, -) -from app.services.grok.models.model import ModelService -from app.services.token import get_token_manager, EffortType -from app.services.grok.processors import VideoStreamProcessor, VideoCollectProcessor -from app.services.grok.utils.headers import apply_statsig, build_sso_cookie -from app.services.grok.utils.stream import wrap_stream_with_usage - -CREATE_POST_API = "https://grok.com/rest/media/post/create" -CHAT_API = "https://grok.com/rest/app-chat/conversations/new" - -_MEDIA_SEMAPHORE = None -_MEDIA_SEM_VALUE = 0 - - -def _get_semaphore() -> asyncio.Semaphore: - """获取或更新信号量""" - global _MEDIA_SEMAPHORE, _MEDIA_SEM_VALUE - value = max(1, int(get_config("performance.media_max_concurrent"))) - if value != _MEDIA_SEM_VALUE: - _MEDIA_SEM_VALUE = value - _MEDIA_SEMAPHORE = asyncio.Semaphore(value) - return _MEDIA_SEMAPHORE - - -class VideoService: - """视频生成服务""" - - def __init__(self, proxy: str = None): - self.proxy = proxy or get_config("network.base_proxy_url") - self.timeout = get_config("network.timeout") - - def _build_headers( - self, token: str, referer: str = "https://grok.com/imagine" - ) -> dict: - """构建请求头""" - user_agent = get_config("security.user_agent") - headers = { - "Accept": "*/*", - "Accept-Encoding": "gzip, deflate, br, zstd", - "Accept-Language": "zh-CN,zh;q=0.9", - "Baggage": "sentry-environment=production,sentry-release=d6add6fb0460641fd482d767a335ef72b9b6abb8,sentry-public_key=b311e0f2690c81f25e2c4cf6d4f7ce1c", - "Cache-Control": "no-cache", - "Content-Type": "application/json", - "Origin": "https://grok.com", - "Pragma": "no-cache", - "Priority": "u=1, i", - "Referer": referer, - "Sec-Ch-Ua": '"Google Chrome";v="136", "Chromium";v="136", "Not(A:Brand";v="24"', - "Sec-Ch-Ua-Arch": "arm", - "Sec-Ch-Ua-Bitness": "64", - "Sec-Ch-Ua-Mobile": "?0", - "Sec-Ch-Ua-Model": "", - "Sec-Ch-Ua-Platform": '"macOS"', - "Sec-Fetch-Dest": "empty", - "Sec-Fetch-Mode": "cors", - "Sec-Fetch-Site": "same-origin", - "User-Agent": user_agent, - } - - apply_statsig(headers) - headers["Cookie"] = build_sso_cookie(token) - - return headers - - def _build_proxies(self) -> Optional[dict]: - """构建代理""" - return {"http": self.proxy, "https": self.proxy} if self.proxy else None - - async def create_post( - self, - token: str, - prompt: str, - media_type: str = "MEDIA_POST_TYPE_VIDEO", - media_url: str = None, - ) -> str: - """创建媒体帖子,返回 post ID""" - try: - headers = self._build_headers(token) - - # 根据类型构建不同的载荷 - if media_type == "MEDIA_POST_TYPE_IMAGE" and media_url: - payload = {"mediaType": media_type, "mediaUrl": media_url} - else: - payload = {"mediaType": media_type, "prompt": prompt} - - async with AsyncSession() as session: - response = await session.post( - CREATE_POST_API, - headers=headers, - json=payload, - impersonate=get_config("security.browser"), - timeout=30, - proxies=self._build_proxies(), - ) - - if response.status_code != 200: - logger.error(f"Create post failed: {response.status_code}") - raise UpstreamException( - f"Failed to create post: {response.status_code}" - ) - - post_id = response.json().get("post", {}).get("id", "") - if not post_id: - raise UpstreamException("No post ID in response") - - logger.info(f"Media post created: {post_id} (type={media_type})") - return post_id - - except AppException: - raise - except Exception as e: - logger.error(f"Create post error: {e}") - raise UpstreamException(f"Create post error: {str(e)}") - - async def create_image_post(self, token: str, image_url: str) -> str: - """创建图片帖子,返回 post ID""" - return await self.create_post( - token, prompt="", media_type="MEDIA_POST_TYPE_IMAGE", media_url=image_url - ) - - def _build_payload( - self, - prompt: str, - post_id: str, - aspect_ratio: str = "3:2", - video_length: int = 6, - resolution_name: str = "480p", - preset: str = "normal", - ) -> dict: - """构建视频生成载荷""" - mode_map = { - "fun": "--mode=extremely-crazy", - "normal": "--mode=normal", - "spicy": "--mode=extremely-spicy-or-crazy", - } - mode_flag = mode_map.get(preset, "--mode=custom") - - payload = { - "temporary": True, - "modelName": "grok-3", - "message": f"{prompt} {mode_flag}", - "toolOverrides": {"videoGen": True}, - "enableSideBySide": True, - "deviceEnvInfo": { - "darkModeEnabled": False, - "devicePixelRatio": 2, - "screenWidth": 1920, - "screenHeight": 1080, - "viewportWidth": 1920, - "viewportHeight": 1080, - }, - "responseMetadata": { - "experiments": [], - "modelConfigOverride": { - "modelMap": { - "videoGenModelConfig": { - "aspectRatio": aspect_ratio, - "parentPostId": post_id, - "resolutionName": resolution_name, - "videoLength": video_length, - } - } - }, - }, - } - - logger.debug(f"Video generation payload: {payload}") - - return payload - - async def _generate_internal( - self, - token: str, - post_id: str, - prompt: str, - aspect_ratio: str, - video_length: int, - resolution_name: str, - preset: str, - ) -> AsyncGenerator[bytes, None]: - """内部生成逻辑""" - session = None - try: - headers = self._build_headers(token) - payload = self._build_payload( - prompt, post_id, aspect_ratio, video_length, resolution_name, preset - ) - - session = AsyncSession(impersonate=get_config("security.browser")) - response = await session.post( - CHAT_API, - headers=headers, - data=orjson.dumps(payload), - timeout=self.timeout, - stream=True, - proxies=self._build_proxies(), - ) - - if response.status_code != 200: - logger.error( - f"Video generation failed: status={response.status_code}, post_id={post_id}" - ) - raise UpstreamException( - message=f"Video generation failed: {response.status_code}", - details={"status": response.status_code}, - ) - - logger.info(f"Video generation started: post_id={post_id}") - - async def stream_response(): - try: - async for line in response.aiter_lines(): - yield line - finally: - await session.close() - - return stream_response() - - except Exception as e: - if session: - try: - await session.close() - except Exception: - pass - logger.error(f"Video generation error: {e}") - if isinstance(e, AppException): - raise - raise UpstreamException(f"Video generation error: {str(e)}") - - async def generate( - self, - token: str, - prompt: str, - aspect_ratio: str = "3:2", - video_length: int = 6, - resolution_name: str = "480p", - preset: str = "normal", - ) -> AsyncGenerator[bytes, None]: - """生成视频""" - logger.info( - f"Video generation: prompt='{prompt[:50]}...', ratio={aspect_ratio}, length={video_length}s, preset={preset}" - ) - async with _get_semaphore(): - post_id = await self.create_post(token, prompt) - return await self._generate_internal( - token, - post_id, - prompt, - aspect_ratio, - video_length, - resolution_name, - preset, - ) - - async def generate_from_image( - self, - token: str, - prompt: str, - image_url: str, - aspect_ratio: str = "3:2", - video_length: int = 6, - resolution: str = "480p", - preset: str = "normal", - ) -> AsyncGenerator[bytes, None]: - """从图片生成视频""" - logger.info( - f"Image to video: prompt='{prompt[:50]}...', image={image_url[:80]}" - ) - async with _get_semaphore(): - post_id = await self.create_image_post(token, image_url) - return await self._generate_internal( - token, post_id, prompt, aspect_ratio, video_length, resolution, preset - ) - - @staticmethod - async def completions( - model: str, - messages: list, - stream: bool = None, - thinking: str = None, - aspect_ratio: str = "3:2", - video_length: int = 6, - resolution: str = "480p", - preset: str = "normal", - ): - """视频生成入口""" - # 获取 token(使用智能路由) - token_mgr = await get_token_manager() - await token_mgr.reload_if_stale() - - # 使用智能路由选择 token(根据视频需求与候选池) - pool_candidates = ModelService.pool_candidates_for_model(model) - token_info = token_mgr.get_token_for_video( - resolution=resolution, - video_length=video_length, - pool_candidates=pool_candidates, - ) - - if not token_info: - raise AppException( - message="No available tokens. Please try again later.", - error_type=ErrorType.RATE_LIMIT.value, - code="rate_limit_exceeded", - status_code=429, - ) - - # 从 TokenInfo 对象中提取 token 字符串 - token = token_info.token - if token.startswith("sso="): - token = token[4:] - - think = {"enabled": True, "disabled": False}.get(thinking) - is_stream = stream if stream is not None else get_config("chat.stream") - - # 提取内容 - from app.services.grok.services.chat import MessageExtractor - from app.services.grok.services.assets import UploadService - - try: - prompt, attachments = MessageExtractor.extract(messages, is_video=True) - except ValueError as e: - raise ValidationException(str(e)) - - # 处理图片附件 - image_url = None - if attachments: - upload_service = UploadService() - try: - for attach_type, attach_data in attachments: - if attach_type == "image": - _, file_uri = await upload_service.upload(attach_data, token) - image_url = f"https://assets.grok.com/{file_uri}" - logger.info(f"Image uploaded for video: {image_url}") - break - finally: - await upload_service.close() - - # 生成视频 - service = VideoService() - if image_url: - response = await service.generate_from_image( - token, prompt, image_url, aspect_ratio, video_length, resolution, preset - ) - else: - response = await service.generate( - token, prompt, aspect_ratio, video_length, resolution, preset - ) - - # 处理响应 - if is_stream: - processor = VideoStreamProcessor(model, token, think) - return wrap_stream_with_usage( - processor.process(response), token_mgr, token, model - ) - - result = await VideoCollectProcessor(model, token).process(response) - try: - model_info = ModelService.get(model) - effort = ( - EffortType.HIGH - if (model_info and model_info.cost.value == "high") - else EffortType.LOW - ) - await token_mgr.consume(token, effort) - logger.debug(f"Video completed, recorded usage (effort={effort.value})") - except Exception as e: - logger.warning(f"Failed to record video usage: {e}") - return result - - -__all__ = ["VideoService"] diff --git a/app/services/grok/models/model.py b/app/services/grok/services/model.py similarity index 77% rename from app/services/grok/models/model.py rename to app/services/grok/services/model.py index e7cd8d61..f5c1e257 100644 --- a/app/services/grok/models/model.py +++ b/app/services/grok/services/model.py @@ -33,8 +33,9 @@ class ModelInfo(BaseModel): cost: Cost = Field(default=Cost.LOW) display_name: str description: str = "" - is_video: bool = False is_image: bool = False + is_image_edit: bool = False + is_video: bool = False class ModelService: @@ -45,105 +46,157 @@ class ModelService: model_id="grok-3", grok_model="grok-3", model_mode="MODEL_MODE_GROK_3", + tier=Tier.BASIC, cost=Cost.LOW, display_name="GROK-3", + is_image=False, + is_image_edit=False, + is_video=False, ), ModelInfo( model_id="grok-3-mini", grok_model="grok-3", model_mode="MODEL_MODE_GROK_3_MINI_THINKING", + tier=Tier.BASIC, cost=Cost.LOW, display_name="GROK-3-MINI", + is_image=False, + is_image_edit=False, + is_video=False, ), ModelInfo( model_id="grok-3-thinking", grok_model="grok-3", model_mode="MODEL_MODE_GROK_3_THINKING", + tier=Tier.BASIC, cost=Cost.LOW, display_name="GROK-3-THINKING", + is_image=False, + is_image_edit=False, + is_video=False, ), ModelInfo( model_id="grok-4", grok_model="grok-4", model_mode="MODEL_MODE_GROK_4", + tier=Tier.BASIC, cost=Cost.LOW, display_name="GROK-4", + is_image=False, + is_image_edit=False, + is_video=False, ), ModelInfo( model_id="grok-4-mini", grok_model="grok-4-mini", model_mode="MODEL_MODE_GROK_4_MINI_THINKING", + tier=Tier.BASIC, cost=Cost.LOW, display_name="GROK-4-MINI", + is_image=False, + is_image_edit=False, + is_video=False, ), ModelInfo( model_id="grok-4-thinking", grok_model="grok-4", model_mode="MODEL_MODE_GROK_4_THINKING", + tier=Tier.BASIC, cost=Cost.LOW, display_name="GROK-4-THINKING", + is_image=False, + is_image_edit=False, + is_video=False, ), ModelInfo( model_id="grok-4-heavy", grok_model="grok-4", model_mode="MODEL_MODE_HEAVY", - cost=Cost.HIGH, tier=Tier.SUPER, + cost=Cost.HIGH, display_name="GROK-4-HEAVY", + is_image=False, + is_image_edit=False, + is_video=False, ), ModelInfo( model_id="grok-4.1-mini", grok_model="grok-4-1-thinking-1129", model_mode="MODEL_MODE_GROK_4_1_MINI_THINKING", + tier=Tier.BASIC, cost=Cost.LOW, display_name="GROK-4.1-MINI", + is_image=False, + is_image_edit=False, + is_video=False, ), ModelInfo( model_id="grok-4.1-fast", grok_model="grok-4-1-thinking-1129", model_mode="MODEL_MODE_FAST", + tier=Tier.BASIC, cost=Cost.LOW, display_name="GROK-4.1-FAST", + is_image=False, + is_image_edit=False, + is_video=False, ), ModelInfo( model_id="grok-4.1-expert", grok_model="grok-4-1-thinking-1129", model_mode="MODEL_MODE_EXPERT", + tier=Tier.SUPER, cost=Cost.HIGH, display_name="GROK-4.1-EXPERT", + is_image=False, + is_image_edit=False, + is_video=False, ), ModelInfo( model_id="grok-4.1-thinking", grok_model="grok-4-1-thinking-1129", model_mode="MODEL_MODE_GROK_4_1_THINKING", + tier=Tier.SUPER, cost=Cost.HIGH, display_name="GROK-4.1-THINKING", + is_image=False, + is_image_edit=False, + is_video=False, ), ModelInfo( model_id="grok-imagine-1.0", grok_model="grok-3", model_mode="MODEL_MODE_FAST", + tier=Tier.BASIC, cost=Cost.HIGH, display_name="Grok Image", description="Image generation model", is_image=True, + is_image_edit=False, + is_video=False, ), ModelInfo( model_id="grok-imagine-1.0-edit", grok_model="imagine-image-edit", model_mode="MODEL_MODE_FAST", + tier=Tier.BASIC, cost=Cost.HIGH, display_name="Grok Image Edit", description="Image edit model", - is_image=True, + is_image=False, + is_image_edit=True, + is_video=False, ), ModelInfo( model_id="grok-imagine-1.0-video", grok_model="grok-3", model_mode="MODEL_MODE_FAST", + tier=Tier.BASIC, cost=Cost.HIGH, display_name="Grok Video", description="Video generation model", + is_image=False, + is_image_edit=False, is_video=True, ), ] diff --git a/app/services/grok/services/nsfw.py b/app/services/grok/services/nsfw.py deleted file mode 100644 index 26a4f261..00000000 --- a/app/services/grok/services/nsfw.py +++ /dev/null @@ -1,183 +0,0 @@ -""" -NSFW (Unhinged) 模式服务 - -使用 gRPC-Web 协议开启账号的 NSFW 功能。 -""" - -from dataclasses import dataclass -from typing import Optional -import datetime -import random - -from curl_cffi.requests import AsyncSession - -from app.core.config import get_config -from app.core.logger import logger -from app.services.grok.protocols.grpc_web import ( - encode_grpc_web_payload, - parse_grpc_web_response, - get_grpc_status, -) -from app.services.grok.utils.headers import build_sso_cookie - -NSFW_API = "https://grok.com/auth_mgmt.AuthManagement/UpdateUserFeatureControls" -BIRTH_DATE_API = "https://grok.com/rest/auth/set-birth-date" - - -@dataclass -class NSFWResult: - """NSFW 操作结果""" - - success: bool - http_status: int - grpc_status: Optional[int] = None - grpc_message: Optional[str] = None - error: Optional[str] = None - - -class NSFWService: - """NSFW 模式服务""" - - def __init__(self, proxy: str = None): - self.proxy = proxy or get_config("network.base_proxy_url") - self.timeout = float(get_config("network.timeout")) - - def _build_proxies(self) -> Optional[dict]: - """构建代理配置""" - return {"http": self.proxy, "https": self.proxy} if self.proxy else None - - @staticmethod - def _random_birth_date() -> str: - """生成随机出生日期(20-40岁)""" - today = datetime.date.today() - birth_year = today.year - random.randint(20, 40) - birth_month = random.randint(1, 12) - birth_day = random.randint(1, 28) - hour = random.randint(0, 23) - minute = random.randint(0, 59) - second = random.randint(0, 59) - microsecond = random.randint(0, 999) - return f"{birth_year:04d}-{birth_month:02d}-{birth_day:02d}T{hour:02d}:{minute:02d}:{second:02d}.{microsecond:03d}Z" - - def _build_headers(self, token: str) -> dict: - """构造 gRPC-Web 请求头""" - cookie = build_sso_cookie(token, include_rw=True) - user_agent = get_config("security.user_agent") - return { - "accept": "*/*", - "content-type": "application/grpc-web+proto", - "origin": "https://grok.com", - "referer": "https://grok.com/", - "user-agent": user_agent, - "x-grpc-web": "1", - "x-user-agent": "connect-es/2.1.1", - "cookie": cookie, - } - - def _build_birth_headers(self, token: str) -> dict: - """构造设置出生日期请求头""" - cookie = build_sso_cookie(token, include_rw=True) - user_agent = get_config("security.user_agent") - return { - "accept": "*/*", - "content-type": "application/json", - "origin": "https://grok.com", - "referer": "https://grok.com/?_s=account", - "user-agent": user_agent, - "cookie": cookie, - } - - @staticmethod - def _build_payload() -> bytes: - """构造请求 payload""" - # protobuf (match captured HAR): - # 0a 02 10 01 -> field 1 (len=2) with inner bool=true - # 12 1a -> field 2, length 26 - # 0a 18 -> nested message with name string - name = b"always_show_nsfw_content" - inner = b"\x0a" + bytes([len(name)]) + name - protobuf = b"\x0a\x02\x10\x01\x12" + bytes([len(inner)]) + inner - return encode_grpc_web_payload(protobuf) - - async def _set_birth_date( - self, session: AsyncSession, token: str - ) -> tuple[bool, int, Optional[str]]: - """设置出生日期""" - headers = self._build_birth_headers(token) - payload = {"birthDate": self._random_birth_date()} - - try: - response = await session.post( - BIRTH_DATE_API, - json=payload, - headers=headers, - timeout=self.timeout, - proxies=self._build_proxies(), - ) - if response.status_code in (200, 204): - return True, response.status_code, None - return False, response.status_code, f"HTTP {response.status_code}" - except Exception as e: - return False, 0, str(e)[:100] - - async def enable(self, token: str) -> NSFWResult: - """为单个 token 开启 NSFW 模式""" - headers = self._build_headers(token) - payload = self._build_payload() - logger.debug(f"NSFW payload: len={len(payload)} hex={payload.hex()}") - - try: - browser = get_config("security.browser") - async with AsyncSession(impersonate=browser) as session: - # 先设置出生日期 - ok, birth_status, birth_err = await self._set_birth_date(session, token) - if not ok: - return NSFWResult( - success=False, - http_status=birth_status, - error=f"Set birth date failed: {birth_err}", - ) - - # 开启 NSFW - response = await session.post( - NSFW_API, - data=payload, - headers=headers, - timeout=self.timeout, - proxies=self._build_proxies(), - ) - - if response.status_code != 200: - return NSFWResult( - success=False, - http_status=response.status_code, - error=f"HTTP {response.status_code}", - ) - - # 解析 gRPC-Web 响应 - _, trailers = parse_grpc_web_response( - response.content, content_type=response.headers.get("content-type") - ) - - grpc_status = get_grpc_status(trailers) - logger.debug( - f"NSFW response: http={response.status_code} grpc={grpc_status.code} " - f"msg={grpc_status.message} trailers={trailers}" - ) - - # HTTP 200 且无 grpc-status(空响应)或 grpc-status=0 都算成功 - success = grpc_status.code == -1 or grpc_status.ok - - return NSFWResult( - success=success, - http_status=response.status_code, - grpc_status=grpc_status.code, - grpc_message=grpc_status.message or None, - ) - - except Exception as e: - logger.error(f"NSFW enable failed: {e}") - return NSFWResult(success=False, http_status=0, error=str(e)[:100]) - - -__all__ = ["NSFWService", "NSFWResult"] diff --git a/app/services/grok/services/usage.py b/app/services/grok/services/usage.py deleted file mode 100644 index 7550c822..00000000 --- a/app/services/grok/services/usage.py +++ /dev/null @@ -1,150 +0,0 @@ -""" -Grok 用量服务 -""" - -import asyncio -from typing import Dict - -from curl_cffi.requests import AsyncSession - -from app.core.logger import logger -from app.core.config import get_config -from app.core.exceptions import UpstreamException -from app.services.grok.utils.headers import apply_statsig, build_sso_cookie -from app.services.grok.utils.retry import retry_on_status - -LIMITS_API = "https://grok.com/rest/rate-limits" - -_USAGE_SEMAPHORE = asyncio.Semaphore(25) -_USAGE_SEM_VALUE = 25 - - -class UsageService: - """用量查询服务""" - - def __init__(self, proxy: str = None): - self.proxy = proxy or get_config("network.base_proxy_url") - self.timeout = get_config("network.timeout") - - def _build_headers(self, token: str) -> dict: - """构建请求头""" - user_agent = get_config("security.user_agent") - headers = { - "Accept": "*/*", - "Accept-Encoding": "gzip, deflate, br, zstd", - "Accept-Language": "zh-CN,zh;q=0.9", - "Baggage": "sentry-environment=production,sentry-release=d6add6fb0460641fd482d767a335ef72b9b6abb8,sentry-public_key=b311e0f2690c81f25e2c4cf6d4f7ce1c", - "Cache-Control": "no-cache", - "Content-Type": "application/json", - "Origin": "https://grok.com", - "Pragma": "no-cache", - "Priority": "u=1, i", - "Referer": "https://grok.com/", - "Sec-Ch-Ua": '"Google Chrome";v="136", "Chromium";v="136", "Not(A:Brand";v="24"', - "Sec-Ch-Ua-Arch": "arm", - "Sec-Ch-Ua-Bitness": "64", - "Sec-Ch-Ua-Mobile": "?0", - "Sec-Ch-Ua-Model": "", - "Sec-Ch-Ua-Platform": '"macOS"', - "Sec-Fetch-Dest": "empty", - "Sec-Fetch-Mode": "cors", - "Sec-Fetch-Site": "same-origin", - "User-Agent": user_agent, - } - - apply_statsig(headers) - headers["Cookie"] = build_sso_cookie(token) - - return headers - - def _build_proxies(self) -> dict: - """构建代理配置""" - return {"http": self.proxy, "https": self.proxy} if self.proxy else None - - async def get(self, token: str, model_name: str = "grok-4-1-thinking-1129") -> Dict: - """ - 获取速率限制信息 - - Args: - token: 认证 Token - model_name: 模型名称 - - Returns: - 响应数据 - - Raises: - UpstreamException: 当获取失败且重试耗尽时 - """ - value = get_config("performance.usage_max_concurrent") - try: - value = int(value) - except Exception: - value = 25 - value = max(1, value) - global _USAGE_SEMAPHORE, _USAGE_SEM_VALUE - if value != _USAGE_SEM_VALUE: - _USAGE_SEM_VALUE = value - _USAGE_SEMAPHORE = asyncio.Semaphore(value) - async with _USAGE_SEMAPHORE: - # 定义状态码提取器 - def extract_status(e: Exception) -> int | None: - if isinstance(e, UpstreamException) and e.details: - return e.details.get("status") - return None - - # 定义实际的请求函数 - async def do_request(): - try: - headers = self._build_headers(token) - payload = {"requestKind": "DEFAULT", "modelName": model_name} - browser = get_config("security.browser") - - async with AsyncSession() as session: - response = await session.post( - LIMITS_API, - headers=headers, - json=payload, - impersonate=browser, - timeout=self.timeout, - proxies=self._build_proxies(), - ) - - if response.status_code == 200: - data = response.json() - remaining = data.get("remainingTokens", 0) - logger.info( - f"Usage sync success: remaining={remaining}, token={token[:10]}..." - ) - return data - - logger.error( - f"Usage sync failed: status={response.status_code}, token={token[:10]}..." - ) - - raise UpstreamException( - message=f"Failed to get usage stats: {response.status_code}", - details={"status": response.status_code}, - ) - - except Exception as e: - if isinstance(e, UpstreamException): - raise - logger.error(f"Usage error: {e}") - raise UpstreamException( - message=f"Usage service error: {str(e)}", - details={"error": str(e)}, - ) - - # 带重试的执行 - try: - result = await retry_on_status( - do_request, extract_status=extract_status - ) - return result - - except Exception: - # 最后一次失败已经被记录 - raise - - -__all__ = ["UsageService"] diff --git a/app/services/grok/services/video.py b/app/services/grok/services/video.py new file mode 100644 index 00000000..70f477e3 --- /dev/null +++ b/app/services/grok/services/video.py @@ -0,0 +1,678 @@ +""" +Grok video generation service. +""" + +import asyncio +import uuid +import re +from typing import Any, AsyncGenerator, AsyncIterable, Optional + +import orjson +from curl_cffi.requests import AsyncSession +from curl_cffi.requests.errors import RequestsError + +from app.core.logger import logger +from app.core.config import get_config +from app.core.exceptions import ( + UpstreamException, + AppException, + ValidationException, + ErrorType, + StreamIdleTimeoutError, +) +from app.services.grok.services.model import ModelService +from app.services.token import get_token_manager, EffortType +from app.services.grok.utils.stream import wrap_stream_with_usage +from app.services.grok.utils.process import ( + BaseProcessor, + _with_idle_timeout, + _normalize_line, + _is_http2_error, +) +from app.services.grok.utils.retry import rate_limited +from app.services.reverse.app_chat import AppChatReverse +from app.services.reverse.media_post import MediaPostReverse +from app.services.reverse.video_upscale import VideoUpscaleReverse +from app.services.token.manager import BASIC_POOL_NAME + +_VIDEO_SEMAPHORE = None +_VIDEO_SEM_VALUE = 0 + +def _get_video_semaphore() -> asyncio.Semaphore: + """Reverse 接口并发控制(video 服务)。""" + global _VIDEO_SEMAPHORE, _VIDEO_SEM_VALUE + value = max(1, int(get_config("video.concurrent"))) + if value != _VIDEO_SEM_VALUE: + _VIDEO_SEM_VALUE = value + _VIDEO_SEMAPHORE = asyncio.Semaphore(value) + return _VIDEO_SEMAPHORE + + +class VideoService: + """Video generation service.""" + + def __init__(self): + self.timeout = None + + async def create_post( + self, + token: str, + prompt: str, + media_type: str = "MEDIA_POST_TYPE_VIDEO", + media_url: str = None, + ) -> str: + """Create media post and return post ID.""" + try: + if media_type == "MEDIA_POST_TYPE_IMAGE" and not media_url: + raise ValidationException("media_url is required for image posts") + + prompt_value = prompt if media_type == "MEDIA_POST_TYPE_VIDEO" else "" + media_value = media_url or "" + + async with AsyncSession() as session: + async with _get_video_semaphore(): + response = await MediaPostReverse.request( + session, + token, + media_type, + media_value, + prompt=prompt_value, + ) + + post_id = response.json().get("post", {}).get("id", "") + if not post_id: + raise UpstreamException("No post ID in response") + + logger.info(f"Media post created: {post_id} (type={media_type})") + return post_id + + except AppException: + raise + except Exception as e: + logger.error(f"Create post error: {e}") + raise UpstreamException(f"Create post error: {str(e)}") + + async def create_image_post(self, token: str, image_url: str) -> str: + """Create image post and return post ID.""" + return await self.create_post( + token, prompt="", media_type="MEDIA_POST_TYPE_IMAGE", media_url=image_url + ) + + async def generate( + self, + token: str, + prompt: str, + aspect_ratio: str = "3:2", + video_length: int = 6, + resolution_name: str = "480p", + preset: str = "normal", + ) -> AsyncGenerator[bytes, None]: + """Generate video.""" + logger.info( + f"Video generation: prompt='{prompt[:50]}...', ratio={aspect_ratio}, length={video_length}s, preset={preset}" + ) + post_id = await self.create_post(token, prompt) + mode_map = { + "fun": "--mode=extremely-crazy", + "normal": "--mode=normal", + "spicy": "--mode=extremely-spicy-or-crazy", + } + mode_flag = mode_map.get(preset, "--mode=custom") + message = f"{prompt} {mode_flag}" + model_config_override = { + "modelMap": { + "videoGenModelConfig": { + "aspectRatio": aspect_ratio, + "parentPostId": post_id, + "resolutionName": resolution_name, + "videoLength": video_length, + } + } + } + + async def _stream(): + session = AsyncSession() + try: + async with _get_video_semaphore(): + stream_response = await AppChatReverse.request( + session, + token, + message=message, + model="grok-3", + tool_overrides={"videoGen": True}, + model_config_override=model_config_override, + ) + logger.info(f"Video generation started: post_id={post_id}") + async for line in stream_response: + yield line + except Exception as e: + try: + await session.close() + except Exception: + pass + logger.error(f"Video generation error: {e}") + if isinstance(e, AppException): + raise + raise UpstreamException(f"Video generation error: {str(e)}") + + return _stream() + + async def generate_from_image( + self, + token: str, + prompt: str, + image_url: str, + aspect_ratio: str = "3:2", + video_length: int = 6, + resolution: str = "480p", + preset: str = "normal", + ) -> AsyncGenerator[bytes, None]: + """Generate video from image.""" + logger.info( + f"Image to video: prompt='{prompt[:50]}...', image={image_url[:80]}" + ) + post_id = await self.create_image_post(token, image_url) + mode_map = { + "fun": "--mode=extremely-crazy", + "normal": "--mode=normal", + "spicy": "--mode=extremely-spicy-or-crazy", + } + mode_flag = mode_map.get(preset, "--mode=custom") + message = f"{prompt} {mode_flag}" + model_config_override = { + "modelMap": { + "videoGenModelConfig": { + "aspectRatio": aspect_ratio, + "parentPostId": post_id, + "resolutionName": resolution, + "videoLength": video_length, + } + } + } + + async def _stream(): + session = AsyncSession() + try: + async with _get_video_semaphore(): + stream_response = await AppChatReverse.request( + session, + token, + message=message, + model="grok-3", + tool_overrides={"videoGen": True}, + model_config_override=model_config_override, + ) + logger.info(f"Video generation started: post_id={post_id}") + async for line in stream_response: + yield line + except Exception as e: + try: + await session.close() + except Exception: + pass + logger.error(f"Video generation error: {e}") + if isinstance(e, AppException): + raise + raise UpstreamException(f"Video generation error: {str(e)}") + + return _stream() + + @staticmethod + async def completions( + model: str, + messages: list, + stream: bool = None, + reasoning_effort: str | None = None, + aspect_ratio: str = "3:2", + video_length: int = 6, + resolution: str = "480p", + preset: str = "normal", + ): + """Video generation entrypoint.""" + # Get token via intelligent routing. + token_mgr = await get_token_manager() + await token_mgr.reload_if_stale() + + max_token_retries = int(get_config("retry.max_retry")) + last_error: Exception | None = None + + if reasoning_effort is None: + show_think = get_config("app.thinking") + else: + show_think = reasoning_effort != "none" + is_stream = stream if stream is not None else get_config("app.stream") + + # Extract content. + from app.services.grok.services.chat import MessageExtractor + from app.services.grok.utils.upload import UploadService + + prompt, file_attachments, image_attachments = MessageExtractor.extract(messages) + + for attempt in range(max_token_retries): + # Select token based on video requirements and pool candidates. + pool_candidates = ModelService.pool_candidates_for_model(model) + token_info = token_mgr.get_token_for_video( + resolution=resolution, + video_length=video_length, + pool_candidates=pool_candidates, + ) + + if not token_info: + if last_error: + raise last_error + raise AppException( + message="No available tokens. Please try again later.", + error_type=ErrorType.RATE_LIMIT.value, + code="rate_limit_exceeded", + status_code=429, + ) + + # Extract token string from TokenInfo. + token = token_info.token + if token.startswith("sso="): + token = token[4:] + pool_name = token_mgr.get_pool_name_for_token(token) + should_upscale = resolution == "720p" and pool_name == BASIC_POOL_NAME + + try: + # Handle image attachments. + image_url = None + if image_attachments: + upload_service = UploadService() + try: + for attach_data in image_attachments: + _, file_uri = await upload_service.upload_file( + attach_data, token + ) + image_url = f"https://assets.grok.com/{file_uri}" + logger.info(f"Image uploaded for video: {image_url}") + break + finally: + await upload_service.close() + + # Generate video. + service = VideoService() + if image_url: + response = await service.generate_from_image( + token, + prompt, + image_url, + aspect_ratio, + video_length, + resolution, + preset, + ) + else: + response = await service.generate( + token, + prompt, + aspect_ratio, + video_length, + resolution, + preset, + ) + + # Process response. + if is_stream: + processor = VideoStreamProcessor( + model, + token, + show_think, + upscale_on_finish=should_upscale, + ) + return wrap_stream_with_usage( + processor.process(response), token_mgr, token, model + ) + + result = await VideoCollectProcessor( + model, token, upscale_on_finish=should_upscale + ).process(response) + try: + model_info = ModelService.get(model) + effort = ( + EffortType.HIGH + if (model_info and model_info.cost.value == "high") + else EffortType.LOW + ) + await token_mgr.consume(token, effort) + logger.debug( + f"Video completed, recorded usage (effort={effort.value})" + ) + except Exception as e: + logger.warning(f"Failed to record video usage: {e}") + return result + + except UpstreamException as e: + last_error = e + if rate_limited(e): + await token_mgr.mark_rate_limited(token) + logger.warning( + f"Token {token[:10]}... rate limited (429), " + f"trying next token (attempt {attempt + 1}/{max_token_retries})" + ) + continue + raise + + if last_error: + raise last_error + raise AppException( + message="No available tokens. Please try again later.", + error_type=ErrorType.RATE_LIMIT.value, + code="rate_limit_exceeded", + status_code=429, + ) + + +class VideoStreamProcessor(BaseProcessor): + """Video stream response processor.""" + + def __init__( + self, + model: str, + token: str = "", + show_think: bool = None, + upscale_on_finish: bool = False, + ): + super().__init__(model, token) + self.response_id: Optional[str] = None + self.think_opened: bool = False + self.role_sent: bool = False + + self.show_think = bool(show_think) + self.upscale_on_finish = bool(upscale_on_finish) + + @staticmethod + def _extract_video_id(video_url: str) -> str: + if not video_url: + return "" + match = re.search(r"/generated/([0-9a-fA-F-]{32,36})/", video_url) + if match: + return match.group(1) + match = re.search(r"/([0-9a-fA-F-]{32,36})/generated_video", video_url) + if match: + return match.group(1) + return "" + + async def _upscale_video_url(self, video_url: str) -> str: + if not video_url or not self.upscale_on_finish: + return video_url + video_id = self._extract_video_id(video_url) + if not video_id: + logger.warning("Video upscale skipped: unable to extract video id") + return video_url + try: + async with AsyncSession() as session: + response = await VideoUpscaleReverse.request( + session, self.token, video_id + ) + payload = response.json() if response is not None else {} + hd_url = payload.get("hdMediaUrl") if isinstance(payload, dict) else None + if hd_url: + logger.info(f"Video upscale completed: {hd_url}") + return hd_url + except Exception as e: + logger.warning(f"Video upscale failed: {e}") + return video_url + + def _sse(self, content: str = "", role: str = None, finish: str = None) -> str: + """Build SSE response.""" + delta = {} + if role: + delta["role"] = role + delta["content"] = "" + elif content: + delta["content"] = content + + chunk = { + "id": self.response_id or f"chatcmpl-{uuid.uuid4().hex[:24]}", + "object": "chat.completion.chunk", + "created": self.created, + "model": self.model, + "choices": [ + {"index": 0, "delta": delta, "logprobs": None, "finish_reason": finish} + ], + } + return f"data: {orjson.dumps(chunk).decode()}\n\n" + + async def process( + self, response: AsyncIterable[bytes] + ) -> AsyncGenerator[str, None]: + """Process video stream response.""" + idle_timeout = get_config("video.stream_timeout") + + try: + async for line in _with_idle_timeout(response, idle_timeout, self.model): + line = _normalize_line(line) + if not line: + continue + try: + data = orjson.loads(line) + except orjson.JSONDecodeError: + continue + + resp = data.get("result", {}).get("response", {}) + is_thinking = bool(resp.get("isThinking")) + + if rid := resp.get("responseId"): + self.response_id = rid + + if not self.role_sent: + yield self._sse(role="assistant") + self.role_sent = True + + if token := resp.get("token"): + if is_thinking: + if not self.show_think: + continue + if not self.think_opened: + yield self._sse("\n") + self.think_opened = True + else: + if self.think_opened: + yield self._sse("\n\n") + self.think_opened = False + yield self._sse(token) + continue + + if video_resp := resp.get("streamingVideoGenerationResponse"): + progress = video_resp.get("progress", 0) + + if is_thinking: + if not self.show_think: + continue + if not self.think_opened: + yield self._sse("\n") + self.think_opened = True + else: + if self.think_opened: + yield self._sse("\n\n") + self.think_opened = False + if self.show_think: + yield self._sse(f"正在生成视频中,当前进度{progress}%\n") + + if progress == 100: + video_url = video_resp.get("videoUrl", "") + thumbnail_url = video_resp.get("thumbnailImageUrl", "") + + if self.think_opened: + yield self._sse("\n\n") + self.think_opened = False + + if video_url: + if self.upscale_on_finish: + yield self._sse("正在对视频进行超分辨率\n") + video_url = await self._upscale_video_url(video_url) + dl_service = self._get_dl() + rendered = await dl_service.render_video( + video_url, self.token, thumbnail_url + ) + yield self._sse(rendered) + + logger.info(f"Video generated: {video_url}") + continue + + if self.think_opened: + yield self._sse("\n") + yield self._sse(finish="stop") + yield "data: [DONE]\n\n" + except asyncio.CancelledError: + logger.debug( + "Video stream cancelled by client", extra={"model": self.model} + ) + except StreamIdleTimeoutError as e: + raise UpstreamException( + message=f"Video stream idle timeout after {e.idle_seconds}s", + status_code=504, + details={ + "error": str(e), + "type": "stream_idle_timeout", + "idle_seconds": e.idle_seconds, + }, + ) + except RequestsError as e: + if _is_http2_error(e): + logger.warning( + f"HTTP/2 stream error in video: {e}", extra={"model": self.model} + ) + raise UpstreamException( + message="Upstream connection closed unexpectedly", + status_code=502, + details={"error": str(e), "type": "http2_stream_error"}, + ) + logger.error( + f"Video stream request error: {e}", extra={"model": self.model} + ) + raise UpstreamException( + message=f"Upstream request failed: {e}", + status_code=502, + details={"error": str(e)}, + ) + except Exception as e: + logger.error( + f"Video stream processing error: {e}", + extra={"model": self.model, "error_type": type(e).__name__}, + ) + finally: + await self.close() + + +class VideoCollectProcessor(BaseProcessor): + """Video non-stream response processor.""" + + def __init__(self, model: str, token: str = "", upscale_on_finish: bool = False): + super().__init__(model, token) + self.upscale_on_finish = bool(upscale_on_finish) + + @staticmethod + def _extract_video_id(video_url: str) -> str: + if not video_url: + return "" + match = re.search(r"/generated/([0-9a-fA-F-]{32,36})/", video_url) + if match: + return match.group(1) + match = re.search(r"/([0-9a-fA-F-]{32,36})/generated_video", video_url) + if match: + return match.group(1) + return "" + + async def _upscale_video_url(self, video_url: str) -> str: + if not video_url or not self.upscale_on_finish: + return video_url + video_id = self._extract_video_id(video_url) + if not video_id: + logger.warning("Video upscale skipped: unable to extract video id") + return video_url + try: + async with AsyncSession() as session: + response = await VideoUpscaleReverse.request( + session, self.token, video_id + ) + payload = response.json() if response is not None else {} + hd_url = payload.get("hdMediaUrl") if isinstance(payload, dict) else None + if hd_url: + logger.info(f"Video upscale completed: {hd_url}") + return hd_url + except Exception as e: + logger.warning(f"Video upscale failed: {e}") + return video_url + + async def process(self, response: AsyncIterable[bytes]) -> dict[str, Any]: + """Process and collect video response.""" + response_id = "" + content = "" + idle_timeout = get_config("video.stream_timeout") + + try: + async for line in _with_idle_timeout(response, idle_timeout, self.model): + line = _normalize_line(line) + if not line: + continue + try: + data = orjson.loads(line) + except orjson.JSONDecodeError: + continue + + resp = data.get("result", {}).get("response", {}) + + if video_resp := resp.get("streamingVideoGenerationResponse"): + if video_resp.get("progress") == 100: + response_id = resp.get("responseId", "") + video_url = video_resp.get("videoUrl", "") + thumbnail_url = video_resp.get("thumbnailImageUrl", "") + + if video_url: + if self.upscale_on_finish: + video_url = await self._upscale_video_url(video_url) + dl_service = self._get_dl() + content = await dl_service.render_video( + video_url, self.token, thumbnail_url + ) + logger.info(f"Video generated: {video_url}") + + except asyncio.CancelledError: + logger.debug( + "Video collect cancelled by client", extra={"model": self.model} + ) + except StreamIdleTimeoutError as e: + logger.warning( + f"Video collect idle timeout: {e}", extra={"model": self.model} + ) + except RequestsError as e: + if _is_http2_error(e): + logger.warning( + f"HTTP/2 stream error in video collect: {e}", + extra={"model": self.model}, + ) + else: + logger.error( + f"Video collect request error: {e}", extra={"model": self.model} + ) + except Exception as e: + logger.error( + f"Video collect processing error: {e}", + extra={"model": self.model, "error_type": type(e).__name__}, + ) + finally: + await self.close() + + return { + "id": response_id, + "object": "chat.completion", + "created": self.created, + "model": self.model, + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": content, + "refusal": None, + }, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + } + + +__all__ = ["VideoService"] diff --git a/app/services/grok/services/voice.py b/app/services/grok/services/voice.py index 006fc547..b72fce3e 100644 --- a/app/services/grok/services/voice.py +++ b/app/services/grok/services/voice.py @@ -2,25 +2,17 @@ Grok Voice Mode Service """ -import orjson -from typing import Dict, Any +from typing import Any, Dict from curl_cffi.requests import AsyncSession -from app.core.logger import logger from app.core.config import get_config -from app.core.exceptions import UpstreamException -from app.services.grok.utils.headers import apply_statsig, build_sso_cookie - -LIVEKIT_TOKEN_API = "https://grok.com/rest/livekit/tokens" +from app.services.reverse.ws_livekit import LivekitTokenReverse class VoiceService: """Voice Mode Service (LiveKit)""" - def __init__(self, proxy: str = None): - self.proxy = proxy or get_config("network.base_proxy_url") - async def get_token( self, token: str, @@ -28,86 +20,13 @@ async def get_token( personality: str = "assistant", speed: float = 1.0, ) -> Dict[str, Any]: - """ - Get LiveKit token - - Args: - token: Auth token - Returns: - Dict containing token and livekitUrl - """ - logger.debug( - f"Voice token request: voice={voice}, personality={personality}, speed={speed}" - ) - headers = self._build_headers(token) - payload = self._build_payload(voice, personality, speed) - - proxies = {"http": self.proxy, "https": self.proxy} if self.proxy else None - - try: - browser = get_config("security.browser") - timeout = get_config("network.timeout") - async with AsyncSession(impersonate=browser) as session: - response = await session.post( - LIVEKIT_TOKEN_API, - headers=headers, - data=orjson.dumps(payload), - timeout=timeout, - proxies=proxies, - ) - - if response.status_code != 200: - body = response.text[:200] - logger.error( - f"Voice token failed: status={response.status_code}, body={body}" - ) - raise UpstreamException( - message=f"Failed to get voice token: {response.status_code}", - details={"status": response.status_code, "body": response.text}, - ) - - result = response.json() - logger.info(f"Voice token obtained: voice={voice}") - return result - - except Exception as e: - logger.error(f"Voice service error: {e}") - if isinstance(e, UpstreamException): - raise - raise UpstreamException(f"Voice service error: {str(e)}") - - def _build_headers(self, token: str) -> Dict[str, str]: - headers = { - "Accept": "*/*", - "Content-Type": "application/json", - "Origin": "https://grok.com", - "Referer": "https://grok.com/", - # Statsig ID is crucial - } - - apply_statsig(headers) - headers["Cookie"] = build_sso_cookie(token) - - return headers - - def _build_payload( - self, - voice: str = "ara", - personality: str = "assistant", - speed: float = 1.0, - ) -> Dict[str, Any]: - """Construct payload with voice settings""" - return { - "sessionPayload": orjson.dumps( - { - "voice": voice, - "personality": personality, - "playback_speed": speed, - "enable_vision": False, - "turn_detection": {"type": "server_vad"}, - } - ).decode(), - "requestAgentDispatch": False, - "livekitUrl": "wss://livekit.grok.com", - "params": {"enable_markdown_transcript": "true"}, - } + browser = get_config("proxy.browser") + async with AsyncSession(impersonate=browser) as session: + response = await LivekitTokenReverse.request( + session, + token=token, + voice=voice, + personality=personality, + speed=speed, + ) + return response.json() diff --git a/app/services/grok/utils/__init__.py b/app/services/grok/utils/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/app/services/grok/utils/batch.py b/app/services/grok/utils/batch.py deleted file mode 100644 index adb64ea2..00000000 --- a/app/services/grok/utils/batch.py +++ /dev/null @@ -1,86 +0,0 @@ -""" -批量执行工具 - -提供分批并发、单项失败隔离的通用批量处理能力。 -""" - -import asyncio -from typing import Any, Awaitable, Callable, Dict, List, Optional, TypeVar - -from app.core.logger import logger - -T = TypeVar("T") - - -async def run_in_batches( - items: List[str], - worker: Callable[[str], Awaitable[T]], - *, - max_concurrent: int = 10, - batch_size: int = 50, - on_item: Optional[Callable[[str, Dict[str, Any]], Awaitable[None]]] = None, - should_cancel: Optional[Callable[[], bool]] = None, -) -> Dict[str, Dict[str, Any]]: - """ - 分批并发执行,单项失败不影响整体 - - Args: - items: 待处理项列表 - worker: 异步处理函数 - max_concurrent: 最大并发数 - batch_size: 每批大小 - - Returns: - {item: {"ok": bool, "data": ..., "error": ...}} - """ - try: - max_concurrent = int(max_concurrent) - except Exception: - max_concurrent = 10 - try: - batch_size = int(batch_size) - except Exception: - batch_size = 50 - - max_concurrent = max(1, max_concurrent) - batch_size = max(1, batch_size) - - sem = asyncio.Semaphore(max_concurrent) - - async def _one(item: str) -> tuple[str, dict]: - if should_cancel and should_cancel(): - return item, {"ok": False, "error": "cancelled", "cancelled": True} - async with sem: - try: - data = await worker(item) - result = {"ok": True, "data": data} - if on_item: - try: - await on_item(item, result) - except Exception: - pass - return item, result - except Exception as e: - logger.warning(f"Batch item failed: {item[:16]}... - {e}") - result = {"ok": False, "error": str(e)} - if on_item: - try: - await on_item(item, result) - except Exception: - pass - return item, result - - results: Dict[str, dict] = {} - - # 分批执行,避免一次性创建所有 task - for i in range(0, len(items), batch_size): - if should_cancel and should_cancel(): - break - chunk = items[i : i + batch_size] - pairs = await asyncio.gather(*(_one(x) for x in chunk)) - results.update(dict(pairs)) - - return results - - -__all__ = ["run_in_batches"] diff --git a/app/services/grok/utils/cache.py b/app/services/grok/utils/cache.py new file mode 100644 index 00000000..a728df15 --- /dev/null +++ b/app/services/grok/utils/cache.py @@ -0,0 +1,110 @@ +""" +Local cache utilities. +""" + +from typing import Any, Dict + +from app.core.storage import DATA_DIR + +IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"} +VIDEO_EXTS = {".mp4", ".mov", ".m4v", ".webm", ".avi", ".mkv"} + + +class CacheService: + """Local cache service.""" + + def __init__(self): + base_dir = DATA_DIR / "tmp" + self.image_dir = base_dir / "image" + self.video_dir = base_dir / "video" + self.image_dir.mkdir(parents=True, exist_ok=True) + self.video_dir.mkdir(parents=True, exist_ok=True) + + def _cache_dir(self, media_type: str): + return self.image_dir if media_type == "image" else self.video_dir + + def _allowed_exts(self, media_type: str): + return IMAGE_EXTS if media_type == "image" else VIDEO_EXTS + + def get_stats(self, media_type: str = "image") -> Dict[str, Any]: + cache_dir = self._cache_dir(media_type) + if not cache_dir.exists(): + return {"count": 0, "size_mb": 0.0} + + allowed = self._allowed_exts(media_type) + files = [ + f for f in cache_dir.glob("*") if f.is_file() and f.suffix.lower() in allowed + ] + total_size = sum(f.stat().st_size for f in files) + return {"count": len(files), "size_mb": round(total_size / 1024 / 1024, 2)} + + def list_files( + self, media_type: str = "image", page: int = 1, page_size: int = 1000 + ) -> Dict[str, Any]: + cache_dir = self._cache_dir(media_type) + if not cache_dir.exists(): + return {"total": 0, "page": page, "page_size": page_size, "items": []} + + allowed = self._allowed_exts(media_type) + files = [ + f for f in cache_dir.glob("*") if f.is_file() and f.suffix.lower() in allowed + ] + + items = [] + for f in files: + try: + stat = f.stat() + items.append( + { + "name": f.name, + "size_bytes": stat.st_size, + "mtime_ms": int(stat.st_mtime * 1000), + } + ) + except Exception: + continue + + items.sort(key=lambda x: x["mtime_ms"], reverse=True) + + total = len(items) + start = max(0, (page - 1) * page_size) + paged = items[start : start + page_size] + + for item in paged: + item["view_url"] = f"/v1/files/{media_type}/{item['name']}" + + return {"total": total, "page": page, "page_size": page_size, "items": paged} + + def delete_file(self, media_type: str, name: str) -> Dict[str, Any]: + cache_dir = self._cache_dir(media_type) + file_path = cache_dir / name.replace("/", "-") + + if file_path.exists(): + try: + file_path.unlink() + return {"deleted": True} + except Exception: + pass + return {"deleted": False} + + def clear(self, media_type: str = "image") -> Dict[str, Any]: + cache_dir = self._cache_dir(media_type) + if not cache_dir.exists(): + return {"count": 0, "size_mb": 0.0} + + files = list(cache_dir.glob("*")) + total_size = sum(f.stat().st_size for f in files if f.is_file()) + count = 0 + + for f in files: + if f.is_file(): + try: + f.unlink() + count += 1 + except Exception: + pass + + return {"count": count, "size_mb": round(total_size / 1024 / 1024, 2)} + + +__all__ = ["CacheService"] diff --git a/app/services/grok/utils/download.py b/app/services/grok/utils/download.py new file mode 100644 index 00000000..edfb279e --- /dev/null +++ b/app/services/grok/utils/download.py @@ -0,0 +1,295 @@ +""" +Download service. + +Download service for assets.grok.com. +""" + +import asyncio +import base64 +import hashlib +import os +from pathlib import Path +from typing import List, Optional, Tuple +from urllib.parse import urlparse + +import aiofiles +from curl_cffi.requests import AsyncSession + +from app.core.logger import logger +from app.core.storage import DATA_DIR +from app.core.config import get_config +from app.core.exceptions import AppException +from app.services.reverse.assets_download import AssetsDownloadReverse +from app.services.grok.utils.locks import _get_download_semaphore, _file_lock + + +class DownloadService: + """Assets download service.""" + + def __init__(self): + self._session: Optional[AsyncSession] = None + base_dir = DATA_DIR / "tmp" + self.image_dir = base_dir / "image" + self.video_dir = base_dir / "video" + self.image_dir.mkdir(parents=True, exist_ok=True) + self.video_dir.mkdir(parents=True, exist_ok=True) + self._cleanup_running = False + + async def create(self) -> AsyncSession: + """Create or reuse a session.""" + if self._session is None: + self._session = AsyncSession() + return self._session + + async def close(self): + """Close the session.""" + if self._session: + await self._session.close() + self._session = None + + async def resolve_url( + self, path_or_url: str, token: str, media_type: str = "image" + ) -> str: + asset_url = path_or_url + path = path_or_url + if path_or_url.startswith("http"): + parsed = urlparse(path_or_url) + path = parsed.path or "" + asset_url = path_or_url + else: + if not path_or_url.startswith("/"): + path_or_url = f"/{path_or_url}" + path = path_or_url + asset_url = f"https://assets.grok.com{path_or_url}" + + app_url = get_config("app.app_url") + if app_url: + await self.download_file(asset_url, token, media_type) + return f"{app_url.rstrip('/')}/v1/files/{media_type}{path}" + return asset_url + + async def render_image( + self, url: str, token: str, image_id: str = "image" + ) -> str: + fmt = get_config("app.image_format") + fmt = fmt.lower() if isinstance(fmt, str) else "url" + if fmt not in ("base64", "url", "markdown"): + fmt = "url" + try: + if fmt == "base64": + data_uri = await self.parse_b64(url, token, "image") + return f"![{image_id}]({data_uri})" + final_url = await self.resolve_url(url, token, "image") + return f"![{image_id}]({final_url})" + except Exception as e: + logger.warning(f"Image render failed, fallback to URL: {e}") + final_url = await self.resolve_url(url, token, "image") + return f"![{image_id}]({final_url})" + + async def render_video( + self, video_url: str, token: str, thumbnail_url: str = "" + ) -> str: + fmt = get_config("app.video_format") + fmt = fmt.lower() if isinstance(fmt, str) else "url" + if fmt not in ("url", "markdown", "html"): + fmt = "url" + final_video_url = await self.resolve_url(video_url, token, "video") + final_thumb_url = "" + if thumbnail_url: + final_thumb_url = await self.resolve_url(thumbnail_url, token, "image") + if fmt == "url": + return final_video_url + if fmt == "markdown": + return f"[video]({final_video_url})" + import html + + safe_video_url = html.escape(final_video_url) + safe_thumbnail_url = html.escape(final_thumb_url) + poster_attr = f' poster="{safe_thumbnail_url}"' if safe_thumbnail_url else "" + return f'''''' + + @staticmethod + def _is_url(value: str) -> bool: + """Check if the value is a URL.""" + try: + parsed = urlparse(value) + return bool( + parsed.scheme and parsed.netloc and parsed.scheme in ["http", "https"] + ) + except Exception: + return False + + async def parse_b64(self, file_path: str, token: str, media_type: str = "image") -> str: + """Download and return data URI.""" + try: + if not isinstance(file_path, str) or not file_path.strip(): + raise AppException("Invalid file path", code="invalid_file_path") + if file_path.startswith("data:"): + raise AppException("Invalid file path", code="invalid_file_path") + if not self._is_url(file_path): + raise AppException("Invalid file path", code="invalid_file_path") + + file_path = self._normalize_path(file_path) + lock_name = f"dl_b64_{hashlib.sha1(file_path.encode()).hexdigest()[:16]}" + lock_timeout = max(1, int(get_config("asset.download_timeout"))) + async with _get_download_semaphore(): + async with _file_lock(lock_name, timeout=lock_timeout): + session = await self.create() + response = await AssetsDownloadReverse.request( + session, token, file_path + ) + + if hasattr(response, "aiter_content"): + data = bytearray() + async for chunk in response.aiter_content(): + if chunk: + data.extend(chunk) + raw = bytes(data) + else: + raw = response.content + + content_type = response.headers.get( + "content-type", "application/octet-stream" + ).split(";")[0] + data_uri = f"data:{content_type};base64,{base64.b64encode(raw).decode()}" + + return data_uri + except Exception as e: + logger.error(f"Failed to convert {file_path} to base64: {e}") + raise + + def _normalize_path(self, file_path: str) -> str: + """Normalize file path for download.""" + if not isinstance(file_path, str) or not file_path.strip(): + raise AppException("Invalid file path", code="invalid_file_path") + parsed = urlparse(file_path) + if not (parsed.scheme and parsed.netloc and parsed.scheme in ["http", "https"]): + raise AppException("Invalid file path", code="invalid_file_path") + path = parsed.path or "" + if parsed.query: + path = f"{path}?{parsed.query}" + file_path = path + if not file_path.startswith("/"): + file_path = f"/{file_path}" + return file_path + + async def download_file(self, file_path: str, token: str, media_type: str = "image") -> Tuple[Optional[Path], str]: + """Download asset to local cache. + + Args: + file_path: str, the path of the file to download. + token: str, the SSO token. + media_type: str, the media type of the file. + + Returns: + Tuple[Optional[Path], str]: The path of the downloaded file and the MIME type. + """ + async with _get_download_semaphore(): + file_path = self._normalize_path(file_path) + cache_dir = self.image_dir if media_type == "image" else self.video_dir + filename = file_path.lstrip("/").replace("/", "-") + cache_path = cache_dir / filename + + lock_name = ( + f"dl_{media_type}_{hashlib.sha1(str(cache_path).encode()).hexdigest()[:16]}" + ) + lock_timeout = max(1, int(get_config("asset.download_timeout"))) + async with _file_lock(lock_name, timeout=lock_timeout): + session = await self.create() + response = await AssetsDownloadReverse.request(session, token, file_path) + + tmp_path = cache_path.with_suffix(cache_path.suffix + ".tmp") + try: + async with aiofiles.open(tmp_path, "wb") as f: + if hasattr(response, "aiter_content"): + async for chunk in response.aiter_content(): + if chunk: + await f.write(chunk) + else: + await f.write(response.content) + os.replace(tmp_path, cache_path) + finally: + if tmp_path.exists() and not cache_path.exists(): + try: + tmp_path.unlink() + except Exception: + pass + + mime = response.headers.get( + "content-type", "application/octet-stream" + ).split(";")[0] + logger.info(f"Downloaded: {file_path}") + + asyncio.create_task(self._check_limit()) + + return cache_path, mime + + async def _check_limit(self): + """Check cache limit and cleanup. + + Args: + self: DownloadService, the download service instance. + + Returns: + None + """ + if self._cleanup_running or not get_config("cache.enable_auto_clean"): + return + + self._cleanup_running = True + try: + try: + async with _file_lock("cache_cleanup", timeout=5): + limit_mb = get_config("cache.limit_mb") + total_size = 0 + all_files: List[Tuple[Path, float, int]] = [] + + for d in [self.image_dir, self.video_dir]: + if d.exists(): + for f in d.glob("*"): + if f.is_file(): + try: + stat = f.stat() + total_size += stat.st_size + all_files.append( + (f, stat.st_mtime, stat.st_size) + ) + except Exception: + pass + current_mb = total_size / 1024 / 1024 + + if current_mb <= limit_mb: + return + + logger.info( + f"Cache limit exceeded ({current_mb:.2f}MB > {limit_mb}MB), cleaning..." + ) + all_files.sort(key=lambda x: x[1]) + + deleted_count = 0 + deleted_size = 0 + target_mb = limit_mb * 0.8 + + for f, _, size in all_files: + try: + f.unlink() + deleted_count += 1 + deleted_size += size + total_size -= size + if (total_size / 1024 / 1024) <= target_mb: + break + except Exception: + pass + + logger.info( + f"Cache cleanup: {deleted_count} files ({deleted_size / 1024 / 1024:.2f}MB)" + ) + except Exception as e: + logger.warning(f"Cache cleanup failed: {e}") + finally: + self._cleanup_running = False + + +__all__ = ["DownloadService"] diff --git a/app/services/grok/utils/headers.py b/app/services/grok/utils/headers.py deleted file mode 100644 index 7a5e1c2a..00000000 --- a/app/services/grok/utils/headers.py +++ /dev/null @@ -1,31 +0,0 @@ -""" -Common header helpers for Grok services. -""" - -from __future__ import annotations - -import uuid -from typing import Dict - -from app.core.config import get_config -from app.services.grok.utils.statsig import StatsigService - - -def _normalize_token(token: str) -> str: - return token[4:] if token.startswith("sso=") else token - - -def build_sso_cookie(token: str, include_rw: bool = False) -> str: - token = _normalize_token(token) - cf = get_config("security.cf_clearance") - cookie = f"sso={token}" - if include_rw: - cookie = f"{cookie}; sso-rw={token}" - if cf: - cookie = f"{cookie};cf_clearance={cf}" - return cookie - - -def apply_statsig(headers: Dict[str, str]) -> None: - headers["x-statsig-id"] = StatsigService.gen_id() - headers["x-xai-request-id"] = str(uuid.uuid4()) diff --git a/app/services/grok/utils/locks.py b/app/services/grok/utils/locks.py new file mode 100644 index 00000000..0ad227f5 --- /dev/null +++ b/app/services/grok/utils/locks.py @@ -0,0 +1,86 @@ +""" +Shared locking helpers for assets operations. +""" + +import asyncio +import time +from contextlib import asynccontextmanager +from pathlib import Path + +from app.core.config import get_config +from app.core.storage import DATA_DIR + +try: + import fcntl +except ImportError: + fcntl = None + + +LOCK_DIR = DATA_DIR / ".locks" + +_UPLOAD_SEMAPHORE = None +_UPLOAD_SEM_VALUE = None +_DOWNLOAD_SEMAPHORE = None +_DOWNLOAD_SEM_VALUE = None + + +def _get_upload_semaphore() -> asyncio.Semaphore: + """Return global semaphore for upload operations.""" + value = max(1, int(get_config("asset.upload_concurrent"))) + + global _UPLOAD_SEMAPHORE, _UPLOAD_SEM_VALUE + if _UPLOAD_SEMAPHORE is None or value != _UPLOAD_SEM_VALUE: + _UPLOAD_SEM_VALUE = value + _UPLOAD_SEMAPHORE = asyncio.Semaphore(value) + return _UPLOAD_SEMAPHORE + + +def _get_download_semaphore() -> asyncio.Semaphore: + """Return global semaphore for download operations.""" + value = max(1, int(get_config("asset.download_concurrent"))) + + global _DOWNLOAD_SEMAPHORE, _DOWNLOAD_SEM_VALUE + if _DOWNLOAD_SEMAPHORE is None or value != _DOWNLOAD_SEM_VALUE: + _DOWNLOAD_SEM_VALUE = value + _DOWNLOAD_SEMAPHORE = asyncio.Semaphore(value) + return _DOWNLOAD_SEMAPHORE + + +@asynccontextmanager +async def _file_lock(name: str, timeout: int = 10): + """File lock guard.""" + if fcntl is None: + yield + return + + LOCK_DIR.mkdir(parents=True, exist_ok=True) + lock_path = Path(LOCK_DIR) / f"{name}.lock" + fd = None + locked = False + start = time.monotonic() + + try: + fd = open(lock_path, "a+") + while True: + try: + fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB) + locked = True + break + except BlockingIOError: + if time.monotonic() - start >= timeout: + break + await asyncio.sleep(0.05) + if not locked: + raise TimeoutError(f"Failed to acquire lock: {name}") + yield + finally: + if fd: + if locked: + try: + fcntl.flock(fd, fcntl.LOCK_UN) + except Exception: + pass + fd.close() + + +__all__ = ["_get_upload_semaphore", "_get_download_semaphore", "_file_lock"] diff --git a/app/services/grok/processors/base.py b/app/services/grok/utils/process.py similarity index 76% rename from app/services/grok/processors/base.py rename to app/services/grok/utils/process.py index 76b4838d..69353c65 100644 --- a/app/services/grok/processors/base.py +++ b/app/services/grok/utils/process.py @@ -8,21 +8,20 @@ from app.core.config import get_config from app.core.logger import logger -from app.services.grok.services.assets import DownloadService +from app.core.exceptions import StreamIdleTimeoutError +from app.services.grok.utils.download import DownloadService -ASSET_URL = "https://assets.grok.com/" - T = TypeVar("T") -def _is_http2_stream_error(e: Exception) -> bool: +def _is_http2_error(e: Exception) -> bool: """检查是否为 HTTP/2 流错误""" err_str = str(e).lower() return "http/2" in err_str or "curl: (92)" in err_str or "stream" in err_str -def _normalize_stream_line(line: Any) -> Optional[str]: +def _normalize_line(line: Any) -> Optional[str]: """规范化流式响应行,兼容 SSE data 前缀与空行""" if line is None: return None @@ -40,7 +39,7 @@ def _normalize_stream_line(line: Any) -> Optional[str]: return text -def _collect_image_urls(obj: Any) -> List[str]: +def _collect_images(obj: Any) -> List[str]: """递归收集响应中的图片 URL""" urls: List[str] = [] seen = set() @@ -71,14 +70,6 @@ def walk(value: Any): return urls -class StreamIdleTimeoutError(Exception): - """流空闲超时错误""" - - def __init__(self, idle_seconds: float): - self.idle_seconds = idle_seconds - super().__init__(f"Stream idle timeout after {idle_seconds}s") - - async def _with_idle_timeout( iterable: AsyncIterable[T], idle_timeout: float, model: str = "" ) -> AsyncGenerator[T, None]: @@ -96,6 +87,16 @@ async def _with_idle_timeout( return iterator = iterable.__aiter__() + + async def _maybe_aclose(it): + aclose = getattr(it, "aclose", None) + if not aclose: + return + try: + await aclose() + except Exception: + pass + while True: try: item = await asyncio.wait_for(iterator.__anext__(), timeout=idle_timeout) @@ -105,7 +106,11 @@ async def _with_idle_timeout( f"Stream idle timeout after {idle_timeout}s", extra={"model": model, "idle_timeout": idle_timeout}, ) + await _maybe_aclose(iterator) raise StreamIdleTimeoutError(idle_timeout) + except asyncio.CancelledError: + await _maybe_aclose(iterator) + raise except StopAsyncIteration: break @@ -134,27 +139,14 @@ async def close(self): async def process_url(self, path: str, media_type: str = "image") -> str: """处理资产 URL""" - if path.startswith("http"): - from urllib.parse import urlparse - - path = urlparse(path).path - - if not path.startswith("/"): - path = f"/{path}" - - if self.app_url: - dl_service = self._get_dl() - await dl_service.download(path, self.token, media_type) - return f"{self.app_url.rstrip('/')}/v1/files/{media_type}{path}" - else: - return f"{ASSET_URL.rstrip('/')}{path}" + dl_service = self._get_dl() + return await dl_service.resolve_url(path, self.token, media_type) __all__ = [ "BaseProcessor", - "StreamIdleTimeoutError", "_with_idle_timeout", - "_normalize_stream_line", - "_collect_image_urls", - "_is_http2_stream_error", + "_normalize_line", + "_collect_images", + "_is_http2_error", ] diff --git a/app/services/grok/utils/retry.py b/app/services/grok/utils/retry.py index 162c4f8c..e0b1edb5 100644 --- a/app/services/grok/utils/retry.py +++ b/app/services/grok/utils/retry.py @@ -1,264 +1,45 @@ """ -Grok API 重试工具 - -提供可配置的重试机制,支持: -- 指数退避 + decorrelated jitter -- Retry-After header 支持 -- 429 专用退避策略 -- 重试预算控制 +Retry helpers for token switching. """ -import asyncio -import random -from typing import Callable, Any, Optional -from functools import wraps +from typing import Optional, Set -from app.core.logger import logger -from app.core.config import get_config from app.core.exceptions import UpstreamException +from app.services.grok.services.model import ModelService -class RetryContext: - """重试上下文""" - - def __init__(self): - self.attempt = 0 - self.max_retry = int(get_config("retry.max_retry")) - self.retry_codes = get_config("retry.retry_status_codes") - self.last_error = None - self.last_status = None - self.total_delay = 0.0 - self.retry_budget = float(get_config("retry.retry_budget")) - - # 退避参数 - self.backoff_base = float(get_config("retry.retry_backoff_base")) - self.backoff_factor = float(get_config("retry.retry_backoff_factor")) - self.backoff_max = float(get_config("retry.retry_backoff_max")) - - # decorrelated jitter 状态 - self._last_delay = self.backoff_base - - def should_retry(self, status_code: int) -> bool: - """判断是否重试""" - if self.attempt >= self.max_retry: - return False - if status_code not in self.retry_codes: - return False - if self.total_delay >= self.retry_budget: - return False - return True - - def record_error(self, status_code: int, error: Exception): - """记录错误信息""" - self.last_status = status_code - self.last_error = error - self.attempt += 1 - - def calculate_delay( - self, status_code: int, retry_after: Optional[float] = None - ) -> float: - """ - 计算退避延迟时间 - - Args: - status_code: HTTP 状态码 - retry_after: Retry-After header 值(秒) - - Returns: - 延迟时间(秒) - """ - # 优先使用 Retry-After - if retry_after is not None and retry_after > 0: - delay = min(retry_after, self.backoff_max) - self._last_delay = delay - return delay - - # 429 使用 decorrelated jitter - if status_code == 429: - # decorrelated jitter: delay = random(base, last_delay * 3) - delay = random.uniform(self.backoff_base, self._last_delay * 3) - delay = min(delay, self.backoff_max) - self._last_delay = delay - return delay - - # 其他状态码使用指数退避 + full jitter - exp_delay = self.backoff_base * (self.backoff_factor**self.attempt) - delay = random.uniform(0, min(exp_delay, self.backoff_max)) - return delay +async def pick_token( + token_mgr, + model_id: str, + tried: Set[str], + preferred: Optional[str] = None, +) -> Optional[str]: + if preferred and preferred not in tried: + return preferred - def record_delay(self, delay: float): - """记录延迟时间""" - self.total_delay += delay + token = None + for pool_name in ModelService.pool_candidates_for_model(model_id): + token = token_mgr.get_token(pool_name, exclude=tried) + if token: + break + if not token and not tried: + result = await token_mgr.refresh_cooling_tokens() + if result.get("recovered", 0) > 0: + for pool_name in ModelService.pool_candidates_for_model(model_id): + token = token_mgr.get_token(pool_name) + if token: + break -def extract_retry_after(error: Exception) -> Optional[float]: - """ - 从异常中提取 Retry-After 值 + return token - Args: - error: 异常对象 - Returns: - Retry-After 秒数,或 None - """ +def rate_limited(error: Exception) -> bool: if not isinstance(error, UpstreamException): - return None - - details = error.details or {} - - # 尝试从 details 中获取 - retry_after = details.get("retry_after") - if retry_after is not None: - try: - return float(retry_after) - except (ValueError, TypeError): - pass - - # 尝试从 headers 中获取 - headers = details.get("headers", {}) - if isinstance(headers, dict): - retry_after = headers.get("Retry-After") or headers.get("retry-after") - if retry_after is not None: - try: - return float(retry_after) - except (ValueError, TypeError): - pass - - return None - - -async def retry_on_status( - func: Callable, - *args, - extract_status: Callable[[Exception], Optional[int]] = None, - on_retry: Callable[[int, int, Exception, float], None] = None, - **kwargs, -) -> Any: - """ - 通用重试函数 - - Args: - func: 重试的异步函数 - *args: 函数参数 - extract_status: 异常提取状态码的函数 - on_retry: 重试时的回调函数 (attempt, status_code, error, delay) - **kwargs: 函数关键字参数 - - Returns: - 函数执行结果 - - Raises: - 最后一次失败的异常 - """ - ctx = RetryContext() - - # 状态码提取器 - if extract_status is None: - - def extract_status(e: Exception) -> Optional[int]: - if isinstance(e, UpstreamException): - # 优先从 details 获取,回退到 status_code 属性 - if e.details and "status" in e.details: - return e.details["status"] - return getattr(e, "status_code", None) - return None - - while ctx.attempt <= ctx.max_retry: - try: - result = await func(*args, **kwargs) - - # 记录日志 - if ctx.attempt > 0: - logger.info( - f"Retry succeeded after {ctx.attempt} attempts, " - f"total delay: {ctx.total_delay:.2f}s" - ) - - return result - - except Exception as e: - # 提取状态码 - status_code = extract_status(e) - - if status_code is None: - # 错误无法识别 - logger.error(f"Non-retryable error: {e}") - raise - - # 记录错误 - ctx.record_error(status_code, e) - - # 判断是否重试 - if ctx.should_retry(status_code): - # 提取 Retry-After - retry_after = extract_retry_after(e) - - # 计算延迟 - delay = ctx.calculate_delay(status_code, retry_after) - - # 检查是否超出预算 - if ctx.total_delay + delay > ctx.retry_budget: - logger.warning( - f"Retry budget exhausted: {ctx.total_delay:.2f}s + {delay:.2f}s > {ctx.retry_budget}s" - ) - raise - - ctx.record_delay(delay) - - logger.warning( - f"Retry {ctx.attempt}/{ctx.max_retry} for status {status_code}, " - f"waiting {delay:.2f}s (total: {ctx.total_delay:.2f}s)" - + (f", Retry-After: {retry_after}s" if retry_after else "") - ) - - # 回调 - if on_retry: - on_retry(ctx.attempt, status_code, e, delay) - - await asyncio.sleep(delay) - continue - else: - # 不可重试或重试次数耗尽 - if status_code in ctx.retry_codes: - logger.error( - f"Retry exhausted after {ctx.attempt} attempts, " - f"last status: {status_code}, total delay: {ctx.total_delay:.2f}s" - ) - else: - logger.error(f"Non-retryable status code: {status_code}") - - # 抛出最后一次的错误 - raise - - -def with_retry( - extract_status: Callable[[Exception], Optional[int]] = None, - on_retry: Callable[[int, int, Exception, float], None] = None, -): - """ - 重试装饰器 - - Usage: - @with_retry() - async def my_api_call(): - ... - """ - - def decorator(func: Callable): - @wraps(func) - async def wrapper(*args, **kwargs): - return await retry_on_status( - func, *args, extract_status=extract_status, on_retry=on_retry, **kwargs - ) - - return wrapper - - return decorator + return False + status = error.details.get("status") if error.details else None + code = error.details.get("error_code") if error.details else None + return status == 429 or code == "rate_limit_exceeded" -__all__ = [ - "RetryContext", - "retry_on_status", - "with_retry", - "extract_retry_after", -] +__all__ = ["pick_token", "rate_limited"] diff --git a/app/services/grok/utils/statsig.py b/app/services/grok/utils/statsig.py deleted file mode 100644 index c2cd15f8..00000000 --- a/app/services/grok/utils/statsig.py +++ /dev/null @@ -1,51 +0,0 @@ -""" -Statsig ID 生成服务 -""" - -import base64 -import random -import string - -from app.core.config import get_config - - -class StatsigService: - """Statsig ID 生成服务""" - - @staticmethod - def _rand(length: int, alphanumeric: bool = False) -> str: - """生成随机字符串""" - chars = ( - string.ascii_lowercase + string.digits - if alphanumeric - else string.ascii_lowercase - ) - return "".join(random.choices(chars, k=length)) - - @staticmethod - def gen_id() -> str: - """ - 生成 Statsig ID - - Returns: - Base64 编码的 ID - """ - dynamic = get_config("chat.dynamic_statsig") - - if not dynamic: - return "ZTpUeXBlRXJyb3I6IENhbm5vdCByZWFkIHByb3BlcnRpZXMgb2YgdW5kZWZpbmVkIChyZWFkaW5nICdjaGlsZE5vZGVzJyk=" - - # 随机格式 - if random.choice([True, False]): - rand = StatsigService._rand(5, alphanumeric=True) - message = f"e:TypeError: Cannot read properties of null (reading 'children['{rand}']')" - else: - rand = StatsigService._rand(10) - message = ( - f"e:TypeError: Cannot read properties of undefined (reading '{rand}')" - ) - - return base64.b64encode(message.encode()).decode() - - -__all__ = ["StatsigService"] diff --git a/app/services/grok/utils/stream.py b/app/services/grok/utils/stream.py index c9e64dd2..053c18d9 100644 --- a/app/services/grok/utils/stream.py +++ b/app/services/grok/utils/stream.py @@ -5,7 +5,7 @@ from typing import AsyncGenerator from app.core.logger import logger -from app.services.grok.models.model import ModelService +from app.services.grok.services.model import ModelService from app.services.token import EffortType diff --git a/app/services/grok/utils/upload.py b/app/services/grok/utils/upload.py new file mode 100644 index 00000000..0861a4ed --- /dev/null +++ b/app/services/grok/utils/upload.py @@ -0,0 +1,242 @@ +""" +Upload service. + +Upload service for assets.grok.com. +""" + +import base64 +import hashlib +import mimetypes +import re +from pathlib import Path +from typing import AsyncIterator, Optional, Tuple +from urllib.parse import urlparse + +import aiofiles +from curl_cffi.requests import AsyncSession + +from app.core.config import get_config +from app.core.exceptions import AppException, UpstreamException, ValidationException +from app.core.logger import logger +from app.core.storage import DATA_DIR +from app.services.reverse.assets_upload import AssetsUploadReverse +from app.services.grok.utils.locks import _get_upload_semaphore, _file_lock + + +class UploadService: + """Assets upload service.""" + + def __init__(self): + self._session: Optional[AsyncSession] = None + self._chunk_size = 64 * 1024 + + async def create(self) -> AsyncSession: + """Create or reuse a session.""" + if self._session is None: + self._session = AsyncSession() + return self._session + + async def close(self): + """Close the session.""" + if self._session: + await self._session.close() + self._session = None + + @staticmethod + def _is_url(value: str) -> bool: + """Check if the value is a URL.""" + try: + parsed = urlparse(value) + return bool( + parsed.scheme and parsed.netloc and parsed.scheme in ["http", "https"] + ) + except Exception: + return False + + @staticmethod + def _infer_mime(filename: str, fallback: str = "application/octet-stream") -> str: + mime, _ = mimetypes.guess_type(filename) + return mime or fallback + + @staticmethod + async def _encode_b64_stream(chunks: AsyncIterator[bytes]) -> str: + parts = [] + remain = b"" + async for chunk in chunks: + if not chunk: + continue + chunk = remain + chunk + keep = len(chunk) % 3 + if keep: + remain = chunk[-keep:] + chunk = chunk[:-keep] + else: + remain = b"" + if chunk: + parts.append(base64.b64encode(chunk).decode()) + if remain: + parts.append(base64.b64encode(remain).decode()) + return "".join(parts) + + async def _read_local_file(self, local_type: str, name: str) -> Tuple[str, str, str]: + base_dir = DATA_DIR / "tmp" + if local_type == "video": + local_dir = base_dir / "video" + mime = "video/mp4" + else: + local_dir = base_dir / "image" + suffix = Path(name).suffix.lower() + if suffix == ".png": + mime = "image/png" + elif suffix == ".webp": + mime = "image/webp" + elif suffix == ".gif": + mime = "image/gif" + else: + mime = "image/jpeg" + + local_path = local_dir / name + lock_name = f"ul_local_{hashlib.sha1(str(local_path).encode()).hexdigest()[:16]}" + lock_timeout = max(1, int(get_config("asset.upload_timeout"))) + async with _file_lock(lock_name, timeout=lock_timeout): + if not local_path.exists(): + raise ValidationException(f"Local file not found: {local_path}") + if not local_path.is_file(): + raise ValidationException(f"Invalid local file: {local_path}") + + async def _iter_file() -> AsyncIterator[bytes]: + async with aiofiles.open(local_path, "rb") as f: + while True: + chunk = await f.read(self._chunk_size) + if not chunk: + break + yield chunk + + b64 = await self._encode_b64_stream(_iter_file()) + filename = name or "file" + return filename, b64, mime + + async def parse_b64(self, url: str) -> Tuple[str, str, str]: + """Fetch URL content and return (filename, base64, mime).""" + try: + app_url = get_config("app.app_url") or "" + if app_url and self._is_url(url): + parsed = urlparse(url) + app_parsed = urlparse(app_url) + if ( + parsed.scheme == app_parsed.scheme + and parsed.netloc == app_parsed.netloc + and parsed.path.startswith("/v1/files/") + ): + parts = parsed.path.strip("/").split("/", 3) + if len(parts) >= 4: + local_type = parts[2] + name = parts[3].replace("/", "-") + return await self._read_local_file(local_type, name) + + lock_name = f"ul_url_{hashlib.sha1(url.encode()).hexdigest()[:16]}" + timeout = float(get_config("asset.upload_timeout")) + proxy_url = get_config("proxy.base_proxy_url") + proxies = {"http": proxy_url, "https": proxy_url} if proxy_url else None + + lock_timeout = max(1, int(get_config("asset.upload_timeout"))) + async with _file_lock(lock_name, timeout=lock_timeout): + session = await self.create() + response = await session.get(url, timeout=timeout, proxies=proxies) + if response.status_code >= 400: + raise UpstreamException( + message=f"Failed to fetch: {response.status_code}", + details={"url": url, "status": response.status_code}, + ) + + filename = url.split("/")[-1].split("?")[0] or "download" + content_type = response.headers.get( + "content-type", "" + ).split(";")[0].strip() + if not content_type: + content_type = self._infer_mime(filename) + if hasattr(response, "aiter_content"): + b64 = await self._encode_b64_stream(response.aiter_content()) + else: + b64 = base64.b64encode(response.content).decode() + + logger.debug(f"Fetched: {url}") + return filename, b64, content_type + except Exception as e: + if isinstance(e, AppException): + raise + logger.error(f"Fetch failed: {url} - {e}") + raise UpstreamException(f"Fetch failed: {str(e)}", details={"url": url}) + + @staticmethod + def format_b64(data_uri: str) -> Tuple[str, str, str]: + """Format data URI to (filename, base64, mime).""" + if not data_uri.startswith("data:"): + raise ValidationException("Invalid file input: not a data URI") + + try: + header, b64 = data_uri.split(",", 1) + except ValueError: + raise ValidationException("Invalid data URI format") + + if ";base64" not in header: + raise ValidationException("Invalid data URI: missing base64 marker") + + mime = header[5:].split(";", 1)[0] or "application/octet-stream" + b64 = re.sub(r"\s+", "", b64) + if not mime or not b64: + raise ValidationException("Invalid data URI: empty content") + ext = mime.split("/")[-1] if "/" in mime else "bin" + return f"file.{ext}", b64, mime + + async def check_format(self, file_input: str) -> Tuple[str, str, str]: + """Check file input format and return (filename, base64, mime).""" + if not isinstance(file_input, str) or not file_input.strip(): + raise ValidationException("Invalid file input: empty content") + + if self._is_url(file_input): + return await self.parse_b64(file_input) + + if file_input.startswith("data:"): + return self.format_b64(file_input) + + raise ValidationException("Invalid file input: must be URL or base64") + + async def upload_file(self, file_input: str, token: str) -> Tuple[str, str]: + """ + Upload file to Grok. + + Args: + file_input: str, the file input. + token: str, the SSO token. + + Returns: + Tuple[str, str]: The file ID and URI. + """ + async with _get_upload_semaphore(): + filename, b64, mime = await self.check_format(file_input) + + logger.debug( + f"Upload prepare: filename={filename}, type={mime}, size={len(b64)}" + ) + + if not b64: + raise ValidationException("Invalid file input: empty content") + + session = await self.create() + response = await AssetsUploadReverse.request( + session, + token, + filename, + mime, + b64, + ) + + result = response.json() + file_id = result.get("fileMetadataId", "") + file_uri = result.get("fileUri", "") + logger.info(f"Upload success: {filename} -> {file_id}") + return file_id, file_uri + + +__all__ = ["UploadService"] diff --git a/app/services/reverse/__init__.py b/app/services/reverse/__init__.py new file mode 100644 index 00000000..6e8aebfa --- /dev/null +++ b/app/services/reverse/__init__.py @@ -0,0 +1,34 @@ +"""Reverse interfaces for Grok endpoints.""" + +from .app_chat import AppChatReverse +from .assets_delete import AssetsDeleteReverse +from .assets_download import AssetsDownloadReverse +from .assets_list import AssetsListReverse +from .assets_upload import AssetsUploadReverse +from .media_post import MediaPostReverse +from .nsfw_mgmt import NsfwMgmtReverse +from .rate_limits import RateLimitsReverse +from .set_birth import SetBirthReverse +from .video_upscale import VideoUpscaleReverse +from .ws_livekit import LivekitTokenReverse, LivekitWebSocketReverse +from .ws_imagine import ImagineWebSocketReverse +from .utils.headers import build_headers +from .utils.statsig import StatsigGenerator + +__all__ = [ + "AppChatReverse", + "AssetsDeleteReverse", + "AssetsDownloadReverse", + "AssetsListReverse", + "AssetsUploadReverse", + "MediaPostReverse", + "NsfwMgmtReverse", + "RateLimitsReverse", + "SetBirthReverse", + "VideoUpscaleReverse", + "LivekitTokenReverse", + "LivekitWebSocketReverse", + "ImagineWebSocketReverse", + "StatsigGenerator", + "build_headers", +] diff --git a/app/services/reverse/accept_tos.py b/app/services/reverse/accept_tos.py new file mode 100644 index 00000000..8459be46 --- /dev/null +++ b/app/services/reverse/accept_tos.py @@ -0,0 +1,118 @@ +""" +Reverse interface: accept ToS (gRPC-Web). +""" + +from curl_cffi.requests import AsyncSession + +from app.core.logger import logger +from app.core.config import get_config +from app.core.exceptions import UpstreamException +from app.services.reverse.utils.headers import build_headers +from app.services.reverse.utils.retry import retry_on_status +from app.services.reverse.utils.grpc import GrpcClient, GrpcStatus + +ACCEPT_TOS_API = "https://accounts.x.ai/auth_mgmt.AuthManagement/SetTosAcceptedVersion" + + +class AcceptTosReverse: + """/auth_mgmt.AuthManagement/SetTosAcceptedVersion reverse interface.""" + + @staticmethod + async def request(session: AsyncSession, token: str) -> GrpcStatus: + """Accept ToS via gRPC-Web. + + Args: + session: AsyncSession, the session to use for the request. + token: str, the SSO token. + + Returns: + GrpcStatus: Parsed gRPC status. + """ + try: + # Get proxies + base_proxy = get_config("proxy.base_proxy_url") + proxies = {"http": base_proxy, "https": base_proxy} if base_proxy else None + + # Build headers + headers = build_headers( + cookie_token=token, + origin="https://accounts.x.ai", + referer="https://accounts.x.ai/accept-tos", + ) + headers["Content-Type"] = "application/grpc-web+proto" + headers["Accept"] = "*/*" + headers["Sec-Fetch-Dest"] = "empty" + headers["x-grpc-web"] = "1" + headers["x-user-agent"] = "connect-es/2.1.1" + headers["Cache-Control"] = "no-cache" + headers["Pragma"] = "no-cache" + + # Build payload + payload = GrpcClient.encode_payload(b"\x10\x01") + + # Curl Config + timeout = get_config("nsfw.timeout") + browser = get_config("proxy.browser") + + async def _do_request(): + response = await session.post( + ACCEPT_TOS_API, + headers=headers, + data=payload, + timeout=timeout, + proxies=proxies, + impersonate=browser, + ) + + if response.status_code != 200: + logger.error( + f"AcceptTosReverse: Request failed, {response.status_code}", + extra={"error_type": "UpstreamException"}, + ) + raise UpstreamException( + message=f"AcceptTosReverse: Request failed, {response.status_code}", + details={"status": response.status_code}, + ) + + logger.debug(f"AcceptTosReverse: Request successful, {response.status_code}") + + return response + + response = await retry_on_status(_do_request) + + _, trailers = GrpcClient.parse_response( + response.content, + content_type=response.headers.get("content-type"), + headers=response.headers, + ) + grpc_status = GrpcClient.get_status(trailers) + + if grpc_status.code not in (-1, 0): + raise UpstreamException( + message=f"AcceptTosReverse: gRPC failed, {grpc_status.code}", + details={ + "status": grpc_status.http_equiv, + "grpc_status": grpc_status.code, + "grpc_message": grpc_status.message, + }, + ) + + return grpc_status + + except Exception as e: + # Handle upstream exception + if isinstance(e, UpstreamException): + raise + + # Handle other non-upstream exceptions + logger.error( + f"AcceptTosReverse: Request failed, {str(e)}", + extra={"error_type": type(e).__name__}, + ) + raise UpstreamException( + message=f"AcceptTosReverse: Request failed, {str(e)}", + details={"status": 502, "error": str(e)}, + ) + + +__all__ = ["AcceptTosReverse"] diff --git a/app/services/reverse/app_chat.py b/app/services/reverse/app_chat.py new file mode 100644 index 00000000..d7075a05 --- /dev/null +++ b/app/services/reverse/app_chat.py @@ -0,0 +1,215 @@ +""" +Reverse interface: app chat conversations. +""" + +import orjson +from typing import Any, Dict, List, Optional +from curl_cffi.requests import AsyncSession + +from app.core.logger import logger +from app.core.config import get_config +from app.core.exceptions import UpstreamException +from app.services.token.service import TokenService +from app.services.reverse.utils.headers import build_headers +from app.services.reverse.utils.retry import retry_on_status + +CHAT_API = "https://grok.com/rest/app-chat/conversations/new" + + +class AppChatReverse: + """/rest/app-chat/conversations/new reverse interface.""" + + @staticmethod + def build_payload( + message: str, + model: str, + mode: str = None, + file_attachments: List[str] = None, + tool_overrides: Dict[str, Any] = None, + model_config_override: Dict[str, Any] = None, + ) -> Dict[str, Any]: + """Build chat payload for Grok app-chat API.""" + + attachments = file_attachments or [] + + payload = { + "deviceEnvInfo": { + "darkModeEnabled": False, + "devicePixelRatio": 2, + "screenWidth": 2056, + "screenHeight": 1329, + "viewportWidth": 2056, + "viewportHeight": 1083, + }, + "disableMemory": get_config("app.disable_memory"), + "disableSearch": False, + "disableSelfHarmShortCircuit": False, + "disableTextFollowUps": False, + "enableImageGeneration": True, + "enableImageStreaming": True, + "enableSideBySide": True, + "fileAttachments": attachments, + "forceConcise": False, + "forceSideBySide": False, + "imageAttachments": [], + "imageGenerationCount": 2, + "isAsyncChat": False, + "isReasoning": False, + "message": message, + "modelMode": mode, + "modelName": model, + "responseMetadata": { + "requestModelDetails": {"modelId": model}, + }, + "returnImageBytes": False, + "returnRawGrokInXaiRequest": False, + "sendFinalMetadata": True, + "temporary": get_config("app.temporary"), + "toolOverrides": tool_overrides or {}, + } + + if model_config_override: + payload["responseMetadata"]["modelConfigOverride"] = model_config_override + + return payload + + @staticmethod + async def request( + session: AsyncSession, + token: str, + message: str, + model: str, + mode: str = None, + file_attachments: List[str] = None, + tool_overrides: Dict[str, Any] = None, + model_config_override: Dict[str, Any] = None, + ) -> Any: + """Send app chat request to Grok. + + Args: + session: AsyncSession, the session to use for the request. + token: str, the SSO token. + message: str, the message to send. + model: str, the model to use. + mode: str, the mode to use. + file_attachments: List[str], the file attachments to send. + tool_overrides: Dict[str, Any], the tool overrides to use. + model_config_override: Dict[str, Any], the model config override to use. + + Returns: + Any: The response from the request. + """ + try: + # Get proxies + base_proxy = get_config("proxy.base_proxy_url") + proxies = {"http": base_proxy, "https": base_proxy} if base_proxy else None + + # Build headers + headers = build_headers( + cookie_token=token, + content_type="application/json", + origin="https://grok.com", + referer="https://grok.com/", + ) + + # Build payload + payload = AppChatReverse.build_payload( + message=message, + model=model, + mode=mode, + file_attachments=file_attachments, + tool_overrides=tool_overrides, + model_config_override=model_config_override, + ) + + # Curl Config + timeout = max( + float(get_config("chat.timeout") or 0), + float(get_config("video.timeout") or 0), + float(get_config("image.timeout") or 0), + ) + browser = get_config("proxy.browser") + + async def _do_request(): + response = await session.post( + CHAT_API, + headers=headers, + data=orjson.dumps(payload), + timeout=timeout, + stream=True, + proxies=proxies, + impersonate=browser, + ) + + if response.status_code != 200: + + # Get response content + content = "" + try: + content = await response.text() + except Exception: + pass + + logger.error( + f"AppChatReverse: Chat failed, {response.status_code}", + extra={"error_type": "UpstreamException"}, + ) + raise UpstreamException( + message=f"AppChatReverse: Chat failed, {response.status_code}", + details={"status": response.status_code, "body": content}, + ) + + return response + + def extract_status(e: Exception) -> Optional[int]: + if isinstance(e, UpstreamException): + if e.details and "status" in e.details: + status = e.details["status"] + else: + status = getattr(e, "status_code", None) + if status == 429: + return None + return status + return None + + response = await retry_on_status(_do_request, extract_status=extract_status) + + # Stream response + async def stream_response(): + try: + async for line in response.aiter_lines(): + yield line + finally: + await session.close() + + return stream_response() + + except Exception as e: + # Handle upstream exception + if isinstance(e, UpstreamException): + status = None + if e.details and "status" in e.details: + status = e.details["status"] + else: + status = getattr(e, "status_code", None) + if status == 401: + try: + await TokenService.record_fail( + token, status, "app_chat_auth_failed" + ) + except Exception: + pass + raise + + # Handle other non-upstream exceptions + logger.error( + f"AppChatReverse: Chat failed, {str(e)}", + extra={"error_type": type(e).__name__}, + ) + raise UpstreamException( + message=f"AppChatReverse: Chat failed, {str(e)}", + details={"status": 502, "error": str(e)}, + ) + + +__all__ = ["AppChatReverse"] diff --git a/app/services/reverse/assets_delete.py b/app/services/reverse/assets_delete.py new file mode 100644 index 00000000..79423107 --- /dev/null +++ b/app/services/reverse/assets_delete.py @@ -0,0 +1,102 @@ +""" +Reverse interface: delete asset metadata. +""" + +from typing import Any +from curl_cffi.requests import AsyncSession + +from app.core.logger import logger +from app.core.config import get_config +from app.core.exceptions import UpstreamException +from app.services.token.service import TokenService +from app.services.reverse.utils.headers import build_headers +from app.services.reverse.utils.retry import retry_on_status + +DELETE_API = "https://grok.com/rest/assets-metadata" + + +class AssetsDeleteReverse: + """/rest/assets-metadata/{file_id} reverse interface.""" + + @staticmethod + async def request(session: AsyncSession, token: str, asset_id: str) -> Any: + """Delete asset from Grok. + + Args: + session: AsyncSession, the session to use for the request. + token: str, the SSO token. + asset_id: str, the ID of the asset to delete. + + Returns: + Any: The response from the request. + """ + try: + # Get proxies + base_proxy = get_config("proxy.base_proxy_url") + assert_proxy = get_config("proxy.asset_proxy_url") + if assert_proxy: + proxies = {"http": assert_proxy, "https": assert_proxy} + else: + proxies = {"http": base_proxy, "https": base_proxy} + + # Build headers + headers = build_headers( + cookie_token=token, + content_type="application/json", + origin="https://grok.com", + referer="https://grok.com/files", + ) + + # Curl Config + timeout = get_config("asset.delete_timeout") + browser = get_config("proxy.browser") + + async def _do_request(): + response = await session.delete( + f"{DELETE_API}/{asset_id}", + headers=headers, + proxies=proxies, + timeout=timeout, + impersonate=browser, + ) + + if response.status_code != 200: + logger.error( + f"AssetsDeleteReverse: Delete failed, {response.status_code}", + extra={"error_type": "UpstreamException"}, + ) + raise UpstreamException( + message=f"AssetsDeleteReverse: Delete failed, {response.status_code}", + details={"status": response.status_code}, + ) + + return response + + return await retry_on_status(_do_request) + + except Exception as e: + # Handle upstream exception + if isinstance(e, UpstreamException): + status = None + if e.details and "status" in e.details: + status = e.details["status"] + else: + status = getattr(e, "status_code", None) + if status == 401: + try: + await TokenService.record_fail(token, status, "assets_delete_auth_failed") + except Exception: + pass + raise + + # Handle other non-upstream exceptions + logger.error( + f"AssetsDeleteReverse: Delete failed, {str(e)}", + extra={"error_type": type(e).__name__}, + ) + raise UpstreamException( + message=f"AssetsDeleteReverse: Delete failed, {str(e)}", + details={"status": 502, "error": str(e)}, + ) + +__all__ = ["AssetsDeleteReverse"] diff --git a/app/services/reverse/assets_download.py b/app/services/reverse/assets_download.py new file mode 100644 index 00000000..ec03794d --- /dev/null +++ b/app/services/reverse/assets_download.py @@ -0,0 +1,132 @@ +""" +Reverse interface: download asset. +""" + +import urllib.parse +from typing import Any +from pathlib import Path +from curl_cffi.requests import AsyncSession + +from app.core.logger import logger +from app.core.config import get_config +from app.core.exceptions import UpstreamException +from app.services.token.service import TokenService +from app.services.reverse.utils.headers import build_headers +from app.services.reverse.utils.retry import retry_on_status + +DOWNLOAD_API = "https://assets.grok.com" + +_CONTENT_TYPES = { + ".jpg": "image/jpeg", + ".jpeg": "image/jpeg", + ".png": "image/png", + ".webp": "image/webp", + ".mp4": "video/mp4", + ".webm": "video/webm", +} + + +class AssetsDownloadReverse: + """assets.grok.com/{path} reverse interface.""" + + @staticmethod + async def request(session: AsyncSession, token: str, file_path: str) -> Any: + """Download asset from Grok. + + Args: + session: AsyncSession, the session to use for the request. + token: str, the SSO token. + file_path: str, the path of the file to download. + + Returns: + Any: The response from the request. + """ + try: + # Normalize path + if not file_path.startswith("/"): + file_path = f"/{file_path}" + url = f"{DOWNLOAD_API}{file_path}" + + # Get proxies + base_proxy = get_config("proxy.base_proxy_url") + assert_proxy = get_config("proxy.asset_proxy_url") + if assert_proxy: + proxies = {"http": assert_proxy, "https": assert_proxy} + else: + proxies = {"http": base_proxy, "https": base_proxy} + + # Guess content type by extension for Accept/Sec-Fetch-Dest + content_type = _CONTENT_TYPES.get(Path(urllib.parse.urlparse(file_path).path).suffix.lower()) + + # Build headers + headers = build_headers( + cookie_token=token, + content_type=content_type, + origin="https://assets.grok.com", + referer="https://grok.com/", + ) + ## Align with browser download navigation headers + headers["Cache-Control"] = "no-cache" + headers["Pragma"] = "no-cache" + headers["Priority"] = "u=0, i" + headers["Sec-Fetch-Mode"] = "navigate" + headers["Sec-Fetch-User"] = "?1" + headers["Upgrade-Insecure-Requests"] = "1" + + # Curl Config + timeout = get_config("asset.download_timeout") + browser = get_config("proxy.browser") + + async def _do_request(): + response = await session.get( + url, + headers=headers, + proxies=proxies, + timeout=timeout, + allow_redirects=True, + impersonate=browser, + stream=True, + ) + + if response.status_code != 200: + logger.error( + f"AssetsDownloadReverse: Download failed, {response.status_code}", + extra={"error_type": "UpstreamException"}, + ) + raise UpstreamException( + message=f"AssetsDownloadReverse: Download failed, {response.status_code}", + details={"status": response.status_code}, + ) + + return response + + return await retry_on_status(_do_request) + + except Exception as e: + # Handle upstream exception + if isinstance(e, UpstreamException): + status = None + if e.details and "status" in e.details: + status = e.details["status"] + else: + status = getattr(e, "status_code", None) + + if status == 401: + try: + await TokenService.record_fail(token, status, "assets_download_auth_failed") + except Exception: + pass + raise + + # Handle other non-upstream exceptions + logger.error( + f"AssetsDownloadReverse: Download failed, {str(e)}", + extra={"error_type": type(e).__name__}, + ) + raise UpstreamException( + message=f"AssetsDownloadReverse: Download failed, {str(e)}", + details={"status": 502, "error": str(e)}, + ) + + +__all__ = ["AssetsDownloadReverse"] diff --git a/app/services/reverse/assets_list.py b/app/services/reverse/assets_list.py new file mode 100644 index 00000000..5c84fe99 --- /dev/null +++ b/app/services/reverse/assets_list.py @@ -0,0 +1,104 @@ +""" +Reverse interface: list assets. +""" + +from typing import Any, Dict +from curl_cffi.requests import AsyncSession + +from app.core.logger import logger +from app.core.config import get_config +from app.core.exceptions import UpstreamException +from app.services.token.service import TokenService +from app.services.reverse.utils.headers import build_headers +from app.services.reverse.utils.retry import retry_on_status + +LIST_API = "https://grok.com/rest/assets" + + +class AssetsListReverse: + """/rest/assets reverse interface.""" + + @staticmethod + async def request(session: AsyncSession, token: str, params: Dict[str, Any]) -> Any: + """List assets from Grok. + + Args: + session: AsyncSession, the session to use for the request. + token: str, the SSO token. + params: Dict[str, Any], the parameters for the request. + + Returns: + Any: The response from the request. + """ + try: + # Get proxies + base_proxy = get_config("proxy.base_proxy_url") + assert_proxy = get_config("proxy.asset_proxy_url") + if assert_proxy: + proxies = {"http": assert_proxy, "https": assert_proxy} + else: + proxies = {"http": base_proxy, "https": base_proxy} + + # Build headers + headers = build_headers( + cookie_token=token, + content_type="application/json", + origin="https://grok.com", + referer="https://grok.com/files", + ) + + # Curl Config + timeout = get_config("asset.list_timeout") + browser = get_config("proxy.browser") + + async def _do_request(): + response = await session.get( + LIST_API, + headers=headers, + params=params, + proxies=proxies, + timeout=timeout, + impersonate=browser, + ) + + if response.status_code != 200: + logger.error( + f"AssetsListReverse: List failed, {response.status_code}", + extra={"error_type": "UpstreamException"}, + ) + raise UpstreamException( + message=f"AssetsListReverse: List failed, {response.status_code}", + details={"status": response.status_code}, + ) + + return response + + return await retry_on_status(_do_request) + + except Exception as e: + # Handle upstream exception + if isinstance(e, UpstreamException): + status = None + if e.details and "status" in e.details: + status = e.details["status"] + else: + status = getattr(e, "status_code", None) + if status == 401: + try: + await TokenService.record_fail(token, status, "assets_list_auth_failed") + except Exception: + pass + raise + + # Handle other non-upstream exceptions + logger.error( + f"AssetsListReverse: List failed, {str(e)}", + extra={"error_type": type(e).__name__}, + ) + raise UpstreamException( + message=f"AssetsListReverse: List failed, {str(e)}", + details={"status": 502, "error": str(e)}, + ) + + +__all__ = ["AssetsListReverse"] diff --git a/app/services/reverse/assets_upload.py b/app/services/reverse/assets_upload.py new file mode 100644 index 00000000..b9d96731 --- /dev/null +++ b/app/services/reverse/assets_upload.py @@ -0,0 +1,111 @@ +""" +Reverse interface: upload asset. +""" + +from typing import Any +from curl_cffi.requests import AsyncSession + +from app.core.logger import logger +from app.core.config import get_config +from app.core.exceptions import UpstreamException +from app.services.token.service import TokenService +from app.services.reverse.utils.headers import build_headers +from app.services.reverse.utils.retry import retry_on_status + +UPLOAD_API = "https://grok.com/rest/app-chat/upload-file" + + +class AssetsUploadReverse: + """/rest/app-chat/upload-file reverse interface.""" + + @staticmethod + async def request(session: AsyncSession, token: str, fileName: str, fileMimeType: str, content: str) -> Any: + """Upload asset to Grok. + + Args: + session: AsyncSession, the session to use for the request. + token: str, the SSO token. + fileName: str, the name of the file. + fileMimeType: str, the MIME type of the file. + content: str, the content of the file. + + Returns: + Any: The response from the request. + """ + try: + # Get proxies + base_proxy = get_config("proxy.base_proxy_url") + assert_proxy = get_config("proxy.asset_proxy_url") + if assert_proxy: + proxies = {"http": assert_proxy, "https": assert_proxy} + else: + proxies = {"http": base_proxy, "https": base_proxy} + + # Build headers + headers = build_headers( + cookie_token=token, + content_type="application/json", + origin="https://grok.com", + referer="https://grok.com/", + ) + + # Build payload + payload = { + "fileName": fileName, + "fileMimeType": fileMimeType, + "content": content, + } + + # Curl Config + timeout = get_config("asset.upload_timeout") + browser = get_config("proxy.browser") + + async def _do_request(): + response = await session.post( + UPLOAD_API, + headers=headers, + json=payload, + proxies=proxies, + timeout=timeout, + impersonate=browser, + ) + if response.status_code != 200: + logger.error( + f"AssetsUploadReverse: Upload failed, {response.status_code}", + extra={"error_type": "UpstreamException"}, + ) + raise UpstreamException( + message=f"AssetsUploadReverse: Upload failed, {response.status_code}", + details={"status": response.status_code}, + ) + return response + + return await retry_on_status(_do_request) + + except Exception as e: + # Handle upstream exception + if isinstance(e, UpstreamException): + status = None + if e.details and "status" in e.details: + status = e.details["status"] + else: + status = getattr(e, "status_code", None) + if status == 401: + try: + await TokenService.record_fail(token, status, "assets_upload_auth_failed") + except Exception: + pass + raise + + # Handle other non-upstream exceptions + logger.error( + f"AssetsUploadReverse: Upload failed, {str(e)}", + extra={"error_type": type(e).__name__}, + ) + raise UpstreamException( + message=f"AssetsUploadReverse: Upload failed, {str(e)}", + details={"status": 502, "error": str(e)}, + ) + + +__all__ = ["AssetsUploadReverse"] diff --git a/app/services/reverse/media_post.py b/app/services/reverse/media_post.py new file mode 100644 index 00000000..6e70d539 --- /dev/null +++ b/app/services/reverse/media_post.py @@ -0,0 +1,120 @@ +""" +Reverse interface: media post create. +""" + +import orjson +from typing import Any +from curl_cffi.requests import AsyncSession + +from app.core.logger import logger +from app.core.config import get_config +from app.core.exceptions import UpstreamException +from app.services.token.service import TokenService +from app.services.reverse.utils.headers import build_headers +from app.services.reverse.utils.retry import retry_on_status + +MEDIA_POST_API = "https://grok.com/rest/media/post/create" + + +class MediaPostReverse: + """/rest/media/post/create reverse interface.""" + + @staticmethod + async def request( + session: AsyncSession, + token: str, + mediaType: str, + mediaUrl: str, + prompt: str = "", + ) -> Any: + """Create media post in Grok. + + Args: + session: AsyncSession, the session to use for the request. + token: str, the SSO token. + mediaType: str, the media type. + mediaUrl: str, the media URL. + + Returns: + Any: The response from the request. + """ + try: + # Get proxies + base_proxy = get_config("proxy.base_proxy_url") + proxies = {"http": base_proxy, "https": base_proxy} if base_proxy else None + + # Build headers + headers = build_headers( + cookie_token=token, + content_type="application/json", + origin="https://grok.com", + referer="https://grok.com", + ) + + # Build payload + payload = {"mediaType": mediaType} + if mediaUrl: + payload["mediaUrl"] = mediaUrl + if prompt: + payload["prompt"] = prompt + + # Curl Config + timeout = get_config("video.timeout") + browser = get_config("proxy.browser") + + async def _do_request(): + response = await session.post( + MEDIA_POST_API, + headers=headers, + data=orjson.dumps(payload), + timeout=timeout, + proxies=proxies, + impersonate=browser, + ) + + if response.status_code != 200: + content = "" + try: + content = await response.text() + except Exception: + pass + logger.error( + f"MediaPostReverse: Media post create failed, {response.status_code}", + extra={"error_type": "UpstreamException"}, + ) + raise UpstreamException( + message=f"MediaPostReverse: Media post create failed, {response.status_code}", + details={"status": response.status_code, "body": content}, + ) + + return response + + return await retry_on_status(_do_request) + + except Exception as e: + # Handle upstream exception + if isinstance(e, UpstreamException): + status = None + if e.details and "status" in e.details: + status = e.details["status"] + else: + status = getattr(e, "status_code", None) + if status == 401: + try: + await TokenService.record_fail(token, status, "media_post_auth_failed") + except Exception: + pass + raise + + # Handle other non-upstream exceptions + logger.error( + f"MediaPostReverse: Media post create failed, {str(e)}", + extra={"error_type": type(e).__name__}, + ) + raise UpstreamException( + message=f"MediaPostReverse: Media post create failed, {str(e)}", + details={"status": 502, "error": str(e)}, + ) + + +__all__ = ["MediaPostReverse"] diff --git a/app/services/reverse/nsfw_mgmt.py b/app/services/reverse/nsfw_mgmt.py new file mode 100644 index 00000000..ca5afc46 --- /dev/null +++ b/app/services/reverse/nsfw_mgmt.py @@ -0,0 +1,126 @@ +""" +Reverse interface: NSFW feature controls (gRPC-Web). +""" + +from curl_cffi.requests import AsyncSession + +from app.core.logger import logger +from app.core.config import get_config +from app.core.exceptions import UpstreamException +from app.services.reverse.utils.headers import build_headers +from app.services.reverse.utils.retry import retry_on_status +from app.services.reverse.utils.grpc import GrpcClient, GrpcStatus + +NSFW_MGMT_API = "https://grok.com/auth_mgmt.AuthManagement/UpdateUserFeatureControls" + + +class NsfwMgmtReverse: + """/auth_mgmt.AuthManagement/UpdateUserFeatureControls reverse interface.""" + + @staticmethod + async def request(session: AsyncSession, token: str) -> GrpcStatus: + """Enable NSFW feature control via gRPC-Web. + + Args: + session: AsyncSession, the session to use for the request. + token: str, the SSO token. + + Returns: + GrpcStatus: Parsed gRPC status. + """ + try: + # Get proxies + base_proxy = get_config("proxy.base_proxy_url") + proxies = {"http": base_proxy, "https": base_proxy} if base_proxy else None + + # Build headers + headers = build_headers( + cookie_token=token, + origin="https://grok.com", + referer="https://grok.com/?_s=data", + ) + headers["Content-Type"] = "application/grpc-web+proto" + headers["Accept"] = "*/*" + headers["Sec-Fetch-Dest"] = "empty" + headers["x-grpc-web"] = "1" + headers["x-user-agent"] = "connect-es/2.1.1" + headers["Cache-Control"] = "no-cache" + headers["Pragma"] = "no-cache" + + # Build payload + name = "always_show_nsfw_content".encode("utf-8") + inner = b"\x0a" + bytes([len(name)]) + name + protobuf = b"\x0a\x02\x10\x01\x12" + bytes([len(inner)]) + inner + payload = GrpcClient.encode_payload(protobuf) + + # Curl Config + timeout = get_config("nsfw.timeout") + browser = get_config("proxy.browser") + + async def _do_request(): + response = await session.post( + NSFW_MGMT_API, + headers=headers, + data=payload, + timeout=timeout, + proxies=proxies, + impersonate=browser, + ) + + if response.status_code != 200: + logger.error( + f"NsfwMgmtReverse: Request failed, {response.status_code}", + extra={"error_type": "UpstreamException"}, + ) + raise UpstreamException( + message=f"NsfwMgmtReverse: Request failed, {response.status_code}", + details={"status": response.status_code}, + ) + + logger.debug(f"NsfwMgmtReverse: Request successful, {response.status_code}") + + return response + + response = await retry_on_status(_do_request) + + _, trailers = GrpcClient.parse_response( + response.content, + content_type=response.headers.get("content-type"), + headers=response.headers, + ) + grpc_status = GrpcClient.get_status(trailers) + + if grpc_status.code not in (-1, 0): + raise UpstreamException( + message=f"NsfwMgmtReverse: gRPC failed, {grpc_status.code}", + details={ + "status": grpc_status.http_equiv, + "grpc_status": grpc_status.code, + "grpc_message": grpc_status.message, + }, + ) + + return grpc_status + + except Exception as e: + # Handle upstream exception + if isinstance(e, UpstreamException): + status = None + if e.details and "status" in e.details: + status = e.details["status"] + else: + status = getattr(e, "status_code", None) + raise + + # Handle other non-upstream exceptions + logger.error( + f"NsfwMgmtReverse: Request failed, {str(e)}", + extra={"error_type": type(e).__name__}, + ) + raise UpstreamException( + message=f"NsfwMgmtReverse: Request failed, {str(e)}", + details={"status": 502, "error": str(e)}, + ) + + +__all__ = ["NsfwMgmtReverse"] diff --git a/app/services/reverse/rate_limits.py b/app/services/reverse/rate_limits.py new file mode 100644 index 00000000..10e6d71f --- /dev/null +++ b/app/services/reverse/rate_limits.py @@ -0,0 +1,100 @@ +""" +Reverse interface: rate limits. +""" + +import orjson +from typing import Any +from curl_cffi.requests import AsyncSession + +from app.core.logger import logger +from app.core.config import get_config +from app.core.exceptions import UpstreamException +from app.services.reverse.utils.headers import build_headers +from app.services.reverse.utils.retry import retry_on_status + +RATE_LIMITS_API = "https://grok.com/rest/rate-limits" + + +class RateLimitsReverse: + """/rest/rate-limits reverse interface.""" + + @staticmethod + async def request(session: AsyncSession, token: str) -> Any: + """Fetch rate limits from Grok. + + Args: + session: AsyncSession, the session to use for the request. + token: str, the SSO token. + + Returns: + Any: The response from the request. + """ + try: + # Get proxies + base_proxy = get_config("proxy.base_proxy_url") + proxies = {"http": base_proxy, "https": base_proxy} if base_proxy else None + + # Build headers + headers = build_headers( + cookie_token=token, + content_type="application/json", + origin="https://grok.com", + referer="https://grok.com/", + ) + + # Build payload + payload = { + "requestKind": "DEFAULT", + "modelName": "grok-4-1-thinking-1129", + } + + # Curl Config + timeout = get_config("usage.timeout") + browser = get_config("proxy.browser") + + async def _do_request(): + response = await session.post( + RATE_LIMITS_API, + headers=headers, + data=orjson.dumps(payload), + timeout=timeout, + proxies=proxies, + impersonate=browser, + ) + + if response.status_code != 200: + logger.error( + f"RateLimitsReverse: Request failed, {response.status_code}", + extra={"error_type": "UpstreamException"}, + ) + raise UpstreamException( + message=f"RateLimitsReverse: Request failed, {response.status_code}", + details={"status": response.status_code}, + ) + + return response + + return await retry_on_status(_do_request) + + except Exception as e: + # Handle upstream exception + if isinstance(e, UpstreamException): + status = None + if e.details and "status" in e.details: + status = e.details["status"] + else: + status = getattr(e, "status_code", None) + raise + + # Handle other non-upstream exceptions + logger.error( + f"RateLimitsReverse: Request failed, {str(e)}", + extra={"error_type": type(e).__name__}, + ) + raise UpstreamException( + message=f"RateLimitsReverse: Request failed, {str(e)}", + details={"status": 502, "error": str(e)}, + ) + + +__all__ = ["RateLimitsReverse"] diff --git a/app/services/reverse/set_birth.py b/app/services/reverse/set_birth.py new file mode 100644 index 00000000..d76c4c60 --- /dev/null +++ b/app/services/reverse/set_birth.py @@ -0,0 +1,111 @@ +""" +Reverse interface: set birth date. +""" + +import datetime +import random +from typing import Any +from curl_cffi.requests import AsyncSession + +from app.core.logger import logger +from app.core.config import get_config +from app.core.exceptions import UpstreamException +from app.services.reverse.utils.headers import build_headers +from app.services.reverse.utils.retry import retry_on_status + +SET_BIRTH_API = "https://grok.com/rest/auth/set-birth-date" + + +class SetBirthReverse: + """/rest/auth/set-birth-date reverse interface.""" + + @staticmethod + async def request(session: AsyncSession, token: str) -> Any: + """Set birth date in Grok. + + Args: + session: AsyncSession, the session to use for the request. + token: str, the SSO token. + + Returns: + Any: The response from the request. + """ + try: + # Get proxies + base_proxy = get_config("proxy.base_proxy_url") + proxies = {"http": base_proxy, "https": base_proxy} if base_proxy else None + + # Build headers + headers = build_headers( + cookie_token=token, + content_type="application/json", + origin="https://grok.com", + referer="https://grok.com/?_s=home", + ) + + # Build payload + today = datetime.date.today() + birth_year = today.year - random.randint(20, 48) + birth_month = random.randint(1, 12) + birth_day = random.randint(1, 28) + hour = random.randint(0, 23) + minute = random.randint(0, 59) + second = random.randint(0, 59) + microsecond = random.randint(0, 999) + payload = { + "birthDate": f"{birth_year:04d}-{birth_month:02d}-{birth_day:02d}" + f"T{hour:02d}:{minute:02d}:{second:02d}.{microsecond:03d}Z" + } + + # Curl Config + timeout = get_config("nsfw.timeout") + browser = get_config("proxy.browser") + + async def _do_request(): + response = await session.post( + SET_BIRTH_API, + headers=headers, + json=payload, + timeout=timeout, + proxies=proxies, + impersonate=browser, + ) + + if response.status_code not in (200, 204): + logger.error( + f"SetBirthReverse: Request failed, {response.status_code}", + extra={"error_type": "UpstreamException"}, + ) + raise UpstreamException( + message=f"SetBirthReverse: Request failed, {response.status_code}", + details={"status": response.status_code}, + ) + + logger.debug(f"SetBirthReverse: Request successful, {response.status_code}") + + return response + + return await retry_on_status(_do_request) + + except Exception as e: + # Handle upstream exception + if isinstance(e, UpstreamException): + status = None + if e.details and "status" in e.details: + status = e.details["status"] + else: + status = getattr(e, "status_code", None) + raise + + # Handle other non-upstream exceptions + logger.error( + f"SetBirthReverse: Request failed, {str(e)}", + extra={"error_type": type(e).__name__}, + ) + raise UpstreamException( + message=f"SetBirthReverse: Request failed, {str(e)}", + details={"status": 502, "error": str(e)}, + ) + + +__all__ = ["SetBirthReverse"] diff --git a/app/services/reverse/utils/grpc.py b/app/services/reverse/utils/grpc.py new file mode 100644 index 00000000..39eb6787 --- /dev/null +++ b/app/services/reverse/utils/grpc.py @@ -0,0 +1,185 @@ +""" +gRPC-Web helpers for reverse interfaces. +""" + +import base64 +import json +import re +import struct +from dataclasses import dataclass +from typing import Dict, List, Mapping, Optional, Tuple +from urllib.parse import unquote + +from app.core.logger import logger + +# Base64 正则 +B64_RE = re.compile(rb"^[A-Za-z0-9+/=\r\n]+$") + + +@dataclass(frozen=True) +class GrpcStatus: + code: int + message: str = "" + + @property + def ok(self) -> bool: + return self.code == 0 + + @property + def http_equiv(self) -> int: + mapping = { + 0: 200, + 16: 401, + 7: 403, + 8: 429, + 4: 504, + 14: 503, + } + return mapping.get(self.code, 502) + + +class GrpcClient: + """gRPC-Web helpers wrapper.""" + + @staticmethod + def _safe_headers(headers: Optional[Mapping[str, str]]) -> Dict[str, str]: + if not headers: + return {} + safe: Dict[str, str] = {} + for k, v in headers.items(): + if k.lower() in ("set-cookie", "cookie", "authorization"): + safe[k] = "" + else: + safe[k] = str(v) + return safe + + @staticmethod + def _b64(data: bytes) -> str: + return base64.b64encode(data).decode() + + @staticmethod + def encode_payload(data: bytes) -> bytes: + """Encode gRPC-Web data frame.""" + return b"\x00" + struct.pack(">I", len(data)) + data + + @staticmethod + def _maybe_decode_grpc_web_text(body: bytes, content_type: Optional[str]) -> bytes: + ct = (content_type or "").lower() + if "grpc-web-text" in ct: + compact = b"".join(body.split()) + return base64.b64decode(compact, validate=False) + + head = body[: min(len(body), 2048)] + if head and B64_RE.fullmatch(head): + compact = b"".join(body.split()) + try: + return base64.b64decode(compact, validate=True) + except Exception: + return body + return body + + @staticmethod + def _parse_trailer_block(payload: bytes) -> Dict[str, str]: + text = payload.decode("utf-8", errors="replace") + lines = [ln for ln in re.split(r"\r\n|\n", text) if ln] + + trailers: Dict[str, str] = {} + for ln in lines: + if ":" not in ln: + continue + k, v = ln.split(":", 1) + trailers[k.strip().lower()] = v.strip() + + if "grpc-message" in trailers: + trailers["grpc-message"] = unquote(trailers["grpc-message"]) + + return trailers + + @classmethod + def parse_response( + cls, + body: bytes, + content_type: Optional[str] = None, + headers: Optional[Mapping[str, str]] = None, + ) -> Tuple[List[bytes], Dict[str, str]]: + decoded = cls._maybe_decode_grpc_web_text(body, content_type) + + messages: List[bytes] = [] + trailers: Dict[str, str] = {} + + i = 0 + n = len(decoded) + while i < n: + if n - i < 5: + break + + flag = decoded[i] + length = int.from_bytes(decoded[i + 1 : i + 5], "big") + i += 5 + + if n - i < length: + break + + payload = decoded[i : i + length] + i += length + + if flag & 0x80: + trailers.update(cls._parse_trailer_block(payload)) + elif flag & 0x01: + raise ValueError("grpc-web compressed flag not supported") + else: + messages.append(payload) + + if headers: + lower = {k.lower(): v for k, v in headers.items()} + if "grpc-status" in lower and "grpc-status" not in trailers: + trailers["grpc-status"] = str(lower["grpc-status"]).strip() + if "grpc-message" in lower and "grpc-message" not in trailers: + trailers["grpc-message"] = unquote(str(lower["grpc-message"]).strip()) + + # Log full response details on gRPC error + raw_status = str(trailers.get("grpc-status", "")).strip() + try: + status_code = int(raw_status) + except Exception: + status_code = -1 + + if status_code not in (0, -1): + try: + payload = { + "grpc_status": status_code, + "grpc_message": trailers.get("grpc-message", ""), + "content_type": content_type or "", + "headers": cls._safe_headers(headers), + "trailers": trailers, + "messages_b64": [cls._b64(m) for m in messages], + "body_b64": cls._b64(body), + } + logger.error( + "gRPC response error: {}", + json.dumps(payload, ensure_ascii=False), + extra={"error_type": "GrpcError"}, + ) + except Exception as e: + logger.error( + f"gRPC response error: failed to log payload ({e})", + extra={"error_type": "GrpcError"}, + ) + + return messages, trailers + + @staticmethod + def get_status(trailers: Mapping[str, str]) -> GrpcStatus: + raw = str(trailers.get("grpc-status", "")).strip() + msg = str(trailers.get("grpc-message", "")).strip() + try: + code = int(raw) + except Exception: + code = -1 + return GrpcStatus(code=code, message=msg) + + +__all__ = [ + "GrpcStatus", + "GrpcClient", +] diff --git a/app/services/reverse/utils/headers.py b/app/services/reverse/utils/headers.py new file mode 100644 index 00000000..e0c534e6 --- /dev/null +++ b/app/services/reverse/utils/headers.py @@ -0,0 +1,134 @@ +"""Shared header builders for reverse interfaces.""" + +import uuid +import orjson +from urllib.parse import urlparse +from typing import Dict, Optional + +from app.core.logger import logger +from app.core.config import get_config +from app.services.reverse.utils.statsig import StatsigGenerator + + +def build_sso_cookie(sso_token: str) -> str: + """ + Build SSO Cookie string. + + Args: + sso_token: str, the SSO token. + + Returns: + str: The SSO Cookie string. + """ + # Format + sso_token = sso_token[4:] if sso_token.startswith("sso=") else sso_token + + # SSO Cookie + cookie = f"sso={sso_token}; sso-rw={sso_token}" + + # CF Clearance + cf_clearance = get_config("proxy.cf_clearance") + if cf_clearance: + cookie += f";cf_clearance={cf_clearance}" + + return cookie + + +def build_ws_headers(token: Optional[str] = None, origin: Optional[str] = None, extra: Optional[Dict[str, str]] = None) -> Dict[str, str]: + """ + Build headers for WebSocket requests. + + Args: + token: Optional[str], the SSO token for Cookie. Defaults to None. + origin: Optional[str], the Origin value. Defaults to "https://grok.com" if not provided. + extra: Optional[Dict[str, str]], extra headers to merge. Defaults to None. + + Returns: + Dict[str, str]: The headers dictionary. + """ + headers = { + "Origin": origin or "https://grok.com", + "User-Agent": get_config("proxy.user_agent"), + "Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8", + "Cache-Control": "no-cache", + "Pragma": "no-cache", + } + + if token: + headers["Cookie"] = build_sso_cookie(token) + + if extra: + headers.update(extra) + + return headers + + +def build_headers(cookie_token: str, content_type: Optional[str] = None, origin: Optional[str] = None, referer: Optional[str] = None) -> Dict[str, str]: + """ + Build headers for reverse interfaces. + + Args: + cookie_token: str, the SSO token. + content_type: Optional[str], the Content-Type value. + origin: Optional[str], the Origin value. Defaults to "https://grok.com" if not provided. + referer: Optional[str], the Referer value. Defaults to "https://grok.com/" if not provided. + + Returns: + Dict[str, str]: The headers dictionary. + """ + headers = { + "Accept-Encoding": "gzip, deflate, br, zstd", + "Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8", + "Baggage": "sentry-environment=production,sentry-release=d6add6fb0460641fd482d767a335ef72b9b6abb8,sentry-public_key=b311e0f2690c81f25e2c4cf6d4f7ce1c", + "Origin": origin or "https://grok.com", + "Priority": "u=1, i", + "Referer": referer or "https://grok.com/", + "Sec-Ch-Ua": '"Google Chrome";v="136", "Chromium";v="136", "Not(A:Brand";v="24"', + "Sec-Ch-Ua-Arch": "arm", + "Sec-Ch-Ua-Bitness": "64", + "Sec-Ch-Ua-Mobile": "?0", + "Sec-Ch-Ua-Model": "", + "Sec-Ch-Ua-Platform": '"macOS"', + "Sec-Fetch-Mode": "cors", + "User-Agent": get_config("proxy.user_agent"), + } + + # Cookie + headers["Cookie"] = build_sso_cookie(cookie_token) + + # Content-Type and Accept/Sec-Fetch-Dest + if content_type and content_type == "application/json": + headers["Content-Type"] = "application/json" + headers["Accept"] = "*/*" + headers["Sec-Fetch-Dest"] = "empty" + elif content_type in ["image/jpeg", "image/png", "video/mp4", "video/webm"]: + headers["Content-Type"] = content_type + headers["Accept"] = "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7" + headers["Sec-Fetch-Dest"] = "document" + else: + headers["Content-Type"] = "application/json" + headers["Accept"] = "*/*" + headers["Sec-Fetch-Dest"] = "empty" + + # Sec-Fetch-Site + origin_domain = urlparse(headers.get("Origin", "")).hostname + referer_domain = urlparse(headers.get("Referer", "")).hostname + if origin_domain and referer_domain and origin_domain == referer_domain: + headers["Sec-Fetch-Site"] = "same-origin" + else: + headers["Sec-Fetch-Site"] = "same-site" + + # X-Statsig-ID and X-XAI-Request-ID + headers["x-statsig-id"] = StatsigGenerator.gen_id() + headers["x-xai-request-id"] = str(uuid.uuid4()) + + # Print headers without Cookie + safe_headers = dict(headers) + if "Cookie" in safe_headers: + safe_headers["Cookie"] = "" + logger.debug(f"Built headers: {orjson.dumps(safe_headers).decode()}") + + return headers + + +__all__ = ["build_headers", "build_sso_cookie", "build_ws_headers"] diff --git a/app/services/reverse/utils/retry.py b/app/services/reverse/utils/retry.py new file mode 100644 index 00000000..0de15b6f --- /dev/null +++ b/app/services/reverse/utils/retry.py @@ -0,0 +1,229 @@ +""" +Reverse retry utilities. +""" + +import asyncio +import random +from typing import Callable, Any, Optional + +from app.core.logger import logger +from app.core.config import get_config +from app.core.exceptions import UpstreamException + + +class RetryContext: + """Retry context.""" + + def __init__(self): + self.attempt = 0 + self.max_retry = int(get_config("retry.max_retry")) + self.retry_codes = get_config("retry.retry_status_codes") + self.last_error = None + self.last_status = None + self.total_delay = 0.0 + self.retry_budget = float(get_config("retry.retry_budget")) + + # Backoff parameters + self.backoff_base = float(get_config("retry.retry_backoff_base")) + self.backoff_factor = float(get_config("retry.retry_backoff_factor")) + self.backoff_max = float(get_config("retry.retry_backoff_max")) + + # Decorrelated jitter state + self._last_delay = self.backoff_base + + def should_retry(self, status_code: int) -> bool: + """Check if should retry.""" + if self.attempt >= self.max_retry: + return False + if status_code not in self.retry_codes: + return False + if self.total_delay >= self.retry_budget: + return False + return True + + def record_error(self, status_code: int, error: Exception): + """Record error information.""" + self.last_status = status_code + self.last_error = error + self.attempt += 1 + + def calculate_delay(self, status_code: int, retry_after: Optional[float] = None) -> float: + """ + Calculate backoff delay time. + + Args: + status_code: HTTP status code + retry_after: Retry-After header value (seconds) + + Returns: + Delay time (seconds) + """ + # Use Retry-After if available + if retry_after is not None and retry_after > 0: + delay = min(retry_after, self.backoff_max) + self._last_delay = delay + return delay + + # Use decorrelated jitter for 429 + if status_code == 429: + # decorrelated jitter: delay = random(base, last_delay * 3) + delay = random.uniform(self.backoff_base, self._last_delay * 3) + delay = min(delay, self.backoff_max) + self._last_delay = delay + return delay + + # Use exponential backoff + full jitter for other status codes + exp_delay = self.backoff_base * (self.backoff_factor**self.attempt) + delay = random.uniform(0, min(exp_delay, self.backoff_max)) + return delay + + def record_delay(self, delay: float): + """Record delay time.""" + self.total_delay += delay + + +def extract_retry_after(error: Exception) -> Optional[float]: + """ + Extract Retry-After value from exception. + + Args: + error: Exception object + + Returns: + Retry-After value (seconds), or None + """ + if not isinstance(error, UpstreamException): + return None + + details = error.details or {} + + # Try to get Retry-After from details + retry_after = details.get("retry_after") + if retry_after is not None: + try: + return float(retry_after) + except (ValueError, TypeError): + pass + + # Try to get Retry-After from headers + headers = details.get("headers", {}) + if isinstance(headers, dict): + retry_after = headers.get("Retry-After") or headers.get("retry-after") + if retry_after is not None: + try: + return float(retry_after) + except (ValueError, TypeError): + pass + + return None + + +async def retry_on_status( + func: Callable, + *args, + extract_status: Callable[[Exception], Optional[int]] = None, + on_retry: Callable[[int, int, Exception, float], None] = None, + **kwargs, +) -> Any: + """ + Generic retry function. + + Args: + func: Retry function + *args: Function arguments + extract_status: Function to extract status code from exception + on_retry: Callback function for retry (attempt, status_code, error, delay) + **kwargs: Function keyword arguments + + Returns: + Function execution result + + Raises: + Last failed exception + """ + ctx = RetryContext() + + # Status code extractor + if extract_status is None: + + def extract_status(e: Exception) -> Optional[int]: + if isinstance(e, UpstreamException): + # Try to get status code from details, fallback to status_code attribute + if e.details and "status" in e.details: + return e.details["status"] + return getattr(e, "status_code", None) + return None + + while ctx.attempt <= ctx.max_retry: + try: + result = await func(*args, **kwargs) + + # Record log + if ctx.attempt > 0: + logger.info( + f"Retry succeeded after {ctx.attempt} attempts, " + f"total delay: {ctx.total_delay:.2f}s" + ) + + return result + + except Exception as e: + # Extract status code + status_code = extract_status(e) + + if status_code is None: + # Error cannot be identified as retryable + logger.error(f"Non-retryable error: {e}") + raise + + # Record error + ctx.record_error(status_code, e) + + # Check if should retry + if ctx.should_retry(status_code): + # Extract Retry-After + retry_after = extract_retry_after(e) + + # Calculate delay + delay = ctx.calculate_delay(status_code, retry_after) + + # Check if exceeds budget + if ctx.total_delay + delay > ctx.retry_budget: + logger.warning( + f"Retry budget exhausted: {ctx.total_delay:.2f}s + {delay:.2f}s > {ctx.retry_budget}s" + ) + raise + + ctx.record_delay(delay) + + logger.warning( + f"Retry {ctx.attempt}/{ctx.max_retry} for status {status_code}, " + f"waiting {delay:.2f}s (total: {ctx.total_delay:.2f}s)" + + (f", Retry-After: {retry_after}s" if retry_after else "") + ) + + # Callback + if on_retry: + on_retry(ctx.attempt, status_code, e, delay) + + await asyncio.sleep(delay) + continue + else: + # Not retryable or retry budget exhausted + if status_code in ctx.retry_codes: + logger.error( + f"Retry exhausted after {ctx.attempt} attempts, " + f"last status: {status_code}, total delay: {ctx.total_delay:.2f}s" + ) + else: + logger.error(f"Non-retryable status code: {status_code}") + + # Raise last failed exception + raise + + +__all__ = [ + "RetryContext", + "retry_on_status", + "extract_retry_after", +] diff --git a/app/services/reverse/utils/statsig.py b/app/services/reverse/utils/statsig.py new file mode 100644 index 00000000..485885f1 --- /dev/null +++ b/app/services/reverse/utils/statsig.py @@ -0,0 +1,56 @@ +""" +Statsig ID generator for reverse interfaces. +""" + +import base64 +import random +import string + +from app.core.logger import logger +from app.core.config import get_config + + +class StatsigGenerator: + """Statsig ID generator for reverse interfaces.""" + + @staticmethod + def _rand(length: int, alphanumeric: bool = False) -> str: + """Generate random string.""" + chars = ( + string.ascii_lowercase + string.digits + if alphanumeric + else string.ascii_lowercase + ) + return "".join(random.choices(chars, k=length)) + + @staticmethod + def gen_id() -> str: + """ + Generate Statsig ID. + + Returns: + Base64 encoded ID. + """ + dynamic = get_config("app.dynamic_statsig") + + # Dynamic Statsig ID + if dynamic: + logger.debug("Generating dynamic Statsig ID") + + if random.choice([True, False]): + rand = StatsigGenerator._rand(5, alphanumeric=True) + message = f"e:TypeError: Cannot read properties of null (reading 'children['{rand}']')" + else: + rand = StatsigGenerator._rand(10) + message = ( + f"e:TypeError: Cannot read properties of undefined (reading '{rand}')" + ) + + return base64.b64encode(message.encode()).decode() + + # Static Statsig ID + logger.debug("Generating static Statsig ID") + return "ZTpUeXBlRXJyb3I6IENhbm5vdCByZWFkIHByb3BlcnRpZXMgb2YgdW5kZWZpbmVkIChyZWFkaW5nICdjaGlsZE5vZGVzJyk=" + + +__all__ = ["StatsigGenerator"] diff --git a/app/services/reverse/utils/websocket.py b/app/services/reverse/utils/websocket.py new file mode 100644 index 00000000..f13586a1 --- /dev/null +++ b/app/services/reverse/utils/websocket.py @@ -0,0 +1,142 @@ +""" +WebSocket helpers for reverse interfaces. +""" + +import ssl +import certifi +import aiohttp +from aiohttp_socks import ProxyConnector +from typing import Mapping, Optional +from urllib.parse import urlparse + +from app.core.logger import logger +from app.core.config import get_config + + +def _default_ssl_context() -> ssl.SSLContext: + context = ssl.create_default_context() + context.load_verify_locations(certifi.where()) + return context + + +def _normalize_socks_proxy(proxy_url: str) -> tuple[str, Optional[bool]]: + scheme = urlparse(proxy_url).scheme.lower() + rdns: Optional[bool] = None + base_scheme = scheme + + if scheme == "socks5h": + base_scheme = "socks5" + rdns = True + elif scheme == "socks4a": + base_scheme = "socks4" + rdns = True + + if base_scheme != scheme: + proxy_url = proxy_url.replace(f"{scheme}://", f"{base_scheme}://", 1) + + return proxy_url, rdns + + +def resolve_proxy(proxy_url: Optional[str] = None, ssl_context: ssl.SSLContext = _default_ssl_context()) -> tuple[aiohttp.BaseConnector, Optional[str]]: + """Resolve proxy connector. + + Args: + proxy_url: Optional[str], the proxy URL. Defaults to None. + ssl_context: ssl.SSLContext, the SSL context. Defaults to _default_ssl_context(). + + Returns: + tuple[aiohttp.BaseConnector, Optional[str]]: The proxy connector and the proxy URL. + """ + if not proxy_url: + return aiohttp.TCPConnector(ssl=ssl_context), None + + scheme = urlparse(proxy_url).scheme.lower() + if scheme.startswith("socks"): + normalized, rdns = _normalize_socks_proxy(proxy_url) + logger.info(f"Using SOCKS proxy: {proxy_url}") + try: + if rdns is not None: + return ( + ProxyConnector.from_url(normalized, rdns=rdns, ssl=ssl_context), + None, + ) + except TypeError: + return ProxyConnector.from_url(normalized, ssl=ssl_context), None + return ProxyConnector.from_url(normalized, ssl=ssl_context), None + + logger.info(f"Using HTTP proxy: {proxy_url}") + return aiohttp.TCPConnector(ssl=ssl_context), proxy_url + + +class WebSocketConnection: + """WebSocket connection wrapper.""" + + def __init__(self, session: aiohttp.ClientSession, ws: aiohttp.ClientWebSocketResponse) -> None: + self.session = session + self.ws = ws + + async def close(self) -> None: + if not self.ws.closed: + await self.ws.close() + await self.session.close() + + async def __aenter__(self) -> aiohttp.ClientWebSocketResponse: + return self.ws + + async def __aexit__(self, exc_type, exc, tb) -> None: + await self.close() + + +class WebSocketClient: + """WebSocket client with proxy support.""" + + def __init__(self, proxy: Optional[str] = None) -> None: + self.proxy = proxy or get_config("proxy.base_proxy_url") + self._ssl_context = _default_ssl_context() + + async def connect( + self, + url: str, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + ws_kwargs: Optional[Mapping[str, object]] = None, + ) -> WebSocketConnection: + """Connect to the WebSocket. + + Args: + url: str, the URL to connect to. + headers: Optional[Mapping[str, str]], the headers to send. Defaults to None. + ws_kwargs: Optional[Mapping[str, object]], extra ws_connect kwargs. Defaults to None. + + Returns: + WebSocketConnection: The WebSocket connection. + """ + # Resolve proxy + connector, proxy = resolve_proxy(self.proxy, self._ssl_context) + + # Build client timeout + total_timeout = ( + float(timeout) + if timeout is not None + else float(get_config("voice.timeout") or 120) + ) + client_timeout = aiohttp.ClientTimeout(total=total_timeout) + + # Create session + session = aiohttp.ClientSession(connector=connector, timeout=client_timeout) + try: + extra_kwargs = dict(ws_kwargs or {}) + ws = await session.ws_connect( + url, + headers=headers, + proxy=proxy, + ssl=self._ssl_context, + **extra_kwargs, + ) + return WebSocketConnection(session, ws) + except Exception: + await session.close() + raise + + +__all__ = ["WebSocketClient", "WebSocketConnection", "resolve_proxy"] diff --git a/app/services/reverse/video_upscale.py b/app/services/reverse/video_upscale.py new file mode 100644 index 00000000..f6c70e17 --- /dev/null +++ b/app/services/reverse/video_upscale.py @@ -0,0 +1,109 @@ +""" +Reverse interface: video upscale. +""" + +import orjson +from typing import Any +from curl_cffi.requests import AsyncSession + +from app.core.logger import logger +from app.core.config import get_config +from app.core.exceptions import UpstreamException +from app.services.token.service import TokenService +from app.services.reverse.utils.headers import build_headers +from app.services.reverse.utils.retry import retry_on_status + +VIDEO_UPSCALE_API = "https://grok.com/rest/media/video/upscale" + + +class VideoUpscaleReverse: + """/rest/media/video/upscale reverse interface.""" + + @staticmethod + async def request(session: AsyncSession, token: str, video_id: str) -> Any: + """Upscale video (image upscaling endpoint) in Grok. + + Args: + session: AsyncSession, the session to use for the request. + token: str, the SSO token. + video_id: str, the video id. + + Returns: + Any: The response from the request. + """ + try: + # Get proxies + base_proxy = get_config("proxy.base_proxy_url") + proxies = {"http": base_proxy, "https": base_proxy} if base_proxy else None + + # Build headers + headers = build_headers( + cookie_token=token, + content_type="application/json", + origin="https://grok.com", + referer="https://grok.com", + ) + + # Build payload + payload = {"videoId": video_id} + + # Curl Config + timeout = get_config("video.timeout") + browser = get_config("proxy.browser") + + async def _do_request(): + response = await session.post( + VIDEO_UPSCALE_API, + headers=headers, + data=orjson.dumps(payload), + timeout=timeout, + proxies=proxies, + impersonate=browser, + ) + + if response.status_code != 200: + content = "" + try: + content = await response.text() + except Exception: + pass + logger.error( + f"VideoUpscaleReverse: Upscale failed, {response.status_code}", + extra={"error_type": "UpstreamException"}, + ) + raise UpstreamException( + message=f"VideoUpscaleReverse: Upscale failed, {response.status_code}", + details={"status": response.status_code, "body": content}, + ) + + return response + + return await retry_on_status(_do_request) + + except Exception as e: + # Handle upstream exception + if isinstance(e, UpstreamException): + status = None + if e.details and "status" in e.details: + status = e.details["status"] + else: + status = getattr(e, "status_code", None) + if status == 401: + try: + await TokenService.record_fail(token, status, "video_upscale_auth_failed") + except Exception: + pass + raise + + # Handle other non-upstream exceptions + logger.error( + f"VideoUpscaleReverse: Upscale failed, {str(e)}", + extra={"error_type": type(e).__name__}, + ) + raise UpstreamException( + message=f"VideoUpscaleReverse: Upscale failed, {str(e)}", + details={"status": 502, "error": str(e)}, + ) + + +__all__ = ["VideoUpscaleReverse"] diff --git a/app/services/reverse/ws_imagine.py b/app/services/reverse/ws_imagine.py new file mode 100644 index 00000000..e9e648af --- /dev/null +++ b/app/services/reverse/ws_imagine.py @@ -0,0 +1,281 @@ +""" +Reverse interface: Imagine WebSocket image stream. +""" + +import asyncio +import orjson +import re +import time +import uuid +from typing import AsyncGenerator, Dict, Optional + +import aiohttp + +from app.core.config import get_config +from app.core.logger import logger +from app.services.reverse.utils.headers import build_ws_headers +from app.services.reverse.utils.websocket import WebSocketClient + +WS_IMAGINE_URL = "wss://grok.com/ws/imagine/listen" + + +class _BlockedError(Exception): + pass + + +class ImagineWebSocketReverse: + """Imagine WebSocket reverse interface.""" + + def __init__(self) -> None: + self._url_pattern = re.compile(r"/images/([a-f0-9-]+)\.(png|jpg|jpeg)") + self._client = WebSocketClient() + + def _parse_image_url(self, url: str) -> tuple[Optional[str], Optional[str]]: + match = self._url_pattern.search(url or "") + if not match: + return None, None + return match.group(1), match.group(2).lower() + + def _is_final_image(self, url: str, blob_size: int, final_min_bytes: int) -> bool: + url_lower = (url or "").lower() + if url_lower.endswith((".jpg", ".jpeg")): + return True + return blob_size > final_min_bytes + + def _classify_image(self, url: str, blob: str, final_min_bytes: int, medium_min_bytes: int) -> Optional[Dict[str, object]]: + if not url or not blob: + return None + + image_id, ext = self._parse_image_url(url) + image_id = image_id or uuid.uuid4().hex + blob_size = len(blob) + is_final = self._is_final_image(url, blob_size, final_min_bytes) + + stage = ( + "final" + if is_final + else ("medium" if blob_size > medium_min_bytes else "preview") + ) + + return { + "type": "image", + "image_id": image_id, + "ext": ext, + "stage": stage, + "blob": blob, + "blob_size": blob_size, + "url": url, + "is_final": is_final, + } + + def _build_request_message(self, request_id: str, prompt: str, aspect_ratio: str, enable_nsfw: bool) -> Dict[str, object]: + return { + "type": "conversation.item.create", + "timestamp": int(time.time() * 1000), + "item": { + "type": "message", + "content": [ + { + "requestId": request_id, + "text": prompt, + "type": "input_text", + "properties": { + "section_count": 0, + "is_kids_mode": False, + "enable_nsfw": enable_nsfw, + "skip_upsampler": False, + "is_initial": False, + "aspect_ratio": aspect_ratio, + }, + } + ], + }, + } + + async def stream( + self, + token: str, + prompt: str, + aspect_ratio: str = "2:3", + n: int = 1, + enable_nsfw: bool = True, + max_retries: Optional[int] = None, + ) -> AsyncGenerator[Dict[str, object], None]: + retries = max(1, max_retries if max_retries is not None else 1) + logger.info( + f"Image generation: prompt='{prompt[:50]}...', n={n}, ratio={aspect_ratio}, nsfw={enable_nsfw}" + ) + + for attempt in range(retries): + try: + yielded_any = False + async for item in self._stream_once( + token, prompt, aspect_ratio, n, enable_nsfw + ): + yielded_any = True + yield item + return + except _BlockedError: + if yielded_any or attempt + 1 >= retries: + if not yielded_any: + yield { + "type": "error", + "error_code": "blocked", + "error": "blocked_no_final_image", + } + return + logger.warning(f"WebSocket blocked, retry {attempt + 1}/{retries}") + except Exception as e: + logger.error(f"WebSocket stream failed: {e}") + yield { + "type": "error", + "error_code": "ws_stream_failed", + "error": str(e), + } + return + + async def _stream_once( + self, + token: str, + prompt: str, + aspect_ratio: str, + n: int, + enable_nsfw: bool, + ) -> AsyncGenerator[Dict[str, object], None]: + request_id = str(uuid.uuid4()) + headers = build_ws_headers(token=token) + timeout = float(get_config("image.timeout")) + stream_timeout = float(get_config("image.stream_timeout")) + final_timeout = float(get_config("image.final_timeout")) + blocked_grace = min(10.0, final_timeout) + final_min_bytes = int(get_config("image.final_min_bytes")) + medium_min_bytes = int(get_config("image.medium_min_bytes")) + + try: + conn = await self._client.connect( + WS_IMAGINE_URL, + headers=headers, + timeout=timeout, + ws_kwargs={ + "heartbeat": 20, + "receive_timeout": stream_timeout, + }, + ) + except Exception as e: + status = getattr(e, "status", None) + error_code = ( + "rate_limit_exceeded" if status == 429 else "connection_failed" + ) + logger.error(f"WebSocket connect failed: {e}") + yield { + "type": "error", + "error_code": error_code, + "status": status, + "error": str(e), + } + return + + try: + async with conn as ws: + message = self._build_request_message( + request_id, prompt, aspect_ratio, enable_nsfw + ) + await ws.send_json(message) + logger.info(f"WebSocket request sent: {prompt[:80]}...") + + final_ids: set[str] = set() + completed = 0 + start_time = last_activity = time.monotonic() + medium_received_time: Optional[float] = None + + while time.monotonic() - start_time < timeout: + try: + ws_msg = await asyncio.wait_for(ws.receive(), timeout=5.0) + except asyncio.TimeoutError: + now = time.monotonic() + if ( + medium_received_time + and completed == 0 + and now - medium_received_time > blocked_grace + ): + raise _BlockedError() + if completed > 0 and now - last_activity > 10: + logger.info( + f"WebSocket idle timeout, collected {completed} images" + ) + break + continue + + if ws_msg.type == aiohttp.WSMsgType.TEXT: + last_activity = time.monotonic() + try: + msg = orjson.loads(ws_msg.data) + except orjson.JSONDecodeError as e: + logger.warning(f"WebSocket message decode failed: {e}") + continue + + msg_type = msg.get("type") + + if msg_type == "image": + info = self._classify_image( + msg.get("url", ""), + msg.get("blob", ""), + final_min_bytes, + medium_min_bytes, + ) + if not info: + continue + + image_id = info["image_id"] + if info["stage"] == "medium" and medium_received_time is None: + medium_received_time = time.monotonic() + + if info["is_final"] and image_id not in final_ids: + final_ids.add(image_id) + completed += 1 + logger.debug( + f"Final image received: id={image_id}, size={info['blob_size']}" + ) + + yield info + + elif msg_type == "error": + logger.warning( + f"WebSocket error: {msg.get('err_code', '')} - {msg.get('err_msg', '')}" + ) + yield { + "type": "error", + "error_code": msg.get("err_code", ""), + "error": msg.get("err_msg", ""), + } + return + + if completed >= n: + logger.info(f"WebSocket collected {completed} final images") + break + + if ( + medium_received_time + and completed == 0 + and time.monotonic() - medium_received_time > final_timeout + ): + raise _BlockedError() + + elif ws_msg.type in ( + aiohttp.WSMsgType.CLOSED, + aiohttp.WSMsgType.ERROR, + ): + logger.warning(f"WebSocket closed/error: {ws_msg.type}") + yield { + "type": "error", + "error_code": "ws_closed", + "error": f"websocket closed: {ws_msg.type}", + } + break + + except aiohttp.ClientError as e: + logger.error(f"WebSocket connection error: {e}") + yield {"type": "error", "error_code": "connection_failed", "error": str(e)} + + +__all__ = ["ImagineWebSocketReverse", "WS_IMAGINE_URL"] diff --git a/app/services/reverse/ws_livekit.py b/app/services/reverse/ws_livekit.py new file mode 100644 index 00000000..bf3d92ae --- /dev/null +++ b/app/services/reverse/ws_livekit.py @@ -0,0 +1,182 @@ +""" +Reverse interface: LiveKit token + WebSocket. +""" + +import orjson +from typing import Any, Dict +from urllib.parse import urlencode +from curl_cffi.requests import AsyncSession + +from app.core.logger import logger +from app.core.config import get_config +from app.core.exceptions import UpstreamException +from app.services.token.service import TokenService +from app.services.reverse.utils.headers import build_headers, build_ws_headers +from app.services.reverse.utils.retry import retry_on_status +from app.services.reverse.utils.websocket import WebSocketClient, WebSocketConnection + +LIVEKIT_TOKEN_API = "https://grok.com/rest/livekit/tokens" +LIVEKIT_WS_URL = "wss://livekit.grok.com" + + +class LivekitTokenReverse: + """/rest/livekit/tokens reverse interface.""" + + @staticmethod + async def request( + session: AsyncSession, + token: str, + voice: str = "ara", + personality: str = "assistant", + speed: float = 1.0, + ) -> Dict[str, Any]: + """Fetch LiveKit token. + + Args: + session: AsyncSession, the session to use for the request. + token: str, the SSO token. + voice: str, the voice to use for the request. + personality: str, the personality to use for the request. + speed: float, the speed to use for the request. + + Returns: + Dict[str, Any]: The LiveKit token. + """ + try: + # Get proxies + base_proxy = get_config("proxy.base_proxy_url") + proxies = {"http": base_proxy, "https": base_proxy} if base_proxy else None + + # Build headers + headers = build_headers( + cookie_token=token, + content_type="application/json", + origin="https://grok.com", + referer="https://grok.com/", + ) + + # Build payload + payload = { + "sessionPayload": orjson.dumps( + { + "voice": voice, + "personality": personality, + "playback_speed": speed, + "enable_vision": False, + "turn_detection": {"type": "server_vad"}, + } + ).decode(), + "requestAgentDispatch": False, + "livekitUrl": LIVEKIT_WS_URL, + "params": {"enable_markdown_transcript": "true"}, + } + + # Curl Config + timeout = get_config("voice.timeout") + browser = get_config("proxy.browser") + + async def _do_request(): + response = await session.post( + LIVEKIT_TOKEN_API, + headers=headers, + data=orjson.dumps(payload), + timeout=timeout, + proxies=proxies, + impersonate=browser, + ) + + if response.status_code != 200: + body = response.text[:200] + logger.error( + f"LivekitTokenReverse: Request failed, {response.status_code}, body={body}" + ) + raise UpstreamException( + message=f"LivekitTokenReverse: Request failed, {response.status_code}", + details={"status": response.status_code, "body": response.text}, + ) + + return response + + response = await retry_on_status(_do_request) + return response + + except Exception as e: + # Handle upstream exception + if isinstance(e, UpstreamException): + status = None + if e.details and "status" in e.details: + status = e.details["status"] + else: + status = getattr(e, "status_code", None) + if status == 401: + try: + await TokenService.record_fail( + token, status, "livekit_token_auth_failed" + ) + except Exception: + pass + raise + + # Handle other non-upstream exceptions + logger.error( + f"LivekitTokenReverse: Request failed, {str(e)}", + extra={"error_type": type(e).__name__}, + ) + raise UpstreamException( + message=f"LivekitTokenReverse: Request failed, {str(e)}", + details={"status": 502, "error": str(e)}, + ) + + +class LivekitWebSocketReverse: + """LiveKit WebSocket reverse interface.""" + + def __init__(self) -> None: + self._client = WebSocketClient() + + async def connect(self, token: str) -> WebSocketConnection: + """Connect to the LiveKit WebSocket. + + Args: + token: str, the SSO token. + + Returns: + WebSocketConnection: The LiveKit WebSocket connection. + """ + # Format URL + base = LIVEKIT_WS_URL.rstrip("/") + if not base.endswith("/rtc"): + base = f"{base}/rtc" + + # Build parameters + params = { + "access_token": token, + "auto_subscribe": "1", + "sdk": "js", + "version": "2.11.4", + "protocol": "15", + } + + # Build URL + url = f"{base}?{urlencode(params)}" + + # Build WebSocket headers + ws_headers = build_ws_headers() + + try: + return await self._client.connect( + url, headers=ws_headers, timeout=get_config("voice.timeout") + ) + except Exception as e: + logger.error(f"LivekitWebSocketReverse: Connect failed, {e}") + raise UpstreamException( + f"LivekitWebSocketReverse: Connect failed, {str(e)}" + ) + + +__all__ = [ + "LivekitTokenReverse", + "LivekitWebSocketReverse", + "LIVEKIT_TOKEN_API", + "LIVEKIT_WS_URL", +] diff --git a/app/services/token/manager.py b/app/services/token/manager.py index 39423ed9..1530cbea 100644 --- a/app/services/token/manager.py +++ b/app/services/token/manager.py @@ -14,9 +14,11 @@ BASIC__DEFAULT_QUOTA, SUPER_DEFAULT_QUOTA, ) -from app.core.storage import get_storage +from app.core.storage import get_storage, LocalStorage from app.core.config import get_config +from app.core.exceptions import UpstreamException from app.services.token.pool import TokenPool +from app.services.grok.batch_services.usage import UsageService DEFAULT_REFRESH_BATCH_SIZE = 10 @@ -70,8 +72,6 @@ async def _load(self): # 如果后端返回 None 或空数据,尝试从本地 data/token.json 初始化后端 if not data: - from app.core.storage import LocalStorage - local_storage = LocalStorage() local_data = await local_storage.load_tokens() if local_data: @@ -299,6 +299,14 @@ def get_token_for_video( ) return None + def get_pool_name_for_token(self, token_str: str) -> Optional[str]: + """Return pool name for the given token string.""" + raw_token = token_str.replace("sso=", "") + for pool_name, pool in self.pools.items(): + if pool.get(raw_token): + return pool_name + return None + async def consume( self, token_str: str, effort: EffortType = EffortType.LOW ) -> bool: @@ -330,7 +338,6 @@ async def consume( async def sync_usage( self, token_str: str, - model_name: str, fallback_effort: EffortType = EffortType.LOW, consume_on_fail: bool = True, is_usage: bool = True, @@ -342,7 +349,6 @@ async def sync_usage( Args: token_str: Token 字符串(可带 sso= 前缀) - model_name: 模型名称(用于 API 查询) fallback_effort: 降级时的消耗力度 consume_on_fail: 失败时是否降级扣费 is_usage: 是否记录为一次使用(影响 use_count) @@ -365,10 +371,8 @@ async def sync_usage( # 尝试 API 同步 try: - from app.services.grok.services.usage import UsageService - usage_service = UsageService() - result = await usage_service.get(token_str, model_name=model_name) + result = await usage_service.get(token_str) if result and "remainingTokens" in result: old_quota = target_token.quota @@ -387,6 +391,14 @@ async def sync_usage( return True except Exception as e: + if isinstance(e, UpstreamException): + status = None + if e.details and "status" in e.details: + status = e.details["status"] + else: + status = getattr(e, "status_code", None) + if status == 401: + await self.record_fail(token_str, status, "rate_limits_auth_failed") logger.warning( f"Token {raw_token[:10]}...: API sync failed, fallback to local ({e})" ) @@ -420,11 +432,19 @@ async def record_fail( for pool in self.pools.values(): token = pool.get(raw_token) if token: - if status_code in (401, 403): - token.record_fail(status_code, reason) + if status_code == 401: + threshold = get_config("token.fail_threshold", FAIL_THRESHOLD) + try: + threshold = int(threshold) + except (TypeError, ValueError): + threshold = FAIL_THRESHOLD + if threshold < 1: + threshold = 1 + + token.record_fail(status_code, reason, threshold=threshold) logger.warning( f"Token {raw_token[:10]}...: recorded {status_code} failure " - f"({token.fail_count}/{FAIL_THRESHOLD}) - {reason}" + f"({token.fail_count}/{threshold}) - {reason}" ) else: logger.info( @@ -636,8 +656,6 @@ async def refresh_cooling_tokens(self) -> Dict[str, int]: Returns: {"checked": int, "refreshed": int, "recovered": int, "expired": int} """ - from app.services.grok.services.usage import UsageService - # 收集需要刷新的 token to_refresh: List[TokenInfo] = [] for pool in self.pools.values(): @@ -678,7 +696,7 @@ async def _refresh_one(token_info: TokenInfo) -> dict: # 重试逻辑:最多 2 次重试 for retry in range(3): # 0, 1, 2 try: - result = await usage_service.get(token_str, model_name="grok-3") + result = await usage_service.get(token_str) if result and "remainingTokens" in result: new_quota = result["remainingTokens"] diff --git a/app/services/token/models.py b/app/services/token/models.py index 0701ab7b..86300d90 100644 --- a/app/services/token/models.py +++ b/app/services/token/models.py @@ -128,17 +128,23 @@ def reset(self, default_quota: Optional[int] = None): self.fail_count = 0 self.last_fail_reason = None - def record_fail(self, status_code: int = 401, reason: str = ""): + def record_fail( + self, + status_code: int = 401, + reason: str = "", + threshold: Optional[int] = None, + ): """记录失败,达到阈值后自动标记为 expired""" - # 401/403 错误计入失败(都表示认证/授权失败) - if status_code not in (401, 403): + # 仅 401 计入失败 + if status_code != 401: return self.fail_count += 1 self.last_fail_at = int(datetime.now().timestamp() * 1000) self.last_fail_reason = reason - if self.fail_count >= FAIL_THRESHOLD: + limit = FAIL_THRESHOLD if threshold is None else threshold + if self.fail_count >= limit: self.status = TokenStatus.EXPIRED def record_success(self, is_usage: bool = True): diff --git a/app/services/token/scheduler.py b/app/services/token/scheduler.py index 14277132..5ec8cafb 100644 --- a/app/services/token/scheduler.py +++ b/app/services/token/scheduler.py @@ -36,7 +36,7 @@ async def _refresh_loop(self): lock_acquired = await lock.acquire(blocking=False) else: try: - async with storage.acquire_lock("token_refresh", timeout=0): + async with storage.acquire_lock("token_refresh", timeout=1): lock_acquired = True except StorageError: lock_acquired = False diff --git a/app/services/token/service.py b/app/services/token/service.py index 75b23a67..b441fbeb 100644 --- a/app/services/token/service.py +++ b/app/services/token/service.py @@ -2,7 +2,6 @@ from typing import List, Optional, Dict -from app.services.token.manager import get_token_manager from app.services.token.models import TokenInfo, EffortType @@ -13,6 +12,12 @@ class TokenService: 提供简化的 API,隐藏内部实现细节 """ + @staticmethod + async def _get_manager(): + from app.services.token.manager import get_token_manager + + return await get_token_manager() + @staticmethod async def get_token(pool_name: str = "ssoBasic") -> Optional[str]: """ @@ -24,7 +29,7 @@ async def get_token(pool_name: str = "ssoBasic") -> Optional[str]: Returns: Token 字符串(不含 sso= 前缀)或 None """ - manager = await get_token_manager() + manager = await TokenService._get_manager() return manager.get_token(pool_name) @staticmethod @@ -39,26 +44,23 @@ async def consume(token: str, effort: EffortType = EffortType.LOW) -> bool: Returns: 是否成功 """ - manager = await get_token_manager() + manager = await TokenService._get_manager() return await manager.consume(token, effort) @staticmethod - async def sync_usage( - token: str, model: str, effort: EffortType = EffortType.LOW - ) -> bool: + async def sync_usage(token: str, effort: EffortType = EffortType.LOW) -> bool: """ 同步 Token 使用量(优先 API,降级本地) Args: token: Token 字符串 - model: 模型名称 effort: 降级时的消耗力度 Returns: 是否成功 """ - manager = await get_token_manager() - return await manager.sync_usage(token, model, effort) + manager = await TokenService._get_manager() + return await manager.sync_usage(token, effort) @staticmethod async def record_fail(token: str, status_code: int = 401, reason: str = "") -> bool: @@ -73,7 +75,7 @@ async def record_fail(token: str, status_code: int = 401, reason: str = "") -> b Returns: 是否成功 """ - manager = await get_token_manager() + manager = await TokenService._get_manager() return await manager.record_fail(token, status_code, reason) @staticmethod @@ -88,7 +90,7 @@ async def add_token(token: str, pool_name: str = "ssoBasic") -> bool: Returns: 是否成功 """ - manager = await get_token_manager() + manager = await TokenService._get_manager() return await manager.add(token, pool_name) @staticmethod @@ -102,7 +104,7 @@ async def remove_token(token: str) -> bool: Returns: 是否成功 """ - manager = await get_token_manager() + manager = await TokenService._get_manager() return await manager.remove(token) @staticmethod @@ -116,13 +118,13 @@ async def reset_token(token: str) -> bool: Returns: 是否成功 """ - manager = await get_token_manager() + manager = await TokenService._get_manager() return await manager.reset_token(token) @staticmethod async def reset_all(): """重置所有 Token""" - manager = await get_token_manager() + manager = await TokenService._get_manager() await manager.reset_all() @staticmethod @@ -133,7 +135,7 @@ async def get_stats() -> Dict[str, dict]: Returns: 各池的统计信息 """ - manager = await get_token_manager() + manager = await TokenService._get_manager() return manager.get_stats() @staticmethod @@ -147,7 +149,7 @@ async def list_tokens(pool_name: str = "ssoBasic") -> List[TokenInfo]: Returns: Token 列表 """ - manager = await get_token_manager() + manager = await TokenService._get_manager() return manager.get_pool_tokens(pool_name) diff --git a/app/static/cache/cache.css b/app/static/admin/css/cache.css similarity index 100% rename from app/static/cache/cache.css rename to app/static/admin/css/cache.css diff --git a/app/static/config/config.css b/app/static/admin/css/config.css similarity index 80% rename from app/static/config/config.css rename to app/static/admin/css/config.css index b887da93..8fa851be 100644 --- a/app/static/config/config.css +++ b/app/static/admin/css/config.css @@ -28,6 +28,7 @@ .config-field { padding-top: 2px; + position: relative; } .config-field-title { @@ -45,3 +46,14 @@ .config-field-input { margin-top: 6px; } + +.config-field.has-action { + padding-right: 44px; +} + +.config-field-action { + position: absolute; + right: 0; + top: 50%; + transform: translateY(-50%); +} diff --git a/app/static/token/token.css b/app/static/admin/css/token.css similarity index 100% rename from app/static/token/token.css rename to app/static/admin/css/token.css diff --git a/app/static/cache/cache.js b/app/static/admin/js/cache.js similarity index 98% rename from app/static/cache/cache.js rename to app/static/admin/js/cache.js index 40b99d64..261cccd3 100644 --- a/app/static/cache/cache.js +++ b/app/static/admin/js/cache.js @@ -84,7 +84,7 @@ function createIconButton(title, svg, onClick) { } async function init() { - apiKey = await ensureApiKey(); + apiKey = await ensureAdminKey(); if (apiKey === null) return; cacheUI(); setupCacheCards(); @@ -233,7 +233,7 @@ async function loadStats(options = {}) { } else { currentScope = 'none'; } - const url = `/api/v1/admin/cache${params.toString() ? `?${params.toString()}` : ''}`; + const url = `/v1/admin/cache${params.toString() ? `?${params.toString()}` : ''}`; const res = await fetch(url, { headers: buildAuthHeaders(apiKey) }); @@ -446,7 +446,7 @@ async function clearCache(type) { if (!ok) return; try { - const res = await fetch('/api/v1/admin/cache/clear', { + const res = await fetch('/v1/admin/cache/clear', { method: 'POST', headers: { 'Content-Type': 'application/json', @@ -770,7 +770,7 @@ async function loadLocalCacheList(type) { body.innerHTML = `加载中...`; try { const params = new URLSearchParams({ type, page: '1', page_size: '1000' }); - const res = await fetch(`/api/v1/admin/cache/list?${params.toString()}`, { + const res = await fetch(`/v1/admin/cache/list?${params.toString()}`, { headers: buildAuthHeaders(apiKey) }); if (!res.ok) { @@ -897,7 +897,7 @@ async function deleteLocalFile(type, name) { async function requestDeleteLocalFile(type, name) { try { - const res = await fetch('/api/v1/admin/cache/item/delete', { + const res = await fetch('/v1/admin/cache/item/delete', { method: 'POST', headers: { 'Content-Type': 'application/json', @@ -1104,7 +1104,7 @@ async function startBatchLoad(tokens) { refreshBatchUI(); try { - const res = await fetch('/api/v1/admin/cache/online/load/async', { + const res = await fetch('/v1/admin/cache/online/load/async', { method: 'POST', headers: { 'Content-Type': 'application/json', @@ -1242,7 +1242,7 @@ async function startBatchDelete(tokens) { updateDeleteButton(); refreshBatchUI(); try { - const res = await fetch('/api/v1/admin/cache/online/clear/async', { + const res = await fetch('/v1/admin/cache/online/clear/async', { method: 'POST', headers: { 'Content-Type': 'application/json', @@ -1337,7 +1337,7 @@ async function clearOnlineCache(targetToken = '', skipConfirm = false) { showToast('正在清理在线资产,请稍候...', 'info'); try { - const res = await fetch('/api/v1/admin/cache/online/clear', { + const res = await fetch('/v1/admin/cache/online/clear', { method: 'POST', headers: { 'Content-Type': 'application/json', diff --git a/app/static/config/config.js b/app/static/admin/js/config.js similarity index 65% rename from app/static/config/config.js rename to app/static/admin/js/config.js index f0282a7c..e2978a66 100644 --- a/app/static/config/config.js +++ b/app/static/admin/js/config.js @@ -13,23 +13,23 @@ const NUMERIC_FIELDS = new Set([ 'fail_threshold', 'limit_mb', 'save_delay_ms', - 'assets_max_concurrent', - 'media_max_concurrent', - 'usage_max_concurrent', - 'assets_delete_batch_size', - 'assets_batch_size', - 'assets_max_tokens', - 'usage_batch_size', - 'usage_max_tokens', + 'upload_concurrent', + 'upload_timeout', + 'download_concurrent', + 'download_timeout', + 'list_concurrent', + 'list_timeout', + 'list_batch_size', + 'delete_concurrent', + 'delete_timeout', + 'delete_batch_size', 'reload_interval_sec', - 'stream_idle_timeout', - 'video_idle_timeout', - 'image_ws_blocked_seconds', - 'image_ws_final_min_bytes', - 'image_ws_medium_min_bytes', - 'nsfw_max_concurrent', - 'nsfw_batch_size', - 'nsfw_max_tokens' + 'stream_timeout', + 'final_timeout', + 'final_min_bytes', + 'medium_min_bytes', + 'concurrent', + 'batch_size' ]); const LOCALE_MAP = { @@ -37,31 +37,30 @@ const LOCALE_MAP = { "label": "应用设置", "api_key": { title: "API 密钥", desc: "调用 Grok2API 服务的 Token(可选)。" }, "app_key": { title: "后台密码", desc: "登录 Grok2API 管理后台的密码(必填)。" }, + "public_enabled": { title: "启用功能玩法", desc: "是否启用功能玩法入口(关闭则功能玩法页面不可访问)。" }, + "public_key": { title: "Public 密码", desc: "功能玩法页面的访问密码(可选)。" }, "app_url": { title: "应用地址", desc: "当前 Grok2API 服务的外部访问 URL,用于文件链接访问。" }, - "image_format": { title: "图片格式", desc: "生成的图片格式(url 或 base64)。" }, - "video_format": { title: "视频格式", desc: "生成的视频格式(html 或 url,url 为处理后的链接)。" } + "image_format": { title: "图片格式", desc: "默认生成的图片格式(url 或 base64)。" }, + "video_format": { title: "视频格式", desc: "默认生成的视频格式(html 或 url,url 为处理后的链接)。" }, + "temporary": { title: "临时对话", desc: "是否默认启用临时对话模式。" }, + "disable_memory": { title: "禁用记忆", desc: "是否默认禁用 Grok 记忆功能。" }, + "stream": { title: "流式响应", desc: "是否默认启用流式输出。" }, + "thinking": { title: "思维链", desc: "是否默认启用思维链输出。" }, + "dynamic_statsig": { title: "动态指纹", desc: "是否默认启用动态生成 Statsig 指纹。" }, + "filter_tags": { title: "过滤标签", desc: "设置自动过滤 Grok 响应中的特殊标签。" } }, - "network": { - "label": "网络配置", - "timeout": { title: "请求超时", desc: "请求 Grok 服务的超时时间(秒)。" }, + + + "proxy": { + "label": "代理配置", "base_proxy_url": { title: "基础代理 URL", desc: "代理请求到 Grok 官网的基础服务地址。" }, - "asset_proxy_url": { title: "资源代理 URL", desc: "代理请求到 Grok 官网的静态资源(图片/视频)地址。" } - }, - "security": { - "label": "反爬虫验证", + "asset_proxy_url": { title: "资源代理 URL", desc: "代理请求到 Grok 官网的静态资源(图片/视频)地址。" }, "cf_clearance": { title: "CF Clearance", desc: "Cloudflare Clearance Cookie,用于绕过反爬虫验证。" }, "browser": { title: "浏览器指纹", desc: "curl_cffi 浏览器指纹标识(如 chrome136)。" }, "user_agent": { title: "User-Agent", desc: "HTTP 请求的 User-Agent 字符串,需与浏览器指纹匹配。" } }, - "chat": { - "label": "对话配置", - "temporary": { title: "临时对话", desc: "是否启用临时对话模式。" }, - "disable_memory": { title: "禁用记忆", desc: "禁用 Grok 记忆功能,以防止响应中出现不相关上下文。" }, - "stream": { title: "流式响应", desc: "是否默认启用流式输出。" }, - "thinking": { title: "思维链", desc: "是否启用模型思维链输出。" }, - "dynamic_statsig": { title: "动态指纹", desc: "是否启用动态生成 Statsig 值。" }, - "filter_tags": { title: "过滤标签", desc: "自动过滤 Grok 响应中的特殊标签。" } - }, + + "retry": { "label": "重试策略", "max_retry": { title: "最大重试次数", desc: "请求 Grok 服务失败时的最大重试次数。" }, @@ -71,19 +70,56 @@ const LOCALE_MAP = { "retry_backoff_max": { title: "退避上限", desc: "单次重试等待的最大延迟(秒)。" }, "retry_budget": { title: "退避预算", desc: "单次请求的最大重试总耗时(秒)。" } }, - "timeout": { - "label": "超时配置", - "stream_idle_timeout": { title: "流空闲超时", desc: "流式响应空闲超时(秒),超过将断开。" }, - "video_idle_timeout": { title: "视频空闲超时", desc: "视频生成空闲超时(秒),超过将断开。" } + + + "chat": { + "label": "对话配置", + "concurrent": { title: "并发上限", desc: "Reverse 接口并发上限。" }, + "timeout": { title: "请求超时", desc: "Reverse 接口超时时间(秒)。" }, + "stream_timeout": { title: "流空闲超时", desc: "流式空闲超时时间(秒)。" } + }, + + + "video": { + "label": "视频配置", + "concurrent": { title: "并发上限", desc: "Reverse 接口并发上限。" }, + "timeout": { title: "请求超时", desc: "Reverse 接口超时时间(秒)。" }, + "stream_timeout": { title: "流空闲超时", desc: "流式空闲超时时间(秒)。" } }, + + "image": { - "label": "图片生成", - "image_ws": { title: "WebSocket 生成", desc: "启用后 /v1/images/generations 走 WebSocket 直连。" }, - "image_ws_nsfw": { title: "NSFW 模式", desc: "WebSocket 请求是否启用 NSFW。" }, - "image_ws_blocked_seconds": { title: "Blocked 阈值", desc: "收到中等图后超过该秒数仍无最终图则判定 blocked。" }, - "image_ws_final_min_bytes": { title: "最终图最小字节", desc: "判定最终图的最小字节数(通常 JPG > 100KB)。" }, - "image_ws_medium_min_bytes": { title: "中等图最小字节", desc: "判定中等质量图的最小字节数。" } + "label": "图像配置", + "timeout": { title: "请求超时", desc: "WebSocket 请求超时时间(秒)。" }, + "stream_timeout": { title: "流空闲超时", desc: "WebSocket 流式空闲超时时间(秒)。" }, + "final_timeout": { title: "最终图超时", desc: "收到中等图后等待最终图的超时秒数。" }, + "nsfw": { title: "NSFW 模式", desc: "WebSocket 请求是否启用 NSFW。" }, + "medium_min_bytes": { title: "中等图最小字节", desc: "判定中等质量图的最小字节数。" }, + "final_min_bytes": { title: "最终图最小字节", desc: "判定最终图的最小字节数(通常 JPG > 100KB)。" } }, + + + "asset": { + "label": "资产配置", + "upload_concurrent": { title: "上传并发", desc: "上传接口的最大并发数。推荐 30。" }, + "upload_timeout": { title: "上传超时", desc: "上传接口超时时间(秒)。推荐 60。" }, + "download_concurrent": { title: "下载并发", desc: "下载接口的最大并发数。推荐 30。" }, + "download_timeout": { title: "下载超时", desc: "下载接口超时时间(秒)。推荐 60。" }, + "list_concurrent": { title: "查询并发", desc: "资产查询接口的最大并发数。推荐 10。" }, + "list_timeout": { title: "查询超时", desc: "资产查询接口超时时间(秒)。推荐 60。" }, + "list_batch_size": { title: "查询批次大小", desc: "单次查询可处理的 Token 数量。推荐 10。" }, + "delete_concurrent": { title: "删除并发", desc: "资产删除接口的最大并发数。推荐 10。" }, + "delete_timeout": { title: "删除超时", desc: "资产删除接口超时时间(秒)。推荐 60。" }, + "delete_batch_size": { title: "删除批次大小", desc: "单次删除可处理的 Token 数量。推荐 10。" } + }, + + + "voice": { + "label": "语音配置", + "timeout": { title: "请求超时", desc: "Voice 请求超时时间(秒)。" } + }, + + "token": { "label": "Token 池管理", "auto_refresh": { title: "自动刷新", desc: "是否开启 Token 自动刷新机制。" }, @@ -93,30 +129,34 @@ const LOCALE_MAP = { "save_delay_ms": { title: "保存延迟", desc: "Token 变更合并写入的延迟(毫秒)。" }, "reload_interval_sec": { title: "同步间隔", desc: "多 worker 场景下 Token 状态刷新间隔(秒)。" } }, + + "cache": { "label": "缓存管理", "enable_auto_clean": { title: "自动清理", desc: "是否启用缓存自动清理,开启后按上限自动回收。" }, "limit_mb": { title: "清理阈值", desc: "缓存大小阈值(MB),超过阈值会触发清理。" } }, - "performance": { - "label": "并发性能", - "media_max_concurrent": { title: "Media 并发上限", desc: "视频/媒体生成请求的并发上限。推荐 50。" }, - "nsfw_max_concurrent": { title: "NSFW 开启并发上限", desc: "批量开启 NSFW 模式时的并发请求上限。推荐 10。" }, - "nsfw_batch_size": { title: "NSFW 开启批量大小", desc: "批量开启 NSFW 模式的单批处理数量。推荐 50。" }, - "nsfw_max_tokens": { title: "NSFW 开启最大数量", desc: "单次批量开启 NSFW 的 Token 数量上限,防止误操作。推荐 1000。" }, - "usage_max_concurrent": { title: "Token 刷新并发上限", desc: "批量刷新 Token 用量时的并发请求上限。推荐 25。" }, - "usage_batch_size": { title: "Token 刷新批次大小", desc: "批量刷新 Token 用量的单批处理数量。推荐 50。" }, - "usage_max_tokens": { title: "Token 刷新最大数量", desc: "单次批量刷新 Token 用量时的处理数量上限。推荐 1000。" }, - "assets_max_concurrent": { title: "Assets 处理并发上限", desc: "批量查找/删除资产时的并发请求上限。推荐 25。" }, - "assets_batch_size": { title: "Assets 处理批次大小", desc: "批量查找/删除资产时的单批处理数量。推荐 10。" }, - "assets_max_tokens": { title: "Assets 处理最大数量", desc: "单次批量查找/删除资产时的处理数量上限。推荐 1000。" }, - "assets_delete_batch_size": { title: "Assets 单账号删除批量大小", desc: "单账号批量删除资产时的单批并发数量。推荐 10。" } + + + "nsfw": { + "label": "NSFW 配置", + "concurrent": { title: "并发上限", desc: "批量开启 NSFW 模式时的并发请求上限。推荐 10。" }, + "batch_size": { title: "批次大小", desc: "批量开启 NSFW 模式的单批处理数量。推荐 50。" }, + "timeout": { title: "请求超时", desc: "NSFW 开启相关请求的超时时间(秒)。推荐 60。" } + }, + + + "usage": { + "label": "Usage 配置", + "concurrent": { title: "并发上限", desc: "批量刷新用量时的并发请求上限。推荐 10。" }, + "batch_size": { title: "批次大小", desc: "批量刷新用量的单批处理数量。推荐 50。" }, + "timeout": { title: "请求超时", desc: "用量查询接口的超时时间(秒)。推荐 60。" } } }; // 配置部分说明(可选) const SECTION_DESCRIPTIONS = { - "security": "配置不正确将导致 403 错误。服务首次请求 Grok 时的 IP 必须与获取 CF Clearance 时的 IP 一致,后续服务器请求 IP 变化不会导致 403。" + "proxy": "配置不正确将导致 403 错误。服务首次请求 Grok 时的 IP 必须与获取 CF Clearance 时的 IP 一致,后续服务器请求 IP 变化不会导致 403。" }; const SECTION_ORDER = new Map(Object.keys(LOCALE_MAP).map((key, index) => [key, index])); @@ -218,6 +258,15 @@ function buildSecretInput(section, key, val) { const wrapper = document.createElement('div'); wrapper.className = 'flex items-center gap-2'; + const genBtn = document.createElement('button'); + genBtn.className = 'flex-none w-[32px] h-[32px] flex items-center justify-center bg-black text-white rounded-md hover:opacity-80 transition-opacity'; + genBtn.type = 'button'; + genBtn.title = '生成'; + genBtn.innerHTML = ``; + genBtn.onclick = () => { + input.value = randomKey(16); + }; + const copyBtn = document.createElement('button'); copyBtn.className = 'flex-none w-[32px] h-[32px] flex items-center justify-center bg-black text-white rounded-md hover:opacity-80 transition-opacity'; copyBtn.type = 'button'; @@ -225,20 +274,38 @@ function buildSecretInput(section, key, val) { copyBtn.onclick = () => copyToClipboard(input.value, copyBtn); wrapper.appendChild(input); + wrapper.appendChild(genBtn); wrapper.appendChild(copyBtn); return { input, node: wrapper }; } +function randomKey(len) { + const chars = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'; + const out = []; + if (window.crypto && window.crypto.getRandomValues) { + const buf = new Uint8Array(len); + window.crypto.getRandomValues(buf); + for (let i = 0; i < len; i++) { + out.push(chars[buf[i] % chars.length]); + } + return out.join(''); + } + for (let i = 0; i < len; i++) { + out.push(chars[Math.floor(Math.random() * chars.length)]); + } + return out.join(''); +} + async function init() { - apiKey = await ensureApiKey(); + apiKey = await ensureAdminKey(); if (apiKey === null) return; loadData(); } async function loadData() { try { - const res = await fetch('/api/v1/admin/config', { + const res = await fetch('/v1/admin/config', { headers: buildAuthHeaders(apiKey) }); if (res.ok) { @@ -347,7 +414,7 @@ function buildFieldCard(section, key, val) { built = buildJsonInput(section, key, val); } else { - if (key === 'api_key' || key === 'app_key') { + if (key === 'api_key' || key === 'app_key' || key === 'public_key') { built = buildSecretInput(section, key, val); } else { built = buildTextInput(section, key, val); @@ -359,6 +426,23 @@ function buildFieldCard(section, key, val) { } fieldCard.appendChild(inputWrapper); + if (section === 'app' && key === 'public_enabled') { + fieldCard.classList.add('has-action'); + const link = document.createElement('a'); + link.href = '/login'; + link.className = 'config-field-action flex-none w-[32px] h-[32px] flex items-center justify-center bg-black text-white rounded-md hover:opacity-80 transition-opacity'; + link.title = '功能玩法'; + link.setAttribute('aria-label', '功能玩法'); + link.innerHTML = ``; + link.style.display = val ? 'inline-flex' : 'none'; + fieldCard.appendChild(link); + if (built && built.input) { + built.input.addEventListener('change', () => { + link.style.display = built.input.checked ? 'inline-flex' : 'none'; + }); + } + } + return fieldCard; } @@ -395,7 +479,7 @@ async function saveConfig() { newConfig[s][k] = val; }); - const res = await fetch('/api/v1/admin/config', { + const res = await fetch('/v1/admin/config', { method: 'POST', headers: { 'Content-Type': 'application/json', diff --git a/app/static/login/login.js b/app/static/admin/js/login.js similarity index 69% rename from app/static/login/login.js rename to app/static/admin/js/login.js index b9bd6cb6..48dfe956 100644 --- a/app/static/login/login.js +++ b/app/static/admin/js/login.js @@ -1,13 +1,19 @@ const apiKeyInput = document.getElementById('api-key-input'); +const publicKeyInput = document.getElementById('public-key-input'); if (apiKeyInput) { apiKeyInput.addEventListener('keypress', (e) => { if (e.key === 'Enter') login(); }); } +if (publicKeyInput) { + publicKeyInput.addEventListener('keypress', (e) => { + if (e.key === 'Enter') login(); + }); +} async function requestLogin(key) { - const res = await fetch('/api/v1/admin/login', { - method: 'POST', + const res = await fetch('/v1/admin/verify', { + method: 'GET', headers: { 'Authorization': `Bearer ${key}` } }); return res.ok; @@ -15,12 +21,16 @@ async function requestLogin(key) { async function login() { const input = (apiKeyInput ? apiKeyInput.value : '').trim(); + const publicKey = (publicKeyInput ? publicKeyInput.value : '').trim(); if (!input) return; try { const ok = await requestLogin(input); if (ok) { await storeAppKey(input); + if (publicKey) { + await storePublicKey(publicKey); + } window.location.href = '/admin/token'; } else { showToast('密钥无效', 'error'); diff --git a/app/static/token/token.js b/app/static/admin/js/token.js similarity index 99% rename from app/static/token/token.js rename to app/static/admin/js/token.js index 1ed1d0af..8c41e78b 100644 --- a/app/static/token/token.js +++ b/app/static/admin/js/token.js @@ -107,7 +107,7 @@ function getPaginationData() { } async function init() { - apiKey = await ensureApiKey(); + apiKey = await ensureAdminKey(); if (apiKey === null) return; setupEditPoolDefaults(); setupConfirmDialog(); @@ -116,7 +116,7 @@ async function init() { async function loadData() { try { - const res = await fetch('/api/v1/admin/tokens', { + const res = await fetch('/v1/admin/tokens', { headers: buildAuthHeaders(apiKey) }); if (res.ok) { @@ -536,7 +536,7 @@ async function syncToServer() { }); try { - const res = await fetch('/api/v1/admin/tokens', { + const res = await fetch('/v1/admin/tokens', { method: 'POST', headers: { 'Content-Type': 'application/json', @@ -622,7 +622,7 @@ async function refreshStatus(token) { btn.innerHTML = ``; } - const res = await fetch('/api/v1/admin/tokens/refresh', { + const res = await fetch('/v1/admin/tokens/refresh', { method: 'POST', headers: { 'Content-Type': 'application/json', @@ -673,7 +673,7 @@ async function startBatchRefresh() { setActionButtonsState(); try { - const res = await fetch('/api/v1/admin/tokens/refresh/async', { + const res = await fetch('/v1/admin/tokens/refresh/async', { method: 'POST', headers: { 'Content-Type': 'application/json', @@ -1039,7 +1039,7 @@ async function batchEnableNSFW() { try { const tokens = selected.length > 0 ? selected.map(t => t.token) : null; - const res = await fetch('/api/v1/admin/tokens/nsfw/enable/async', { + const res = await fetch('/v1/admin/tokens/nsfw/enable/async', { method: 'POST', headers: { 'Content-Type': 'application/json', diff --git a/app/static/cache/cache.html b/app/static/admin/pages/cache.html similarity index 92% rename from app/static/cache/cache.html rename to app/static/admin/pages/cache.html index 128ab971..52055d24 100644 --- a/app/static/cache/cache.html +++ b/app/static/admin/pages/cache.html @@ -5,13 +5,13 @@ Grok2API - 缓存管理 - + - - - + + + @@ -196,13 +196,13 @@

缓存管理

- - - - - - - + + + + + + + diff --git a/app/static/config/config.html b/app/static/admin/pages/config.html similarity index 75% rename from app/static/config/config.html rename to app/static/admin/pages/config.html index b19411c1..68786b0d 100644 --- a/app/static/config/config.html +++ b/app/static/admin/pages/config.html @@ -5,13 +5,13 @@ Grok2API - 配置管理 - + - - - + + + @@ -46,11 +46,11 @@

配置管理

- - - - - + + + + + diff --git a/app/static/login/login.html b/app/static/admin/pages/login.html similarity index 79% rename from app/static/login/login.html rename to app/static/admin/pages/login.html index 0cb78039..913c33f7 100644 --- a/app/static/login/login.html +++ b/app/static/admin/pages/login.html @@ -5,7 +5,7 @@ Grok2API - 登录 - + @@ -24,9 +24,9 @@ } } - - - + + + @@ -47,7 +47,7 @@ @@ -55,10 +55,10 @@ - - - - + + + + diff --git a/app/static/token/token.html b/app/static/admin/pages/token.html similarity index 94% rename from app/static/token/token.html rename to app/static/admin/pages/token.html index f43aec9c..daf49759 100644 --- a/app/static/token/token.html +++ b/app/static/admin/pages/token.html @@ -5,13 +5,13 @@ Grok2API - Token 管理 - + - - - + + + @@ -291,13 +291,13 @@ - - - - - - - + + + + + + + diff --git a/app/static/common/common.css b/app/static/common/css/common.css similarity index 100% rename from app/static/common/common.css rename to app/static/common/css/common.css diff --git a/app/static/login/login.css b/app/static/common/css/login.css similarity index 100% rename from app/static/login/login.css rename to app/static/common/css/login.css diff --git a/app/static/common/toast.css b/app/static/common/css/toast.css similarity index 100% rename from app/static/common/toast.css rename to app/static/common/css/toast.css diff --git a/app/static/common/footer.html b/app/static/common/html/footer.html similarity index 100% rename from app/static/common/footer.html rename to app/static/common/html/footer.html diff --git a/app/static/common/header.html b/app/static/common/html/header.html similarity index 64% rename from app/static/common/header.html rename to app/static/common/html/header.html index 7303639d..2330a996 100644 --- a/app/static/common/header.html +++ b/app/static/common/html/header.html @@ -14,23 +14,9 @@ class="text-xs text-[var(--accents-4)] hover:text-black">@chenyme
- - + Token管理 + 配置管理 + 缓存管理