From e70609be74b486d7c3c2bb1ec891fcafb9db9378 Mon Sep 17 00:00:00 2001 From: nlgtuankiet Date: Sat, 21 Mar 2026 20:47:46 +0700 Subject: [PATCH 1/2] ignore flow2api.iml --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 4c133db..fe62ae7 100644 --- a/.gitignore +++ b/.gitignore @@ -36,6 +36,7 @@ logs.txt # IDE .vscode/ .idea/ +flow2api.iml .lh/ *.swp *.swo From 8f45ca9afad60a88d92a64f78991dddda7b72387 Mon Sep 17 00:00:00 2001 From: nlgtuankiet Date: Sat, 21 Mar 2026 20:47:46 +0700 Subject: [PATCH 2/2] POC: cache media id --- src/core/database.py | 141 +++++++- src/services/generation_handler.py | 490 +++++++++++++++++++++----- tests/test_generation_handler.py | 458 +++++++++++++++++++++++- tests/test_uploaded_image_cache_db.py | 116 ++++++ 4 files changed, 1109 insertions(+), 96 deletions(-) create mode 100644 tests/test_uploaded_image_cache_db.py diff --git a/src/core/database.py b/src/core/database.py index 9cbbddf..a6bc853 100644 --- a/src/core/database.py +++ b/src/core/database.py @@ -1,10 +1,12 @@ """Database storage layer for Flow2API""" +from pathlib import Path + import aiosqlite import json -from datetime import datetime from typing import Optional, List, Dict, Any -from pathlib import Path -from .models import Token, TokenStats, Task, RequestLog, AdminConfig, ProxyConfig, GenerationConfig, CacheConfig, Project, CaptchaConfig, PluginConfig, CallLogicConfig + +from .models import Token, TokenStats, Task, RequestLog, AdminConfig, ProxyConfig, GenerationConfig, CacheConfig, \ + Project, CaptchaConfig, PluginConfig, CallLogicConfig class Database: @@ -325,6 +327,25 @@ async def check_and_migrate_db(self, config_dict: dict = None): ) """) + # Check and create uploaded_image_cache table if missing + if not await self._table_exists(db, "uploaded_image_cache"): + print(" ✓ Creating missing table: uploaded_image_cache") + await db.execute(""" + CREATE TABLE uploaded_image_cache + ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + email TEXT NOT NULL, + project_id TEXT NOT NULL, + image_hash TEXT NOT NULL, + aspect_ratio TEXT NOT NULL, + media_id TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + last_used_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + UNIQUE (email, project_id, image_hash, aspect_ratio) + ) + """) + # ========== Step 2: Add missing columns to existing tables ========== # Check and add missing columns to tokens table if await self._table_exists(db, "tokens"): @@ -543,6 +564,23 @@ async def init_db(self): ) """) + # Uploaded image media cache table + await db.execute(""" + CREATE TABLE IF NOT EXISTS uploaded_image_cache + ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + email TEXT NOT NULL, + project_id TEXT NOT NULL, + image_hash TEXT NOT NULL, + aspect_ratio TEXT NOT NULL, + media_id TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + last_used_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + UNIQUE (email, project_id, image_hash, aspect_ratio) + ) + """) + # Admin config table await db.execute(""" CREATE TABLE IF NOT EXISTS admin_config ( @@ -656,6 +694,8 @@ async def init_db(self): await db.execute("CREATE INDEX IF NOT EXISTS idx_project_id ON projects(project_id)") await db.execute("CREATE INDEX IF NOT EXISTS idx_tokens_email ON tokens(email)") await db.execute("CREATE INDEX IF NOT EXISTS idx_tokens_is_active_last_used_at ON tokens(is_active, last_used_at)") + await db.execute( + "CREATE INDEX IF NOT EXISTS idx_uploaded_image_cache_lookup ON uploaded_image_cache(email, project_id, image_hash, aspect_ratio)") # Migrate request_logs table if needed await self._migrate_request_logs(db) @@ -944,6 +984,101 @@ async def delete_project(self, project_id: str): await db.execute("DELETE FROM projects WHERE project_id = ?", (project_id,)) await db.commit() + # Uploaded image cache operations + async def get_uploaded_image_cache( + self, + email: str, + project_id: str, + image_hash: str, + aspect_ratio: str, + ) -> Optional[Dict[str, Any]]: + """Get cached uploaded media metadata by cache key.""" + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute( + """ + SELECT * + FROM uploaded_image_cache + WHERE email = ? + AND project_id = ? + AND image_hash = ? + AND aspect_ratio = ? + """, + (email, project_id, image_hash, aspect_ratio), + ) + row = await cursor.fetchone() + if row: + return dict(row) + return None + + async def upsert_uploaded_image_cache( + self, + email: str, + project_id: str, + image_hash: str, + aspect_ratio: str, + media_id: str, + ): + """Insert or update cached uploaded media metadata.""" + async with aiosqlite.connect(self.db_path) as db: + await db.execute( + """ + INSERT INTO uploaded_image_cache (email, project_id, image_hash, aspect_ratio, media_id) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT(email, project_id, image_hash, aspect_ratio) + DO UPDATE SET media_id = excluded.media_id, + updated_at = CURRENT_TIMESTAMP, + last_used_at = CURRENT_TIMESTAMP + """, + (email, project_id, image_hash, aspect_ratio, media_id), + ) + await db.commit() + + async def touch_uploaded_image_cache( + self, + email: str, + project_id: str, + image_hash: str, + aspect_ratio: str, + ): + """Update cache entry usage timestamps.""" + async with aiosqlite.connect(self.db_path) as db: + await db.execute( + """ + UPDATE uploaded_image_cache + SET updated_at = CURRENT_TIMESTAMP, + last_used_at = CURRENT_TIMESTAMP + WHERE email = ? + AND project_id = ? + AND image_hash = ? + AND aspect_ratio = ? + """, + (email, project_id, image_hash, aspect_ratio), + ) + await db.commit() + + async def delete_uploaded_image_cache( + self, + email: str, + project_id: str, + image_hash: str, + aspect_ratio: str, + ): + """Delete cached uploaded media metadata by cache key.""" + async with aiosqlite.connect(self.db_path) as db: + await db.execute( + """ + DELETE + FROM uploaded_image_cache + WHERE email = ? + AND project_id = ? + AND image_hash = ? + AND aspect_ratio = ? + """, + (email, project_id, image_hash, aspect_ratio), + ) + await db.commit() + # Task operations async def create_task(self, task: Task) -> int: """Create a new task""" diff --git a/src/services/generation_handler.py b/src/services/generation_handler.py index 95e3fdb..6d0f135 100644 --- a/src/services/generation_handler.py +++ b/src/services/generation_handler.py @@ -1,21 +1,20 @@ """Generation handler for Flow2API""" import asyncio -import base64 +import hashlib import json import time from typing import Optional, AsyncGenerator, List, Dict, Any -from ..core.logger import debug_logger -from ..core.config import config -from ..core.models import Task, RequestLog + +from .file_cache import FileCache from ..core.account_tiers import ( - PAYGATE_TIER_NOT_PAID, get_paygate_tier_label, get_required_paygate_tier_for_model, normalize_user_paygate_tier, supports_model_for_tier, ) -from .file_cache import FileCache - +from ..core.config import config +from ..core.logger import debug_logger +from ..core.models import Task, RequestLog # Model configuration MODEL_CONFIG = { @@ -686,6 +685,8 @@ def __init__(self, flow_client, token_manager, load_balancer, db, concurrency_ma ) self._last_generated_url = None self._last_generation_assets = None + self._uploaded_image_cache_locks: Dict[str, asyncio.Lock] = {} + self._uploaded_image_cache_locks_guard = asyncio.Lock() def _create_generation_result(self) -> Dict[str, Any]: """????????????????""" @@ -712,6 +713,244 @@ def _normalize_error_message(self, error_message: Any, max_length: int = 1000) - return text return f"{text[:max_length - 3]}..." + def _normalize_upload_aspect_ratio(self, aspect_ratio: str) -> str: + """Align upload cache keys with Flow upload semantics.""" + if aspect_ratio.startswith("VIDEO_"): + return aspect_ratio.replace("VIDEO_", "IMAGE_", 1) + return aspect_ratio + + async def _get_uploaded_image_cache_lock(self, cache_key: str) -> asyncio.Lock: + """Get or create a per-key in-process upload dedupe lock.""" + async with self._uploaded_image_cache_locks_guard: + upload_lock = self._uploaded_image_cache_locks.get(cache_key) + if upload_lock is None: + upload_lock = asyncio.Lock() + self._uploaded_image_cache_locks[cache_key] = upload_lock + return upload_lock + + def _build_uploaded_image_cache_key( + self, + email: str, + project_id: str, + image_hash: str, + aspect_ratio: str, + ) -> Dict[str, str]: + """Build the SQLite cache key for uploaded images.""" + return { + "email": email, + "project_id": project_id, + "image_hash": image_hash, + "aspect_ratio": self._normalize_upload_aspect_ratio(aspect_ratio), + } + + async def _get_or_upload_cached_media_id( + self, + token, + project_id: str, + image_bytes: bytes, + aspect_ratio: str, + ) -> Dict[str, Any]: + """Resolve an uploaded image media id via cache or upstream upload.""" + email = str(getattr(token, "email", "") or "").strip() + if not email: + raise ValueError("Token email is required for uploaded image caching") + + image_hash = hashlib.sha256(image_bytes).hexdigest() + cache_key = self._build_uploaded_image_cache_key( + email=email, + project_id=project_id, + image_hash=image_hash, + aspect_ratio=aspect_ratio, + ) + cache_key_id = "|".join( + ( + cache_key["email"], + cache_key["project_id"], + cache_key["image_hash"], + cache_key["aspect_ratio"], + ) + ) + + cached_entry = await self.db.get_uploaded_image_cache(**cache_key) + if cached_entry: + await self.db.touch_uploaded_image_cache(**cache_key) + return { + "media_id": cached_entry["media_id"], + "cache_key": cache_key, + "was_cached": True, + "image_bytes": image_bytes, + } + + upload_lock = await self._get_uploaded_image_cache_lock(cache_key_id) + async with upload_lock: + cached_entry = await self.db.get_uploaded_image_cache(**cache_key) + if cached_entry: + await self.db.touch_uploaded_image_cache(**cache_key) + return { + "media_id": cached_entry["media_id"], + "cache_key": cache_key, + "was_cached": True, + "image_bytes": image_bytes, + } + + media_id = await self.flow_client.upload_image( + token.at, + image_bytes, + cache_key["aspect_ratio"], + project_id=project_id, + ) + await self.db.upsert_uploaded_image_cache( + media_id=media_id, + **cache_key, + ) + return { + "media_id": media_id, + "cache_key": cache_key, + "was_cached": False, + "image_bytes": image_bytes, + } + + async def _resolve_uploaded_media_records( + self, + token, + project_id: str, + images: List[bytes], + aspect_ratio: str, + ) -> List[Dict[str, Any]]: + """Resolve media ids for input images using upload cache.""" + upload_records: List[Dict[str, Any]] = [] + for image_bytes in images: + upload_records.append( + await self._get_or_upload_cached_media_id( + token=token, + project_id=project_id, + image_bytes=image_bytes, + aspect_ratio=aspect_ratio, + ) + ) + return upload_records + + def _build_image_inputs_from_upload_records( + self, + upload_records: List[Dict[str, Any]], + ) -> List[Dict[str, str]]: + """Convert cached upload records into image-generation inputs.""" + return [ + { + "name": upload_record["media_id"], + "imageInputType": "IMAGE_INPUT_TYPE_REFERENCE", + } + for upload_record in upload_records + ] + + def _build_reference_images_from_upload_records( + self, + upload_records: List[Dict[str, Any]], + ) -> List[Dict[str, str]]: + """Convert cached upload records into reference-image inputs.""" + return [ + { + "imageUsageType": "IMAGE_USAGE_TYPE_ASSET", + "mediaId": upload_record["media_id"], + } + for upload_record in upload_records + ] + + def _has_cached_uploaded_media(self, upload_records: List[Dict[str, Any]]) -> bool: + """Whether a request used any cached media ids.""" + return any(upload_record.get("was_cached") for upload_record in upload_records) + + def _is_asset_like_cache_invalidation_error(self, error: Exception) -> bool: + """Detect asset-reference errors that justify evicting cached media ids.""" + error_lower = str(error or "").lower() + if not error_lower: + return False + + if any( + keyword in error_lower + for keyword in [ + "403", + "429", + "recaptcha", + "timed out", + "timeout", + "network", + "tls", + "public_error", + "internal error", + "server error", + "http error 500", + ] + ): + return False + + if not any( + keyword in error_lower + for keyword in [ + "invalid_argument", + "not_found", + "http error 400", + "http error 404", + "bad request", + "not found", + ] + ): + return False + + return any( + keyword in error_lower + for keyword in [ + "media", + "mediaid", + "media id", + "imageinput", + "image input", + "imageinputs", + "referenceimage", + "reference image", + "referenceimages", + "startimage", + "start image", + "endimage", + "end image", + "asset", + ] + ) + + async def _evict_uploaded_media_records(self, upload_records: List[Dict[str, Any]]): + """Evict cached upload records used by a failed generation request.""" + seen_keys = set() + for upload_record in upload_records: + if not upload_record.get("was_cached"): + continue + cache_key = upload_record.get("cache_key") or {} + cache_key_id = ( + cache_key.get("email"), + cache_key.get("project_id"), + cache_key.get("image_hash"), + cache_key.get("aspect_ratio"), + ) + if cache_key_id in seen_keys: + continue + seen_keys.add(cache_key_id) + await self.db.delete_uploaded_image_cache(**cache_key) + + async def _refresh_uploaded_media_records( + self, + token, + project_id: str, + upload_records: List[Dict[str, Any]], + aspect_ratio: str, + ) -> List[Dict[str, Any]]: + """Evict cached media ids and rebuild media records once.""" + await self._evict_uploaded_media_records(upload_records) + return await self._resolve_uploaded_media_records( + token=token, + project_id=project_id, + images=[upload_record["image_bytes"] for upload_record in upload_records], + aspect_ratio=aspect_ratio, + ) + async def _fail_video_task(self, operations: Optional[List[Dict[str, Any]]], error_message: str): """将视频任务收口到失败态,避免残留 processing。""" if not operations: @@ -1124,28 +1363,27 @@ async def _handle_image_generation( try: # 上传图片 (如果有) - upload_started_at = time.time() + upload_elapsed_ms = 0 + upload_records: List[Dict[str, Any]] = [] image_inputs = [] if images and len(images) > 0: if stream: yield self._create_stream_chunk(f"上传 {len(images)} 张参考图片...\n") - # 支持多图输入 - for idx, image_bytes in enumerate(images): - media_id = await self.flow_client.upload_image( - token.at, - image_bytes, - model_config["aspect_ratio"], - project_id=project_id - ) - image_inputs.append({ - "name": media_id, - "imageInputType": "IMAGE_INPUT_TYPE_REFERENCE" - }) + upload_started_at = time.time() + upload_records = await self._resolve_uploaded_media_records( + token=token, + project_id=project_id, + images=images, + aspect_ratio=model_config["aspect_ratio"], + ) + upload_elapsed_ms += int((time.time() - upload_started_at) * 1000) + image_inputs = self._build_image_inputs_from_upload_records(upload_records) + for idx in range(len(upload_records)): if stream: yield self._create_stream_chunk(f"已上传第 {idx + 1}/{len(images)} 张图片\n") if image_trace is not None: - image_trace["upload_images_ms"] = int((time.time() - upload_started_at) * 1000) + image_trace["upload_images_ms"] = upload_elapsed_ms # 调用生成API if stream: @@ -1163,19 +1401,54 @@ async def _image_progress_callback(status_text: str, progress: int): ) generate_started_at = time.time() - result, generation_session_id, upstream_trace = await self.flow_client.generate_image( - at=token.at, - project_id=project_id, - prompt=prompt, - model_name=model_config["model_name"], - aspect_ratio=model_config["aspect_ratio"], - image_inputs=image_inputs, - token_id=token.id, - token_image_concurrency=token.image_concurrency, - progress_callback=_image_progress_callback, - ) + refresh_attempted = False + while True: + try: + result, generation_session_id, upstream_trace = await self.flow_client.generate_image( + at=token.at, + project_id=project_id, + prompt=prompt, + model_name=model_config["model_name"], + aspect_ratio=model_config["aspect_ratio"], + image_inputs=image_inputs, + token_id=token.id, + token_image_concurrency=token.image_concurrency, + progress_callback=_image_progress_callback, + ) + break + except Exception as exc: + if ( + refresh_attempted + or not upload_records + or not self._has_cached_uploaded_media(upload_records) + or not self._is_asset_like_cache_invalidation_error(exc) + ): + raise + + refresh_attempted = True + debug_logger.log_warning( + f"[IMAGE CACHE] Cached media rejected by upstream, evicting and retrying once: {exc}" + ) + await self._update_request_log_progress( + request_log_state, + token_id=token.id, + status_text="uploading_images", + progress=28, + ) + if stream: + yield self._create_stream_chunk("缓存的参考图片已失效,正在重新上传并重试...\n") + upload_started_at = time.time() + upload_records = await self._refresh_uploaded_media_records( + token=token, + project_id=project_id, + upload_records=upload_records, + aspect_ratio=model_config["aspect_ratio"], + ) + upload_elapsed_ms += int((time.time() - upload_started_at) * 1000) + image_inputs = self._build_image_inputs_from_upload_records(upload_records) if image_trace is not None: image_trace["generate_api_ms"] = int((time.time() - generate_started_at) * 1000) + image_trace["upload_images_ms"] = upload_elapsed_ms image_trace["upstream_trace"] = upstream_trace attempts = upstream_trace.get("generation_attempts") if isinstance(upstream_trace, dict) else None if isinstance(attempts, list) and attempts: @@ -1467,6 +1740,8 @@ async def _handle_video_generation( start_media_id = None end_media_id = None reference_images = [] + upload_records: List[Dict[str, Any]] = [] + upload_elapsed_ms = 0 # I2V: 首尾帧处理 if video_type == "i2v" and images: @@ -1474,21 +1749,31 @@ async def _handle_video_generation( # 只有1张图: 仅作为首帧 if stream: yield self._create_stream_chunk("上传首帧图片...\n") - start_media_id = await self.flow_client.upload_image( - token.at, images[0], model_config["aspect_ratio"], project_id=project_id + upload_started_at = time.time() + upload_records = await self._resolve_uploaded_media_records( + token=token, + project_id=project_id, + images=[images[0]], + aspect_ratio=model_config["aspect_ratio"], ) + upload_elapsed_ms += int((time.time() - upload_started_at) * 1000) + start_media_id = upload_records[0]["media_id"] debug_logger.log_info(f"[I2V] 仅上传首帧: {start_media_id}") elif image_count == 2: # 2张图: 首帧+尾帧 if stream: yield self._create_stream_chunk("上传首帧和尾帧图片...\n") - start_media_id = await self.flow_client.upload_image( - token.at, images[0], model_config["aspect_ratio"], project_id=project_id - ) - end_media_id = await self.flow_client.upload_image( - token.at, images[1], model_config["aspect_ratio"], project_id=project_id + upload_started_at = time.time() + upload_records = await self._resolve_uploaded_media_records( + token=token, + project_id=project_id, + images=images[:2], + aspect_ratio=model_config["aspect_ratio"], ) + upload_elapsed_ms += int((time.time() - upload_started_at) * 1000) + start_media_id = upload_records[0]["media_id"] + end_media_id = upload_records[1]["media_id"] debug_logger.log_info(f"[I2V] 上传首尾帧: {start_media_id}, {end_media_id}") # R2V: 多图处理 @@ -1496,83 +1781,124 @@ async def _handle_video_generation( if stream: yield self._create_stream_chunk(f"上传 {image_count} 张参考图片...\n") - for img in images: - media_id = await self.flow_client.upload_image( - token.at, img, model_config["aspect_ratio"], project_id=project_id - ) - reference_images.append({ - "imageUsageType": "IMAGE_USAGE_TYPE_ASSET", - "mediaId": media_id - }) + upload_started_at = time.time() + upload_records = await self._resolve_uploaded_media_records( + token=token, + project_id=project_id, + images=images, + aspect_ratio=model_config["aspect_ratio"], + ) + upload_elapsed_ms += int((time.time() - upload_started_at) * 1000) + reference_images = self._build_reference_images_from_upload_records(upload_records) debug_logger.log_info(f"[R2V] 上传了 {len(reference_images)} 张参考图片") + if video_trace is not None: + video_trace["upload_images_ms"] = upload_elapsed_ms # ========== 调用生成API ========== if stream: yield self._create_stream_chunk("提交视频生成任务...\n") submit_started_at = time.time() - # I2V: 首尾帧生成 - if video_type == "i2v" and start_media_id: - if end_media_id: - # 有首尾帧 - result = await self.flow_client.generate_video_start_end( + async def _submit_video_generation(current_upload_records: List[Dict[str, Any]]) -> Dict[str, Any]: + current_reference_images = self._build_reference_images_from_upload_records(current_upload_records) + current_start_media_id = current_upload_records[0]["media_id"] if current_upload_records else None + current_end_media_id = current_upload_records[1]["media_id"] if len( + current_upload_records) > 1 else None + + # I2V: 首尾帧生成 + if video_type == "i2v" and current_start_media_id: + if current_end_media_id: + return await self.flow_client.generate_video_start_end( + at=token.at, + project_id=project_id, + prompt=prompt, + model_key=model_config["model_key"], + aspect_ratio=model_config["aspect_ratio"], + start_media_id=current_start_media_id, + end_media_id=current_end_media_id, + user_paygate_tier=normalized_tier, + token_id=token.id, + token_video_concurrency=token.video_concurrency, + ) + + actual_model_key = model_config["model_key"].replace("_fl_", "_") + if actual_model_key.endswith("_fl"): + actual_model_key = actual_model_key[:-3] + debug_logger.log_info(f"[I2V] 单帧模式,model_key: {model_config['model_key']} -> {actual_model_key}") + return await self.flow_client.generate_video_start_image( at=token.at, project_id=project_id, prompt=prompt, - model_key=model_config["model_key"], + model_key=actual_model_key, aspect_ratio=model_config["aspect_ratio"], - start_media_id=start_media_id, - end_media_id=end_media_id, + start_media_id=current_start_media_id, user_paygate_tier=normalized_tier, token_id=token.id, token_video_concurrency=token.video_concurrency, ) - else: - # 只有首帧 - 需要去掉 model_key 中的 _fl - # 情况1: _fl_ 在中间 (如 veo_3_1_i2v_s_fast_fl_ultra_relaxed -> veo_3_1_i2v_s_fast_ultra_relaxed) - # 情况2: _fl 在结尾 (如 veo_3_1_i2v_s_fast_ultra_fl -> veo_3_1_i2v_s_fast_ultra) - actual_model_key = model_config["model_key"].replace("_fl_", "_") - if actual_model_key.endswith("_fl"): - actual_model_key = actual_model_key[:-3] - debug_logger.log_info(f"[I2V] 单帧模式,model_key: {model_config['model_key']} -> {actual_model_key}") - result = await self.flow_client.generate_video_start_image( + + # R2V: 多图生成 + if video_type == "r2v" and current_reference_images: + return await self.flow_client.generate_video_reference_images( at=token.at, project_id=project_id, prompt=prompt, - model_key=actual_model_key, + model_key=model_config["model_key"], aspect_ratio=model_config["aspect_ratio"], - start_media_id=start_media_id, + reference_images=current_reference_images, user_paygate_tier=normalized_tier, token_id=token.id, token_video_concurrency=token.video_concurrency, ) - # R2V: 多图生成 - elif video_type == "r2v" and reference_images: - result = await self.flow_client.generate_video_reference_images( + # T2V 或 R2V无图: 纯文本生成 + return await self.flow_client.generate_video_text( at=token.at, project_id=project_id, prompt=prompt, model_key=model_config["model_key"], aspect_ratio=model_config["aspect_ratio"], - reference_images=reference_images, user_paygate_tier=normalized_tier, token_id=token.id, token_video_concurrency=token.video_concurrency, ) - # T2V 或 R2V无图: 纯文本生成 - else: - result = await self.flow_client.generate_video_text( - at=token.at, - project_id=project_id, - prompt=prompt, - model_key=model_config["model_key"], - aspect_ratio=model_config["aspect_ratio"], - user_paygate_tier=normalized_tier, - token_id=token.id, - token_video_concurrency=token.video_concurrency, - ) + refresh_attempted = False + while True: + try: + result = await _submit_video_generation(upload_records) + break + except Exception as exc: + if ( + refresh_attempted + or not upload_records + or not self._has_cached_uploaded_media(upload_records) + or not self._is_asset_like_cache_invalidation_error(exc) + ): + raise + + refresh_attempted = True + debug_logger.log_warning( + f"[VIDEO CACHE] Cached media rejected by upstream, evicting and retrying once: {exc}" + ) + await self._update_request_log_progress( + request_log_state, + token_id=token.id, + status_text="preparing_video", + progress=24, + ) + if stream: + yield self._create_stream_chunk("缓存的参考图片已失效,正在重新上传并重试...\n") + upload_started_at = time.time() + upload_records = await self._refresh_uploaded_media_records( + token=token, + project_id=project_id, + upload_records=upload_records, + aspect_ratio=model_config["aspect_ratio"], + ) + upload_elapsed_ms += int((time.time() - upload_started_at) * 1000) + if video_trace is not None: + video_trace["upload_images_ms"] = upload_elapsed_ms if video_trace is not None: video_trace["submit_generation_ms"] = int((time.time() - submit_started_at) * 1000) diff --git a/tests/test_generation_handler.py b/tests/test_generation_handler.py index a4f068e..d7da82a 100644 --- a/tests/test_generation_handler.py +++ b/tests/test_generation_handler.py @@ -1,12 +1,53 @@ import asyncio +import hashlib from types import SimpleNamespace +import pytest + from src.services.generation_handler import GenerationHandler +IMAGE_MODEL_CONFIG = { + "model_name": "NARWHAL", + "aspect_ratio": "IMAGE_ASPECT_RATIO_SQUARE", +} + +I2V_MODEL_CONFIG = { + "video_type": "i2v", + "model_key": "veo_3_1_i2v_s_fast_fl", + "aspect_ratio": "VIDEO_ASPECT_RATIO_LANDSCAPE", + "min_images": 1, + "max_images": 2, +} + +R2V_MODEL_CONFIG = { + "video_type": "r2v", + "model_key": "veo_3_1_r2v_fast_landscape", + "aspect_ratio": "VIDEO_ASPECT_RATIO_LANDSCAPE", + "min_images": 0, + "max_images": 3, +} + class FakeFlowClient: + def __init__(self): + self.upload_calls = [] + self.upload_delay = 0 + self.generate_image_errors = [] + self.generate_image_call_count = 0 + self.video_calls = [] + async def upload_image(self, at, image_bytes, aspect_ratio, project_id=None): - return "media-uploaded" + if self.upload_delay > 0: + await asyncio.sleep(self.upload_delay) + self.upload_calls.append( + { + "at": at, + "image_bytes": image_bytes, + "aspect_ratio": aspect_ratio, + "project_id": project_id, + } + ) + return f"media-uploaded-{len(self.upload_calls)}" async def generate_image( self, @@ -20,9 +61,16 @@ async def generate_image( token_image_concurrency=None, progress_callback=None, ): + self.generate_image_call_count += 1 if progress_callback is not None: await progress_callback("solving_image_captcha", 38) await progress_callback("submitting_image", 48) + + if self.generate_image_errors: + error = self.generate_image_errors.pop(0) + if error is not None: + raise error + return ( { "media": [ @@ -40,10 +88,61 @@ async def generate_image( {"generation_attempts": [{"launch_queue_ms": 0, "launch_stagger_ms": 0}]}, ) + async def generate_video_start_image(self, **kwargs): + self.video_calls.append(("i2v-single", kwargs)) + return { + "operations": [ + { + "operation": {"name": "task-i2v-single"}, + "sceneId": "scene-i2v-single", + "status": "MEDIA_GENERATION_STATUS_PENDING", + } + ] + } + + async def generate_video_start_end(self, **kwargs): + self.video_calls.append(("i2v-pair", kwargs)) + return { + "operations": [ + { + "operation": {"name": "task-i2v-pair"}, + "sceneId": "scene-i2v-pair", + "status": "MEDIA_GENERATION_STATUS_PENDING", + } + ] + } + + async def generate_video_reference_images(self, **kwargs): + self.video_calls.append(("r2v", kwargs)) + return { + "operations": [ + { + "operation": {"name": "task-r2v"}, + "sceneId": "scene-r2v", + "status": "MEDIA_GENERATION_STATUS_PENDING", + } + ] + } + + async def generate_video_text(self, **kwargs): + self.video_calls.append(("t2v", kwargs)) + return { + "operations": [ + { + "operation": {"name": "task-t2v"}, + "sceneId": "scene-t2v", + "status": "MEDIA_GENERATION_STATUS_PENDING", + } + ] + } + class FakeDB: def __init__(self): self.status_updates = [] + self.uploaded_image_cache = {} + self.deleted_keys = [] + self.created_tasks = [] async def update_request_log(self, log_id, **kwargs): self.status_updates.append( @@ -54,6 +153,49 @@ async def update_request_log(self, log_id, **kwargs): } ) + async def get_uploaded_image_cache(self, email, project_id, image_hash, aspect_ratio): + return self.uploaded_image_cache.get((email, project_id, image_hash, aspect_ratio)) + + async def upsert_uploaded_image_cache(self, email, project_id, image_hash, aspect_ratio, media_id): + self.uploaded_image_cache[(email, project_id, image_hash, aspect_ratio)] = { + "email": email, + "project_id": project_id, + "image_hash": image_hash, + "aspect_ratio": aspect_ratio, + "media_id": media_id, + } + + async def touch_uploaded_image_cache(self, email, project_id, image_hash, aspect_ratio): + return None + + async def delete_uploaded_image_cache(self, email, project_id, image_hash, aspect_ratio): + self.deleted_keys.append((email, project_id, image_hash, aspect_ratio)) + self.uploaded_image_cache.pop((email, project_id, image_hash, aspect_ratio), None) + + async def create_task(self, task): + self.created_tasks.append(task) + return len(self.created_tasks) + + async def update_task(self, task_id, **kwargs): + return None + + +class NoPollGenerationHandler(GenerationHandler): + async def _poll_video_result(self, *args, **kwargs): + if False: + yield None + + +def _make_token(email="user@example.com", token_id=1): + return SimpleNamespace( + id=token_id, + at="at-token", + email=email, + image_concurrency=-1, + video_concurrency=-1, + user_paygate_tier="PAYGATE_TIER_NOT_PAID", + ) + async def _collect(async_gen): items = [] @@ -72,12 +214,7 @@ def test_image_generation_progress_switches_from_upload_to_captcha(): concurrency_manager=None, proxy_manager=None, ) - token = SimpleNamespace( - id=1, - at="at-token", - image_concurrency=-1, - user_paygate_tier="PAYGATE_TIER_NOT_PAID", - ) + token = _make_token() generation_result = handler._create_generation_result() request_log_state = {"id": 123} @@ -86,10 +223,7 @@ def test_image_generation_progress_switches_from_upload_to_captcha(): handler._handle_image_generation( token=token, project_id="project-1", - model_config={ - "model_name": "NARWHAL", - "aspect_ratio": "IMAGE_ASPECT_RATIO_SQUARE", - }, + model_config=IMAGE_MODEL_CONFIG, prompt="draw a cat", images=[b"fake-image"], stream=False, @@ -109,3 +243,305 @@ def test_image_generation_progress_switches_from_upload_to_captcha(): "submitting_image", "image_generated", ] + + +def test_image_generation_reuses_cached_media_ids_across_requests(): + db = FakeDB() + flow_client = FakeFlowClient() + handler = GenerationHandler( + flow_client=flow_client, + token_manager=None, + load_balancer=None, + db=db, + concurrency_manager=None, + proxy_manager=None, + ) + token = _make_token() + + for _ in range(2): + asyncio.run( + _collect( + handler._handle_image_generation( + token=token, + project_id="project-1", + model_config=IMAGE_MODEL_CONFIG, + prompt="draw a cat", + images=[b"fake-image"], + stream=False, + perf_trace={}, + generation_result=handler._create_generation_result(), + request_log_state={"id": 123}, + pending_token_state={"active": False}, + ) + ) + ) + + assert len(flow_client.upload_calls) == 1 + + +def test_uploaded_media_cache_is_scoped_by_project(): + db = FakeDB() + flow_client = FakeFlowClient() + handler = GenerationHandler( + flow_client=flow_client, + token_manager=None, + load_balancer=None, + db=db, + concurrency_manager=None, + proxy_manager=None, + ) + token = _make_token() + + asyncio.run( + handler._resolve_uploaded_media_records( + token=token, + project_id="project-a", + images=[b"same-image"], + aspect_ratio="IMAGE_ASPECT_RATIO_SQUARE", + ) + ) + asyncio.run( + handler._resolve_uploaded_media_records( + token=token, + project_id="project-b", + images=[b"same-image"], + aspect_ratio="IMAGE_ASPECT_RATIO_SQUARE", + ) + ) + + assert len(flow_client.upload_calls) == 2 + + +def test_uploaded_media_cache_is_scoped_by_email(): + db = FakeDB() + flow_client = FakeFlowClient() + handler = GenerationHandler( + flow_client=flow_client, + token_manager=None, + load_balancer=None, + db=db, + concurrency_manager=None, + proxy_manager=None, + ) + + asyncio.run( + handler._resolve_uploaded_media_records( + token=_make_token(email="first@example.com", token_id=1), + project_id="project-1", + images=[b"same-image"], + aspect_ratio="IMAGE_ASPECT_RATIO_SQUARE", + ) + ) + asyncio.run( + handler._resolve_uploaded_media_records( + token=_make_token(email="second@example.com", token_id=2), + project_id="project-1", + images=[b"same-image"], + aspect_ratio="IMAGE_ASPECT_RATIO_SQUARE", + ) + ) + + assert len(flow_client.upload_calls) == 2 + + +def test_uploaded_media_cache_dedupes_concurrent_identical_uploads(): + db = FakeDB() + flow_client = FakeFlowClient() + flow_client.upload_delay = 0.05 + handler = GenerationHandler( + flow_client=flow_client, + token_manager=None, + load_balancer=None, + db=db, + concurrency_manager=None, + proxy_manager=None, + ) + token = _make_token() + + async def _run(): + return await asyncio.gather( + handler._resolve_uploaded_media_records( + token=token, + project_id="project-1", + images=[b"same-image"], + aspect_ratio="IMAGE_ASPECT_RATIO_SQUARE", + ), + handler._resolve_uploaded_media_records( + token=token, + project_id="project-1", + images=[b"same-image"], + aspect_ratio="IMAGE_ASPECT_RATIO_SQUARE", + ), + ) + + results = asyncio.run(_run()) + + assert len(flow_client.upload_calls) == 1 + assert results[0][0]["media_id"] == results[1][0]["media_id"] + + +def test_image_generation_evicts_stale_cached_media_and_reuploads_once(): + db = FakeDB() + flow_client = FakeFlowClient() + flow_client.generate_image_errors = [ + Exception("INVALID_ARGUMENT: referenceImages mediaId not found"), + None, + ] + handler = GenerationHandler( + flow_client=flow_client, + token_manager=None, + load_balancer=None, + db=db, + concurrency_manager=None, + proxy_manager=None, + ) + token = _make_token() + cache_key = handler._build_uploaded_image_cache_key( + email=token.email, + project_id="project-1", + image_hash=hashlib.sha256(b"fake-image").hexdigest(), + aspect_ratio="IMAGE_ASPECT_RATIO_SQUARE", + ) + asyncio.run(db.upsert_uploaded_image_cache(media_id="stale-media", **cache_key)) + + asyncio.run( + _collect( + handler._handle_image_generation( + token=token, + project_id="project-1", + model_config=IMAGE_MODEL_CONFIG, + prompt="draw a cat", + images=[b"fake-image"], + stream=False, + perf_trace={}, + generation_result=handler._create_generation_result(), + request_log_state={"id": 123}, + pending_token_state={"active": False}, + ) + ) + ) + + assert flow_client.generate_image_call_count == 2 + assert len(flow_client.upload_calls) == 1 + assert db.deleted_keys == [ + ( + token.email, + "project-1", + hashlib.sha256(b"fake-image").hexdigest(), + "IMAGE_ASPECT_RATIO_SQUARE", + ) + ] + + +def test_image_generation_does_not_evict_cache_for_non_asset_errors(): + db = FakeDB() + flow_client = FakeFlowClient() + flow_client.generate_image_errors = [Exception("PUBLIC_ERROR: internal error")] + handler = GenerationHandler( + flow_client=flow_client, + token_manager=None, + load_balancer=None, + db=db, + concurrency_manager=None, + proxy_manager=None, + ) + token = _make_token() + cache_key = handler._build_uploaded_image_cache_key( + email=token.email, + project_id="project-1", + image_hash=hashlib.sha256(b"fake-image").hexdigest(), + aspect_ratio="IMAGE_ASPECT_RATIO_SQUARE", + ) + asyncio.run(db.upsert_uploaded_image_cache(media_id="cached-media", **cache_key)) + + with pytest.raises(Exception, match="PUBLIC_ERROR: internal error"): + asyncio.run( + _collect( + handler._handle_image_generation( + token=token, + project_id="project-1", + model_config=IMAGE_MODEL_CONFIG, + prompt="draw a cat", + images=[b"fake-image"], + stream=False, + perf_trace={}, + generation_result=handler._create_generation_result(), + request_log_state={"id": 123}, + pending_token_state={"active": False}, + ) + ) + ) + + assert len(flow_client.upload_calls) == 0 + assert db.deleted_keys == [] + + +def test_i2v_reuses_cached_uploads_for_start_and_end_images(): + db = FakeDB() + flow_client = FakeFlowClient() + handler = NoPollGenerationHandler( + flow_client=flow_client, + token_manager=None, + load_balancer=None, + db=db, + concurrency_manager=None, + proxy_manager=None, + ) + token = _make_token() + + for _ in range(2): + asyncio.run( + _collect( + handler._handle_video_generation( + token=token, + project_id="project-1", + model_config=I2V_MODEL_CONFIG, + prompt="animate", + images=[b"start", b"end"], + stream=False, + perf_trace={}, + generation_result=handler._create_generation_result(), + request_log_state={"id": 123}, + pending_token_state={"active": False}, + ) + ) + ) + + assert len(flow_client.upload_calls) == 2 + assert [call[0] for call in flow_client.video_calls] == ["i2v-pair", "i2v-pair"] + + +def test_r2v_reuses_cached_uploads_for_reference_images(): + db = FakeDB() + flow_client = FakeFlowClient() + handler = NoPollGenerationHandler( + flow_client=flow_client, + token_manager=None, + load_balancer=None, + db=db, + concurrency_manager=None, + proxy_manager=None, + ) + token = _make_token() + images = [b"one", b"two", b"three"] + + for _ in range(2): + asyncio.run( + _collect( + handler._handle_video_generation( + token=token, + project_id="project-1", + model_config=R2V_MODEL_CONFIG, + prompt="animate", + images=images, + stream=False, + perf_trace={}, + generation_result=handler._create_generation_result(), + request_log_state={"id": 123}, + pending_token_state={"active": False}, + ) + ) + ) + + assert len(flow_client.upload_calls) == 3 + assert [call[0] for call in flow_client.video_calls] == ["r2v", "r2v"] diff --git a/tests/test_uploaded_image_cache_db.py b/tests/test_uploaded_image_cache_db.py new file mode 100644 index 0000000..6c5373b --- /dev/null +++ b/tests/test_uploaded_image_cache_db.py @@ -0,0 +1,116 @@ +import asyncio +import time + +from src.core.database import Database + + +def test_uploaded_image_cache_crud_roundtrip(tmp_path): + db = Database(db_path=str(tmp_path / "flow.db")) + asyncio.run(db.init_db()) + + asyncio.run( + db.upsert_uploaded_image_cache( + email="user@example.com", + project_id="project-1", + image_hash="hash-1", + aspect_ratio="IMAGE_ASPECT_RATIO_SQUARE", + media_id="media-1", + ) + ) + + entry = asyncio.run( + db.get_uploaded_image_cache( + email="user@example.com", + project_id="project-1", + image_hash="hash-1", + aspect_ratio="IMAGE_ASPECT_RATIO_SQUARE", + ) + ) + assert entry is not None + assert entry["media_id"] == "media-1" + + asyncio.run( + db.upsert_uploaded_image_cache( + email="user@example.com", + project_id="project-1", + image_hash="hash-1", + aspect_ratio="IMAGE_ASPECT_RATIO_SQUARE", + media_id="media-2", + ) + ) + + updated_entry = asyncio.run( + db.get_uploaded_image_cache( + email="user@example.com", + project_id="project-1", + image_hash="hash-1", + aspect_ratio="IMAGE_ASPECT_RATIO_SQUARE", + ) + ) + assert updated_entry is not None + assert updated_entry["media_id"] == "media-2" + + asyncio.run( + db.delete_uploaded_image_cache( + email="user@example.com", + project_id="project-1", + image_hash="hash-1", + aspect_ratio="IMAGE_ASPECT_RATIO_SQUARE", + ) + ) + + deleted_entry = asyncio.run( + db.get_uploaded_image_cache( + email="user@example.com", + project_id="project-1", + image_hash="hash-1", + aspect_ratio="IMAGE_ASPECT_RATIO_SQUARE", + ) + ) + assert deleted_entry is None + + +def test_uploaded_image_cache_touch_updates_last_used_at(tmp_path): + db = Database(db_path=str(tmp_path / "flow.db")) + asyncio.run(db.init_db()) + + asyncio.run( + db.upsert_uploaded_image_cache( + email="user@example.com", + project_id="project-1", + image_hash="hash-1", + aspect_ratio="IMAGE_ASPECT_RATIO_SQUARE", + media_id="media-1", + ) + ) + before_touch = asyncio.run( + db.get_uploaded_image_cache( + email="user@example.com", + project_id="project-1", + image_hash="hash-1", + aspect_ratio="IMAGE_ASPECT_RATIO_SQUARE", + ) + ) + + time.sleep(1.1) + + asyncio.run( + db.touch_uploaded_image_cache( + email="user@example.com", + project_id="project-1", + image_hash="hash-1", + aspect_ratio="IMAGE_ASPECT_RATIO_SQUARE", + ) + ) + after_touch = asyncio.run( + db.get_uploaded_image_cache( + email="user@example.com", + project_id="project-1", + image_hash="hash-1", + aspect_ratio="IMAGE_ASPECT_RATIO_SQUARE", + ) + ) + + assert before_touch is not None + assert after_touch is not None + assert after_touch["last_used_at"] >= before_touch["last_used_at"]