-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathevaluate_baseline.py
More file actions
183 lines (148 loc) · 5.8 KB
/
evaluate_baseline.py
File metadata and controls
183 lines (148 loc) · 5.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
"""
Full Context Baseline 评估脚本
只需要跑一次,结果保存后供 evaluate.py compare 使用
"""
import json
import time
from pathlib import Path
from datetime import datetime
from typing import List
from llm_client import LLMClient
from config import DATA_PATHS, LLM_CONFIG, LLM_PROVIDERS
from evaluate import compute_metrics
def format_full_context(conversations: List[dict]) -> str:
"""格式化全部对话为 LLM 上下文"""
parts = []
for conv in conversations:
date = conv.get("session_date", "")
user_msg = conv.get("user", "")
assistant_msg = conv.get("assistant", "")
note_id = conv.get("id", "")
parts.append(f"[{note_id}] [{date}]\nUser: {user_msg}\nAssistant: {assistant_msg}")
return "\n\n".join(parts)
def generate_baseline_answer(query: str, full_context: str, llm: LLMClient) -> str:
"""Baseline: 用全部对话生成答案"""
prompt = f"""Based on the following conversation history, answer the question.
{full_context}
Question: {query}
Instructions:
- Answer based ONLY on the information provided above
- Output ONLY the final answer (no explanation, no extra words, no quotes)
- Be concise and direct
- If the question asks "when"/time/date: output an ABSOLUTE date/month/year (e.g., "February, 2023", "20 June, 2023"); avoid relative terms like "next month", "tomorrow", "yesterday"
- Prefer explicit dates mentioned in the conversation text; if only relative time is given, infer the absolute time from the note date shown in brackets
- If the question is yes/no: output ONLY "Yes" or "No"
Answer:"""
return llm.call_for_answer(prompt) or ""
def run_baseline(num_queries: int = None):
"""
运行 Full Context Baseline 评估
结果保存到 eval_results/baseline_*.json
"""
print("=" * 60)
print("Full Context Baseline Evaluation")
print("=" * 60)
# 加载 benchmark
script_dir = Path(__file__).parent
benchmark_file = script_dir / DATA_PATHS["benchmark_file"]
with open(benchmark_file, "r", encoding="utf-8") as f:
data = json.load(f)
queries = data.get("queries", [])
conversations = data.get("conversations", [])
if num_queries:
queries = queries[:num_queries]
# 格式化全部对话
full_context = format_full_context(conversations)
context_char_count = len(full_context)
print(f"\n[Baseline] {len(queries)} queries, {len(conversations)} conversations")
print(f"[Baseline] Full context: {context_char_count:,} chars")
# 初始化 LLM
llm = LLMClient()
results = []
total_time = 0
for i, q in enumerate(queries):
query_id = q.get("id", f"Q{i+1}")
query_text = q.get("query", "")
gold_notes = q.get("gold_notes", [])
gold_answer = q.get("gold_answer", "")
category = q.get("category", 0)
if (i + 1) % 5 == 0:
print(f"[Baseline] Processing {i + 1}/{len(queries)}...")
start = time.time()
answer = generate_baseline_answer(query_text, full_context, llm)
elapsed = time.time() - start
total_time += elapsed
results.append({
"id": query_id,
"query": query_text,
"category": category,
"gold_notes": gold_notes,
"gold_answer": gold_answer,
"retrieved_indices": [], # baseline 不做检索
"answer": answer,
"elapsed": elapsed
})
# 获取 token 使用统计
usage = llm.get_answer_usage()
# 计算答案指标
metrics = compute_metrics(results, evaluate_answers=True)
# 输出结果
print("\n" + "=" * 60)
print("Baseline Results")
print("=" * 60)
print(f"\nAnswer Metrics:")
print(f" F1: {metrics['answer']['overall']['f1']:.4f}")
print(f" BLEU-1: {metrics['answer']['overall']['bleu1']:.4f}")
print(f"\nToken Usage:")
print(f" Prompt tokens: {usage['prompt_tokens']:,}")
print(f" Completion tokens: {usage['completion_tokens']:,}")
print(f" Total tokens: {usage['total_tokens']:,}")
print(f" Calls: {usage['call_count']}")
print(f"\nTotal time: {total_time:.2f}s")
print(f"Avg time per query: {total_time/len(queries):.2f}s")
# 保存结果
provider = LLM_CONFIG["active_provider"]
model_name = LLM_PROVIDERS[provider]["model"].split("/")[-1]
temp_str = str(LLM_CONFIG["temperature"]).replace(".", "_")
date_str = datetime.now().strftime("%m%d")
output_dir = script_dir / "eval_results"
output_dir.mkdir(exist_ok=True)
output = {
"config": {
"num_queries": len(queries),
"conversation_count": len(conversations),
"context_chars": context_char_count,
"llm_provider": provider,
"model": LLM_PROVIDERS[provider]["model"],
"temperature": LLM_CONFIG["temperature"],
},
"metrics": metrics,
"usage": usage,
"total_time": total_time,
"results": results, # 每个问题的详细结果
}
# 找最大 id
import glob
pattern = str(output_dir / f"baseline_{date_str}_{model_name}_{temp_str}_*.json")
existing = glob.glob(pattern)
max_id = 0
for f in existing:
try:
file_id = int(Path(f).stem.split("_")[-1])
max_id = max(max_id, file_id)
except ValueError:
pass
next_id = max_id + 1
output_file = output_dir / f"baseline_{date_str}_{model_name}_{temp_str}_{next_id:03d}.json"
with open(output_file, "w", encoding="utf-8") as f:
json.dump(output, f, indent=2, ensure_ascii=False)
print(f"\n[Baseline] Results saved to {output_file}")
return output
if __name__ == "__main__":
import sys
num_queries = None
for a in sys.argv[1:]:
if a.isdigit():
num_queries = int(a)
break
run_baseline(num_queries=num_queries)