-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathimport_data.py
More file actions
129 lines (107 loc) · 4.63 KB
/
import_data.py
File metadata and controls
129 lines (107 loc) · 4.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""
V4 数据导入脚本 - 简化版 (w/o LineNode)
使用 LLM 提取实体和关系构建真正的知识图谱
消融实验:移除 LineNode 演化链
"""
import json
import argparse
from pathlib import Path
from knowledge_graph_nolinenode import KnowledgeGraphNoLineNode as KnowledgeGraphV2
from kg_builder_nolinenode import KGBuilderNoLineNode as KnowledgeGraphBuilder
from config import DATA_PATHS, get_neo4j_config
def main():
"""主函数"""
parser = argparse.ArgumentParser(description="Import conversation data into EvoGraphV4")
parser.add_argument("--conversation", "-c", type=str, default="conv-26",
help="Conversation ID (e.g., conv-26, conv-30)")
args = parser.parse_args()
print("=" * 60)
print("EvoGraph V2 Data Import")
print("LLM Entity Extraction + Real Knowledge Graph")
print("=" * 60)
# 确定数据路径
script_dir = Path(__file__).parent
benchmark_file = script_dir / f"../../datasets/locomo10_split/locomo_{args.conversation}_benchmark.json"
if not benchmark_file.exists():
print(f"[ERROR] Benchmark file not found: {benchmark_file}")
print(f"[TIP] Available conversations: conv-26, conv-30, conv-41, conv-42, conv-43, conv-44, conv-47, conv-48, conv-49, conv-50")
return
# 加载数据
print(f"\n[Load] Reading from {benchmark_file}...")
with open(benchmark_file, "r", encoding="utf-8") as f:
data = json.load(f)
conversations = data.get("conversations", [])
print(f"[Load] Found {len(conversations)} conversations")
# 创建图谱(使用对应的数据库配置)
print(f"\n[Init] Creating KnowledgeGraphV2 for {args.conversation}...")
db_config = get_neo4j_config(args.conversation)
print(f"[DB] Connecting to {db_config['uri']}...")
kg = KnowledgeGraphV2(neo4j_config=db_config)
# 检查现有数据
current_stats = kg.get_stats()
if current_stats.get("note_count", 0) > 0:
print(f"\n[Warning] Graph already contains data:")
for k, v in current_stats.items():
print(f" {k}: {v}")
response = input("\nClear existing data? (y/n): ").strip().lower()
if response == "y":
print("[Clear] Clearing existing data...")
kg.clear_all()
else:
print("[Skip] Keeping existing data")
kg.close()
return
# 构建图谱
print("\n" + "=" * 60)
print("Building Knowledge Graph with LLM Extraction...")
print("=" * 60 + "\n")
builder = KnowledgeGraphBuilder(kg)
stats = builder.build_from_conversations(conversations)
# 打印结果
print("\n" + "=" * 60)
print("Import Complete")
print("=" * 60)
print("\nGraph Statistics:")
for k, v in stats.items():
print(f" {k}: {v}")
# 打印实体样例
print("\nSample Entities:")
entities = kg.get_all_entities()[:10]
for e in entities:
print(f" {e['name']} ({e.get('type', 'Unknown')}): {e.get('display_name', e['name'])}")
# Token 使用统计和成本计算
total_usage = builder.extractor.llm.get_total_usage()
if total_usage["call_count"] > 0:
print("\n" + "=" * 60)
print("Token Usage & Cost (Construction Phase)")
print("=" * 60)
print(f"\nToken Usage (All LLM Calls):")
print(f" Input tokens: {total_usage['prompt_tokens']:,}")
print(f" Output tokens: {total_usage['completion_tokens']:,}")
print(f" Total tokens: {total_usage['total_tokens']:,}")
print(f" API calls: {total_usage['call_count']}")
# 计算成本(价格:input $0.15/1M, output $0.60/1M)
input_cost = total_usage['prompt_tokens'] / 1_000_000 * 0.15
output_cost = total_usage['completion_tokens'] / 1_000_000 * 0.60
total_cost = input_cost + output_cost
print(f"\nEstimated Cost:")
print(f" Input: ${input_cost:.4f}")
print(f" Output: ${output_cost:.4f}")
print(f" Total: ${total_cost:.4f}")
# 保存 construction token usage 到文件
construction_usage_file = script_dir / f"construction_usage_{args.conversation}.json"
with open(construction_usage_file, "w", encoding="utf-8") as f:
json.dump({
"conversation": args.conversation,
"token_usage": total_usage,
"cost": {
"input": input_cost,
"output": output_cost,
"total": total_cost
}
}, f, indent=2)
print(f"\n[Save] Construction usage saved to {construction_usage_file}")
kg.close()
print("\n[Done]")
if __name__ == "__main__":
main()