diff --git a/.env.example b/.env.example index 78a3b72c..67466278 100644 --- a/.env.example +++ b/.env.example @@ -4,11 +4,39 @@ LLM_API_KEY=your_api_key_here LLM_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 LLM_MODEL_NAME=qwen-plus +LLM_MAX_TOKENS=4096 # LLM 最大输出 token 数 -# ===== ZEP记忆图谱配置 ===== +# 嵌入模型配置(用于 Graphiti local 模式,可独立配置) +# 如果不配置,则默认使用 LLM 的 API +# EMBEDDING_API_KEY=your_embedding_api_key # 可选 +# EMBEDDING_BASE_URL=your_embedding_base_url # 可选 +# EMBEDDING_MODEL=text-embedding-3-small +# EMBEDDING_DIM=1536 +EMBEDDING_BATCH_SIZE=5 # 嵌入向量批处理大小 + +# ===== 知识图谱配置 ===== +# 模式选择: "cloud" (Zep Cloud) 或 "local" (Graphiti + Neo4j) +KNOWLEDGE_GRAPH_MODE=cloud + +# Zep Cloud 配置 (KNOWLEDGE_GRAPH_MODE=cloud 时需要) # 每月免费额度即可支撑简单使用:https://app.getzep.com/ ZEP_API_KEY=your_zep_api_key_here +# Graphiti / Neo4j 配置 (KNOWLEDGE_GRAPH_MODE=local 时需要) +NEO4J_URI=bolt://localhost:7687 +NEO4J_USER=neo4j +NEO4J_PASSWORD=your_neo4j_password +# 嵌入向量 API Key(可选,用 LLM_API_KEY 即可,会自动使用同一 base_url) +# OPENAI_API_KEY=your_openai_key_for_embedding + +# OASIS 模拟配置 +OASIS_DEFAULT_MAX_ROUNDS=10 # 默认最大轮数 + +# Report Agent 配置 +REPORT_AGENT_MAX_TOOL_CALLS=5 # 最大工具调用次数 +REPORT_AGENT_MAX_REFLECTION_ROUNDS=2 # 最大反思轮数 +REPORT_AGENT_TEMPERATURE=0.5 # 温度参数 + # ===== 加速 LLM 配置(可选)===== # 注意如果不使用加速配置,env文件中就不要出现下面的配置项 LLM_BOOST_API_KEY=your_api_key_here diff --git a/backend/app/api/graph.py b/backend/app/api/graph.py index 12ff1ba2..1aa18ab1 100644 --- a/backend/app/api/graph.py +++ b/backend/app/api/graph.py @@ -282,9 +282,9 @@ def build_graph(): try: logger.info("=== 开始构建图谱 ===") - # 检查配置 + # 检查配置 (cloud 模式需要 Zep Cloud) errors = [] - if not Config.ZEP_API_KEY: + if Config.KNOWLEDGE_GRAPH_MODE == 'cloud' and not Config.ZEP_API_KEY: errors.append("ZEP_API_KEY未配置") if errors: logger.error(f"配置错误: {errors}") @@ -374,7 +374,7 @@ def build_graph(): def build_task(): build_logger = get_logger('mirofish.build') try: - build_logger.info(f"[{task_id}] 开始构建图谱...") + build_logger.debug(f"[{task_id}] 开始构建图谱...") task_manager.update_task( task_id, status=TaskStatus.PROCESSING, @@ -410,12 +410,15 @@ def build_task(): ProjectManager.save_project(project) # 设置本体 + build_logger.debug(f"[{task_id}] 准备设置本体...") task_manager.update_task( task_id, message="设置本体定义...", progress=15 ) + build_logger.debug(f"[{task_id}] 开始设置本体...") builder.set_ontology(graph_id, ontology) + build_logger.debug(f"[{task_id}] 本体设置完成") # 添加文本(progress_callback 签名是 (msg, progress_ratio)) def add_progress_callback(msg, progress_ratio): @@ -431,15 +434,18 @@ def add_progress_callback(msg, progress_ratio): message=f"开始添加 {total_chunks} 个文本块...", progress=15 ) - + + build_logger.debug(f"[{task_id}] 准备添加文本,共 {total_chunks} 个块") episode_uuids = builder.add_text_batches( - graph_id, + graph_id, chunks, batch_size=3, progress_callback=add_progress_callback ) + build_logger.debug(f"[{task_id}] 文本添加完成,共 {len(episode_uuids)} 个 episode") # 等待Zep处理完成(查询每个episode的processed状态) + build_logger.debug(f"[{task_id}] 开始等待处理,共 {len(episode_uuids)} 个 episode") task_manager.update_task( task_id, message="等待Zep处理数据...", @@ -567,13 +573,13 @@ def get_graph_data(graph_id: str): 获取图谱数据(节点和边) """ try: - if not Config.ZEP_API_KEY: + if Config.KNOWLEDGE_GRAPH_MODE == 'cloud' and not Config.ZEP_API_KEY: return jsonify({ "success": False, "error": "ZEP_API_KEY未配置" }), 500 - - builder = GraphBuilderService(api_key=Config.ZEP_API_KEY) + + builder = GraphBuilderService() graph_data = builder.get_graph_data(graph_id) return jsonify({ @@ -595,13 +601,13 @@ def delete_graph(graph_id: str): 删除Zep图谱 """ try: - if not Config.ZEP_API_KEY: + if Config.KNOWLEDGE_GRAPH_MODE == 'cloud' and not Config.ZEP_API_KEY: return jsonify({ "success": False, "error": "ZEP_API_KEY未配置" }), 500 - - builder = GraphBuilderService(api_key=Config.ZEP_API_KEY) + + builder = GraphBuilderService() builder.delete_graph(graph_id) return jsonify({ diff --git a/backend/app/api/simulation.py b/backend/app/api/simulation.py index 3a0f6816..e2f6e32e 100644 --- a/backend/app/api/simulation.py +++ b/backend/app/api/simulation.py @@ -56,18 +56,18 @@ def get_graph_entities(graph_id: str): enrich: 是否获取相关边信息(默认true) """ try: - if not Config.ZEP_API_KEY: + if Config.KNOWLEDGE_GRAPH_MODE == 'cloud' and not Config.ZEP_API_KEY: return jsonify({ "success": False, "error": "ZEP_API_KEY未配置" }), 500 - + entity_types_str = request.args.get('entity_types', '') entity_types = [t.strip() for t in entity_types_str.split(',') if t.strip()] if entity_types_str else None enrich = request.args.get('enrich', 'true').lower() == 'true' - + logger.info(f"获取图谱实体: graph_id={graph_id}, entity_types={entity_types}, enrich={enrich}") - + reader = ZepEntityReader() result = reader.filter_defined_entities( graph_id=graph_id, @@ -93,12 +93,12 @@ def get_graph_entities(graph_id: str): def get_entity_detail(graph_id: str, entity_uuid: str): """获取单个实体的详细信息""" try: - if not Config.ZEP_API_KEY: + if Config.KNOWLEDGE_GRAPH_MODE == 'cloud' and not Config.ZEP_API_KEY: return jsonify({ "success": False, "error": "ZEP_API_KEY未配置" }), 500 - + reader = ZepEntityReader() entity = reader.get_entity_with_context(graph_id, entity_uuid) @@ -126,12 +126,12 @@ def get_entity_detail(graph_id: str, entity_uuid: str): def get_entities_by_type(graph_id: str, entity_type: str): """获取指定类型的所有实体""" try: - if not Config.ZEP_API_KEY: + if Config.KNOWLEDGE_GRAPH_MODE == 'cloud' and not Config.ZEP_API_KEY: return jsonify({ "success": False, "error": "ZEP_API_KEY未配置" }), 500 - + enrich = request.args.get('enrich', 'true').lower() == 'true' reader = ZepEntityReader() @@ -982,6 +982,44 @@ def get_simulation_history(): }), 500 +@simulation_bp.route('/', methods=['DELETE']) +def delete_simulation(simulation_id: str): + """ + 删除模拟及其所有相关数据 + + Args: + simulation_id: 模拟ID + + Returns: + { + "success": true, + "message": "删除成功" + } + """ + try: + manager = SimulationManager() + success = manager.delete_simulation(simulation_id) + + if success: + return jsonify({ + "success": True, + "message": "删除成功" + }) + else: + return jsonify({ + "success": False, + "error": "删除失败,模拟可能不存在" + }), 404 + + except Exception as e: + logger.error(f"删除模拟失败: {str(e)}") + return jsonify({ + "success": False, + "error": str(e), + "traceback": traceback.format_exc() + }), 500 + + @simulation_bp.route('//profiles', methods=['GET']) def get_simulation_profiles(simulation_id: str): """ @@ -1409,7 +1447,7 @@ def generate_profiles(): "error": "没有找到符合条件的实体" }), 400 - generator = OasisProfileGenerator() + generator = OasisProfileGenerator(graph_id=graph_id) profiles = generator.generate_profiles_from_entities( entities=filtered.entities, use_llm=use_llm @@ -2440,7 +2478,7 @@ def interview_all_agents(): simulation_id = data.get('simulation_id') prompt = data.get('prompt') platform = data.get('platform') # 可选:twitter/reddit/None - timeout = data.get('timeout', 180) + timeout = data.get('timeout', 240) if not simulation_id: return jsonify({ diff --git a/backend/app/config.py b/backend/app/config.py index 953dfa50..7e9270c3 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -31,9 +31,29 @@ class Config: LLM_API_KEY = os.environ.get('LLM_API_KEY') LLM_BASE_URL = os.environ.get('LLM_BASE_URL', 'https://api.openai.com/v1') LLM_MODEL_NAME = os.environ.get('LLM_MODEL_NAME', 'gpt-4o-mini') - - # Zep配置 + LLM_MAX_TOKENS = int(os.environ.get('LLM_MAX_TOKENS', '4096')) + + # 嵌入模型配置(用于 Graphiti local 模式,可独立配置) + EMBEDDING_API_KEY = os.environ.get('EMBEDDING_API_KEY') # 可选,默认使用 LLM_API_KEY + EMBEDDING_BASE_URL = os.environ.get('EMBEDDING_BASE_URL') # 可选,默认使用 LLM_BASE_URL + EMBEDDING_MODEL = os.environ.get('EMBEDDING_MODEL', 'text-embedding-3-small') + EMBEDDING_DIM = int(os.environ.get('EMBEDDING_DIM', '1536')) + EMBEDDING_BATCH_SIZE = int(os.environ.get('EMBEDDING_BATCH_SIZE', '5')) # 批处理大小,默认5 + + # 知识图谱模式配置 + # cloud: 使用 Zep Cloud (默认) + # local: 使用 Graphiti + Neo4j (本地部署) + KNOWLEDGE_GRAPH_MODE = os.environ.get('KNOWLEDGE_GRAPH_MODE', 'cloud') + + # Zep Cloud 配置 (KNOWLEDGE_GRAPH_MODE=cloud 时需要) ZEP_API_KEY = os.environ.get('ZEP_API_KEY') + + # Graphiti / Neo4j 配置 (KNOWLEDGE_GRAPH_MODE=local 时需要) + NEO4J_URI = os.environ.get('NEO4J_URI', 'bolt://localhost:7687') + NEO4J_USER = os.environ.get('NEO4J_USER', 'neo4j') + NEO4J_PASSWORD = os.environ.get('NEO4J_PASSWORD') + # OpenAI API 用于嵌入向量 (Graphiti 模式需要) + OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY') # 文件上传配置 MAX_CONTENT_LENGTH = 50 * 1024 * 1024 # 50MB @@ -69,7 +89,18 @@ def validate(cls): errors = [] if not cls.LLM_API_KEY: errors.append("LLM_API_KEY 未配置") - if not cls.ZEP_API_KEY: - errors.append("ZEP_API_KEY 未配置") + + # 根据模式验证对应的配置 + if cls.KNOWLEDGE_GRAPH_MODE == 'cloud': + if not cls.ZEP_API_KEY: + errors.append("ZEP_API_KEY 未配置 (当前模式: cloud)") + elif cls.KNOWLEDGE_GRAPH_MODE == 'local': + if not cls.NEO4J_PASSWORD: + errors.append("NEO4J_PASSWORD 未配置 (当前模式: local)") + if not cls.LLM_API_KEY and not cls.OPENAI_API_KEY: + errors.append("LLM_API_KEY 或 OPENAI_API_KEY 未配置 (当前模式: local,用于嵌入向量)") + else: + errors.append(f"未知的 KNOWLEDGE_GRAPH_MODE: {cls.KNOWLEDGE_GRAPH_MODE}") + return errors diff --git a/backend/app/services/graph_builder.py b/backend/app/services/graph_builder.py index 0e0444bf..601be4fb 100644 --- a/backend/app/services/graph_builder.py +++ b/backend/app/services/graph_builder.py @@ -1,6 +1,7 @@ """ 图谱构建服务 -接口2:使用Zep API构建Standalone Graph +接口2:使用知识图谱API构建图谱 +支持 Zep Cloud 和 Graphiti (本地) 两种模式 """ import os @@ -10,14 +11,19 @@ from typing import Dict, Any, List, Optional, Callable from dataclasses import dataclass -from zep_cloud.client import Zep -from zep_cloud import EpisodeData, EntityEdgeSourceTarget - from ..config import Config from ..models.task import TaskManager, TaskStatus -from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges +from .kg_adapter import get_knowledge_graph_adapter from .text_processor import TextProcessor +# 保留原有的导入,用于动态类生成(兼容模式) +try: + from zep_cloud import EpisodeData, EntityEdgeSourceTarget + from zep_cloud.external_clients.ontology import EntityModel, EntityText, EdgeModel + ZEP_CLOUD_AVAILABLE = True +except ImportError: + ZEP_CLOUD_AVAILABLE = False + @dataclass class GraphInfo: @@ -39,15 +45,14 @@ def to_dict(self) -> Dict[str, Any]: class GraphBuilderService: """ 图谱构建服务 - 负责调用Zep API构建知识图谱 + 负责调用知识图谱 API 构建图谱 + 支持 Zep Cloud 和 Graphiti 两种模式 """ - + def __init__(self, api_key: Optional[str] = None): - self.api_key = api_key or Config.ZEP_API_KEY - if not self.api_key: - raise ValueError("ZEP_API_KEY 未配置") - - self.client = Zep(api_key=self.api_key) + self.api_key = api_key # 保留参数兼容性 + # 使用适配器 + self.kg = get_knowledge_graph_adapter() self.task_manager = TaskManager() def build_graph_async( @@ -185,15 +190,14 @@ def _build_graph_worker( self.task_manager.fail_task(task_id, error_msg) def create_graph(self, name: str) -> str: - """创建Zep图谱(公开方法)""" + """创建图谱(公开方法)""" graph_id = f"mirofish_{uuid.uuid4().hex[:16]}" - - self.client.graph.create( + + self.kg.create_graph( graph_id=graph_id, name=name, - description="MiroFish Social Simulation Graph" ) - + return graph_id def set_ontology(self, graph_id: str, ontology: Dict[str, Any]): @@ -277,13 +281,14 @@ def safe_attr_name(attr_name: str) -> str: if source_targets: edge_definitions[name] = (edge_class, source_targets) - # 调用Zep API设置本体 + # 调用图谱API设置本体 if entity_types or edge_definitions: - self.client.graph.set_ontology( - graph_ids=[graph_id], - entities=entity_types if entity_types else None, - edges=edge_definitions if edge_definitions else None, - ) + # 封装为 ontology 格式 + ontology = { + "entities": entity_types if entity_types else None, + "edges": edge_definitions if edge_definitions else None, + } + self.kg.set_ontology(graph_id, ontology) def add_text_batches( self, @@ -293,49 +298,59 @@ def add_text_batches( progress_callback: Optional[Callable] = None ) -> List[str]: """分批添加文本到图谱,返回所有 episode 的 uuid 列表""" + import logging + build_logger = logging.getLogger('mirofish.build') + episode_uuids = [] total_chunks = len(chunks) - + + build_logger.debug(f"[add_text_batches] 开始添加 {total_chunks} 个块,batch_size={batch_size}") + for i in range(0, total_chunks, batch_size): batch_chunks = chunks[i:i + batch_size] batch_num = i // batch_size + 1 total_batches = (total_chunks + batch_size - 1) // batch_size - + if progress_callback: progress = (i + len(batch_chunks)) / total_chunks progress_callback( f"发送第 {batch_num}/{total_batches} 批数据 ({len(batch_chunks)} 块)...", progress ) - - # 构建episode数据 - episodes = [ - EpisodeData(data=chunk, type="text") - for chunk in batch_chunks - ] - - # 发送到Zep + + build_logger.debug(f"[add_text_batches] 准备发送批次 {batch_num}/{total_batches}") + + # 发送到图谱 try: - batch_result = self.client.graph.add_batch( + # 使用适配器的批量添加方法 + build_logger.debug(f"[add_text_batches] 调用 kg.add_episodes_batch...") + batch_result = self.kg.add_episodes_batch( graph_id=graph_id, - episodes=episodes + texts=batch_chunks ) - - # 收集返回的 episode uuid + build_logger.debug(f"[add_text_batches] 批次 {batch_num} 发送完成") + + # 收集返回的 episode uuid(兼容 dict 和对象两种格式) if batch_result and isinstance(batch_result, list): for ep in batch_result: - ep_uuid = getattr(ep, 'uuid_', None) or getattr(ep, 'uuid', None) + if isinstance(ep, dict): + ep_uuid = ep.get('uuid') or ep.get('uuid_') + else: + ep_uuid = getattr(ep, 'uuid_', None) or getattr(ep, 'uuid', None) if ep_uuid: episode_uuids.append(ep_uuid) - + build_logger.debug(f"[add_text_batches] 收集到 episode uuid: {ep_uuid}") + # 避免请求过快 time.sleep(1) - + except Exception as e: + build_logger.error(f"[add_text_batches] 批次 {batch_num} 发送失败: {str(e)}") if progress_callback: progress_callback(f"批次 {batch_num} 发送失败: {str(e)}", 0) raise - + + build_logger.debug(f"[add_text_batches] 所有批次发送完成,共 {len(episode_uuids)} 个 episode") return episode_uuids def _wait_for_episodes( @@ -370,13 +385,17 @@ def _wait_for_episodes( # 检查每个 episode 的处理状态 for ep_uuid in list(pending_episodes): try: - episode = self.client.graph.episode.get(uuid_=ep_uuid) - is_processed = getattr(episode, 'processed', False) - + episode = self.kg.get_episode(ep_uuid) + # 兼容 dict 和对象两种格式 + if isinstance(episode, dict): + is_processed = episode.get('processed', False) + else: + is_processed = getattr(episode, 'processed', False) + if is_processed: pending_episodes.remove(ep_uuid) completed_count += 1 - + except Exception as e: # 忽略单个查询错误,继续 pass @@ -396,17 +415,18 @@ def _wait_for_episodes( def _get_graph_info(self, graph_id: str) -> GraphInfo: """获取图谱信息""" - # 获取节点(分页) - nodes = fetch_all_nodes(self.client, graph_id) + # 获取节点(使用适配器) + nodes = self.kg.get_nodes(graph_id, limit=2000) - # 获取边(分页) - edges = fetch_all_edges(self.client, graph_id) + # 获取边(使用适配器) + edges = self.kg.get_edges(graph_id, limit=2000) # 统计实体类型 entity_types = set() for node in nodes: - if node.labels: - for label in node.labels: + labels = node.labels if hasattr(node, 'labels') else node.get('labels', []) + if labels: + for label in labels: if label not in ["Entity", "Node"]: entity_types.add(label) @@ -420,72 +440,113 @@ def _get_graph_info(self, graph_id: str) -> GraphInfo: def get_graph_data(self, graph_id: str) -> Dict[str, Any]: """ 获取完整图谱数据(包含详细信息) - + Args: graph_id: 图谱ID - + Returns: 包含nodes和edges的字典,包括时间信息、属性等详细数据 """ - nodes = fetch_all_nodes(self.client, graph_id) - edges = fetch_all_edges(self.client, graph_id) + # 使用适配器获取节点和边 + nodes = self.kg.get_nodes(graph_id, limit=2000) + edges = self.kg.get_edges(graph_id, limit=2000) - # 创建节点映射用于获取节点名称 + # 创建节点映射用于获取节点名称(兼容对象和字典两种格式) node_map = {} for node in nodes: - node_map[node.uuid_] = node.name or "" - + if isinstance(node, dict): + node_map[node.get('uuid_', '')] = node.get('name', '') or "" + else: + node_map[getattr(node, 'uuid_', '')] = getattr(node, 'name', '') or "" + nodes_data = [] for node in nodes: - # 获取创建时间 - created_at = getattr(node, 'created_at', None) - if created_at: - created_at = str(created_at) - - nodes_data.append({ - "uuid": node.uuid_, - "name": node.name, - "labels": node.labels or [], - "summary": node.summary or "", - "attributes": node.attributes or {}, - "created_at": created_at, - }) - + # 兼容对象和字典两种格式 + if isinstance(node, dict): + created_at = node.get('created_at') + if created_at: + created_at = str(created_at) + nodes_data.append({ + "uuid": node.get('uuid_', ''), + "name": node.get('name', ''), + "labels": node.get('labels', []), + "summary": node.get('summary', ''), + "attributes": node.get('attributes', {}), + "created_at": created_at, + }) + else: + created_at = getattr(node, 'created_at', None) + if created_at: + created_at = str(created_at) + nodes_data.append({ + "uuid": getattr(node, 'uuid_', ''), + "name": getattr(node, 'name', ''), + "labels": getattr(node, 'labels', []), + "summary": getattr(node, 'summary', ''), + "attributes": getattr(node, 'attributes', {}), + "created_at": created_at, + }) + edges_data = [] for edge in edges: - # 获取时间信息 - created_at = getattr(edge, 'created_at', None) - valid_at = getattr(edge, 'valid_at', None) - invalid_at = getattr(edge, 'invalid_at', None) - expired_at = getattr(edge, 'expired_at', None) - - # 获取 episodes - episodes = getattr(edge, 'episodes', None) or getattr(edge, 'episode_ids', None) - if episodes and not isinstance(episodes, list): - episodes = [str(episodes)] - elif episodes: - episodes = [str(e) for e in episodes] - - # 获取 fact_type - fact_type = getattr(edge, 'fact_type', None) or edge.name or "" - - edges_data.append({ - "uuid": edge.uuid_, - "name": edge.name or "", - "fact": edge.fact or "", - "fact_type": fact_type, - "source_node_uuid": edge.source_node_uuid, - "target_node_uuid": edge.target_node_uuid, - "source_node_name": node_map.get(edge.source_node_uuid, ""), - "target_node_name": node_map.get(edge.target_node_uuid, ""), - "attributes": edge.attributes or {}, - "created_at": str(created_at) if created_at else None, - "valid_at": str(valid_at) if valid_at else None, - "invalid_at": str(invalid_at) if invalid_at else None, - "expired_at": str(expired_at) if expired_at else None, - "episodes": episodes or [], - }) - + # 兼容对象和字典两种格式 + if isinstance(edge, dict): + created_at = edge.get('created_at') + valid_at = edge.get('valid_at') + invalid_at = edge.get('invalid_at') + expired_at = edge.get('expired_at') + episodes = edge.get('episodes', []) + fact_type = edge.get('fact_type', '') or edge.get('name', '') + edges_data.append({ + "uuid": edge.get('uuid_', ''), + "name": edge.get('name', ''), + "fact": edge.get('fact', ''), + "fact_type": fact_type, + "source_node_uuid": edge.get('source_node_uuid', ''), + "target_node_uuid": edge.get('target_node_uuid', ''), + "source_node_name": node_map.get(edge.get('source_node_uuid', ''), ''), + "target_node_name": node_map.get(edge.get('target_node_uuid', ''), ''), + "attributes": edge.get('attributes', {}), + "created_at": str(created_at) if created_at else None, + "valid_at": str(valid_at) if valid_at else None, + "invalid_at": str(invalid_at) if invalid_at else None, + "expired_at": str(expired_at) if expired_at else None, + "episodes": episodes if isinstance(episodes, list) else [], + }) + else: + # 获取时间信息 + created_at = getattr(edge, 'created_at', None) + valid_at = getattr(edge, 'valid_at', None) + invalid_at = getattr(edge, 'invalid_at', None) + expired_at = getattr(edge, 'expired_at', None) + + # 获取 episodes + episodes = getattr(edge, 'episodes', None) or getattr(edge, 'episode_ids', None) + if episodes and not isinstance(episodes, list): + episodes = [str(episodes)] + elif episodes: + episodes = [str(e) for e in episodes] + + # 获取 fact_type + fact_type = getattr(edge, 'fact_type', None) or getattr(edge, 'name', '') or "" + + edges_data.append({ + "uuid": getattr(edge, 'uuid_', ''), + "name": getattr(edge, 'name', ''), + "fact": getattr(edge, 'fact', ''), + "fact_type": fact_type, + "source_node_uuid": getattr(edge, 'source_node_uuid', ''), + "target_node_uuid": getattr(edge, 'target_node_uuid', ''), + "source_node_name": node_map.get(getattr(edge, 'source_node_uuid', ''), ''), + "target_node_name": node_map.get(getattr(edge, 'target_node_uuid', ''), ''), + "attributes": getattr(edge, 'attributes', {}), + "created_at": str(created_at) if created_at else None, + "valid_at": str(valid_at) if valid_at else None, + "invalid_at": str(invalid_at) if invalid_at else None, + "expired_at": str(expired_at) if expired_at else None, + "episodes": episodes or [], + }) + return { "graph_id": graph_id, "nodes": nodes_data, @@ -496,5 +557,5 @@ def get_graph_data(self, graph_id: str) -> Dict[str, Any]: def delete_graph(self, graph_id: str): """删除图谱""" - self.client.graph.delete(graph_id=graph_id) + self.kg.delete(graph_id=graph_id) diff --git a/backend/app/services/kg_adapter.py b/backend/app/services/kg_adapter.py new file mode 100644 index 00000000..ab4d2b8e --- /dev/null +++ b/backend/app/services/kg_adapter.py @@ -0,0 +1,730 @@ +""" +知识图谱适配器 +支持 Zep Cloud 和 Graphiti (本地) 两种模式 + +使用方式: + from app.services.kg_adapter import get_knowledge_graph_adapter + + kg = get_knowledge_graph_adapter() + kg.add_episode(graph_id="xxx", text="hello") + kg.search(graph_id="xxx", query="hello") +""" + +from abc import ABC, abstractmethod +from typing import Dict, Any, List, Optional +import logging + +from ..config import Config +from graphiti_core.embedder import EmbedderClient as EmbeddingClient + +logger = logging.getLogger(__name__) + + +class KnowledgeGraphAdapter(ABC): + """知识图谱适配器抽象基类""" + + @abstractmethod + def create_graph(self, graph_id: str, name: str = None) -> Any: + """创建图谱""" + pass + + @abstractmethod + def add_episode(self, graph_id: str, text: str, **kwargs) -> Any: + """添加单条内容""" + pass + + @abstractmethod + def add_episodes_batch(self, graph_id: str, texts: List[str]) -> List[Any]: + """批量添加内容""" + pass + + @abstractmethod + def get_episode(self, episode_uuid: str) -> Any: + """获取单个 episode""" + pass + + @abstractmethod + def search(self, graph_id: str, query: str, limit: int = 10, scope: str = "all", reranker: str = None) -> List[Dict]: + """搜索""" + pass + + @abstractmethod + def get_nodes(self, graph_id: str, limit: int = 100, cursor: str = None) -> List[Any]: + """获取节点""" + pass + + @abstractmethod + def get_node(self, node_uuid: str) -> Any: + """获取单个节点""" + pass + + @abstractmethod + def get_node_edges(self, node_uuid: str) -> List[Dict]: + """获取节点的所有边""" + pass + + @abstractmethod + def get_edges(self, graph_id: str, limit: int = 100, cursor: str = None) -> List[Any]: + """获取边""" + pass + + @abstractmethod + def delete(self, graph_id: str) -> bool: + """删除图谱""" + pass + + @abstractmethod + def set_ontology(self, graph_id: str, ontology: Dict) -> bool: + """设置本体""" + pass + + @abstractmethod + def get_graph_info(self, graph_id: str) -> Dict: + """获取图谱信息""" + pass + + +class ZepCloudAdapter(KnowledgeGraphAdapter): + """Zep Cloud 适配器""" + + def __init__(self, api_key: str = None): + from zep_cloud.client import Zep + self.api_key = api_key or Config.ZEP_API_KEY + if not self.api_key: + raise ValueError("ZEP_API_KEY 未配置") + self.client = Zep(api_key=self.api_key) + logger.info("ZepCloudAdapter 初始化完成") + + def create_graph(self, graph_id: str, name: str = None) -> Any: + return self.client.graph.create(graph_id=graph_id, name=name or graph_id) + + def add_episode(self, graph_id: str, text: str, **kwargs) -> Any: + return self.client.graph.add(graph_id=graph_id, type="text", data=text) + + def add_episodes_batch(self, graph_id: str, texts: List[str]) -> List[Any]: + from zep_cloud.types import EpisodeData + episodes = [EpisodeData(data=t, type="text") for t in texts] + return self.client.graph.add_batch(episodes=episodes, graph_id=graph_id) + + def get_episode(self, episode_uuid: str) -> Any: + return self.client.graph.episode.get(uuid_=episode_uuid) + + def search(self, graph_id: str, query: str, limit: int = 10, scope: str = "all", reranker: str = None): + """搜索图谱 + + 返回 GraphSearchResults 对象: + - scope="edges": 结果在 .edges 中 + - scope="nodes": 结果在 .nodes 中 + - scope="all": 结果同时在 .edges 和 .nodes 中 + """ + try: + result = self.client.graph.search(graph_id=graph_id, query=query, limit=limit, scope=scope, reranker=reranker) + logger.info(f"[ZepCloud search] query={query}, edges={len(result.edges) if hasattr(result, 'edges') and result.edges else 0}, nodes={len(result.nodes) if hasattr(result, 'nodes') and result.nodes else 0}") + return result + except Exception as e: + logger.error(f"[ZepCloud search] API调用失败: {e}") + # 返回空的 GraphSearchResults + from zep_cloud.types.graph_search_results import GraphSearchResults + return GraphSearchResults(edges=[], nodes=[], episodes=[]) + + def get_nodes(self, graph_id: str, limit: int = 100, cursor: str = None) -> List[Any]: + kwargs = {"limit": limit} + if cursor: + kwargs["uuid_cursor"] = cursor + return self.client.graph.node.get_by_graph_id(graph_id=graph_id, **kwargs) + + def get_node(self, node_uuid: str) -> Any: + return self.client.graph.node.get(uuid_=node_uuid) + + def get_node_edges(self, node_uuid: str) -> List[Dict]: + edges = self.client.graph.node.get_entity_edges(node_uuid=node_uuid) + return [e.model_dump() if hasattr(e, 'model_dump') else e for e in edges] + + def get_edges(self, graph_id: str, limit: int = 100, cursor: str = None) -> List[Any]: + kwargs = {"limit": limit} + if cursor: + kwargs["uuid_cursor"] = cursor + return self.client.graph.edge.get_by_graph_id(graph_id=graph_id, **kwargs) + + def delete(self, graph_id: str) -> bool: + self.client.graph.delete(graph_id=graph_id) + return True + + def set_ontology(self, graph_id: str, ontology: Dict) -> bool: + entities = ontology.get('entities', {}) + edges = ontology.get('edges', {}) + self.client.graph.set_ontology( + entities=entities, + edges=edges, + graph_ids=[graph_id] + ) + return True + + def get_graph_info(self, graph_id: str) -> Dict: + # Zep Cloud 没有直接的图谱信息 API,返回基本信息 + return {"graph_id": graph_id} + + +# 自定义 Embedder,支持可配置的批处理大小 +class SingleEmbeddingEmbedder(EmbeddingClient): + """自定义 embedder,支持可配置的批处理大小 + + Args: + base: 基础 embedder + batch_size: 批量大小,默认 10(阿里百炼支持),设为 1 则逐个处理 + """ + + def __init__(self, base, batch_size: int = 10): + self.base = base + self.batch_size = batch_size + + async def create(self, input_data): + # 如果是列表,根据列表长度决定返回格式 + if isinstance(input_data, list): + if len(input_data) == 1: + return await self.base.create(input_data[0]) + elif len(input_data) == 0: + return await self.base.create("") + else: + # 多个输入,调用批量处理 + return await self.create_batch(input_data) + return await self.base.create(input_data) + + async def create_batch(self, input_data_list): + # 逐个处理,避免兼容性问题 + results = [] + for text in input_data_list: + embedding = await self.base.create(text) + results.append(embedding) + return results + + +class GraphitiAdapter(KnowledgeGraphAdapter): + """Graphiti 适配器 - 本地部署 + + 注意:类级别的 event loop 会在应用退出时自动释放, + 不需要手动关闭。长时间运行的服务器不需要关闭 event loop。 + """ + + # 类级别的 event loop,供所有实例共享 + _event_loop = None + + def __init__(self): + import os + import asyncio + from graphiti_core import Graphiti + from graphiti_core.llm_client.config import LLMConfig + from graphiti_core.llm_client.openai_client import OpenAIClient + from graphiti_core.embedder import OpenAIEmbedder, OpenAIEmbedderConfig + + if not all([Config.NEO4J_URI, Config.NEO4J_USER, Config.NEO4J_PASSWORD]): + raise ValueError("Neo4j 配置不完整,请检查 NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD") + + # 获取 API Key(优先使用独立配置,其次使用 LLM 配置) + api_key = Config.LLM_API_KEY or Config.OPENAI_API_KEY + llm_base_url = Config.LLM_BASE_URL + + # 嵌入模型独立配置 + embedding_api_key = Config.EMBEDDING_API_KEY or api_key + embedding_base_url = Config.EMBEDDING_BASE_URL or llm_base_url + + if not api_key: + raise ValueError("请配置 LLM_API_KEY") + + # 设置环境变量(Graphiti 内部组件会读取) + os.environ['OPENAI_API_KEY'] = api_key + os.environ['OPENAI_BASE_URL'] = llm_base_url + + # 配置 LLM 客户端(支持 OpenAI 兼容 API) + llm_config = LLMConfig( + api_key=api_key, + base_url=llm_base_url, + model=Config.LLM_MODEL_NAME, + small_model=Config.LLM_MODEL_NAME, # 使用相同模型 + ) + llm_client = OpenAIClient(config=llm_config) + + # 配置 Embedder 客户端(可独立配置) + # 注意:一些 embedding API 不支持批量请求 + # 因此我们创建一个自定义包装器来确保每次只处理单个文本 + embedder_config = OpenAIEmbedderConfig( + api_key=embedding_api_key, + base_url=embedding_base_url, + embedding_model=Config.EMBEDDING_MODEL, + embedding_dim=Config.EMBEDDING_DIM, + ) + logger.debug("embedding_base_url:" + embedding_base_url) + base_embedder = OpenAIEmbedder(config=embedder_config) + + # 使用自定义包装器,支持可配置的批处理大小 + # 默认 batch_size=10(阿里百炼支持),可在 Config 中配置 + batch_size = getattr(Config, 'EMBEDDING_BATCH_SIZE', 10) + embedder_client = SingleEmbeddingEmbedder(base_embedder, batch_size=batch_size) + logger.debug(f"model: {Config.EMBEDDING_MODEL}, batch_size: {batch_size}") + + self.client = Graphiti( + uri=Config.NEO4J_URI, + user=Config.NEO4J_USER, + password=Config.NEO4J_PASSWORD, + llm_client=llm_client, + embedder=embedder_client, + cross_encoder=None, # 禁用 reranker,需要时可配置 + ) + # graph_id 到 group 的映射(Graphiti 使用 group 区分不同的图) + self._graph_id_to_group: Dict[str, str] = {} + + # 使用同步驱动避免 asyncio 事件循环冲突 + from neo4j import GraphDatabase + self._sync_driver = GraphDatabase.driver( + Config.NEO4J_URI, + auth=(Config.NEO4J_USER, Config.NEO4J_PASSWORD) + ) + + # 初始化数据库索引 + try: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(self.client.build_indices_and_constraints()) + # 不关闭 loop,保存起来供后续使用 + GraphitiAdapter._event_loop = loop + logger.info("Graphiti 数据库索引初始化完成") + except Exception as e: + logger.warning(f"数据库索引初始化警告: {e}") + + logger.info("GraphitiAdapter 初始化完成") + + def _run_async(self, coro, timeout: int = 300): + """同步调用异步方法的包装器,使用类级别的 event loop,带超时保护""" + import asyncio + import concurrent.futures + + # 使用类级别的 event loop + if GraphitiAdapter._event_loop is None or GraphitiAdapter._event_loop.is_closed(): + GraphitiAdapter._event_loop = asyncio.new_event_loop() + asyncio.set_event_loop(GraphitiAdapter._event_loop) + + # 使用线程池执行,避免阻塞 + def run_in_loop(): + return GraphitiAdapter._event_loop.run_until_complete(coro) + + try: + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(run_in_loop) + return future.result(timeout=timeout) + except concurrent.futures.TimeoutError: + logger.error(f"Graphiti 操作超时 ({timeout}秒)") + raise TimeoutError(f"Graphiti operation timed out after {timeout} seconds") + except Exception as e: + logger.error(f"Graphiti 操作失败: {str(e)}") + raise + + def _get_group(self, graph_id: str) -> str: + """获取或创建 group""" + if graph_id not in self._graph_id_to_group: + self._graph_id_to_group[graph_id] = graph_id + return self._graph_id_to_group[graph_id] + + def create_graph(self, graph_id: str, name: str = None) -> Any: + # Graphiti 不需要预创建图,通过 group 区分 + self._graph_id_to_group[graph_id] = graph_id + + # 创建 Group 节点 + with self._sync_driver.session() as session: + session.run(""" + MERGE (g:Group {name: $name}) + SET g.created_at = datetime() + """, name=graph_id) + + logger.info(f"Graphiti: 标记图谱 {graph_id}") + return {"status": "ok", "graph_id": graph_id} + + def add_episode(self, graph_id: str, text: str, **kwargs) -> Any: + """使用同步驱动添加 episode""" + import uuid + from datetime import datetime, timezone + + group = self._get_group(graph_id) + episode_uuid = str(uuid.uuid4()) + now = datetime.now(timezone.utc) + + # 直接使用同步驱动创建 episode + with self._sync_driver.session() as session: + query = """ + CREATE (e:Episodic { + uuid: $uuid, + name: $name, + content: $content, + created_at: $created_at, + valid_at: $valid_at, + group_id: $group_id, + source: 'text', + episode_type: 'text' + }) + RETURN e + """ + result = session.run( + query, + uuid=episode_uuid, + name=f"episode_{now.strftime('%Y%m%d%H%M%S')}", + content=text, + created_at=now, + valid_at=now, + group_id=group + ) + record = result.single() + return {"uuid": episode_uuid, "name": record["e"]["name"]} if record else None + + def add_episodes_batch(self, graph_id: str, texts: List[str], batch_size: int = 10) -> List[Any]: + """批量添加内容,使用 Graphiti 原生的 add_episode API""" + from datetime import datetime, timezone + from graphiti_core.nodes import EpisodeType + + results = [] + group = self._get_group(graph_id) + now = datetime.now(timezone.utc) + + # 获取实体类型(如果有的话) + entity_types = getattr(self, '_entity_types', None) + if entity_types: + logger.info(f"Graphiti: 使用 {len(entity_types)} 个实体类型进行提取: {list(entity_types.keys())}") + + # 使用 Graphiti 原生的 add_episode 方法 + # 它会自动:1. 用 embedder 做嵌入 2. 用 LLM 提取实体和关系 + for i, text in enumerate(texts): + episode_name = f"episode_{now.strftime('%Y%m%d%H%M%S')}_{i}" + + try: + # 调用 Graphiti 原生 API,传入 entity_types + result = self._run_async( + self.client.add_episode( + name=episode_name, + episode_body=text, + source_description="MiroFish document", + reference_time=now, + source=EpisodeType.text, + group_id=group, + entity_types=entity_types, # 传入实体类型定义 + ) + ) + + # 从返回结果中获取 episode 的 uuid + episode_uuid = None + if result and hasattr(result, 'episode'): + episode_uuid = getattr(result.episode, 'uuid_', None) or getattr(result.episode, 'uuid', None) + + logger.info(f"Graphiti 原生添加 episode {i+1}/{len(texts)}: {episode_uuid}") + results.append({"uuid": episode_uuid, "name": episode_name}) + + except Exception as e: + logger.error(f"添加 episode 失败: {str(e)}") + results.append({"uuid": None, "name": episode_name, "error": str(e)}) + + logger.info(f"Graphiti 实体全部构建完成: {graph_id}, 共 {len(results)} 条") + return results + + def get_episode(self, episode_uuid: str) -> Any: + """使用同步驱动获取 episode""" + with self._sync_driver.session() as session: + query = """ + MATCH (e:Episodic {uuid: $uuid}) + RETURN e.content as content, e.created_at as created_at, + e.valid_at as valid_at, e.uuid as uuid, + e.name as name, e.group_id as group_id + """ + result = session.run(query, uuid=episode_uuid) + record = result.single() + if record: + data = dict(record) + # Graphiti 模式下添加是同步的,返回 processed=True 表示已完成 + data['processed'] = True + logger.debug(f"[get_episode] uuid={episode_uuid}, processed=True") + return data + logger.warning(f"[get_episode] uuid={episode_uuid}, 未找到 episode") + return None + + def search(self, graph_id: str, query: str, limit: int = 10, scope: str = "all", reranker: str = None): + logger.info(f"[GraphitiAdapter.search] 调用") + """使用同步驱动搜索 + + 返回兼容对象格式: + - scope="edges": 返回带 .edges 属性的对象 + - scope="nodes": 返回带 .nodes 属性的对象 + - scope="all": 返回带 .edges 和 .nodes 属性的对象 + """ + from dataclasses import dataclass, field + import re + + @dataclass + class SearchResult: + edges: list = field(default_factory=list) + nodes: list = field(default_factory=list) + + group = self._get_group(graph_id) + result = SearchResult() + + # 从查询中提取关键词(移除"关于...的所有信息"等前缀) + search_keyword = query + if '的' in search_keyword: + match = re.search(r'关于(.+?)的', search_keyword) + if match: + search_keyword = match.group(1).strip() + if len(search_keyword) > 10: + search_keyword = search_keyword[:10] + + with self._sync_driver.session() as session: + # 搜索 Episodes 作为事实来源 + episode_query = """ + MATCH (e:Episodic {group_id: $gid}) + WHERE e.content CONTAINS $search + RETURN e.content as content, e.uuid as uuid, e.name as name + """ + episode_result = session.run( + episode_query, + gid=group, + search=search_keyword + ) + + episodes = [{"content": r["content"], "uuid": r["uuid"], "name": r["name"]} for r in episode_result] + + # 根据 scope 返回对应格式 + if scope in ("edges", "all"): + # 将 episodes 内容转为 fact 格式 + for ep in episodes: + class Edge: + def __init__(self, fact): + self.fact = fact + result.edges.append(Edge(ep.get("content", ""))) + + if scope in ("nodes", "all"): + # 搜索相关实体节点 + entity_query = """ + MATCH (e:Entity {group_id: $gid}) + WHERE e.name CONTAINS $search OR e.summary CONTAINS $search + RETURN e.uuid as uuid_, e.name as name, e.summary as summary + """ + entity_result = session.run( + entity_query, + gid=group, + search=search_keyword + ) + + for ent in entity_result: + class Node: + def __init__(self, name, summary): + self.name = name + self.summary = summary if summary else "" + result.nodes.append(Node(ent["name"], ent.get("summary"))) + + return result + + def get_nodes(self, graph_id: str, limit: int = 100, cursor: str = None) -> List[Any]: + """通过同步驱动查询实体节点""" + with self._sync_driver.session() as session: + # Graphiti 使用 group_id 属性来区分不同的图谱 + query = """ + MATCH (e:Entity {group_id: $group_id}) + RETURN e.uuid as uuid_, e.name as name, labels(e) as labels, + e.summary as summary, e.created_at as created_at, + e.entity_type as entity_type + LIMIT $limit + """ + result = session.run(query, group_id=graph_id, limit=limit) + nodes = [dict(record) for record in result] + logger.info(f"[get_nodes] graph_id={graph_id}, 查询到 {len(nodes)} 个节点") + + # 转换格式以兼容前端 + for node in nodes: + if 'attributes' not in node: + node['attributes'] = {} + + # 优先使用 entity_type 属性 + entity_type = node.get('entity_type') + if entity_type: + node['labels'] = [entity_type] + node['attributes']['entity_type'] = entity_type + else: + # 从标签中提取实体类型(第一个非 Entity 的标签) + labels = node.get('labels', []) + found_type = None + for label in labels: + if label and label != 'Entity': + found_type = label + break + + if found_type: + node['labels'] = [found_type] + node['attributes']['entity_type'] = found_type + else: + # 如果都没有,使用节点名称作为类型 + node['labels'] = ['Entity'] + node['attributes']['entity_type'] = 'Entity' + + return nodes + + def get_node(self, node_uuid: str) -> Any: + """通过同步驱动获取单个节点""" + with self._sync_driver.session() as session: + query = """ + MATCH (e:Entity {uuid: $uuid}) + RETURN e.uuid as uuid_, e.name as name, labels(e) as labels, + e.summary as summary, e.created_at as created_at + """ + result = session.run(query, uuid=node_uuid) + record = result.single() + if record: + node = dict(record) + if 'attributes' not in node: + node['attributes'] = {} + return node + return None + + def get_node_edges(self, node_uuid: str) -> List[Dict]: + """通过同步驱动获取节点的所有边""" + with self._sync_driver.session() as session: + query = """ + MATCH (e1:Entity {uuid: $uuid})-[r]-(e2:Entity) + RETURN r.uuid as uuid_, type(r) as name, r.fact as fact, + r.fact_type as fact_type, + e1.uuid as source_node_uuid, e2.uuid as target_node_uuid, + e1.name as source_node_name, e2.name as target_node_name, + r.created_at as created_at + """ + result = session.run(query, uuid=node_uuid) + edges = [dict(record) for record in result] + return edges + + def get_edges(self, graph_id: str, limit: int = 100, cursor: str = None) -> List[Any]: + """通过同步驱动查询边""" + with self._sync_driver.session() as session: + query = """ + MATCH (e1:Entity {group_id: $group_id})-[r]-(e2:Entity {group_id: $group_id}) + RETURN r.uuid as uuid_, type(r) as name, r.fact as fact, + r.fact_type as fact_type, + e1.uuid as source_node_uuid, e2.uuid as target_node_uuid, + e1.name as source_node_name, e2.name as target_node_name, + r.created_at as created_at, r.valid_at as valid_at, + r.invalid_at as invalid_at, r.expired_at as expired_at + LIMIT $limit + """ + result = session.run(query, group_id=graph_id, limit=limit) + edges = [dict(record) for record in result] + # 兼容前端格式 + for edge in edges: + if 'attributes' not in edge: + edge['attributes'] = {} + if 'episodes' not in edge: + edge['episodes'] = [] + return edges + + def delete(self, graph_id: str) -> bool: + """使用同步驱动删除图谱""" + with self._sync_driver.session() as session: + # 删除关联边(使用 group_id 属性) + session.run(""" + MATCH (e1:Entity)-[r]-(e2:Entity) + WHERE e1.group_id = $group_id OR e2.group_id = $group_id + DELETE r + """, group_id=graph_id) + # 删除实体节点(使用 group_id 属性) + session.run(""" + MATCH (e:Entity {group_id: $group_id}) + DELETE e + """, group_id=graph_id) + + if graph_id in self._graph_id_to_group: + del self._graph_id_to_group[graph_id] + + logger.info(f"Graphiti: 删除图谱 {graph_id}") + return True + + def set_ontology(self, graph_id: str, ontology: Dict) -> bool: + """设置实体类型(Graphiti 模式)""" + import warnings + from typing import Optional + from pydantic import Field + from graphiti_core.nodes import EntityNode + + # graph_builder.set_ontology 已经把 ontology 转换成 Pydantic 类 + # 格式: {'entities': {类名: 类}, 'edges': {...}} + entity_types = {} + + if ontology.get("entities") and isinstance(ontology.get("entities"), dict): + entities = ontology.get("entities", {}) + for name, entity_class in entities.items(): + entity_types[name] = entity_class + logger.info(f"Graphiti: 使用已处理的实体类型,共 {len(entity_types)} 个") + else: + logger.warning(f"Graphiti: ontology 格式异常: {list(ontology.keys())}") + + # 存储到实例变量 + self._entity_types = entity_types + + return True + + def get_graph_info(self, graph_id: str) -> Dict: + """使用同步驱动获取图谱信息""" + with self._sync_driver.session() as session: + # 统计节点数量 - 使用 group_id 属性 + node_result = session.run(""" + MATCH (e:Entity {group_id: $group_id}) + RETURN count(e) as count + """, group_id=graph_id) + node_count = node_result.single()["count"] if node_result.single() else 0 + + # 统计边数量 - 使用 group_id 属性 + edge_result = session.run(""" + MATCH (e1:Entity {group_id: $group_id})-[r]-(e2:Entity {group_id: $group_id}) + RETURN count(r) as count + """, group_id=graph_id) + edge_count = edge_result.single()["count"] if edge_result.single() else 0 + + return { + "graph_id": graph_id, + "node_count": node_count, + "edge_count": edge_count, + } + + def _result_to_dict(self, result) -> Dict: + if hasattr(result, 'model_dump'): + return result.model_dump() + elif hasattr(result, 'dict'): + return result.dict() + return {} + + +# 全局缓存 +_adapter_cache: Optional[KnowledgeGraphAdapter] = None + + +def get_knowledge_graph_adapter(force_refresh: bool = True) -> KnowledgeGraphAdapter: + """ + 获取知识图谱适配器实例 + + Args: + force_refresh: 是否强制刷新缓存 + + Returns: + KnowledgeGraphAdapter: 适配器实例 + """ + global _adapter_cache + + if _adapter_cache is not None and not force_refresh: + return _adapter_cache + + mode = Config.KNOWLEDGE_GRAPH_MODE + logger.info(f"[kg_adapter] 使用模式: {mode}") + + if mode == 'local': + _adapter_cache = GraphitiAdapter() + elif mode == 'cloud': + _adapter_cache = ZepCloudAdapter() + else: + raise ValueError(f"未知的 KNOWLEDGE_GRAPH_MODE: {mode}") + + return _adapter_cache + + +def reset_adapter(): + """重置适配器缓存""" + global _adapter_cache + _adapter_cache = None diff --git a/backend/app/services/oasis_profile_generator.py b/backend/app/services/oasis_profile_generator.py index 57836c53..4a99a29e 100644 --- a/backend/app/services/oasis_profile_generator.py +++ b/backend/app/services/oasis_profile_generator.py @@ -1,9 +1,10 @@ """ OASIS Agent Profile生成器 -将Zep图谱中的实体转换为OASIS模拟平台所需的Agent Profile格式 +将图谱中的实体转换为OASIS模拟平台所需的Agent Profile格式 +支持 Zep Cloud 和 Graphiti 两种模式 优化改进: -1. 调用Zep检索功能二次丰富节点信息 +1. 调用图谱检索功能二次丰富节点信息 2. 优化提示词生成非常详细的人设 3. 区分个人实体和抽象群体实体 """ @@ -16,11 +17,11 @@ from datetime import datetime from openai import OpenAI -from zep_cloud.client import Zep from ..config import Config from ..utils.logger import get_logger from .zep_entity_reader import EntityNode, ZepEntityReader +from .kg_adapter import get_knowledge_graph_adapter logger = get_logger('mirofish.oasis_profile') @@ -196,17 +197,15 @@ def __init__( api_key=self.api_key, base_url=self.base_url ) - - # Zep客户端用于检索丰富上下文 - self.zep_api_key = zep_api_key or Config.ZEP_API_KEY - self.zep_client = None + + # 图谱客户端用于检索丰富上下文 self.graph_id = graph_id - - if self.zep_api_key: - try: - self.zep_client = Zep(api_key=self.zep_api_key) - except Exception as e: - logger.warning(f"Zep客户端初始化失败: {e}") + # 使用适配器 + try: + self.kg = get_knowledge_graph_adapter() + except Exception as e: + logger.warning(f"图谱客户端初始化失败: {e}") + self.kg = None def generate_profile_from_entity( self, @@ -285,51 +284,54 @@ def _generate_username(self, name: str) -> str: def _search_zep_for_entity(self, entity: EntityNode) -> Dict[str, Any]: """ 使用Zep图谱混合搜索功能获取实体相关的丰富信息 - + Zep没有内置混合搜索接口,需要分别搜索edges和nodes然后合并结果。 使用并行请求同时搜索,提高效率。 - + Args: entity: 实体节点对象 - + Returns: 包含facts, node_summaries, context的字典 """ import concurrent.futures - - if not self.zep_client: + + if not self.kg: return {"facts": [], "node_summaries": [], "context": ""} - + entity_name = entity.name - + results = { "facts": [], "node_summaries": [], "context": "" } - + # 必须有graph_id才能进行搜索 if not self.graph_id: logger.debug(f"跳过Zep检索:未设置graph_id") return results - - comprehensive_query = f"关于{entity_name}的所有信息、活动、事件、关系和背景" - + + # 使用实体名称作为查询,而不是复杂的中文描述 + # Zep Cloud search 对中文支持有限,尝试直接用实体名搜索 + comprehensive_query = entity_name + def search_edges(): """搜索边(事实/关系)- 带重试机制""" max_retries = 3 last_exception = None delay = 2.0 - + for attempt in range(max_retries): try: - return self.zep_client.graph.search( + result = self.kg.search( query=comprehensive_query, graph_id=self.graph_id, limit=30, scope="edges", reranker="rrf" ) + return result except Exception as e: last_exception = e if attempt < max_retries - 1: @@ -348,7 +350,7 @@ def search_nodes(): for attempt in range(max_retries): try: - return self.zep_client.graph.search( + return self.kg.search( query=comprehensive_query, graph_id=self.graph_id, limit=20, diff --git a/backend/app/services/report_agent.py b/backend/app/services/report_agent.py index 02ca5bdc..47c8621f 100644 --- a/backend/app/services/report_agent.py +++ b/backend/app/services/report_agent.py @@ -567,6 +567,18 @@ def to_dict(self) -> Dict[str, Any]: - ❌ 不是对现实世界现状的分析 - ❌ 不是泛泛而谈的舆情综述 +【风格与主题约束】(重要!必须遵守) + +报告是给人看的,必须通俗易懂! + +1. 报告标题必须是预测结果的直接表述,像新闻标题一样一目了然 +2. 章节标题要简洁明了,回答"预测到了什么" +3. 禁止使用抽象、晦涩、诗意化的表达 +4. 禁止把"变量""注入""模拟""状态""演化"等抽象词汇放在标题中 + +✅ 好的标题:「XX场景下用户行为预测」「XX产品发布后市场趋势预测」 +❌ 差的标题:「变量注入后社会状态预测」「智能时代的命运交响」 + 【章节数量限制】 - 最少2个章节,最多5个章节 - 不需要子章节,每个章节直接撰写完整内容 @@ -721,8 +733,8 @@ def to_dict(self) -> Dict[str, Any]: 选项A - 调用工具: 输出你的思考,然后用以下格式调用一个工具: - -{{"name": "工具名称", "parameters": {{"参数名": "参数值"}}}} + + 参数值 系统会执行工具并把结果返回给你。你不需要也不能自己编写工具返回结果。 @@ -844,8 +856,8 @@ def to_dict(self) -> Dict[str, Any]: {tools_description} 【工具调用格式】 - -{{"name": "工具名称", "parameters": {{"参数名": "参数值"}}}} + + 参数值 【回答风格】 @@ -1068,25 +1080,52 @@ def _parse_tool_calls(self, response: str) -> List[Dict[str, Any]]: 从LLM响应中解析工具调用 支持的格式(按优先级): - 1. {"name": "tool_name", "parameters": {...}} - 2. 裸 JSON(响应整体或单行就是一个工具调用 JSON) + 1. XML格式(标准格式): value + 2. JSON格式(兜底): {"name": "tool_name", "parameters": {...}} + 3. 裸 JSON(兜底): 响应整体或单行就是一个工具调用 JSON """ tool_calls = [] - # 格式1: XML风格(标准格式) - xml_pattern = r'\s*(\{.*?\})\s*' + # 格式1: XML格式(标准格式)- 优先匹配 + # + # 参数值 + # + xml_pattern = r']*>(.*?)' for match in re.finditer(xml_pattern, response, re.DOTALL): + tool_name = match.group(1) + params_content = match.group(2) + + # 提取所有 value 标签 + param_pattern = r'([^<]*)' + params = {} + for param_match in re.finditer(param_pattern, params_content): + param_name = param_match.group(1) + param_value = param_match.group(2) + params[param_name] = param_value + + if tool_name: + tool_calls.append({ + "name": tool_name, + "parameters": params + }) + + if tool_calls: + return tool_calls + + # 格式2: JSON格式(旧格式兼容)- {"name": ...} + json_xml_pattern = r'\s*(\{.*?\})\s*' + for match in re.finditer(json_xml_pattern, response, re.DOTALL): try: call_data = json.loads(match.group(1)) - tool_calls.append(call_data) + if self._is_valid_tool_call(call_data): + tool_calls.append(call_data) except json.JSONDecodeError: pass if tool_calls: return tool_calls - # 格式2: 兜底 - LLM 直接输出裸 JSON(没包 标签) - # 只在格式1未匹配时尝试,避免误匹配正文中的 JSON + # 格式3: 兜底 - LLM 直接输出裸 JSON(没包 标签) stripped = response.strip() if stripped.startswith('{') and stripped.endswith('}'): try: @@ -1114,14 +1153,37 @@ def _is_valid_tool_call(self, data: dict) -> bool: """校验解析出的 JSON 是否是合法的工具调用""" # 支持 {"name": ..., "parameters": ...} 和 {"tool": ..., "params": ...} 两种键名 tool_name = data.get("name") or data.get("tool") - if tool_name and tool_name in self.VALID_TOOL_NAMES: - # 统一键名为 name / parameters - if "tool" in data: - data["name"] = data.pop("tool") - if "params" in data and "parameters" not in data: - data["parameters"] = data.pop("params") + if not tool_name: + return False + + # 精确匹配 + if tool_name in self.VALID_TOOL_NAMES: + self._normalize_tool_call(data) return True + + # 容错匹配:处理常见格式错误 + # interviewagents -> interview_agents + # quicksearch -> quick_search + # panoramasearch -> panorama_search + # insightforge -> insight_forge + normalized = tool_name.replace("search", "_search").replace("forge", "_forge") + normalized = normalized.replace("agents", "_agents") + + # 尝试匹配 + for valid_name in self.VALID_TOOL_NAMES: + if normalized == valid_name or tool_name == valid_name: + data["name"] = valid_name # 修正为正确名称 + self._normalize_tool_call(data) + return True + return False + + def _normalize_tool_call(self, data: dict): + """统一键名格式""" + if "tool" in data: + data["name"] = data.pop("tool") + if "params" in data and "parameters" not in data: + data["parameters"] = data.pop("params") def _get_tools_description(self) -> str: """生成工具描述文本""" @@ -1834,7 +1896,7 @@ def chat( if not tool_calls: # 没有工具调用,直接返回响应 - clean_response = re.sub(r'.*?', '', response, flags=re.DOTALL) + clean_response = re.sub(r']*>.*?', '', response, flags=re.DOTALL) clean_response = re.sub(r'\[TOOL_CALL\].*?\)', '', clean_response) return { @@ -1870,7 +1932,7 @@ def chat( ) # 清理响应 - clean_response = re.sub(r'.*?', '', final_response, flags=re.DOTALL) + clean_response = re.sub(r']*>.*?', '', final_response, flags=re.DOTALL) clean_response = re.sub(r'\[TOOL_CALL\].*?\)', '', clean_response) return { diff --git a/backend/app/services/simulation_ipc.py b/backend/app/services/simulation_ipc.py index 9d70d0be..05e9bcf7 100644 --- a/backend/app/services/simulation_ipc.py +++ b/backend/app/services/simulation_ipc.py @@ -195,13 +195,13 @@ def send_interview( ) -> IPCResponse: """ 发送单个Agent采访命令 - + Args: agent_id: Agent ID prompt: 采访问题 platform: 指定平台(可选) - "twitter": 只采访Twitter平台 - - "reddit": 只采访Reddit平台 + - "reddit": 只采访Reddit平台 - None: 双平台模拟时同时采访两个平台,单平台模拟时采访该平台 timeout: 超时时间 diff --git a/backend/app/services/simulation_manager.py b/backend/app/services/simulation_manager.py index 96c496fd..0020a55d 100644 --- a/backend/app/services/simulation_manager.py +++ b/backend/app/services/simulation_manager.py @@ -526,3 +526,36 @@ def get_run_instructions(self, simulation_id: str) -> Dict[str, str]: f" - 并行运行双平台: python {scripts_dir}/run_parallel_simulation.py --config {config_path}" ) } + + def delete_simulation(self, simulation_id: str) -> bool: + """ + 删除模拟及其所有相关数据 + + Args: + simulation_id: 模拟ID + + Returns: + 删除是否成功 + """ + import shutil + + sim_dir = self._get_simulation_dir(simulation_id) + + # 检查模拟目录是否存在 + if not os.path.exists(sim_dir): + logger.warning(f"模拟目录不存在,跳过删除: {simulation_id}") + return False + + try: + # 删除模拟目录 + shutil.rmtree(sim_dir) + + # 从内存缓存中移除 + if simulation_id in self._simulations: + del self._simulations[simulation_id] + + logger.info(f"成功删除模拟: {simulation_id}") + return True + except Exception as e: + logger.error(f"删除模拟失败: {simulation_id}, 错误: {e}") + return False diff --git a/backend/app/services/simulation_runner.py b/backend/app/services/simulation_runner.py index 8c35380d..f225102f 100644 --- a/backend/app/services/simulation_runner.py +++ b/backend/app/services/simulation_runner.py @@ -1489,7 +1489,7 @@ def interview_agents_batch( simulation_id: str, interviews: List[Dict[str, Any]], platform: str = None, - timeout: float = 120.0 + timeout: float = 180.0 ) -> Dict[str, Any]: """ 批量采访多个Agent diff --git a/backend/app/services/zep_entity_reader.py b/backend/app/services/zep_entity_reader.py index 71661be4..40578c78 100644 --- a/backend/app/services/zep_entity_reader.py +++ b/backend/app/services/zep_entity_reader.py @@ -1,17 +1,16 @@ """ -Zep实体读取与过滤服务 -从Zep图谱中读取节点,筛选出符合预定义实体类型的节点 +图谱实体读取与过滤服务 +从图谱中读取节点,筛选出符合预定义实体类型的节点 +支持 Zep Cloud 和 Graphiti 两种模式 """ import time from typing import Dict, Any, List, Optional, Set, Callable, TypeVar from dataclasses import dataclass, field -from zep_cloud.client import Zep - from ..config import Config from ..utils.logger import get_logger -from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges +from .kg_adapter import get_knowledge_graph_adapter logger = get_logger('mirofish.zep_entity_reader') @@ -70,20 +69,18 @@ def to_dict(self) -> Dict[str, Any]: class ZepEntityReader: """ - Zep实体读取与过滤服务 - + 图谱实体读取与过滤服务 + 主要功能: - 1. 从Zep图谱读取所有节点 + 1. 从图谱读取所有节点 2. 筛选出符合预定义实体类型的节点(Labels不只是Entity的节点) 3. 获取每个实体的相关边和关联节点信息 """ - + def __init__(self, api_key: Optional[str] = None): - self.api_key = api_key or Config.ZEP_API_KEY - if not self.api_key: - raise ValueError("ZEP_API_KEY 未配置") - - self.client = Zep(api_key=self.api_key) + self.api_key = api_key # 保留参数兼容性 + # 使用适配器 + self.kg = get_knowledge_graph_adapter() def _call_with_retry( self, @@ -136,15 +133,25 @@ def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]: """ logger.info(f"获取图谱 {graph_id} 的所有节点...") - nodes = fetch_all_nodes(self.client, graph_id) + # 使用适配器获取节点 + nodes = self.kg.get_nodes(graph_id, limit=2000) nodes_data = [] for node in nodes: - nodes_data.append({ - "uuid": getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), - "name": node.name or "", - "labels": node.labels or [], - "summary": node.summary or "", + if isinstance(node, dict): + nodes_data.append({ + "uuid": node.get('uuid_', '') or node.get('uuid', ''), + "name": node.get('name', ''), + "labels": node.get('labels', []), + "summary": node.get('summary', ''), + "attributes": node.get('attributes', {}), + }) + else: + nodes_data.append({ + "uuid": getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), + "name": node.name or "", + "labels": node.labels or [], + "summary": node.summary or "", "attributes": node.attributes or {}, }) @@ -163,18 +170,29 @@ def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]: """ logger.info(f"获取图谱 {graph_id} 的所有边...") - edges = fetch_all_edges(self.client, graph_id) + # 使用适配器获取边 + edges = self.kg.get_edges(graph_id, limit=2000) edges_data = [] for edge in edges: - edges_data.append({ - "uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''), - "name": edge.name or "", - "fact": edge.fact or "", - "source_node_uuid": edge.source_node_uuid, - "target_node_uuid": edge.target_node_uuid, - "attributes": edge.attributes or {}, - }) + if isinstance(edge, dict): + edges_data.append({ + "uuid": edge.get('uuid_', '') or edge.get('uuid', ''), + "name": edge.get('name', ''), + "fact": edge.get('fact', ''), + "source_node_uuid": edge.get('source_node_uuid', ''), + "target_node_uuid": edge.get('target_node_uuid', ''), + "attributes": edge.get('attributes', {}), + }) + else: + edges_data.append({ + "uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''), + "name": edge.name or "", + "fact": edge.fact or "", + "source_node_uuid": edge.source_node_uuid, + "target_node_uuid": edge.target_node_uuid, + "attributes": edge.attributes or {}, + }) logger.info(f"共获取 {len(edges_data)} 条边") return edges_data @@ -190,23 +208,33 @@ def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]: 边列表 """ try: - # 使用重试机制调用Zep API + # 使用重试机制调用图谱API edges = self._call_with_retry( - func=lambda: self.client.graph.node.get_entity_edges(node_uuid=node_uuid), + func=lambda: self.kg.get_node_edges(node_uuid), operation_name=f"获取节点边(node={node_uuid[:8]}...)" ) - + edges_data = [] for edge in edges: - edges_data.append({ - "uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''), - "name": edge.name or "", - "fact": edge.fact or "", - "source_node_uuid": edge.source_node_uuid, - "target_node_uuid": edge.target_node_uuid, - "attributes": edge.attributes or {}, - }) - + if isinstance(edge, dict): + edges_data.append({ + "uuid": edge.get('uuid_', '') or edge.get('uuid', ''), + "name": edge.get('name', ''), + "fact": edge.get('fact', ''), + "source_node_uuid": edge.get('source_node_uuid', ''), + "target_node_uuid": edge.get('target_node_uuid', ''), + "attributes": edge.get('attributes', {}), + }) + else: + edges_data.append({ + "uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''), + "name": edge.name or "", + "fact": edge.fact or "", + "source_node_uuid": edge.source_node_uuid, + "target_node_uuid": edge.target_node_uuid, + "attributes": edge.attributes or {}, + }) + return edges_data except Exception as e: logger.warning(f"获取节点 {node_uuid} 的边失败: {str(e)}") @@ -251,23 +279,27 @@ def filter_defined_entities( for node in all_nodes: labels = node.get("labels", []) - - # 筛选逻辑:Labels必须包含除"Entity"和"Node"之外的标签 - custom_labels = [l for l in labels if l not in ["Entity", "Node"]] - - if not custom_labels: - # 只有默认标签,跳过 + + # 获取实体类型(优先从属性获取,其次从标签获取) + entity_type = None + if node.get("attributes"): + entity_type = node["attributes"].get("entity_type") + + if not entity_type: + # 筛选逻辑:Labels必须包含除"Entity"和"Node"之外的标签 + custom_labels = [l for l in labels if l not in ["Entity", "Node"]] + if custom_labels: + entity_type = custom_labels[0] + + if not entity_type: + # 没有实体类型,跳过 continue - + # 如果指定了预定义类型,检查是否匹配 if defined_entity_types: - matching_labels = [l for l in custom_labels if l in defined_entity_types] - if not matching_labels: + if entity_type not in defined_entity_types: continue - entity_type = matching_labels[0] - else: - entity_type = custom_labels[0] - + entity_types_found.add(entity_type) # 创建实体节点对象 @@ -341,27 +373,27 @@ def get_entity_with_context( Args: graph_id: 图谱ID entity_uuid: 实体UUID - + Returns: EntityNode或None """ try: # 使用重试机制获取节点 node = self._call_with_retry( - func=lambda: self.client.graph.node.get(uuid_=entity_uuid), + func=lambda: self.kg.get_node(entity_uuid), operation_name=f"获取节点详情(uuid={entity_uuid[:8]}...)" ) - + if not node: return None - + # 获取节点的边 edges = self.get_node_edges(entity_uuid) - + # 获取所有节点用于关联查找 all_nodes = self.get_all_nodes(graph_id) node_map = {n["uuid"]: n for n in all_nodes} - + # 处理相关边和节点 related_edges = [] related_node_uuids = set() diff --git a/backend/app/services/zep_graph_memory_updater.py b/backend/app/services/zep_graph_memory_updater.py index a8f3cecd..116f6d74 100644 --- a/backend/app/services/zep_graph_memory_updater.py +++ b/backend/app/services/zep_graph_memory_updater.py @@ -1,6 +1,7 @@ """ -Zep图谱记忆更新服务 -将模拟中的Agent活动动态更新到Zep图谱中 +图谱记忆更新服务 +将模拟中的Agent活动动态更新到图谱中 +支持 Zep Cloud 和 Graphiti 两种模式 """ import os @@ -12,10 +13,9 @@ from datetime import datetime from queue import Queue, Empty -from zep_cloud.client import Zep - from ..config import Config from ..utils.logger import get_logger +from .kg_adapter import get_knowledge_graph_adapter logger = get_logger('mirofish.zep_graph_memory_updater') @@ -200,49 +200,47 @@ def _describe_generic(self) -> str: class ZepGraphMemoryUpdater: """ - Zep图谱记忆更新器 - - 监控模拟的actions日志文件,将新的agent活动实时更新到Zep图谱中。 - 按平台分组,每累积BATCH_SIZE条活动后批量发送到Zep。 - - 所有有意义的行为都会被更新到Zep,action_args中会包含完整的上下文信息: + 图谱记忆更新器 + + 监控模拟的actions日志文件,将新的agent活动实时更新到图谱中。 + 按平台分组,每累积BATCH_SIZE条活动后批量发送。 + + 所有有意义的行为都会被更新到图谱,action_args中会包含完整的上下文信息: - 点赞/踩的帖子原文 - 转发/引用的帖子原文 - 关注/屏蔽的用户名 - 点赞/踩的评论原文 """ - + # 批量发送大小(每个平台累积多少条后发送) BATCH_SIZE = 5 - + # 平台名称映射(用于控制台显示) PLATFORM_DISPLAY_NAMES = { 'twitter': '世界1', 'reddit': '世界2', } - + # 发送间隔(秒),避免请求过快 SEND_INTERVAL = 0.5 - + # 重试配置 MAX_RETRIES = 3 RETRY_DELAY = 2 # 秒 - + def __init__(self, graph_id: str, api_key: Optional[str] = None): """ 初始化更新器 - + Args: - graph_id: Zep图谱ID - api_key: Zep API Key(可选,默认从配置读取) + graph_id: 图谱ID + api_key: 保留参数(兼容旧代码) """ self.graph_id = graph_id - self.api_key = api_key or Config.ZEP_API_KEY - - if not self.api_key: - raise ValueError("ZEP_API_KEY未配置") - - self.client = Zep(api_key=self.api_key) + self.api_key = api_key # 保留参数兼容性 + + # 使用适配器 + self.kg = get_knowledge_graph_adapter() # 活动队列 self._activity_queue: Queue = Queue() @@ -401,29 +399,28 @@ def _send_batch_activities(self, activities: List[AgentActivity], platform: str) # 将多条活动合并为一条文本,用换行分隔 episode_texts = [activity.to_episode_text() for activity in activities] combined_text = "\n".join(episode_texts) - + # 带重试的发送 for attempt in range(self.MAX_RETRIES): try: - self.client.graph.add( + self.kg.add_episode( graph_id=self.graph_id, - type="text", - data=combined_text + text=combined_text ) - + self._total_sent += 1 self._total_items_sent += len(activities) display_name = self._get_platform_display_name(platform) logger.info(f"成功批量发送 {len(activities)} 条{display_name}活动到图谱 {self.graph_id}") logger.debug(f"批量内容预览: {combined_text[:200]}...") return - + except Exception as e: if attempt < self.MAX_RETRIES - 1: - logger.warning(f"批量发送到Zep失败 (尝试 {attempt + 1}/{self.MAX_RETRIES}): {e}") + logger.warning(f"批量发送到图谱失败 (尝试 {attempt + 1}/{self.MAX_RETRIES}): {e}") time.sleep(self.RETRY_DELAY * (attempt + 1)) else: - logger.error(f"批量发送到Zep失败,已重试{self.MAX_RETRIES}次: {e}") + logger.error(f"批量发送到图谱失败,已重试{self.MAX_RETRIES}次: {e}") self._failed_count += 1 def _flush_remaining(self): diff --git a/backend/app/services/zep_tools.py b/backend/app/services/zep_tools.py index 384cf540..438f9947 100644 --- a/backend/app/services/zep_tools.py +++ b/backend/app/services/zep_tools.py @@ -1,6 +1,7 @@ """ -Zep检索工具服务 +图谱检索工具服务 封装图谱搜索、节点读取、边查询等工具,供Report Agent使用 +支持 Zep Cloud 和 Graphiti 两种模式 核心检索工具(优化后): 1. InsightForge(深度洞察检索)- 最强大的混合检索,自动生成子问题并多维度检索 @@ -13,12 +14,10 @@ from typing import Dict, Any, List, Optional from dataclasses import dataclass, field -from zep_cloud.client import Zep - from ..config import Config from ..utils.logger import get_logger from ..utils.llm_client import LLMClient -from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges +from .kg_adapter import get_knowledge_graph_adapter logger = get_logger('mirofish.zep_tools') @@ -61,7 +60,24 @@ class NodeInfo: labels: List[str] summary: str attributes: Dict[str, Any] - + + def __init__(self, uuid: str = "", name: str = "", labels: List[str] = None, + summary: str = "", attributes: Dict[str, Any] = None): + # 如果传入的是 dict,转换为对象属性 + if isinstance(uuid, dict): + d = uuid + self.uuid = d.get('uuid_') or d.get('uuid', '') + self.name = d.get('name', '') + self.labels = d.get('labels', []) + self.summary = d.get('summary', '') + self.attributes = d.get('attributes', {}) + else: + self.uuid = uuid or '' + self.name = name or '' + self.labels = labels or [] + self.summary = summary or '' + self.attributes = attributes or {} + def to_dict(self) -> Dict[str, Any]: return { "uuid": self.uuid, @@ -70,7 +86,7 @@ def to_dict(self) -> Dict[str, Any]: "summary": self.summary, "attributes": self.attributes } - + def to_text(self) -> str: """转换为文本格式""" entity_type = next((l for l in self.labels if l not in ["Entity", "Node"]), "未知类型") @@ -92,7 +108,39 @@ class EdgeInfo: valid_at: Optional[str] = None invalid_at: Optional[str] = None expired_at: Optional[str] = None - + + def __init__(self, uuid: str = "", name: str = "", fact: str = "", + source_node_uuid: str = "", target_node_uuid: str = "", + source_node_name: str = None, target_node_name: str = None, + created_at: str = None, valid_at: str = None, + invalid_at: str = None, expired_at: str = None): + # 如果传入的是 dict,转换为对象属性 + if isinstance(uuid, dict): + d = uuid + self.uuid = d.get('uuid_') or d.get('uuid', '') + self.name = d.get('name', 'RELATED') + self.fact = d.get('fact', '') + self.source_node_uuid = d.get('source_node_uuid', '') + self.target_node_uuid = d.get('target_node_uuid', '') + self.source_node_name = d.get('source_node_name') + self.target_node_name = d.get('target_node_name') + self.created_at = d.get('created_at') + self.valid_at = d.get('valid_at') + self.invalid_at = d.get('invalid_at') + self.expired_at = d.get('expired_at') + else: + self.uuid = uuid or '' + self.name = name or 'RELATED' + self.fact = fact or '' + self.source_node_uuid = source_node_uuid or '' + self.target_node_uuid = target_node_uuid or '' + self.source_node_name = source_node_name + self.target_node_name = target_node_name + self.created_at = created_at + self.valid_at = valid_at + self.invalid_at = invalid_at + self.expired_at = expired_at + def to_dict(self) -> Dict[str, Any]: return { "uuid": self.uuid, @@ -422,11 +470,9 @@ class ZepToolsService: RETRY_DELAY = 2.0 def __init__(self, api_key: Optional[str] = None, llm_client: Optional[LLMClient] = None): - self.api_key = api_key or Config.ZEP_API_KEY - if not self.api_key: - raise ValueError("ZEP_API_KEY 未配置") - - self.client = Zep(api_key=self.api_key) + self.api_key = api_key # 保留参数兼容性 + # 使用适配器 + self.kg = get_knowledge_graph_adapter() # LLM客户端用于InsightForge生成子问题 self._llm_client = llm_client logger.info("ZepToolsService 初始化完成") @@ -485,15 +531,13 @@ def search_graph( """ logger.info(f"图谱搜索: graph_id={graph_id}, query={query[:50]}...") - # 尝试使用Zep Cloud Search API + # 尝试使用图谱搜索 API try: search_results = self._call_with_retry( - func=lambda: self.client.graph.search( + func=lambda: self.kg.search( graph_id=graph_id, query=query, limit=limit, - scope=scope, - reranker="cross_encoder" ), operation_name=f"图谱搜索(graph={graph_id})" ) @@ -501,33 +545,62 @@ def search_graph( facts = [] edges = [] nodes = [] - - # 解析边搜索结果 - if hasattr(search_results, 'edges') and search_results.edges: - for edge in search_results.edges: - if hasattr(edge, 'fact') and edge.fact: - facts.append(edge.fact) - edges.append({ - "uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''), - "name": getattr(edge, 'name', ''), - "fact": getattr(edge, 'fact', ''), - "source_node_uuid": getattr(edge, 'source_node_uuid', ''), - "target_node_uuid": getattr(edge, 'target_node_uuid', ''), - }) - - # 解析节点搜索结果 - if hasattr(search_results, 'nodes') and search_results.nodes: - for node in search_results.nodes: - nodes.append({ - "uuid": getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), - "name": getattr(node, 'name', ''), - "labels": getattr(node, 'labels', []), - "summary": getattr(node, 'summary', ''), - }) - # 节点摘要也算作事实 - if hasattr(node, 'summary') and node.summary: - facts.append(f"[{node.name}]: {node.summary}") - + + # 解析搜索结果(兼容对象和字典格式) + # 适配器返回的是 List[Dict],需要处理 + if isinstance(search_results, list): + for result in search_results: + # 判断是边还是节点 + if isinstance(result, dict): + if result.get('source_node_uuid') and result.get('target_node_uuid'): + # 边 + fact = result.get('fact', '') + if fact: + facts.append(fact) + edges.append({ + "uuid": result.get('uuid_', '') or result.get('uuid', ''), + "name": result.get('name', ''), + "fact": fact, + "source_node_uuid": result.get('source_node_uuid', ''), + "target_node_uuid": result.get('target_node_uuid', ''), + }) + else: + # 节点 + name = result.get('name', '') + summary = result.get('summary', '') + nodes.append({ + "uuid": result.get('uuid_', '') or result.get('uuid', ''), + "name": name, + "labels": result.get('labels', []), + "summary": summary, + }) + if summary: + facts.append(f"[{name}]: {summary}") + else: + # 原始对象格式(保留兼容性) + if hasattr(search_results, 'edges') and search_results.edges: + for edge in search_results.edges: + if hasattr(edge, 'fact') and edge.fact: + facts.append(edge.fact) + edges.append({ + "uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''), + "name": getattr(edge, 'name', ''), + "fact": getattr(edge, 'fact', ''), + "source_node_uuid": getattr(edge, 'source_node_uuid', ''), + "target_node_uuid": getattr(edge, 'target_node_uuid', ''), + }) + + if hasattr(search_results, 'nodes') and search_results.nodes: + for node in search_results.nodes: + nodes.append({ + "uuid": getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), + "name": getattr(node, 'name', ''), + "labels": getattr(node, 'labels', []), + "summary": getattr(node, 'summary', ''), + }) + if hasattr(node, 'summary') and node.summary: + facts.append(f"[{node.name}]: {node.summary}") + logger.info(f"搜索完成: 找到 {len(facts)} 条相关事实") return SearchResult( @@ -659,7 +732,38 @@ def get_all_nodes(self, graph_id: str) -> List[NodeInfo]: """ logger.info(f"获取图谱 {graph_id} 的所有节点...") - nodes = fetch_all_nodes(self.client, graph_id) + # 使用适配器分页获取所有节点 + # Zep Cloud API 使用 uuid_cursor 分页,但响应不返回 cursor + # 通过返回数量判断是否有更多:< limit 则说明是最后一页 + nodes = [] + cursor = None + max_pages = 100 # 最多获取 100 页,防止无限循环 + page_count = 0 + + while page_count < max_pages: + page = self.kg.get_nodes(graph_id, limit=100, cursor=cursor) + if not page: + break + # 将 dict 转换为 NodeInfo 对象 + for item in page: + if isinstance(item, dict): + nodes.append(NodeInfo(item)) + else: + nodes.append(item) + page_count += 1 + + # 如果返回数量 < limit,说明是最后一页 + if len(page) < 100: + break + + # 尝试获取下一页 - Zep Cloud 使用 uuid_cursor 参数 + # 由于 API 不返回 next_cursor,我们需要用最后一条的 uuid 作为 cursor + last_item = page[-1] + cursor = getattr(last_item, 'uuid_', None) or getattr(last_item, 'uuid', None) + if not cursor: + break + + logger.info(f"分页获取完成,共 {page_count} 页,{len(nodes)} 个节点") result = [] for node in nodes: @@ -688,7 +792,33 @@ def get_all_edges(self, graph_id: str, include_temporal: bool = True) -> List[Ed """ logger.info(f"获取图谱 {graph_id} 的所有边...") - edges = fetch_all_edges(self.client, graph_id) + # 使用适配器分页获取所有边 + edges = [] + cursor = None + max_pages = 100 + page_count = 0 + + while page_count < max_pages: + page = self.kg.get_edges(graph_id, limit=100, cursor=cursor) + if not page: + break + # 将 dict 转换为 EdgeInfo 对象 + for item in page: + if isinstance(item, dict): + edges.append(EdgeInfo(item)) + else: + edges.append(item) + page_count += 1 + + if len(page) < 100: + break + + last_item = page[-1] + cursor = getattr(last_item, 'uuid_', None) or (last_item.get('uuid_') if isinstance(last_item, dict) else None) or (last_item.get('uuid') if isinstance(last_item, dict) else None) + if not cursor: + break + + logger.info(f"分页获取完成,共 {page_count} 页,{len(edges)} 条边") result = [] for edge in edges: @@ -724,23 +854,33 @@ def get_node_detail(self, node_uuid: str) -> Optional[NodeInfo]: 节点信息或None """ logger.info(f"获取节点详情: {node_uuid[:8]}...") - + try: node = self._call_with_retry( - func=lambda: self.client.graph.node.get(uuid_=node_uuid), + func=lambda: self.kg.get_node(node_uuid), operation_name=f"获取节点详情(uuid={node_uuid[:8]}...)" ) - + if not node: return None - - return NodeInfo( - uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), - name=node.name or "", - labels=node.labels or [], - summary=node.summary or "", - attributes=node.attributes or {} - ) + + # 兼容对象和字典格式 + if isinstance(node, dict): + return NodeInfo( + uuid=node.get('uuid_', '') or node.get('uuid', ''), + name=node.get('name', ''), + labels=node.get('labels', []), + summary=node.get('summary', ''), + attributes=node.get('attributes', {}) + ) + else: + return NodeInfo( + uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), + name=node.name or "", + labels=node.labels or [], + summary=node.summary or "", + attributes=node.attributes or {} + ) except Exception as e: logger.error(f"获取节点详情失败: {str(e)}") return None @@ -863,22 +1003,24 @@ def get_graph_statistics(self, graph_id: str) -> Dict[str, Any]: 统计信息 """ logger.info(f"获取图谱 {graph_id} 的统计信息...") - + nodes = self.get_all_nodes(graph_id) edges = self.get_all_edges(graph_id) - - # 统计实体类型分布 + + # 统计实体类型分布(兼容 dict 和对象) entity_types = {} for node in nodes: - for label in node.labels: + labels = node.labels if hasattr(node, 'labels') else node.get('labels', []) + for label in labels: if label not in ["Entity", "Node"]: entity_types[label] = entity_types.get(label, 0) + 1 - - # 统计关系类型分布 + + # 统计关系类型分布(兼容 dict 和对象) relation_types = {} for edge in edges: - relation_types[edge.name] = relation_types.get(edge.name, 0) + 1 - + edge_name = edge.name if hasattr(edge, 'name') else edge.get('name', 'RELATED') + relation_types[edge_name] = relation_types.get(edge_name, 0) + 1 + return { "graph_id": graph_id, "total_nodes": len(nodes), @@ -1650,6 +1792,7 @@ def _generate_interview_questions( 4. 语言自然,像真实采访一样 5. 每个问题控制在50字以内,简洁明了 6. 直接提问,不要包含背景说明或前缀 +7. 问题数量根据采访需求的复杂度决定,简单主题1-5个,复杂主题最多3个 返回JSON格式:{"questions": ["问题1", "问题2", ...]}""" @@ -1676,8 +1819,7 @@ def _generate_interview_questions( logger.warning(f"生成采访问题失败: {e}") return [ f"关于{interview_requirement},您的观点是什么?", - "这件事对您或您所代表的群体有什么影响?", - "您认为应该如何解决或改进这个问题?" + "这件事对您或您所代表的群体有什么影响?" ] def _generate_interview_summary( diff --git a/backend/app/utils/llm_client.py b/backend/app/utils/llm_client.py index 6c1a81f4..7c8d5a24 100644 --- a/backend/app/utils/llm_client.py +++ b/backend/app/utils/llm_client.py @@ -36,26 +36,29 @@ def chat( self, messages: List[Dict[str, str]], temperature: float = 0.7, - max_tokens: int = 4096, + max_tokens: Optional[int] = None, response_format: Optional[Dict] = None ) -> str: """ 发送聊天请求 - + Args: messages: 消息列表 temperature: 温度参数 - max_tokens: 最大token数 + max_tokens: 最大token数(默认使用配置中的 LLM_MAX_TOKENS) response_format: 响应格式(如JSON模式) Returns: 模型响应文本 """ + # 如果未指定 max_tokens,使用配置中的默认值 + effective_max_tokens = max_tokens if max_tokens is not None else Config.LLM_MAX_TOKENS + kwargs = { "model": self.model, "messages": messages, "temperature": temperature, - "max_tokens": max_tokens, + "max_tokens": effective_max_tokens, } if response_format: @@ -71,16 +74,16 @@ def chat_json( self, messages: List[Dict[str, str]], temperature: float = 0.3, - max_tokens: int = 4096 + max_tokens: Optional[int] = None ) -> Dict[str, Any]: """ 发送聊天请求并返回JSON - + Args: messages: 消息列表 temperature: 温度参数 - max_tokens: 最大token数 - + max_tokens: 最大token数(默认使用配置中的 LLM_MAX_TOKENS) + Returns: 解析后的JSON对象 """ diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 4f5361d5..83d2dc6d 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -12,23 +12,27 @@ dependencies = [ # 核心框架 "flask>=3.0.0", "flask-cors>=6.0.0", - + # LLM 相关 "openai>=1.0.0", - - # Zep Cloud - "zep-cloud==3.13.0", - + + # 知识图谱 - 根据 KNOWLEDGE_GRAPH_MODE 选择使用 + # cloud 模式: zep-cloud + # local 模式: graphiti-core + neo4j + "zep-cloud>=3.13.0", + "graphiti-core>=0.5.0", + "neo4j>=5.0.0", + # OASIS 社交媒体模拟 "camel-oasis==0.2.5", "camel-ai==0.2.78", - + # 文件处理 "PyMuPDF>=1.24.0", # 编码检测(支持非UTF-8编码的文本文件) "charset-normalizer>=3.0.0", "chardet>=5.0.0", - + # 工具库 "python-dotenv>=1.0.0", "pydantic>=2.0.0", diff --git a/backend/requirements.txt b/backend/requirements.txt index 4f146296..e23821b2 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -13,8 +13,13 @@ flask-cors>=6.0.0 # OpenAI SDK(统一使用 OpenAI 格式调用 LLM) openai>=1.0.0 -# ============= Zep Cloud ============= -zep-cloud==3.13.0 +# ============= 知识图谱 ============= +# 根据 KNOWLEDGE_GRAPH_MODE 选择使用 +# cloud 模式: zep-cloud +# local 模式: graphiti-core + neo4j +zep-cloud>=3.13.0 +graphiti-core>=0.5.0 +neo4j>=5.0.0 # ============= OASIS 社交媒体模拟 ============= # OASIS 社交模拟框架 diff --git a/backend/scripts/run_parallel_simulation.py b/backend/scripts/run_parallel_simulation.py index 2a627ffd..16b12075 100644 --- a/backend/scripts/run_parallel_simulation.py +++ b/backend/scripts/run_parallel_simulation.py @@ -514,23 +514,41 @@ async def handle_batch_interview(self, command_id: str, interviews: List[Dict], self.send_response(command_id, "failed", error="没有成功的采访") return False + def _ensure_interview_index(self, db_path: str): + """确保trace表有Interview查询所需的索引""" + try: + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + # 创建索引加速按action和user_id查询最新记录 + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_trace_interview_lookup + ON trace(action, user_id, created_at DESC) + """) + conn.commit() + conn.close() + except Exception as e: + print(f" 创建索引失败: {e}") + def _get_interview_result(self, agent_id: int, platform: str) -> Dict[str, Any]: """从数据库获取最新的Interview结果""" db_path = os.path.join(self.simulation_dir, f"{platform}_simulation.db") - + result = { "agent_id": agent_id, "response": None, "timestamp": None } - + if not os.path.exists(db_path): return result - + try: + # 确保索引存在 + self._ensure_interview_index(db_path) + conn = sqlite3.connect(db_path) cursor = conn.cursor() - + # 查询最新的Interview记录 cursor.execute(""" SELECT user_id, info, created_at @@ -539,7 +557,7 @@ def _get_interview_result(self, agent_id: int, platform: str) -> Dict[str, Any]: ORDER BY created_at DESC LIMIT 1 """, (ActionType.INTERVIEW.value, agent_id)) - + row = cursor.fetchone() if row: user_id, info_json, created_at = row @@ -549,12 +567,12 @@ def _get_interview_result(self, agent_id: int, platform: str) -> Dict[str, Any]: result["timestamp"] = created_at except json.JSONDecodeError: result["response"] = info_json - + conn.close() - + except Exception as e: print(f" 读取Interview结果失败: {e}") - + return result async def process_commands(self) -> bool: diff --git a/backend/scripts/run_reddit_simulation.py b/backend/scripts/run_reddit_simulation.py index 14907cbd..6ab24d0c 100644 --- a/backend/scripts/run_reddit_simulation.py +++ b/backend/scripts/run_reddit_simulation.py @@ -297,23 +297,41 @@ async def handle_batch_interview(self, command_id: str, interviews: List[Dict]) self.send_response(command_id, "failed", error=error_msg) return False + def _ensure_interview_index(self, db_path: str): + """确保trace表有Interview查询所需的索引""" + try: + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + # 创建索引加速按action和user_id查询最新记录 + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_trace_interview_lookup + ON trace(action, user_id, created_at DESC) + """) + conn.commit() + conn.close() + except Exception as e: + print(f" 创建索引失败: {e}") + def _get_interview_result(self, agent_id: int) -> Dict[str, Any]: """从数据库获取最新的Interview结果""" db_path = os.path.join(self.simulation_dir, "reddit_simulation.db") - + result = { "agent_id": agent_id, "response": None, "timestamp": None } - + if not os.path.exists(db_path): return result - + try: + # 确保索引存在 + self._ensure_interview_index(db_path) + conn = sqlite3.connect(db_path) cursor = conn.cursor() - + # 查询最新的Interview记录 cursor.execute(""" SELECT user_id, info, created_at @@ -322,7 +340,7 @@ def _get_interview_result(self, agent_id: int) -> Dict[str, Any]: ORDER BY created_at DESC LIMIT 1 """, (ActionType.INTERVIEW.value, agent_id)) - + row = cursor.fetchone() if row: user_id, info_json, created_at = row @@ -332,12 +350,12 @@ def _get_interview_result(self, agent_id: int) -> Dict[str, Any]: result["timestamp"] = created_at except json.JSONDecodeError: result["response"] = info_json - + conn.close() - + except Exception as e: print(f" 读取Interview结果失败: {e}") - + return result async def process_commands(self) -> bool: diff --git a/backend/scripts/run_twitter_simulation.py b/backend/scripts/run_twitter_simulation.py index caab9e9d..41cb6a29 100644 --- a/backend/scripts/run_twitter_simulation.py +++ b/backend/scripts/run_twitter_simulation.py @@ -297,23 +297,41 @@ async def handle_batch_interview(self, command_id: str, interviews: List[Dict]) self.send_response(command_id, "failed", error=error_msg) return False + def _ensure_interview_index(self, db_path: str): + """确保trace表有Interview查询所需的索引""" + try: + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + # 创建索引加速按action和user_id查询最新记录 + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_trace_interview_lookup + ON trace(action, user_id, created_at DESC) + """) + conn.commit() + conn.close() + except Exception as e: + print(f" 创建索引失败: {e}") + def _get_interview_result(self, agent_id: int) -> Dict[str, Any]: """从数据库获取最新的Interview结果""" db_path = os.path.join(self.simulation_dir, "twitter_simulation.db") - + result = { "agent_id": agent_id, "response": None, "timestamp": None } - + if not os.path.exists(db_path): return result - + try: + # 确保索引存在 + self._ensure_interview_index(db_path) + conn = sqlite3.connect(db_path) cursor = conn.cursor() - + # 查询最新的Interview记录 cursor.execute(""" SELECT user_id, info, created_at @@ -322,7 +340,7 @@ def _get_interview_result(self, agent_id: int) -> Dict[str, Any]: ORDER BY created_at DESC LIMIT 1 """, (ActionType.INTERVIEW.value, agent_id)) - + row = cursor.fetchone() if row: user_id, info_json, created_at = row @@ -332,12 +350,12 @@ def _get_interview_result(self, agent_id: int) -> Dict[str, Any]: result["timestamp"] = created_at except json.JSONDecodeError: result["response"] = info_json - + conn.close() - + except Exception as e: print(f" 读取Interview结果失败: {e}") - + return result async def process_commands(self) -> bool: diff --git a/backend/tests/test_kg_adapter.py b/backend/tests/test_kg_adapter.py new file mode 100644 index 00000000..463c7333 --- /dev/null +++ b/backend/tests/test_kg_adapter.py @@ -0,0 +1,152 @@ +""" +Knowledge Graph Adapter Unit Tests + +Tests the kg_adapter module API signatures and configuration. +Run with: uv run pytest tests/test_kg_adapter.py -v +""" +import pytest +from unittest.mock import Mock, patch +import os + + +class TestZepCloudAdapterAPI: + """Test ZepCloudAdapter API calls match Zep Cloud SDK""" + + def test_create_graph_signature(self): + """Test create_graph accepts graph_id and name""" + from app.services.kg_adapter import ZepCloudAdapter + import inspect + + sig = inspect.signature(ZepCloudAdapter.create_graph) + params = list(sig.parameters.keys()) + assert 'self' in params + assert 'graph_id' in params + assert 'name' in params + + def test_add_episode_signature(self): + """Test add_episode accepts graph_id and text""" + from app.services.kg_adapter import ZepCloudAdapter + import inspect + + sig = inspect.signature(ZepCloudAdapter.add_episode) + params = list(sig.parameters.keys()) + assert 'graph_id' in params + assert 'text' in params + + def test_add_episodes_batch_signature(self): + """Test add_episodes_batch accepts graph_id and texts""" + from app.services.kg_adapter import ZepCloudAdapter + import inspect + + sig = inspect.signature(ZepCloudAdapter.add_episodes_batch) + params = list(sig.parameters.keys()) + assert 'graph_id' in params + assert 'texts' in params + + def test_set_ontology_signature(self): + """Test set_ontology accepts graph_id and ontology""" + from app.services.kg_adapter import ZepCloudAdapter + import inspect + + sig = inspect.signature(ZepCloudAdapter.set_ontology) + params = list(sig.parameters.keys()) + assert 'graph_id' in params + assert 'ontology' in params + + def test_search_signature(self): + """Test search accepts graph_id, query and limit""" + from app.services.kg_adapter import ZepCloudAdapter + import inspect + + sig = inspect.signature(ZepCloudAdapter.search) + params = list(sig.parameters.keys()) + assert 'graph_id' in params + assert 'query' in params + assert 'limit' in params + + def test_get_nodes_signature(self): + """Test get_nodes accepts graph_id, limit and cursor""" + from app.services.kg_adapter import ZepCloudAdapter + import inspect + + sig = inspect.signature(ZepCloudAdapter.get_nodes) + params = list(sig.parameters.keys()) + assert 'graph_id' in params + assert 'limit' in params + assert 'cursor' in params + + def test_get_edges_signature(self): + """Test get_edges accepts graph_id, limit and cursor""" + from app.services.kg_adapter import ZepCloudAdapter + import inspect + + sig = inspect.signature(ZepCloudAdapter.get_edges) + params = list(sig.parameters.keys()) + assert 'graph_id' in params + assert 'limit' in params + assert 'cursor' in params + + +class TestGraphitiAdapterAPI: + """Test GraphitiAdapter API signatures""" + + def test_create_graph_signature(self): + """Test create_graph accepts graph_id and name""" + from app.services.kg_adapter import GraphitiAdapter + import inspect + + sig = inspect.signature(GraphitiAdapter.create_graph) + params = list(sig.parameters.keys()) + assert 'graph_id' in params + + def test_add_episode_signature(self): + """Test add_episode accepts graph_id and text""" + from app.services.kg_adapter import GraphitiAdapter + import inspect + + sig = inspect.signature(GraphitiAdapter.add_episode) + params = list(sig.parameters.keys()) + assert 'graph_id' in params + assert 'text' in params + + def test_add_episodes_batch_signature(self): + """Test add_episodes_batch accepts graph_id and texts""" + from app.services.kg_adapter import GraphitiAdapter + import inspect + + sig = inspect.signature(GraphitiAdapter.add_episodes_batch) + params = list(sig.parameters.keys()) + assert 'graph_id' in params + assert 'texts' in params + + def test_search_signature(self): + """Test search accepts graph_id, query and limit""" + from app.services.kg_adapter import GraphitiAdapter + import inspect + + sig = inspect.signature(GraphitiAdapter.search) + params = list(sig.parameters.keys()) + assert 'graph_id' in params + assert 'query' in params + assert 'limit' in params + + +class TestAdapterFactory: + """Test adapter factory function""" + + def test_factory_returns_adapter(self): + """Test factory returns an adapter""" + from app.services.kg_adapter import get_knowledge_graph_adapter + + adapter = get_knowledge_graph_adapter() + assert adapter is not None + + def test_cloud_mode(self): + """Test knowledge graph mode is valid""" + from app.config import Config + + assert Config.KNOWLEDGE_GRAPH_MODE in ['cloud', 'local'] + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) diff --git a/backend/uv.lock b/backend/uv.lock index f1ce4b60..cb34fa5c 100644 --- a/backend/uv.lock +++ b/backend/uv.lock @@ -475,6 +475,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/07/6c/aa3f2f849e01cb6a001cd8554a88d4c77c5c1a31c95bdf1cf9301e6d9ef4/defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61", size = 25604, upload-time = "2021-03-08T10:59:24.45Z" }, ] +[[package]] +name = "diskcache" +version = "5.6.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/3f/21/1c1ffc1a039ddcc459db43cc108658f32c57d271d7289a2794e401d0fdb6/diskcache-5.6.3.tar.gz", hash = "sha256:2c3a3fa2743d8535d832ec61c2054a1641f41775aa7c556758a109941e33e4fc", size = 67916, upload-time = "2023-08-31T06:12:00.316Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3f/27/4570e78fc0bf5ea0ca45eb1de3818a23787af9b390c0b0a0033a1b8236f9/diskcache-5.6.3-py3-none-any.whl", hash = "sha256:5e31b2d5fbad117cc363ebaf6b689474db18a1f6438bc82358b024abd4c2ca19", size = 45550, upload-time = "2023-08-31T06:11:58.822Z" }, +] + [[package]] name = "distlib" version = "0.4.0" @@ -592,6 +601,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/51/c7/b64cae5dba3a1b138d7123ec36bb5ccd39d39939f18454407e5468f4763f/fsspec-2025.12.0-py3-none-any.whl", hash = "sha256:8bf1fe301b7d8acfa6e8571e3b1c3d158f909666642431cc78a1b7b4dbc5ec5b", size = 201422, upload-time = "2025-12-03T15:23:41.434Z" }, ] +[[package]] +name = "graphiti-core" +version = "0.11.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "diskcache" }, + { name = "neo4j" }, + { name = "numpy" }, + { name = "openai" }, + { name = "pydantic" }, + { name = "python-dotenv" }, + { name = "tenacity" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/30/94/3f84400e5f02ea8e9dc79784202de4173cbc16f4b3ad1bd4302da888e4d8/graphiti_core-0.11.6.tar.gz", hash = "sha256:31d26621834d7d4b8865059ab749feb18af15937b59c69598a640a5dfabea331", size = 71928, upload-time = "2025-05-15T17:58:02.304Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ac/2e/c8f22f01585bf173d1c82f6d4615511aebc75aeda764c69aa394446fa93c/graphiti_core-0.11.6-py3-none-any.whl", hash = "sha256:6ec4807a884f5ea88b942d0c8b7bcd2e107c7358ab4f98ef2a2092c229929707", size = 111001, upload-time = "2025-05-15T17:58:00.542Z" }, +] + [[package]] name = "h11" version = "0.16.0" @@ -1248,6 +1275,8 @@ dependencies = [ { name = "charset-normalizer" }, { name = "flask" }, { name = "flask-cors" }, + { name = "graphiti-core" }, + { name = "neo4j" }, { name = "openai" }, { name = "pydantic" }, { name = "pymupdf" }, @@ -1276,6 +1305,8 @@ requires-dist = [ { name = "charset-normalizer", specifier = ">=3.0.0" }, { name = "flask", specifier = ">=3.0.0" }, { name = "flask-cors", specifier = ">=6.0.0" }, + { name = "graphiti-core", specifier = ">=0.5.0" }, + { name = "neo4j", specifier = ">=5.0.0" }, { name = "openai", specifier = ">=1.0.0" }, { name = "pipreqs", marker = "extra == 'dev'", specifier = ">=0.5.0" }, { name = "pydantic", specifier = ">=2.0.0" }, @@ -1283,7 +1314,7 @@ requires-dist = [ { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" }, { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.23.0" }, { name = "python-dotenv", specifier = ">=1.0.0" }, - { name = "zep-cloud", specifier = "==3.13.0" }, + { name = "zep-cloud", specifier = ">=3.13.0" }, ] provides-extras = ["dev"] @@ -2987,6 +3018,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/40/44/4a5f08c96eb108af5cb50b41f76142f0afa346dfa99d5296fe7202a11854/tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f", size = 35252, upload-time = "2022-10-06T17:21:44.262Z" }, ] +[[package]] +name = "tenacity" +version = "9.1.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/47/c6/ee486fd809e357697ee8a44d3d69222b344920433d3b6666ccd9b374630c/tenacity-9.1.4.tar.gz", hash = "sha256:adb31d4c263f2bd041081ab33b498309a57c77f9acf2db65aadf0898179cf93a", size = 49413, upload-time = "2026-02-07T10:45:33.841Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d7/c1/eb8f9debc45d3b7918a32ab756658a0904732f75e555402972246b0b8e71/tenacity-9.1.4-py3-none-any.whl", hash = "sha256:6095a360c919085f28c6527de529e76a06ad89b23659fa881ae0649b867a9d55", size = 28926, upload-time = "2026-02-07T10:45:32.24Z" }, +] + [[package]] name = "texttable" version = "1.7.0" diff --git a/docs/KNOWLEDGE_GRAPH_MODE.md b/docs/KNOWLEDGE_GRAPH_MODE.md new file mode 100644 index 00000000..d8973899 --- /dev/null +++ b/docs/KNOWLEDGE_GRAPH_MODE.md @@ -0,0 +1,147 @@ +# 知识图谱双模式说明 + +MiroFish 支持两种知识图谱模式,您可以根据需求选择使用 **Cloud 模式** 或 **Local 模式**。 + +## 模式简介 + +| 模式 | 部署方式 | 适用场景 | +|------|----------|----------| +| Cloud | Zep Cloud API | 快速上手、无需本地部署 | +| Local | Graphiti + Neo4j | 数据隐私、完全控制 | + +## 快速开始 + +### 1. 选择模式 + +在 `.env` 文件中设置 `KNOWLEDGE_GRAPH_MODE`: + +```env +# Cloud 模式 (默认) +KNOWLEDGE_GRAPH_MODE=cloud + +# Local 模式 +KNOWLEDGE_GRAPH_MODE=local +``` + +### 2. 配置对应参数 + +#### Cloud 模式配置 + +```env +KNOWLEDGE_GRAPH_MODE=cloud +ZEP_API_KEY=your_zep_api_key_here +``` + +**获取 Zep API Key:** +1. 访问 [Zep Cloud](https://app.getzep.com/) +2. 注册账号并创建项目 +3. 在项目设置中找到 API Key +4. 每月免费额度即可支撑简单使用 + +#### Local 模式配置 + +```env +KNOWLEDGE_GRAPH_MODE=local + +# Neo4j 数据库配置 +NEO4J_URI=bolt://localhost:7687 +NEO4J_USER=neo4j +NEO4J_PASSWORD=your_neo4j_password + +# 嵌入向量 API (支持 OpenAI 兼容 API) +OPENAI_API_KEY=your_openai_key_here +# 或使用其他 OpenAI 兼容服务: +# OPENAI_BASE_URL=https://your-custom-api.com/v1 +``` + +**Local 模式前置要求:** + +1. **安装 Neo4j** + ```bash + # macOS (Homebrew) + brew install neo4j + brew services start neo4j + + # 或使用 Docker + docker run -d --name neo4j \ + -p 7474:7474 -p 7687:7687 \ + -e NEO4J_AUTH=neo4j/password \ + neo4j + ``` + +2. **配置嵌入向量 API** + - 支持 OpenAI、阿里云百炼、Cohere、Ollama、LM Studio 等 + - 确保 API 可访问 + +## 切换模式 + +修改 `.env` 文件中的 `KNOWLEDGE_GRAPH_MODE` 后,重启服务即可生效。 + +```bash +# 重启后端服务 +cd backend +python app.py +``` + +## 配置参数说明 + +### 知识图谱通用配置 + +| 参数 | 说明 | 默认值 | +|------|------|--------| +| `KNOWLEDGE_GRAPH_MODE` | 模式选择: `cloud` 或 `local` | `cloud` | + +### Cloud 模式参数 + +| 参数 | 说明 | +|------|------| +| `ZEP_API_KEY` | Zep Cloud API Key | + +### Local 模式参数 + +| 参数 | 说明 | 默认值 | +|------|------|--------| +| `NEO4J_URI` | Neo4j 连接地址 | `bolt://localhost:7687` | +| `NEO4J_USER` | Neo4j 用户名 | `neo4j` | +| `NEO4J_PASSWORD` | Neo4j 密码 | - | +| `OPENAI_API_KEY` | 嵌入向量 API Key | 使用 `LLM_API_KEY` | + +### 嵌入模型配置 (Local 模式) + +| 参数 | 说明 | 默认值 | +|------|------|--------| +| `EMBEDDING_API_KEY` | 嵌入模型 API Key | 使用 `LLM_API_KEY` | +| `EMBEDDING_BASE_URL` | 嵌入模型 API 地址 | 使用 `LLM_BASE_URL` | +| `EMBEDDING_MODEL` | 嵌入模型名称 | `text-embedding-3-small` | +| `EMBEDDING_DIM` | 嵌入向量维度 | `1536` | +| `EMBEDDING_BATCH_SIZE` | 批处理大小 | `5` | + +## 常见问题 + +### Q1: 如何判断当前使用哪种模式? + +检查 `.env` 文件中 `KNOWLEDGE_GRAPH_MODE` 的值。 + +### Q2: Local 模式启动失败? + +1. 确认 Neo4j 已启动: `brew services list` 或 `docker ps` +2. 检查 `NEO4J_URI`、`NEO4J_USER`、`NEO4J_PASSWORD` 配置是否正确 +3. 确认嵌入向量 API 可访问 + +### Q3: Cloud 模式返回空结果? + +1. 确认 `ZEP_API_KEY` 正确配置 +2. 检查网络连接是否正常 +3. 确认 Zep Cloud 账户状态正常 + +### Q4: 可以在同一项目中切换模式吗? + +可以。修改 `KNOWLEDGE_GRAPH_MODE` 并重启服务即可。但注意: +- Cloud 和 Local 的数据不互通 +- 切换后需要重新导入数据 + +## 相关文档 + +- [Zep Cloud 官方文档](https://docs.getzep.com/) +- [Graphiti GitHub](https://github.com/getzep/graphiti) +- [Neo4j 官方文档](https://neo4j.com/docs/) diff --git a/docs/KNOWLEDGE_GRAPH_MODE_EN.md b/docs/KNOWLEDGE_GRAPH_MODE_EN.md new file mode 100644 index 00000000..564e9166 --- /dev/null +++ b/docs/KNOWLEDGE_GRAPH_MODE_EN.md @@ -0,0 +1,147 @@ +# Knowledge Graph Dual-Mode Guide + +MiroFish supports two knowledge graph modes: **Cloud Mode** and **Local Mode**. Choose based on your requirements. + +## Mode Overview + +| Mode | Deployment | Use Case | +|------|------------|----------| +| Cloud | Zep Cloud API | Quick setup, no local deployment | +| Local | Graphiti + Neo4j | Data privacy, full control | + +## Quick Start + +### 1. Choose Mode + +Set `KNOWLEDGE_GRAPH_MODE` in your `.env` file: + +```env +# Cloud Mode (default) +KNOWLEDGE_GRAPH_MODE=cloud + +# Local Mode +KNOWLEDGE_GRAPH_MODE=local +``` + +### 2. Configure Corresponding Parameters + +#### Cloud Mode Configuration + +```env +KNOWLEDGE_GRAPH_MODE=cloud +ZEP_API_KEY=your_zep_api_key_here +``` + +**Get Zep API Key:** +1. Visit [Zep Cloud](https://app.getzep.com/) +2. Register and create a project +3. Find API Key in project settings +4. Free tier is sufficient for basic usage + +#### Local Mode Configuration + +```env +KNOWLEDGE_GRAPH_MODE=local + +# Neo4j Database Configuration +NEO4J_URI=bolt://localhost:7687 +NEO4J_USER=neo4j +NEO4J_PASSWORD=your_neo4j_password + +# Embedding API (OpenAI-compatible) +OPENAI_API_KEY=your_openai_key_here +# Or use other OpenAI-compatible services: +# OPENAI_BASE_URL=https://your-custom-api.com/v1 +``` + +**Local Mode Prerequisites:** + +1. **Install Neo4j** + ```bash + # macOS (Homebrew) + brew install neo4j + brew services start neo4j + + # Or use Docker + docker run -d --name neo4j \ + -p 7474:7474 -p 7687:7687 \ + -e NEO4J_AUTH=neo4j/password \ + neo4j + ``` + +2. **Configure Embedding API** + - Supports OpenAI, Alibaba Cloud Bailian, Cohere, Ollama, LM Studio, etc. + - Ensure API is accessible + +## Switching Modes + +Modify `KNOWLEDGE_GRAPH_MODE` in `.env` and restart the service. + +```bash +# Restart backend +cd backend +python app.py +``` + +## Configuration Reference + +### Common Knowledge Graph Config + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `KNOWLEDGE_GRAPH_MODE` | Mode: `cloud` or `local` | `cloud` | + +### Cloud Mode Parameters + +| Parameter | Description | +|-----------|-------------| +| `ZEP_API_KEY` | Zep Cloud API Key | + +### Local Mode Parameters + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `NEO4J_URI` | Neo4j connection URL | `bolt://localhost:7687` | +| `NEO4J_USER` | Neo4j username | `neo4j` | +| `NEO4J_PASSWORD` | Neo4j password | - | +| `OPENAI_API_KEY` | Embedding API Key | uses `LLM_API_KEY` | + +### Embedding Model Config (Local Mode) + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `EMBEDDING_API_KEY` | Embedding API Key | uses `LLM_API_KEY` | +| `EMBEDDING_BASE_URL` | Embedding API URL | uses `LLM_BASE_URL` | +| `EMBEDDING_MODEL` | Embedding model name | `text-embedding-3-small` | +| `EMBEDDING_DIM` | Embedding vector dimension | `1536` | +| `EMBEDDING_BATCH_SIZE` | Batch size | `5` | + +## FAQ + +### Q1: How to check which mode is active? + +Check the value of `KNOWLEDGE_GRAPH_MODE` in your `.env` file. + +### Q2: Local mode fails to start? + +1. Confirm Neo4j is running: `brew services list` or `docker ps` +2. Check `NEO4J_URI`, `NEO4J_USER`, `NEO4J_PASSWORD` are correct +3. Verify embedding API is accessible + +### Q3: Cloud mode returns empty results? + +1. Confirm `ZEP_API_KEY` is correctly configured +2. Check network connectivity +3. Verify Zep Cloud account status is active + +### Q4: Can I switch modes in the same project? + +Yes. Change `KNOWLEDGE_GRAPH_MODE` and restart. Note: +- Cloud and Local data are not shared +- You need to re-import data after switching + +## Related Documentation + +- [Zep Cloud Documentation](https://docs.getzep.com/) +- [Graphiti GitHub](https://github.com/getzep/graphiti) +- [Neo4j Documentation](https://neo4j.com/docs/) diff --git a/frontend/src/api/report.js b/frontend/src/api/report.js index c89a67d8..cd95f9d5 100644 --- a/frontend/src/api/report.js +++ b/frontend/src/api/report.js @@ -49,3 +49,13 @@ export const getReport = (reportId) => { export const chatWithReport = (data) => { return requestWithRetry(() => service.post('/api/report/chat', data), 3, 1000) } + +/** + * 下载报告文件 + * @param {string} reportId + */ +export const downloadReport = (reportId) => { + return service.get(`/api/report/${reportId}/download`, { + responseType: 'blob' + }) +} diff --git a/frontend/src/api/simulation.js b/frontend/src/api/simulation.js index f878586f..c045d777 100644 --- a/frontend/src/api/simulation.js +++ b/frontend/src/api/simulation.js @@ -185,3 +185,11 @@ export const getSimulationHistory = (limit = 20) => { return service.get('/api/simulation/history', { params: { limit } }) } +/** + * 删除模拟 + * @param {string} simulationId - 模拟ID + */ +export const deleteSimulation = (simulationId) => { + return service.delete(`/api/simulation/${simulationId}`) +} + diff --git a/frontend/src/components/HistoryDatabase.vue b/frontend/src/components/HistoryDatabase.vue index edc73f46..0f98cc12 100644 --- a/frontend/src/components/HistoryDatabase.vue +++ b/frontend/src/components/HistoryDatabase.vue @@ -89,11 +89,23 @@ {{ formatDate(project.created_at) }} {{ formatTime(project.created_at) }} - - {{ formatRounds(project) }} - + - +
@@ -187,13 +199,40 @@ + + + + + + +