Skip to content

Commit 83b7c83

Browse files
committed
add eval code and dataset
1 parent c70747e commit 83b7c83

File tree

10 files changed

+2534
-35
lines changed

10 files changed

+2534
-35
lines changed

data/livedrbench-subset.csv

Lines changed: 147 additions & 0 deletions
Large diffs are not rendered by default.

data/livedrbench.bin

Lines changed: 0 additions & 1 deletion
This file was deleted.

data/livedrbench.csv

Lines changed: 1626 additions & 0 deletions
Large diffs are not rendered by default.

src/decrypt_and_eval.py

Lines changed: 0 additions & 34 deletions
This file was deleted.

src/evals/datasets_flights.py

Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
import json
2+
from openai import OpenAI
3+
from utils import retry_on_exception
4+
5+
from typing import List
6+
7+
prompts = {
8+
'dict': """
9+
You are evaluating whether two dictionaries refer to the same information. Your task is to determine whether each corresponding value is equivalent.
10+
11+
For each key-value pair in the ground truth dictionary:
12+
- If the key is absent in the model output dictionary, it is not equivalent.
13+
- If the value is a number (integer or float), they are equivalent within a 1% margin of error.
14+
- If the value is a string, consider them equivalent if they are similar in meaning, or if one is a subset of the other.
15+
- Acronyms or shortened names are equivalent to their expanded forms.
16+
- "BERT" and "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding" -> equivalent
17+
- "ICLR" and "International Conference on Learning Representations" -> equivalent
18+
- "CNN" and "ResNet" -> equivalent (both are convolutional neural networks)
19+
- "Adam" and "SGD" -> not equivalent (different optimizers)
20+
- "0.8" and "0.80" -> equivalent numerically
21+
- If the value is a URL, they must be exactly the same, or a shortened version of the same URL.
22+
- If the value is a list, check if every element in the ground truth list is present in the model output list, regardless of order. Follow the same rules for strings as above.
23+
- If the value is a dictionary, check if every key-value pair in the ground truth dictionary is present in the model output dictionary, regardless of order. Again, follow the same rules for strings as above.
24+
- Some values may depend on information from other keys in the dictionary. Use such context to determine whether two values are equivalent, even if they are not identical on their own.
25+
- Ignore trivial formatting differences such as casing, whitespace, or common abbreviations if the meaning is preserved.
26+
27+
Now compare the following dictionaries:
28+
29+
Ground Truth:
30+
{dict_gt}
31+
32+
Model Output:
33+
{dict_pred}
34+
35+
Your output should be a JSON object with the same keys as the input dictionaries. For each key, the value should be:
36+
- 3: The predicted value closely matches the ground-truth, preserving almost all important information with minimal or no loss of meaning.
37+
- 2: The predicted value captures the main idea of the ground-truth but omits key details or has minor inaccuracies.
38+
- 1: The predicted value shows some overlap with the ground-truth but misses most of the essential meaning.
39+
- 0: The predicted value is largely incorrect, unrelated, or does not correspond meaningfully to the ground-truth value.
40+
41+
IMPORTANT:
42+
- Do NOT explain your reasoning.
43+
- Do NOT include any extra text.
44+
- Only output the final JSON object. Do not wrap it in markdown or say anything else.
45+
""",
46+
'list_of_dicts_match': """
47+
You are given a list of dictionaries (the ground truth) and a single predicted dictionary (the model output). Your task is to determine which dictionary in the ground truth list corresponds with the predicted dictionary, using a set of primary keys to identify the best match.
48+
49+
- You will be a given a list of primary keys that can be used to compare the dictionaries. If a primary key is not present in two dictionaries, make sure to use the other primary keys to determine the match.
50+
- Do not consider ANY other fields in the dictionaries except for the primary keys.
51+
- If the value is a string, consider them equivalent if they are similar in meaning, or if one is a subset of the other.
52+
- Match based on semantic similarity or if one is a subset of the other, not exact string equality.
53+
- Acronyms, shortened names, or partial names are equivalent to their expanded forms.
54+
- "BERT" and "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding" -> equivalent
55+
- "ICLR" and "International Conference on Learning Representations" -> equivalent
56+
- A prediction may include extra info (e.g., a year or qualifier). If the ground truth is a subset and meaning is preserved, treat them as equivalent. If the ground truth includes a qualifier, the prediction's qualifier must match if present.
57+
- You may also be provided with a set of extra evaluation notes that provide more specific evaluation criteria. These extra rules may override the general rules above.
58+
- If after all this no match exists, return an empty dictionary.
59+
60+
Now compare the following:
61+
62+
Ground Truth List:
63+
{list_gt}
64+
65+
Model Output:
66+
{dict_pred}
67+
68+
Primary Keys:
69+
{primary_keys}
70+
71+
Extra Evaluation Notes:
72+
{note}
73+
74+
Your output should be the dictionary from the ground truth list that best matches the model output, or an empty dictionary if no match is found.
75+
76+
IMPORTANT:
77+
- Do NOT explain your reasoning.
78+
- Do NOT include any extra text.
79+
- Only output the final dictionary as a JSON object. Do not wrap it in markdown or say anything else.
80+
""",
81+
}
82+
83+
def find_and_remove_dict(dict_to_remove, list_of_dicts, primary_keys):
84+
if dict_to_remove is None:
85+
return list_of_dicts
86+
87+
for i, d in enumerate(list_of_dicts):
88+
if all(d.get(pk) == dict_to_remove.get(pk) for pk in primary_keys):
89+
return list_of_dicts[:i] + list_of_dicts[i+1:]
90+
return list_of_dicts
91+
92+
def evaluate_dict_item(client, judge_name, key, gt_value, pred_value):
93+
prompt = prompts['dict_item'].format(key=key, gt_value=json.dumps(gt_value), pred_value=json.dumps(pred_value))
94+
response = client.chat.completions.create(
95+
model=judge_name,
96+
messages=[{"role": "user", "content": prompt}],
97+
max_tokens=1,
98+
temperature=0.0
99+
)
100+
eval_output = int(response.choices[0].message.content.strip())
101+
return eval_output
102+
103+
@retry_on_exception(max_retries=3, delay=1)
104+
def evaluate_dict(client, judge_name, gt_dict, pred_dict):
105+
prompt = prompts['dict'].format(dict_gt=json.dumps(gt_dict), dict_pred=json.dumps(pred_dict))
106+
response = client.chat.completions.create(
107+
model=judge_name,
108+
messages=[{"role": "user", "content": prompt}],
109+
max_tokens=1024,
110+
temperature=0.0
111+
)
112+
output = response.choices[0].message.content.replace('```json', '').replace('```', '').strip()
113+
eval_output = json.loads(output)
114+
return eval_output
115+
116+
@retry_on_exception(max_retries=3, delay=1)
117+
def evaluate_list_of_dicts(client, judge_name, list_gt, dict_pred, eval_info):
118+
primary_keys = eval_info['primary_keys']
119+
note = eval_info.get('evaluation_note', {})
120+
match_prompt = prompts['list_of_dicts_match'].format(list_gt=json.dumps(list_gt), dict_pred=json.dumps(dict_pred), primary_keys=json.dumps(primary_keys), note=json.dumps(note))
121+
response = client.chat.completions.create(
122+
model=judge_name,
123+
messages=[{"role": "user", "content": match_prompt}],
124+
max_tokens=1024,
125+
temperature=0.0
126+
)
127+
output = response.choices[0].message.content.replace('```json', '').replace('```', '').strip()
128+
extracted_gt_dict = json.loads(output)
129+
130+
if len(extracted_gt_dict) == 0:
131+
return None, None
132+
133+
eval_prompt = prompts['dict'].format(dict_gt=json.dumps(extracted_gt_dict), dict_pred=json.dumps(dict_pred))
134+
response = client.chat.completions.create(
135+
model=judge_name,
136+
messages=[{"role": "user", "content": eval_prompt}],
137+
max_tokens=1024,
138+
temperature=0.0
139+
)
140+
141+
output = response.choices[0].message.content.replace('```json', '').replace('```', '').strip()
142+
eval_output = json.loads(output)
143+
return extracted_gt_dict, eval_output
144+
145+
def grade(judge_name: str, key: str, ground_truths: List[List[dict] | dict], preds: List[List[dict] | dict], eval_info: dict | None) -> dict:
146+
client = OpenAI()
147+
148+
evaluations = []
149+
for gt, pred in zip(ground_truths, preds):
150+
if isinstance(gt, dict):
151+
eval_output = evaluate_dict(client, judge_name, gt, pred)
152+
evaluations.append({ 'type': 'dict', 'extracted_gt': None, 'pred': pred, 'eval_output': eval_output })
153+
elif isinstance(gt, list) and isinstance(gt[0], dict):
154+
for row in pred:
155+
extracted_gt_dict, eval_output = evaluate_list_of_dicts(client, judge_name, gt, row, eval_info)
156+
if eval_output is None:
157+
evaluations.append({ 'type': 'list_of_dicts_match', 'extracted_gt': None, 'pred': row, 'eval_output': None })
158+
continue
159+
160+
gt = find_and_remove_dict(extracted_gt_dict, gt, eval_info['primary_keys'])
161+
evaluations.append({ 'type': 'list_of_dicts_match', 'extracted_gt': extracted_gt_dict, 'pred': row, 'eval_output': eval_output })
162+
else:
163+
raise ValueError(f"Unsupported ground truth type: {type(gt)} for key {key}")
164+
165+
main_claims = eval_info.get('main_claims', [])
166+
ignore_keys = set(eval_info.get('ignore_keys', []))
167+
168+
gt_key_count = 0
169+
for gt in ground_truths:
170+
if type(gt) == dict:
171+
gt_key_count += len([k for k in gt if k not in ignore_keys])
172+
elif type(gt) == list and type(gt[0]) == dict:
173+
for item in gt:
174+
gt_key_count += len([k for k in item if k not in ignore_keys])
175+
176+
eval_key_count = 0
177+
for evaluation in evaluations:
178+
pred = evaluation['pred']
179+
assert type(pred) is dict
180+
eval_key_count += len([k for k in pred if k not in ignore_keys])
181+
182+
# search for a possible dataset/flight identification task. if this is not correct, there is no point in evaluating the rest of the claims
183+
is_identification_claim_correct = True
184+
for evaluation in evaluations:
185+
evaluation_type = evaluation['type']
186+
if evaluation_type != 'dict':
187+
continue
188+
189+
pred, eval_output = evaluation['pred'], evaluation['eval_output']
190+
191+
for k, v in eval_output.items():
192+
if k in main_claims and v <= 1:
193+
is_identification_claim_correct = False
194+
break
195+
196+
if not is_identification_claim_correct:
197+
break
198+
199+
# now do the actual evaluation
200+
eval_score = 0
201+
for evaluation in evaluations:
202+
pred, eval_output = evaluation['pred'], evaluation['eval_output']
203+
if eval_output is None:
204+
continue
205+
206+
# check if all main_extractive_claims, if any, are correct
207+
# is_identification_claim_correct will be False if the dataset or flight number identification task was not correct
208+
# leading to everything being considered incorrect
209+
are_main_extractive_claims_correct = is_identification_claim_correct
210+
for k, v in eval_output.items():
211+
if k in main_claims and v <= 1:
212+
are_main_extractive_claims_correct = False
213+
break
214+
215+
# if main claims are wrong, all claims for this paritular evaluation are not considered correct
216+
if not are_main_extractive_claims_correct:
217+
for k, v in eval_output.items():
218+
eval_output[k] = 0
219+
220+
# binarize
221+
for k, v in eval_output.items():
222+
eval_output[k] = 0 if v <= 1 else 3
223+
224+
eval_score += sum([v for k, v in eval_output.items() if k not in ignore_keys])
225+
226+
eval_score /= 3 # normalize to [0, 1] range
227+
assert eval_score <= eval_key_count, f"Eval score {eval_score} exceeds eval key count {eval_key_count} for key {key}"
228+
assert eval_score <= gt_key_count, f"Eval score {eval_score} exceeds ground truth key count {gt_key_count} for key {key}"
229+
230+
precision = eval_score / eval_key_count if eval_key_count > 0 else 0
231+
recall = eval_score / gt_key_count if gt_key_count > 0 else 0
232+
f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
233+
# print(key, precision, recall, f1)
234+
235+
return { 'precision': precision, 'recall': recall, 'f1': f1 }

src/evals/entities.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import json
2+
from openai import OpenAI
3+
from utils import retry_on_exception
4+
5+
from typing import List
6+
7+
prompts = {
8+
'str_search': """
9+
You are evaluating whether a predicted string is present in a ground-truth list of strings.
10+
11+
- Allow for minor differences in formatting, such as casing, whitespace, or common abbreviations.
12+
- Allow for minor variation in names, such as missing middle names or switched first and last names.
13+
- If found, return the matching string from the ground truth list in JSON format: {{ "match": <matched_string> }}
14+
- If not, return an empty dictionary.
15+
16+
Now compare the following:
17+
18+
Ground Truth List:
19+
{gt_list}
20+
21+
Predicted String:
22+
{pred_str}
23+
24+
IMPORTANT:
25+
- Do NOT explain your reasoning.
26+
- Do NOT include any extra text.
27+
- Only output the JSON object. Do not wrap it in markdown or say anything else.
28+
"""
29+
}
30+
31+
def find_and_remove_str(str_to_remove, list_of_strs):
32+
if str_to_remove is None:
33+
return list_of_strs
34+
35+
for i, d in enumerate(list_of_strs):
36+
if d.lower() == str_to_remove.lower():
37+
return list_of_strs[:i] + list_of_strs[i+1:]
38+
return list_of_strs
39+
40+
@retry_on_exception(max_retries=3, delay=1)
41+
def evaluate_list_of_str(client, judge_name, gt_list, pred_value):
42+
prompt = prompts['str_search'].format(gt_list=json.dumps(gt_list), pred_str=json.dumps(pred_value))
43+
response = client.chat.completions.create(
44+
model=judge_name,
45+
messages=[{"role": "user", "content": prompt}],
46+
max_tokens=500,
47+
temperature=0.0
48+
)
49+
output = response.choices[0].message.content.strip()
50+
eval_output = json.loads(output)
51+
52+
if 'match' not in eval_output:
53+
return None
54+
return eval_output
55+
56+
57+
def grade(judge_name: str, key: str, ground_truths: List[List[str]], preds: List[List[str]], eval_info: dict | None) -> dict:
58+
client = OpenAI()
59+
60+
evaluations = []
61+
for gt, pred in zip(ground_truths, preds):
62+
for row in pred:
63+
eval_output = evaluate_list_of_str(client, judge_name, gt, row)
64+
if eval_output is None:
65+
evaluations.append({ 'type': 'list_of_dicts_match', 'extracted_gt': None, 'pred': row, 'eval_output': None })
66+
continue
67+
68+
extracted_gt = eval_output['match']
69+
gt = find_and_remove_str(extracted_gt, gt)
70+
evaluations.append({ 'type': 'str_search', 'extracted_gt': extracted_gt, 'pred': row, 'eval_output': extracted_gt })
71+
72+
gt_key_count = 0
73+
for gt in ground_truths:
74+
gt_key_count += len(gt)
75+
76+
eval_key_count = len(evaluations)
77+
eval_score = sum([1 for evaluation in evaluations if evaluation['extracted_gt'] is not None])
78+
79+
assert eval_score <= eval_key_count, f"Eval score {eval_score} exceeds eval key count {eval_key_count} for key {key}"
80+
assert eval_score <= gt_key_count, f"Eval score {eval_score} exceeds ground truth key count {gt_key_count} for key {key}"
81+
82+
precision = eval_score / eval_key_count if eval_key_count > 0 else 0
83+
recall = eval_score / gt_key_count if gt_key_count > 0 else 0
84+
f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
85+
# print(key, precision, recall, f1)
86+
87+
return { 'precision': precision, 'recall': recall, 'f1': f1 }

0 commit comments

Comments
 (0)