Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 60 additions & 6 deletions app/api/v1/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from app.services.grok.processor import ImageCollectProcessor, ImageStreamProcessor
from app.services.quota import enforce_daily_quota
from app.services.request_stats import request_stats
from app.services.token_usage import build_image_usage
from app.services.token import get_token_manager


Expand Down Expand Up @@ -510,6 +511,23 @@ async def _record_request(model_id: str, success: bool):
pass


async def _record_request_with_usage(model_id: str, success: bool, prompt: str, success_count: int = 1):
try:
usage = build_image_usage(prompt, success_count=success_count)
raw = usage.get("_raw") or {}
await request_stats.record_request(
model_id,
success=success,
total_tokens=int(usage.get("total_tokens", 0) or 0),
input_tokens=int(usage.get("input_tokens", 0) or 0),
output_tokens=int(usage.get("output_tokens", 0) or 0),
reasoning_tokens=int(raw.get("reasoning_tokens", 0) or 0),
cached_tokens=int(raw.get("cached_tokens", 0) or 0),
)
except Exception:
pass


async def _get_token_for_model(model_id: str):
"""获取指定模型可用 token,失败时抛出统一异常"""
try:
Expand Down Expand Up @@ -659,7 +677,12 @@ async def _wrapped_experimental_stream():
consume_on_fail=True,
is_usage=True,
)
await _record_request(model_info.model_id, True)
await _record_request_with_usage(
model_info.model_id,
True,
f"Image Generation: {request.prompt}",
success_count=n,
)
else:
await _record_request(model_info.model_id, False)
except Exception:
Expand Down Expand Up @@ -707,7 +730,12 @@ async def _wrapped_stream():
consume_on_fail=True,
is_usage=True,
)
await _record_request(model_info.model_id, True)
await _record_request_with_usage(
model_info.model_id,
True,
f"Image Generation: {request.prompt}",
success_count=n,
)
else:
await _record_request(model_info.model_id, False)
except Exception:
Expand Down Expand Up @@ -766,7 +794,15 @@ async def _wrapped_stream():
consume_on_fail=True,
is_usage=True,
)
await _record_request(model_info.model_id, bool(success))
if success:
await _record_request_with_usage(
model_info.model_id,
True,
f"Image Generation: {request.prompt}",
success_count=n,
)
else:
await _record_request(model_info.model_id, False)
except Exception:
pass

Expand Down Expand Up @@ -919,7 +955,12 @@ async def _wrapped_experimental_stream():
consume_on_fail=True,
is_usage=True,
)
await _record_request(model_info.model_id, True)
await _record_request_with_usage(
model_info.model_id,
True,
f"Image Edit: {edit_request.prompt}",
success_count=n,
)
else:
await _record_request(model_info.model_id, False)
except Exception:
Expand Down Expand Up @@ -970,7 +1011,12 @@ async def _wrapped_stream():
consume_on_fail=True,
is_usage=True,
)
await _record_request(model_info.model_id, True)
await _record_request_with_usage(
model_info.model_id,
True,
f"Image Edit: {edit_request.prompt}",
success_count=n,
)
else:
await _record_request(model_info.model_id, False)
except Exception:
Expand Down Expand Up @@ -1055,7 +1101,15 @@ async def _wrapped_stream():
consume_on_fail=True,
is_usage=True,
)
await _record_request(model_info.model_id, bool(success))
if success:
await _record_request_with_usage(
model_info.model_id,
True,
f"Image Edit: {edit_request.prompt}",
success_count=n,
)
else:
await _record_request(model_info.model_id, False)
except Exception:
pass

Expand Down
33 changes: 29 additions & 4 deletions app/services/grok/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,9 @@ async def completions(

# 处理响应
if is_stream:
processor = StreamProcessor(model_name, token, think).process(response)
stream_processor = StreamProcessor(model_name, token, think)
processor = stream_processor.process(response)
prompt_messages = [msg.model_dump() for msg in messages]

async def _wrapped_stream():
completed = False
Expand All @@ -544,19 +546,42 @@ async def _wrapped_stream():
# Only count as "success" when the stream ends naturally.
try:
if completed:
usage = stream_processor.build_usage(prompt_messages)
raw = usage.get("_raw") or {}
await token_mgr.sync_usage(token, model_name, consume_on_fail=True, is_usage=True)
await request_stats.record_request(model_name, success=True)
await request_stats.record_request(
model_name,
success=True,
total_tokens=int(usage.get("total_tokens", 0) or 0),
input_tokens=int(usage.get("prompt_tokens", 0) or 0),
output_tokens=int(usage.get("completion_tokens", 0) or 0),
reasoning_tokens=int(raw.get("reasoning_tokens", 0) or 0),
cached_tokens=int(raw.get("cached_tokens", 0) or 0),
)
else:
await request_stats.record_request(model_name, success=False)
except Exception:
pass

return _wrapped_stream()

result = await CollectProcessor(model_name, token).process(response)
result = await CollectProcessor(model_name, token).process(
response,
prompt_messages=[msg.model_dump() for msg in messages],
)
try:
usage = result.get("usage") or {}
raw = usage.get("_raw") or {}
await token_mgr.sync_usage(token, model_name, consume_on_fail=True, is_usage=True)
await request_stats.record_request(model_name, success=True)
await request_stats.record_request(
model_name,
success=True,
total_tokens=int(usage.get("total_tokens", 0) or 0),
input_tokens=int(usage.get("prompt_tokens", 0) or 0),
output_tokens=int(usage.get("completion_tokens", 0) or 0),
reasoning_tokens=int(raw.get("reasoning_tokens", 0) or 0),
cached_tokens=int(raw.get("cached_tokens", 0) or 0),
)
except Exception:
pass
return result
Expand Down
28 changes: 21 additions & 7 deletions app/services/grok/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from app.core.config import get_config
from app.core.logger import logger
from app.services.grok.assets import DownloadService
from app.services.token_usage import build_chat_usage


ASSET_URL = "https://assets.grok.com/"
Expand Down Expand Up @@ -116,6 +117,8 @@ def __init__(self, model: str, token: str = "", think: bool = None):
self.fingerprint: str = ""
self.think_opened: bool = False
self.role_sent: bool = False
self._output_text: str = ""
self._reasoning_text: str = ""
self.filter_tags = get_config("grok.filter_tags", [])
self.image_format = get_config("app.image_format", "url")

Expand Down Expand Up @@ -157,32 +160,37 @@ async def process(self, response: AsyncIterable[bytes]) -> AsyncGenerator[str, N
idx = img.get('imageIndex', 0) + 1
progress = img.get('progress', 0)
yield self._sse(f"正在生成第{idx}张图片中,当前进度{progress}%\n")
self._reasoning_text += f"正在生成第{idx}张图片中,当前进度{progress}%\n"
continue

# modelResponse
if mr := resp.get("modelResponse"):
if self.think_opened and self.show_think:
if msg := mr.get("message"):
yield self._sse(msg + "\n")
self._reasoning_text += msg + "\n"
yield self._sse("</think>\n")
self.think_opened = False

# 处理生成的图片
for url in mr.get("generatedImageUrls", []):
parts = url.split("/")
img_id = parts[-2] if len(parts) >= 2 else "image"

if self.image_format == "base64":
dl_service = self._get_dl()
base64_data = await dl_service.to_base64(url, self.token, "image")
if base64_data:
yield self._sse(f"![{img_id}]({base64_data})\n")
self._output_text += f"![{img_id}]({base64_data})\n"
else:
final_url = await self.process_url(url, "image")
yield self._sse(f"![{img_id}]({final_url})\n")
self._output_text += f"![{img_id}]({final_url})\n"
else:
final_url = await self.process_url(url, "image")
yield self._sse(f"![{img_id}]({final_url})\n")
self._output_text += f"![{img_id}]({final_url})\n"

if (meta := mr.get("metadata", {})).get("llm_info", {}).get("modelHash"):
self.fingerprint = meta["llm_info"]["modelHash"]
Expand All @@ -192,9 +200,14 @@ async def process(self, response: AsyncIterable[bytes]) -> AsyncGenerator[str, N
if (token := resp.get("token")) is not None:
if token and not (self.filter_tags and any(t in token for t in self.filter_tags)):
yield self._sse(token)
if self.think_opened and self.show_think:
self._reasoning_text += token
else:
self._output_text += token

if self.think_opened:
yield self._sse("</think>\n")
self.think_opened = False
yield self._sse(finish="stop")
yield "data: [DONE]\n\n"
except Exception as e:
Expand All @@ -203,6 +216,10 @@ async def process(self, response: AsyncIterable[bytes]) -> AsyncGenerator[str, N
finally:
await self.close()

def build_usage(self, prompt_messages: Optional[list[dict]] = None) -> dict[str, Any]:
usage = build_chat_usage(prompt_messages or [], (self._output_text + self._reasoning_text))
return usage


class CollectProcessor(BaseProcessor):
"""非流式响应处理器"""
Expand All @@ -211,7 +228,7 @@ def __init__(self, model: str, token: str = ""):
super().__init__(model, token)
self.image_format = get_config("app.image_format", "url")

async def process(self, response: AsyncIterable[bytes]) -> dict[str, Any]:
async def process(self, response: AsyncIterable[bytes], prompt_messages: Optional[list[dict]] = None) -> dict[str, Any]:
"""处理并收集完整响应"""
response_id = ""
fingerprint = ""
Expand Down Expand Up @@ -261,6 +278,7 @@ async def process(self, response: AsyncIterable[bytes]) -> dict[str, Any]:
finally:
await self.close()

usage = build_chat_usage(prompt_messages or [], content)
return {
"id": response_id,
"object": "chat.completion",
Expand All @@ -272,11 +290,7 @@ async def process(self, response: AsyncIterable[bytes]) -> dict[str, Any]:
"message": {"role": "assistant", "content": content, "refusal": None, "annotations": []},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0,
"prompt_tokens_details": {"cached_tokens": 0, "text_tokens": 0, "audio_tokens": 0, "image_tokens": 0},
"completion_tokens_details": {"text_tokens": 0, "audio_tokens": 0, "reasoning_tokens": 0}
}
"usage": usage
}


Expand Down
15 changes: 15 additions & 0 deletions app/services/request_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ class RequestLog:
status: int
key_name: str
token_suffix: str
total_tokens: int = 0
input_tokens: int = 0
output_tokens: int = 0
reasoning_tokens: int = 0
cached_tokens: int = 0
error: str = ""

class RequestLogger:
Expand Down Expand Up @@ -95,6 +100,11 @@ async def add_log(self,
status: int,
key_name: str,
token_suffix: str = "",
total_tokens: int = 0,
input_tokens: int = 0,
output_tokens: int = 0,
reasoning_tokens: int = 0,
cached_tokens: int = 0,
error: str = ""):
"""添加日志"""
if not self._loaded:
Expand All @@ -115,6 +125,11 @@ async def add_log(self,
"status": status,
"key_name": key_name,
"token_suffix": token_suffix,
"total_tokens": int(total_tokens or 0),
"input_tokens": int(input_tokens or 0),
"output_tokens": int(output_tokens or 0),
"reasoning_tokens": int(reasoning_tokens or 0),
"cached_tokens": int(cached_tokens or 0),
"error": error
}

Expand Down
Loading