From d700c0c7d0a595f20fd415c476a4b3547e1bfb0d Mon Sep 17 00:00:00 2001 From: Pratyush Sharma <56130065+pratyush618@users.noreply.github.com> Date: Tue, 10 Mar 2026 14:58:42 +0530 Subject: [PATCH] feat(ragflow): add RAGflow as alternative graph backend with full pipeline compliance - Add RagflowGraphBuilderService and RagflowEntityReader for self-hosted graph support - Add _get_entity_reader() helper in simulation.py to auto-select reader by graph_id prefix - Fix 4 simulation endpoints (get_graph_entities, get_entity_detail, get_entities_by_type, generate_profiles) to support ragflow_ graph IDs - Guard ZepGraphMemoryManager.create_updater() to skip for RAGflow graph IDs - Add get_node_edges, get_entity_with_context, get_entities_by_type methods to RagflowEntityReader - Update config.py, project.py, graph.py, simulation_manager.py for dual-backend support - Document RAGflow config in .env.example, README.md, and README-EN.md --- .env.example | 8 + README-EN.md | 5 + README.md | 5 + backend/app/api/graph.py | 313 ++++++----- backend/app/api/simulation.py | 42 +- backend/app/config.py | 17 +- backend/app/models/project.py | 3 + backend/app/services/ragflow_entity_reader.py | 218 ++++++++ backend/app/services/ragflow_graph_builder.py | 527 ++++++++++++++++++ backend/app/services/simulation_manager.py | 30 +- backend/app/services/simulation_runner.py | 17 +- backend/pyproject.toml | 1 + 12 files changed, 1002 insertions(+), 184 deletions(-) create mode 100644 backend/app/services/ragflow_entity_reader.py create mode 100644 backend/app/services/ragflow_graph_builder.py diff --git a/.env.example b/.env.example index 78a3b72c..f44a8dd8 100644 --- a/.env.example +++ b/.env.example @@ -9,6 +9,14 @@ LLM_MODEL_NAME=qwen-plus # 每月免费额度即可支撑简单使用:https://app.getzep.com/ ZEP_API_KEY=your_zep_api_key_here +# ===== RAGflow图谱配置(可替代Zep,本地部署无免费额度限制)===== +# 设置 GRAPH_BACKEND=ragflow 后,ZEP_API_KEY 不再需要 +# GRAPH_BACKEND=ragflow +# RAGFLOW_BASE_URL=http://localhost +# RAGFLOW_API_KEY=your-ragflow-api-key +# RAGFLOW_LLM_ID= # 可选,留空使用RAGflow系统默认 +# RAGFLOW_EMBEDDING_MODEL= # 可选,留空使用RAGflow系统默认 + # ===== 加速 LLM 配置(可选)===== # 注意如果不使用加速配置,env文件中就不要出现下面的配置项 LLM_BOOST_API_KEY=your_api_key_here diff --git a/README-EN.md b/README-EN.md index cd24e83e..0e5898d1 100644 --- a/README-EN.md +++ b/README-EN.md @@ -125,6 +125,11 @@ LLM_MODEL_NAME=qwen-plus # Zep Cloud Configuration # Free monthly quota is sufficient for simple usage: https://app.getzep.com/ ZEP_API_KEY=your_zep_api_key +# Graph backend choice (pick one, defaults to Zep) +# To use RAGflow instead, set GRAPH_BACKEND=ragflow; ZEP_API_KEY is then not required +# GRAPH_BACKEND=ragflow +# RAGFLOW_BASE_URL=http://localhost +# RAGFLOW_API_KEY=your-ragflow-api-key ``` #### 2. Install Dependencies diff --git a/README.md b/README.md index a47976c4..cbf900ba 100644 --- a/README.md +++ b/README.md @@ -125,6 +125,11 @@ LLM_MODEL_NAME=qwen-plus # Zep Cloud 配置 # 每月免费额度即可支撑简单使用:https://app.getzep.com/ ZEP_API_KEY=your_zep_api_key +# 图谱后端选择(二选一,默认 Zep) +# 如使用 RAGflow,将 GRAPH_BACKEND 设为 ragflow,不再需要 ZEP_API_KEY +# GRAPH_BACKEND=ragflow +# RAGFLOW_BASE_URL=http://localhost +# RAGFLOW_API_KEY=your-ragflow-api-key ``` #### 2. 安装依赖 diff --git a/backend/app/api/graph.py b/backend/app/api/graph.py index 12ff1ba2..b24f67b0 100644 --- a/backend/app/api/graph.py +++ b/backend/app/api/graph.py @@ -6,18 +6,47 @@ import os import traceback import threading +from typing import Optional from flask import request, jsonify from . import graph_bp from ..config import Config from ..services.ontology_generator import OntologyGenerator from ..services.graph_builder import GraphBuilderService +from ..services.ragflow_graph_builder import RagflowGraphBuilderService from ..services.text_processor import TextProcessor from ..utils.file_parser import FileParser from ..utils.logger import get_logger from ..models.task import TaskManager, TaskStatus from ..models.project import ProjectManager, ProjectStatus + +def _is_ragflow_graph(graph_id: str) -> bool: + return graph_id.startswith("ragflow_") + + +def _get_builder(graph_id: Optional[str] = None, backend: Optional[str] = None): + """ + 根据graph_id前缀或backend参数返回合适的图谱构建服务实例。 + + 优先级:backend参数 > graph_id前缀 > GRAPH_BACKEND配置 + """ + # 从显式参数或graph_id前缀确定后端 + if backend is None: + if graph_id and _is_ragflow_graph(graph_id): + backend = "ragflow" + else: + backend = Config.GRAPH_BACKEND + + if backend == "ragflow": + if not Config.RAGFLOW_API_KEY: + raise ValueError("RAGFLOW_API_KEY 未配置,无法使用RAGflow后端") + return RagflowGraphBuilderService() + else: + if not Config.ZEP_API_KEY: + raise ValueError("ZEP_API_KEY 未配置,无法使用Zep后端") + return GraphBuilderService(api_key=Config.ZEP_API_KEY) + # 获取日志器 logger = get_logger('mirofish.api') @@ -281,22 +310,27 @@ def build_graph(): """ try: logger.info("=== 开始构建图谱 ===") - - # 检查配置 - errors = [] - if not Config.ZEP_API_KEY: - errors.append("ZEP_API_KEY未配置") - if errors: - logger.error(f"配置错误: {errors}") - return jsonify({ - "success": False, - "error": "配置错误: " + "; ".join(errors) - }), 500 - + # 解析请求 data = request.get_json() or {} project_id = data.get('project_id') - logger.debug(f"请求参数: project_id={project_id}") + # backend参数:显式指定后端("zep" 或 "ragflow"),不传则使用配置默认值 + requested_backend = data.get('backend', Config.GRAPH_BACKEND) + logger.debug(f"请求参数: project_id={project_id}, backend={requested_backend}") + + # 检查后端配置 + if requested_backend == 'ragflow': + if not Config.RAGFLOW_API_KEY: + return jsonify({ + "success": False, + "error": "RAGFLOW_API_KEY未配置,无法使用RAGflow后端" + }), 500 + else: + if not Config.ZEP_API_KEY: + return jsonify({ + "success": False, + "error": "ZEP_API_KEY未配置,无法使用Zep后端" + }), 500 if not project_id: return jsonify({ @@ -363,116 +397,129 @@ def build_graph(): # 创建异步任务 task_manager = TaskManager() task_id = task_manager.create_task(f"构建图谱: {graph_name}") - logger.info(f"创建图谱构建任务: task_id={task_id}, project_id={project_id}") - + logger.info(f"创建图谱构建任务: task_id={task_id}, project_id={project_id}, backend={requested_backend}") + # 更新项目状态 project.status = ProjectStatus.GRAPH_BUILDING project.graph_build_task_id = task_id + project.graph_backend = requested_backend ProjectManager.save_project(project) - + + # 获取项目文件路径(供RAGflow直接上传原始文件) + project_file_paths = ProjectManager.get_project_files(project_id) + # 启动后台任务 def build_task(): build_logger = get_logger('mirofish.build') try: - build_logger.info(f"[{task_id}] 开始构建图谱...") - task_manager.update_task( - task_id, - status=TaskStatus.PROCESSING, - message="初始化图谱构建服务..." - ) - - # 创建图谱构建服务 - builder = GraphBuilderService(api_key=Config.ZEP_API_KEY) - - # 分块 - task_manager.update_task( - task_id, - message="文本分块中...", - progress=5 - ) - chunks = TextProcessor.split_text( - text, - chunk_size=chunk_size, - overlap=chunk_overlap - ) - total_chunks = len(chunks) - - # 创建图谱 - task_manager.update_task( - task_id, - message="创建Zep图谱...", - progress=10 - ) - graph_id = builder.create_graph(name=graph_name) - - # 更新项目的graph_id - project.graph_id = graph_id - ProjectManager.save_project(project) - - # 设置本体 + build_logger.info(f"[{task_id}] 开始构建图谱(后端: {requested_backend})...") task_manager.update_task( task_id, - message="设置本体定义...", - progress=15 + status=TaskStatus.PROCESSING, + message=f"初始化图谱构建服务({requested_backend})..." ) - builder.set_ontology(graph_id, ontology) - - # 添加文本(progress_callback 签名是 (msg, progress_ratio)) - def add_progress_callback(msg, progress_ratio): - progress = 15 + int(progress_ratio * 40) # 15% - 55% + + # 初始化变量(两种后端都需要) + graph_id = "" + node_count = 0 + edge_count = 0 + + if requested_backend == 'ragflow': + # ── RAGflow后端 ────────────────────────────────────────── + builder = RagflowGraphBuilderService() + task_manager.update_task(task_id, progress=5, + message="正在创建RAGflow数据集...") + + # RAGflow优先上传原始文件,没有文件时上传提取的文本 + builder_task_id = builder.build_graph_async( + text=text, + ontology=ontology, + graph_name=graph_name, + file_paths=project_file_paths if project_file_paths else None, + ) + + # 轮询RAGflow子任务直到完成 + import time as _time + while True: + sub_task = task_manager.get_task(builder_task_id) + if sub_task is None: + break + # 将子任务进度同步到主任务 + task_manager.update_task( + task_id, + progress=sub_task.progress or 0, + message=sub_task.message or "" + ) + if sub_task.status in (TaskStatus.COMPLETED, TaskStatus.FAILED): + if sub_task.status == TaskStatus.FAILED: + raise RuntimeError(sub_task.error or "RAGflow构建失败") + # 从子任务结果获取graph_id + result = sub_task.result or {} + graph_id = result.get("graph_id", "") + node_count = result.get("node_count", 0) + edge_count = result.get("edge_count", 0) + break + _time.sleep(3) + + project.graph_id = graph_id + ProjectManager.save_project(project) + + else: + # ── Zep后端(原有逻辑)─────────────────────────────────── + builder = GraphBuilderService(api_key=Config.ZEP_API_KEY) + + task_manager.update_task(task_id, message="文本分块中...", progress=5) + chunks = TextProcessor.split_text( + text, chunk_size=chunk_size, overlap=chunk_overlap + ) + total_chunks = len(chunks) + + task_manager.update_task(task_id, message="创建Zep图谱...", progress=10) + graph_id = builder.create_graph(name=graph_name) + + project.graph_id = graph_id + ProjectManager.save_project(project) + + task_manager.update_task(task_id, message="设置本体定义...", progress=15) + builder.set_ontology(graph_id, ontology) + + def add_progress_callback(msg, progress_ratio): + progress = 15 + int(progress_ratio * 40) # 15% - 55% + task_manager.update_task(task_id, message=msg, progress=progress) + task_manager.update_task( task_id, - message=msg, - progress=progress + message=f"开始添加 {total_chunks} 个文本块...", + progress=15 ) - - task_manager.update_task( - task_id, - message=f"开始添加 {total_chunks} 个文本块...", - progress=15 - ) - - episode_uuids = builder.add_text_batches( - graph_id, - chunks, - batch_size=3, - progress_callback=add_progress_callback - ) - - # 等待Zep处理完成(查询每个episode的processed状态) - task_manager.update_task( - task_id, - message="等待Zep处理数据...", - progress=55 - ) - - def wait_progress_callback(msg, progress_ratio): - progress = 55 + int(progress_ratio * 35) # 55% - 90% + episode_uuids = builder.add_text_batches( + graph_id, chunks, batch_size=3, + progress_callback=add_progress_callback + ) + task_manager.update_task( - task_id, - message=msg, - progress=progress + task_id, message="等待Zep处理数据...", progress=55 ) - - builder._wait_for_episodes(episode_uuids, wait_progress_callback) - - # 获取图谱数据 - task_manager.update_task( - task_id, - message="获取图谱数据...", - progress=95 - ) - graph_data = builder.get_graph_data(graph_id) - - # 更新项目状态 + + def wait_progress_callback(msg, progress_ratio): + progress = 55 + int(progress_ratio * 35) # 55% - 90% + task_manager.update_task(task_id, message=msg, progress=progress) + + builder._wait_for_episodes(episode_uuids, wait_progress_callback) + + task_manager.update_task(task_id, message="获取图谱数据...", progress=95) + graph_data = builder.get_graph_data(graph_id) + node_count = graph_data.get("node_count", 0) + edge_count = graph_data.get("edge_count", 0) + + # ── 完成(两种后端共用)──────────────────────────────────── project.status = ProjectStatus.GRAPH_COMPLETED ProjectManager.save_project(project) - - node_count = graph_data.get("node_count", 0) - edge_count = graph_data.get("edge_count", 0) - build_logger.info(f"[{task_id}] 图谱构建完成: graph_id={graph_id}, 节点={node_count}, 边={edge_count}") - - # 完成 + + build_logger.info( + f"[{task_id}] 图谱构建完成: backend={requested_backend}, " + f"graph_id={graph_id}, 节点={node_count}, 边={edge_count}" + ) task_manager.update_task( task_id, status=TaskStatus.COMPLETED, @@ -481,28 +528,27 @@ def wait_progress_callback(msg, progress_ratio): result={ "project_id": project_id, "graph_id": graph_id, + "backend": requested_backend, "node_count": node_count, "edge_count": edge_count, - "chunk_count": total_chunks } ) - + except Exception as e: - # 更新项目状态为失败 build_logger.error(f"[{task_id}] 图谱构建失败: {str(e)}") build_logger.debug(traceback.format_exc()) - + project.status = ProjectStatus.FAILED project.error = str(e) ProjectManager.save_project(project) - + task_manager.update_task( task_id, status=TaskStatus.FAILED, message=f"构建失败: {str(e)}", error=traceback.format_exc() ) - + # 启动后台线程 thread = threading.Thread(target=build_task, daemon=True) thread.start() @@ -512,7 +558,8 @@ def wait_progress_callback(msg, progress_ratio): "data": { "project_id": project_id, "task_id": task_id, - "message": "图谱构建任务已启动,请通过 /task/{task_id} 查询进度" + "backend": requested_backend, + "message": f"图谱构建任务已启动(后端: {requested_backend}),请通过 /task/{{task_id}} 查询进度" } }) @@ -565,22 +612,16 @@ def list_tasks(): def get_graph_data(graph_id: str): """ 获取图谱数据(节点和边) + + 根据graph_id前缀自动选择后端: + - ragflow_* → RAGflow后端(读取本地缓存) + - 其他 → Zep后端 """ try: - if not Config.ZEP_API_KEY: - return jsonify({ - "success": False, - "error": "ZEP_API_KEY未配置" - }), 500 - - builder = GraphBuilderService(api_key=Config.ZEP_API_KEY) + builder = _get_builder(graph_id=graph_id) graph_data = builder.get_graph_data(graph_id) - - return jsonify({ - "success": True, - "data": graph_data - }) - + return jsonify({"success": True, "data": graph_data}) + except Exception as e: return jsonify({ "success": False, @@ -592,23 +633,17 @@ def get_graph_data(graph_id: str): @graph_bp.route('/delete/', methods=['DELETE']) def delete_graph(graph_id: str): """ - 删除Zep图谱 + 删除图谱 + + 根据graph_id前缀自动选择后端: + - ragflow_* → 删除RAGflow数据集及本地缓存 + - 其他 → 删除Zep图谱 """ try: - if not Config.ZEP_API_KEY: - return jsonify({ - "success": False, - "error": "ZEP_API_KEY未配置" - }), 500 - - builder = GraphBuilderService(api_key=Config.ZEP_API_KEY) + builder = _get_builder(graph_id=graph_id) builder.delete_graph(graph_id) - - return jsonify({ - "success": True, - "message": f"图谱已删除: {graph_id}" - }) - + return jsonify({"success": True, "message": f"图谱已删除: {graph_id}"}) + except Exception as e: return jsonify({ "success": False, diff --git a/backend/app/api/simulation.py b/backend/app/api/simulation.py index 3a0f6816..0fa45099 100644 --- a/backend/app/api/simulation.py +++ b/backend/app/api/simulation.py @@ -10,6 +10,7 @@ from . import simulation_bp from ..config import Config from ..services.zep_entity_reader import ZepEntityReader +from ..services.ragflow_entity_reader import RagflowEntityReader from ..services.oasis_profile_generator import OasisProfileGenerator from ..services.simulation_manager import SimulationManager, SimulationStatus from ..services.simulation_runner import SimulationRunner, RunnerStatus @@ -19,6 +20,15 @@ logger = get_logger('mirofish.api.simulation') +def _get_entity_reader(graph_id: str): + """根据graph_id前缀选择合适的实体读取器""" + if graph_id.startswith("ragflow_"): + return RagflowEntityReader() + if not Config.ZEP_API_KEY: + raise ValueError("ZEP_API_KEY未配置") + return ZepEntityReader() + + # Interview prompt 优化前缀 # 添加此前缀可以避免Agent调用工具,直接用文本回复 INTERVIEW_PROMPT_PREFIX = "结合你的人设、所有的过往记忆与行动,不调用任何工具直接用文本回复我:" @@ -56,19 +66,13 @@ def get_graph_entities(graph_id: str): enrich: 是否获取相关边信息(默认true) """ try: - if not Config.ZEP_API_KEY: - return jsonify({ - "success": False, - "error": "ZEP_API_KEY未配置" - }), 500 - entity_types_str = request.args.get('entity_types', '') entity_types = [t.strip() for t in entity_types_str.split(',') if t.strip()] if entity_types_str else None enrich = request.args.get('enrich', 'true').lower() == 'true' - + logger.info(f"获取图谱实体: graph_id={graph_id}, entity_types={entity_types}, enrich={enrich}") - - reader = ZepEntityReader() + + reader = _get_entity_reader(graph_id) result = reader.filter_defined_entities( graph_id=graph_id, defined_entity_types=entity_types, @@ -93,13 +97,7 @@ def get_graph_entities(graph_id: str): def get_entity_detail(graph_id: str, entity_uuid: str): """获取单个实体的详细信息""" try: - if not Config.ZEP_API_KEY: - return jsonify({ - "success": False, - "error": "ZEP_API_KEY未配置" - }), 500 - - reader = ZepEntityReader() + reader = _get_entity_reader(graph_id) entity = reader.get_entity_with_context(graph_id, entity_uuid) if not entity: @@ -126,15 +124,9 @@ def get_entity_detail(graph_id: str, entity_uuid: str): def get_entities_by_type(graph_id: str, entity_type: str): """获取指定类型的所有实体""" try: - if not Config.ZEP_API_KEY: - return jsonify({ - "success": False, - "error": "ZEP_API_KEY未配置" - }), 500 - enrich = request.args.get('enrich', 'true').lower() == 'true' - - reader = ZepEntityReader() + + reader = _get_entity_reader(graph_id) entities = reader.get_entities_by_type( graph_id=graph_id, entity_type=entity_type, @@ -1396,7 +1388,7 @@ def generate_profiles(): use_llm = data.get('use_llm', True) platform = data.get('platform', 'reddit') - reader = ZepEntityReader() + reader = _get_entity_reader(graph_id) filtered = reader.filter_defined_entities( graph_id=graph_id, defined_entity_types=entity_types, diff --git a/backend/app/config.py b/backend/app/config.py index 953dfa50..7c0f4883 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -34,6 +34,15 @@ class Config: # Zep配置 ZEP_API_KEY = os.environ.get('ZEP_API_KEY') + + # RAGflow配置(知识图谱后端替代方案) + # GRAPH_BACKEND: 图谱构建后端,可选 "zep"(默认)或 "ragflow" + GRAPH_BACKEND = os.environ.get('GRAPH_BACKEND', 'zep') + RAGFLOW_BASE_URL = os.environ.get('RAGFLOW_BASE_URL', 'http://localhost') + RAGFLOW_API_KEY = os.environ.get('RAGFLOW_API_KEY') + # RAGflow可选配置:指定使用的LLM和Embedding模型(留空使用RAGflow系统默认) + RAGFLOW_LLM_ID = os.environ.get('RAGFLOW_LLM_ID', '') + RAGFLOW_EMBEDDING_MODEL = os.environ.get('RAGFLOW_EMBEDDING_MODEL', '') # 文件上传配置 MAX_CONTENT_LENGTH = 50 * 1024 * 1024 # 50MB @@ -69,7 +78,11 @@ def validate(cls): errors = [] if not cls.LLM_API_KEY: errors.append("LLM_API_KEY 未配置") - if not cls.ZEP_API_KEY: - errors.append("ZEP_API_KEY 未配置") + if cls.GRAPH_BACKEND == 'ragflow': + if not cls.RAGFLOW_API_KEY: + errors.append("RAGFLOW_API_KEY 未配置(当前 GRAPH_BACKEND=ragflow)") + else: + if not cls.ZEP_API_KEY: + errors.append("ZEP_API_KEY 未配置") return errors diff --git a/backend/app/models/project.py b/backend/app/models/project.py index 08978937..4387cb59 100644 --- a/backend/app/models/project.py +++ b/backend/app/models/project.py @@ -43,6 +43,7 @@ class Project: # 图谱信息(接口2完成后填充) graph_id: Optional[str] = None graph_build_task_id: Optional[str] = None + graph_backend: str = "zep" # 图谱构建后端: "zep" 或 "ragflow" # 配置 simulation_requirement: Optional[str] = None @@ -66,6 +67,7 @@ def to_dict(self) -> Dict[str, Any]: "analysis_summary": self.analysis_summary, "graph_id": self.graph_id, "graph_build_task_id": self.graph_build_task_id, + "graph_backend": self.graph_backend, "simulation_requirement": self.simulation_requirement, "chunk_size": self.chunk_size, "chunk_overlap": self.chunk_overlap, @@ -91,6 +93,7 @@ def from_dict(cls, data: Dict[str, Any]) -> 'Project': analysis_summary=data.get('analysis_summary'), graph_id=data.get('graph_id'), graph_build_task_id=data.get('graph_build_task_id'), + graph_backend=data.get('graph_backend', 'zep'), simulation_requirement=data.get('simulation_requirement'), chunk_size=data.get('chunk_size', 500), chunk_overlap=data.get('chunk_overlap', 50), diff --git a/backend/app/services/ragflow_entity_reader.py b/backend/app/services/ragflow_entity_reader.py new file mode 100644 index 00000000..9cfbfd44 --- /dev/null +++ b/backend/app/services/ragflow_entity_reader.py @@ -0,0 +1,218 @@ +""" +RAGflow实体读取服务 +从本地缓存的RAGflow知识图谱中读取实体,返回与ZepEntityReader相同格式的数据 + +RAGflow图谱数据在构建时已保存到本地JSON文件,本服务直接读取该缓存。 +接口设计与ZepEntityReader保持一致,SimulationManager可无缝切换。 +""" + +import os +import json +from typing import Dict, Any, List, Optional, Set + +from ..utils.logger import get_logger +from .zep_entity_reader import EntityNode, FilteredEntities + +logger = get_logger('mirofish.ragflow_entity_reader') + +RAGFLOW_GRAPHS_DIR = os.path.join(os.path.dirname(__file__), '../../uploads/ragflow_graphs') + + +class RagflowEntityReader: + """ + RAGflow实体读取与过滤服务 + + 从本地缓存的RAGflow知识图谱JSON文件中读取实体和关系, + 并以与ZepEntityReader相同的接口返回FilteredEntities对象, + 使SimulationManager可以不加修改地使用RAGflow图谱进行模拟。 + """ + + def _load_graph_data(self, graph_id: str) -> Dict[str, Any]: + """从本地缓存加载图谱数据""" + graph_file = os.path.join(RAGFLOW_GRAPHS_DIR, graph_id, "graph_data.json") + if not os.path.exists(graph_file): + logger.error(f"RAGflow图谱缓存不存在: {graph_id},请确认图谱已成功构建(/graph/build)") + raise FileNotFoundError( + f"RAGflow图谱缓存不存在: {graph_id}。" + "请确认图谱已成功构建(/graph/build)。" + ) + with open(graph_file, 'r', encoding='utf-8') as f: + return json.load(f) + + def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]: + """获取图谱所有节点""" + return self._load_graph_data(graph_id).get("nodes", []) + + def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]: + """获取图谱所有边""" + return self._load_graph_data(graph_id).get("edges", []) + + def get_node_edges(self, node_uuid: str, graph_id: str) -> List[Dict[str, Any]]: + """获取指定节点的所有相关边(从本地缓存读取)""" + try: + all_edges = self.get_all_edges(graph_id) + return [ + e for e in all_edges + if e.get("source_node_uuid") == node_uuid + or e.get("target_node_uuid") == node_uuid + ] + except Exception as e: + logger.warning(f"获取节点 {node_uuid} 的边失败: {str(e)}") + return [] + + def filter_defined_entities( + self, + graph_id: str, + defined_entity_types: Optional[List[str]] = None, + enrich_with_edges: bool = True, + ) -> FilteredEntities: + """ + 筛选符合预定义实体类型的节点(与ZepEntityReader.filter_defined_entities接口一致) + + 筛选逻辑与ZepEntityReader相同: + - 节点labels中除"Entity"/"Node"外还有其他标签,才认为是有效实体 + - 如果指定了defined_entity_types,只保留匹配类型的实体 + + Args: + graph_id: RAGflow图谱ID(格式: ragflow_{dataset_id}) + defined_entity_types: 预定义实体类型列表(来自本体定义) + enrich_with_edges: 是否为实体附加关联边和节点信息 + + Returns: + FilteredEntities + """ + logger.info(f"从RAGflow图谱读取实体: {graph_id}") + + all_nodes = self.get_all_nodes(graph_id) + all_edges = self.get_all_edges(graph_id) if enrich_with_edges else [] + + total_count = len(all_nodes) + node_map = {n["uuid"]: n for n in all_nodes} + + filtered_entities: List[EntityNode] = [] + entity_types_found: Set[str] = set() + + for node in all_nodes: + labels = node.get("labels", []) + custom_labels = [label for label in labels if label not in ("Entity", "Node")] + + if not custom_labels: + # 没有自定义标签,跳过 + continue + + if defined_entity_types: + matching = [l for l in custom_labels if l in defined_entity_types] + if not matching: + continue + entity_type = matching[0] + else: + entity_type = custom_labels[0] + + entity_types_found.add(entity_type) + + entity = EntityNode( + uuid=node["uuid"], + name=node.get("name", ""), + labels=labels, + summary=node.get("summary", ""), + attributes=node.get("attributes", {}), + ) + + if enrich_with_edges: + related_edges: List[Dict] = [] + related_node_uuids: Set[str] = set() + + for edge in all_edges: + src = edge.get("source_node_uuid", "") + tgt = edge.get("target_node_uuid", "") + + if src == node["uuid"]: + related_edges.append({ + "direction": "outgoing", + "edge_name": edge.get("name", ""), + "fact": edge.get("fact", ""), + "target_node_uuid": tgt, + }) + related_node_uuids.add(tgt) + elif tgt == node["uuid"]: + related_edges.append({ + "direction": "incoming", + "edge_name": edge.get("name", ""), + "fact": edge.get("fact", ""), + "source_node_uuid": src, + }) + related_node_uuids.add(src) + + entity.related_edges = related_edges + + related_nodes = [] + for rel_uuid in related_node_uuids: + if rel_uuid in node_map: + rn = node_map[rel_uuid] + related_nodes.append({ + "uuid": rn["uuid"], + "name": rn.get("name", ""), + "labels": rn.get("labels", []), + "summary": rn.get("summary", ""), + }) + entity.related_nodes = related_nodes + + filtered_entities.append(entity) + + logger.info( + f"RAGflow实体筛选完成: 总节点={total_count}, " + f"符合条件={len(filtered_entities)}, 类型={entity_types_found}" + ) + + return FilteredEntities( + entities=filtered_entities, + entity_types=entity_types_found, + total_count=total_count, + filtered_count=len(filtered_entities), + ) + + def get_entity_with_context( + self, graph_id: str, entity_uuid: str + ) -> Optional[EntityNode]: + """获取单个实体及其完整上下文(从本地缓存读取)""" + try: + all_nodes = self.get_all_nodes(graph_id) + all_edges = self.get_all_edges(graph_id) + node_map = {n["uuid"]: n for n in all_nodes} + node = node_map.get(entity_uuid) + if not node: + return None + related_edges, related_node_uuids = [], set() + for edge in all_edges: + src, tgt = edge.get("source_node_uuid", ""), edge.get("target_node_uuid", "") + if src == entity_uuid: + related_edges.append({"direction": "outgoing", "edge_name": edge.get("name", ""), + "fact": edge.get("fact", ""), "target_node_uuid": tgt}) + related_node_uuids.add(tgt) + elif tgt == entity_uuid: + related_edges.append({"direction": "incoming", "edge_name": edge.get("name", ""), + "fact": edge.get("fact", ""), "source_node_uuid": src}) + related_node_uuids.add(src) + related_nodes = [ + {"uuid": node_map[u]["uuid"], "name": node_map[u].get("name", ""), + "labels": node_map[u].get("labels", []), "summary": node_map[u].get("summary", "")} + for u in related_node_uuids if u in node_map + ] + return EntityNode( + uuid=node["uuid"], name=node.get("name", ""), labels=node.get("labels", []), + summary=node.get("summary", ""), attributes=node.get("attributes", {}), + related_edges=related_edges, related_nodes=related_nodes, + ) + except Exception as e: + logger.error(f"获取RAGflow实体 {entity_uuid} 失败: {str(e)}") + return None + + def get_entities_by_type( + self, graph_id: str, entity_type: str, enrich_with_edges: bool = True + ) -> List[EntityNode]: + """获取指定类型的所有实体""" + return self.filter_defined_entities( + graph_id=graph_id, + defined_entity_types=[entity_type], + enrich_with_edges=enrich_with_edges, + ).entities diff --git a/backend/app/services/ragflow_graph_builder.py b/backend/app/services/ragflow_graph_builder.py new file mode 100644 index 00000000..0b99baa3 --- /dev/null +++ b/backend/app/services/ragflow_graph_builder.py @@ -0,0 +1,527 @@ +""" +RAGflow图谱构建服务 +使用RAGflow API构建知识图谱(Zep的替代方案) + +RAGflow支持本地部署,通过知识图谱模式解析文档并提取实体和关系。 +图谱数据构建后会缓存到本地,供后续模拟使用。 + +图谱ID格式: ragflow_{dataset_id} +""" + +import os +import json +import uuid +import time +import threading +import tempfile +import shutil +from typing import Dict, Any, List, Optional, Callable + +import requests + +from ..config import Config +from ..models.task import TaskManager, TaskStatus +from ..utils.logger import get_logger + +logger = get_logger('mirofish.ragflow_graph_builder') + +# RAGflow图谱数据本地缓存目录 +RAGFLOW_GRAPHS_DIR = os.path.join(os.path.dirname(__file__), '../../uploads/ragflow_graphs') + + +class RagflowGraphBuilderService: + """ + RAGflow图谱构建服务 + + 工作流: + 1. 创建RAGflow数据集(知识库),使用knowledge_graph解析模式 + 2. 上传文档(原始文件或将文本写成临时文件) + 3. 触发解析,RAGflow会自动提取实体和关系 + 4. 轮询等待解析完成 + 5. 从RAGflow获取知识图谱数据(实体和关系) + 6. 将结果缓存到本地JSON文件,供模拟读取使用 + """ + + def __init__(self, base_url: Optional[str] = None, api_key: Optional[str] = None): + self.base_url = (base_url or Config.RAGFLOW_BASE_URL or "http://localhost").rstrip('/') + self.api_key = api_key or Config.RAGFLOW_API_KEY + if not self.api_key: + raise ValueError("RAGFLOW_API_KEY 未配置") + + self._auth_headers = {"Authorization": f"Bearer {self.api_key}"} + self._json_headers = {**self._auth_headers, "Content-Type": "application/json"} + self.task_manager = TaskManager() + + os.makedirs(RAGFLOW_GRAPHS_DIR, exist_ok=True) + + # ── HTTP请求工具 ────────────────────────────────────────────────────────── + + def _api_url(self, path: str) -> str: + return f"{self.base_url}/api/v1{path}" + + def _get(self, path: str, params: Optional[Dict] = None) -> Dict[str, Any]: + resp = requests.get(self._api_url(path), headers=self._auth_headers, + params=params, timeout=30) + resp.raise_for_status() + return resp.json() + + def _post(self, path: str, payload: Optional[Dict] = None) -> Dict[str, Any]: + resp = requests.post(self._api_url(path), headers=self._json_headers, + json=payload or {}, timeout=30) + resp.raise_for_status() + return resp.json() + + def _post_files(self, path: str, files: Dict, data: Optional[Dict] = None) -> Dict[str, Any]: + resp = requests.post(self._api_url(path), headers=self._auth_headers, + files=files, data=data or {}, timeout=60) + resp.raise_for_status() + return resp.json() + + def _delete(self, path: str, payload: Optional[Dict] = None) -> Dict[str, Any]: + resp = requests.delete(self._api_url(path), headers=self._json_headers, + json=payload or {}, timeout=30) + resp.raise_for_status() + return resp.json() + + @staticmethod + def _check_response(result: Dict, operation: str) -> Any: + """检查RAGflow API返回码,非0时抛出异常""" + if result.get("code", -1) != 0: + raise ValueError(f"{operation}失败: {result.get('message', '未知错误')}") + return result.get("data") + + # ── 数据集(知识库)管理 ────────────────────────────────────────────────── + + def create_dataset(self, name: str) -> str: + """ + 创建RAGflow数据集,使用knowledge_graph解析模式 + + Returns: + dataset_id + """ + payload: Dict[str, Any] = { + "name": name, + "chunk_method": "knowledge_graph", + "description": "MiroFish Knowledge Graph", + } + # 如果配置了RAGflow的LLM和Embedding,传入 + if Config.RAGFLOW_LLM_ID: + payload["llm_id"] = Config.RAGFLOW_LLM_ID + if Config.RAGFLOW_EMBEDDING_MODEL: + payload["embedding_model"] = Config.RAGFLOW_EMBEDDING_MODEL + + result = self._post("/datasets", payload) + data = self._check_response(result, "创建数据集") + dataset_id = data["id"] if isinstance(data, dict) else data + logger.info(f"RAGflow数据集已创建: {dataset_id}") + return dataset_id + + def delete_dataset(self, dataset_id: str): + """删除RAGflow数据集""" + result = self._delete("/datasets", {"ids": [dataset_id]}) + self._check_response(result, "删除数据集") + + # ── 文档上传 ────────────────────────────────────────────────────────────── + + def upload_text_as_document(self, dataset_id: str, text: str, + filename: str = "document.txt") -> str: + """将文本写成临时文件并上传到RAGflow数据集""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', + encoding='utf-8', delete=False) as f: + f.write(text) + tmp_path = f.name + try: + return self._upload_file(dataset_id, tmp_path, filename, "text/plain") + finally: + os.unlink(tmp_path) + + def upload_file_document(self, dataset_id: str, file_path: str) -> str: + """上传已有文件到RAGflow数据集""" + filename = os.path.basename(file_path) + if filename.endswith('.pdf'): + content_type = 'application/pdf' + elif filename.endswith('.md') or filename.endswith('.markdown'): + content_type = 'text/markdown' + else: + content_type = 'text/plain' + return self._upload_file(dataset_id, file_path, filename, content_type) + + def _upload_file(self, dataset_id: str, file_path: str, + filename: str, content_type: str) -> str: + with open(file_path, 'rb') as f: + files = {"file": (filename, f, content_type)} + result = self._post_files(f"/datasets/{dataset_id}/documents", files=files) + data = self._check_response(result, "上传文档") + docs = data if isinstance(data, list) else [data] + if not docs: + raise ValueError("上传成功但未返回文档信息") + doc_id = docs[0]["id"] + logger.info(f"文档已上传: {doc_id}") + return doc_id + + # ── 文档解析 ────────────────────────────────────────────────────────────── + + def start_parsing(self, dataset_id: str, document_ids: List[str]): + """触发文档解析(启动知识图谱构建)""" + result = self._post(f"/datasets/{dataset_id}/chunks", + {"document_ids": document_ids}) + self._check_response(result, "启动解析") + logger.info(f"解析已启动: dataset={dataset_id}, docs={document_ids}") + + def get_document_statuses(self, dataset_id: str) -> List[Dict[str, Any]]: + """获取数据集中所有文档的状态""" + result = self._get(f"/datasets/{dataset_id}/documents", + params={"page": 1, "page_size": 100}) + if result.get("code") == 0: + return result.get("data", {}).get("docs", []) + return [] + + def wait_for_parsing( + self, + dataset_id: str, + document_ids: List[str], + progress_callback: Optional[Callable[[str, float], None]] = None, + timeout: int = 600, + ): + """轮询等待文档解析完成(支持进度回调)""" + start_time = time.time() + pending = set(document_ids) + total = len(document_ids) + + # RAGflow run状态: 0=未处理, 1=运行中, 2=完成, 3=失败 + done_states = {"DONE", "done", "2", 2} + fail_states = {"FAIL", "fail", "3", 3, "CANCEL"} + + while pending: + if time.time() - start_time > timeout: + logger.warning(f"解析超时: 仍有 {len(pending)} 个文档未完成") + break + + for doc in self.get_document_statuses(dataset_id): + doc_id = doc.get("id") + if doc_id not in pending: + continue + run = doc.get("run", doc.get("status", "")) + if run in done_states: + pending.discard(doc_id) + logger.info(f"文档解析完成: {doc_id}") + elif run in fail_states: + pending.discard(doc_id) + logger.error(f"文档解析失败: {doc_id}, 状态: {run}") + + elapsed = int(time.time() - start_time) + completed = total - len(pending) + if progress_callback: + ratio = completed / total if total else 1.0 + progress_callback( + f"RAGflow解析中... {completed}/{total} 完成 ({elapsed}秒)", ratio + ) + + if pending: + time.sleep(5) + + if progress_callback: + progress_callback("文档解析完成", 1.0) + + # ── 实体与图谱数据获取 ──────────────────────────────────────────────────── + + def get_graph_entities(self, dataset_id: str) -> Dict[str, Any]: + """ + 从RAGflow获取知识图谱的实体和关系 + + 依次尝试: + 1. /datasets/{id}/graphs —— RAGflow 新版图谱API + 2. 解析chunks并提取实体/关系类型的条目 + """ + # 方式1:专用图谱端点(RAGflow v0.15+) + try: + result = self._get(f"/datasets/{dataset_id}/graphs") + if result.get("code") == 0: + data = result.get("data", {}) + nodes = data.get("nodes", data.get("entities", [])) + edges = data.get("edges", data.get("relations", data.get("relationships", []))) + if nodes or edges: + logger.info(f"通过图谱API获取到 {len(nodes)} 节点, {len(edges)} 边") + return { + "nodes": self._normalize_nodes(nodes), + "edges": self._normalize_edges(edges), + } + except Exception as e: + logger.debug(f"图谱专用API不可用,回退到chunks解析: {e}") + + # 方式2:从chunks中提取实体和关系 + nodes, edges = [], [] + try: + page = 1 + while True: + result = self._get( + f"/datasets/{dataset_id}/chunks", + params={"page": page, "page_size": 256}, + ) + if result.get("code") != 0: + break + data = result.get("data", {}) + chunks = data.get("chunks", []) + if not chunks: + break + n, e = self._parse_kg_chunks(chunks) + nodes.extend(n) + edges.extend(e) + if len(chunks) < 256: + break + page += 1 + logger.info(f"从chunks解析到 {len(nodes)} 节点, {len(edges)} 边") + except Exception as e: + logger.warning(f"从chunks获取数据失败: {e}") + + return {"nodes": nodes, "edges": edges} + + def _normalize_nodes(self, nodes: List[Dict]) -> List[Dict]: + """将RAGflow节点格式标准化为MiroFish格式""" + result = [] + for n in nodes: + nid = n.get("id", n.get("entity_id", str(uuid.uuid4()))) + name = n.get("name", n.get("entity_name", n.get("label", ""))) + entity_type = n.get("type", n.get("entity_type", "Entity")) + result.append({ + "uuid": nid, + "name": name, + "labels": [entity_type, "Entity"] if entity_type != "Entity" else ["Entity"], + "summary": n.get("description", n.get("summary", "")), + "attributes": n.get("attributes", n.get("properties", {})), + }) + return result + + def _normalize_edges(self, edges: List[Dict]) -> List[Dict]: + """将RAGflow边格式标准化为MiroFish格式""" + result = [] + for e in edges: + eid = e.get("id", e.get("relation_id", str(uuid.uuid4()))) + result.append({ + "uuid": eid, + "name": e.get("type", e.get("relation_type", e.get("label", ""))), + "fact": e.get("description", e.get("fact", "")), + "source_node_uuid": e.get("source_id", e.get("source", "")), + "target_node_uuid": e.get("target_id", e.get("target", "")), + "attributes": e.get("attributes", e.get("properties", {})), + }) + return result + + def _parse_kg_chunks(self, chunks: List[Dict]) -> tuple: + """从知识图谱chunks中解析实体和关系(chunks API回退方案)""" + nodes: List[Dict] = [] + edges: List[Dict] = [] + + for chunk in chunks: + chunk_type = chunk.get("type", chunk.get("chunk_type", "")).lower() + content = chunk.get("content", chunk.get("content_ltks", "")) + cid = chunk.get("chunk_id", chunk.get("id", str(uuid.uuid4()))) + + if chunk_type in ("entity", "node", "kg_entity"): + name = chunk.get("entity_name", chunk.get("name", content[:80])) + etype = chunk.get("entity_type", chunk.get("label", "Entity")) + nodes.append({ + "uuid": cid, + "name": name, + "labels": [etype, "Entity"] if etype != "Entity" else ["Entity"], + "summary": content, + "attributes": chunk.get("attributes", {}), + }) + elif chunk_type in ("relation", "edge", "kg_relation", "relationship"): + edges.append({ + "uuid": cid, + "name": chunk.get("relation_name", chunk.get("type_", "")), + "fact": content, + "source_node_uuid": chunk.get("source_id", chunk.get("src_id", "")), + "target_node_uuid": chunk.get("target_id", chunk.get("tgt_id", "")), + "attributes": chunk.get("attributes", {}), + }) + + return nodes, edges + + # ── 本地缓存 ────────────────────────────────────────────────────────────── + + def save_graph_locally(self, graph_id: str, + nodes: List[Dict], edges: List[Dict]) -> str: + """将图谱数据持久化到本地JSON文件""" + graph_dir = os.path.join(RAGFLOW_GRAPHS_DIR, graph_id) + os.makedirs(graph_dir, exist_ok=True) + + # 构建节点名称映射(用于边的补充信息) + node_name_map = {n["uuid"]: n.get("name", "") for n in nodes} + for edge in edges: + edge.setdefault("source_node_name", node_name_map.get(edge.get("source_node_uuid", ""), "")) + edge.setdefault("target_node_name", node_name_map.get(edge.get("target_node_uuid", ""), "")) + + graph_data = { + "graph_id": graph_id, + "nodes": nodes, + "edges": edges, + "node_count": len(nodes), + "edge_count": len(edges), + } + graph_file = os.path.join(graph_dir, "graph_data.json") + with open(graph_file, 'w', encoding='utf-8') as f: + json.dump(graph_data, f, ensure_ascii=False, indent=2) + logger.info(f"图谱数据已缓存: {graph_file} ({len(nodes)} 节点, {len(edges)} 边)") + return graph_file + + # ── 公开接口(与GraphBuilderService保持兼容)────────────────────────────── + + def build_graph_async( + self, + text: str, + ontology: Dict[str, Any], + graph_name: str = "MiroFish Graph", + chunk_size: int = 500, + chunk_overlap: int = 50, + batch_size: int = 3, + file_paths: Optional[List[str]] = None, + ) -> str: + """ + 异步构建知识图谱 + + Args: + text: 提取的文本(用于上传为单个文档) + ontology: 本体定义(RAGflow不使用本体,但保留参数以保持接口兼容) + graph_name: 图谱名称(用作RAGflow数据集名称) + file_paths: 可选,直接上传原始文件(优先于text) + + Returns: + task_id + """ + task_id = self.task_manager.create_task( + task_type="graph_build", + metadata={"graph_name": graph_name, "backend": "ragflow", + "text_length": len(text)}, + ) + thread = threading.Thread( + target=self._build_graph_worker, + args=(task_id, text, graph_name, file_paths), + daemon=True, + ) + thread.start() + return task_id + + def _build_graph_worker( + self, + task_id: str, + text: str, + graph_name: str, + file_paths: Optional[List[str]], + ): + """后台工作线程:调用RAGflow API完成图谱构建""" + try: + self.task_manager.update_task( + task_id, status=TaskStatus.PROCESSING, + progress=5, message="正在创建RAGflow数据集..." + ) + + # 1. 创建数据集 + dataset_id = self.create_dataset(graph_name) + graph_id = f"ragflow_{dataset_id}" + self.task_manager.update_task( + task_id, progress=10, + message=f"数据集已创建,正在上传文档..." + ) + + # 2. 上传文档 + document_ids: List[str] = [] + if file_paths: + total_files = len(file_paths) + for i, fp in enumerate(file_paths): + doc_id = self.upload_file_document(dataset_id, fp) + document_ids.append(doc_id) + progress = 10 + int((i + 1) / total_files * 20) + self.task_manager.update_task( + task_id, progress=progress, + message=f"已上传文件 {i + 1}/{total_files}" + ) + else: + doc_id = self.upload_text_as_document(dataset_id, text, "document.txt") + document_ids.append(doc_id) + self.task_manager.update_task(task_id, progress=20, message="文档上传完成") + + # 3. 触发解析 + self.task_manager.update_task( + task_id, progress=25, + message="正在启动知识图谱解析..." + ) + self.start_parsing(dataset_id, document_ids) + + # 4. 等待解析完成 + def wait_cb(msg: str, ratio: float): + progress = 25 + int(ratio * 55) # 25% → 80% + self.task_manager.update_task(task_id, progress=progress, message=msg) + + self.wait_for_parsing(dataset_id, document_ids, wait_cb) + + # 5. 获取图谱数据 + self.task_manager.update_task( + task_id, progress=82, message="正在获取知识图谱数据..." + ) + graph_data = self.get_graph_entities(dataset_id) + nodes = graph_data.get("nodes", []) + edges = graph_data.get("edges", []) + + # 6. 本地缓存 + self.task_manager.update_task( + task_id, progress=95, message="正在保存图谱数据..." + ) + self.save_graph_locally(graph_id, nodes, edges) + + # 完成 + self.task_manager.complete_task(task_id, { + "graph_id": graph_id, + "dataset_id": dataset_id, + "node_count": len(nodes), + "edge_count": len(edges), + }) + logger.info(f"RAGflow图谱构建完成: {graph_id}, 节点={len(nodes)}, 边={len(edges)}") + + except Exception as exc: + import traceback + error_msg = f"{str(exc)}\n{traceback.format_exc()}" + logger.error(f"RAGflow图谱构建失败: {error_msg}") + self.task_manager.fail_task(task_id, error_msg) + + def get_graph_data(self, graph_id: str) -> Dict[str, Any]: + """ + 获取图谱的完整数据(节点和边) + + 优先读取本地缓存,缓存不存在时从RAGflow实时获取。 + """ + graph_file = os.path.join(RAGFLOW_GRAPHS_DIR, graph_id, "graph_data.json") + if os.path.exists(graph_file): + with open(graph_file, 'r', encoding='utf-8') as f: + return json.load(f) + + # 本地缓存不存在,从RAGflow获取 + dataset_id = graph_id.removeprefix("ragflow_") + graph_data = self.get_graph_entities(dataset_id) + nodes = graph_data.get("nodes", []) + edges = graph_data.get("edges", []) + self.save_graph_locally(graph_id, nodes, edges) + + return { + "graph_id": graph_id, + "nodes": nodes, + "edges": edges, + "node_count": len(nodes), + "edge_count": len(edges), + } + + def delete_graph(self, graph_id: str): + """删除RAGflow数据集及本地缓存""" + dataset_id = graph_id.removeprefix("ragflow_") + + try: + self.delete_dataset(dataset_id) + logger.info(f"RAGflow数据集已删除: {dataset_id}") + except Exception as exc: + logger.warning(f"删除RAGflow数据集失败(可能已不存在): {exc}") + + graph_dir = os.path.join(RAGFLOW_GRAPHS_DIR, graph_id) + if os.path.exists(graph_dir): + shutil.rmtree(graph_dir) + logger.info(f"本地图谱缓存已删除: {graph_dir}") diff --git a/backend/app/services/simulation_manager.py b/backend/app/services/simulation_manager.py index 96c496fd..6d7a77b4 100644 --- a/backend/app/services/simulation_manager.py +++ b/backend/app/services/simulation_manager.py @@ -15,6 +15,7 @@ from ..config import Config from ..utils.logger import get_logger from .zep_entity_reader import ZepEntityReader, FilteredEntities +from .ragflow_entity_reader import RagflowEntityReader from .oasis_profile_generator import OasisProfileGenerator, OasisAgentProfile from .simulation_config_generator import SimulationConfigGenerator, SimulationParameters @@ -269,14 +270,20 @@ def prepare_simulation( sim_dir = self._get_simulation_dir(simulation_id) # ========== 阶段1: 读取并过滤实体 ========== + is_ragflow = state.graph_id.startswith("ragflow_") + if progress_callback: - progress_callback("reading", 0, "正在连接Zep图谱...") - - reader = ZepEntityReader() - + backend_name = "RAGflow" if is_ragflow else "Zep" + progress_callback("reading", 0, f"正在连接{backend_name}图谱...") + + if is_ragflow: + reader = RagflowEntityReader() + else: + reader = ZepEntityReader() + if progress_callback: progress_callback("reading", 30, "正在读取节点数据...") - + filtered = reader.filter_defined_entities( graph_id=state.graph_id, defined_entity_types=defined_entity_types, @@ -311,8 +318,9 @@ def prepare_simulation( total=total_entities ) - # 传入graph_id以启用Zep检索功能,获取更丰富的上下文 - generator = OasisProfileGenerator(graph_id=state.graph_id) + # 对于RAGflow图谱,跳过Zep检索(graph_id=None),使用实体自带的related_edges + zep_graph_id = None if is_ragflow else state.graph_id + generator = OasisProfileGenerator(graph_id=zep_graph_id) def profile_progress(current, total, msg): if progress_callback: @@ -339,10 +347,10 @@ def profile_progress(current, total, msg): entities=filtered.entities, use_llm=use_llm_for_profiles, progress_callback=profile_progress, - graph_id=state.graph_id, # 传入graph_id用于Zep检索 - parallel_count=parallel_profile_count, # 并行生成数量 - realtime_output_path=realtime_output_path, # 实时保存路径 - output_platform=realtime_platform # 输出格式 + graph_id=zep_graph_id, # RAGflow图谱传None,跳过Zep检索 + parallel_count=parallel_profile_count, + realtime_output_path=realtime_output_path, + output_platform=realtime_platform ) state.profiles_count = len(profiles) diff --git a/backend/app/services/simulation_runner.py b/backend/app/services/simulation_runner.py index 8c35380d..ad9050e6 100644 --- a/backend/app/services/simulation_runner.py +++ b/backend/app/services/simulation_runner.py @@ -373,13 +373,16 @@ def start_simulation( if not graph_id: raise ValueError("启用图谱记忆更新时必须提供 graph_id") - try: - ZepGraphMemoryManager.create_updater(simulation_id, graph_id) - cls._graph_memory_enabled[simulation_id] = True - logger.info(f"已启用图谱记忆更新: simulation_id={simulation_id}, graph_id={graph_id}") - except Exception as e: - logger.error(f"创建图谱记忆更新器失败: {e}") - cls._graph_memory_enabled[simulation_id] = False + if graph_id and graph_id.startswith("ragflow_"): + logger.info(f"RAGflow后端不支持实时图谱记忆更新,跳过: simulation_id={simulation_id}") + else: + try: + ZepGraphMemoryManager.create_updater(simulation_id, graph_id) + cls._graph_memory_enabled[simulation_id] = True + logger.info(f"已启用图谱记忆更新: simulation_id={simulation_id}, graph_id={graph_id}") + except Exception as e: + logger.error(f"创建图谱记忆更新器失败: {e}") + cls._graph_memory_enabled[simulation_id] = False else: cls._graph_memory_enabled[simulation_id] = False diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 4f5361d5..cbcc8c48 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ # 工具库 "python-dotenv>=1.0.0", "pydantic>=2.0.0", + "requests>=2.28.0", ] [project.optional-dependencies]