-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathevaluate.py
More file actions
78 lines (60 loc) · 2.04 KB
/
evaluate.py
File metadata and controls
78 lines (60 loc) · 2.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from typing import cast
import pandas as pd
import requests
from tqdm import tqdm
def evaluate_agent(csv_path: str, api_url: str) -> None:
try:
df = pd.read_csv(csv_path)
except FileNotFoundError:
print(f"Error: '{csv_path}' 파일을 찾을 수 없습니다.")
return
total_count = len(df)
correct_count = 0
# 숫자 정답 알파벳으로 매핑
answer_map: dict[int, str] = {1: "A", 2: "B", 3: "C", 4: "D"}
print("평가 시작")
for _, row in tqdm(df.iterrows(), total=total_count):
# 1. 문제 구성
query_text = (
f"{row['question']}\n"
f"A. {row['A']}\n"
f"B. {row['B']}\n"
f"C. {row['C']}\n"
f"D. {row['D']}"
)
# 2. API 서버 요청
pred: str = "Error"
try:
response = requests.post(
api_url,
json={"query": query_text},
timeout=10,
)
if response.status_code == 200:
pred = response.json().get("answer", "").strip().upper()
else:
pred = "Error"
except Exception as e:
print(f"\nRequest failed: {e}")
pred = "Error"
# 3. 정답 비교
raw_answer: int | str = cast(int | str, row["answer"])
# 정답이 숫자면 알파벳으로 변환
if isinstance(raw_answer, int):
answer: str = answer_map.get(raw_answer, str(raw_answer))
elif raw_answer.isdigit():
answer = answer_map.get(int(raw_answer), raw_answer)
else:
answer = raw_answer.strip().upper()
# 채점
if pred == answer:
correct_count += 1
# 4. 최종 점수 계산
accuracy = correct_count / total_count if total_count > 0 else 0
print("\n" + "=" * 30)
print(f"정확도: {accuracy:.4f}")
print("=" * 30)
if __name__ == "__main__":
API_URL = "http://127.0.0.1:8000/request"
DEV_DATA_PATH = "data/dev.csv"
evaluate_agent(DEV_DATA_PATH, API_URL)