Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions backend/app/services/ontology_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

import json
import re
from typing import Dict, Any, List, Optional
from ..utils.llm_client import LLMClient

Expand Down Expand Up @@ -254,6 +255,22 @@ def _build_user_message(

return message

def _to_pascal_case(self, name: str, default: str = "Entity") -> str:
"""
将任意字符串转换为 PascalCase(仅包含字母数字),
以满足 Zep 对 name / source / target 的格式要求。
"""
if not name:
return default
# 已经是合法 PascalCase(以大写开头,仅字母数字)则原样返回
if re.fullmatch(r"[A-Z][a-zA-Z0-9]*", name):
return name
parts = re.split(r"[^a-zA-Z0-9]+|_", name)
parts = [p for p in parts if p]
if not parts:
return default
return "".join(p[:1].upper() + p[1:] for p in parts)

def _validate_and_process(self, result: Dict[str, Any]) -> Dict[str, Any]:
"""验证和后处理结果"""

Expand All @@ -265,8 +282,15 @@ def _validate_and_process(self, result: Dict[str, Any]) -> Dict[str, Any]:
if "analysis_summary" not in result:
result["analysis_summary"] = ""

# 验证实体类型
# 先规范化实体名称为 PascalCase,并做基本校验
normalized_entity_names: Dict[str, str] = {}
for entity in result["entity_types"]:
original_name = entity.get("name", "")
normalized_name = self._to_pascal_case(original_name or "Entity")
entity["name"] = normalized_name
if original_name:
normalized_entity_names[original_name] = normalized_name

if "attributes" not in entity:
entity["attributes"] = []
if "examples" not in entity:
Expand All @@ -275,14 +299,31 @@ def _validate_and_process(self, result: Dict[str, Any]) -> Dict[str, Any]:
if len(entity.get("description", "")) > 100:
entity["description"] = entity["description"][:97] + "..."

# 验证关系类型
# 验证并规范化关系类型
for edge in result["edge_types"]:
# 关系名本身也规范为 PascalCase(满足 Zep name 约束)
edge_name = edge.get("name", "")
edge["name"] = self._to_pascal_case(edge_name or "Edge", default="Edge")

if "source_targets" not in edge:
edge["source_targets"] = []
if "attributes" not in edge:
edge["attributes"] = []
if len(edge.get("description", "")) > 100:
edge["description"] = edge["description"][:97] + "..."

# Zep 的约束:source / target 也必须是 PascalCase
for st in edge["source_targets"]:
src = st.get("source", "")
tgt = st.get("target", "")
if src in normalized_entity_names:
st["source"] = normalized_entity_names[src]
else:
st["source"] = self._to_pascal_case(src or "Entity")
if tgt in normalized_entity_names:
st["target"] = normalized_entity_names[tgt]
else:
st["target"] = self._to_pascal_case(tgt or "Entity")

# Zep API 限制:最多 10 个自定义实体类型,最多 10 个自定义边类型
MAX_ENTITY_TYPES = 10
Expand Down