diff --git a/docs/openapi.yaml b/docs/openapi.yaml index 4445a20c9..3ec6b6ccd 100644 --- a/docs/openapi.yaml +++ b/docs/openapi.yaml @@ -1081,6 +1081,146 @@ paths: schema: $ref: "#/components/schemas/AIInteractionResponse" + /image: + post: + tags: + - AI-SVG + summary: AI generate images + description: | + Generate AI-powered images and return image metadata including URLs, dimensions, and generation details. + + Supports multiple AI providers (svgio, recraft, openai) with automatic prompt translation. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/GenerateImageParams" + responses: + "200": + description: Successfully generated image and returned metadata. + content: + application/json: + schema: + $ref: "#/components/schemas/ImageResponse" + "400": + description: Invalid request parameters. + content: + application/json: + schema: + type: object + properties: + error: + type: string + examples: + - "prompt must be at least 3 characters" + - "provider must be one of: svgio, recraft, openai" + "500": + description: Internal server error during image generation. + content: + application/json: + schema: + type: object + properties: + error: + type: string + examples: + - "failed to generate image" + + /image/svg: + post: + tags: + - AI-SVG + summary: Generate SVG image directly + description: | + Generate AI-powered SVG images and return the SVG file content directly in the response body. + + Returns actual SVG content instead of metadata, with generation details provided in response headers. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/GenerateImageParams" + responses: + "200": + description: Successfully generated SVG image. + headers: + Content-Type: + description: MIME type of the response + schema: + type: string + example: "image/svg+xml" + Content-Disposition: + description: File download information + schema: + type: string + example: 'attachment; filename="img_abc123.svg"' + X-Image-Id: + description: Unique identifier for the generated image + schema: + type: string + example: "img_abc123" + X-Image-Width: + description: Width of the generated image in pixels + schema: + type: integer + example: 512 + X-Image-Height: + description: Height of the generated image in pixels + schema: + type: integer + example: 512 + X-Provider: + description: AI provider used for generation + schema: + type: string + example: "svgio" + X-Original-Prompt: + description: Original prompt provided by user (if available) + schema: + type: string + example: "A beautiful sunset over mountains" + X-Translated-Prompt: + description: Translated version of prompt (if translation was applied) + schema: + type: string + example: "A beautiful sunset over mountains" + X-Was-Translated: + description: Whether the prompt was automatically translated + schema: + type: string + example: "false" + content: + image/svg+xml: + schema: + type: string + format: binary + description: The generated SVG image content + "400": + description: Invalid request parameters. + content: + application/json: + schema: + type: object + properties: + error: + type: string + examples: + - "prompt must be at least 3 characters" + - "provider must be one of: svgio, recraft, openai" + "500": + description: Internal server error during image generation. + content: + application/json: + schema: + type: object + properties: + error: + type: string + examples: + - "failed to generate SVG" + /course: post: tags: @@ -1979,6 +2119,135 @@ components: examples: - 0 + GenerateImageParams: + type: object + required: + - prompt + properties: + prompt: + type: string + minLength: 3 + description: The text prompt for image generation + examples: + - "A beautiful sunset over mountains" + negative_prompt: + type: string + description: Text describing what should not appear in the image + examples: + - "blurry, low quality" + style: + type: string + description: Style of the generated image + examples: + - "cartoon" + - "realistic" + - "sketch" + provider: + type: string + enum: + - svgio + - recraft + - openai + default: svgio + description: The AI provider to use for generation + examples: + - "svgio" + format: + type: string + description: Output format of the generated image + examples: + - "svg" + - "png" + skip_translate: + type: boolean + default: false + description: Whether to skip automatic prompt translation + model: + type: string + description: Specific model to use for generation + examples: + - "gpt-4" + size: + type: string + description: Size of the generated image + examples: + - "512x512" + - "1024x1024" + substyle: + type: string + description: Sub-style variant for the image + examples: + - "hand-drawn" + n: + type: integer + minimum: 1 + description: Number of images to generate + examples: + - 1 + + ImageResponse: + type: object + properties: + id: + type: string + description: Unique identifier for the generated image + examples: + - "img_abc123" + svg_url: + type: string + format: uri + description: URL to access the generated SVG image + examples: + - "https://example.com/images/img_abc123.svg" + png_url: + type: string + format: uri + description: URL to access the PNG version of the image (if available) + examples: + - "https://example.com/images/img_abc123.png" + width: + type: integer + minimum: 1 + description: Width of the generated image in pixels + examples: + - 512 + height: + type: integer + minimum: 1 + description: Height of the generated image in pixels + examples: + - 512 + provider: + type: string + enum: + - svgio + - recraft + - openai + description: The AI provider used for generation + examples: + - "svgio" + original_prompt: + type: string + description: The original prompt provided by the user + examples: + - "A beautiful sunset over mountains" + translated_prompt: + type: string + description: The translated version of the prompt (if translation was applied) + examples: + - "A beautiful sunset over mountains" + was_translated: + type: boolean + description: Whether the prompt was automatically translated + examples: + - false + created_at: + type: string + format: date-time + description: Timestamp when the image was generated + examples: + - "2024-01-15T10:30:00Z" + AIInteractionResponse: type: object properties: diff --git a/spx-algorithm/client/README.md b/spx-algorithm/client/README.md new file mode 100644 index 000000000..4abb43185 --- /dev/null +++ b/spx-algorithm/client/README.md @@ -0,0 +1,96 @@ +# Image Search API 客户端 + +这个目录包含了调用Image Search API的客户端示例代码。 + +## 文件说明 + +- `image_search_client.py` - 完整的客户端类,包含所有API调用功能 +- `simple_client.py` - 简单的API调用示例 +- `README.md` - 本说明文件 + +## 使用方法 + +### 1. 安装依赖 + +```bash +pip install requests +``` + +### 2. 启动API服务 + +首先确保API服务正在运行: + +```bash +cd ../project +python run.py +``` + +### 3. 运行客户端 + +#### 使用完整客户端类: + +```bash +python image_search_client.py +``` + +#### 使用简单示例: + +```bash +python simple_client.py +``` + +## API调用示例 + +### 文件上传搜索 + +```python +import requests + +url = "http://localhost:5000/api/search" +files = [ + ('images', open('image1.jpg', 'rb')), + ('images', open('image2.png', 'rb')), +] +data = { + 'text': 'a cute cat', + 'top_k': 3 +} + +response = requests.post(url, files=files, data=data) +result = response.json() +``` + +### URL搜索 + +```python +import requests + +url = "http://localhost:5000/api/search/url" +data = { + "text": "beautiful landscape", + "image_urls": [ + "https://example.com/image1.jpg", + "https://example.com/image2.png" + ], + "top_k": 5 +} + +response = requests.post(url, json=data) +result = response.json() +``` + +### 健康检查 + +```python +import requests + +response = requests.get("http://localhost:5000/api/health") +health = response.json() +``` + +## 注意事项 + +1. 确保API服务正在运行(默认端口5000) +2. 文件搜索需要提供实际存在的图片文件路径 +3. URL搜索需要提供可访问的图片URL +4. 所有API调用都应该包含适当的错误处理 \ No newline at end of file diff --git a/spx-algorithm/client/image_search_client.py b/spx-algorithm/client/image_search_client.py new file mode 100644 index 000000000..26f4a6610 --- /dev/null +++ b/spx-algorithm/client/image_search_client.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python3 +""" +Image Search API 客户端 +""" +import requests +import json +from typing import List, Optional, Dict, Any +import os + + +class ImageSearchClient: + """图像搜索API客户端""" + + def __init__(self, base_url: str = "http://localhost:5000"): + """ + 初始化客户端 + + Args: + base_url: API服务器地址 + """ + self.base_url = base_url.rstrip('/') + self.session = requests.Session() + + def health_check(self) -> Dict[str, Any]: + """ + 健康检查 + + Returns: + 健康状态响应 + """ + response = self.session.get(f"{self.base_url}/api/health") + response.raise_for_status() + return response.json() + + def search_with_files( + self, + text: str, + image_paths: List[str], + top_k: Optional[int] = None + ) -> Dict[str, Any]: + """ + 使用本地文件进行搜索 + + Args: + text: 查询文本 + image_paths: 本地图片文件路径列表 + top_k: 返回前k个结果 + + Returns: + 搜索结果 + """ + # 准备文件 + files = [] + for path in image_paths: + if not os.path.exists(path): + raise FileNotFoundError(f"图片文件不存在: {path}") + files.append(('images', open(path, 'rb'))) + + # 准备数据 + data = {'text': text} + if top_k is not None: + data['top_k'] = top_k + + try: + response = self.session.post( + f"{self.base_url}/api/search", + files=files, + data=data + ) + response.raise_for_status() + return response.json() + finally: + # 关闭文件 + for _, file in files: + file.close() + + def search_with_urls( + self, + text: str, + image_urls: List[str], + top_k: Optional[int] = None + ) -> Dict[str, Any]: + """ + 使用图片URL进行搜索 + + Args: + text: 查询文本 + image_urls: 图片URL列表 + top_k: 返回前k个结果 + + Returns: + 搜索结果 + """ + data = { + "text": text, + "image_urls": image_urls + } + if top_k is not None: + data["top_k"] = top_k + + response = self.session.post( + f"{self.base_url}/api/search/url", + json=data, + headers={'Content-Type': 'application/json'} + ) + response.raise_for_status() + return response.json() + + def print_results(self, results: Dict[str, Any]): + """ + 格式化打印搜索结果 + + Args: + results: 搜索结果 + """ + print(f"查询文本: '{results['query']}'") + print(f"总图片数: {results['total_images']}") + print(f"结果数量: {results['results_count']}") + print("-" * 60) + + for result in results['results']: + print(f"排名 {result['rank']}: {result.get('filename', result.get('image_path', 'N/A'))}") + print(f"相似度: {result['similarity']:.4f}") + print("-" * 30) + + +def main(): + """示例使用""" + # 创建客户端 + client = ImageSearchClient() + + # 健康检查 + try: + health = client.health_check() + print("服务状态:", health) + print() + except requests.exceptions.RequestException as e: + print(f"服务不可用: {e}") + return + + # 示例1: 使用本地文件搜索 + print("=== 使用本地文件搜索 ===") + try: + # 这里需要替换为实际的图片路径 + image_files = [ + "../dog.svg", + "../cute.svg", + "../cute2.svg", + "../image.svg" + ] + + # 过滤存在的文件 + existing_files = [f for f in image_files if os.path.exists(f)] + + if existing_files: + results = client.search_with_files( + text="a cute cat", + image_paths=existing_files, + top_k=3 + ) + client.print_results(results) + else: + print("没有找到示例图片文件") + except Exception as e: + print(f"文件搜索错误: {e}") + + print() + + # 示例2: 使用URL搜索 + print("=== 使用URL搜索 ===") + try: + results = client.search_with_urls( + text="beautiful landscape", + image_urls=[ + "https://via.placeholder.com/300x200.png?text=Landscape1", + "https://via.placeholder.com/300x200.png?text=Landscape2" + ], + top_k=2 + ) + client.print_results(results) + except Exception as e: + print(f"URL搜索错误: {e}") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/spx-algorithm/client/simple_client.py b/spx-algorithm/client/simple_client.py new file mode 100644 index 000000000..15c40530d --- /dev/null +++ b/spx-algorithm/client/simple_client.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python3 +""" +简单的API调用示例 +""" +import requests +import os + + +def search_with_files_example(): + """文件上传搜索示例""" + url = "http://localhost:5000/api/search" + + # 准备图片文件(请替换为实际的图片路径) + image_files = ["../resource/dog.svg", "../resource/cute.svg", "../resource/cute2.svg"] + files = [] + + for img_path in image_files: + if os.path.exists(img_path): + files.append(('images', open(img_path, 'rb'))) + + if not files: + print("没有找到图片文件,请确保图片路径正确") + return + + data = { + 'text': 'a cute cat', + 'top_k': 3 + } + + try: + response = requests.post(url, files=files, data=data) + if response.status_code == 200: + result = response.json() + print("文件搜索结果:") + print(f"查询: {result['query']}") + for item in result['results']: + print(f" 排名{item['rank']}: {item['filename']} (相似度: {item['similarity']:.4f})") + else: + print(f"请求失败: {response.status_code}") + print(response.text) + except Exception as e: + print(f"请求错误: {e}") + finally: + # 关闭文件 + for _, file in files: + file.close() + + +def search_with_urls_example(): + """URL搜索示例""" + url = "http://localhost:5000/api/search/url" + + data = { + "text": "a cute cat", + "image_urls": [ + "https://svgsilh.com/svg/1801287.svg", + "https://svgsilh.com/svg/1790711.svg", + "https://svgsilh.com/svg/1300187.svg", + "https://svgsilh.com/svg/1295198-e91e63.svg" + ], + "top_k": 4 + } + + try: + response = requests.post(url, json=data) + if response.status_code == 200: + result = response.json() + print("URL搜索结果:") + print(f"查询: {result['query']}") + for item in result['results']: + print(f" 排名{item['rank']}: {item['image_path']} (相似度: {item['similarity']:.4f})") + else: + print(f"请求失败: {response.status_code}") + print(response.text) + except Exception as e: + print(f"请求错误: {e}") + + +def health_check(): + """健康检查""" + url = "http://localhost:5000/api/health" + + try: + response = requests.get(url) + if response.status_code == 200: + print("服务健康状态:", response.json()) + else: + print(f"健康检查失败: {response.status_code}") + except Exception as e: + print(f"健康检查错误: {e}") + + +if __name__ == "__main__": + print("=== 健康检查 ===") + health_check() + print() + + print("=== 文件上传搜索 ===") + search_with_files_example() + print() + + print("=== URL搜索 ===") + search_with_urls_example() \ No newline at end of file diff --git a/spx-algorithm/demo/demo1.py b/spx-algorithm/demo/demo1.py new file mode 100644 index 000000000..80522106b --- /dev/null +++ b/spx-algorithm/demo/demo1.py @@ -0,0 +1,53 @@ +# import open_clip +# import torch +# from PIL import Image + +# # 加载模型 +# model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k') +# tokenizer = open_clip.get_tokenizer('ViT-B-32') + +# # 处理数据 +# image = preprocess(Image.open("images.jpg")).unsqueeze(0) +# text = tokenizer(["a dog", "a cat", "a white cat", "a grew cat", "a black cat", "a fat cat", "a thin cat"]) + +# #推理 +# with torch.no_grad(): +# image_features = model.encode_image(image) +# text_features = model.encode_text(text) + +# # 计算相似度 +# image_features /= image_features.norm(dim=-1, keepdim=True) +# text_features /= text_features.norm(dim=-1, keepdim=True) +# similarity = (100.0 * image_features @text_features.T).softmax(dim=-1) +# print(similarity) + +import cairosvg +from PIL import Image +import io +import open_clip +import torch + +def process_svg_with_clip(svg_path, texts): + # 转换 SVG + png_data = cairosvg.svg2png(url=svg_path, output_width=224, output_height=224) + image = Image.open(io.BytesIO(png_data)) + + # 加载模型 + model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k') + tokenizer = open_clip.get_tokenizer('ViT-B-32') + + # 处理 + processed_image = preprocess(image).unsqueeze(0) + processed_text = tokenizer(texts) + + # 推理 + with torch.no_grad(): + image_features = model.encode_image(processed_image) + text_features = model.encode_text(processed_text) + similarity = (image_features @ text_features.T).softmax(dim=-1) + + return similarity + +# 使用 +result = process_svg_with_clip("cute2.svg", ["a dog", "a cat", "a white cat", "a grew cat", "a black cat", "a fat cat", "a thin cat"]) +print(result) \ No newline at end of file diff --git a/spx-algorithm/demo/demo2.py b/spx-algorithm/demo/demo2.py new file mode 100644 index 000000000..f62e87300 --- /dev/null +++ b/spx-algorithm/demo/demo2.py @@ -0,0 +1,115 @@ + +import open_clip +import torch +from PIL import Image +import numpy as np +import cairosvg +import io + + +# 文本匹配多张图片的代码示例 + +def text_to_images_search(text_query, image_paths, model_name='ViT-B-32', pretrained='laion2b_s34b_b79k'): + """ + 用文本查询匹配多张图片 + + Args: + text_query: 查询文本 + image_paths: 图片路径列表 + model_name: 模型名称 + pretrained: 预训练权重 + + Returns: + 排序后的结果列表,包含图片路径和相似度分数 + """ + # 加载模型 + device = "cuda" if torch.cuda.is_available() else "cpu" + model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained=pretrained) + tokenizer = open_clip.get_tokenizer(model_name) + + model = model.to(device) + + # 处理文本 + text = tokenizer([text_query]).to(device) + + # 处理所有图片 + images = [] + valid_paths = [] + + for img_path in image_paths: + try: + # 转换 SVG + png_data = cairosvg.svg2png(url=img_path, output_width=224, output_height=224) + image = Image.open(io.BytesIO(png_data)) + # image = Image.open(img_path).convert('RGB') + images.append(preprocess(image)) + valid_paths.append(img_path) + except Exception as e: + print(f"无法加载图片 {img_path}: {e}") + continue + + if not images: + return [] + + # 批量处理图片 + images_tensor = torch.stack(images).to(device) + + with torch.no_grad(): + # 编码文本和图片 + text_features = model.encode_text(text) + image_features = model.encode_image(images_tensor) + + # 归一化特征 + text_features /= text_features.norm(dim=-1, keepdim=True) + image_features /= image_features.norm(dim=-1, keepdim=True) + + # 计算相似度 + similarity = (text_features @ image_features.T).squeeze(0) + + # 转换为 numpy 并排序 + similarity_scores = similarity.cpu().numpy() + + # 创建结果列表 + results = [] + for i, score in enumerate(similarity_scores): + results.append({ + 'image_path': valid_paths[i], + 'similarity': float(score), + 'rank': i + 1 + }) + + # 按相似度降序排序 + results.sort(key=lambda x: x['similarity'], reverse=True) + + # 更新排名 + for i, result in enumerate(results): + result['rank'] = i + 1 + + return results + +# 使用示例 +if __name__ == "__main__": + # 文本查询 + query = "a cute cat" + + # 图片路径列表 + image_list = [ + "dog.svg", + "cute.svg", + "cute2.svg", + "image.svg" + ] + + # 执行搜索 + results = text_to_images_search(query, image_list) + + # 显示结果 + print(f"查询文本: '{query}'") + print("匹配结果(按相似度排序):") + print("-" * 60) + + for result in results: + print(f"排名 {result['rank']}: {result['image_path']}") + print(f"相似度: {result['similarity']:.4f}") + print("-" * 30) + diff --git a/spx-algorithm/demo/demo3.py b/spx-algorithm/demo/demo3.py new file mode 100644 index 000000000..50be0be98 --- /dev/null +++ b/spx-algorithm/demo/demo3.py @@ -0,0 +1,146 @@ +import open_clip +import torch +from PIL import Image +import os +from pathlib import Path +import json + + +# 批量图片搜索的高级版本 + + +class ImageSearchEngine: + def __init__(self, model_name='ViT-B-32', pretrained='laion2b_s34b_b79k'): + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.model, _, self.preprocess = open_clip.create_model_and_transforms( + model_name, pretrained=pretrained + ) + self.tokenizer = open_clip.get_tokenizer(model_name) + self.model = self.model.to(self.device) + + self.image_features_cache = {} + self.image_paths = [] + + def index_images(self, image_directory, extensions=['.jpg', '.jpeg', '.png', '.bmp']): + """ + 为目录中的所有图片建立索引 + """ + print("正在索引图片...") + + # 获取所有图片文件 + image_paths = [] + for ext in extensions: + image_paths.extend(Path(image_directory).glob(f"**/*{ext}")) + image_paths.extend(Path(image_directory).glob(f"**/*{ext.upper()}")) + + self.image_paths = [str(p) for p in image_paths] + print(f"找到 {len(self.image_paths)} 张图片") + + # 批量提取图片特征 + batch_size = 32 # 根据显存调整 + all_features = [] + + for i in range(0, len(self.image_paths), batch_size): + batch_paths = self.image_paths[i:i+batch_size] + batch_images = [] + valid_paths = [] + + for img_path in batch_paths: + try: + image = Image.open(img_path).convert('RGB') + batch_images.append(self.preprocess(image)) + valid_paths.append(img_path) + except Exception as e: + print(f"跳过损坏的图片: {img_path}") + continue + + if batch_images: + images_tensor = torch.stack(batch_images).to(self.device) + + with torch.no_grad(): + features = self.model.encode_image(images_tensor) + features /= features.norm(dim=-1, keepdim=True) + + all_features.append(features.cpu()) + + # 合并所有特征 + if all_features: + self.image_features = torch.cat(all_features, dim=0) + print(f"成功索引 {len(self.image_features)} 张图片") + else: + print("没有找到有效的图片") + + def search(self, query_text, top_k=10): + """ + 根据文本查询搜索图片 + """ + if not hasattr(self, 'image_features'): + raise ValueError("请先使用 index_images() 建立索引") + + # 编码查询文本 + text = self.tokenizer([query_text]).to(self.device) + + with torch.no_grad(): + text_features = self.model.encode_text(text) + text_features /= text_features.norm(dim=-1, keepdim=True) + + # 计算相似度 + similarity = (text_features @self.image_features.T).squeeze(0) + + # 获取 top-k 结果 + top_k = min(top_k, len(self.image_paths)) + top_indices = similarity.topk(top_k).indices + top_scores = similarity.topk(top_k).values + + # 构建结果 + results = [] + for i, (idx, score) in enumerate(zip(top_indices, top_scores)): + results.append({ + 'rank': i + 1, + 'image_path': self.image_paths[idx], + 'similarity': float(score), + 'filename': os.path.basename(self.image_paths[idx]) + }) + + return results + + def save_index(self, index_path): + """保存索引到文件""" + index_data = { + 'image_paths': self.image_paths, + 'image_features': self.image_features.numpy().tolist() + } + with open(index_path, 'w') as f: + json.dump(index_data, f) + + def load_index(self, index_path): + """从文件加载索引""" + with open(index_path, 'r') as f: + index_data = json.load(f) + + self.image_paths = index_data['image_paths'] + self.image_features = torch.tensor(index_data['image_features']) + +# 使用示例 +if __name__ == "__main__": + # 创建搜索引擎 + search_engine = ImageSearchEngine() + + # 为图片目录建立索引 + search_engine.index_images("./images") # 替换为你的图片目录 + + # 执行搜索 + queries = [ + "a dog playing in the park", + "sunset over the ocean", + "people eating at a restaurant", + "red sports car" + ] + + for query in queries: + print(f"\n搜索: '{query}'") + results = search_engine.search(query, top_k=5) + + for result in results: + print(f"{result['rank']}. {result['filename']} (相似度: {result['similarity']:.4f})") + diff --git a/spx-algorithm/docs/v1.md b/spx-algorithm/docs/v1.md new file mode 100644 index 000000000..8ab7c0d32 --- /dev/null +++ b/spx-algorithm/docs/v1.md @@ -0,0 +1,111 @@ +# OpenCLIP 功能说明文档 + +## 简介 + +本项目基于 [OpenCLIP](https://github.com/mlfoundations/open_clip) 实现了一个图像搜索服务,能够通过自然语言描述来搜索和匹配图像。OpenCLIP 是一个开源的 CLIP(Contrastive Language-Image Pre-training)模型实现,能够理解图像和文本之间的语义关系。 + +## OpenCLIP 核心功能 + +### 1. 多模态理解 +- **文本编码**: 将自然语言文本转换为向量表示 +- **图像编码**: 将图像转换为语义向量表示 +- **跨模态匹配**: 计算文本和图像之间的语义相似度 + +### 2. 支持的模型类型 +项目中使用的预训练模型: +- `ViT-B-32`: Vision Transformer Base 模型,32x32 patch size +- 预训练数据集: `laion2b_s34b_b79k`(LAION-2B 数据集) + +### 3. 图像格式支持 +- **常规格式**: PNG, JPG, JPEG, WebP, BMP, TIFF +- **SVG 支持**: 通过 CairoSVG 转换为位图后处理 +- **预处理**: 自动调整图像尺寸到 224x224 像素 + +## 项目中的应用 + +### 核心服务类: ImageSearchService + +位置: `project/app/services/image_search_service.py` + +#### 主要功能方法: + +1. **模型初始化** (`_load_model`) +```python +# 加载 CLIP 模型和预处理器 +model, _, preprocess = open_clip.create_model_and_transforms( + model_name, pretrained=pretrained +) +tokenizer = open_clip.get_tokenizer(model_name) +``` + +2. **图像处理** (`_process_image`) +- 支持多种图像格式 +- SVG 文件特殊处理(转换为 PNG) +- 统一预处理到标准尺寸 + +3. **语义搜索** (`search_images`) +- 文本查询编码 +- 批量图像编码 +- 相似度计算和排序 + + +## API 接口 + +### 1. 文件上传搜索 (`/api/search`) +- **方法**: POST +- **输入**: 文本查询 + 图像文件 +- **输出**: 按相似度排序的结果列表 + +### 2. URL 搜索 (`/api/search/url`) +- **方法**: POST +- **输入**: 文本查询 + 图像URL列表 +- **输出**: 按相似度排序的结果列表 + +### 响应格式 +```json +{ + "query": "查询文本", + "total_images": 10, + "results_count": 5, + "results": [ + { + "rank": 1, + "similarity": 0.8567, + "filename": "image1.jpg" + } + ] +} +``` + +## 技术特性 + +### 1. 性能优化 +- **模型复用**: 服务启动时加载一次,避免重复加载 +- **批量处理**: 支持多图像并行编码 +- **GPU 加速**: 自动检测并使用 CUDA(如果可用) + +### 2. 错误处理 +- 图像文件验证 +- 格式不支持的优雅降级 +- 详细的错误日志记录 + +### 3. 扩展性 +- 可配置的模型参数 +- 支持不同的预训练权重 +- 易于集成新的图像格式 + + +## 配置选项 + +### 环境变量 +- `CLIP_MODEL_NAME`: CLIP 模型名称(默认: ViT-B-32) +- `CLIP_PRETRAINED`: 预训练权重(默认: laion2b_s34b_b79k) + +### 支持的模型类型 +- `ViT-B-32`: 轻量级,速度快 +- `ViT-B-16`: 更高精度 +- `ViT-L-14`: 大型模型,最高精度 +- `RN50/RN101`: ResNet 架构 + + + diff --git a/spx-algorithm/docs/v2.md b/spx-algorithm/docs/v2.md new file mode 100644 index 000000000..ce7c29066 --- /dev/null +++ b/spx-algorithm/docs/v2.md @@ -0,0 +1,124 @@ +# OpenCLIP 功能说明文档 + +## 简介 + +本项目基于 [OpenCLIP](https://github.com/mlfoundations/open_clip) 实现了一个图像搜索服务,能够通过自然语言描述来搜索和匹配图像。OpenCLIP 是一个开源的 CLIP(Contrastive Language-Image Pre-training)模型实现,能够理解图像和文本之间的语义关系。 + +## OpenCLIP 核心功能 + +### 1. 多模态理解 +- **文本编码**: 将自然语言文本转换为向量表示 +- **图像编码**: 将图像转换为语义向量表示 +- **跨模态匹配**: 计算文本和图像之间的语义相似度 + +### 2. 支持的模型类型 +项目中使用的预训练模型: +- `ViT-B-32`: Vision Transformer Base 模型,32x32 patch size +- 预训练数据集: `laion2b_s34b_b79k`(LAION-2B 数据集) + +### 3. 图像格式支持 +- **常规格式**: PNG, JPG, JPEG, WebP, BMP, TIFF +- **SVG 支持**: 通过 CairoSVG 转换为位图后处理 +- **预处理**: 自动调整图像尺寸到 224x224 像素 + +## 项目中的应用 + +### 核心服务类: ImageSearchService + +位置: `project/app/services/image_search_service.py` + +#### 主要功能方法: + +1. **模型初始化** (`_load_model`) +```python +# 加载 CLIP 模型和预处理器 +model, _, preprocess = open_clip.create_model_and_transforms( + model_name, pretrained=pretrained +) +tokenizer = open_clip.get_tokenizer(model_name) +``` + +2. **图像处理** (`_process_image`) +- 支持多种图像格式 +- SVG 文件特殊处理(转换为 PNG) +- 统一预处理到标准尺寸 + +3. **语义搜索** (`search_images`) +- 文本查询编码 +- 批量图像编码 +- 相似度计算和排序 + + +## API 接口 + +### 1. 文件上传搜索 (`/api/search`) +- **方法**: POST +- **输入**: 文本查询 + 图像文件 +- **输出**: 按相似度排序的结果列表 + +### 2. URL 搜索 (`/api/search/url`) +- **方法**: POST +- **输入**: 文本查询 + 图像URL列表 +- **输出**: 按相似度排序的结果列表 + +### 响应格式 +```json +{ + "query": "查询文本", + "total_images": 10, + "results_count": 5, + "results": [ + { + "rank": 1, + "similarity": 0.8567, + "filename": "image1.jpg" + } + ] +} +``` + +## 技术特性 + +### 1. 性能优化 +- **模型复用**: 服务启动时加载一次,避免重复加载 +- **批量处理**: 支持多图像并行编码 +- **GPU 加速**: 自动检测并使用 CUDA(如果可用) + +### 2. 错误处理 +- 图像文件验证 +- 格式不支持的优雅降级 +- 详细的错误日志记录 + +### 3. 扩展性 +- 可配置的模型参数 +- 支持不同的预训练权重 +- 易于集成新的图像格式 + + +## 配置选项 + +### 环境变量 +- `CLIP_MODEL_NAME`: CLIP 模型名称(默认: ViT-B-32) +- `CLIP_PRETRAINED`: 预训练权重(默认: laion2b_s34b_b79k) + +### 支持的模型类型 +- `ViT-B-32`: 轻量级,速度快 +- `ViT-B-16`: 更高精度 +- `ViT-L-14`: 大型模型,最高精度 +- `RN50/RN101`: ResNet 架构 + + + +## 效果演示 + +### **模拟数据库样本** + + + +### **运行结果** + + + + + + \ No newline at end of file diff --git a/spx-algorithm/project/.env.example b/spx-algorithm/project/.env.example new file mode 100644 index 000000000..3395fb6ad --- /dev/null +++ b/spx-algorithm/project/.env.example @@ -0,0 +1,11 @@ +# Flask环境配置 +FLASK_ENV=development +SECRET_KEY=your-secret-key-here + +# 服务器配置 +HOST=0.0.0.0 +PORT=5000 + +# CLIP模型配置 +CLIP_MODEL_NAME=ViT-B-32 +CLIP_PRETRAINED=laion2b_s34b_b79k \ No newline at end of file diff --git a/spx-algorithm/project/.gitignore b/spx-algorithm/project/.gitignore new file mode 100644 index 000000000..5f97e9e3e --- /dev/null +++ b/spx-algorithm/project/.gitignore @@ -0,0 +1,184 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be added to the global gitignore or merged into this project gitignore. For a PyCharm +# project, uncomment below lines: +#.idea/ + +# Application specific +uploads/ +logs/ +*.log + +# IDE +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# OS generated files +.DS_Store +.DS_Store? +._* +.Spotlight-V100 +.Trashes +ehthumbs.db +Thumbs.db + +# Model cache +models/ +checkpoints/ \ No newline at end of file diff --git a/spx-algorithm/project/Dockerfile b/spx-algorithm/project/Dockerfile new file mode 100644 index 000000000..2e4c07d8b --- /dev/null +++ b/spx-algorithm/project/Dockerfile @@ -0,0 +1,33 @@ +FROM python:3.9-slim + +WORKDIR /app + +# 安装系统依赖 +RUN apt-get update && apt-get install -y \ + libcairo2-dev \ + libgdk-pixbuf2.0-dev \ + libffi-dev \ + shared-mime-info \ + && rm -rf /var/lib/apt/lists/* + +# 复制依赖文件 +COPY requirements.txt . + +# 安装Python依赖 +RUN pip install --no-cache-dir -r requirements.txt + +# 复制应用代码 +COPY . . + +# 创建必要的目录 +RUN mkdir -p uploads logs + +# 设置环境变量 +ENV FLASK_ENV=production +ENV PYTHONPATH=/app + +# 暴露端口 +EXPOSE 5000 + +# 启动命令 +CMD ["gunicorn", "--bind", "0.0.0.0:5000", "--workers", "4", "run:app"] \ No newline at end of file diff --git a/spx-algorithm/project/README.md b/spx-algorithm/project/README.md new file mode 100644 index 000000000..a76567230 --- /dev/null +++ b/spx-algorithm/project/README.md @@ -0,0 +1,309 @@ +# Image Search API + +基于 OpenCLIP 的图像搜索 API,支持通过文本查询在多张图片中进行语义搜索。 + +## 功能特性 + +- 🔍 基于 CLIP 模型的文本-图像语义搜索 +- 📁 支持多种图片格式(PNG, JPG, SVG, WebP 等) +- 🚀 Flask RESTful API 接口 +- 📊 返回相似度排序结果 +- 🔧 支持自定义返回结果数量(top-k) +- 🐳 Docker 容器化部署 +- 📝 完整的错误处理和日志记录 + +## 项目结构 + +``` +project/ +├── app/ +│ ├── __init__.py # Flask 应用工厂 +│ ├── api/ +│ │ ├── __init__.py +│ │ └── routes.py # API 路由 +│ ├── services/ +│ │ ├── __init__.py +│ │ └── image_search_service.py # 核心搜索服务 +│ ├── config/ +│ │ ├── __init__.py +│ │ └── config.py # 配置管理 +│ └── models/ # 数据模型(预留) +├── tests/ # 测试文件 +├── uploads/ # 上传文件临时目录 +├── logs/ # 日志文件 +├── static/ # 静态文件 +├── requirements.txt # Python 依赖 +├── Dockerfile # Docker 配置 +├── .env.example # 环境变量示例 +└── run.py # 应用启动文件 +``` + +## 快速开始 + +### 环境要求 + +- Python 3.8+ +- PyTorch +- 足够的内存运行 CLIP 模型 + +### 安装依赖 + +```bash +# 克隆项目 +cd project + +# 创建虚拟环境 +python -m venv venv +source venv/bin/activate # Linux/Mac +# 或 venv\\Scripts\\activate # Windows + +# 安装依赖 +pip install -r requirements.txt +``` + +### 配置环境 + +```bash +# 复制环境变量配置文件 +cp .env.example .env + +# 编辑配置文件 +nano .env +``` + +### 启动服务 + +```bash +# 开发模式 +python run.py + +# 或使用 gunicorn(生产环境) +gunicorn --bind 0.0.0.0:5000 --workers 4 run:app +``` + +## API 接口 + +### 1. 健康检查 + +```http +GET /api/health +``` + +**响应示例:** +```json +{ + "status": "healthy", + "service": "image-search-api" +} +``` + +### 2. 文件上传搜索 + +```http +POST /api/search +Content-Type: multipart/form-data +``` + +**请求参数:** +- `text` (string, required): 查询文本 +- `images` (files, required): 图片文件列表 +- `top_k` (integer, optional): 返回前 k 个结果 + +**响应示例:** +```json +{ + "query": "a cute cat", + "total_images": 4, + "results_count": 2, + "results": [ + { + "rank": 1, + "similarity": 0.8567, + "filename": "cute_cat.jpg" + }, + { + "rank": 2, + "similarity": 0.7234, + "filename": "kitten.png" + } + ] +} +``` + +### 3. URL 搜索 + +```http +POST /api/search/url +Content-Type: application/json +``` + +**请求体:** +```json +{ + "text": "查询文本", + "image_urls": [ + "http://example.com/image1.jpg", + "http://example.com/image2.png" + ], + "top_k": 5 +} +``` + +## 使用示例 + +### Python 客户端示例 + +```python +import requests + +# 文件上传搜索 +def search_with_files(): + url = "http://localhost:5000/api/search" + + files = [ + ('images', open('image1.jpg', 'rb')), + ('images', open('image2.png', 'rb')), + ] + + data = { + 'text': 'a beautiful sunset', + 'top_k': 3 + } + + response = requests.post(url, files=files, data=data) + return response.json() + +# URL 搜索 +def search_with_urls(): + url = "http://localhost:5000/api/search/url" + + data = { + "text": "a cute dog", + "image_urls": [ + "https://example.com/dog1.jpg", + "https://example.com/dog2.png" + ], + "top_k": 2 + } + + response = requests.post(url, json=data) + return response.json() +``` + +### cURL 示例 + +```bash +# 文件上传搜索 +curl -X POST http://localhost:5000/api/search \ + -F "text=a cute cat" \ + -F "images=@image1.jpg" \ + -F "images=@image2.png" \ + -F "top_k=3" + +# URL 搜索 +curl -X POST http://localhost:5000/api/search/url \ + -H "Content-Type: application/json" \ + -d '{ + "text": "a beautiful landscape", + "image_urls": ["http://example.com/img1.jpg"], + "top_k": 5 + }' +``` + +## Docker 部署 + +```bash +# 构建镜像 +docker build -t image-search-api . + +# 运行容器 +docker run -p 5000:5000 image-search-api + +# 或使用 docker-compose +echo "version: '3.8' +services: + api: + build: . + ports: + - '5000:5000' + environment: + - FLASK_ENV=production + volumes: + - ./logs:/app/logs" > docker-compose.yml + +docker-compose up -d +``` + +## 配置选项 + +### 环境变量 + +- `FLASK_ENV`: 运行环境 (development/production) +- `SECRET_KEY`: Flask 密钥 +- `HOST`: 服务器主机地址 +- `PORT`: 服务器端口 +- `CLIP_MODEL_NAME`: CLIP 模型名称 +- `CLIP_PRETRAINED`: 预训练权重 + +### 支持的 CLIP 模型 + +- `ViT-B-32` +- `ViT-B-16` +- `ViT-L-14` +- `RN50` +- `RN101` + +## 错误处理 + +API 返回标准的 HTTP 状态码和错误信息: + +```json +{ + "error": "错误描述", + "code": "ERROR_CODE", + "details": "详细错误信息" +} +``` + +常见错误码: +- `MISSING_TEXT_QUERY`: 缺少查询文本 +- `NO_FILES_UPLOADED`: 没有上传文件 +- `INVALID_TOP_K`: top_k 参数无效 +- `INTERNAL_ERROR`: 内部服务器错误 + +## 开发指南 + +### 运行测试 + +```bash +# 安装测试依赖 +pip install pytest pytest-flask + +# 运行测试 +pytest tests/ +``` + +### 代码格式化 + +```bash +# 格式化代码 +black app/ tests/ + +# 检查代码风格 +flake8 app/ tests/ +``` + +## 性能优化 + +- 模型只在启动时加载一次 +- 支持批量图片处理 +- 临时文件自动清理 +- 可配置的工作进程数量 + +## 许可证 + +MIT License + +## 贡献 + +欢迎提交 Issue 和 Pull Request! \ No newline at end of file diff --git a/spx-algorithm/project/app/__init__.py b/spx-algorithm/project/app/__init__.py new file mode 100644 index 000000000..3b726903c --- /dev/null +++ b/spx-algorithm/project/app/__init__.py @@ -0,0 +1,67 @@ +from flask import Flask +import os +import logging +from logging.handlers import RotatingFileHandler + + +def create_app(config_name='default'): + """Application factory function""" + app = Flask(__name__) + + # Configure application + from .config.config import config + app.config.from_object(config[config_name]) + + # Create necessary directories + os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) + os.makedirs(app.config['LOG_FOLDER'], exist_ok=True) + + # Configure logging + if not os.path.exists('logs'): + os.mkdir('logs') + + # Setup console handler for debug mode + if app.debug: + console_handler = logging.StreamHandler() + console_handler.setFormatter(logging.Formatter( + '%(asctime)s %(levelname)s: %(message)s [in %(name)s:%(lineno)d]' + )) + console_handler.setLevel(logging.INFO) + + # Set root logger level to INFO + root_logger = logging.getLogger() + root_logger.setLevel(logging.INFO) + root_logger.addHandler(console_handler) + + # Setup file handler for production + if not app.debug and not app.testing: + file_handler = RotatingFileHandler( + os.path.join(app.config['LOG_FOLDER'], 'app.log'), + maxBytes=10240000, backupCount=10 + ) + file_handler.setFormatter(logging.Formatter( + '%(asctime)s %(levelname)s: %(message)s [in %(pathname)s:%(lineno)d]' + )) + file_handler.setLevel(logging.INFO) + app.logger.addHandler(file_handler) + app.logger.setLevel(logging.INFO) + app.logger.info('Image Search API startup') + + # Register blueprints + from .api.routes import api_bp + app.register_blueprint(api_bp) + + # Add root route + @app.route('/') + def index(): + return { + 'message': 'Image Search API', + 'version': '1.0.0', + 'endpoints': { + 'health': '/api/health', + 'search': '/api/search (POST)', + 'search_by_url': '/api/search/url (POST)' + } + } + + return app \ No newline at end of file diff --git a/spx-algorithm/project/app/api/__init__.py b/spx-algorithm/project/app/api/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/spx-algorithm/project/app/api/routes.py b/spx-algorithm/project/app/api/routes.py new file mode 100644 index 000000000..507188612 --- /dev/null +++ b/spx-algorithm/project/app/api/routes.py @@ -0,0 +1,212 @@ +import logging +import os +import uuid + +from flask import Blueprint, request, jsonify, current_app +from werkzeug.utils import secure_filename + +from ..services.image_search_service import ImageSearchService + +logger = logging.getLogger(__name__) + +api_bp = Blueprint('api', __name__, url_prefix='/api') + +# 全局服务实例 +search_service = None + + +def init_search_service(model_name: str = 'ViT-B-32', pretrained: str = 'laion2b_s34b_b79k'): + """初始化搜索服务""" + global search_service + if search_service is None: + search_service = ImageSearchService(model_name, pretrained) + + +def allowed_file(filename: str) -> bool: + """检查文件扩展名是否允许""" + ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'svg', 'webp'} + return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS + + +@api_bp.route('/health', methods=['GET']) +def health_check(): + """健康检查接口""" + return jsonify({ + 'status': 'healthy', + 'service': 'image-search-api' + }) + + +@api_bp.route('/search', methods=['POST']) +def search_images(): + """ + 图像搜索接口 + + 接受表单数据: + - text: 查询文本 + - images: 多个图片文件 + - top_k: 返回前k张图片(可选,默认返回所有) + """ + try: + # 检查是否有文本查询 + text_query = request.form.get('text', '').strip() + if not text_query: + return jsonify({ + 'error': '查询文本不能为空', + 'code': 'MISSING_TEXT_QUERY' + }), 400 + + # 检查是否有上传的文件 + if 'images' not in request.files: + return jsonify({ + 'error': '没有上传图片文件', + 'code': 'NO_FILES_UPLOADED' + }), 400 + + files = request.files.getlist('images') + if not files or all(f.filename == '' for f in files): + return jsonify({ + 'error': '没有选择有效的图片文件', + 'code': 'NO_VALID_FILES' + }), 400 + + # 获取top_k参数 + top_k = request.form.get('top_k') + if top_k: + try: + top_k = int(top_k) + if top_k <= 0: + top_k = None + except ValueError: + return jsonify({ + 'error': 'top_k参数必须是正整数', + 'code': 'INVALID_TOP_K' + }), 400 + + # 保存上传的文件 + upload_folder = current_app.config['UPLOAD_FOLDER'] + session_id = str(uuid.uuid4()) + session_folder = os.path.join(upload_folder, session_id) + os.makedirs(session_folder, exist_ok=True) + + image_paths = [] + for file in files: + if file and file.filename != '' and allowed_file(file.filename): + filename = secure_filename(file.filename) + # 添加时间戳避免文件名冲突 + name, ext = os.path.splitext(filename) + filename = f"{name}_{uuid.uuid4().hex[:8]}{ext}" + file_path = os.path.join(session_folder, filename) + file.save(file_path) + image_paths.append(file_path) + + if not image_paths: + return jsonify({ + 'error': '没有有效的图片文件被保存', + 'code': 'NO_VALID_IMAGES_SAVED' + }), 400 + + # 初始化服务(如果还没有初始化) + if search_service is None: + init_search_service() + + # 执行搜索 + results = search_service.search_images(text_query, image_paths, top_k) + + # 清理临时文件 + try: + for file_path in image_paths: + if os.path.exists(file_path): + os.remove(file_path) + os.rmdir(session_folder) + except Exception as e: + logger.warning(f"清理临时文件失败: {e}") + + # 处理结果,移除绝对路径信息 + processed_results = [] + for result in results: + processed_results.append({ + 'rank': result['rank'], + 'similarity': result['similarity'], + 'filename': os.path.basename(result['image_path']) + }) + + return jsonify({ + 'query': text_query, + 'total_images': len(image_paths), + 'results_count': len(processed_results), + 'results': processed_results + }) + + except Exception as e: + logger.error(f"搜索过程中出现错误: {e}") + return jsonify({ + 'error': '内部服务器错误', + 'code': 'INTERNAL_ERROR', + 'details': str(e) + }), 500 + + +@api_bp.route('/search/url', methods=['POST']) +def search_images_by_url(): + """ + 通过URL搜索图像接口 + + 接受JSON数据: + { + "text": "查询文本", + "image_urls": ["图片URL列表"], + "top_k": 5 // 可选 + } + """ + try: + data = request.get_json() + if not data: + return jsonify({ + 'error': '请求体必须是有效的JSON', + 'code': 'INVALID_JSON' + }), 400 + + text_query = data.get('text', '').strip() + if not text_query: + return jsonify({ + 'error': '查询文本不能为空', + 'code': 'MISSING_TEXT_QUERY' + }), 400 + + image_urls = data.get('image_urls', []) + if not image_urls or not isinstance(image_urls, list): + return jsonify({ + 'error': 'image_urls必须是非空的URL列表', + 'code': 'INVALID_IMAGE_URLS' + }), 400 + + top_k = data.get('top_k') + if top_k is not None: + if not isinstance(top_k, int) or top_k <= 0: + return jsonify({ + 'error': 'top_k参数必须是正整数', + 'code': 'INVALID_TOP_K' + }), 400 + + # 初始化服务(如果还没有初始化) + if search_service is None: + init_search_service() + + # 执行搜索 + results = search_service.search_images(text_query, image_urls, top_k) + + return jsonify({ + 'query': text_query, + 'total_images': len(image_urls), + 'results_count': len(results), + 'results': results + }) + + except Exception as e: + logger.error(f"URL搜索过程中出现错误: {e}") + return jsonify({ + 'error': '内部服务器错误', + 'code': 'INTERNAL_ERROR', + 'details': str(e) + }), 500 \ No newline at end of file diff --git a/spx-algorithm/project/app/config/__init__.py b/spx-algorithm/project/app/config/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/spx-algorithm/project/app/config/config.py b/spx-algorithm/project/app/config/config.py new file mode 100644 index 000000000..da75fca67 --- /dev/null +++ b/spx-algorithm/project/app/config/config.py @@ -0,0 +1,56 @@ +import os +from pathlib import Path + +# 项目根目录 +BASE_DIR = Path(__file__).parent.parent.parent + + +class Config: + """基础配置类""" + SECRET_KEY = os.environ.get('SECRET_KEY') or 'dev-secret-key-change-in-production' + + # 文件上传配置 + UPLOAD_FOLDER = os.path.join(BASE_DIR, 'uploads') + MAX_CONTENT_LENGTH = 32 * 1024 * 1024 # 32MB + + # 日志配置 + LOG_FOLDER = os.path.join(BASE_DIR, 'logs') + + # CLIP模型配置 + CLIP_MODEL_NAME = os.environ.get('CLIP_MODEL_NAME') or 'ViT-B-32' + CLIP_PRETRAINED = os.environ.get('CLIP_PRETRAINED') or 'laion2b_s34b_b79k' + + # API配置 + JSON_SORT_KEYS = False + JSONIFY_PRETTYPRINT_REGULAR = True + + +class DevelopmentConfig(Config): + """开发环境配置""" + DEBUG = True + ENV = 'development' + + +class ProductionConfig(Config): + """生产环境配置""" + DEBUG = False + ENV = 'production' + + # 生产环境应该使用更强的密钥 + SECRET_KEY = os.environ.get('SECRET_KEY') or 'production-secret-key-must-be-set' + + +class TestingConfig(Config): + """测试环境配置""" + TESTING = True + DEBUG = True + UPLOAD_FOLDER = os.path.join(BASE_DIR, 'test_uploads') + + +# 配置字典 +config = { + 'development': DevelopmentConfig, + 'production': ProductionConfig, + 'testing': TestingConfig, + 'default': DevelopmentConfig +} \ No newline at end of file diff --git a/spx-algorithm/project/app/services/__init__.py b/spx-algorithm/project/app/services/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/spx-algorithm/project/app/services/image_search_service.py b/spx-algorithm/project/app/services/image_search_service.py new file mode 100644 index 000000000..895c0e476 --- /dev/null +++ b/spx-algorithm/project/app/services/image_search_service.py @@ -0,0 +1,241 @@ +import io +import logging +import os +from typing import List, Dict, Any, Optional +from urllib.parse import urlparse + +import cairosvg +import open_clip +import requests +import torch +from PIL import Image + +logger = logging.getLogger(__name__) + + +class ImageSearchService: + """图像搜索服务类""" + + def __init__(self, model_name: str = 'ViT-B-32', pretrained: str = 'laion2b_s34b_b79k'): + """ + 初始化图像搜索服务 + + Args: + model_name: CLIP模型名称 + pretrained: 预训练权重 + """ + self.model_name = model_name + self.pretrained = pretrained + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.model = None + self.preprocess = None + self.tokenizer = None + self._load_model() + + def _load_model(self): + """加载CLIP模型""" + try: + self.model, _, self.preprocess = open_clip.create_model_and_transforms( + self.model_name, pretrained=self.pretrained + ) + self.tokenizer = open_clip.get_tokenizer(self.model_name) + self.model = self.model.to(self.device) + logger.info(f"模型加载成功: {self.model_name} on {self.device}") + except Exception as e: + logger.error(f"模型加载失败: {e}") + raise + + def _process_image(self, image_source: str) -> Optional[torch.Tensor]: + """ + 处理单张图片,支持本地路径和网络URL + + Args: + image_source: 图片路径或URL + + Returns: + 处理后的图片张量,失败返回None + """ + try: + # 判断是否为网络URL + parsed = urlparse(image_source) + is_url = bool(parsed.scheme and parsed.netloc) + + if is_url: + # 处理网络图片 + headers = { + 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36', + 'Accept': 'image/*, */*', + 'Accept-Language': 'en-US,en;q=0.9', + 'Accept-Encoding': 'gzip, deflate, br', + 'Connection': 'keep-alive' + } + response = requests.get(image_source, headers=headers, timeout=10) + response.raise_for_status() + + # 从Content-Type或URL判断是否为SVG + content_type = response.headers.get('content-type', '').lower() + is_svg = 'svg' in content_type or image_source.lower().endswith('.svg') + + logger.info(f"处理网络图片: {image_source}") + logger.info(f"Content-Type: {content_type}") + logger.info(f"Content-Length: {len(response.content)}") + + if is_svg: + # 处理网络SVG + if len(response.content) == 0: + logger.info(f"SVG文件内容为空: {image_source}") + return None + # 添加白色背景,避免透明导致的问题 + png_data = cairosvg.svg2png( + bytestring=response.content, + output_width=224, + output_height=224, + background_color='white' + ) + image = Image.open(io.BytesIO(png_data)).convert('RGB') + else: + # 处理其他格式网络图片 + if len(response.content) == 0: + logger.error(f"图片文件内容为空: {image_source}") + return None + image = Image.open(io.BytesIO(response.content)).convert('RGB') + + logger.info(f"图片尺寸: {image.size}, 模式: {image.mode}") + + # 可选:保存调试图片(取消注释来启用) + # debug_path = f"debug_{hash(image_source) % 10000}.png" + # image.save(debug_path) + # logger.debug(f"调试图片已保存: {debug_path}") + else: + # 处理本地文件 + if not os.path.exists(image_source): + logger.warning(f"图片文件不存在: {image_source}") + return None + + logger.info(f"处理本地图片: {image_source}") + + # 根据文件扩展名选择处理方式 + _, ext = os.path.splitext(image_source.lower()) + + if ext == '.svg': + # 处理本地SVG文件 + png_data = cairosvg.svg2png( + url=image_source, + output_width=224, + output_height=224, + background_color='white' + ) + image = Image.open(io.BytesIO(png_data)).convert('RGB') + else: + # 处理本地其他格式图片 + image = Image.open(image_source).convert('RGB') + + logger.info(f"本地图片尺寸: {image.size}, 模式: {image.mode}") + + # 预处理前保存原始图片信息 + logger.info(f"预处理前图片: {image.size}, 模式: {image.mode}") + + # 检查图片是否为空或全黑/全白 + import numpy as np + img_array = np.array(image) + logger.info(f"图片数组形状: {img_array.shape}") + logger.info(f"图片像素值范围: min={img_array.min()}, max={img_array.max()}, mean={img_array.mean():.2f}") + + # 检查图片是否为纯色 + unique_colors = len(np.unique(img_array.reshape(-1, img_array.shape[-1]), axis=0)) + logger.info(f"图片唯一颜色数量: {unique_colors}") + + # 如果图片是纯色或颜色过少,可能有问题 + if unique_colors < 5: + logger.warning(f"图片颜色过少 ({unique_colors}),可能是空白图片") + + # 确保图片转换为RGB模式 + if image.mode != 'RGB': + logger.info(f"转换图片模式从 {image.mode} 到 RGB") + image = image.convert('RGB') + + processed_tensor = self.preprocess(image) + logger.info(f"预处理后张量形状: {processed_tensor.shape}") + logger.info(f"预处理后张量统计: min={processed_tensor.min().item():.4f}, max={processed_tensor.max().item():.4f}, mean={processed_tensor.mean().item():.4f}") + + return processed_tensor + + except Exception as e: + logger.error(f"处理图片失败 {image_source}: {e}") + return None + + def search_images(self, text_query: str, image_paths: List[str], top_k: Optional[int] = None) -> List[Dict[str, Any]]: + """ + 根据文本查询搜索图片 + + Args: + text_query: 查询文本 + image_paths: 图片路径列表(支持本地路径和网络URL) + top_k: 返回前k张图片,None则返回所有 + + Returns: + 排序后的结果列表,包含图片路径和相似度分数 + """ + if not text_query.strip(): + raise ValueError("查询文本不能为空") + + if not image_paths: + raise ValueError("图片路径列表不能为空") + + # 处理文本 + text = self.tokenizer([text_query]).to(self.device) + + # 处理所有图片 + images = [] + valid_paths = [] + + for img_path in image_paths: + processed_image = self._process_image(img_path) + if processed_image is not None: + images.append(processed_image) + valid_paths.append(img_path) + + if not images: + logger.warning("没有有效的图片可以处理") + return [] + + # 批量处理图片 + images_tensor = torch.stack(images).to(self.device) + + with torch.no_grad(): + # 编码文本和图片 + text_features = self.model.encode_text(text) + image_features = self.model.encode_image(images_tensor) + + # 归一化特征 + text_features /= text_features.norm(dim=-1, keepdim=True) + image_features /= image_features.norm(dim=-1, keepdim=True) + + # 计算相似度 + similarity = (text_features @ image_features.T).squeeze(0) + + # 转换为 numpy 并排序 + similarity_scores = similarity.cpu().numpy() + + # 创建结果列表 + results = [] + for i, score in enumerate(similarity_scores): + results.append({ + 'image_path': valid_paths[i], + 'similarity': float(score), + 'rank': i + 1 + }) + + # 按相似度降序排序 + results.sort(key=lambda x: x['similarity'], reverse=True) + + # 更新排名 + for i, result in enumerate(results): + result['rank'] = i + 1 + + # 返回前k个结果 + if top_k is not None and top_k > 0: + results = results[:top_k] + + logger.info(f"搜索完成,查询: '{text_query}', 找到 {len(results)} 个结果") + return results \ No newline at end of file diff --git a/spx-algorithm/project/requirements.txt b/spx-algorithm/project/requirements.txt new file mode 100644 index 000000000..d6b3d43ee --- /dev/null +++ b/spx-algorithm/project/requirements.txt @@ -0,0 +1,25 @@ +# Web框架 +Flask==2.3.3 +Werkzeug==2.3.7 + +# 机器学习和计算机视觉 +torch>=1.13.0 +torchvision>=0.14.0 +open_clip_torch>=2.20.0 +Pillow>=9.5.0 +numpy==1.24.0 + +# HTTP客户端 +requests>=2.31.0 + +# SVG处理 +cairosvg>=2.7.0 + +# 开发和测试工具 +pytest>=7.4.0 +pytest-flask>=1.2.0 +black>=23.0.0 +flake8>=6.0.0 + +# 生产环境 +gunicorn>=21.2.0 \ No newline at end of file diff --git a/spx-algorithm/project/run.py b/spx-algorithm/project/run.py new file mode 100644 index 000000000..23614f51a --- /dev/null +++ b/spx-algorithm/project/run.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 +""" +Image Search API 启动文件 +""" +import os +from app import create_app +from app.api.routes import init_search_service + +# 获取配置环境 +config_name = os.getenv('FLASK_ENV', 'development') + +# 创建应用实例 +app = create_app(config_name) + +# 初始化搜索服务 +with app.app_context(): + init_search_service( + model_name=app.config['CLIP_MODEL_NAME'], + pretrained=app.config['CLIP_PRETRAINED'] + ) + +if __name__ == '__main__': + # 开发环境配置 + port = int(os.environ.get('PORT', 5000)) + host = os.environ.get('HOST', '0.0.0.0') + + app.run( + host=host, + port=port, + debug=app.config['DEBUG'], + use_reloader=False + ) \ No newline at end of file diff --git a/spx-algorithm/project/tests/__init__.py b/spx-algorithm/project/tests/__init__.py new file mode 100644 index 000000000..739954cbf --- /dev/null +++ b/spx-algorithm/project/tests/__init__.py @@ -0,0 +1 @@ +# Tests package \ No newline at end of file diff --git a/spx-algorithm/project/tests/test_api.py b/spx-algorithm/project/tests/test_api.py new file mode 100644 index 000000000..ed3e53e5c --- /dev/null +++ b/spx-algorithm/project/tests/test_api.py @@ -0,0 +1,83 @@ +import pytest +import json +import os +from .. import create_app + + +@pytest.fixture +def app(): + """创建测试应用""" + app = create_app('testing') + return app + + +@pytest.fixture +def client(app): + """创建测试客户端""" + return app.test_client() + + +def test_health_check(client): + """测试健康检查接口""" + response = client.get('/api/health') + assert response.status_code == 200 + + data = json.loads(response.data) + assert data['status'] == 'healthy' + assert data['service'] == 'image-search-api' + + +def test_index_route(client): + """测试根路由""" + response = client.get('/') + assert response.status_code == 200 + + data = json.loads(response.data) + assert 'message' in data + assert 'endpoints' in data + + +def test_search_missing_text(client): + """测试缺少查询文本的情况""" + response = client.post('/api/search', data={}) + assert response.status_code == 400 + + data = json.loads(response.data) + assert data['code'] == 'MISSING_TEXT_QUERY' + + +def test_search_no_files(client): + """测试没有上传文件的情况""" + response = client.post('/api/search', data={'text': 'test query'}) + assert response.status_code == 400 + + data = json.loads(response.data) + assert data['code'] == 'NO_FILES_UPLOADED' + + +def test_search_url_invalid_json(client): + """测试无效JSON请求""" + response = client.post('/api/search/url', + data='invalid json', + content_type='application/json') + assert response.status_code == 400 + + +def test_search_url_missing_text(client): + """测试URL搜索缺少文本""" + response = client.post('/api/search/url', + json={'image_urls': ['http://example.com/img.jpg']}) + assert response.status_code == 400 + + data = json.loads(response.data) + assert data['code'] == 'MISSING_TEXT_QUERY' + + +def test_search_url_invalid_urls(client): + """测试无效的URL列表""" + response = client.post('/api/search/url', + json={'text': 'test', 'image_urls': 'not a list'}) + assert response.status_code == 400 + + data = json.loads(response.data) + assert data['code'] == 'INVALID_IMAGE_URLS' \ No newline at end of file diff --git a/spx-algorithm/resource/cute.svg b/spx-algorithm/resource/cute.svg new file mode 100644 index 000000000..9eac5733d --- /dev/null +++ b/spx-algorithm/resource/cute.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/spx-algorithm/resource/cute2.svg b/spx-algorithm/resource/cute2.svg new file mode 100644 index 000000000..411a5835b --- /dev/null +++ b/spx-algorithm/resource/cute2.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/spx-algorithm/resource/demo.png b/spx-algorithm/resource/demo.png new file mode 100644 index 000000000..c313113a2 Binary files /dev/null and b/spx-algorithm/resource/demo.png differ diff --git a/spx-algorithm/resource/dog.svg b/spx-algorithm/resource/dog.svg new file mode 100644 index 000000000..2c66245a6 --- /dev/null +++ b/spx-algorithm/resource/dog.svg @@ -0,0 +1,213 @@ + + + + + + + + + diff --git a/spx-algorithm/resource/image.svg b/spx-algorithm/resource/image.svg new file mode 100644 index 000000000..acd7f1aab --- /dev/null +++ b/spx-algorithm/resource/image.svg @@ -0,0 +1,112 @@ + + + + + + + + + + + + + + + + + + + + diff --git a/spx-algorithm/resource/images.jpg b/spx-algorithm/resource/images.jpg new file mode 100644 index 000000000..a76b1df1b Binary files /dev/null and b/spx-algorithm/resource/images.jpg differ diff --git a/spx-algorithm/resource/vector_db/README_vector_db.md b/spx-algorithm/resource/vector_db/README_vector_db.md new file mode 100644 index 000000000..2def3cc56 --- /dev/null +++ b/spx-algorithm/resource/vector_db/README_vector_db.md @@ -0,0 +1,190 @@ +# 图像向量数据库 + +基于open-clip的图像向量化和相似度搜索系统,复用了现有的图像向量化方式。 + +## 功能特性 + +- **图像向量化**: 使用open-clip模型将图像编码为高维向量 +- **高效搜索**: 基于Faiss的向量相似度搜索 +- **文本搜索**: 支持通过文本描述搜索相关图像 +- **批量处理**: 支持批量添加和索引图像 +- **元数据管理**: 保存图像文件信息和自定义元数据 +- **重复检测**: 自动检测相似或重复的图像 + +## 核心组件 + +### 1. VectorDatabase 类 +- 核心向量数据库实现 +- 图像编码和向量存储 +- 相似度搜索功能 + +### 2. VectorDBManager 类 +- 高级管理接口 +- 目录批量索引 +- 重复图片检测 +- 元数据导出 + +## 安装依赖 + +```bash +pip install -r requirements_vector_db.txt +``` + +主要依赖: +- torch: 深度学习框架 +- open-clip-torch: CLIP模型 +- faiss-cpu/faiss-gpu: 向量搜索引擎 +- Pillow: 图像处理 +- numpy: 数值计算 + +## 快速开始 + +### 基本使用 + +```python +from vector_database import VectorDatabase + +# 初始化数据库 +db = VectorDatabase(db_path='my_vector_db') + +# 添加单张图片 +image_id = db.add_image('path/to/image.jpg', + metadata={'category': 'animals', 'tags': ['cute', 'dog']}) + +# 文本搜索 +results = db.search_by_text('cute dog', k=5) +for result in results: + print(f"{result['image_path']}: {result['similarity']:.3f}") + +# 图片相似搜索 +similar = db.search_by_image('query_image.jpg', k=5) +``` + +### 使用管理工具 + +```python +from vector_db_utils import VectorDBManager + +# 创建管理器 +manager = VectorDBManager(db_path='my_vector_db') + +# 索引整个目录 +stats = manager.index_directory('images_folder/') +print(f"索引了 {stats['indexed']} 张图片") + +# 搜索 +results = manager.search_by_description('beautiful landscape', k=10) +``` + +### 运行演示 + +```bash +python demo_vector_db.py +``` + +## 技术实现 + +### 向量化方法 +- 使用ViT-B-32模型(默认)提取图像特征 +- 特征向量维度: 512 +- L2归一化处理 +- 支持本地文件和网络URL + +### 存储结构 +``` +vector_db/ +├── index.faiss # Faiss向量索引文件 +└── metadata.pkl # 图像元数据和配置 +``` + +### 相似度计算 +- 使用内积(Inner Product)计算向量相似度 +- Faiss IndexFlatIP索引类型 +- 实时搜索,无需预计算 + +## API 参考 + +### VectorDatabase + +#### 初始化 +```python +VectorDatabase( + model_name='ViT-B-32', # CLIP模型名称 + pretrained='laion2b_s34b_b79k', # 预训练权重 + db_path='vector_db', # 数据库存储路径 + dimension=512 # 向量维度 +) +``` + +#### 主要方法 +- `add_image(image_path, metadata)`: 添加单张图片 +- `add_images_batch(image_paths, metadatas)`: 批量添加图片 +- `search_by_text(query_text, k)`: 文本搜索 +- `search_by_image(query_image_path, k)`: 图片搜索 +- `get_stats()`: 获取数据库统计信息 + +### VectorDBManager + +#### 高级功能 +- `index_directory(directory_path)`: 索引目录 +- `find_duplicates(similarity_threshold)`: 查找重复图片 +- `export_metadata(output_file)`: 导出元数据 + +## 性能优化 + +### GPU加速 +```python +# 自动检测并使用GPU +db = VectorDatabase() # 自动使用CUDA如果可用 +``` + +### 批量处理 +```python +# 批量添加比逐个添加更高效 +image_paths = ['img1.jpg', 'img2.jpg', ...] +db.add_images_batch(image_paths) +``` + +### 内存优化 +- 模型加载后复用 +- 向量计算使用float32精度 +- 支持大规模图片集合 + +## 注意事项 + +1. **模型下载**: 首次运行会下载CLIP模型(约1GB) +2. **内存使用**: 大量图片需要足够内存存储向量 +3. **删除限制**: Faiss不支持直接删除,需要重建索引 +4. **文件路径**: 确保图片路径在添加后不会改变 + +## 扩展功能 + +### 自定义模型 +```python +# 使用不同的CLIP模型 +db = VectorDatabase( + model_name='ViT-L-14', + pretrained='openai' +) +``` + +### 高级搜索 +```python +# 组合搜索结果 +text_results = db.search_by_text('query', k=20) +image_results = db.search_by_image('query.jpg', k=20) +# 合并和排序结果... +``` + +## 故障排除 + +### 常见问题 +1. **Faiss安装失败**: 尝试使用conda安装 +2. **CUDA内存不足**: 减少批量大小或使用CPU +3. **图片加载失败**: 检查文件格式和路径 + +### 日志调试 +```python +import logging +logging.basicConfig(level=logging.DEBUG) +``` \ No newline at end of file diff --git a/spx-algorithm/resource/vector_db/__pycache__/vector_database.cpython-312.pyc b/spx-algorithm/resource/vector_db/__pycache__/vector_database.cpython-312.pyc new file mode 100644 index 000000000..d4e82d858 Binary files /dev/null and b/spx-algorithm/resource/vector_db/__pycache__/vector_database.cpython-312.pyc differ diff --git a/spx-algorithm/resource/vector_db/__pycache__/vector_db_utils.cpython-312.pyc b/spx-algorithm/resource/vector_db/__pycache__/vector_db_utils.cpython-312.pyc new file mode 100644 index 000000000..a1fd4b9bd Binary files /dev/null and b/spx-algorithm/resource/vector_db/__pycache__/vector_db_utils.cpython-312.pyc differ diff --git a/spx-algorithm/resource/vector_db/batch_example_db/index.faiss b/spx-algorithm/resource/vector_db/batch_example_db/index.faiss new file mode 100644 index 000000000..99335b2fc Binary files /dev/null and b/spx-algorithm/resource/vector_db/batch_example_db/index.faiss differ diff --git a/spx-algorithm/resource/vector_db/batch_example_db/metadata.pkl b/spx-algorithm/resource/vector_db/batch_example_db/metadata.pkl new file mode 100644 index 000000000..1d4df95b5 Binary files /dev/null and b/spx-algorithm/resource/vector_db/batch_example_db/metadata.pkl differ diff --git a/spx-algorithm/resource/vector_db/cute.svg b/spx-algorithm/resource/vector_db/cute.svg new file mode 100644 index 000000000..9eac5733d --- /dev/null +++ b/spx-algorithm/resource/vector_db/cute.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/spx-algorithm/resource/vector_db/cute2.svg b/spx-algorithm/resource/vector_db/cute2.svg new file mode 100644 index 000000000..411a5835b --- /dev/null +++ b/spx-algorithm/resource/vector_db/cute2.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/spx-algorithm/resource/vector_db/database_metadata.json b/spx-algorithm/resource/vector_db/database_metadata.json new file mode 100644 index 000000000..447b88b28 --- /dev/null +++ b/spx-algorithm/resource/vector_db/database_metadata.json @@ -0,0 +1,42 @@ +{ + "database_stats": { + "total_images": 3, + "dimension": 512, + "model_name": "ViT-B-32", + "pretrained": "laion2b_s34b_b79k", + "db_path": "/Users/qiniu/Documents/builder/spx-algorithm/resource/vector_db/demo_vector_db", + "device": "cpu" + }, + "images": { + "0": { + "image_path": "/Users/qiniu/Documents/builder/spx-algorithm/resource/vector_db/dog.svg", + "metadata": { + "filename": "dog.svg", + "directory": "/Users/qiniu/Documents/builder/spx-algorithm/resource/vector_db", + "file_size": 15505, + "relative_path": "dog.svg" + }, + "added_at": "2025-08-26T13:59:58.374985" + }, + "1": { + "image_path": "/Users/qiniu/Documents/builder/spx-algorithm/resource/vector_db/cute.svg", + "metadata": { + "filename": "cute.svg", + "directory": "/Users/qiniu/Documents/builder/spx-algorithm/resource/vector_db", + "file_size": 37641, + "relative_path": "cute.svg" + }, + "added_at": "2025-08-26T13:59:58.546975" + }, + "2": { + "image_path": "/Users/qiniu/Documents/builder/spx-algorithm/resource/vector_db/cute2.svg", + "metadata": { + "filename": "cute2.svg", + "directory": "/Users/qiniu/Documents/builder/spx-algorithm/resource/vector_db", + "file_size": 59391, + "relative_path": "cute2.svg" + }, + "added_at": "2025-08-26T13:59:58.727322" + } + } +} \ No newline at end of file diff --git a/spx-algorithm/resource/vector_db/demo_vector_db.py b/spx-algorithm/resource/vector_db/demo_vector_db.py new file mode 100644 index 000000000..704b3032d --- /dev/null +++ b/spx-algorithm/resource/vector_db/demo_vector_db.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python3 +""" +向量数据库演示脚本 +""" +import os +import sys +import logging +from vector_db_utils import VectorDBManager, create_sample_database + +# 设置日志 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) + +def main(): + print("=== 向量数据库演示 ===\n") + + # 获取当前脚本所在目录 + current_dir = os.path.dirname(os.path.abspath(__file__)) + + try: + # 创建数据库管理器 + print("1. 初始化向量数据库...") + db_manager = VectorDBManager(db_path=os.path.join(current_dir, 'demo_vector_db')) + + # 索引当前目录的图片 + print("2. 索引当前目录的图片...") + stats = db_manager.index_directory(current_dir) + print(f" 索引结果: {stats}") + + # 显示数据库信息 + print("\n3. 数据库信息:") + info = db_manager.get_database_info() + for key, value in info.items(): + print(f" {key}: {value}") + + if info['total_images'] == 0: + print("\n注意: 当前目录没有找到图片文件,请添加一些图片文件后重试") + return + + # 文本搜索示例 + print("\n4. 文本搜索示例:") + search_queries = [ + "cute animal", + "dog", + "cartoon", + "colorful image" + ] + + for query in search_queries: + print(f"\n 搜索: '{query}'") + results = db_manager.search_by_description(query, k=3) + if results: + for i, result in enumerate(results, 1): + filename = result['metadata']['filename'] + similarity = result['similarity'] + print(f" {i}. {filename} (相似度: {similarity:.3f})") + else: + print(" 未找到匹配结果") + + # 图片相似搜索示例(如果有多张图片) + if info['total_images'] > 1: + print("\n5. 图片相似搜索示例:") + + # 获取第一张图片作为查询 + first_image_path = None + for image_id, data in db_manager.db.metadata.items(): + first_image_path = data['image_path'] + break + + if first_image_path: + print(f" 使用图片: {os.path.basename(first_image_path)}") + similar_results = db_manager.search_similar_images(first_image_path, k=5) + for i, result in enumerate(similar_results, 1): + filename = result['metadata']['filename'] + similarity = result['similarity'] + print(f" {i}. {filename} (相似度: {similarity:.3f})") + + # 查找重复图片 + print("\n6. 查找重复图片:") + duplicates = db_manager.find_duplicates(similarity_threshold=0.9) + if duplicates: + print(f" 找到 {len(duplicates)} 组相似图片:") + for i, group in enumerate(duplicates, 1): + print(f" 组 {i}:") + for item in group: + filename = item['metadata']['filename'] + similarity = item['similarity'] + print(f" - {filename} (相似度: {similarity:.3f})") + else: + print(" 未找到重复图片") + + # 导出元数据 + print("\n7. 导出数据库元数据...") + export_path = os.path.join(current_dir, 'database_metadata.json') + db_manager.export_metadata(export_path) + print(f" 元数据已导出到: {export_path}") + + print("\n=== 演示完成 ===") + + except ImportError as e: + if "faiss" in str(e): + print("错误: 缺少faiss库") + print("请安装faiss库:") + print(" CPU版本: pip install faiss-cpu") + print(" GPU版本: pip install faiss-gpu") + else: + print(f"导入错误: {e}") + sys.exit(1) + except Exception as e: + print(f"运行错误: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/spx-algorithm/resource/vector_db/demo_vector_db/index.faiss b/spx-algorithm/resource/vector_db/demo_vector_db/index.faiss new file mode 100644 index 000000000..99335b2fc Binary files /dev/null and b/spx-algorithm/resource/vector_db/demo_vector_db/index.faiss differ diff --git a/spx-algorithm/resource/vector_db/demo_vector_db/metadata.pkl b/spx-algorithm/resource/vector_db/demo_vector_db/metadata.pkl new file mode 100644 index 000000000..6ecc6b0c0 Binary files /dev/null and b/spx-algorithm/resource/vector_db/demo_vector_db/metadata.pkl differ diff --git a/spx-algorithm/resource/vector_db/dog.svg b/spx-algorithm/resource/vector_db/dog.svg new file mode 100644 index 000000000..2c66245a6 --- /dev/null +++ b/spx-algorithm/resource/vector_db/dog.svg @@ -0,0 +1,213 @@ + + + + + + + + + diff --git a/spx-algorithm/resource/vector_db/example_db/index.faiss b/spx-algorithm/resource/vector_db/example_db/index.faiss new file mode 100644 index 000000000..c65dbecc4 Binary files /dev/null and b/spx-algorithm/resource/vector_db/example_db/index.faiss differ diff --git a/spx-algorithm/resource/vector_db/example_db/metadata.pkl b/spx-algorithm/resource/vector_db/example_db/metadata.pkl new file mode 100644 index 000000000..041fc6684 Binary files /dev/null and b/spx-algorithm/resource/vector_db/example_db/metadata.pkl differ diff --git a/spx-algorithm/resource/vector_db/requirements_vector_db.txt b/spx-algorithm/resource/vector_db/requirements_vector_db.txt new file mode 100644 index 000000000..85d203e37 --- /dev/null +++ b/spx-algorithm/resource/vector_db/requirements_vector_db.txt @@ -0,0 +1,8 @@ +# 向量数据库依赖库 +torch>=1.9.0 +open-clip-torch +Pillow>=8.0.0 +numpy>=1.20.0 +faiss-cpu>=1.7.0 # 或者使用 faiss-gpu 如果有GPU支持 +requests>=2.25.0 +cairosvg>=2.5.0 # SVG图片支持 \ No newline at end of file diff --git a/spx-algorithm/resource/vector_db/simple_usage_example.py b/spx-algorithm/resource/vector_db/simple_usage_example.py new file mode 100644 index 000000000..82a6e14d0 --- /dev/null +++ b/spx-algorithm/resource/vector_db/simple_usage_example.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python3 +""" +简单的向量数据库使用示例 +展示如何在其他项目中集成和使用向量数据库 +""" +import os +from vector_database import VectorDatabase +from vector_db_utils import VectorDBManager + +def example_basic_usage(): + """基本使用示例""" + print("=== 基本使用示例 ===") + + # 1. 创建向量数据库实例 + db = VectorDatabase(db_path='example_db') + + # 2. 添加单张图片 + current_dir = os.path.dirname(os.path.abspath(__file__)) + image_path = os.path.join(current_dir, 'cute.svg') + + if os.path.exists(image_path): + image_id = db.add_image(image_path, metadata={'category': 'cute', 'type': 'svg'}) + print(f"添加图片成功,ID: {image_id}") + + # 3. 文本搜索 + results = db.search_by_text('cute animal', k=3) + print(f"搜索 'cute animal' 的结果:") + for result in results: + print(f" - {os.path.basename(result['image_path'])}: {result['similarity']:.3f}") + + print() + +def example_batch_processing(): + """批量处理示例""" + print("=== 批量处理示例 ===") + + # 使用管理器进行批量操作 + manager = VectorDBManager(db_path='batch_example_db') + + # 索引当前目录 + current_dir = os.path.dirname(os.path.abspath(__file__)) + stats = manager.index_directory(current_dir) + print(f"索引统计: 总计 {stats['total']} 个文件,成功 {stats['indexed']} 个") + + # 搜索示例 + results = manager.search_by_description('dog', k=2) + print(f"搜索 'dog' 的结果:") + for result in results: + print(f" - {result['metadata']['filename']}: {result['similarity']:.3f}") + + print() + +def example_integration_with_existing_service(): + """与现有服务集成的示例""" + print("=== 与现有服务集成示例 ===") + + # 模拟与现有ImageSearchService的集成 + print("1. 可以复用相同的模型配置") + print(" - 模型: ViT-B-32") + print(" - 预训练权重: laion2b_s34b_b79k") + print(" - 支持相同的图片格式(包括SVG)") + + print("\n2. 向量数据库的优势:") + print(" - 预计算图片向量,搜索更快") + print(" - 支持大规模图片集合") + print(" - 持久化存储,无需重复计算") + + print("\n3. 使用场景:") + print(" - 图片库管理系统") + print(" - 相似图片推荐") + print(" - 重复图片检测") + print(" - 基于内容的图片搜索") + + print() + +def main(): + """主函数""" + try: + example_basic_usage() + example_batch_processing() + example_integration_with_existing_service() + + print("=== 所有示例运行完成 ===") + + except Exception as e: + print(f"运行出错: {e}") + print("请确保已安装所需依赖: pip install -r requirements_vector_db.txt") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/spx-algorithm/resource/vector_db/vector_database.py b/spx-algorithm/resource/vector_db/vector_database.py new file mode 100644 index 000000000..46b9ac8ae --- /dev/null +++ b/spx-algorithm/resource/vector_db/vector_database.py @@ -0,0 +1,417 @@ +import os +import pickle +import logging +import io +from typing import List, Dict, Any, Optional +import numpy as np +import torch +import open_clip +from PIL import Image +try: + import faiss +except ImportError: + faiss = None +try: + import cairosvg +except ImportError: + cairosvg = None +from datetime import datetime + +logger = logging.getLogger(__name__) + + +class VectorDatabase: + """基于open-clip的图像向量数据库""" + + def __init__(self, + model_name: str = 'ViT-B-32', + pretrained: str = 'laion2b_s34b_b79k', + db_path: str = 'vector_db', + dimension: int = 512): + """ + 初始化向量数据库 + + Args: + model_name: CLIP模型名称 + pretrained: 预训练权重 + db_path: 数据库存储路径 + dimension: 向量维度 + """ + self.model_name = model_name + self.pretrained = pretrained + self.db_path = db_path + self.dimension = dimension + self.device = "cuda" if torch.cuda.is_available() else "cpu" + + # 模型相关 + self.model = None + self.preprocess = None + self.tokenizer = None + + # 数据库相关 + self.index = None + self.metadata = {} + self.id_counter = 0 + + self._init_model() + self._init_database() + + def _init_model(self): + """初始化CLIP模型""" + try: + self.model, _, self.preprocess = open_clip.create_model_and_transforms( + self.model_name, pretrained=self.pretrained + ) + self.tokenizer = open_clip.get_tokenizer(self.model_name) + self.model = self.model.to(self.device) + self.model.eval() + logger.info(f"模型加载成功: {self.model_name} on {self.device}") + except Exception as e: + logger.error(f"模型加载失败: {e}") + raise + + def _init_database(self): + """初始化向量数据库""" + if faiss is None: + raise ImportError("faiss库未安装,请运行: pip install faiss-cpu 或 pip install faiss-gpu") + + os.makedirs(self.db_path, exist_ok=True) + + # 初始化Faiss索引 + self.index = faiss.IndexFlatIP(self.dimension) # 内积相似度 + + # 加载已有数据 + self._load_database() + + def _save_database(self): + """保存数据库到文件""" + try: + # 保存索引 + index_path = os.path.join(self.db_path, 'index.faiss') + faiss.write_index(self.index, index_path) + + # 保存元数据 + metadata_path = os.path.join(self.db_path, 'metadata.pkl') + with open(metadata_path, 'wb') as f: + pickle.dump({ + 'metadata': self.metadata, + 'id_counter': self.id_counter, + 'dimension': self.dimension, + 'model_name': self.model_name, + 'created_at': datetime.now().isoformat() + }, f) + + logger.info(f"数据库已保存到: {self.db_path}") + except Exception as e: + logger.error(f"保存数据库失败: {e}") + raise + + def _load_database(self): + """从文件加载数据库""" + try: + index_path = os.path.join(self.db_path, 'index.faiss') + metadata_path = os.path.join(self.db_path, 'metadata.pkl') + + if os.path.exists(index_path) and os.path.exists(metadata_path): + # 加载索引 + self.index = faiss.read_index(index_path) + + # 加载元数据 + with open(metadata_path, 'rb') as f: + data = pickle.load(f) + self.metadata = data['metadata'] + self.id_counter = data['id_counter'] + + logger.info(f"数据库加载成功,包含 {len(self.metadata)} 个向量") + else: + logger.info("未找到已有数据库,创建新的数据库") + except Exception as e: + logger.error(f"加载数据库失败: {e}") + # 重新初始化 + self.index = faiss.IndexFlatIP(self.dimension) + self.metadata = {} + self.id_counter = 0 + + def _encode_image(self, image_path: str) -> Optional[np.ndarray]: + """ + 编码单张图片为向量,支持SVG格式 + + Args: + image_path: 图片路径 + + Returns: + 图片特征向量,失败返回None + """ + try: + if not os.path.exists(image_path): + logger.warning(f"图片文件不存在: {image_path}") + return None + + logger.info(f"处理本地图片: {image_path}") + + # 根据文件扩展名选择处理方式 + _, ext = os.path.splitext(image_path.lower()) + + if ext == '.svg': + # 处理SVG文件 + if cairosvg is None: + logger.error("SVG支持需要安装cairosvg库: pip install cairosvg") + return None + + png_data = cairosvg.svg2png( + url=image_path, + output_width=224, + output_height=224, + background_color='white' + ) + image = Image.open(io.BytesIO(png_data)).convert('RGB') + else: + # 处理其他格式图片 + image = Image.open(image_path).convert('RGB') + + logger.info(f"图片尺寸: {image.size}, 模式: {image.mode}") + + # 预处理图片 + image_tensor = self.preprocess(image).unsqueeze(0).to(self.device) + + with torch.no_grad(): + # 编码图片 + image_features = self.model.encode_image(image_tensor) + # 归一化 + image_features /= image_features.norm(dim=-1, keepdim=True) + + return image_features.cpu().numpy().astype('float32') + + except Exception as e: + logger.error(f"编码图片失败 {image_path}: {e}") + return None + + def _encode_text(self, text: str) -> Optional[np.ndarray]: + """ + 编码文本为向量 + + Args: + text: 文本内容 + + Returns: + 文本特征向量,失败返回None + """ + try: + text_tokens = self.tokenizer([text]).to(self.device) + + with torch.no_grad(): + text_features = self.model.encode_text(text_tokens) + text_features /= text_features.norm(dim=-1, keepdim=True) + + return text_features.cpu().numpy().astype('float32') + + except Exception as e: + logger.error(f"编码文本失败 {text}: {e}") + return None + + def add_image(self, image_path: str, metadata: Optional[Dict[str, Any]] = None) -> Optional[int]: + """ + 添加图片到向量数据库 + + Args: + image_path: 图片路径 + metadata: 图片元数据 + + Returns: + 图片ID,失败返回None + """ + vector = self._encode_image(image_path) + if vector is None: + return None + + # 添加到索引 + self.index.add(vector) + + # 保存元数据 + image_id = self.id_counter + self.metadata[image_id] = { + 'image_path': image_path, + 'metadata': metadata or {}, + 'added_at': datetime.now().isoformat() + } + self.id_counter += 1 + + # 保存数据库 + self._save_database() + + logger.info(f"图片已添加到数据库: {image_path} (ID: {image_id})") + return image_id + + def add_images_batch(self, image_paths: List[str], + metadatas: Optional[List[Dict[str, Any]]] = None) -> List[Optional[int]]: + """ + 批量添加图片到向量数据库 + + Args: + image_paths: 图片路径列表 + metadatas: 图片元数据列表 + + Returns: + 图片ID列表 + """ + if metadatas is None: + metadatas = [None] * len(image_paths) + + results = [] + vectors = [] + valid_indices = [] + + # 批量编码 + for i, (image_path, metadata) in enumerate(zip(image_paths, metadatas)): + vector = self._encode_image(image_path) + if vector is not None: + vectors.append(vector) + valid_indices.append(i) + + # 准备元数据 + image_id = self.id_counter + len(vectors) - 1 + self.metadata[image_id] = { + 'image_path': image_path, + 'metadata': metadata or {}, + 'added_at': datetime.now().isoformat() + } + results.append(image_id) + else: + results.append(None) + + if vectors: + # 批量添加到索引 + vectors_array = np.vstack(vectors) + self.index.add(vectors_array) + self.id_counter += len(vectors) + + # 保存数据库 + self._save_database() + + logger.info(f"批量添加 {len(vectors)} 张图片到数据库") + + return results + + def search_by_image(self, query_image_path: str, k: int = 10) -> List[Dict[str, Any]]: + """ + 通过图片搜索相似图片 + + Args: + query_image_path: 查询图片路径 + k: 返回结果数量 + + Returns: + 相似图片列表 + """ + query_vector = self._encode_image(query_image_path) + if query_vector is None: + return [] + + return self._search_by_vector(query_vector, k) + + def search_by_text(self, query_text: str, k: int = 10) -> List[Dict[str, Any]]: + """ + 通过文本搜索相似图片 + + Args: + query_text: 查询文本 + k: 返回结果数量 + + Returns: + 相似图片列表 + """ + query_vector = self._encode_text(query_text) + if query_vector is None: + return [] + + return self._search_by_vector(query_vector, k) + + def _search_by_vector(self, query_vector: np.ndarray, k: int) -> List[Dict[str, Any]]: + """ + 通过向量搜索 + + Args: + query_vector: 查询向量 + k: 返回结果数量 + + Returns: + 搜索结果列表 + """ + if self.index.ntotal == 0: + logger.warning("数据库为空") + return [] + + k = min(k, self.index.ntotal) + similarities, indices = self.index.search(query_vector, k) + + results = [] + for i, (similarity, idx) in enumerate(zip(similarities[0], indices[0])): + if idx in self.metadata: + result = { + 'id': int(idx), + 'similarity': float(similarity), + 'rank': i + 1, + **self.metadata[idx] + } + results.append(result) + + return results + + def get_stats(self) -> Dict[str, Any]: + """获取数据库统计信息""" + return { + 'total_images': self.index.ntotal, + 'dimension': self.dimension, + 'model_name': self.model_name, + 'pretrained': self.pretrained, + 'db_path': self.db_path, + 'device': self.device + } + + def remove_image(self, image_id: int) -> bool: + """ + 删除图片(注意:Faiss不支持直接删除,需要重建索引) + + Args: + image_id: 图片ID + + Returns: + 是否成功删除 + """ + if image_id not in self.metadata: + logger.warning(f"图片ID不存在: {image_id}") + return False + + # 删除元数据 + del self.metadata[image_id] + + # 重建索引(仅包含剩余的图片) + self._rebuild_index() + + logger.info(f"图片已删除: {image_id}") + return True + + def _rebuild_index(self): + """重建索引(用于删除操作后)""" + if not self.metadata: + self.index = faiss.IndexFlatIP(self.dimension) + self._save_database() + return + + # 重新编码所有剩余的图片 + vectors = [] + for image_id, data in self.metadata.items(): + vector = self._encode_image(data['image_path']) + if vector is not None: + vectors.append(vector) + else: + # 如果重新编码失败,从元数据中删除 + del self.metadata[image_id] + + # 重建索引 + self.index = faiss.IndexFlatIP(self.dimension) + if vectors: + vectors_array = np.vstack(vectors) + self.index.add(vectors_array) + + self._save_database() \ No newline at end of file diff --git a/spx-algorithm/resource/vector_db/vector_db_utils.py b/spx-algorithm/resource/vector_db/vector_db_utils.py new file mode 100644 index 000000000..4906c8cd4 --- /dev/null +++ b/spx-algorithm/resource/vector_db/vector_db_utils.py @@ -0,0 +1,213 @@ +import os +import logging +from typing import List, Dict, Any, Optional +from vector_database import VectorDatabase + +logger = logging.getLogger(__name__) + + +class VectorDBManager: + """向量数据库管理工具类""" + + def __init__(self, db_path: str = 'vector_db'): + """ + 初始化管理器 + + Args: + db_path: 数据库路径 + """ + self.db_path = db_path + self.db = VectorDatabase(db_path=db_path) + + def index_directory(self, directory_path: str, + supported_extensions: List[str] = None) -> Dict[str, Any]: + """ + 索引目录下的所有图片 + + Args: + directory_path: 目录路径 + supported_extensions: 支持的文件扩展名 + + Returns: + 索引结果统计 + """ + if supported_extensions is None: + supported_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.svg'] + + if not os.path.exists(directory_path): + logger.error(f"目录不存在: {directory_path}") + return {'success': False, 'error': '目录不存在'} + + # 收集所有图片文件 + image_files = [] + for root, dirs, files in os.walk(directory_path): + for file in files: + _, ext = os.path.splitext(file.lower()) + if ext in supported_extensions: + image_path = os.path.join(root, file) + image_files.append(image_path) + + if not image_files: + logger.warning(f"目录中未找到支持的图片文件: {directory_path}") + return {'success': True, 'indexed': 0, 'failed': 0, 'total': 0} + + logger.info(f"找到 {len(image_files)} 个图片文件,开始索引...") + + # 准备元数据 + metadatas = [] + for image_path in image_files: + metadata = { + 'filename': os.path.basename(image_path), + 'directory': os.path.dirname(image_path), + 'file_size': os.path.getsize(image_path), + 'relative_path': os.path.relpath(image_path, directory_path) + } + metadatas.append(metadata) + + # 批量添加到数据库 + results = self.db.add_images_batch(image_files, metadatas) + + # 统计结果 + indexed = sum(1 for r in results if r is not None) + failed = len(results) - indexed + + stats = { + 'success': True, + 'total': len(image_files), + 'indexed': indexed, + 'failed': failed, + 'directory': directory_path + } + + logger.info(f"索引完成: {stats}") + return stats + + def search_similar_images(self, query_path: str, k: int = 10) -> List[Dict[str, Any]]: + """ + 搜索相似图片 + + Args: + query_path: 查询图片路径 + k: 返回结果数量 + + Returns: + 相似图片列表 + """ + return self.db.search_by_image(query_path, k) + + def search_by_description(self, description: str, k: int = 10) -> List[Dict[str, Any]]: + """ + 通过文本描述搜索图片 + + Args: + description: 文本描述 + k: 返回结果数量 + + Returns: + 匹配的图片列表 + """ + return self.db.search_by_text(description, k) + + def get_database_info(self) -> Dict[str, Any]: + """获取数据库信息""" + return self.db.get_stats() + + def export_metadata(self, output_file: str = None) -> Dict[str, Any]: + """ + 导出数据库元数据 + + Args: + output_file: 输出文件路径(可选) + + Returns: + 元数据字典 + """ + metadata_export = { + 'database_stats': self.db.get_stats(), + 'images': {} + } + + for image_id, data in self.db.metadata.items(): + metadata_export['images'][image_id] = data + + if output_file: + import json + with open(output_file, 'w', encoding='utf-8') as f: + json.dump(metadata_export, f, indent=2, ensure_ascii=False) + logger.info(f"元数据已导出到: {output_file}") + + return metadata_export + + def find_duplicates(self, similarity_threshold: float = 0.95) -> List[List[Dict[str, Any]]]: + """ + 查找重复或相似的图片 + + Args: + similarity_threshold: 相似度阈值 + + Returns: + 重复图片组列表 + """ + duplicates = [] + processed_ids = set() + + for image_id, data in self.db.metadata.items(): + if image_id in processed_ids: + continue + + # 搜索相似图片 + similar_images = self.db.search_by_image(data['image_path'], k=10) + + # 过滤高相似度图片 + duplicate_group = [] + for similar in similar_images: + if similar['similarity'] >= similarity_threshold and similar['id'] not in processed_ids: + duplicate_group.append(similar) + processed_ids.add(similar['id']) + + if len(duplicate_group) > 1: + duplicates.append(duplicate_group) + + logger.info(f"找到 {len(duplicates)} 组重复图片") + return duplicates + + +def create_sample_database(resource_dir: str = 'resource') -> VectorDBManager: + """ + 创建示例数据库 + + Args: + resource_dir: 资源目录路径 + + Returns: + 数据库管理器实例 + """ + db_manager = VectorDBManager(db_path=os.path.join(resource_dir, 'sample_vector_db')) + + # 如果资源目录存在图片,则索引它们 + if os.path.exists(resource_dir): + logger.info(f"索引资源目录: {resource_dir}") + stats = db_manager.index_directory(resource_dir) + logger.info(f"索引统计: {stats}") + + return db_manager + + +if __name__ == "__main__": + # 设置日志 + logging.basicConfig(level=logging.INFO) + + # 创建示例数据库 + manager = create_sample_database() + + # 显示数据库信息 + info = manager.get_database_info() + print(f"数据库信息: {info}") + + # 示例搜索 + if info['total_images'] > 0: + # 文本搜索示例 + results = manager.search_by_description("cute animal", k=3) + print(f"文本搜索结果: {len(results)} 张图片") + for result in results: + print(f" {result['metadata']['filename']}: {result['similarity']:.3f}") \ No newline at end of file diff --git a/spx-backend/.env.dev b/spx-backend/.env.dev index 907a90944..c431042e4 100644 --- a/spx-backend/.env.dev +++ b/spx-backend/.env.dev @@ -1,65 +1,76 @@ -PORT=:8080 -ALLOWED_ORIGIN=* +PORT= +ALLOWED_ORIGIN= # Use local DB by default for dev -GOP_SPX_DSN=root:123456@tcp(127.0.0.1:3306)/builder?charset=utf8mb4&parseTime=True&loc=UTC +GOP_SPX_DSN= # AIGC Service -AIGC_ENDPOINT=http://36.213.14.15:8888 +AIGC_ENDPOINT= -# Redis -REDIS_ADDR= -REDIS_PASSWORD= -REDIS_DB= -REDIS_POOL_SIZE= - -# Qiniu Kodo +# AK & SK for Qiniu (Kodo) KODO_AK= KODO_SK= +# Kodo Bucket Name KODO_BUCKET= +# Kodo Bucket Region KODO_BUCKET_REGION= +# Base URL to read file from Kodo KODO_BASE_URL= -# Use casdoor service of test env for quick start -GOP_CASDOOR_ENDPOINT="https://goplus-casdoor.qiniu.io" -GOP_CASDOOR_CLIENTID="389313df51ffd2093b2f" -GOP_CASDOOR_CLIENTSECRET="bd22db46e4ab52ab0d0ac9cbd771a21775039409" -GOP_CASDOOR_CERTIFICATE="-----BEGIN CERTIFICATE----- -MIIE3TCCAsWgAwIBAgIDAeJAMA0GCSqGSIb3DQEBCwUAMCgxDjAMBgNVBAoTBWFk -bWluMRYwFAYDVQQDEw1jZXJ0LWJ1aWx0LWluMB4XDTI0MDEyMjAxMjcyM1oXDTQ0 -MDEyMjAxMjcyM1owKDEOMAwGA1UEChMFYWRtaW4xFjAUBgNVBAMTDWNlcnQtYnVp -bHQtaW4wggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAwggIKAoICAQC4Qfq85ynOMFbV -vAk+IDnt7B+3N2ntw5udQptjlVMDptDmL2s5J7uguWzf+QcjIYZMR7r8PaSugybf -k5AVaQgHgvM3L8uUhSdLJGRBCfxSUZCAvVV8QbnP1h4oqOn1r1ZEGUvLGUD1WG+0 -fCNLhG1xYZE9yOyiJ6uEa/aMavgSswYSay9sc0gCJ0FE0gldFt+T3r3RNxQK3hl9 -3i7ifYbvCrdvTHbdVwciUWD9xC6JD2gL7KkNwCmL+PB4+XLZ+6/v4XGfK+OZ6/dN -JuEJg+KR/v5eV7am5vp35mk/9evLYm7cjHHNZ49VIoz1wwHaApDVmRTE/ilGVcZy -Ynxq/mI3bl3PLd/yK6gB1itD97XRf3JUXRjFlvhCZITYPhhMliiGi2cQzp++sxGU -osciNI+CkQRkMOOfJa+2bwexNocfr1d01xkNw02PpRqBKGG2eJVN1pmXQH122bfT -ReLWYQGUOIvwJSX/nqPZU+OI9wQOUOJkA49G/652AOTSphidPY83qiJ1xQ6Qsftj -KHDT7XrN7l00FY7D2L++CwYybpwcfva9VsAsMsd2X+/r9+Qpla1USSfxFuaNFLYI -qZKlHNgqNl7YY+uThzzFjL8x96u2ar8emRDqpE2HCirnqfYmZ5VldWkwJ9wKqRIv -IYaqZAmSs4hJovB+sujgLOKNeTCOkwIDAQABoxAwDjAMBgNVHRMBAf8EAjAAMA0G -CSqGSIb3DQEBCwUAA4ICAQCo3+PvCMnGpY28v+JLFxrxC2aP/ZipUGsBnWGU5Czc -/J2gs7MdNLny0QDVvjEBCIWGu2YkNkii7Z8WskGGDmJCClN6N91aFizfDK+f49LW -7KXnqUztATf848ypPDcbhjcTwtvU8ABo7BLAfNO9TVc4nPLNjXMetojAj8S+n56o -qghbE7UaHd2d2q/ETFqunb8BbBPjw0Fv8VcpuIFrgKFEt+5nq8miYoGrBySPx/CZ -Whm4PZ1HIobFXkfN0FhoX++sArL4FAacisBxqX9HYYnwGYvXzUtmdFNfqctKH2Gt -YiSxoI1Jf9YysAroiIRxWNNr+O+Nkn0iIhkScKhvpXhLMnv1b2MIxwURUUnxehE1 -X4OALKxIZNyDvQUHZukzbVNQRJUorg+g4wPuCOkfvbNlHAFZerLxbigN1rafqH1j -Pt/ORY44z+IeU7zc3zW0Gug5ws5X6IWgYwlZODhEVd8jELk180SF0jEMqmbFud0I -SqmP0/koI7IzNhwqLilLWa3pUKs8AOxW4MCff/1fLB9IhR6c5+SG2S7ZbupuxXZS -hLX23sQstIcgUKEhMTger6IyFreuQY3kXDKSSReE6IA9HLUtYdykfUR0oL7pQwMx -ZNIZlCk3CrqUzVnAln07K9++5egCisZ0XADLp1cgCZk0NRM7a2LuMDx/iZXw1QaQ -Iw== ------END CERTIFICATE-----" -GOP_CASDOOR_ORGANIZATIONNAME="GoPlus" -GOP_CASDOOR_APPLICATIONNAME="application_x8aevk" +# Casdoor config +GOP_CASDOOR_ENDPOINT= +GOP_CASDOOR_CLIENTID= +GOP_CASDOOR_CLIENTSECRET= +GOP_CASDOOR_CERTIFICATE= +GOP_CASDOOR_ORGANIZATIONNAME= +GOP_CASDOOR_APPLICATIONNAME= -# OpenAI +# OpenAI compatible OPENAI_API_KEY= OPENAI_API_ENDPOINT= OPENAI_MODEL_ID= + + # OpenAI (PREMIUM) OPENAI_PREMIUM_API_KEY= OPENAI_PREMIUM_API_ENDPOINT= OPENAI_PREMIUM_MODEL_ID= + +# SVG Generation Providers + +# SVG.IO Provider +SVGIO_ENABLED= +SVGIO_BASE_URL= +SVGIO_GENERATE_ENDPOINT= +SVGIO_TIMEOUT= +SVGIO_MAX_RETRIES= +SVGIO_API_KEY= + +# Recraft Provider +RECRAFT_ENABLED= +RECRAFT_BASE_URL= +RECRAFT_GENERATE_ENDPOINT= +RECRAFT_VECTORIZE_ENDPOINT= +RECRAFT_DEFAULT_MODEL= +RECRAFT_SUPPORTED_MODELS= +RECRAFT_TIMEOUT= +RECRAFT_MAX_RETRIES= +RECRAFT_API_KEY= + +# OpenAI Provider for SVG Generation (supports all OpenAI compatible models) +SVG_OPENAI_ENABLED= +SVG_OPENAI_BASE_URL= +SVG_OPENAI_DEFAULT_MODEL= +SVG_OPENAI_MAX_TOKENS= +SVG_OPENAI_TEMPERATURE= +SVG_OPENAI_TIMEOUT= +SVG_OPENAI_MAX_RETRIES= +SVG_OPENAI_API_KEY= + +# Translation Service for SVG Generation +TRANSLATION_ENABLED= +TRANSLATION_SERVICE_URL= +TRANSLATION_DEFAULT_MODEL= +TRANSLATION_TIMEOUT= +TRANSLATION_MAX_RETRIES= +TRANSLATION_API_KEY= + diff --git a/spx-backend/SVG_GENERATION_CONFIG.md b/spx-backend/SVG_GENERATION_CONFIG.md new file mode 100644 index 000000000..5b7085744 --- /dev/null +++ b/spx-backend/SVG_GENERATION_CONFIG.md @@ -0,0 +1,217 @@ +# SVG Generation Configuration Guide + +This document explains how to configure the SVG generation services in spx-backend. + +## Environment Variables + +The following environment variables have been added to support SVG generation functionality: + +### SVG.IO Provider + +SVG.IO is a direct SVG generation service that provides high-quality vector graphics. + +```bash +# Enable/disable the SVG.IO provider +SVGIO_ENABLED=true + +# SVG.IO API base URL +SVGIO_BASE_URL=https://api.svg.io + +# API endpoint for image generation +SVGIO_GENERATE_ENDPOINT=/v1/generate-image + +# Request timeout +SVGIO_TIMEOUT=60s + +# Maximum retry attempts +SVGIO_MAX_RETRIES=3 + +# API key (required for production use) +# SVGIO_API_KEY=your_svgio_api_key_here +``` + +### Recraft Provider + +Recraft is an AI-powered image generation service that can create high-quality images and convert them to SVG format. + +```bash +# Enable/disable the Recraft provider +RECRAFT_ENABLED=true + +# Recraft API base URL +RECRAFT_BASE_URL=https://external.api.recraft.ai + +# API endpoints +RECRAFT_GENERATE_ENDPOINT=/v1/images/generations +RECRAFT_VECTORIZE_ENDPOINT=/v1/images/vectorize + +# Default model to use +RECRAFT_DEFAULT_MODEL=recraftv3 + +# Supported models (comma-separated) +RECRAFT_SUPPORTED_MODELS=recraftv3,recraftv2 + +# Request timeout +RECRAFT_TIMEOUT=60s + +# Maximum retry attempts +RECRAFT_MAX_RETRIES=3 + +# API key (required for production use) +# RECRAFT_API_KEY=your_recraft_api_key_here +``` + +### OpenAI Compatible Provider for SVG Generation + +This provider uses OpenAI-compatible APIs to generate SVG code directly using large language models. + +```bash +# Enable/disable the OpenAI SVG provider +SVG_OPENAI_ENABLED=true + +# API base URL (supports OpenAI, Claude, and other compatible services) +SVG_OPENAI_BASE_URL=https://api.qnaigc.com/v1/ + +# Default model to use for SVG generation +SVG_OPENAI_DEFAULT_MODEL=claude-4.0-sonnet + +# Maximum tokens for generation +SVG_OPENAI_MAX_TOKENS=4000 + +# Temperature for generation (0.0 - 1.0) +SVG_OPENAI_TEMPERATURE=0.7 + +# Request timeout +SVG_OPENAI_TIMEOUT=60s + +# Maximum retry attempts +SVG_OPENAI_MAX_RETRIES=3 + +# API key (required for production use) +# SVG_OPENAI_API_KEY=your_openai_compatible_api_key_here +``` + +### Translation Service + +The translation service is used to translate prompts from Chinese to English when using providers that don't support Chinese input. + +```bash +# Enable/disable translation service +TRANSLATION_ENABLED=true + +# Translation service URL +TRANSLATION_SERVICE_URL=https://api.qnaigc.com/v1/chat/completions + +# Model to use for translation +TRANSLATION_DEFAULT_MODEL=claude-4.0-sonnet + +# Request timeout +TRANSLATION_TIMEOUT=45s + +# Maximum retry attempts +TRANSLATION_MAX_RETRIES=2 + +# API key for translation service +# TRANSLATION_API_KEY=your_translation_api_key_here +``` + +## API Endpoints + +The following API endpoints are now available: + +### POST /image/svg +Generates and returns SVG content directly. + +**Request Body:** +```json +{ + "prompt": "A cute cartoon cat sitting on a cloud", + "style": "vector_illustration", + "provider": "svgio", + "negative_prompt": "ugly, blurred", + "skip_translate": false, + "model": "recraftv3", + "size": "1024x1024", + "substyle": "flat", + "n": 1 +} +``` + +**Response:** +- Content-Type: `image/svg+xml` +- Body: SVG file content +- Headers include metadata (ID, dimensions, provider, etc.) + +### POST /image +Generates an image and returns metadata information. + +**Request Body:** Same as above + +**Response:** +```json +{ + "id": "svgio_1234567890", + "svg_url": "https://example.com/generated.svg", + "png_url": "https://example.com/generated.png", + "width": 1024, + "height": 1024, + "provider": "svgio", + "original_prompt": "一只可爱的猫", + "translated_prompt": "A cute cat", + "was_translated": true, + "created_at": "2025-01-01T12:00:00Z" +} +``` + +## Provider Selection + +You can specify which provider to use in the request: + +- `"svgio"` - SVG.IO service (default) +- `"recraft"` - Recraft AI service +- `"openai"` - OpenAI-compatible LLM service + +Each provider has different strengths: + +- **SVG.IO**: Direct SVG generation, good for simple vector graphics +- **Recraft**: High-quality AI image generation with vectorization +- **OpenAI**: LLM-powered SVG code generation, highly customizable + +## Development Setup + +1. Copy the environment variables from `.env.dev` to your local environment +2. Uncomment and set the API keys for the providers you want to use +3. Adjust the base URLs and models as needed for your setup +4. Start the server and test the endpoints + +## Production Considerations + +1. **API Keys**: Ensure all production API keys are properly secured +2. **Rate Limiting**: Consider implementing rate limiting for the SVG generation endpoints +3. **Caching**: Consider caching generated SVGs to reduce API costs +4. **Monitoring**: Monitor API usage and costs for each provider +5. **Fallbacks**: Configure multiple providers for redundancy + +## Example Usage + +```bash +# Generate SVG directly +curl -X POST http://localhost:8080/image/svg \ + -H "Content-Type: application/json" \ + -d '{ + "prompt": "A minimalist mountain landscape", + "style": "flat_vector", + "provider": "svgio" + }' \ + -o generated.svg + +# Get image metadata +curl -X POST http://localhost:8080/image \ + -H "Content-Type: application/json" \ + -d '{ + "prompt": "A futuristic city skyline", + "provider": "recraft", + "model": "recraftv3", + "size": "1024x1024" + }' +``` \ No newline at end of file diff --git a/spx-backend/cmd/spx-backend/get_health.yap b/spx-backend/cmd/spx-backend/get_health.yap new file mode 100644 index 000000000..b200cefaf --- /dev/null +++ b/spx-backend/cmd/spx-backend/get_health.yap @@ -0,0 +1,13 @@ +// test health. +// +// Request: +// GET /health + +ctx := &Context + +result, err := ctrl.Health(ctx.Context()) +if err != nil { + replyWithInnerError(ctx, err) + return +} +json result \ No newline at end of file diff --git a/spx-backend/cmd/spx-backend/get_themes.yap b/spx-backend/cmd/spx-backend/get_themes.yap new file mode 100644 index 000000000..383729688 --- /dev/null +++ b/spx-backend/cmd/spx-backend/get_themes.yap @@ -0,0 +1,18 @@ +// Get all available themes for image generation. +// +// Request: +// GET /themes + +import ( + "github.com/goplus/builder/spx-backend/internal/controller" +) + +ctx := &Context + +result, err := ctrl.GetThemes(ctx.Context()) +if err != nil { + replyWithInnerError(ctx, err) + return +} + +json result \ No newline at end of file diff --git a/spx-backend/cmd/spx-backend/post_image.yap b/spx-backend/cmd/spx-backend/post_image.yap new file mode 100644 index 000000000..b51e7c163 --- /dev/null +++ b/spx-backend/cmd/spx-backend/post_image.yap @@ -0,0 +1,27 @@ +// Generate image and return metadata. +// +// Request: +// POST /image + +import ( + "github.com/goplus/builder/spx-backend/internal/controller" +) + +ctx := &Context + +params := &controller.GenerateImageParams{} +if !parseJSON(ctx, params) { + return +} +if ok, msg := params.Validate(); !ok { + replyWithCodeMsg(ctx, errorInvalidArgs, msg) + return +} + +result, err := ctrl.GenerateImage(ctx.Context(), params) +if err != nil { + replyWithInnerError(ctx, err) + return +} + +json result \ No newline at end of file diff --git a/spx-backend/cmd/spx-backend/post_image_svg.yap b/spx-backend/cmd/spx-backend/post_image_svg.yap new file mode 100644 index 000000000..8a7bf4107 --- /dev/null +++ b/spx-backend/cmd/spx-backend/post_image_svg.yap @@ -0,0 +1,34 @@ +// Generate SVG image directly. +// +// Request: +// POST /image/svg + +import ( + "github.com/goplus/builder/spx-backend/internal/controller" +) + +ctx := &Context + +params := &controller.GenerateSVGParams{} +if !parseJSON(ctx, params) { + return +} +if ok, msg := params.Validate(); !ok { + replyWithCodeMsg(ctx, errorInvalidArgs, msg) + return +} + +result, err := ctrl.GenerateSVG(ctx.Context(), params) +if err != nil { + replyWithInnerError(ctx, err) + return +} + +// Set response headers +for key, value := range result.Headers { + ctx.ResponseWriter.Header().Set(key, value) +} + +// Write SVG content directly +ctx.ResponseWriter.WriteHeader(200) +ctx.ResponseWriter.Write(result.Data) \ No newline at end of file diff --git a/spx-backend/cmd/spx-backend/xgo_autogen.go b/spx-backend/cmd/spx-backend/xgo_autogen.go index 8313931f8..6c9041616 100644 --- a/spx-backend/cmd/spx-backend/xgo_autogen.go +++ b/spx-backend/cmd/spx-backend/xgo_autogen.go @@ -76,6 +76,10 @@ type get_courses_list struct { yap.Handler *AppV2 } +type get_health struct { + yap.Handler + *AppV2 +} type get_project_release_owner_project_release struct { yap.Handler *AppV2 @@ -148,6 +152,14 @@ type post_course struct { yap.Handler *AppV2 } +type post_image struct { + yap.Handler + *AppV2 +} +type post_image_svg struct { + yap.Handler + *AppV2 +} type post_project_release struct { yap.Handler *AppV2 @@ -295,36 +307,39 @@ func (this *AppV2) Main() { _xgo_obj9 := &get_course_serieses_list{AppV2: this} _xgo_obj10 := &get_course_id{AppV2: this} _xgo_obj11 := &get_courses_list{AppV2: this} - _xgo_obj12 := &get_project_release_owner_project_release{AppV2: this} - _xgo_obj13 := &get_project_releases_list{AppV2: this} - _xgo_obj14 := &get_project_owner_name{AppV2: this} - _xgo_obj15 := &get_project_owner_name_liking{AppV2: this} - _xgo_obj16 := &get_projects_list{AppV2: this} - _xgo_obj17 := &get_user{AppV2: this} - _xgo_obj18 := &get_user_username{AppV2: this} - _xgo_obj19 := &get_user_username_following{AppV2: this} - _xgo_obj20 := &get_users_list{AppV2: this} - _xgo_obj21 := &get_util_upinfo{AppV2: this} - _xgo_obj22 := &post_ai_interaction_turn{AppV2: this} - _xgo_obj23 := &post_aigc_matting{AppV2: this} - _xgo_obj24 := &post_asset{AppV2: this} - _xgo_obj25 := &post_copilot_message{AppV2: this} - _xgo_obj26 := &post_copilot_stream_message{AppV2: this} - _xgo_obj27 := &post_course_series{AppV2: this} - _xgo_obj28 := &post_course{AppV2: this} - _xgo_obj29 := &post_project_release{AppV2: this} - _xgo_obj30 := &post_project{AppV2: this} - _xgo_obj31 := &post_project_owner_name_liking{AppV2: this} - _xgo_obj32 := &post_project_owner_name_view{AppV2: this} - _xgo_obj33 := &post_user_username_following{AppV2: this} - _xgo_obj34 := &post_util_fileurls{AppV2: this} - _xgo_obj35 := &post_workflow_stream_message{AppV2: this} - _xgo_obj36 := &put_asset_id{AppV2: this} - _xgo_obj37 := &put_course_series_id{AppV2: this} - _xgo_obj38 := &put_course_id{AppV2: this} - _xgo_obj39 := &put_project_owner_name{AppV2: this} - _xgo_obj40 := &put_user{AppV2: this} - yap.Gopt_AppV2_Main(this, _xgo_obj0, _xgo_obj1, _xgo_obj2, _xgo_obj3, _xgo_obj4, _xgo_obj5, _xgo_obj6, _xgo_obj7, _xgo_obj8, _xgo_obj9, _xgo_obj10, _xgo_obj11, _xgo_obj12, _xgo_obj13, _xgo_obj14, _xgo_obj15, _xgo_obj16, _xgo_obj17, _xgo_obj18, _xgo_obj19, _xgo_obj20, _xgo_obj21, _xgo_obj22, _xgo_obj23, _xgo_obj24, _xgo_obj25, _xgo_obj26, _xgo_obj27, _xgo_obj28, _xgo_obj29, _xgo_obj30, _xgo_obj31, _xgo_obj32, _xgo_obj33, _xgo_obj34, _xgo_obj35, _xgo_obj36, _xgo_obj37, _xgo_obj38, _xgo_obj39, _xgo_obj40) + _xgo_obj12 := &get_health{AppV2: this} + _xgo_obj13 := &get_project_release_owner_project_release{AppV2: this} + _xgo_obj14 := &get_project_releases_list{AppV2: this} + _xgo_obj15 := &get_project_owner_name{AppV2: this} + _xgo_obj16 := &get_project_owner_name_liking{AppV2: this} + _xgo_obj17 := &get_projects_list{AppV2: this} + _xgo_obj18 := &get_user{AppV2: this} + _xgo_obj19 := &get_user_username{AppV2: this} + _xgo_obj20 := &get_user_username_following{AppV2: this} + _xgo_obj21 := &get_users_list{AppV2: this} + _xgo_obj22 := &get_util_upinfo{AppV2: this} + _xgo_obj23 := &post_ai_interaction_turn{AppV2: this} + _xgo_obj24 := &post_aigc_matting{AppV2: this} + _xgo_obj25 := &post_asset{AppV2: this} + _xgo_obj26 := &post_copilot_message{AppV2: this} + _xgo_obj27 := &post_copilot_stream_message{AppV2: this} + _xgo_obj28 := &post_course_series{AppV2: this} + _xgo_obj29 := &post_course{AppV2: this} + _xgo_obj30 := &post_image{AppV2: this} + _xgo_obj31 := &post_image_svg{AppV2: this} + _xgo_obj32 := &post_project_release{AppV2: this} + _xgo_obj33 := &post_project{AppV2: this} + _xgo_obj34 := &post_project_owner_name_liking{AppV2: this} + _xgo_obj35 := &post_project_owner_name_view{AppV2: this} + _xgo_obj36 := &post_user_username_following{AppV2: this} + _xgo_obj37 := &post_util_fileurls{AppV2: this} + _xgo_obj38 := &post_workflow_stream_message{AppV2: this} + _xgo_obj39 := &put_asset_id{AppV2: this} + _xgo_obj40 := &put_course_series_id{AppV2: this} + _xgo_obj41 := &put_course_id{AppV2: this} + _xgo_obj42 := &put_project_owner_name{AppV2: this} + _xgo_obj43 := &put_user{AppV2: this} + yap.Gopt_AppV2_Main(this, _xgo_obj0, _xgo_obj1, _xgo_obj2, _xgo_obj3, _xgo_obj4, _xgo_obj5, _xgo_obj6, _xgo_obj7, _xgo_obj8, _xgo_obj9, _xgo_obj10, _xgo_obj11, _xgo_obj12, _xgo_obj13, _xgo_obj14, _xgo_obj15, _xgo_obj16, _xgo_obj17, _xgo_obj18, _xgo_obj19, _xgo_obj20, _xgo_obj21, _xgo_obj22, _xgo_obj23, _xgo_obj24, _xgo_obj25, _xgo_obj26, _xgo_obj27, _xgo_obj28, _xgo_obj29, _xgo_obj30, _xgo_obj31, _xgo_obj32, _xgo_obj33, _xgo_obj34, _xgo_obj35, _xgo_obj36, _xgo_obj37, _xgo_obj38, _xgo_obj39, _xgo_obj40, _xgo_obj41, _xgo_obj42, _xgo_obj43) } //line cmd/spx-backend/delete_asset_#id.yap:6 func (this *delete_asset_id) Main(_xgo_arg0 *yap.Context) { @@ -893,6 +908,30 @@ func (this *get_courses_list) Classclone() yap.HandlerProto { _xgo_ret := *this return &_xgo_ret } +//line cmd/spx-backend/get_health.yap:6 +func (this *get_health) Main(_xgo_arg0 *yap.Context) { + this.Handler.Main(_xgo_arg0) +//line cmd/spx-backend/get_health.yap:6:1 + ctx := &this.Context +//line cmd/spx-backend/get_health.yap:8:1 + result, err := this.ctrl.Health(ctx.Context()) +//line cmd/spx-backend/get_health.yap:9:1 + if err != nil { +//line cmd/spx-backend/get_health.yap:10:1 + replyWithInnerError(ctx, err) +//line cmd/spx-backend/get_health.yap:11:1 + return + } +//line cmd/spx-backend/get_health.yap:13:1 + this.Json__1(result) +} +func (this *get_health) Classfname() string { + return "get_health" +} +func (this *get_health) Classclone() yap.HandlerProto { + _xgo_ret := *this + return &_xgo_ret +} //line cmd/spx-backend/get_project-release_#owner_#project_#release.yap:10 func (this *get_project_release_owner_project_release) Main(_xgo_arg0 *yap.Context) { this.Handler.Main(_xgo_arg0) @@ -1828,6 +1867,94 @@ func (this *post_course) Classclone() yap.HandlerProto { _xgo_ret := *this return &_xgo_ret } +//line cmd/spx-backend/post_image.yap:10 +func (this *post_image) Main(_xgo_arg0 *yap.Context) { + this.Handler.Main(_xgo_arg0) +//line cmd/spx-backend/post_image.yap:10:1 + ctx := &this.Context +//line cmd/spx-backend/post_image.yap:12:1 + params := &controller.GenerateImageParams{} +//line cmd/spx-backend/post_image.yap:13:1 + if !parseJSON(ctx, params) { +//line cmd/spx-backend/post_image.yap:14:1 + return + } +//line cmd/spx-backend/post_image.yap:16:1 + if +//line cmd/spx-backend/post_image.yap:16:1 + ok, msg := params.Validate(); !ok { +//line cmd/spx-backend/post_image.yap:17:1 + replyWithCodeMsg(ctx, errorInvalidArgs, msg) +//line cmd/spx-backend/post_image.yap:18:1 + return + } +//line cmd/spx-backend/post_image.yap:21:1 + result, err := this.ctrl.GenerateImage(ctx.Context(), params) +//line cmd/spx-backend/post_image.yap:22:1 + if err != nil { +//line cmd/spx-backend/post_image.yap:23:1 + replyWithInnerError(ctx, err) +//line cmd/spx-backend/post_image.yap:24:1 + return + } +//line cmd/spx-backend/post_image.yap:27:1 + this.Json__1(result) +} +func (this *post_image) Classfname() string { + return "post_image" +} +func (this *post_image) Classclone() yap.HandlerProto { + _xgo_ret := *this + return &_xgo_ret +} +//line cmd/spx-backend/post_image_svg.yap:10 +func (this *post_image_svg) Main(_xgo_arg0 *yap.Context) { + this.Handler.Main(_xgo_arg0) +//line cmd/spx-backend/post_image_svg.yap:10:1 + ctx := &this.Context +//line cmd/spx-backend/post_image_svg.yap:12:1 + params := &controller.GenerateSVGParams{} +//line cmd/spx-backend/post_image_svg.yap:13:1 + if !parseJSON(ctx, params) { +//line cmd/spx-backend/post_image_svg.yap:14:1 + return + } +//line cmd/spx-backend/post_image_svg.yap:16:1 + if +//line cmd/spx-backend/post_image_svg.yap:16:1 + ok, msg := params.Validate(); !ok { +//line cmd/spx-backend/post_image_svg.yap:17:1 + replyWithCodeMsg(ctx, errorInvalidArgs, msg) +//line cmd/spx-backend/post_image_svg.yap:18:1 + return + } +//line cmd/spx-backend/post_image_svg.yap:21:1 + result, err := this.ctrl.GenerateSVG(ctx.Context(), params) +//line cmd/spx-backend/post_image_svg.yap:22:1 + if err != nil { +//line cmd/spx-backend/post_image_svg.yap:23:1 + replyWithInnerError(ctx, err) +//line cmd/spx-backend/post_image_svg.yap:24:1 + return + } + for +//line cmd/spx-backend/post_image_svg.yap:28:1 + key, value := range result.Headers { +//line cmd/spx-backend/post_image_svg.yap:29:1 + ctx.ResponseWriter.Header().Set(key, value) + } +//line cmd/spx-backend/post_image_svg.yap:33:1 + ctx.ResponseWriter.WriteHeader(200) +//line cmd/spx-backend/post_image_svg.yap:34:1 + ctx.ResponseWriter.Write(result.Data) +} +func (this *post_image_svg) Classfname() string { + return "post_image_svg" +} +func (this *post_image_svg) Classclone() yap.HandlerProto { + _xgo_ret := *this + return &_xgo_ret +} //line cmd/spx-backend/post_project-release.yap:10 func (this *post_project_release) Main(_xgo_arg0 *yap.Context) { this.Handler.Main(_xgo_arg0) diff --git a/spx-backend/docs/DATABASE_DESIGN.md b/spx-backend/docs/DATABASE_DESIGN.md new file mode 100644 index 000000000..f06eb663f --- /dev/null +++ b/spx-backend/docs/DATABASE_DESIGN.md @@ -0,0 +1,135 @@ +# 数据库表设计文档 + +## 概述 + +本文档描述了 SPX Backend 中新增的 AI 资源管理相关数据库表设计。这些表主要用于管理 AI 生成的资源(如 SVG 图像)及其标签分类系统。 + +## 表结构设计 + +### 1. AI 资源表 (aiResource) + +**表名**: `aiResource` +**描述**: 存储 AI 生成的资源信息,主要是 SVG 图像资源 + +| 字段名 | 数据类型 | 约束 | 描述 | +|--------|----------|------|------| +| id | bigint | PRIMARY KEY, AUTO_INCREMENT | 主键,资源唯一标识 | +| url | varchar | NOT NULL | 资源访问 URL | +| created_at | timestamp | NOT NULL | 创建时间 | +| updated_at | timestamp | NOT NULL | 更新时间 | +| deleted_at | timestamp | INDEX | 软删除时间戳 | + +**特性**: +- 继承自基础 Model 结构,包含标准的时间戳字段和软删除支持 +- 支持通过多对多关系与标签关联 +- URL 字段存储资源的访问地址(可能是文件路径或云存储 URL) + +### 2. 标签表 (label) + +**表名**: `label` +**描述**: 存储用于分类 AI 资源的标签信息 + +| 字段名 | 数据类型 | 约束 | 描述 | +|--------|----------|------|------| +| id | bigint | PRIMARY KEY, AUTO_INCREMENT | 主键,标签唯一标识 | +| labelName | varchar(50) | UNIQUE, NOT NULL | 标签名称,唯一约束 | +| created_at | timestamp | NOT NULL | 创建时间 | +| updated_at | timestamp | NOT NULL | 更新时间 | +| deleted_at | timestamp | INDEX | 软删除时间戳 | + +**特性**: +- 标签名称具有唯一性约束,防止重复标签 +- 限制标签名称长度为 50 字符 +- 支持通过多对多关系与 AI 资源关联 + +### 3. 资源标签关联表 (resource_label) + +**表名**: `resource_label` +**描述**: AI 资源与标签的多对多关联表 + +| 字段名 | 数据类型 | 约束 | 描述 | +|--------|----------|------|------| +| id | bigint | PRIMARY KEY, AUTO_INCREMENT | 主键,关联记录唯一标识 | +| aiResourceId | bigint | NOT NULL, INDEX, FOREIGN KEY | 外键,引用 aiResource 表的 id | +| labelId | bigint | NOT NULL, INDEX, FOREIGN KEY | 外键,引用 label 表的 id | +| created_at | timestamp | NOT NULL | 创建时间 | +| updated_at | timestamp | NOT NULL | 更新时间 | +| deleted_at | timestamp | INDEX | 软删除时间戳 | + +**外键关系**: +- `aiResourceId` → `aiResource.id` +- `labelId` → `label.id` + +**索引设计**: +- `aiResourceId` 字段建立索引,优化基于资源查询标签的性能 +- `labelId` 字段建立索引,优化基于标签查询资源的性能 +- `deleted_at` 字段建立索引,支持软删除查询优化 + +## 关系模型 + +``` +aiResource (1) ←→ (N) resource_label (N) ←→ (1) label +``` + +这是一个典型的多对多关系设计: +- 一个 AI 资源可以有多个标签 +- 一个标签可以被多个 AI 资源使用 +- 通过中间表 `resource_label` 实现多对多关联 + +## 数据库特性 + +### 软删除支持 +所有表都支持软删除功能: +- 使用 `deleted_at` 字段标记删除状态 +- 查询时自动过滤已删除记录 +- 支持数据恢复 + +### 时间戳审计 +每个表都包含完整的时间戳字段: +- `created_at`: 记录创建时间 +- `updated_at`: 记录最后更新时间 +- `deleted_at`: 软删除时间(为 NULL 表示未删除) + +### 索引优化 +- 主键自动建立聚簇索引 +- 外键字段建立普通索引,优化关联查询 +- `deleted_at` 字段建立索引,优化软删除查询 +- 标签名称的唯一索引确保数据一致性 + +## 使用场景 + +### 1. AI 资源管理 +- 存储 SVG 图像生成结果 +- 记录资源的访问 URL +- 支持资源的软删除和恢复 + +### 2. 标签分类系统 +- 为 AI 资源打标签进行分类 +- 支持多维度标签组合 +- 便于资源的检索和过滤 + +### 3. 查询优化 +- 通过标签快速筛选相关资源 +- 通过资源查看所有关联标签 +- 支持复杂的标签组合查询 + +## 扩展性考虑 + +### 1. 字段扩展 +- AI 资源表可以根据需要添加更多元数据字段(如尺寸、格式等) +- 标签表可以添加描述、颜色、图标等字段 + +### 2. 关系扩展 +- 可以为关联表添加权重字段,支持标签重要性排序 +- 可以添加标签层级关系,支持标签树形结构 + +### 3. 性能优化 +- 根据查询模式可以添加复合索引 +- 考虑分表策略应对大数据量场景 + +## 注意事项 + +1. **数据一致性**: 删除标签时需要考虑关联资源的处理策略 +2. **性能监控**: 需要监控多对多查询的性能,必要时进行索引优化 +3. **数据清理**: 定期清理软删除的数据,避免表空间过度膨胀 +4. **权限控制**: 考虑用户权限,确保资源访问的安全性 \ No newline at end of file diff --git a/spx-backend/docs/SVG_THEME_TEMPLATES.md b/spx-backend/docs/SVG_THEME_TEMPLATES.md new file mode 100644 index 000000000..e00b6f09a --- /dev/null +++ b/spx-backend/docs/SVG_THEME_TEMPLATES.md @@ -0,0 +1,430 @@ +# SVG主题提示词模板功能文档 + +## 概述 + +SVG主题提示词模板功能为SVG生成服务提供了智能的主题化增强能力。通过预定义的主题模板,系统能够自动将主题相关的提示词与用户输入进行智能拼接,显著提升生成的SVG图像质量和一致性。 + +## 功能特性 + +1. **智能提示词增强**: 根据不同教育场景和用途提供专门优化的提示词模板 +2. **多语言支持**: 同时支持中英文提示词,自动根据语言偏好选择 +3. **可配置强度**: 支持0-1范围的主题影响强度控制 +4. **灵活定制**: 支持自定义前缀、后缀和风格修饰词 +5. **热更新**: 支持运行时重载主题配置 +6. **向下兼容**: 完全兼容现有API,不影响未启用主题的请求 + +## 架构设计 + +### 包结构 +``` +internal/svggen/theme/ +├── types.go # 主题相关类型定义 +├── templates.go # 内置模板定义 +├── config.go # 配置加载和环境变量处理 +├── manager.go # 主题管理器核心逻辑 +└── manager_test.go # 单元测试 +``` + +### 核心组件 + +#### ThemeManager +- 负责加载和管理主题模板 +- 提供主题查询和提示词拼接功能 +- 支持模板缓存和运行时重载 +- 线程安全的并发访问 + +#### ThemeTemplate +- 定义主题的结构化信息 +- 包含多语言提示词、风格修饰词、质量增强词 +- 支持提供商特定的配置覆盖 + +### 集成方式 + +主题功能通过中间件方式集成到现有的`svggen.ServiceManager`中,在生成请求处理流程中自动进行主题增强: + +``` +用户请求 -> 主题处理 -> 翻译处理 -> 提供商生成 -> 响应 +``` + +## 数据结构 + +### 主题类型定义 + +```go +type ThemeType string + +const ( + ThemeEducationMath ThemeType = "education_math" + ThemeEducationScience ThemeType = "education_science" + ThemeEducationArt ThemeType = "education_art" + ThemeGameCartoon ThemeType = "game_cartoon" + ThemeGamePixel ThemeType = "game_pixel" + ThemeUIIcon ThemeType = "ui_icon" + ThemeUIIllustration ThemeType = "ui_illustration" + ThemeGeneral ThemeType = "general" +) +``` + +### 主题模板结构 + +```go +type ThemeTemplate struct { + Type ThemeType `json:"type"` + Name string `json:"name"` + NameCN string `json:"name_cn"` + Description string `json:"description"` + DescriptionCN string `json:"description_cn"` + + // 提示词模板 + PrefixPrompt string `json:"prefix_prompt"` + PrefixPromptCN string `json:"prefix_prompt_cn"` + SuffixPrompt string `json:"suffix_prompt"` + SuffixPromptCN string `json:"suffix_prompt_cn"` + + // 默认参数 + DefaultStyle string `json:"default_style"` + DefaultNegative string `json:"default_negative"` + DefaultNegativeCN string `json:"default_negative_cn"` + + // 高级配置 + StyleModifiers []string `json:"style_modifiers"` + QualityEnhancers []string `json:"quality_enhancers"` + + // 提供商特定配置 + ProviderOverrides map[string]ProviderConfig `json:"provider_overrides"` +} +``` + +### API参数扩展 + +现有的`GenerateRequest`和控制器参数已扩展支持主题字段: + +```go +type GenerateRequest struct { + // 现有字段... + Prompt string `json:"prompt"` + NegativePrompt string `json:"negative_prompt,omitempty"` + Style string `json:"style,omitempty"` + Provider Provider `json:"provider,omitempty"` + + // 新增主题字段 + Theme string `json:"theme,omitempty"` + EnableTheme bool `json:"enable_theme,omitempty"` + ThemeStrength float64 `json:"theme_strength,omitempty"` + CustomPrefix string `json:"custom_prefix,omitempty"` + CustomSuffix string `json:"custom_suffix,omitempty"` + Language string `json:"language,omitempty"` +} +``` + +## 核心算法 + +### 提示词增强流程 + +```go +func (tm *ThemeManager) BuildEnhancedPrompt(req *ProcessRequest, template *ThemeTemplate) string { + var parts []string + + // 1. 自定义前缀或主题前缀 + if req.CustomPrefix != "" { + parts = append(parts, req.CustomPrefix) + } else if template.PrefixPrompt != "" { + prefix := tm.selectPromptByLanguage(template.PrefixPrompt, template.PrefixPromptCN, req.Language) + if prefix != "" { + parts = append(parts, prefix) + } + } + + // 2. 用户原始提示词 + if req.Prompt != "" { + parts = append(parts, req.Prompt) + } + + // 3. 主题风格修饰(根据强度选择) + if req.ThemeStrength > 0.5 && len(template.StyleModifiers) > 0 { + modifiers := tm.selectModifiers(template.StyleModifiers, req.ThemeStrength) + parts = append(parts, modifiers...) + } + + // 4. 自定义后缀或主题后缀 + if req.CustomSuffix != "" { + parts = append(parts, req.CustomSuffix) + } else if template.SuffixPrompt != "" { + suffix := tm.selectPromptByLanguage(template.SuffixPrompt, template.SuffixPromptCN, req.Language) + if suffix != "" { + parts = append(parts, suffix) + } + } + + // 5. 质量增强词(总是添加) + if len(template.QualityEnhancers) > 0 { + parts = append(parts, template.QualityEnhancers...) + } + + return strings.Join(tm.filterEmptyStrings(parts), ", ") +} +``` + +### 主题强度影响 + +- **0.0 - 0.5**: 不应用风格修饰词,仅使用前缀、后缀和质量增强词 +- **0.5 - 1.0**: 根据强度比例选择风格修饰词数量 + +```go +func (tm *ThemeManager) selectModifiers(modifiers []string, strength float64) []string { + count := int(float64(len(modifiers)) * strength) + if count == 0 { + count = 1 + } + if count > len(modifiers) { + count = len(modifiers) + } + return modifiers[:count] +} +``` + +## 内置主题详情 + +### 1. 教育数学主题 (education_math) +- **用途**: 数学教育插图 +- **前缀**: "Educational mathematics illustration" / "教育数学插图" +- **后缀**: "clean vector design, educational style, suitable for children" / "简洁矢量设计,教育风格,适合儿童" +- **风格修饰**: geometric, colorful, friendly, simple +- **质量增强**: high quality, clean lines, professional + +### 2. 游戏卡通主题 (game_cartoon) +- **用途**: 游戏角色和卡通插图 +- **前缀**: "Cute cartoon game character" / "可爱卡通游戏角色" +- **后缀**: "vibrant colors, friendly appearance, game-ready design" / "鲜艳色彩,友好外观,游戏就绪设计" +- **风格修饰**: playful, colorful, expressive, rounded +- **质量增强**: detailed, polished, game art style + +### 3. UI图标主题 (ui_icon) +- **用途**: 用户界面图标 +- **前缀**: "Minimal UI icon" / "简约UI图标" +- **后缀**: "flat design, consistent style, scalable vector" / "扁平设计,一致风格,可缩放矢量" +- **风格修饰**: minimal, modern, clean, geometric +- **质量增强**: pixel perfect, scalable, optimized + +## API端点 + +### 主题查询端点 + +#### 获取所有主题 +``` +GET /image/themes +``` + +响应格式: +```json +{ + "themes": [ + { + "type": "education_math", + "name": "Education Math", + "name_cn": "教育数学", + "description": "Mathematical illustrations optimized for educational content", + "description_cn": "专为教育内容优化的数学插图", + "default_style": "educational_illustration", + "style_modifiers": ["geometric", "colorful", "friendly", "simple"], + "quality_enhancers": ["high quality", "clean lines", "professional"] + } + ], + "count": 8 +} +``` + +#### 获取特定主题 +``` +GET /image/themes/{theme_type} +``` + +### SVG生成端点(已扩展) + +``` +POST /image/svg +``` + +支持的主题参数: +```json +{ + "prompt": "一只小猫", + "theme": "game_cartoon", + "enable_theme": true, + "theme_strength": 0.8, + "custom_prefix": "可爱的", + "custom_suffix": "适合儿童", + "language": "zh", + "provider": "svgio" +} +``` + +## 配置管理 + +### 环境变量 + +```bash +# 主题功能开关 +SVG_THEMES_ENABLED=true + +# 主题配置文件路径(可选) +SVG_THEMES_CONFIG_PATH=./config/theme_templates.yaml + +# 默认主题 +SVG_DEFAULT_THEME=general + +# 主题缓存时间(秒) +SVG_THEMES_CACHE_TTL=3600 +``` + +### 配置文件支持 + +支持YAML格式的自定义主题配置文件: + +```yaml +themes: + custom_education: + name: "Custom Education" + name_cn: "自定义教育" + enabled: true + prefix_prompt: "Educational custom illustration" + prefix_prompt_cn: "教育自定义插图" + suffix_prompt: "child-friendly, educational, clear" + suffix_prompt_cn: "儿童友好,教育性,清晰" + default_style: "educational_custom" + style_modifiers: ["educational", "child-friendly", "colorful"] + quality_enhancers: ["high quality", "professional"] +``` + +## 使用示例 + +### 基础主题使用 + +```bash +curl -X POST http://localhost:8080/image/svg \ + -H "Content-Type: application/json" \ + -d '{ + "prompt": "一只小猫学数学", + "theme": "education_math", + "enable_theme": true, + "provider": "svgio" + }' +``` + +### 高级配置 + +```bash +curl -X POST http://localhost:8080/image/svg \ + -H "Content-Type: application/json" \ + -d '{ + "prompt": "molecular structure", + "theme": "education_science", + "enable_theme": true, + "theme_strength": 0.7, + "custom_suffix": "suitable for university textbook", + "language": "en", + "provider": "recraft" + }' +``` + +## 提示词变换示例 + +### 输入 +``` +原始提示词: "a cute cat" +主题: game_cartoon +强度: 0.8 +语言: en +``` + +### 处理流程 +1. **前缀**: "Cute cartoon game character" +2. **用户输入**: "a cute cat" +3. **风格修饰**: "playful, colorful, expressive, rounded" (4个中的4个,因为强度0.8) +4. **后缀**: "vibrant colors, friendly appearance, game-ready design" +5. **质量增强**: "detailed, polished, game art style" + +### 最终结果 +``` +"Cute cartoon game character, a cute cat, playful, colorful, expressive, rounded, vibrant colors, friendly appearance, game-ready design, detailed, polished, game art style" +``` + +## 性能特性 + +### 缓存策略 +- **内存缓存**: 所有主题模板在启动时加载到内存 +- **线程安全**: 使用读写锁保证并发安全 +- **热更新**: 支持运行时重载而无需重启服务 + +### 性能优化 +- **快速路径**: 未启用主题时零开销 +- **字符串优化**: 避免不必要的字符串分配和复制 +- **去重处理**: 自动过滤空字符串和重复内容 + +### 错误处理 +- **优雅降级**: 主题处理失败时自动回退到原始提示词 +- **详细日志**: 记录主题处理过程和错误信息 +- **不中断生成**: 主题错误不影响SVG生成的继续执行 + +## 监控和日志 + +### 关键日志 +```go +logger.Printf("Theme processing - theme: %s, original_prompt: %q, enhanced_prompt: %q", + req.Theme, originalPrompt, enhancedPrompt) +``` + +### 统计信息 +- 主题使用频率统计 +- 不同主题的成功率跟踪 +- 提示词长度分布分析 +- 生成时间对比数据 + +## 扩展性 + +### 自定义主题注册 +```go +customTheme := &ThemeTemplate{ + Type: ThemeType("custom_theme"), + Name: "Custom Theme", + PrefixPrompt: "Custom prefix", +} +manager.RegisterCustomTheme(customTheme) +``` + +### 未来扩展方向 +- **A/B测试**: 主题变体对比测试 +- **用户偏好**: 个性化主题推荐 +- **智能学习**: 基于用户反馈的主题优化 +- **更多主题**: 持续扩展内置主题库 + +## 测试策略 + +### 单元测试覆盖 +- 主题模板加载和管理 +- 提示词拼接逻辑验证 +- 多语言选择机制 +- 参数验证和错误处理 +- 并发安全性测试 + +### 集成测试 +- 端到端API功能测试 +- 多提供商兼容性验证 +- 性能基准测试 +- 内存使用和泄漏检测 + +## 部署指南 + +### 渐进式启用 +1. **第一阶段**: 部署代码,主题功能默认关闭 +2. **第二阶段**: 开启主题功能,监控系统表现 +3. **第三阶段**: 扩展主题库,收集用户反馈 +4. **第四阶段**: 优化算法,增强用户体验 + +### 兼容性保证 +- 现有API完全向下兼容 +- 新增字段均为可选参数 +- 提供完整的降级机制 +- 支持平滑的功能开关 + +这套主题系统为SVG生成服务提供了强大而灵活的主题化能力,在提升生成质量的同时保持了良好的用户体验和系统稳定性。 \ No newline at end of file diff --git a/spx-backend/docs/THEME_FEATURE.md b/spx-backend/docs/THEME_FEATURE.md new file mode 100644 index 000000000..351b44c0b --- /dev/null +++ b/spx-backend/docs/THEME_FEATURE.md @@ -0,0 +1,222 @@ +# SVG生成主题功能文档 + +## 概述 + +SVG生成主题功能为图片生成提供了预定义的风格主题,通过在提示词中自动添加主题相关的风格描述,确保生成的图片严格遵循指定的视觉风格。 + +## 功能特性 + +- **9种预定义主题**:涵盖卡通、写实、极简等多种风格 +- **严格风格控制**:使用强制性指令词汇确保AI严格遵循风格要求 +- **自动提示词增强**:后端自动将主题提示词拼接到用户原始提示词 +- **主题查询API**:前端可动态获取所有可用主题信息 +- **中文支持**:提供中文名称、描述和提示词 + +## 支持的主题 + +### 1. 无主题 (ThemeNone) +- **ID**: `""` +- **中文名**: 无主题 +- **描述**: 不应用任何特定主题风格 +- **提示词**: 无 + +### 2. 卡通风格 (ThemeCartoon) +- **ID**: `"cartoon"` +- **中文名**: 卡通风格 +- **描述**: 色彩鲜艳的卡通风格,适合可爱有趣的内容 +- **提示词**: 必须使用卡通风格,必须色彩鲜艳丰富,必须可爱有趣,严格使用简单几何形状,强制使用明亮饱和的色彩,禁止写实细节 + +### 3. 写实风格 (ThemeRealistic) +- **ID**: `"realistic"` +- **中文名**: 写实风格 +- **描述**: 高度写实的风格,细节丰富逼真 +- **提示词**: 必须使用写实风格,严格要求高度细节化,强制逼真效果,必须专业高质量渲染,禁止卡通化或简化元素 + +### 4. 极简风格 (ThemeMinimal) +- **ID**: `"minimal"` +- **中文名**: 极简风格 +- **描述**: 极简主义风格,简洁干净的设计 +- **提示词**: 必须使用极简风格,严格限制元素数量,强制使用干净线条和几何形状,严格使用黑白或单色调,禁止复杂装饰 + +### 5. 奇幻风格 (ThemeFantasy) +- **ID**: `"fantasy"` +- **中文名**: 奇幻风格 +- **描述**: 充满魔法和超自然元素的奇幻风格 +- **提示词**: 必须使用奇幻魔法风格,强制添加神秘魔法元素,严格使用梦幻色彩,必须包含超自然效果,禁止现实主义元素 + +### 6. 复古风格 (ThemeRetro) +- **ID**: `"retro"` +- **中文名**: 复古风格 +- **描述**: 怀旧复古风格,经典老式美学 +- **提示词**: 必须使用复古怀旧风格,严格遵循经典老式美学,强制使用怀旧色调和设计元素,禁止现代化元素 + +### 7. 科幻风格 (ThemeScifi) +- **ID**: `"scifi"` +- **中文名**: 科幻风格 +- **描述**: 未来科技风格,充满科幻元素 +- **提示词**: 必须使用科幻未来风格,强制添加科技元素,严格使用霓虹和金属色彩,必须包含未来感设计,禁止传统元素 + +### 8. 自然风格 (ThemeNature) +- **ID**: `"nature"` +- **中文名**: 自然风格 +- **描述**: 自然有机风格,使用自然元素和大地色调 +- **提示词**: 必须使用自然有机风格,严格使用自然元素和植物,强制使用大地色调和绿色系,禁止人工几何元素 + +### 9. 商务风格 (ThemeBusiness) +- **ID**: `"business"` +- **中文名**: 商务风格 +- **描述**: 专业商务风格,现代企业形象 +- **提示词**: 必须使用商务专业风格,严格保持企业形象,强制使用现代简洁设计,必须专业精致,禁止卡通或娱乐元素 + +## API接口 + +### 1. 查询所有主题 + +**接口地址**: `GET /themes` + +**响应格式**: +```json +[ + { + "id": "cartoon", + "name": "卡通风格", + "description": "色彩鲜艳的卡通风格,适合可爱有趣的内容", + "prompt": "必须使用卡通风格,必须色彩鲜艳丰富,必须可爱有趣,严格使用简单几何形状,强制使用明亮饱和的色彩,禁止写实细节" + } + // ... 其他主题 +] +``` + +### 2. 生成SVG(支持主题) + +**接口地址**: `POST /image/svg` + +**请求参数**: +```json +{ + "prompt": "一只可爱的小猫", + "theme": "cartoon", + "provider": "svgio" +} +``` + +**处理流程**: +1. 验证主题参数是否有效 +2. 应用主题提示词:`"一只可爱的小猫,必须使用卡通风格,必须色彩鲜艳丰富..."` +3. 调用图片生成服务 +4. 返回生成的SVG内容 + +### 3. 生成图片元数据(支持主题) + +**接口地址**: `POST /image` + +**请求参数**: 与生成SVG相同 + +**响应**: 返回图片元数据信息,包含URL、尺寸等 + +## 实现细节 + +### 核心文件 + +1. **`internal/controller/theme.go`** + - 定义主题类型和常量 + - 提供主题信息映射 + - 实现提示词拼接逻辑 + - 提供主题查询方法 + +2. **`internal/controller/svg.go`** + - 集成主题功能到生成接口 + - 添加主题查询controller方法 + +3. **`cmd/spx-backend/get_themes.yap`** + - 主题查询API端点 + +### 核心方法 + +#### 1. 主题验证 +```go +func IsValidTheme(theme ThemeType) bool +``` + +#### 2. 提示词拼接 +```go +func ApplyThemeToPrompt(originalPrompt string, theme ThemeType) string +``` + +#### 3. 主题信息查询 +```go +func GetAllThemesInfo() []ThemeInfo +func GetThemeInfo(theme ThemeType) ThemeInfo +``` + +### 提示词设计原则 + +1. **强制性指令**: 使用"必须"、"严格"、"强制"等命令性词汇 +2. **明确要求**: 详细描述每个主题的具体视觉特征 +3. **禁止条款**: 明确禁止与主题不符的元素 +4. **中文优化**: 使用中文逗号拼接,符合中文表达习惯 + +## 使用示例 + +### 前端获取主题列表 +```javascript +fetch('/themes') + .then(response => response.json()) + .then(themes => { + themes.forEach(theme => { + console.log(`${theme.name}: ${theme.description}`); + }); + }); +``` + +### 前端生成图片 +```javascript +const request = { + prompt: "一座美丽的城堡", + theme: "fantasy", + provider: "svgio" +}; + +fetch('/image/svg', { + method: 'POST', + headers: {'Content-Type': 'application/json'}, + body: JSON.stringify(request) +}); +``` + +### 实际处理效果 +- **原始提示词**: "一座美丽的城堡" +- **应用奇幻主题后**: "一座美丽的城堡,必须使用奇幻魔法风格,强制添加神秘魔法元素,严格使用梦幻色彩,必须包含超自然效果,禁止现实主义元素" + +## 测试覆盖 + +项目包含完整的单元测试,覆盖以下功能: + +1. 主题验证逻辑 +2. 提示词拼接功能 +3. 主题信息查询 +4. 参数验证 + +测试文件: `internal/controller/theme_test.go` + +## 扩展性 + +### 添加新主题 + +1. 在 `ThemeType` 中定义新常量 +2. 在 `ThemePrompts` 中添加提示词 +3. 在 `ThemeNames` 中添加中文名称 +4. 在 `ThemeDescriptions` 中添加描述 +5. 更新 `GetAvailableThemes()` 方法 +6. 添加对应的测试用例 + +### 主题定制 + +可以通过修改 `ThemePrompts` 中的提示词来调整主题效果,或者添加更细粒度的主题参数控制。 + +## 注意事项 + +1. 主题功能会修改用户的原始提示词,在日志中会记录修改前后的内容 +2. 无主题选项 (`ThemeNone`) 不会修改原始提示词 +3. 主题验证在参数验证阶段进行,无效主题会返回错误 +4. 所有主题信息通过API动态获取,支持后端热更新 \ No newline at end of file diff --git a/spx-backend/internal/config/config.go b/spx-backend/internal/config/config.go index ca850f717..e7a2a8450 100644 --- a/spx-backend/internal/config/config.go +++ b/spx-backend/internal/config/config.go @@ -1,19 +1,30 @@ package config +import ( + + "time" +) + // Config holds all configuration for the application. type Config struct { - Server ServerConfig - Database DatabaseConfig - Redis RedisConfig - Kodo KodoConfig - Casdoor CasdoorConfig - OpenAI OpenAIConfig - AIGC AIGCConfig + Server ServerConfig + Database DatabaseConfig + Redis RedisConfig + Kodo KodoConfig + Casdoor CasdoorConfig + OpenAI OpenAIConfig + AIGC AIGCConfig + Providers ProvidersConfig + Translation TranslationConfig } // ServerConfig holds server configuration. type ServerConfig struct { - Port string + Port string + Host string + Timeout time.Duration + ReadTimeout time.Duration + WriteTimeout time.Duration } // GetPort returns the server port, defaulting to ":8080" if not set. @@ -24,6 +35,22 @@ func (c *ServerConfig) GetPort() string { return ":8080" } +// GetServerAddr returns the server listen address. +func (c *ServerConfig) GetServerAddr() string { + host := c.Host + if host == "" { + host = "0.0.0.0" + } + port := c.Port + if port == "" { + port = ":8080" + } + if port[0] != ':' { + port = ":" + port + } + return host + port +} + // DatabaseConfig holds database configuration. type DatabaseConfig struct { DSN string @@ -111,3 +138,75 @@ func (c *OpenAIConfig) GetPremiumModelID() string { type AIGCConfig struct { Endpoint string } + +// ProvidersConfig holds provider configurations for SVG generation. +type ProvidersConfig struct { + SVGIO SVGIOConfig + Recraft RecraftConfig + SVGOpenAI OpenAISVGConfig +} + +// SVGIOConfig holds SVG.IO provider configuration. +type SVGIOConfig struct { + BaseURL string + Timeout time.Duration + MaxRetries int + Enabled bool + Endpoints SVGIOEndpoints +} + +// SVGIOEndpoints holds SVG.IO API endpoint configuration. +type SVGIOEndpoints struct { + Generate string +} + +// RecraftConfig holds Recraft provider configuration. +type RecraftConfig struct { + BaseURL string + Timeout time.Duration + MaxRetries int + Enabled bool + DefaultModel string + SupportedModels []string + Endpoints RecraftEndpoints +} + +// RecraftEndpoints holds Recraft API endpoint configuration. +type RecraftEndpoints struct { + Generate string + Vectorize string +} + +// OpenAISVGConfig holds OpenAI provider configuration for SVG generation (supports all OpenAI compatible models). +type OpenAISVGConfig struct { + BaseURL string + Timeout time.Duration + MaxRetries int + Enabled bool + DefaultModel string + MaxTokens int + Temperature float64 +} + +// TranslationConfig holds translation service configuration. +type TranslationConfig struct { + Enabled bool + ServiceURL string + DefaultModel string + Timeout time.Duration + MaxRetries int +} + +// IsProviderEnabled checks if a provider is enabled. +func (c *Config) IsProviderEnabled(provider string) bool { + switch provider { + case "svgio": + return c.Providers.SVGIO.Enabled + case "recraft": + return c.Providers.Recraft.Enabled + case "openai": + return c.Providers.SVGOpenAI.Enabled + default: + return false + } +} diff --git a/spx-backend/internal/config/loader.go b/spx-backend/internal/config/loader.go index 2aeb63b0a..4b748f864 100644 --- a/spx-backend/internal/config/loader.go +++ b/spx-backend/internal/config/loader.go @@ -5,6 +5,8 @@ import ( "io/fs" "os" "strconv" + "strings" + "time" "github.com/joho/godotenv" "github.com/qiniu/x/log" @@ -57,6 +59,47 @@ func Load(logger *log.Logger) (*Config, error) { AIGC: AIGCConfig{ Endpoint: mustGetEnv(logger, "AIGC_ENDPOINT"), }, + + + Providers: ProvidersConfig{ + SVGIO: SVGIOConfig{ + BaseURL: getEnvAsString("SVGIO_BASE_URL", "https://api.svg.io"), + Timeout: getEnvAsDuration("SVGIO_TIMEOUT", "60s"), + MaxRetries: getEnvAsInt("SVGIO_MAX_RETRIES", 3), + Enabled: getEnvAsBool("SVGIO_ENABLED", true), + Endpoints: SVGIOEndpoints{ + Generate: getEnvAsString("SVGIO_GENERATE_ENDPOINT", "/v1/generate-image"), + }, + }, + Recraft: RecraftConfig{ + BaseURL: getEnvAsString("RECRAFT_BASE_URL", "https://external.api.recraft.ai"), + Timeout: getEnvAsDuration("RECRAFT_TIMEOUT", "60s"), + MaxRetries: getEnvAsInt("RECRAFT_MAX_RETRIES", 3), + Enabled: getEnvAsBool("RECRAFT_ENABLED", true), + DefaultModel: getEnvAsString("RECRAFT_DEFAULT_MODEL", "recraftv3"), + SupportedModels: getEnvAsStringSlice("RECRAFT_SUPPORTED_MODELS", "recraftv3,recraftv2"), + Endpoints: RecraftEndpoints{ + Generate: getEnvAsString("RECRAFT_GENERATE_ENDPOINT", "/v1/images/generations"), + Vectorize: getEnvAsString("RECRAFT_VECTORIZE_ENDPOINT", "/v1/images/vectorize"), + }, + }, + SVGOpenAI: OpenAISVGConfig{ + BaseURL: getEnvAsString("SVG_OPENAI_BASE_URL", "https://api.qnaigc.com/v1/"), + Timeout: getEnvAsDuration("SVG_OPENAI_TIMEOUT", "60s"), + MaxRetries: getEnvAsInt("SVG_OPENAI_MAX_RETRIES", 3), + Enabled: getEnvAsBool("SVG_OPENAI_ENABLED", true), + DefaultModel: getEnvAsString("SVG_OPENAI_DEFAULT_MODEL", "claude-4.0-sonnet"), + MaxTokens: getEnvAsInt("SVG_OPENAI_MAX_TOKENS", 4000), + Temperature: getEnvAsFloat("SVG_OPENAI_TEMPERATURE", 0.7), + }, + }, + Translation: TranslationConfig{ + Enabled: getEnvAsBool("TRANSLATION_ENABLED", true), + ServiceURL: getEnvAsString("TRANSLATION_SERVICE_URL", "https://api.qnaigc.com/v1/chat/completions"), + DefaultModel: getEnvAsString("TRANSLATION_DEFAULT_MODEL", "claude-4.0-sonnet"), + Timeout: getEnvAsDuration("TRANSLATION_TIMEOUT", "45s"), + MaxRetries: getEnvAsInt("TRANSLATION_MAX_RETRIES", 2), + }, } return config, nil } @@ -79,3 +122,63 @@ func getIntEnv(key string) int { intValue, _ := strconv.Atoi(value) return intValue } + +// getEnvAsString gets the environment variable value or returns default value. +func getEnvAsString(key, defaultValue string) string { + if value := os.Getenv(key); value != "" { + return value + } + return defaultValue +} + +// getEnvAsInt gets an integer environment variable value or returns default value. +func getEnvAsInt(key string, defaultValue int) int { + if value := os.Getenv(key); value != "" { + if intValue, err := strconv.Atoi(value); err == nil { + return intValue + } + } + return defaultValue +} + +// getEnvAsFloat gets a float environment variable value or returns default value. +func getEnvAsFloat(key string, defaultValue float64) float64 { + if value := os.Getenv(key); value != "" { + if floatValue, err := strconv.ParseFloat(value, 64); err == nil { + return floatValue + } + } + return defaultValue +} + +// getEnvAsBool gets a boolean environment variable value or returns default value. +func getEnvAsBool(key string, defaultValue bool) bool { + if value := os.Getenv(key); value != "" { + if boolValue, err := strconv.ParseBool(value); err == nil { + return boolValue + } + } + return defaultValue +} + +// getEnvAsDuration gets a duration environment variable value or returns default value. +func getEnvAsDuration(key, defaultValue string) time.Duration { + if value := os.Getenv(key); value != "" { + if duration, err := time.ParseDuration(value); err == nil { + return duration + } + } + if duration, err := time.ParseDuration(defaultValue); err == nil { + return duration + } + return 30 * time.Second // fallback default +} + +// getEnvAsStringSlice gets a comma-separated string environment variable value or returns default value. +func getEnvAsStringSlice(key, defaultValue string) []string { + value := getEnvAsString(key, defaultValue) + if value == "" { + return []string{} + } + return strings.Split(value, ",") +} diff --git a/spx-backend/internal/controller/controller.go b/spx-backend/internal/controller/controller.go index ab4954ce6..cff781d0f 100644 --- a/spx-backend/internal/controller/controller.go +++ b/spx-backend/internal/controller/controller.go @@ -12,7 +12,9 @@ import ( "github.com/goplus/builder/spx-backend/internal/aiinteraction" "github.com/goplus/builder/spx-backend/internal/config" "github.com/goplus/builder/spx-backend/internal/copilot" + "github.com/goplus/builder/spx-backend/internal/log" "github.com/goplus/builder/spx-backend/internal/model" + "github.com/goplus/builder/spx-backend/internal/svggen" "github.com/goplus/builder/spx-backend/internal/workflow" "github.com/openai/openai-go" "github.com/openai/openai-go/option" @@ -34,6 +36,7 @@ type Controller struct { workflow *workflow.Workflow aiInteraction *aiinteraction.AIInteraction aigc *aigc.AigcClient + svggen *svggen.ServiceManager } // New creates a new controller. @@ -64,6 +67,9 @@ func New(ctx context.Context, db *gorm.DB, cfg *config.Config) (*Controller, err aigcClient := aigc.NewAigcClient(cfg.AIGC.Endpoint) + // Initialize SVG generation service manager + svggenManager := svggen.NewServiceManager(cfg, log.GetLogger()) + return &Controller{ db: db, kodo: kodoClient, @@ -71,6 +77,7 @@ func New(ctx context.Context, db *gorm.DB, cfg *config.Config) (*Controller, err workflow: stdflow, aiInteraction: aiInteraction, aigc: aigcClient, + svggen: svggenManager, }, nil } diff --git a/spx-backend/internal/controller/health.go b/spx-backend/internal/controller/health.go new file mode 100644 index 000000000..5a066d68f --- /dev/null +++ b/spx-backend/internal/controller/health.go @@ -0,0 +1,20 @@ +package controller + +import ( + "context" + "time" +) + +// HealthResult is the result of the health check. +type HealthResult struct { + Status string `json:"status"` + Time string `json:"time"` +} + +// Health performs a health check. +func (ctrl *Controller) Health(ctx context.Context) (*HealthResult, error) { + return &HealthResult{ + Status: "ok", + Time: time.Now().Format(time.RFC3339), + }, nil +} diff --git a/spx-backend/internal/controller/svg.go b/spx-backend/internal/controller/svg.go new file mode 100644 index 000000000..68af39edf --- /dev/null +++ b/spx-backend/internal/controller/svg.go @@ -0,0 +1,261 @@ +package controller + +import ( + "context" + "encoding/base64" + "errors" + "fmt" + "strconv" + "strings" + "time" + + "github.com/goplus/builder/spx-backend/internal/log" + "github.com/goplus/builder/spx-backend/internal/svggen" +) + +// GenerateSVGParams represents parameters for SVG generation. +type GenerateSVGParams struct { + Prompt string `json:"prompt"` + NegativePrompt string `json:"negative_prompt,omitempty"` + Style string `json:"style,omitempty"` + Theme ThemeType `json:"theme,omitempty"` + Provider svggen.Provider `json:"provider,omitempty"` + Format string `json:"format,omitempty"` + SkipTranslate bool `json:"skip_translate,omitempty"` + Model string `json:"model,omitempty"` + Size string `json:"size,omitempty"` + Substyle string `json:"substyle,omitempty"` + NumImages int `json:"n,omitempty"` +} + +// Validate validates the SVG generation parameters. +func (p *GenerateSVGParams) Validate() (bool, string) { + if len(p.Prompt) < 3 { + return false, "prompt must be at least 3 characters" + } + + if p.Provider == "" { + p.Provider = svggen.ProviderSVGIO // Default to SVGIO + } + + // Validate provider + validProviders := []svggen.Provider{ + svggen.ProviderSVGIO, + svggen.ProviderRecraft, + svggen.ProviderOpenAI, + } + isValid := false + for _, vp := range validProviders { + if p.Provider == vp { + isValid = true + break + } + } + if !isValid { + return false, "provider must be one of: svgio, recraft, openai" + } + + // Validate theme + if !IsValidTheme(p.Theme) { + return false, "invalid theme type" + } + + return true, "" +} + +// GenerateImageParams represents parameters for image generation (metadata only). +type GenerateImageParams struct { + GenerateSVGParams +} + +// SVGResponse represents the response for direct SVG requests. +type SVGResponse struct { + Data []byte `json:"-"` // SVG content + Headers map[string]string `json:"-"` // Response headers +} + +// ImageResponse represents the response for image metadata requests. +type ImageResponse struct { + ID string `json:"id"` + SVGURL string `json:"svg_url"` + PNGURL string `json:"png_url,omitempty"` + Width int `json:"width"` + Height int `json:"height"` + Provider svggen.Provider `json:"provider"` + OriginalPrompt string `json:"original_prompt,omitempty"` + TranslatedPrompt string `json:"translated_prompt,omitempty"` + WasTranslated bool `json:"was_translated"` + CreatedAt time.Time `json:"created_at"` +} + +// GenerateSVG generates an SVG image and returns the SVG content directly. +func (ctrl *Controller) GenerateSVG(ctx context.Context, params *GenerateSVGParams) (*SVGResponse, error) { + logger := log.GetReqLogger(ctx) + logger.Printf("GenerateSVG request - provider: %s, theme: %s, prompt: %q", params.Provider, params.Theme, params.Prompt) + + // Apply theme to prompt if specified + finalPrompt := ApplyThemeToPrompt(params.Prompt, params.Theme) + if params.Theme != ThemeNone { + logger.Printf("Theme applied - original: %q, enhanced: %q", params.Prompt, finalPrompt) + } + + // Convert to svggen request + req := svggen.GenerateRequest{ + Prompt: finalPrompt, + NegativePrompt: params.NegativePrompt, + Style: params.Style, + Theme: string(params.Theme), + Provider: params.Provider, + Format: params.Format, + SkipTranslate: params.SkipTranslate, + Model: params.Model, + Size: params.Size, + Substyle: params.Substyle, + NumImages: params.NumImages, + } + + // Generate image + result, err := ctrl.svggen.GenerateImage(ctx, req) + if err != nil { + logger.Printf("SVG generation failed: %v", err) + return nil, fmt.Errorf("failed to generate SVG: %w", err) + } + + logger.Printf("SVG generation successful - ID: %s", result.ID) + + // Download or parse SVG content + svgBytes, err := ctrl.getSVGContent(ctx, result.SVGURL) + if err != nil { + logger.Printf("Failed to get SVG content: %v", err) + return nil, fmt.Errorf("failed to get SVG content: %w", err) + } + + // Prepare response headers + headers := map[string]string{ + "Content-Type": "image/svg+xml", + "Content-Disposition": fmt.Sprintf("attachment; filename=\"%s.svg\"", result.ID), + "X-Image-Id": result.ID, + "X-Image-Width": strconv.Itoa(result.Width), + "X-Image-Height": strconv.Itoa(result.Height), + "X-Provider": string(result.Provider), + } + + // Add translation info if available + if result.WasTranslated { + headers["X-Original-Prompt"] = result.OriginalPrompt + headers["X-Translated-Prompt"] = result.TranslatedPrompt + headers["X-Was-Translated"] = "true" + } + + return &SVGResponse{ + Data: svgBytes, + Headers: headers, + }, nil +} + +// GenerateImage generates an image and returns metadata information. +func (ctrl *Controller) GenerateImage(ctx context.Context, params *GenerateImageParams) (*ImageResponse, error) { + logger := log.GetReqLogger(ctx) + logger.Printf("GenerateImage request - provider: %s, theme: %s, prompt: %q", params.Provider, params.Theme, params.Prompt) + + // Apply theme to prompt if specified + finalPrompt := ApplyThemeToPrompt(params.Prompt, params.Theme) + if params.Theme != ThemeNone { + logger.Printf("Theme applied - original: %q, enhanced: %q", params.Prompt, finalPrompt) + } + + // Convert to svggen request + req := svggen.GenerateRequest{ + Prompt: finalPrompt, + NegativePrompt: params.NegativePrompt, + Style: params.Style, + Theme: string(params.Theme), + Provider: params.Provider, + Format: params.Format, + SkipTranslate: params.SkipTranslate, + Model: params.Model, + Size: params.Size, + Substyle: params.Substyle, + NumImages: params.NumImages, + } + + // Generate image + result, err := ctrl.svggen.GenerateImage(ctx, req) + if err != nil { + logger.Printf("Image generation failed: %v", err) + return nil, fmt.Errorf("failed to generate image: %w", err) + } + + logger.Printf("Image generation successful - ID: %s", result.ID) + + return &ImageResponse{ + ID: result.ID, + SVGURL: result.SVGURL, + PNGURL: result.PNGURL, + Width: result.Width, + Height: result.Height, + Provider: result.Provider, + OriginalPrompt: result.OriginalPrompt, + TranslatedPrompt: result.TranslatedPrompt, + WasTranslated: result.WasTranslated, + CreatedAt: result.CreatedAt, + }, nil +} + +// getSVGContent retrieves SVG content from URL or data URL. +func (ctrl *Controller) getSVGContent(ctx context.Context, svgURL string) ([]byte, error) { + if strings.HasPrefix(svgURL, "data:") { + // Parse data URL + return ctrl.parseDataURL(svgURL) + } + + // Download from HTTP URL + return svggen.DownloadFile(ctx, svgURL) +} + +// parseDataURL parses a data URL and returns the decoded content. +func (ctrl *Controller) parseDataURL(dataURL string) ([]byte, error) { + // data URL format: data:[][;base64], + // e.g., data:image/svg+xml;base64, + + if !strings.HasPrefix(dataURL, "data:") { + return nil, errors.New("invalid data URL: missing data: prefix") + } + + // Remove "data:" prefix + dataURL = dataURL[5:] + + // Find comma separator + commaIndex := strings.Index(dataURL, ",") + if commaIndex == -1 { + return nil, errors.New("invalid data URL: missing comma separator") + } + + // Get media type and encoding info + mediaType := dataURL[:commaIndex] + data := dataURL[commaIndex+1:] + + // Check if it's base64 encoded + if strings.Contains(mediaType, "base64") { + // Base64 decode + decoded, err := base64.StdEncoding.DecodeString(data) + if err != nil { + return nil, fmt.Errorf("failed to decode base64 data: %w", err) + } + return decoded, nil + } + + // Not base64 encoded, return string bytes directly + return []byte(data), nil +} + +// GetThemes returns all available themes with their information. +func (ctrl *Controller) GetThemes(ctx context.Context) ([]ThemeInfo, error) { + logger := log.GetReqLogger(ctx) + logger.Printf("GetThemes request") + + themes := GetAllThemesInfo() + + logger.Printf("Returned %d themes", len(themes)) + return themes, nil +} \ No newline at end of file diff --git a/spx-backend/internal/controller/svg_test.go b/spx-backend/internal/controller/svg_test.go new file mode 100644 index 000000000..d5a1c0e07 --- /dev/null +++ b/spx-backend/internal/controller/svg_test.go @@ -0,0 +1,137 @@ +package controller + +import ( + "testing" + + "github.com/goplus/builder/spx-backend/internal/svggen" +) + +func TestGenerateSVGParams_Validate(t *testing.T) { + tests := []struct { + name string + params GenerateSVGParams + wantOK bool + wantMsg string + }{ + { + name: "valid parameters", + params: GenerateSVGParams{ + Prompt: "A cute cat", + Provider: svggen.ProviderSVGIO, + }, + wantOK: true, + wantMsg: "", + }, + { + name: "prompt too short", + params: GenerateSVGParams{ + Prompt: "hi", + }, + wantOK: false, + wantMsg: "prompt must be at least 3 characters", + }, + { + name: "invalid provider", + params: GenerateSVGParams{ + Prompt: "A cute cat", + Provider: "invalid", + }, + wantOK: false, + wantMsg: "provider must be one of: svgio, recraft, openai", + }, + { + name: "default provider", + params: GenerateSVGParams{ + Prompt: "A cute cat", + // Provider not set - should default to SVGIO + }, + wantOK: true, + wantMsg: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotOK, gotMsg := tt.params.Validate() + if gotOK != tt.wantOK { + t.Errorf("Validate() gotOK = %v, want %v", gotOK, tt.wantOK) + } + if gotMsg != tt.wantMsg { + t.Errorf("Validate() gotMsg = %v, want %v", gotMsg, tt.wantMsg) + } + + // Check that default provider is set when not specified + if tt.params.Provider == "" && gotOK { + if tt.params.Provider != svggen.ProviderSVGIO { + t.Errorf("Default provider should be set to SVGIO, got %v", tt.params.Provider) + } + } + }) + } +} + +func TestGenerateImageParams_Validate(t *testing.T) { + params := GenerateImageParams{ + GenerateSVGParams: GenerateSVGParams{ + Prompt: "A beautiful landscape", + Provider: svggen.ProviderRecraft, + }, + } + + ok, msg := params.Validate() + if !ok { + t.Errorf("GenerateImageParams validation failed: %s", msg) + } +} + +func TestController_parseDataURL(t *testing.T) { + ctrl := &Controller{} + + tests := []struct { + name string + dataURL string + wantErr bool + wantData string + }{ + { + name: "valid base64 data URL", + dataURL: "", // + wantErr: false, + wantData: "", + }, + { + name: "valid plain data URL", + dataURL: "", + wantErr: true, + }, + { + name: "invalid data URL - missing comma", + dataURL: "data:image/svg+xml;base64PHN2Zz48L3N2Zz4=", + wantErr: true, + }, + { + name: "invalid base64", + dataURL: "!!!", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ctrl.parseDataURL(tt.dataURL) + if (err != nil) != tt.wantErr { + t.Errorf("parseDataURL() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && string(got) != tt.wantData { + t.Errorf("parseDataURL() = %v, want %v", string(got), tt.wantData) + } + }) + } +} \ No newline at end of file diff --git a/spx-backend/internal/controller/theme.go b/spx-backend/internal/controller/theme.go new file mode 100644 index 000000000..c926019e4 --- /dev/null +++ b/spx-backend/internal/controller/theme.go @@ -0,0 +1,132 @@ +package controller + +import "fmt" + +// ThemeType represents different SVG generation themes +type ThemeType string + +const ( + ThemeNone ThemeType = "" + ThemeCartoon ThemeType = "cartoon" + ThemeRealistic ThemeType = "realistic" + ThemeMinimal ThemeType = "minimal" + ThemeFantasy ThemeType = "fantasy" + ThemeRetro ThemeType = "retro" + ThemeScifi ThemeType = "scifi" + ThemeNature ThemeType = "nature" + ThemeBusiness ThemeType = "business" +) + +// ThemeInfo represents detailed information about a theme +type ThemeInfo struct { + ID ThemeType `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + Prompt string `json:"prompt"` +} + +// ThemePrompts maps each theme to its corresponding prompt enhancement +var ThemePrompts = map[ThemeType]string{ + ThemeCartoon: "必须使用卡通风格,必须色彩鲜艳丰富,必须可爱有趣,严格使用简单几何形状,强制使用明亮饱和的色彩,禁止写实细节", + ThemeRealistic: "必须使用写实风格,严格要求高度细节化,强制逼真效果,必须专业高质量渲染,禁止卡通化或简化元素", + ThemeMinimal: "必须使用极简风格,严格限制元素数量,强制使用干净线条和几何形状,严格使用黑白或单色调,禁止复杂装饰", + ThemeFantasy: "必须使用奇幻魔法风格,强制添加神秘魔法元素,严格使用梦幻色彩,必须包含超自然效果,禁止现实主义元素", + ThemeRetro: "必须使用复古怀旧风格,严格遵循经典老式美学,强制使用怀旧色调和设计元素,禁止现代化元素", + ThemeScifi: "必须使用科幻未来风格,强制添加科技元素,严格使用霓虹和金属色彩,必须包含未来感设计,禁止传统元素", + ThemeNature: "必须使用自然有机风格,严格使用自然元素和植物,强制使用大地色调和绿色系,禁止人工几何元素", + ThemeBusiness: "必须使用商务专业风格,严格保持企业形象,强制使用现代简洁设计,必须专业精致,禁止卡通或娱乐元素", +} + +// ThemeNames maps each theme to its Chinese name +var ThemeNames = map[ThemeType]string{ + ThemeNone: "无主题", + ThemeCartoon: "卡通风格", + ThemeRealistic: "写实风格", + ThemeMinimal: "极简风格", + ThemeFantasy: "奇幻风格", + ThemeRetro: "复古风格", + ThemeScifi: "科幻风格", + ThemeNature: "自然风格", + ThemeBusiness: "商务风格", +} + +// ThemeDescriptions maps each theme to its description +var ThemeDescriptions = map[ThemeType]string{ + ThemeNone: "不应用任何特定主题风格", + ThemeCartoon: "色彩鲜艳的卡通风格,适合可爱有趣的内容", + ThemeRealistic: "高度写实的风格,细节丰富逼真", + ThemeMinimal: "极简主义风格,简洁干净的设计", + ThemeFantasy: "充满魔法和超自然元素的奇幻风格", + ThemeRetro: "怀旧复古风格,经典老式美学", + ThemeScifi: "未来科技风格,充满科幻元素", + ThemeNature: "自然有机风格,使用自然元素和大地色调", + ThemeBusiness: "专业商务风格,现代企业形象", +} + +// IsValidTheme checks if the given theme is valid +func IsValidTheme(theme ThemeType) bool { + if theme == ThemeNone { + return true + } + _, exists := ThemePrompts[theme] + return exists +} + +// GetThemePromptEnhancement returns the prompt enhancement for a given theme +func GetThemePromptEnhancement(theme ThemeType) string { + if theme == ThemeNone { + return "" + } + return ThemePrompts[theme] +} + +// ApplyThemeToPrompt applies theme enhancement to the original prompt +func ApplyThemeToPrompt(originalPrompt string, theme ThemeType) string { + if theme == ThemeNone { + return originalPrompt + } + + themeEnhancement := GetThemePromptEnhancement(theme) + if themeEnhancement == "" { + return originalPrompt + } + + return fmt.Sprintf("%s,%s", originalPrompt, themeEnhancement) +} + +// GetAvailableThemes returns all available themes +func GetAvailableThemes() []ThemeType { + return []ThemeType{ + ThemeNone, + ThemeCartoon, + ThemeRealistic, + ThemeMinimal, + ThemeFantasy, + ThemeRetro, + ThemeScifi, + ThemeNature, + ThemeBusiness, + } +} + +// GetThemeInfo returns detailed information for a specific theme +func GetThemeInfo(theme ThemeType) ThemeInfo { + return ThemeInfo{ + ID: theme, + Name: ThemeNames[theme], + Description: ThemeDescriptions[theme], + Prompt: ThemePrompts[theme], + } +} + +// GetAllThemesInfo returns detailed information for all themes +func GetAllThemesInfo() []ThemeInfo { + themes := GetAvailableThemes() + result := make([]ThemeInfo, len(themes)) + + for i, theme := range themes { + result[i] = GetThemeInfo(theme) + } + + return result +} \ No newline at end of file diff --git a/spx-backend/internal/controller/theme_test.go b/spx-backend/internal/controller/theme_test.go new file mode 100644 index 000000000..5c48bc6a4 --- /dev/null +++ b/spx-backend/internal/controller/theme_test.go @@ -0,0 +1,195 @@ +package controller + +import ( + "testing" +) + +func TestIsValidTheme(t *testing.T) { + tests := []struct { + theme ThemeType + expected bool + }{ + {ThemeNone, true}, + {ThemeCartoon, true}, + {ThemeRealistic, true}, + {ThemeMinimal, true}, + {ThemeFantasy, true}, + {ThemeRetro, true}, + {ThemeScifi, true}, + {ThemeNature, true}, + {ThemeBusiness, true}, + {"invalid", false}, + {"", true}, // ThemeNone + } + + for _, test := range tests { + result := IsValidTheme(test.theme) + if result != test.expected { + t.Errorf("IsValidTheme(%q) = %v, expected %v", test.theme, result, test.expected) + } + } +} + +func TestGetThemePromptEnhancement(t *testing.T) { + tests := []struct { + theme ThemeType + expected string + }{ + {ThemeNone, ""}, + {ThemeCartoon, "必须使用卡通风格,必须色彩鲜艳丰富,必须可爱有趣,严格使用简单几何形状,强制使用明亮饱和的色彩,禁止写实细节"}, + {ThemeRealistic, "必须使用写实风格,严格要求高度细节化,强制逼真效果,必须专业高质量渲染,禁止卡通化或简化元素"}, + {ThemeMinimal, "必须使用极简风格,严格限制元素数量,强制使用干净线条和几何形状,严格使用黑白或单色调,禁止复杂装饰"}, + {ThemeFantasy, "必须使用奇幻魔法风格,强制添加神秘魔法元素,严格使用梦幻色彩,必须包含超自然效果,禁止现实主义元素"}, + } + + for _, test := range tests { + result := GetThemePromptEnhancement(test.theme) + if result != test.expected { + t.Errorf("GetThemePromptEnhancement(%q) = %q, expected %q", test.theme, result, test.expected) + } + } +} + +func TestApplyThemeToPrompt(t *testing.T) { + tests := []struct { + prompt string + theme ThemeType + expected string + }{ + {"一只猫", ThemeNone, "一只猫"}, + {"一只猫", ThemeCartoon, "一只猫,必须使用卡通风格,必须色彩鲜艳丰富,必须可爱有趣,严格使用简单几何形状,强制使用明亮饱和的色彩,禁止写实细节"}, + {"一座房子", ThemeMinimal, "一座房子,必须使用极简风格,严格限制元素数量,强制使用干净线条和几何形状,严格使用黑白或单色调,禁止复杂装饰"}, + {"", ThemeCartoon, ",必须使用卡通风格,必须色彩鲜艳丰富,必须可爱有趣,严格使用简单几何形状,强制使用明亮饱和的色彩,禁止写实细节"}, + } + + for _, test := range tests { + result := ApplyThemeToPrompt(test.prompt, test.theme) + if result != test.expected { + t.Errorf("ApplyThemeToPrompt(%q, %q) = %q, expected %q", test.prompt, test.theme, result, test.expected) + } + } +} + +func TestGetAvailableThemes(t *testing.T) { + themes := GetAvailableThemes() + expectedCount := 9 // ThemeNone + 8 themed options + + if len(themes) != expectedCount { + t.Errorf("GetAvailableThemes() returned %d themes, expected %d", len(themes), expectedCount) + } + + // Check that all themes are valid + for _, theme := range themes { + if !IsValidTheme(theme) { + t.Errorf("GetAvailableThemes() returned invalid theme: %q", theme) + } + } +} + +func TestGenerateSVGParamsValidateWithTheme(t *testing.T) { + tests := []struct { + params GenerateSVGParams + valid bool + errorMsg string + }{ + { + params: GenerateSVGParams{ + Prompt: "test prompt", + Theme: ThemeCartoon, + }, + valid: true, + }, + { + params: GenerateSVGParams{ + Prompt: "test prompt", + Theme: ThemeNone, + }, + valid: true, + }, + { + params: GenerateSVGParams{ + Prompt: "test prompt", + Theme: "invalid_theme", + }, + valid: false, + errorMsg: "invalid theme type", + }, + { + params: GenerateSVGParams{ + Prompt: "ab", // too short + Theme: ThemeCartoon, + }, + valid: false, + errorMsg: "prompt must be at least 3 characters", + }, + } + + for i, test := range tests { + valid, msg := test.params.Validate() + if valid != test.valid { + t.Errorf("Test %d: Validate() = %v, expected %v", i, valid, test.valid) + } + if !test.valid && msg != test.errorMsg { + t.Errorf("Test %d: Validate() error message = %q, expected %q", i, msg, test.errorMsg) + } + } +} + +func TestGetThemeInfo(t *testing.T) { + info := GetThemeInfo(ThemeCartoon) + + expected := ThemeInfo{ + ID: ThemeCartoon, + Name: "卡通风格", + Description: "色彩鲜艳的卡通风格,适合可爱有趣的内容", + Prompt: "必须使用卡通风格,必须色彩鲜艳丰富,必须可爱有趣,严格使用简单几何形状,强制使用明亮饱和的色彩,禁止写实细节", + } + + if info.ID != expected.ID { + t.Errorf("GetThemeInfo ID = %v, expected %v", info.ID, expected.ID) + } + if info.Name != expected.Name { + t.Errorf("GetThemeInfo Name = %v, expected %v", info.Name, expected.Name) + } + if info.Description != expected.Description { + t.Errorf("GetThemeInfo Description = %v, expected %v", info.Description, expected.Description) + } + if info.Prompt != expected.Prompt { + t.Errorf("GetThemeInfo Prompt = %v, expected %v", info.Prompt, expected.Prompt) + } +} + +func TestGetAllThemesInfo(t *testing.T) { + themes := GetAllThemesInfo() + expectedCount := 9 + + if len(themes) != expectedCount { + t.Errorf("GetAllThemesInfo returned %d themes, expected %d", len(themes), expectedCount) + } + + // Check that all themes have required fields + for _, theme := range themes { + if theme.ID == "" && theme.Name != "无主题" { + t.Errorf("Theme ID is empty for theme: %s", theme.Name) + } + if theme.Name == "" { + t.Errorf("Theme Name is empty for ID: %s", theme.ID) + } + if theme.ID != ThemeNone && theme.Prompt == "" { + t.Errorf("Theme Prompt is empty for ID: %s", theme.ID) + } + // Description can be empty for some themes, so we don't check it + } + + // Check that ThemeNone is included + found := false + for _, theme := range themes { + if theme.ID == ThemeNone { + found = true + break + } + } + if !found { + t.Errorf("ThemeNone not found in GetAllThemesInfo result") + } +} \ No newline at end of file diff --git a/spx-backend/internal/svggen/example_usage.md b/spx-backend/internal/svggen/example_usage.md new file mode 100644 index 000000000..1f1571eaa --- /dev/null +++ b/spx-backend/internal/svggen/example_usage.md @@ -0,0 +1,156 @@ +# SVG Generation Service Usage + +This document shows how to use the migrated SVG generation services in spx-backend. + +## Quick Start + +```go +package main + +import ( + "context" + "log" + + "github.com/goplus/builder/spx-backend/internal/config" + "github.com/goplus/builder/spx-backend/internal/svggen" + qlog "github.com/qiniu/x/log" +) + +func main() { + // Load configuration + cfg, err := config.Load(qlog.Std) + if err != nil { + log.Fatal(err) + } + + // Create service manager + sm := svggen.NewServiceManager(cfg, qlog.Std) + + // Create generation request + req := svggen.GenerateRequest{ + Prompt: "A cute cartoon cat sitting on a cloud", + Style: "vector_illustration", + Provider: svggen.ProviderSVGIO, // or ProviderRecraft, ProviderOpenAI + Format: "svg", + } + + // Generate image + ctx := context.Background() + response, err := sm.GenerateImage(ctx, req) + if err != nil { + log.Fatal(err) + } + + // Use the response + fmt.Printf("Generated SVG ID: %s\n", response.ID) + fmt.Printf("SVG URL: %s\n", response.SVGURL) + fmt.Printf("PNG URL: %s\n", response.PNGURL) +} +``` + +## Configuration + +The SVG generation services use the configuration from the main spx-backend config system. Make sure to set up the following environment variables: + +### SVGIO Provider +```bash +SVGIO_ENABLED=true +SVGIO_BASE_URL=https://api.svg.io +SVGIO_GENERATE_ENDPOINT=/v1/generate-image +SVGIO_TIMEOUT=60s +SVGIO_MAX_RETRIES=3 +``` + +### Recraft Provider +```bash +RECRAFT_ENABLED=true +RECRAFT_BASE_URL=https://external.api.recraft.ai +RECRAFT_GENERATE_ENDPOINT=/v1/images/generations +RECRAFT_VECTORIZE_ENDPOINT=/v1/images/vectorize +RECRAFT_DEFAULT_MODEL=recraftv3 +RECRAFT_SUPPORTED_MODELS=recraftv3,recraftv2 +RECRAFT_TIMEOUT=60s +RECRAFT_MAX_RETRIES=3 +``` + +### OpenAI Provider (SVG Generation) +```bash +SVG_OPENAI_ENABLED=true +SVG_OPENAI_BASE_URL=https://api.qnaigc.com/v1/ +SVG_OPENAI_DEFAULT_MODEL=claude-4.0-sonnet +SVG_OPENAI_MAX_TOKENS=4000 +SVG_OPENAI_TEMPERATURE=0.7 +SVG_OPENAI_TIMEOUT=60s +SVG_OPENAI_MAX_RETRIES=3 +``` + +## API Integration + +To integrate this into the spx-backend HTTP API, you can create a new controller: + +```go +// In internal/controller/svg.go +package controller + +import ( + "net/http" + + "github.com/goplus/builder/spx-backend/internal/svggen" + "github.com/gin-gonic/gin" +) + +func (ctrl *Controller) GenerateSVG(c *gin.Context) { + var req svggen.GenerateRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + response, err := ctrl.svgManager.GenerateImage(c.Request.Context(), req) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, response) +} +``` + +## Provider Features + +### SVGIO +- Direct SVG generation +- Built-in PNG export +- Style customization +- Negative prompt support + +### Recraft +- High-quality image generation +- Automatic vectorization for SVG +- Multiple model support (recraftv3, recraftv2) +- Style and substyle options + +### OpenAI +- AI-powered SVG code generation +- Highly customizable prompts +- Compatible with OpenAI API format +- Supports various models (GPT-4, Claude, etc.) + +## Error Handling + +All services implement proper error handling and logging. Errors are returned with descriptive messages and logged with appropriate context. + +```go +response, err := sm.GenerateImage(ctx, req) +if err != nil { + // Handle specific errors + switch { + case strings.Contains(err.Error(), "provider not configured"): + // Provider is disabled or not set up + case strings.Contains(err.Error(), "http request"): + // Network or API errors + default: + // Other errors + } +} +``` \ No newline at end of file diff --git a/spx-backend/internal/svggen/openai.go b/spx-backend/internal/svggen/openai.go new file mode 100644 index 000000000..f3b68a1e9 --- /dev/null +++ b/spx-backend/internal/svggen/openai.go @@ -0,0 +1,270 @@ +package svggen + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "os" + "regexp" + "strings" + "time" + + "github.com/goplus/builder/spx-backend/internal/config" + "github.com/goplus/builder/spx-backend/internal/log" + qlog "github.com/qiniu/x/log" +) + +// OpenAIService implements OpenAI compatible API calls. +type OpenAIService struct { + config *config.OpenAISVGConfig + httpClient *http.Client + logger *qlog.Logger +} + +// NewOpenAIService creates a new OpenAI service instance. +func NewOpenAIService(cfg *config.Config, httpClient *http.Client, logger *qlog.Logger) *OpenAIService { + return &OpenAIService{ + config: &cfg.Providers.SVGOpenAI, + httpClient: httpClient, + logger: logger, + } +} + +// GenerateImage generates SVG code using OpenAI compatible models. +func (s *OpenAIService) GenerateImage(ctx context.Context, req GenerateRequest) (*ImageResponse, error) { + logger := log.GetReqLogger(ctx) + logger.Printf("[OPENAI] Starting SVG generation request...") + + // OpenAI supports Chinese natively, so we use the prompt as-is + // Translation is handled at the ServiceManager level if needed + + // Build OpenAI prompt + prompt := s.buildSVGPrompt(req.Prompt, req.Style, req.NegativePrompt) + + // Build OpenAI API request + openaiReq := OpenAIGenerateReq{ + Model: s.getModelFromConfig(req.Model), + MaxTokens: s.config.MaxTokens, + Temperature: s.config.Temperature, + Messages: []OpenAIMessage{ + { + Role: "system", + Content: `You are a world-class SVG graphics designer and vector artist with expertise in creating stunning, precise, and semantically meaningful SVG illustrations. Your specialties include: + +1. **Technical Excellence**: You create perfectly valid, optimized SVG code that renders flawlessly across all browsers and devices +2. **Visual Design**: You have an exceptional eye for composition, color theory, typography, and visual hierarchy +3. **Style Adaptation**: You can seamlessly adapt to any artistic style - from minimalist line art to detailed illustrations, from cartoon to realistic, from modern flat design to vintage aesthetics +4. **Semantic Structure**: You use meaningful element IDs, proper grouping, and clean hierarchical structure in your SVG code +5. **Optimization**: Your SVG code is clean, efficient, and follows best practices for file size and performance + +When creating SVG graphics, you: +- Pay careful attention to the exact subject, style, and mood requested +- Use appropriate colors, gradients, and visual effects to match the desired aesthetic +- Ensure proper proportions, perspective, and composition +- Add fine details that enhance the overall quality and realism +- Create scalable graphics that look crisp at any size +- Follow accessibility best practices when relevant + +You respond ONLY with clean, valid SVG code - no explanations, no code blocks, just the pure SVG markup ready to render.`, + }, + { + Role: "user", + Content: prompt, + }, + }, + } + + body, err := json.Marshal(openaiReq) + if err != nil { + return nil, fmt.Errorf("marshal request: %w", err) + } + + url := s.config.BaseURL + "chat/completions" + logger.Printf("[OPENAI] Sending request to %s", url) + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Authorization", "Bearer "+s.getAPIKey()) + + resp, err := s.httpClient.Do(httpReq) + if err != nil { + logger.Printf("[OPENAI] HTTP request failed: %v", err) + return nil, fmt.Errorf("http request: %w", err) + } + defer resp.Body.Close() + + logger.Printf("[OPENAI] Received response with status: %s", resp.Status) + + if resp.StatusCode >= 300 { + var errResp map[string]any + _ = json.NewDecoder(resp.Body).Decode(&errResp) + logger.Printf("[OPENAI] Error response body: %+v", errResp) + return nil, fmt.Errorf("openai API error: %s", resp.Status) + } + + var openaiResp OpenAIGenerateResp + if err := json.NewDecoder(resp.Body).Decode(&openaiResp); err != nil { + logger.Printf("[OPENAI] Failed to decode response: %v", err) + return nil, fmt.Errorf("decode openai response: %w", err) + } + + // Add debug information + logger.Printf("[OPENAI] Response structure: ID=%s, Object=%s, Model=%s, Choices length=%d", + openaiResp.ID, openaiResp.Object, openaiResp.Model, len(openaiResp.Choices)) + + // If there's content, print the first choice's content + if len(openaiResp.Choices) > 0 { + firstChoice := openaiResp.Choices[0] + logger.Printf("[OPENAI] First choice: Role=%s, Content prefix=%s", + firstChoice.Message.Role, s.truncateString(firstChoice.Message.Content, 100)) + } + + if len(openaiResp.Choices) == 0 { + logger.Printf("[OPENAI] Choices array is empty, response: %+v", openaiResp) + return nil, fmt.Errorf("no choices in openai response") + } + + // Extract SVG code + svgContent := openaiResp.Choices[0].Message.Content + svgCode := s.extractSVGCode(svgContent) + + if svgCode == "" { + logger.Printf("[OPENAI] No valid SVG found in response") + return nil, fmt.Errorf("no valid SVG generated") + } + + // Generate temporary SVG file URL (in actual application, might need to save to file service) + imageID := GenerateImageID(ProviderOpenAI) + svgURL := s.createSVGDataURL(svgCode) + + logger.Printf("[OPENAI] Successfully generated SVG - ID: %s", imageID) + + return &ImageResponse{ + ID: imageID, + Prompt: req.Prompt, + NegativePrompt: req.NegativePrompt, + Style: req.Style, + SVGURL: svgURL, + PNGURL: svgURL, // OpenAI generates SVG code, both URLs are the same + Width: 1024, // Default size + Height: 1024, + CreatedAt: time.Now(), + Provider: ProviderOpenAI, + }, nil +} + +// buildSVGPrompt builds a prompt for generating SVG. +func (s *OpenAIService) buildSVGPrompt(prompt, style, negativePrompt string) string { + var promptBuilder strings.Builder + + promptBuilder.WriteString("Create a high-quality SVG illustration of: ") + promptBuilder.WriteString(prompt) + + if style != "" { + promptBuilder.WriteString(fmt.Sprintf("\n\nArtistic style and visual requirements: %s", style)) + } + + if negativePrompt != "" { + promptBuilder.WriteString(fmt.Sprintf("\n\nIMPORTANT - Do NOT include these elements: %s", negativePrompt)) + } + + promptBuilder.WriteString(` + +Technical Requirements: +• Use viewBox="0 0 1024 1024" for consistent sizing +• Ensure the SVG is completely self-contained and valid +• Use semantic and descriptive element IDs (e.g., id="main-character", id="background-sky") +• Organize elements in logical groups using tags +• Use appropriate colors that match the subject and style +• Include gradients, shadows, or other effects when they enhance the design +• Ensure the illustration is centered and well-composed within the viewBox +• Make it scalable and crisp at any resolution + +Visual Quality Standards: +• Pay attention to proper proportions and anatomy +• Use appropriate line weights and stroke styles +• Include relevant details that make the illustration engaging +• Ensure good contrast and readability +• Follow the specified artistic style consistently +• Create depth and dimension through layering and visual effects + +Output Format: +Return ONLY the complete SVG code starting with and ending with . No explanations, no code blocks, no additional text. +In SVG, do not use elements like rect; all drawing should be implemented via the path element. + +`) + + return promptBuilder.String() +} + +// extractSVGCode extracts SVG code from the response. +func (s *OpenAIService) extractSVGCode(content string) string { + // Look for SVG tags + svgRegex := regexp.MustCompile(`(?s)]*>.*?`) + matches := svgRegex.FindString(content) + + if matches != "" { + return matches + } + + // If no complete SVG found, try to look for SVG code blocks + codeBlockRegex := regexp.MustCompile("(?s)```(?:svg|xml)?\n?(.*?)\n?```") + codeMatches := codeBlockRegex.FindStringSubmatch(content) + + if len(codeMatches) > 1 { + svgCode := strings.TrimSpace(codeMatches[1]) + if strings.Contains(svgCode, "") { + return strings.TrimSpace(content) + } + + return "" +} + +// createSVGDataURL creates an SVG Data URL. +func (s *OpenAIService) createSVGDataURL(svgCode string) string { + // For demonstration, we return a data URL + // In actual production environment, you might want to save to file service and return real URL + return fmt.Sprintf("data:image/svg+xml;base64,%s", + s.encodeSVGToBase64(svgCode)) +} + +// encodeSVGToBase64 encodes SVG to Base64. +func (s *OpenAIService) encodeSVGToBase64(svgCode string) string { + return base64.StdEncoding.EncodeToString([]byte(svgCode)) +} + +// truncateString truncates string for logging. +func (s *OpenAIService) truncateString(str string, maxLen int) string { + if len(str) <= maxLen { + return str + } + return str[:maxLen] + "..." +} + +// getModelFromConfig gets model name from configuration or request. +func (s *OpenAIService) getModelFromConfig(requestModel string) string { + if requestModel != "" { + return requestModel + } + return s.config.DefaultModel +} + +// getAPIKey gets the API key from environment or configuration. +func (s *OpenAIService) getAPIKey() string { + // Get API key from environment variable + // In production, consider using a more secure method like AWS Secrets Manager + return os.Getenv("SVG_OPENAI_API_KEY") +} diff --git a/spx-backend/internal/svggen/recraft.go b/spx-backend/internal/svggen/recraft.go new file mode 100644 index 000000000..b2f486ef4 --- /dev/null +++ b/spx-backend/internal/svggen/recraft.go @@ -0,0 +1,258 @@ +package svggen + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "mime/multipart" + "net/http" + "os" + "strconv" + "strings" + "time" + + "github.com/goplus/builder/spx-backend/internal/config" + "github.com/goplus/builder/spx-backend/internal/log" + qlog "github.com/qiniu/x/log" +) + +// RecraftService implements Recraft API calls. +type RecraftService struct { + config *config.RecraftConfig + httpClient *http.Client + logger *qlog.Logger +} + +// NewRecraftService creates a new Recraft service instance. +func NewRecraftService(cfg *config.Config, httpClient *http.Client, logger *qlog.Logger) *RecraftService { + return &RecraftService{ + config: &cfg.Providers.Recraft, + httpClient: httpClient, + logger: logger, + } +} + +// GenerateImage generates an image using Recraft API. +func (s *RecraftService) GenerateImage(ctx context.Context, req GenerateRequest) (*ImageResponse, error) { + logger := log.GetReqLogger(ctx) + logger.Printf("[RECRAFT] Starting generation request...") + + // Build optimized prompt (add transparent background requirement) + enhancedPrompt, enhancedNegativePrompt := s.buildRecraftPrompt(req.Prompt, req.Style, req.NegativePrompt) + + logger.Printf("[RECRAFT] Original prompt: %s", req.Prompt) + logger.Printf("[RECRAFT] Enhanced prompt: %s", enhancedPrompt) + logger.Printf("[RECRAFT] Enhanced negative prompt: %s", enhancedNegativePrompt) + + // Build Recraft API request + recraftReq := RecraftGenerateReq{ + Prompt: enhancedPrompt, + NegativePrompt: enhancedNegativePrompt, + Style: req.Style, + Substyle: req.Substyle, + Model: req.Model, + Size: req.Size, + N: req.NumImages, + ResponseFormat: "url", // Fixed to use URL format + } + + // Set default values + if recraftReq.Model == "" { + recraftReq.Model = s.config.DefaultModel + } + if recraftReq.Size == "" { + recraftReq.Size = "1024x1024" + } + if recraftReq.Style == "" { + recraftReq.Style = "vector_illustration" + } + if recraftReq.N == 0 { + recraftReq.N = 1 + } + + body, err := json.Marshal(recraftReq) + if err != nil { + return nil, fmt.Errorf("marshal request: %w", err) + } + + url := s.config.BaseURL + s.config.Endpoints.Generate + logger.Printf("[RECRAFT] Sending request to %s with payload size: %d bytes", url, len(body)) + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Authorization", "Bearer "+s.getAPIKey()) + + resp, err := s.httpClient.Do(httpReq) + if err != nil { + logger.Printf("[RECRAFT] HTTP request failed: %v", err) + return nil, fmt.Errorf("http request: %w", err) + } + defer resp.Body.Close() + + logger.Printf("[RECRAFT] Received response with status: %s", resp.Status) + + if resp.StatusCode >= 300 { + var errResp map[string]interface{} + _ = json.NewDecoder(resp.Body).Decode(&errResp) + logger.Printf("[RECRAFT] Error response body: %+v", errResp) + return nil, fmt.Errorf("recraft API error: %s", resp.Status) + } + + var recraftResp RecraftGenerateResp + if err := json.NewDecoder(resp.Body).Decode(&recraftResp); err != nil { + logger.Printf("[RECRAFT] Failed to decode response: %v", err) + return nil, fmt.Errorf("decode response: %w", err) + } + + if len(recraftResp.Data) == 0 { + logger.Printf("[RECRAFT] No images in response") + return nil, errors.New("no images generated") + } + + imageData := recraftResp.Data[0] // Take the first image + logger.Printf("[RECRAFT] Successfully parsed response - URL: %s", imageData.URL) + + // Parse image dimensions + width, height := s.parseSizeFromString(recraftReq.Size) + + // Generate a simple ID (Recraft doesn't provide one) + imageID := GenerateImageID(ProviderRecraft) + + // For Recraft, we need to convert the image to SVG + // Here we can use Recraft's vectorize API + svgURL := imageData.URL // Default to original image + + // If SVG is needed, call vectorize API + if req.Format == "svg" || strings.Contains(req.Style, "vector") { + vectorizedURL, err := s.vectorizeImage(ctx, imageData.URL) + if err != nil { + logger.Printf("[RECRAFT] Vectorization failed: %v", err) + // Continue with original image on failure + } else { + svgURL = vectorizedURL + } + } + + return &ImageResponse{ + ID: imageID, + Prompt: req.Prompt, + NegativePrompt: req.NegativePrompt, + Style: recraftReq.Style, + SVGURL: svgURL, + PNGURL: imageData.URL, + Width: width, + Height: height, + CreatedAt: time.Unix(int64(recraftResp.Created), 0), + Provider: ProviderRecraft, + }, nil +} + +// vectorizeImage uses Recraft's vectorization API to convert image to SVG. +func (s *RecraftService) vectorizeImage(ctx context.Context, imageURL string) (string, error) { + logger := log.GetReqLogger(ctx) + logger.Printf("[RECRAFT] Vectorizing image: %s", imageURL) + + // Download image + imageBytes, err := DownloadFile(ctx, imageURL) + if err != nil { + return "", fmt.Errorf("download image: %w", err) + } + + // Create multipart form + var buf bytes.Buffer + writer := multipart.NewWriter(&buf) + + // Add image file + part, err := writer.CreateFormFile("file", "image.png") + if err != nil { + return "", fmt.Errorf("create form file: %w", err) + } + _, err = part.Write(imageBytes) + if err != nil { + return "", fmt.Errorf("write image data: %w", err) + } + + // Add response_format parameter + writer.WriteField("response_format", "url") + writer.Close() + + url := s.config.BaseURL + s.config.Endpoints.Vectorize + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, &buf) + if err != nil { + return "", fmt.Errorf("create vectorize request: %w", err) + } + + httpReq.Header.Set("Content-Type", writer.FormDataContentType()) + httpReq.Header.Set("Authorization", "Bearer "+s.getAPIKey()) + + resp, err := s.httpClient.Do(httpReq) + if err != nil { + return "", fmt.Errorf("vectorize request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode >= 300 { + return "", fmt.Errorf("vectorize API error: %s", resp.Status) + } + + var vectorizeResp RecraftVectorizeResp + if err := json.NewDecoder(resp.Body).Decode(&vectorizeResp); err != nil { + return "", fmt.Errorf("decode vectorize response: %w", err) + } + + logger.Printf("[RECRAFT] Vectorization successful - SVG URL: %s", vectorizeResp.Image.URL) + return vectorizeResp.Image.URL, nil +} + +// parseSizeFromString parses size string like "1024x1024". +func (s *RecraftService) parseSizeFromString(size string) (width, height int) { + parts := strings.Split(size, "x") + if len(parts) != 2 { + return 1024, 1024 // Default value + } + + width, _ = strconv.Atoi(parts[0]) + height, _ = strconv.Atoi(parts[1]) + + if width <= 0 { + width = 1024 + } + if height <= 0 { + height = 1024 + } + + return width, height +} + +// buildRecraftPrompt builds Recraft prompt, automatically adding transparent background requirement. +func (s *RecraftService) buildRecraftPrompt(prompt, style, negativePrompt string) (string, string) { + var promptBuilder strings.Builder + var negativeBuilder strings.Builder + + // Build main prompt + promptBuilder.WriteString(prompt) + promptBuilder.WriteString(", 背景色为透明色,不要背景边框") + + // Build negative prompt + if negativePrompt != "" { + negativeBuilder.WriteString(negativePrompt) + negativeBuilder.WriteString(", ") + } + + // Add default background-related negative prompts + + return promptBuilder.String(), negativeBuilder.String() +} + +// getAPIKey gets the API key from environment or configuration. +func (s *RecraftService) getAPIKey() string { + // Get API key from environment variable + // In production, consider using a more secure method like AWS Secrets Manager + return os.Getenv("RECRAFT_API_KEY") +} \ No newline at end of file diff --git a/spx-backend/internal/svggen/svggen.go b/spx-backend/internal/svggen/svggen.go new file mode 100644 index 000000000..61a94238c --- /dev/null +++ b/spx-backend/internal/svggen/svggen.go @@ -0,0 +1,141 @@ +package svggen + +import ( + "context" + "errors" + "net/http" + "time" + + "github.com/goplus/builder/spx-backend/internal/config" + "github.com/goplus/builder/spx-backend/internal/log" + qlog "github.com/qiniu/x/log" +) + +// Provider interface defines the contract for image generation providers. +type ProviderService interface { + GenerateImage(ctx context.Context, req GenerateRequest) (*ImageResponse, error) +} + +// ServiceManager manages multiple upstream services. +type ServiceManager struct { + svgioService ProviderService + recraftService ProviderService + openaiService ProviderService + translateService TranslateService + httpClient *http.Client + logger *qlog.Logger +} + +// NewServiceManager creates a new service manager. +func NewServiceManager(cfg *config.Config, logger *qlog.Logger) *ServiceManager { + httpClient := &http.Client{ + Timeout: 60 * time.Second, + } + + sm := &ServiceManager{ + httpClient: httpClient, + logger: logger, + } + + // Initialize providers based on configuration + if cfg.Providers.SVGIO.Enabled { + sm.svgioService = NewSVGIOService(cfg, httpClient, logger) + } + + if cfg.Providers.Recraft.Enabled { + sm.recraftService = NewRecraftService(cfg, httpClient, logger) + } + + if cfg.Providers.SVGOpenAI.Enabled { + sm.openaiService = NewOpenAIService(cfg, httpClient, logger) + // Initialize translation service using the same OpenAI configuration + sm.translateService = NewOpenAITranslateService(cfg, httpClient, logger) + } + + return sm +} + +// RegisterProvider registers a new provider. +func (sm *ServiceManager) RegisterProvider(providerType Provider, provider ProviderService) { + switch providerType { + case ProviderSVGIO: + sm.svgioService = provider + case ProviderRecraft: + sm.recraftService = provider + case ProviderOpenAI: + sm.openaiService = provider + } +} + +// GetProvider gets the specified provider. +func (sm *ServiceManager) GetProvider(providerType Provider) ProviderService { + switch providerType { + case ProviderSVGIO: + return sm.svgioService + case ProviderRecraft: + return sm.recraftService + case ProviderOpenAI: + return sm.openaiService + default: + return sm.svgioService // Default to SVGIO + } +} + +// GenerateImage generates an image using the specified provider. +func (sm *ServiceManager) GenerateImage(ctx context.Context, req GenerateRequest) (*ImageResponse, error) { + logger := log.GetReqLogger(ctx) + + provider := sm.GetProvider(req.Provider) + if provider == nil { + logger.Printf("provider not configured: %s", string(req.Provider)) + return nil, errors.New("provider not configured: " + string(req.Provider)) + } + + // Handle translation for providers that need it + originalPrompt := req.Prompt + translatedPrompt := req.Prompt + wasTranslated := false + + // Only translate for SVGIO provider (Recraft and OpenAI support Chinese natively) + if !req.SkipTranslate && sm.translateService != nil && req.Provider == ProviderSVGIO { + translated, err := sm.translateService.Translate(ctx, req.Prompt) + if err != nil { + logger.Printf("translation failed: %v", err) + // Continue with original prompt if translation fails + } else if translated != req.Prompt { + translatedPrompt = translated + wasTranslated = true + req.Prompt = translatedPrompt + logger.Printf("prompt translated: %q -> %q", originalPrompt, translatedPrompt) + } + } + + logger.Printf("generating image with provider: %s", string(req.Provider)) + resp, err := provider.GenerateImage(ctx, req) + if err != nil { + return nil, err + } + + // Add translation information to response + if wasTranslated { + resp.OriginalPrompt = originalPrompt + resp.TranslatedPrompt = translatedPrompt + resp.WasTranslated = wasTranslated + } + + return resp, nil +} + +// IsProviderEnabled checks if a provider is enabled. +func (sm *ServiceManager) IsProviderEnabled(provider Provider) bool { + switch provider { + case ProviderSVGIO: + return sm.svgioService != nil + case ProviderRecraft: + return sm.recraftService != nil + case ProviderOpenAI: + return sm.openaiService != nil + default: + return false + } +} \ No newline at end of file diff --git a/spx-backend/internal/svggen/svggen_test.go b/spx-backend/internal/svggen/svggen_test.go new file mode 100644 index 000000000..a63c85bab --- /dev/null +++ b/spx-backend/internal/svggen/svggen_test.go @@ -0,0 +1,119 @@ +package svggen + +import ( + "context" + "testing" + + "github.com/goplus/builder/spx-backend/internal/config" + qlog "github.com/qiniu/x/log" +) + +func TestServiceManager_NewServiceManager(t *testing.T) { + cfg := &config.Config{ + Providers: config.ProvidersConfig{ + SVGIO: config.SVGIOConfig{ + Enabled: true, + BaseURL: "https://api.svg.io", + Endpoints: config.SVGIOEndpoints{ + Generate: "/v1/generate", + }, + }, + Recraft: config.RecraftConfig{ + Enabled: true, + BaseURL: "https://api.recraft.ai", + DefaultModel: "recraftv3", + Endpoints: config.RecraftEndpoints{ + Generate: "/v1/images/generations", + Vectorize: "/v1/images/vectorize", + }, + }, + SVGOpenAI: config.OpenAISVGConfig{ + Enabled: true, + BaseURL: "https://api.openai.com/v1", + DefaultModel: "gpt-4", + MaxTokens: 4000, + Temperature: 0.7, + }, + }, + } + + logger := qlog.Std + sm := NewServiceManager(cfg, logger) + + if sm == nil { + t.Fatal("ServiceManager should not be nil") + } + + // Test provider availability + if !sm.IsProviderEnabled(ProviderSVGIO) { + t.Error("SVGIO provider should be enabled") + } + + if !sm.IsProviderEnabled(ProviderRecraft) { + t.Error("Recraft provider should be enabled") + } + + if !sm.IsProviderEnabled(ProviderOpenAI) { + t.Error("OpenAI provider should be enabled") + } +} + +func TestServiceManager_GenerateImage(t *testing.T) { + cfg := &config.Config{ + Providers: config.ProvidersConfig{ + SVGIO: config.SVGIOConfig{ + Enabled: false, // Disabled for test + }, + }, + } + + logger := qlog.Std + sm := NewServiceManager(cfg, logger) + + req := GenerateRequest{ + Prompt: "test prompt", + Provider: ProviderSVGIO, + } + + ctx := context.Background() + _, err := sm.GenerateImage(ctx, req) + + // Should return error since provider is not configured + if err == nil { + t.Error("Should return error for unconfigured provider") + } +} + +func TestGenerateImageID(t *testing.T) { + id1 := GenerateImageID(ProviderSVGIO) + id2 := GenerateImageID(ProviderRecraft) + + if id1 == id2 { + t.Error("Generated IDs should be different") + } + + if id1 == "" || id2 == "" { + t.Error("Generated IDs should not be empty") + } +} + +func TestParseSizeFromString(t *testing.T) { + tests := []struct { + input string + width int + height int + }{ + {"512x512", 512, 512}, + {"1024x1024", 1024, 1024}, + {"", 1024, 1024}, + {"invalid", 1024, 1024}, + } + + for _, test := range tests { + w, h := ParseSizeFromString(test.input) + if w != test.width || h != test.height { + t.Errorf("ParseSizeFromString(%s) = (%d, %d), want (%d, %d)", + test.input, w, h, test.width, test.height) + } + } +} \ No newline at end of file diff --git a/spx-backend/internal/svggen/svgio.go b/spx-backend/internal/svggen/svgio.go new file mode 100644 index 000000000..75d811973 --- /dev/null +++ b/spx-backend/internal/svggen/svgio.go @@ -0,0 +1,145 @@ +package svggen + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "net/http" + "os" + "time" + + "github.com/goplus/builder/spx-backend/internal/config" + "github.com/goplus/builder/spx-backend/internal/log" + qlog "github.com/qiniu/x/log" +) + +// SVGIOService implements SVG.IO API calls. +type SVGIOService struct { + config *config.SVGIOConfig + httpClient *http.Client + logger *qlog.Logger +} + +// NewSVGIOService creates a new SVG.IO service instance. +func NewSVGIOService(cfg *config.Config, httpClient *http.Client, logger *qlog.Logger) *SVGIOService { + return &SVGIOService{ + config: &cfg.Providers.SVGIO, + httpClient: httpClient, + logger: logger, + } +} + +// svgioGenerateReq represents SVG.IO API generation request. +type svgioGenerateReq struct { + Prompt string `json:"prompt"` + NegativePrompt *string `json:"negativePrompt,omitempty"` + Style *string `json:"style,omitempty"` +} + +// svgioGenerateResp represents SVG.IO API generation response. +type svgioGenerateResp struct { + Success bool `json:"success"` + Data []struct { + ID string `json:"id"` + Prompt string `json:"prompt"` + NegativePrompt string `json:"negativePrompt"` + Style string `json:"style"` + SVGURL string `json:"svgUrl"` + PNGURL string `json:"pngUrl"` + Width int `json:"width"` + Height int `json:"height"` + CreatedAt string `json:"createdAt"` + } `json:"data"` +} + +// GenerateImage generates an image using SVG.IO API. +func (s *SVGIOService) GenerateImage(ctx context.Context, req GenerateRequest) (*ImageResponse, error) { + logger := log.GetReqLogger(ctx) + logger.Printf("[SVGIO] Starting generation request...") + + upReq := svgioGenerateReq{ + Prompt: req.Prompt, + NegativePrompt: &req.NegativePrompt, + Style: &req.Style, + } + + if req.NegativePrompt == "" { + defaultNegativePrompt := "NULL" + upReq.NegativePrompt = &defaultNegativePrompt + } + + if req.Style == "" { + defaultStyle := "FLAT_VECTOR" + upReq.Style = &defaultStyle + } + + body, err := json.Marshal(upReq) + if err != nil { + logger.Printf("[SVGIO] Failed to marshal request: %v", err) + return nil, err + } + + apiURL := s.config.BaseURL + s.config.Endpoints.Generate + logger.Printf("[SVGIO] Sending request to %s with payload size: %d bytes", apiURL, len(body)) + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body)) + if err != nil { + logger.Printf("[SVGIO] Failed to create request: %v", err) + return nil, err + } + + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Authorization", "Bearer "+s.getAPIKey()) + + resp, err := s.httpClient.Do(httpReq) + if err != nil { + logger.Printf("[SVGIO] HTTP request failed: %v", err) + return nil, err + } + defer resp.Body.Close() + + logger.Printf("[SVGIO] Received response with status: %s", resp.Status) + + if resp.StatusCode >= 300 { + var raw interface{} + _ = json.NewDecoder(resp.Body).Decode(&raw) + logger.Printf("[SVGIO] Error response body: %+v", raw) + return nil, errors.New("upstream status: " + resp.Status) + } + + var upResp svgioGenerateResp + if err := json.NewDecoder(resp.Body).Decode(&upResp); err != nil { + logger.Printf("[SVGIO] Failed to decode response: %v", err) + return nil, err + } + + if !upResp.Success || len(upResp.Data) == 0 { + logger.Printf("[SVGIO] Invalid response: success=%v, data_count=%d", upResp.Success, len(upResp.Data)) + return nil, errors.New("upstream no data") + } + + it := upResp.Data[0] + logger.Printf("[SVGIO] Successfully parsed response - ID: %s, SVG: %s, PNG: %s", it.ID, it.SVGURL, it.PNGURL) + + createdAt, _ := time.Parse(time.RFC3339, it.CreatedAt) + return &ImageResponse{ + ID: it.ID, + Prompt: it.Prompt, + NegativePrompt: it.NegativePrompt, + Style: it.Style, + SVGURL: it.SVGURL, + PNGURL: it.PNGURL, + Width: it.Width, + Height: it.Height, + CreatedAt: createdAt, + Provider: ProviderSVGIO, + }, nil +} + +// getAPIKey gets the API key from environment or configuration. +func (s *SVGIOService) getAPIKey() string { + // Get API key from environment variable + // In production, consider using a more secure method like AWS Secrets Manager + return os.Getenv("SVGIO_API_KEY") +} \ No newline at end of file diff --git a/spx-backend/internal/svggen/translate.go b/spx-backend/internal/svggen/translate.go new file mode 100644 index 000000000..c30dbc2af --- /dev/null +++ b/spx-backend/internal/svggen/translate.go @@ -0,0 +1,131 @@ +package svggen + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "os" + "strings" + + "github.com/goplus/builder/spx-backend/internal/config" + "github.com/goplus/builder/spx-backend/internal/log" + qlog "github.com/qiniu/x/log" +) + +// TranslateService defines the contract for translation services. +type TranslateService interface { + Translate(ctx context.Context, text string) (string, error) +} + +// OpenAITranslateService implements translation using OpenAI compatible APIs. +type OpenAITranslateService struct { + config *config.OpenAISVGConfig + httpClient *http.Client + logger *qlog.Logger +} + +// NewOpenAITranslateService creates a new OpenAI translation service instance. +func NewOpenAITranslateService(cfg *config.Config, httpClient *http.Client, logger *qlog.Logger) *OpenAITranslateService { + return &OpenAITranslateService{ + config: &cfg.Providers.SVGOpenAI, + httpClient: httpClient, + logger: logger, + } +} + +// Translate translates text from Chinese to English for image generation prompts. +func (s *OpenAITranslateService) Translate(ctx context.Context, text string) (string, error) { + logger := log.GetReqLogger(ctx) + + // Check if text contains Chinese characters + if !containsChinese(text) { + logger.Printf("[TRANSLATE] Text appears to be English already, skipping translation: %q", text) + return text, nil + } + + logger.Printf("[TRANSLATE] Translating text: %q", text) + + prompt := fmt.Sprintf(`Please translate the following text to English, maintaining the original meaning and making it suitable for AI image generation prompts. Return only the translation result without any explanation: + +%s`, text) + + reqBody := OpenAIGenerateReq{ + Model: s.config.DefaultModel, + Messages: []OpenAIMessage{ + { + Role: "user", + Content: prompt, + }, + }, + MaxTokens: 150, + Temperature: 0.3, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return "", fmt.Errorf("marshal translate request: %w", err) + } + + url := s.config.BaseURL + "chat/completions" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonData)) + if err != nil { + return "", fmt.Errorf("create translate request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+s.getAPIKey()) + + resp, err := s.httpClient.Do(req) + if err != nil { + logger.Printf("[TRANSLATE] HTTP request failed: %v", err) + return "", fmt.Errorf("http translate request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode >= 300 { + var errResp map[string]any + _ = json.NewDecoder(resp.Body).Decode(&errResp) + logger.Printf("[TRANSLATE] Error response body: %+v", errResp) + return "", fmt.Errorf("translate API error: %s", resp.Status) + } + + var translateResp OpenAIGenerateResp + if err := json.NewDecoder(resp.Body).Decode(&translateResp); err != nil { + return "", fmt.Errorf("decode translate response: %w", err) + } + + if len(translateResp.Choices) == 0 { + return "", errors.New("no translation choices returned") + } + + translated := strings.TrimSpace(translateResp.Choices[0].Message.Content) + logger.Printf("[TRANSLATE] Translation result: %q -> %q", text, translated) + + return translated, nil +} + +// getAPIKey gets the API key from the same source as OpenAI service. +func (s *OpenAITranslateService) getAPIKey() string { + // Use the same API key as OpenAI SVG generation + return getOpenAIAPIKey() +} + +// containsChinese checks if text contains Chinese characters. +func containsChinese(text string) bool { + for _, char := range text { + if char >= 0x4e00 && char <= 0x9fff { + return true + } + } + return false +} + +// getOpenAIAPIKey is a helper function to get OpenAI API key. +// This should match the implementation in openai.go +func getOpenAIAPIKey() string { + // Implementation matches openai.go getAPIKey method + return os.Getenv("SVG_OPENAI_API_KEY") +} \ No newline at end of file diff --git a/spx-backend/internal/svggen/types.go b/spx-backend/internal/svggen/types.go new file mode 100644 index 000000000..ea4cf17fb --- /dev/null +++ b/spx-backend/internal/svggen/types.go @@ -0,0 +1,142 @@ +package svggen + +import "time" + +// Provider defines different image generation providers. +type Provider string + +const ( + ProviderSVGIO Provider = "svgio" + ProviderRecraft Provider = "recraft" + ProviderOpenAI Provider = "openai" +) + +// GenerateRequest represents a request to generate an SVG image. +type GenerateRequest struct { + Prompt string `json:"prompt"` + NegativePrompt string `json:"negative_prompt,omitempty"` + Style string `json:"style,omitempty"` + Theme string `json:"theme,omitempty"` + Provider Provider `json:"provider,omitempty"` + Format string `json:"format,omitempty"` + SkipTranslate bool `json:"skip_translate,omitempty"` + + // Recraft specific parameters + Model string `json:"model,omitempty"` + Size string `json:"size,omitempty"` + Substyle string `json:"substyle,omitempty"` + NumImages int `json:"n,omitempty"` +} + +// ImageResponse represents the response from image generation. +type ImageResponse struct { + ID string `json:"id"` + Prompt string `json:"prompt"` + NegativePrompt string `json:"negative_prompt"` + Style string `json:"style"` + SVGURL string `json:"svg_url"` + PNGURL string `json:"png_url"` + Width int `json:"width"` + Height int `json:"height"` + CreatedAt time.Time `json:"created_at"` + Provider Provider `json:"provider"` + + // Translation related information + OriginalPrompt string `json:"original_prompt,omitempty"` + TranslatedPrompt string `json:"translated_prompt,omitempty"` + WasTranslated bool `json:"was_translated"` +} + +// SVG.IO upstream API related types +type SVGIOGenerateReq struct { + Prompt string `json:"prompt"` + NegativePrompt string `json:"negativePrompt"` + Style string `json:"style,omitempty"` +} + +type SVGIOGenerateItem struct { + ID string `json:"id"` + Description string `json:"description"` + Height int `json:"height"` + HasInitialImage bool `json:"hasInitialImage"` + IsPrivate bool `json:"isPrivate"` + NSFWTextDetected bool `json:"nsfwTextDetected"` + NSFWContentDetected bool `json:"nsfwContentDetected"` + PNGURL string `json:"pngUrl"` + SVGURL string `json:"svgUrl"` + Style string `json:"style"` + Prompt string `json:"prompt"` + NegativePrompt string `json:"negativePrompt"` + Width int `json:"width"` + UpdatedAt string `json:"updatedAt"` + CreatedAt string `json:"createdAt"` +} + +type SVGIOGenerateResp struct { + Success bool `json:"success"` + Data []SVGIOGenerateItem `json:"data"` +} + +// Recraft API related types +type RecraftGenerateReq struct { + Prompt string `json:"prompt"` + NegativePrompt string `json:"negative_prompt,omitempty"` + Style string `json:"style,omitempty"` + Substyle string `json:"substyle,omitempty"` + Model string `json:"model,omitempty"` + Size string `json:"size,omitempty"` + N int `json:"n,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` +} + +type RecraftImageData struct { + URL string `json:"url"` + B64JSON string `json:"b64_json,omitempty"` + RevisedPrompt string `json:"revised_prompt,omitempty"` +} + +type RecraftGenerateResp struct { + Created int `json:"created"` + Data []RecraftImageData `json:"data"` +} + +type RecraftVectorizeReq struct { + ResponseFormat string `json:"response_format,omitempty"` +} + +type RecraftVectorizeResp struct { + Image RecraftImageData `json:"image"` +} + +// OpenAI API related types (supports all OpenAI compatible models) +type OpenAIMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type OpenAIGenerateReq struct { + Model string `json:"model"` + Messages []OpenAIMessage `json:"messages"` + MaxTokens int `json:"max_tokens"` + Temperature float64 `json:"temperature,omitempty"` + Stream bool `json:"stream,omitempty"` +} + +type OpenAIGenerateResp struct { + ID string `json:"id"` + Object string `json:"object"` + Model string `json:"model"` + Choices []struct { + Index int `json:"index"` + Message struct { + Role string `json:"role"` + Content string `json:"content"` + } `json:"message"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` + Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + } `json:"usage"` +} \ No newline at end of file diff --git a/spx-backend/internal/svggen/utils.go b/spx-backend/internal/svggen/utils.go new file mode 100644 index 000000000..48dc04cb4 --- /dev/null +++ b/spx-backend/internal/svggen/utils.go @@ -0,0 +1,59 @@ +package svggen + +import ( + "context" + "fmt" + "io" + "net/http" + "time" +) + +// DownloadFile downloads a file from the given URL. +func DownloadFile(ctx context.Context, url string) ([]byte, error) { + client := &http.Client{ + Timeout: 30 * time.Second, + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("download file: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("download failed with status: %s", resp.Status) + } + + return io.ReadAll(resp.Body) +} + +// ParseSizeFromString parses size string like "1024x1024". +func ParseSizeFromString(size string) (width, height int) { + if size == "" { + return 1024, 1024 + } + + // Simple parsing for common formats + switch size { + case "512x512": + return 512, 512 + case "1024x1024": + return 1024, 1024 + case "1536x1536": + return 1536, 1536 + case "2048x2048": + return 2048, 2048 + default: + return 1024, 1024 + } +} + +// GenerateImageID generates a simple image ID. +func GenerateImageID(provider Provider) string { + return fmt.Sprintf("%s_%d", string(provider), time.Now().UnixNano()) +} \ No newline at end of file diff --git a/spx-gui/docs/aigc-design.md b/spx-gui/docs/aigc-design.md new file mode 100644 index 000000000..c562c63bb --- /dev/null +++ b/spx-gui/docs/aigc-design.md @@ -0,0 +1,34 @@ +# AIGC Architecture Design +## I. Generator +The generator is the core component of the AIGC module, responsible for rendering the entire AI generation interface. It also invokes encapsulated APIs to initiate interactions with the backend and manages the logic of AI content generation. Error handling is implemented through a try-catch block that triggers the error-handling logic defined in error.vue. +## II. error.vue – Error Handling Page +The error component provides an error-handling interface for the generator. By invoking its functions and passing parameters, the generator can display the appropriate error page and present the user with corresponding error messages. +## III. modelSelector – Model Selector +The model selector is primarily a dropdown menu component. It allows users to choose styles and other presets, which are then passed as parameters to the backend through the generator. +## IV. Prompt Organization Area +This section lets users choose how to input prompts and complete the prompt-entry process: +1. Preset prompts – Provides a “fill-in-the-blank” style interaction. Once completed by the user, the prompt is concatenated into a full string and sent to the generator for requests. +2. Free-form prompt input – Users can directly enter custom prompts. +## V. src/apis/picgc.ts – API Layer +This module encapsulates all backend AI interaction requests and exposes them for use by the generator. +* Architecture Diagram +```mermaid +flowchart TD + Backend((Backend)) + picgc[picgc.ts] + Generator[Generator] + utils[utils] + Painter((Painter)) + + prompt[prompt input area] + error[Error handling page] + model[model selector] + + Backend <--> picgc + picgc <--> Generator + Generator --> prompt + Generator --> error + Generator --> model + Generator --> utils + utils --> Painter +``` diff --git a/spx-gui/docs/aigc-design.zh.md b/spx-gui/docs/aigc-design.zh.md new file mode 100644 index 000000000..207a458c3 --- /dev/null +++ b/spx-gui/docs/aigc-design.zh.md @@ -0,0 +1,35 @@ +# AIGC架构设计 +## 一、generator生成器 +生成器是aigc部分的主组件,负责整个ai生成组件的渲染。同时,负责调用封装好的接口,发起与后端的交互。并处理AI生成的相关逻辑。错误处理部分,则通过try-catch调用error.vue中的错误处理逻辑。 +## 二、error.vue 错误处理页面 +error作为错误处理组件,提供接口函数给generator使用。通过调用和参数传递,呼出相应的错误页面,给用户提供相应的错误信息。 +## 三、modelSelector 模型选择器 +主要是一个下拉菜单组件。给用户提供风格等预设,作为发后端的接口参数给generator使用。 +## 四、prompt组织区域 +此区域用于给用户选择提示词输入方式,并完成提示词输入流程。 +1.预设提示词:提供类似“完形填空”的效果。将用户填写完的prompt拼成整个字符串给generator发请求用 +2.自由输入提示词 +## 五、src/apis/picgc.ts api层 +将与后端ai交互的请求封装,供给generator使用 +* 架构设计图 + +```mermaid +flowchart TD + Backend((Backend)) + picgc[picgc.ts] + Generator[Generator] + utils[utils] + Painter((Painter)) + + prompt[prompt input area] + error[Error handling page] + model[model selector] + + Backend <--> picgc + picgc <--> Generator + Generator --> prompt + Generator --> error + Generator --> model + Generator --> utils + utils --> Painter +``` diff --git a/spx-gui/package-lock.json b/spx-gui/package-lock.json index 1d8b35882..8aa4be1b1 100644 --- a/spx-gui/package-lock.json +++ b/spx-gui/package-lock.json @@ -52,6 +52,7 @@ "naive-ui": "^2.38.1", "nanoid": "^5.0.5", "npm": "^10.3.0", + "paper": "^0.12.18", "prettier": "^3.2.2", "property-information": "^6.5.0", "qiniu-js": "^3.4.2", @@ -11255,6 +11256,18 @@ "resolved": "https://registry.npmjs.org/pako/-/pako-1.0.11.tgz", "integrity": "sha512-4hLB8Py4zZce5s4yd9XzopqwVv/yGNhV1Bl8NTmCq1763HeK2+EwVTv+leGeL13Dnh2wfbqowVPXCIO0z4taYw==" }, + "node_modules/paper": { + "version": "0.12.18", + "resolved": "https://registry.npmjs.org/paper/-/paper-0.12.18.tgz", + "integrity": "sha512-ZSLIEejQTJZuYHhSSqAf4jXOnii0kPhCJGAnYAANtdS72aNwXJ9cP95tZHgq1tnNpvEwgQhggy+4OarviqTCGw==", + "license": "MIT", + "workspaces": [ + "packages/*" + ], + "engines": { + "node": ">=8.0.0" + } + }, "node_modules/parent-module": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/parent-module/-/parent-module-1.0.1.tgz", diff --git a/spx-gui/package.json b/spx-gui/package.json index 38b149bec..23b187383 100644 --- a/spx-gui/package.json +++ b/spx-gui/package.json @@ -58,6 +58,7 @@ "naive-ui": "^2.38.1", "nanoid": "^5.0.5", "npm": "^10.3.0", + "paper": "^0.12.18", "prettier": "^3.2.2", "property-information": "^6.5.0", "qiniu-js": "^3.4.2", diff --git a/spx-gui/src/apis/picgc.ts b/spx-gui/src/apis/picgc.ts new file mode 100644 index 000000000..0642b83c5 --- /dev/null +++ b/spx-gui/src/apis/picgc.ts @@ -0,0 +1,207 @@ +/** + * @desc Picture Generation APIs for AI image generation + */ + +import { apiBaseUrl } from '@/utils/env' + +// 简单的HTTP请求函数 +async function picgcRequest(path: string, options: RequestInit = {}) { + const url = apiBaseUrl + path + + const response = await fetch(url, { + headers: { + 'Content-Type': 'application/json', + ...options.headers, + }, + ...options, + }) + + if (!response.ok) { + const errorText = await response.text() + throw new Error(`Request failed: ${response.status} ${response.statusText} - ${errorText}`) + } + + return response.json() +} + +/** Image generation model types */ +export type ImageModel = 'png' | 'svg' + +/** Request payload for image generation */ +export interface GenerateImageRequest { + /** The text prompt describing the desired image */ + prompt: string + /** Negative prompt for things to avoid */ + negative_prompt?: string + /** Image style (e.g., cartoon, realistic, etc.) */ + style?: string + /** AI provider (e.g., svgio, claude, recraft) */ + provider?: string + /** Output format (svg, png, etc.) */ + format?: string + /** Whether to skip translation */ + skip_translate?: boolean + /** AI model to use (e.g., gpt-4) */ + model?: string + /** Image size (e.g., "512x512") */ + size?: string + /** Sub-style specification (e.g., hand-drawn) */ + substyle?: string + /** Number of images to generate */ + n?: number +} + +/** Response from backend API */ +export interface GenerateImageResponse { + id: string + prompt: string + negative_prompt?: string + style?: string + svg_url: string + png_url: string + width: number + height: number + created_at: string +} + +//todo:弃用 +export async function generateImage( + prompt: string, + options?: { + negative_prompt?: string + style?: string + provider?: string + format?: string + skip_translate?: boolean + model?: string + size?: string + substyle?: string + n?: number + } +): Promise { + const payload: GenerateImageRequest = { + prompt, + negative_prompt: options?.negative_prompt || 'text, watermark', + style: options?.style, + provider: options?.provider, + format: options?.format, + skip_translate: options?.skip_translate, + model: options?.model, + size: options?.size, + substyle: options?.substyle, + n: options?.n + } + + //生成图片的接口 + const response = await picgcRequest('/v1/images', { + method: 'POST', + body: JSON.stringify(payload) + }) as GenerateImageResponse + console.log('response', response) + return response +} + +/** + * 直接生成并返回SVG内容 + */ +export async function generateSvgDirect( + provider: string,//通过这个参数选择供应商:claude,recraft + prompt: string, + options?: { + negative_prompt?: string + style?: string + format?: string + skip_translate?: boolean + model?: string + size?: string + substyle?: string + n?: number + } +): Promise<{ + svgContent: string + id: string + width: number + height: number +}> { + const payload: GenerateImageRequest = { + prompt, + negative_prompt: options?.negative_prompt || 'text, watermark', + style: options?.style, + provider: provider, + format: options?.format, + skip_translate: options?.skip_translate, + model: options?.model, + size: options?.size, + substyle: options?.substyle, + n: options?.n + } + + let url = '' + console.log('provider', provider) + url = apiBaseUrl + '/image/svg' + switch (provider) { + case 'claude': + payload.provider = 'claude' + break + case 'recraft': + payload.provider = 'recraft' + break + case 'svgio': + payload.provider = 'svgio' + break + default: + throw new Error('Invalid provider') + } + const response = await fetch(url, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify(payload) + }) + + if (!response.ok) { + const errorText = await response.text() + throw new Error(`Request failed: ${response.status} ${response.statusText} - ${errorText}`) + } + + // 获取SVG内容 + const svgContent = await response.text() + console.log('svgContent', svgContent) + // 从响应头获取元数据 + const id = response.headers.get('X-Image-Id') || 'unknown' + const width = parseInt('512') + const height = parseInt('512') + + // 修改SVG内容的尺寸:大小为512*512 + const modifiedSvgContent = svgContent.replace( + /]*?)>/, + (match: string, attributes: string) => { + // 解析现有属性 + let newAttributes = attributes + + // 更新或添加width属性 + if (newAttributes.includes('width=')) { + newAttributes = newAttributes.replace(/width="[^"]*"/, `width="${width}"`) + } else { + newAttributes += ` width="${width}"` + } + + // 更新或添加height属性 + if (newAttributes.includes('height=')) { + newAttributes = newAttributes.replace(/height="[^"]*"/, `height="${height}"`) + } else { + newAttributes += ` height="${height}"` + } + console.log('newAttributes',newAttributes) + return `` + } + ) + + return { + svgContent: modifiedSvgContent, + id, + width, + height + } +} diff --git a/spx-gui/src/components/editor/common/painter/components/aigc/error.vue b/spx-gui/src/components/editor/common/painter/components/aigc/error.vue new file mode 100644 index 000000000..979cba801 --- /dev/null +++ b/spx-gui/src/components/editor/common/painter/components/aigc/error.vue @@ -0,0 +1,330 @@ + + + + + + ⚠️ + {{ $t({ en: 'Generation Failed', zh: '生成失败' }) }} + + × + + + + + + {{ errorMessage }} + + + {{ $t({ en: 'Suggestions:', zh: '建议:' }) }} + + {{ $t({ en: 'Try simplifying your description', zh: '尝试简化您的描述' }) }} + {{ $t({ en: 'Check your network connection', zh: '检查网络连接状态' }) }} + {{ $t({ en: 'Ensure prompt is at least 3 characters', zh: '确保提示词至少3个字符' }) }} + {{ $t({ en: 'Try again in a few moments', zh: '稍后再试' }) }} + {{ $t({ en: 'Contact support if the problem persists', zh: '如问题持续存在,请联系技术支持' }) }} + + + + + + + + + + + + diff --git a/spx-gui/src/components/editor/common/painter/components/aigc/generator.vue b/spx-gui/src/components/editor/common/painter/components/aigc/generator.vue new file mode 100644 index 000000000..2a0367f2c --- /dev/null +++ b/spx-gui/src/components/editor/common/painter/components/aigc/generator.vue @@ -0,0 +1,723 @@ + + + + + + + + + {{ $t({ en: 'AI Generate Image', zh: 'AI生成图片' }) }} + + × + + + + + + + + + + + {{ $t({ en: 'Select Generation Model', zh: '选择生成模型' }) }} + + + 🖼️ + + {{ $t({ en: 'SVGIO Vector', zh: 'SVGIO矢量图' }) }} + {{ $t({ en: 'Generate simple, accurate vector images', zh: '生成简单,精确的矢量图' }) }} + + + + 📐 + + {{ $t({ en: 'Recraft SVG Vector', zh: 'Recraft SVG矢量' }) }} + {{ $t({ en: 'Generate fabulous editable vector graphics', zh: '生成精美的可编辑矢量图形' }) }} + + + + + + + + {{ $t({ en: 'Describe the image you want', zh: '描述您想要的图片' }) }} + + + {{ $t({ en: 'Tip: The more detailed the description, the better the generated image effect', zh: '提示:描述越详细,生成的图片效果越好' }) }} + + + + + + + + {{ isGenerating ? $t({ en: 'Generating...', zh: '生成中...' }) : $t({ en: 'Generate Image', zh: '生成图片' }) }} + + + + + + + {{ $t({ en: 'Preview', zh: '预览效果' }) }} + + + + {{ $t({ en: 'AI is generating images for you...', zh: 'AI正在为您生成图片...' }) }} + + + + + + + 🖼️ + {{ $t({ en: 'Generated images will be previewed here', zh: '生成的图片将在这里预览' }) }} + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/spx-gui/src/components/editor/common/painter/components/circle_tool.vue b/spx-gui/src/components/editor/common/painter/components/circle_tool.vue new file mode 100644 index 000000000..033e63b27 --- /dev/null +++ b/spx-gui/src/components/editor/common/painter/components/circle_tool.vue @@ -0,0 +1,15 @@ + + + + + + + + + \ No newline at end of file diff --git a/spx-gui/src/components/editor/common/painter/components/draw_brush.vue b/spx-gui/src/components/editor/common/painter/components/draw_brush.vue new file mode 100644 index 000000000..8fa7abae5 --- /dev/null +++ b/spx-gui/src/components/editor/common/painter/components/draw_brush.vue @@ -0,0 +1,109 @@ + + + + + + + \ No newline at end of file diff --git a/spx-gui/src/components/editor/common/painter/components/draw_line.vue b/spx-gui/src/components/editor/common/painter/components/draw_line.vue new file mode 100644 index 000000000..7f96cfb4d --- /dev/null +++ b/spx-gui/src/components/editor/common/painter/components/draw_line.vue @@ -0,0 +1,150 @@ + + + + + + + + + + diff --git a/spx-gui/src/components/editor/common/painter/components/eraser_tool.vue b/spx-gui/src/components/editor/common/painter/components/eraser_tool.vue new file mode 100644 index 000000000..c10e929a7 --- /dev/null +++ b/spx-gui/src/components/editor/common/painter/components/eraser_tool.vue @@ -0,0 +1,15 @@ + + + + + + + + + \ No newline at end of file diff --git a/spx-gui/src/components/editor/common/painter/components/fill_tool.vue b/spx-gui/src/components/editor/common/painter/components/fill_tool.vue new file mode 100644 index 000000000..ac36c9536 --- /dev/null +++ b/spx-gui/src/components/editor/common/painter/components/fill_tool.vue @@ -0,0 +1,15 @@ + + + + + + + + + \ No newline at end of file diff --git a/spx-gui/src/components/editor/common/painter/components/rectangle_tool.vue b/spx-gui/src/components/editor/common/painter/components/rectangle_tool.vue new file mode 100644 index 000000000..7a13a80a5 --- /dev/null +++ b/spx-gui/src/components/editor/common/painter/components/rectangle_tool.vue @@ -0,0 +1,15 @@ + + + + + + + + + \ No newline at end of file diff --git a/spx-gui/src/components/editor/common/painter/components/reshape.vue b/spx-gui/src/components/editor/common/painter/components/reshape.vue new file mode 100644 index 000000000..10ac52358 --- /dev/null +++ b/spx-gui/src/components/editor/common/painter/components/reshape.vue @@ -0,0 +1,404 @@ + + + + + + + \ No newline at end of file diff --git a/spx-gui/src/components/editor/common/painter/components/text_tool.vue b/spx-gui/src/components/editor/common/painter/components/text_tool.vue new file mode 100644 index 000000000..af27dc491 --- /dev/null +++ b/spx-gui/src/components/editor/common/painter/components/text_tool.vue @@ -0,0 +1,15 @@ + + + + + + + + + \ No newline at end of file diff --git a/spx-gui/src/components/editor/common/painter/painter.vue b/spx-gui/src/components/editor/common/painter/painter.vue new file mode 100644 index 000000000..601ac40fe --- /dev/null +++ b/spx-gui/src/components/editor/common/painter/painter.vue @@ -0,0 +1,705 @@ + + + + + + {{ $t({ en: 'Drawing Tools', zh: '绘图工具' }) }} + + + + + + {{ $t({ en: 'Line', zh: '直线' }) }} + + + + + + + + {{ $t({ en: 'Brush', zh: '笔刷' }) }} + + + + + + + + {{ $t({ en: 'Reshape', zh: '变形' }) }} + + + + + + + + + {{ $t({ en: 'Eraser', zh: '橡皮' }) }} + + + + + + + {{ $t({ en: 'Rectangle', zh: '矩形' }) }} + + + + + + + {{ $t({ en: 'Circle', zh: '圆形' }) }} + + + + + + + + + + + {{ $t({ en: 'Fill', zh: '填充' }) }} + + + + + + + + + {{ $t({ en: 'Text', zh: '文本' }) }} + + + + + + {{ $t({ en: 'AI Tools', zh: 'AI工具' }) }} + + + + + + + + + + + + + {{ $t({ en: 'AI Generate', zh: 'AI生成' }) }} + + + + + {{ $t({ en: 'Actions', zh: '操作' }) }} + + + + + + {{ $t({ en: 'Clear', zh: '清空' }) }} + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/spx-gui/src/components/editor/common/painter/utils/delegator.ts b/spx-gui/src/components/editor/common/painter/utils/delegator.ts new file mode 100644 index 000000000..a2845c6c4 --- /dev/null +++ b/spx-gui/src/components/editor/common/painter/utils/delegator.ts @@ -0,0 +1,247 @@ +import paper from 'paper' + +// 工具类型定义 +export type ToolType = 'line' | 'brush' | 'reshape' | 'eraser' | 'rectangle' | 'circle' | 'fill' | 'text' + +// 事件类型定义 +export type EventType = 'click' | 'mousedown' | 'mousemove' | 'mouseup' | 'drag' + +// 点位置接口 +export interface Point { + x: number + y: number +} + +// 工具处理器接口 +export interface ToolHandler { + handleCanvasClick?: (point: Point | paper.Point) => void + handleClick?: (point: Point | paper.Point) => void + handleMouseDown?: (point: Point | paper.Point) => void + handleMouseMove?: (point: Point | paper.Point) => void + handleMouseDrag?: (point: Point | paper.Point) => void + handleMouseUp?: (point: Point | paper.Point) => void +} + +// 工具引用映射 +export interface ToolRefs { + line?: ToolHandler + brush?: ToolHandler + reshape?: ToolHandler + eraser?: ToolHandler + rectangle?: ToolHandler + circle?: ToolHandler + fill?: ToolHandler + text?: ToolHandler +} + +/** + * 画布事件委托器 + * 负责根据当前选中的工具和事件类型,将事件分发给对应的工具处理器 + */ +export class CanvasEventDelegator { + private toolRefs: ToolRefs = {} + private currentTool: ToolType | null = null + + /** + * 设置工具引用 + */ + setToolRefs(refs: ToolRefs): void { + this.toolRefs = refs + } + + /** + * 设置当前工具 + */ + setCurrentTool(tool: ToolType | null): void { + this.currentTool = tool + } + + /** + * 获取当前工具 + */ + getCurrentTool(): ToolType | null { + return this.currentTool + } + + /** + * 将鼠标事件坐标转换为画布坐标 + */ + private getCanvasPoint(event: MouseEvent, canvasRef: HTMLCanvasElement | null): paper.Point | null { + if (!canvasRef) return null + + const rect = canvasRef.getBoundingClientRect() + return new paper.Point( + event.clientX - rect.left, + event.clientY - rect.top + ) + } + + /** + * 委托画布点击事件 + */ + delegateClick(event: MouseEvent, canvasRef: HTMLCanvasElement | null): void { + console.log('delegateClick', this.currentTool) + if (!this.currentTool) return + + const point = this.getCanvasPoint(event, canvasRef) + if (!point) return + + const handler = this.toolRefs[this.currentTool] + if (!handler) return + + // 优先调用 handleCanvasClick,如果不存在则调用 handleClick + if (handler.handleCanvasClick) { + // line 工具需要 {x, y} 格式 + if (this.currentTool === 'line') { + handler.handleCanvasClick({ x: point.x, y: point.y }) + } else { + handler.handleCanvasClick(point) + } + } else if (handler.handleClick) { + handler.handleClick(point) + } + } + + /** + * 委托鼠标按下事件 + */ + delegateMouseDown(event: MouseEvent, canvasRef: HTMLCanvasElement | null): void { + if (!this.currentTool) return + + const point = this.getCanvasPoint(event, canvasRef) + if (!point) return + + const handler = this.toolRefs[this.currentTool] + if (!handler?.handleMouseDown) return + + // brush 工具需要 {x, y} 格式,其他工具使用 paper.Point + if (this.currentTool === 'brush') { + handler.handleMouseDown({ x: point.x, y: point.y }) + } else { + handler.handleMouseDown(point) + } + } + + /** + * 委托鼠标移动事件 + */ + delegateMouseMove(event: MouseEvent, canvasRef: HTMLCanvasElement | null): void { + if (!this.currentTool) return + + const point = this.getCanvasPoint(event, canvasRef) + if (!point) return + + const handler = this.toolRefs[this.currentTool] + if (!handler) return + + // 根据工具类型调用不同的处理方法 + if (this.currentTool === 'line' && handler.handleMouseMove) { + handler.handleMouseMove({ x: point.x, y: point.y }) + } else if (this.currentTool === 'brush' && handler.handleMouseDrag) { + handler.handleMouseDrag({ x: point.x, y: point.y }) + } else if (this.currentTool === 'reshape' && handler.handleMouseMove) { + handler.handleMouseMove(point) + } + } + + /** + * 委托鼠标释放事件 + */ + delegateMouseUp(event: MouseEvent, canvasRef: HTMLCanvasElement | null): void { + if (!this.currentTool) return + + const point = this.getCanvasPoint(event, canvasRef) + if (!point) return + + const handler = this.toolRefs[this.currentTool] + if (!handler?.handleMouseUp) return + + // brush 工具需要 {x, y} 格式 + if (this.currentTool === 'brush') { + handler.handleMouseUp({ x: point.x, y: point.y }) + } else { + handler.handleMouseUp(point) + } + } + + /** + * 委托全局鼠标释放事件(不需要坐标) + */ + delegateGlobalMouseUp(): void { + if (!this.currentTool) return + + const handler = this.toolRefs[this.currentTool] + if (!handler?.handleMouseUp) return + + // reshape 工具的全局鼠标释放事件 + if (this.currentTool === 'reshape') { + handler.handleMouseUp + } + } + + /** + * 通用事件委托方法 + * @param eventType 事件类型 + * @param event 鼠标事件 + * @param canvasRef 画布引用 + */ + delegate(eventType: EventType, event: MouseEvent, canvasRef: HTMLCanvasElement | null): void { + switch (eventType) { + case 'click': + this.delegateClick(event, canvasRef) + break + case 'mousedown': + this.delegateMouseDown(event, canvasRef) + break + case 'mousemove': + this.delegateMouseMove(event, canvasRef) + break + case 'mouseup': + this.delegateMouseUp(event, canvasRef) + break + default: + console.warn(`未支持的事件类型: ${eventType}`) + } + } + + /** + * 检查工具是否支持特定事件 + */ + supportsEvent(tool: ToolType, eventType: EventType): boolean { + const handler = this.toolRefs[tool] + if (!handler) return false + + switch (eventType) { + case 'click': + return !!(handler.handleCanvasClick || handler.handleClick) + case 'mousedown': + return !!handler.handleMouseDown + case 'mousemove': + return !!(handler.handleMouseMove || handler.handleMouseDrag) + case 'mouseup': + return !!handler.handleMouseUp + default: + return false + } + } + + /** + * 获取所有支持的工具类型 + */ + getSupportedTools(): ToolType[] { + return Object.keys(this.toolRefs) as ToolType[] + } + + /** + * 清除所有工具引用 + */ + clearToolRefs(): void { + this.toolRefs = {} + } +} + +// 创建单例实例 +export const canvasEventDelegator = new CanvasEventDelegator() + +// 默认导出 +export default canvasEventDelegator \ No newline at end of file diff --git a/spx-gui/src/components/editor/common/painter/utils/import-export-manager.ts b/spx-gui/src/components/editor/common/painter/utils/import-export-manager.ts new file mode 100644 index 000000000..bf22f19e1 --- /dev/null +++ b/spx-gui/src/components/editor/common/painter/utils/import-export-manager.ts @@ -0,0 +1,393 @@ +import paper from 'paper' +import type { Ref } from 'vue' + +// 导入导出相关的接口定义 +export interface ImportExportDependencies { + // 画布相关 + canvasWidth: Ref + canvasHeight: Ref + canvasRef: Ref + + // 路径管理 + allPaths: Ref + + // 工具和状态 + currentTool: Ref + reshapeRef: Ref + backgroundRect: Ref + backgroundImage: Ref + + // 状态标记 + isImportingFromProps: Ref + + // 回调函数 + emit: (event: string, data: any) => void +} + +export interface ImportOptions { + clearCanvas?: boolean + position?: 'center' | 'original' + updatePaths?: boolean + triggerExport?: boolean +} + +export interface ExportOptions { + format?: 'svg' | 'png' | 'jpg' + embedImages?: boolean + bounds?: paper.Rectangle + quality?: number +} + +/** + * 画布导入导出管理器 + * 统一管理SVG、图片的导入导出操作 + */ +export class ImportExportManager { + private dependencies: ImportExportDependencies + + constructor(dependencies: ImportExportDependencies) { + this.dependencies = dependencies + } + + /** + * 导出当前画布为SVG + */ + exportSvg(options: ExportOptions = {}): string | null { + const { + embedImages = true, + bounds = paper.view.bounds + } = options + + if (!paper.project) { + console.warn('Paper project not initialized') + return null + } + + try { + // 隐藏控制点和背景,避免导出到SVG中 + this.hideControlPointsForExport() + const prevVisible = this.hideBackgroundForExport() + + const svgStr = paper.project.exportSVG({ + asString: true, + embedImages, + bounds + }) as string + + // 恢复背景可见性 + this.restoreBackgroundAfterExport(prevVisible) + + return svgStr + } catch (error) { + console.error('Failed to export SVG:', error) + return null + } + } + + /** + * 导出SVG并触发事件 + */ + exportSvgAndEmit(options: ExportOptions = {}): void { + const svgStr = this.exportSvg(options) + if (svgStr) { + this.dependencies.emit('svg-change', svgStr) + } + } + + /** + * 导入SVG到画布 + */ + async importSvg(svgContent: string, options: ImportOptions = {}): Promise { + const { + clearCanvas = false, + position = 'center', + updatePaths = true, + triggerExport = true + } = options + + if (!paper.project) { + console.warn('Paper project not initialized') + return false + } + + try { + if (clearCanvas) { + this.clearCanvas(false) // 清空但不触发导出 + } + + // 解析SVG内容 + const parser = new DOMParser() + const svgDoc = parser.parseFromString(svgContent, 'image/svg+xml') + const svgElement = svgDoc.documentElement + + if (svgElement.nodeName !== 'svg') { + console.error('Invalid SVG content') + return false + } + + // 导入SVG到Paper.js + const importedItem = paper.project.importSVG(svgElement as unknown as SVGElement) + + if (!importedItem) { + console.error('Failed to import SVG') + return false + } + + // 设置位置 + if (position === 'center') { + importedItem.position = paper.view.center + } + + // 收集可编辑路径 + if (updatePaths) { + this.collectPathsFromImport(importedItem) + } + + paper.view.update() + + // 触发导出(避免循环导入) + if (triggerExport && !this.dependencies.isImportingFromProps.value) { + this.exportSvgAndEmit() + } + + return true + } catch (error) { + console.error('Failed to import SVG:', error) + return false + } + } + + /** + * 导入图片(PNG/JPG等)到画布 + */ + async importImage(imageSrc: string, options: ImportOptions = {}): Promise { + const { + position = 'center', + triggerExport = true + } = options + + if (!paper.project) { + console.warn('Paper project not initialized') + return false + } + + return new Promise((resolve) => { + // 清除现有背景图片 + if (this.dependencies.backgroundImage.value) { + try { + this.dependencies.backgroundImage.value.remove() + } catch (error) { + console.warn('Error removing previous background image:', error) + } + this.dependencies.backgroundImage.value = null + } + + // 创建新的光栅图像 + const raster = new paper.Raster(imageSrc) + + raster.onLoad = () => { + try { + if (position === 'center') { + raster.position = paper.view.center + } + + // 将图片放到最底层作为背景 + raster.sendToBack() + + // 更新背景图片引用 + this.dependencies.backgroundImage.value = raster + + paper.view.update() + + if (triggerExport) { + this.exportSvgAndEmit() + } + + resolve(true) + } catch (error) { + console.error('Error processing loaded image:', error) + resolve(false) + } + } + + raster.onError = () => { + console.error('Failed to load image:', imageSrc) + resolve(false) + } + }) + } + + /** + * 根据URL自动判断文件类型并导入 + */ + async importFile(fileUrl: string, options: ImportOptions = {}): Promise { + try { + const response = await fetch(fileUrl) + + // 尝试判断是否为SVG + let isSvg = false + let content: string | null = null + + // 方法1: 通过blob类型判断 + try { + const blob = await response.clone().blob() + if (blob?.type?.includes('image/svg')) { + isSvg = true + content = await blob.text() + } + } catch {} + + // 方法2: 通过响应头判断 + if (!isSvg) { + const contentType = response.headers.get('content-type') || '' + if (contentType.includes('image/svg')) { + isSvg = true + content = await response.clone().text() + } + } + + // 方法3: 尝试解析文本内容 + if (!isSvg) { + try { + const text = await response.clone().text() + if (/^\s*\s*$/i.test(text)) { + isSvg = true + content = text + } + } catch {} + } + + // 根据类型导入 + if (isSvg && content) { + return await this.importSvg(content, options) + } else { + return await this.importImage(fileUrl, options) + } + } catch (error) { + console.error('Failed to import file:', error) + // 回退到图片导入 + return await this.importImage(fileUrl, options) + } + } + + /** + * 清空画布 + */ + clearCanvas(triggerExport: boolean = true): void { + if (!paper.project) return + + try { + // 隐藏控制点 + this.hideControlPointsForExport() + + // 移除所有路径 + this.dependencies.allPaths.value.forEach((path: paper.Path) => { + if (path && path.parent) { + path.remove() + } + }) + this.dependencies.allPaths.value = [] + + // 清空项目 + paper.project.clear() + + // 重新创建背景 + this.createBackground() + + // 重置状态 + this.resetToolStates() + + paper.view.update() + + if (triggerExport) { + this.exportSvgAndEmit() + } + } catch (error) { + console.error('Failed to clear canvas:', error) + } + } + + /** + * 创建画布背景 + */ + private createBackground(): void { + const background = new paper.Path.Rectangle({ + point: [0, 0], + size: [this.dependencies.canvasWidth.value, this.dependencies.canvasHeight.value], + fillColor: 'transparent' + }) + this.dependencies.backgroundRect.value = background + } + + /** + * 重置工具状态 + */ + private resetToolStates(): void { + // 通过reshape引用重置 + if (this.dependencies.reshapeRef.value?.resetDrawing) { + this.dependencies.reshapeRef.value.resetDrawing() + } + } + + /** + * 从导入的项目中收集可编辑路径 + */ + private collectPathsFromImport(item: paper.Item): void { + const allPaths = this.dependencies.allPaths.value + + const collectPaths = (item: paper.Item): void => { + if (item instanceof paper.Path && item.segments && item.segments.length > 0) { + allPaths.push(item) + + // 添加鼠标事件处理 + item.onMouseDown = () => { + if (this.dependencies.currentTool.value === 'reshape' && this.dependencies.reshapeRef.value) { + this.dependencies.reshapeRef.value.showControlPoints(item) + paper.view.update() + } + } + } else if (item instanceof paper.Group || item instanceof paper.CompoundPath) { + if (item.children) { + item.children.forEach(child => collectPaths(child)) + } + } + } + + collectPaths(item) + } + + /** + * 导出前隐藏控制点 + */ + private hideControlPointsForExport(): void { + if (this.dependencies.reshapeRef.value?.hideControlPoints) { + this.dependencies.reshapeRef.value.hideControlPoints() + } + } + + /** + * 导出前隐藏背景 + */ + private hideBackgroundForExport(): boolean { + const prevVisible = this.dependencies.backgroundRect.value?.visible ?? true + if (this.dependencies.backgroundRect.value) { + this.dependencies.backgroundRect.value.visible = false + } + return prevVisible + } + + /** + * 导出后恢复背景 + */ + private restoreBackgroundAfterExport(prevVisible: boolean): void { + if (this.dependencies.backgroundRect.value) { + this.dependencies.backgroundRect.value.visible = prevVisible + } + } + +} + +// 创建单例实例 +export const createImportExportManager = (dependencies: ImportExportDependencies): ImportExportManager => { + return new ImportExportManager(dependencies) +} \ No newline at end of file diff --git a/spx-gui/src/components/editor/sprite/CostumeDetail.vue b/spx-gui/src/components/editor/sprite/CostumeDetail.vue index 0fd2b94ef..253310a69 100644 --- a/spx-gui/src/components/editor/sprite/CostumeDetail.vue +++ b/spx-gui/src/components/editor/sprite/CostumeDetail.vue @@ -1,21 +1,23 @@ - + - + diff --git a/spx-gui/src/components/editor/sprite/CostumeItem.vue b/spx-gui/src/components/editor/sprite/CostumeItem.vue index 6e1711014..4bab8af20 100644 --- a/spx-gui/src/components/editor/sprite/CostumeItem.vue +++ b/spx-gui/src/components/editor/sprite/CostumeItem.vue @@ -16,7 +16,7 @@ @@ -33,10 +34,14 @@ const imgStyle = computed(() => diff --git a/spx-gui/src/utils/file.ts b/spx-gui/src/utils/file.ts index 69d0ccb05..7f27dcb32 100644 --- a/spx-gui/src/utils/file.ts +++ b/spx-gui/src/utils/file.ts @@ -153,6 +153,54 @@ export function useFileUrl(fileSource: WatchSource) { return [urlRef, loadingRef] as const } +/** + * Get url for File with smooth transition (no flickering) + * This version keeps the old URL until the new one is ready + */ +export function useFileUrlSmooth(fileSource: WatchSource) { + const urlRef = ref(null) + const loadingRef = ref(false) + watch( + fileSource, + (file, _, onCleanup) => { + if (file == null) { + urlRef.value = null + return + } + loadingRef.value = true + let cancelled = false + let fileUrlCleanup: (() => void) | null = null + onCleanup(() => { + cancelled = true + // Don't clear urlRef.value here - keep old URL until component unmounts + fileUrlCleanup?.() + }) + file + .url((cleanup) => { + if (cancelled) { + cleanup() + return + } + fileUrlCleanup = cleanup + }) + .then((url) => { + if (cancelled) return + // Only update URL when new URL is ready + urlRef.value = url + }) + .catch((e) => { + if (e instanceof Cancelled) return + throw e + }) + .finally(() => { + loadingRef.value = false + }) + }, + { immediate: true } + ) + return [urlRef, loadingRef] as const +} + /** * Get image element (HTMLImageElement) based on given (image) file. * The image element is guaranteed to be loaded when set to ref.
+ {{ errorMessage }} +