-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
109 lines (90 loc) · 3.28 KB
/
main.py
File metadata and controls
109 lines (90 loc) · 3.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import json
import time
from dotenv import load_dotenv
from crewai import Crew, Process
from agents import create_agents
from tasks import create_tasks
from pipeline_logger import logger
load_dotenv()
MODEL_NAME = "llama3.1:8b"
BENCHMARK = "arc_challenge" # causal_judgement | arc_challenge
BENCHMARK_NAME = {
"causal_judgement": "BIG-Bench Hard causal_judgement",
"arc_challenge": "ARC-Challenge",
}[BENCHMARK]
RESULTS_FILE = f"data/{BENCHMARK}/results.jsonl"
LOG_FILE = f"outputs/{BENCHMARK}/pipeline_log.jsonl"
AGENT_ROLES = [
"Benchmark Runner",
"Benchmark Data Parser",
"Statistical Analyst",
"QA Detail Reviewer",
"Comparative Research Analyst",
"Report Publisher",
]
def main():
os.makedirs(f"outputs/{BENCHMARK}", exist_ok=True)
# 设置本次 benchmark 的日志路径
logger.set_log_file(LOG_FILE)
# 清空旧日志
if os.path.exists(LOG_FILE):
os.remove(LOG_FILE)
# 安装 stdout 拦截器(实时捕获 crewai verbose 输出)
interceptor = logger.install_stdout_interceptor()
runner, parser, analyst, qa_reviewer, comparator, publisher = create_agents()
run_task, parse_task, stats_task, qa_task, compare_task, report_task = create_tasks(
runner, parser, analyst, qa_reviewer, comparator, publisher,
file_path=RESULTS_FILE,
log_file=LOG_FILE,
benchmark=BENCHMARK,
)
tasks = [run_task, parse_task, stats_task, qa_task, compare_task, report_task]
for i, (task, role) in enumerate(zip(tasks, AGENT_ROLES)):
logger.log_task_start(i, task.description, role)
crew = Crew(
agents=[runner, parser, analyst, qa_reviewer, comparator, publisher],
tasks=tasks,
process=Process.sequential,
verbose=True,
task_callback=logger.make_task_callback(),
)
print("\n" + "="*60)
print(f" Benchmark: {BENCHMARK_NAME}")
print(f" Model: {MODEL_NAME}")
print(f" Log: {LOG_FILE}")
print("="*60 + "\n")
t0 = time.time()
result = crew.kickoff(inputs={
"model_name": MODEL_NAME,
"benchmark_name": BENCHMARK_NAME,
"benchmark": BENCHMARK,
})
elapsed = round(time.time() - t0, 1)
logger.restore_stdout(interceptor)
out_dir = f"outputs/{BENCHMARK}"
report_file = f"{out_dir}/llama3.1_8b_report.md"
summary = {
"model": MODEL_NAME,
"benchmark": BENCHMARK_NAME,
"total_elapsed_s": elapsed,
"log_entries": len(logger.entries),
"outputs": {
"report": report_file,
"qa_detail": f"{out_dir}/qa_detail.md",
"gif": f"{out_dir}/pipeline_flow.gif",
"log": LOG_FILE,
},
"finished_at": time.strftime("%Y-%m-%d %H:%M:%S"),
}
with open(f"{out_dir}/pipeline_summary.json", "w", encoding="utf-8") as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
print(f"\n{'='*60}")
print(f" 报告: {report_file}")
print(f" 逐题对照: {out_dir}/qa_detail.md")
print(f" GIF: {out_dir}/pipeline_flow.gif")
print(f" 日志: {LOG_FILE} ({len(logger.entries)} 条事件)")
print(f" 耗时: {elapsed}s")
print("="*60)
if __name__ == "__main__":
main()