|
| 1 | +import json |
| 2 | +from openai import OpenAI |
| 3 | +from utils import retry_on_exception |
| 4 | + |
| 5 | +from typing import List |
| 6 | + |
| 7 | +prompts = { |
| 8 | + 'dict': """ |
| 9 | +You are evaluating whether two dictionaries refer to the same information. Your task is to determine whether each corresponding value is equivalent. |
| 10 | +
|
| 11 | +For each key-value pair in the ground truth dictionary: |
| 12 | +- If the key is absent in the model output dictionary, it is not equivalent. |
| 13 | +- If the value is a number (integer or float), they are equivalent within a 1% margin of error. |
| 14 | +- If the value is a string, consider them equivalent if they are similar in meaning, or if one is a subset of the other. |
| 15 | + - Acronyms or shortened names are equivalent to their expanded forms. |
| 16 | + - "BERT" and "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding" -> equivalent |
| 17 | + - "ICLR" and "International Conference on Learning Representations" -> equivalent |
| 18 | + - "CNN" and "ResNet" -> equivalent (both are convolutional neural networks) |
| 19 | + - "Adam" and "SGD" -> not equivalent (different optimizers) |
| 20 | + - "0.8" and "0.80" -> equivalent numerically |
| 21 | +- If the value is a URL, they must be exactly the same, or a shortened version of the same URL. |
| 22 | +- If the value is a list, check if every element in the ground truth list is present in the model output list, regardless of order. Follow the same rules for strings as above. |
| 23 | +- If the value is a dictionary, check if every key-value pair in the ground truth dictionary is present in the model output dictionary, regardless of order. Again, follow the same rules for strings as above. |
| 24 | +- Some values may depend on information from other keys in the dictionary. Use such context to determine whether two values are equivalent, even if they are not identical on their own. |
| 25 | +- Ignore trivial formatting differences such as casing, whitespace, or common abbreviations if the meaning is preserved. |
| 26 | +
|
| 27 | +Now compare the following dictionaries: |
| 28 | +
|
| 29 | +Ground Truth: |
| 30 | +{dict_gt} |
| 31 | +
|
| 32 | +Model Output: |
| 33 | +{dict_pred} |
| 34 | +
|
| 35 | +Your output should be a JSON object with the same keys as the input dictionaries. For each key, the value should be: |
| 36 | +- 3: The predicted value closely matches the ground-truth, preserving almost all important information with minimal or no loss of meaning. |
| 37 | +- 2: The predicted value captures the main idea of the ground-truth but omits key details or has minor inaccuracies. |
| 38 | +- 1: The predicted value shows some overlap with the ground-truth but misses most of the essential meaning. |
| 39 | +- 0: The predicted value is largely incorrect, unrelated, or does not correspond meaningfully to the ground-truth value. |
| 40 | +
|
| 41 | +IMPORTANT: |
| 42 | +- Do NOT explain your reasoning. |
| 43 | +- Do NOT include any extra text. |
| 44 | +- Only output the final JSON object. Do not wrap it in markdown or say anything else. |
| 45 | +""", |
| 46 | + 'list_of_dicts_match': """ |
| 47 | +You are given a list of dictionaries (the ground truth) and a single predicted dictionary (the model output). Your task is to determine which dictionary in the ground truth list corresponds with the predicted dictionary, using a set of primary keys to identify the best match. |
| 48 | +
|
| 49 | +- You will be a given a list of primary keys that can be used to compare the dictionaries. If a primary key is not present in two dictionaries, make sure to use the other primary keys to determine the match. |
| 50 | +- Do not consider ANY other fields in the dictionaries except for the primary keys. |
| 51 | +- If the value is a string, consider them equivalent if they are similar in meaning, or if one is a subset of the other. |
| 52 | +- Match based on semantic similarity or if one is a subset of the other, not exact string equality. |
| 53 | + - Acronyms, shortened names, or partial names are equivalent to their expanded forms. |
| 54 | + - "BERT" and "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding" -> equivalent |
| 55 | + - "ICLR" and "International Conference on Learning Representations" -> equivalent |
| 56 | + - A prediction may include extra info (e.g., a year or qualifier). If the ground truth is a subset and meaning is preserved, treat them as equivalent. If the ground truth includes a qualifier, the prediction's qualifier must match if present. |
| 57 | +- You may also be provided with a set of extra evaluation notes that provide more specific evaluation criteria. These extra rules may override the general rules above. |
| 58 | +- If after all this no match exists, return an empty dictionary. |
| 59 | +
|
| 60 | +Now compare the following: |
| 61 | +
|
| 62 | +Ground Truth List: |
| 63 | +{list_gt} |
| 64 | +
|
| 65 | +Model Output: |
| 66 | +{dict_pred} |
| 67 | +
|
| 68 | +Primary Keys: |
| 69 | +{primary_keys} |
| 70 | +
|
| 71 | +Extra Evaluation Notes: |
| 72 | +{note} |
| 73 | +
|
| 74 | +Your output should be the dictionary from the ground truth list that best matches the model output, or an empty dictionary if no match is found. |
| 75 | +
|
| 76 | +IMPORTANT: |
| 77 | +- Do NOT explain your reasoning. |
| 78 | +- Do NOT include any extra text. |
| 79 | +- Only output the final dictionary as a JSON object. Do not wrap it in markdown or say anything else. |
| 80 | +""", |
| 81 | +} |
| 82 | + |
| 83 | +def find_and_remove_dict(dict_to_remove, list_of_dicts, primary_keys): |
| 84 | + if dict_to_remove is None: |
| 85 | + return list_of_dicts |
| 86 | + |
| 87 | + for i, d in enumerate(list_of_dicts): |
| 88 | + if all(d.get(pk) == dict_to_remove.get(pk) for pk in primary_keys): |
| 89 | + return list_of_dicts[:i] + list_of_dicts[i+1:] |
| 90 | + return list_of_dicts |
| 91 | + |
| 92 | +def evaluate_dict_item(client, judge_name, key, gt_value, pred_value): |
| 93 | + prompt = prompts['dict_item'].format(key=key, gt_value=json.dumps(gt_value), pred_value=json.dumps(pred_value)) |
| 94 | + response = client.chat.completions.create( |
| 95 | + model=judge_name, |
| 96 | + messages=[{"role": "user", "content": prompt}], |
| 97 | + max_tokens=1, |
| 98 | + temperature=0.0 |
| 99 | + ) |
| 100 | + eval_output = int(response.choices[0].message.content.strip()) |
| 101 | + return eval_output |
| 102 | + |
| 103 | +@retry_on_exception(max_retries=3, delay=1) |
| 104 | +def evaluate_dict(client, judge_name, gt_dict, pred_dict): |
| 105 | + prompt = prompts['dict'].format(dict_gt=json.dumps(gt_dict), dict_pred=json.dumps(pred_dict)) |
| 106 | + response = client.chat.completions.create( |
| 107 | + model=judge_name, |
| 108 | + messages=[{"role": "user", "content": prompt}], |
| 109 | + max_tokens=1024, |
| 110 | + temperature=0.0 |
| 111 | + ) |
| 112 | + output = response.choices[0].message.content.replace('```json', '').replace('```', '').strip() |
| 113 | + eval_output = json.loads(output) |
| 114 | + return eval_output |
| 115 | + |
| 116 | +@retry_on_exception(max_retries=3, delay=1) |
| 117 | +def evaluate_list_of_dicts(client, judge_name, list_gt, dict_pred, eval_info): |
| 118 | + primary_keys = eval_info['primary_keys'] |
| 119 | + note = eval_info.get('evaluation_note', {}) |
| 120 | + match_prompt = prompts['list_of_dicts_match'].format(list_gt=json.dumps(list_gt), dict_pred=json.dumps(dict_pred), primary_keys=json.dumps(primary_keys), note=json.dumps(note)) |
| 121 | + response = client.chat.completions.create( |
| 122 | + model=judge_name, |
| 123 | + messages=[{"role": "user", "content": match_prompt}], |
| 124 | + max_tokens=1024, |
| 125 | + temperature=0.0 |
| 126 | + ) |
| 127 | + output = response.choices[0].message.content.replace('```json', '').replace('```', '').strip() |
| 128 | + extracted_gt_dict = json.loads(output) |
| 129 | + |
| 130 | + if len(extracted_gt_dict) == 0: |
| 131 | + return None, None |
| 132 | + |
| 133 | + eval_prompt = prompts['dict'].format(dict_gt=json.dumps(extracted_gt_dict), dict_pred=json.dumps(dict_pred)) |
| 134 | + response = client.chat.completions.create( |
| 135 | + model=judge_name, |
| 136 | + messages=[{"role": "user", "content": eval_prompt}], |
| 137 | + max_tokens=1024, |
| 138 | + temperature=0.0 |
| 139 | + ) |
| 140 | + |
| 141 | + output = response.choices[0].message.content.replace('```json', '').replace('```', '').strip() |
| 142 | + eval_output = json.loads(output) |
| 143 | + return extracted_gt_dict, eval_output |
| 144 | + |
| 145 | +def grade(judge_name: str, key: str, ground_truths: List[List[dict] | dict], preds: List[List[dict] | dict], eval_info: dict | None) -> dict: |
| 146 | + client = OpenAI() |
| 147 | + |
| 148 | + evaluations = [] |
| 149 | + for gt, pred in zip(ground_truths, preds): |
| 150 | + if isinstance(gt, dict): |
| 151 | + eval_output = evaluate_dict(client, judge_name, gt, pred) |
| 152 | + evaluations.append({ 'type': 'dict', 'extracted_gt': None, 'pred': pred, 'eval_output': eval_output }) |
| 153 | + elif isinstance(gt, list) and isinstance(gt[0], dict): |
| 154 | + for row in pred: |
| 155 | + extracted_gt_dict, eval_output = evaluate_list_of_dicts(client, judge_name, gt, row, eval_info) |
| 156 | + if eval_output is None: |
| 157 | + evaluations.append({ 'type': 'list_of_dicts_match', 'extracted_gt': None, 'pred': row, 'eval_output': None }) |
| 158 | + continue |
| 159 | + |
| 160 | + gt = find_and_remove_dict(extracted_gt_dict, gt, eval_info['primary_keys']) |
| 161 | + evaluations.append({ 'type': 'list_of_dicts_match', 'extracted_gt': extracted_gt_dict, 'pred': row, 'eval_output': eval_output }) |
| 162 | + else: |
| 163 | + raise ValueError(f"Unsupported ground truth type: {type(gt)} for key {key}") |
| 164 | + |
| 165 | + main_claims = eval_info.get('main_claims', []) |
| 166 | + ignore_keys = set(eval_info.get('ignore_keys', [])) |
| 167 | + |
| 168 | + gt_key_count = 0 |
| 169 | + for gt in ground_truths: |
| 170 | + if type(gt) == dict: |
| 171 | + gt_key_count += len([k for k in gt if k not in ignore_keys]) |
| 172 | + elif type(gt) == list and type(gt[0]) == dict: |
| 173 | + for item in gt: |
| 174 | + gt_key_count += len([k for k in item if k not in ignore_keys]) |
| 175 | + |
| 176 | + eval_key_count = 0 |
| 177 | + for evaluation in evaluations: |
| 178 | + pred = evaluation['pred'] |
| 179 | + assert type(pred) is dict |
| 180 | + eval_key_count += len([k for k in pred if k not in ignore_keys]) |
| 181 | + |
| 182 | + # search for a possible dataset/flight identification task. if this is not correct, there is no point in evaluating the rest of the claims |
| 183 | + is_identification_claim_correct = True |
| 184 | + for evaluation in evaluations: |
| 185 | + evaluation_type = evaluation['type'] |
| 186 | + if evaluation_type != 'dict': |
| 187 | + continue |
| 188 | + |
| 189 | + pred, eval_output = evaluation['pred'], evaluation['eval_output'] |
| 190 | + |
| 191 | + for k, v in eval_output.items(): |
| 192 | + if k in main_claims and v <= 1: |
| 193 | + is_identification_claim_correct = False |
| 194 | + break |
| 195 | + |
| 196 | + if not is_identification_claim_correct: |
| 197 | + break |
| 198 | + |
| 199 | + # now do the actual evaluation |
| 200 | + eval_score = 0 |
| 201 | + for evaluation in evaluations: |
| 202 | + pred, eval_output = evaluation['pred'], evaluation['eval_output'] |
| 203 | + if eval_output is None: |
| 204 | + continue |
| 205 | + |
| 206 | + # check if all main_extractive_claims, if any, are correct |
| 207 | + # is_identification_claim_correct will be False if the dataset or flight number identification task was not correct |
| 208 | + # leading to everything being considered incorrect |
| 209 | + are_main_extractive_claims_correct = is_identification_claim_correct |
| 210 | + for k, v in eval_output.items(): |
| 211 | + if k in main_claims and v <= 1: |
| 212 | + are_main_extractive_claims_correct = False |
| 213 | + break |
| 214 | + |
| 215 | + # if main claims are wrong, all claims for this paritular evaluation are not considered correct |
| 216 | + if not are_main_extractive_claims_correct: |
| 217 | + for k, v in eval_output.items(): |
| 218 | + eval_output[k] = 0 |
| 219 | + |
| 220 | + # binarize |
| 221 | + for k, v in eval_output.items(): |
| 222 | + eval_output[k] = 0 if v <= 1 else 3 |
| 223 | + |
| 224 | + eval_score += sum([v for k, v in eval_output.items() if k not in ignore_keys]) |
| 225 | + |
| 226 | + eval_score /= 3 # normalize to [0, 1] range |
| 227 | + assert eval_score <= eval_key_count, f"Eval score {eval_score} exceeds eval key count {eval_key_count} for key {key}" |
| 228 | + assert eval_score <= gt_key_count, f"Eval score {eval_score} exceeds ground truth key count {gt_key_count} for key {key}" |
| 229 | + |
| 230 | + precision = eval_score / eval_key_count if eval_key_count > 0 else 0 |
| 231 | + recall = eval_score / gt_key_count if gt_key_count > 0 else 0 |
| 232 | + f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 |
| 233 | + # print(key, precision, recall, f1) |
| 234 | + |
| 235 | + return { 'precision': precision, 'recall': recall, 'f1': f1 } |
0 commit comments