Skip to content

Commit 6ff85b6

Browse files
scifacts minor fix
handle list of lists
1 parent e7420d1 commit 6ff85b6

1 file changed

Lines changed: 4 additions & 4 deletions

File tree

src/evals/scifacts.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,15 +98,15 @@ def parse_output(output):
9898

9999
def grade(judge_name: str, key: str, ground_truths: List[List[dict]], preds: List[List[dict]], eval_info: dict | None) -> dict:
100100
client = OpenAI()
101-
101+
assert len(preds) == 1
102102
# compute metrics
103103
if "material" in ground_truths[0][0].keys():
104-
m_precision, m_recall, m_f1 = compute_metrics(ground_truths[0], preds, key="material", client=client, judge_name=judge_name)
105-
p_precision, p_recall, p_f1 = compute_metrics(ground_truths[0], preds, key="paper_title", client=client, judge_name=judge_name)
104+
m_precision, m_recall, m_f1 = compute_metrics(ground_truths[0], preds[0], key="material", client=client, judge_name=judge_name)
105+
p_precision, p_recall, p_f1 = compute_metrics(ground_truths[0], preds[0], key="paper_title", client=client, judge_name=judge_name)
106106

107107
return {"precision": m_precision*p_precision, "recall": m_recall*p_recall, "f1": m_f1*p_f1}
108108

109109
else:
110-
precision, recall, f1 = compute_metrics(ground_truths[0], preds, key="paper_title", client=client, judge_name=judge_name)
110+
precision, recall, f1 = compute_metrics(ground_truths[0], preds[0], key="paper_title", client=client, judge_name=judge_name)
111111

112112
return {"precision": precision, "recall": recall, "f1": f1}

0 commit comments

Comments
 (0)