@@ -98,15 +98,15 @@ def parse_output(output):
9898
9999def grade (judge_name : str , key : str , ground_truths : List [List [dict ]], preds : List [List [dict ]], eval_info : dict | None ) -> dict :
100100 client = OpenAI ()
101-
101+ assert len ( preds ) == 1
102102 # compute metrics
103103 if "material" in ground_truths [0 ][0 ].keys ():
104- m_precision , m_recall , m_f1 = compute_metrics (ground_truths [0 ], preds , key = "material" , client = client , judge_name = judge_name )
105- p_precision , p_recall , p_f1 = compute_metrics (ground_truths [0 ], preds , key = "paper_title" , client = client , judge_name = judge_name )
104+ m_precision , m_recall , m_f1 = compute_metrics (ground_truths [0 ], preds [ 0 ] , key = "material" , client = client , judge_name = judge_name )
105+ p_precision , p_recall , p_f1 = compute_metrics (ground_truths [0 ], preds [ 0 ] , key = "paper_title" , client = client , judge_name = judge_name )
106106
107107 return {"precision" : m_precision * p_precision , "recall" : m_recall * p_recall , "f1" : m_f1 * p_f1 }
108108
109109 else :
110- precision , recall , f1 = compute_metrics (ground_truths [0 ], preds , key = "paper_title" , client = client , judge_name = judge_name )
110+ precision , recall , f1 = compute_metrics (ground_truths [0 ], preds [ 0 ] , key = "paper_title" , client = client , judge_name = judge_name )
111111
112112 return {"precision" : precision , "recall" : recall , "f1" : f1 }
0 commit comments