Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,8 @@ logs/
*.log

# Data
data/*.json
data/tmp/*
data/.locks/*
app/data/*.json
app/data/tmp/*
app/data/.locks/*
data/
app/data/

# Testing
.pytest_cache/
Expand Down
9 changes: 4 additions & 5 deletions app/services/grok/batch_services/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,23 @@
import asyncio
from typing import Dict, List, Optional

from curl_cffi.requests import AsyncSession

from app.core.config import get_config
from app.core.logger import logger
from app.services.reverse.assets_list import AssetsListReverse
from app.services.reverse.assets_delete import AssetsDeleteReverse
from app.services.reverse.utils.session import ResettableSession
from app.core.batch import run_batch


class BaseAssetsService:
"""Base assets service."""

def __init__(self):
self._session: Optional[AsyncSession] = None
self._session: Optional[ResettableSession] = None

async def _get_session(self) -> AsyncSession:
async def _get_session(self) -> ResettableSession:
if self._session is None:
self._session = AsyncSession()
self._session = ResettableSession()
return self._session

async def close(self):
Expand Down
5 changes: 2 additions & 3 deletions app/services/grok/batch_services/nsfw.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@
import asyncio
from typing import Callable, Awaitable, Dict, Any, Optional

from curl_cffi.requests import AsyncSession

from app.core.logger import logger
from app.core.config import get_config
from app.core.exceptions import UpstreamException
from app.services.reverse.accept_tos import AcceptTosReverse
from app.services.reverse.nsfw_mgmt import NsfwMgmtReverse
from app.services.reverse.set_birth import SetBirthReverse
from app.services.reverse.utils.session import ResettableSession
from app.core.batch import run_batch


Expand Down Expand Up @@ -44,7 +43,7 @@ async def batch(
async def _enable(token: str):
try:
browser = get_config("proxy.browser")
async with AsyncSession(impersonate=browser) as session:
async with ResettableSession(impersonate=browser) as session:
async def _record_fail(err: UpstreamException, reason: str):
status = None
if err.details and "status" in err.details:
Expand Down
5 changes: 2 additions & 3 deletions app/services/grok/batch_services/usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
import asyncio
from typing import Callable, Awaitable, Dict, Any, Optional, List

from curl_cffi.requests import AsyncSession

from app.core.logger import logger
from app.core.config import get_config
from app.services.reverse.rate_limits import RateLimitsReverse
from app.services.reverse.utils.session import ResettableSession
from app.core.batch import run_batch

_USAGE_SEMAPHORE = None
Expand Down Expand Up @@ -43,7 +42,7 @@ async def get(self, token: str) -> Dict:
"""
async with _get_usage_semaphore():
try:
async with AsyncSession() as session:
async with ResettableSession() as session:
response = await RateLimitsReverse.request(session, token)
data = response.json()
remaining = data.get("remainingTokens")
Expand Down
4 changes: 2 additions & 2 deletions app/services/grok/services/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import Dict, List, Any, AsyncGenerator, AsyncIterable

import orjson
from curl_cffi.requests import AsyncSession
from curl_cffi.requests.errors import RequestsError

from app.core.logger import logger
Expand All @@ -25,6 +24,7 @@
from app.services.grok.utils import process as proc_base
from app.services.grok.utils.retry import pick_token, rate_limited
from app.services.reverse.app_chat import AppChatReverse
from app.services.reverse.utils.session import ResettableSession
from app.services.grok.utils.stream import wrap_stream_with_usage
from app.services.token import get_token_manager, EffortType

Expand Down Expand Up @@ -190,7 +190,7 @@ async def chat(
browser = get_config("proxy.browser")

async def _stream():
session = AsyncSession(impersonate=browser)
session = ResettableSession(impersonate=browser)
try:
async with _get_chat_semaphore():
stream_response = await AppChatReverse.request(
Expand Down
12 changes: 6 additions & 6 deletions app/services/grok/services/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import Any, AsyncGenerator, AsyncIterable, Optional

import orjson
from curl_cffi.requests import AsyncSession
from curl_cffi.requests.errors import RequestsError

from app.core.logger import logger
Expand All @@ -33,6 +32,7 @@
from app.services.reverse.app_chat import AppChatReverse
from app.services.reverse.media_post import MediaPostReverse
from app.services.reverse.video_upscale import VideoUpscaleReverse
from app.services.reverse.utils.session import ResettableSession
from app.services.token.manager import BASIC_POOL_NAME

_VIDEO_SEMAPHORE = None
Expand Down Expand Up @@ -69,7 +69,7 @@ async def create_post(
prompt_value = prompt if media_type == "MEDIA_POST_TYPE_VIDEO" else ""
media_value = media_url or ""

async with AsyncSession() as session:
async with ResettableSession() as session:
async with _get_video_semaphore():
response = await MediaPostReverse.request(
session,
Expand Down Expand Up @@ -131,7 +131,7 @@ async def generate(
}

async def _stream():
session = AsyncSession()
session = ResettableSession()
try:
async with _get_video_semaphore():
stream_response = await AppChatReverse.request(
Expand Down Expand Up @@ -191,7 +191,7 @@ async def generate_from_image(
}

async def _stream():
session = AsyncSession()
session = ResettableSession()
try:
async with _get_video_semaphore():
stream_response = await AppChatReverse.request(
Expand Down Expand Up @@ -401,7 +401,7 @@ async def _upscale_video_url(self, video_url: str) -> str:
logger.warning("Video upscale skipped: unable to extract video id")
return video_url
try:
async with AsyncSession() as session:
async with ResettableSession() as session:
response = await VideoUpscaleReverse.request(
session, self.token, video_id
)
Expand Down Expand Up @@ -583,7 +583,7 @@ async def _upscale_video_url(self, video_url: str) -> str:
logger.warning("Video upscale skipped: unable to extract video id")
return video_url
try:
async with AsyncSession() as session:
async with ResettableSession() as session:
response = await VideoUpscaleReverse.request(
session, self.token, video_id
)
Expand Down
5 changes: 2 additions & 3 deletions app/services/grok/services/voice.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@

from typing import Any, Dict

from curl_cffi.requests import AsyncSession

from app.core.config import get_config
from app.services.reverse.ws_livekit import LivekitTokenReverse
from app.services.reverse.utils.session import ResettableSession


class VoiceService:
Expand All @@ -21,7 +20,7 @@ async def get_token(
speed: float = 1.0,
) -> Dict[str, Any]:
browser = get_config("proxy.browser")
async with AsyncSession(impersonate=browser) as session:
async with ResettableSession(impersonate=browser) as session:
response = await LivekitTokenReverse.request(
session,
token=token,
Expand Down
8 changes: 4 additions & 4 deletions app/services/grok/utils/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,32 +13,32 @@
from urllib.parse import urlparse

import aiofiles
from curl_cffi.requests import AsyncSession

from app.core.logger import logger
from app.core.storage import DATA_DIR
from app.core.config import get_config
from app.core.exceptions import AppException
from app.services.reverse.assets_download import AssetsDownloadReverse
from app.services.reverse.utils.session import ResettableSession
from app.services.grok.utils.locks import _get_download_semaphore, _file_lock


class DownloadService:
"""Assets download service."""

def __init__(self):
self._session: Optional[AsyncSession] = None
self._session: Optional[ResettableSession] = None
base_dir = DATA_DIR / "tmp"
self.image_dir = base_dir / "image"
self.video_dir = base_dir / "video"
self.image_dir.mkdir(parents=True, exist_ok=True)
self.video_dir.mkdir(parents=True, exist_ok=True)
self._cleanup_running = False

async def create(self) -> AsyncSession:
async def create(self) -> ResettableSession:
"""Create or reuse a session."""
if self._session is None:
self._session = AsyncSession()
self._session = ResettableSession()
return self._session

async def close(self):
Expand Down
8 changes: 4 additions & 4 deletions app/services/grok/utils/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,27 @@
from urllib.parse import urlparse

import aiofiles
from curl_cffi.requests import AsyncSession

from app.core.config import get_config
from app.core.exceptions import AppException, UpstreamException, ValidationException
from app.core.logger import logger
from app.core.storage import DATA_DIR
from app.services.reverse.assets_upload import AssetsUploadReverse
from app.services.reverse.utils.session import ResettableSession
from app.services.grok.utils.locks import _get_upload_semaphore, _file_lock


class UploadService:
"""Assets upload service."""

def __init__(self):
self._session: Optional[AsyncSession] = None
self._session: Optional[ResettableSession] = None
self._chunk_size = 64 * 1024

async def create(self) -> AsyncSession:
async def create(self) -> ResettableSession:
"""Create or reuse a session."""
if self._session is None:
self._session = AsyncSession()
self._session = ResettableSession()
return self._session

async def close(self):
Expand Down
10 changes: 7 additions & 3 deletions app/services/reverse/utils/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import asyncio
import inspect
import random
from typing import Callable, Any, Optional

Expand Down Expand Up @@ -122,7 +123,7 @@ async def retry_on_status(
func: Callable,
*args,
extract_status: Callable[[Exception], Optional[int]] = None,
on_retry: Callable[[int, int, Exception, float], None] = None,
on_retry: Callable[[int, int, Exception, float], Any] = None,
**kwargs,
) -> Any:
"""
Expand All @@ -132,7 +133,8 @@ async def retry_on_status(
func: Retry function
*args: Function arguments
extract_status: Function to extract status code from exception
on_retry: Callback function for retry (attempt, status_code, error, delay)
on_retry: Callback function for retry (attempt, status_code, error, delay).
Can be sync or async.
**kwargs: Function keyword arguments

Returns:
Expand Down Expand Up @@ -204,7 +206,9 @@ def extract_status(e: Exception) -> Optional[int]:

# Callback
if on_retry:
on_retry(ctx.attempt, status_code, e, delay)
result = on_retry(ctx.attempt, status_code, e, delay)
if inspect.isawaitable(result):
await result

await asyncio.sleep(delay)
continue
Expand Down
87 changes: 87 additions & 0 deletions app/services/reverse/utils/session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""
Resettable session wrapper for reverse requests.
"""

import asyncio
from typing import Any, Iterable, Optional

from curl_cffi.requests import AsyncSession

from app.core.config import get_config
from app.core.logger import logger


class ResettableSession:
"""AsyncSession wrapper that resets connection on specific HTTP status codes."""

def __init__(
self,
*,
reset_on_status: Optional[Iterable[int]] = None,
**session_kwargs: Any,
):
self._session_kwargs = dict(session_kwargs)
config_codes = get_config("retry.reset_session_status_codes")
if reset_on_status is None:
reset_on_status = config_codes if config_codes is not None else [403]
if isinstance(reset_on_status, int):
reset_on_status = [reset_on_status]
self._reset_on_status = (
{int(code) for code in reset_on_status} if reset_on_status else set()
)
self._reset_requested = False
self._reset_lock = asyncio.Lock()
self._session = AsyncSession(**self._session_kwargs)

async def _maybe_reset(self) -> None:
if not self._reset_requested:
return
async with self._reset_lock:
if not self._reset_requested:
return
self._reset_requested = False
old_session = self._session
self._session = AsyncSession(**self._session_kwargs)
try:
await old_session.close()
except Exception:
pass
logger.debug("ResettableSession: session reset")

async def _request(self, method: str, *args: Any, **kwargs: Any):
await self._maybe_reset()
response = await getattr(self._session, method)(*args, **kwargs)
if self._reset_on_status and response.status_code in self._reset_on_status:
self._reset_requested = True
return response

async def get(self, *args: Any, **kwargs: Any):
return await self._request("get", *args, **kwargs)

async def post(self, *args: Any, **kwargs: Any):
return await self._request("post", *args, **kwargs)

async def reset(self) -> None:
self._reset_requested = True
await self._maybe_reset()

async def close(self) -> None:
if self._session is None:
return
try:
await self._session.close()
finally:
self._session = None
self._reset_requested = False

async def __aenter__(self) -> "ResettableSession":
return self

async def __aexit__(self, exc_type, exc, tb) -> None:
await self.close()

def __getattr__(self, name: str) -> Any:
return getattr(self._session, name)


__all__ = ["ResettableSession"]
Loading
Loading