-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtasks.py
More file actions
196 lines (165 loc) · 6.97 KB
/
tasks.py
File metadata and controls
196 lines (165 loc) · 6.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
from crewai import Task
# Human-readable task descriptions per benchmark
_BENCHMARK_LABELS = {
"causal_judgement": "BIG-Bench Hard causal_judgement",
"arc_challenge": "ARC-Challenge (allenai/ai2_arc)",
}
_QA_DATASET_DESC = {
"causal_judgement": "加载 BBH causal_judgement 原始数据集获取完整题目",
"arc_challenge": "直接从 results.jsonl 中读取完整题目和选项(ARC 数据已内嵌在结果文件中)",
}
_QA_ERROR_DESC = {
"causal_judgement": (
"- Yes/No 偏向(模型倾向答哪个)\n"
" - 错误的题目类型(哪类因果推理场景更难)\n"
" - 典型错误样例(列出 2-3 条最有代表性的错误)"
),
"arc_challenge": (
"- 模型最常误选的选项(A/B/C/D 分布)\n"
" - 错误的题目科学领域分布\n"
" - 典型错误样例(列出 2-3 条最有代表性的错误)"
),
}
_COMPARE_TASK_DESC = {
"causal_judgement": (
"调用工具查询 BIG-Bench causal_judgement 任务上的公开基线分数,重点查找:\n"
"1. GPT-4、GPT-3.5、PaLM 2 在 causal_judgement 上的 accuracy\n"
"2. LLaMA 系列开源模型的公开分数\n"
"3. Human baseline"
),
"arc_challenge": (
"调用工具查询 ARC-Challenge 任务上的公开基线分数,重点查找:\n"
"1. GPT-4、GPT-3.5、PaLM 2 在 ARC-Challenge 上的 accuracy\n"
"2. LLaMA 系列开源模型的公开分数(尤其是 LLaMA-3.1-8B)\n"
"3. Human baseline 和 Random baseline (25%)"
),
}
def create_tasks(runner, parser, analyst, qa_reviewer, comparator, publisher,
file_path: str, log_file: str, benchmark: str = "arc_challenge"):
bench_label = _BENCHMARK_LABELS.get(benchmark, benchmark)
report_file = f"outputs/{benchmark}/llama3.1_8b_report.md"
# ── Task 1: 确保推理结果就绪 ──────────────────────────────────────────────
run_task = Task(
description=f"""
检查推理结果文件 {file_path} 是否存在。
- 若文件不存在:调用工具对 {bench_label} 运行推理,
使用 llama3.1:8b,取前 30 条样本,将结果写入 {file_path}。
工具调用时传入 benchmark="{benchmark}"。
- 若文件已存在:直接返回文件的基本摘要(样本数、准确率)。
""",
expected_output="一句话说明结果文件状态(已存在/新生成)和准确率。",
agent=runner,
)
# ── Task 2: 解析数据结构 ──────────────────────────────────────────────────
parse_task = Task(
description=f"""
读取评测结果文件:{file_path}
请报告:
1. 文件总记录数
2. 所有字段名称
3. category 分布(每个类别的样本数量)
4. score 字段的整体分布(min/max/mean)
5. 平均推理延迟(latency_s 字段均值)
6. 随机列出 3 条样例记录(含 question_id、score、model_answer、targets)
如文件不存在或为空,明确报告错误原因。
""",
expected_output="""
结构化数据摘要,包含:
- 记录总数、字段列表
- category 分布
- score 统计数字
- latency 统计
- 3 条样例
""",
agent=parser,
context=[run_task],
)
# ── Task 3: 统计分析 ──────────────────────────────────────────────────────
stats_task = Task(
description=f"""
基于解析员提供的数据结构,对文件 {file_path} 做完整统计分析:
1. 按 category 分别计算:mean score、std、pass@1(score≥0.5 视为通过)
2. 计算 overall mean 和总体 pass@1
3. 统计答错题目总数和比例
4. 列出前 2 条答错样例(model_answer vs targets)
所有数字保留 4 位小数。
""",
expected_output="""
分维度统计表:
[category] n=? | mean=? | std=? | pass@1=?
以及 overall 汇总行。
答错分析:总错误数 / 比例 / 样例。
""",
agent=analyst,
context=[parse_task],
)
# ── Task 4: 逐题对照分析 ──────────────────────────────────────────────────
qa_desc = _QA_DATASET_DESC.get(benchmark, "")
error_desc = _QA_ERROR_DESC.get(benchmark, "")
qa_task = Task(
description=f"""
调用工具对 {file_path} 进行逐题对照分析:
1. {qa_desc}
2. 与推理结果合并,生成逐题对照表,保存到 outputs/qa_detail.md
3. 分析错误模式:
{error_desc}
""",
expected_output="""
- 确认 outputs/qa_detail.md 已生成
- 错误模式分析:偏向类型、难点类别、2-3 条代表性错误样例
""",
agent=qa_reviewer,
context=[stats_task],
)
# ── Task 5: 横向对比 ──────────────────────────────────────────────────────
compare_desc = _COMPARE_TASK_DESC.get(benchmark, "")
compare_task = Task(
description=f"""
{compare_desc}
调用工具时传入 task_name="{benchmark}"。
将 {{model_name}} 的 overall mean 与上述结果对比,说明其相对水平。
""",
expected_output="""
对比表:模型名 | accuracy | 来源
以及 {model_name} 的定位描述(1-2 句话)。
""",
agent=comparator,
context=[stats_task],
)
# ── Task 6: 生成报告 + GIF ────────────────────────────────────────────────
report_task = Task(
description=f"""
完成两件事:
**1. 生成 pipeline 可视化 GIF**
调用工具读取日志文件 {log_file},生成动画 GIF 到 outputs/pipeline_flow.gif。
**2. 撰写完整评测报告**
将前面所有分析师的输出整合为 Markdown 报告,保存到文件。
报告结构:
# {{model_name}} 在 {{benchmark_name}} 上的评测报告
## 执行摘要
(3-5 句话,含关键数字)
## 评测设置
- 模型、Benchmark、任务类型、样本量、推理方式
## 统计结果
(格式化为表格)
## 逐题错误分析
(来自 QA 分析师:错误模式、偏向分析、代表性错误样例)
## 与公开 Baseline 的对比
(来自比较分析师的对比表)
## 局限性说明
## 附件
- 逐题对照:outputs/qa_detail.md
- Pipeline 可视化:outputs/pipeline_flow.gif
- 原始日志:{log_file}
保存到:{report_file}
""",
expected_output=f"""
确认:
1. GIF 已生成(文件大小 KB)
2. 报告已保存到 {report_file}(含所有章节)
""",
agent=publisher,
context=[parse_task, stats_task, qa_task, compare_task],
output_file=report_file,
)
return run_task, parse_task, stats_task, qa_task, compare_task, report_task