Skip to content

Commit 7df7685

Browse files
committed
wip
1 parent 3d89db5 commit 7df7685

2 files changed

Lines changed: 282 additions & 0 deletions

File tree

evalbench/scorers/llmrater_v2.py

Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
"""
2+
LLMRater
3+
In this comparison strategy, an LLM compares the golden execution results with the generated sql execution results.
4+
It returns a score between 0 and 100, with a score of 100 for concrete positive cases,
5+
where either there is a mismatch of columns names, extra relevant columns, or harmless unrequested sorting/limits in Generated SQL.
6+
7+
Evaluation rules given to LLM:
8+
1. Assume OUTPUT #1 is the gold standard and is ALWAYS correct.
9+
2. The order of columns in OUTPUT #2 does not matter.
10+
3. The order of rows in OUTPUT #2 does not matter UNLESS explicitly requested in the prompt.
11+
4. Allow slight variations due to differences in rounding or precision, for calculated values.
12+
5. Allow acceptable divergences based on relaxed criteria (ambiguous counts, null handling, relative dates, IDs vs Names, ambiguous limits, etc.).
13+
6. The mapped column names might differ, do not make any assumptions based on them.
14+
15+
Run Configuration Options:
16+
1. model_config: Required
17+
- File that defines the configuration settings for the LLM model to be used in evaluation.
18+
"""
19+
20+
from typing import Tuple
21+
from generators.models import get_generator
22+
from scorers import setmatcher
23+
import logging
24+
25+
from scorers import comparator
26+
from .util import make_hashable, with_cache_execute
27+
from databases.util import get_cache_client
28+
29+
ERROR_CATEGORIZATION_PROMPT = """
30+
You are an expert SQL evaluator. Your task is to analyze a "Generated SQL" query against a "Golden SQL" (ground truth) query and their respective execution results.
31+
32+
### Input Data
33+
**NL Prompt:** {nl_prompt}
34+
**Golden SQL:** {golden_sql}
35+
**Golden Result:** {golden_execution_result}
36+
**Generated SQL:** {generated_sql}
37+
**Generated Result:** {generated_execution_result}
38+
39+
### Task
40+
Compare the queries and results to identify specific errors in the Generated SQL. If the Generated SQL is functionally equivalent to the Golden SQL (even if syntax differs), mark it as correct.
41+
42+
### Error Taxonomy
43+
If errors exist, categorize them using ONLY the following tags:
44+
45+
1. [EntityError] - Wrong table or entity was used.
46+
2. [ValueLinkingError] - Wrong literal value (string/number) was used.
47+
3. [ColumnLinkingError] - Wrong column was selected or used in a condition.
48+
4. [OrderingError] - Sorting order (ASC/DESC) or column is incorrect (only flag if prompt explicitly requested sorting).
49+
5. [InstructionError] - Failed to follow specific constraints in the prompt (e.g., "return top 5").
50+
6. [IntentError] - Misinterpreted the user's fundamental request.
51+
7. [DataTypesError] - Incorrect handling of data types (e.g., casting, dates).
52+
8. [CountingError] - Aggregation or counting logic is flawed.
53+
9. [FilterError] - Correct columns used, but wrong logical operator or filter condition.
54+
10. [LogicError] - Fundamental logic flaw not covered by other categories (e.g., wrong join type).
55+
11. [OtherError] - Any other error not covered by the above categories.
56+
57+
### Output Format
58+
Provide your response in the following format:
59+
60+
**Reasoning:**
61+
<Analyze the differences between the queries and results here>
62+
63+
**Tags & Explanations:**
64+
<Tag 1>: <One-line explanation of the specific error>
65+
<Tag 2>: <One-line explanation of the specific error>
66+
"""
67+
68+
69+
class LLMRater(comparator.Comparator):
70+
"""
71+
LLMRater class implements the Comparator base class.
72+
73+
Attributes:
74+
1. name: Name of the comparator. Set to "llmrater"
75+
2. model_config: File that defines the configuration settings for the LLM model used in evaluation.
76+
"""
77+
78+
def __init__(self, config: dict, global_models):
79+
self.name = "llmrater"
80+
self.set_match_checker = setmatcher.SetMatcher({})
81+
self.cache_client = get_cache_client(config)
82+
self.model_config = config.get("model_config") or ""
83+
if not self.model_config:
84+
raise ValueError("model_config is required for LLM Rater")
85+
self.model = get_generator(global_models, self.model_config)
86+
87+
def _is_exact_match(
88+
self,
89+
nl_prompt: str,
90+
golden_query: str,
91+
query_type: str,
92+
golden_execution_result: list,
93+
golden_eval_result: str,
94+
golden_error: str,
95+
generated_query: str,
96+
generated_execution_result: list,
97+
generated_eval_result: str,
98+
generated_error: str,
99+
):
100+
score, _ = self.set_match_checker.compare(
101+
nl_prompt,
102+
golden_query,
103+
query_type,
104+
golden_execution_result,
105+
golden_eval_result,
106+
golden_error,
107+
generated_query,
108+
generated_execution_result,
109+
generated_eval_result,
110+
generated_error,
111+
)
112+
return score == 100
113+
114+
def _inference_without_caching(self, prompt):
115+
if self.model is None:
116+
raise RuntimeError("Model not initialized")
117+
return self.model.generate(prompt)
118+
119+
@staticmethod
120+
def take_n_uniques(output_list: list, n: int) -> list:
121+
"""Takes n number of unique (non duplicate) values from the output list.
122+
123+
Args:
124+
output_list: The execution output result set
125+
n: Max number of unique values needed.
126+
127+
Returns:
128+
The execution output result set without duplicates in a size of n values or less.
129+
"""
130+
seen_dicts = set()
131+
new_list = []
132+
for d in output_list:
133+
# Convert the dictionary to a hashable frozenset for efficient lookup
134+
t = frozenset((k, make_hashable(v)) for k, v in d.items())
135+
if t not in seen_dicts:
136+
seen_dicts.add(t)
137+
new_list.append(d)
138+
if len(new_list) == n:
139+
break
140+
return new_list
141+
142+
def compare(
143+
self,
144+
nl_prompt: str,
145+
golden_query: str,
146+
query_type: str,
147+
golden_execution_result: list,
148+
golden_eval_result: str,
149+
golden_error: str,
150+
generated_query: str,
151+
generated_execution_result: list,
152+
generated_eval_result: str,
153+
generated_error: str,
154+
) -> Tuple[float, str]:
155+
if self._is_exact_match(
156+
nl_prompt,
157+
golden_query,
158+
query_type,
159+
golden_execution_result,
160+
golden_eval_result,
161+
golden_error,
162+
generated_query,
163+
generated_execution_result,
164+
generated_eval_result,
165+
generated_error,
166+
):
167+
return 100, "Skipped. Exact Match was found."
168+
169+
if golden_error:
170+
return 0, "Golden query failed to execute."
171+
if generated_error:
172+
return 0, "Generated query failed to execute."
173+
174+
only_first_n = 50
175+
176+
golden_execution_result = self.take_n_uniques(
177+
golden_execution_result, only_first_n
178+
)
179+
generated_execution_result = self.take_n_uniques(
180+
generated_execution_result, only_first_n
181+
)
182+
183+
prompt = f"""
184+
We are trying to answer this question by querying a database:
185+
186+
QUESTION: {nl_prompt}
187+
188+
The correct answer to this question is:
189+
190+
OUTPUT #1 (Gold Standard):
191+
192+
{golden_execution_result}
193+
194+
195+
We get the following answer from a generated query:
196+
197+
OUTPUT #2 (Generated Result):
198+
199+
{generated_execution_result}
200+
201+
202+
Thinking step by step, compare the two outputs and look for differences in data and presentation.
203+
Here are steps to follow:
204+
205+
1. Analyze the QUESTION: Does it explicitly ask for a specific sorting order (e.g., "ordered by date", "top 5")? Does it explicitly ask for a limit?
206+
2. Column Mapping: Ensure that every column in OUTPUT #1 has a corresponding column in OUTPUT #2 that represents the same information. OUTPUT #2 is allowed to have additional descriptive columns.
207+
3. Data Comparison: Compare the data within each mapped column pair.
208+
4. Row Order: Ignore differences in row order UNLESS the QUESTION explicitly requested a specific sorting. Treat the data as unordered sets if no order is specified.
209+
5. Extra Rows: If OUTPUT #2 has extra rows but contains all of OUTPUT #1, evaluate if the extra rows violate the prompt's constraints. If the prompt was ambiguous about limits (e.g. "Identify the MSA with the highest growth" and the model returns a ranked list instead of a single row), treating it as EXTRA_INFORMATION is acceptable and correct.
210+
211+
RULES & RELAXED EVALUATION CRITERIA - These MUST be strictly followed:
212+
213+
1. Assume OUTPUT #1 is the gold standard and its core data values are ALWAYS mathematically/logically correct.
214+
2. The mapped column names might differ, do not make any assumptions based on them.
215+
3. Do NOT penalize OUTPUT #2 if it differs from OUTPUT #1 for ANY of the following reasons:
216+
- Column/Row Order: Differences in column names, column order, or row order when no requirements are specified in the QUESTION.
217+
- Rounding: Differences in integer/decimal rounding or precision when the QUESTION lacks specific guidelines.
218+
- Ambiguous "Top X": The QUESTION requests "first" or "top" X entries but doesn't specify an ordering field, yielding different subsets.
219+
- Null/NA Handling: Differences in including vs. excluding 'null' or 'NA' values when the QUESTION does not specify.
220+
- Ambiguous Limit: The QUESTION asks for "top/highest" or "bottom/lowest" entries but doesn't specify a concrete limit, leading to different numbers of entries.
221+
- Entity Representation: The QUESTION asks for a list of items but doesn't specify IDs or names, leading one output to return IDs and the other names.
222+
- Extra Columns: OUTPUT #2 has a small number of extra columns that are not explicitly excluded and don't render the overall result incorrect.
223+
- Truncation/Subsets: Truncation for display, or differing subsets of data when ordering is not specified.
224+
- Fewer than X: Returning fewer than X entries for a "top X" request because fewer entries meet the underlying criteria.
225+
- Relative Time/Date: Differences arising because queries were evaluated at different assumed current times/dates (e.g., "last two years").
226+
227+
FINAL QUESTION: Does OUTPUT #2 provide the same information as OUTPUT #1?
228+
FINAL ANSWER: Choose ONLY ONE
229+
- INFORMATION_MATCHES -- OUTPUT #1 and OUTPUT #2 provide the same core information (or differences fall under the acceptable relaxed criteria).
230+
- MISSING_INFORMATION -- Something important requested by the QUESTION is missing from OUTPUT #2 (e.g. data points dropped, missing expected columns).
231+
- EXTRA_INFORMATION -- OUTPUT #2 includes the correct answer but added non-harmful extra relevant columns, or harmless extra rows due to an ambiguous limit/sorting constraint in the QUESTION.
232+
- INCORRECT_INFORMATION -- OUTPUT #2 contains mathematically or logically incorrect data, wrong aggregations, bad joins, missing expected rows, or violates explicit constraints in the QUESTION.
233+
"""
234+
235+
logging.debug("\n --------- prompt: --------- \n %s ", prompt)
236+
237+
if self.cache_client:
238+
response = with_cache_execute(
239+
prompt,
240+
self.model_config,
241+
self._inference_without_caching,
242+
self.cache_client,
243+
)
244+
else:
245+
response = self._inference_without_caching(prompt)
246+
247+
logging.debug(
248+
"\n --------- llm_rater_output: --------- \n %s ", response)
249+
250+
# Scoring Logic: Both INFORMATION_MATCHES and EXTRA_INFORMATION are rewarded as correct.
251+
score = (
252+
100
253+
if ("INFORMATION_MATCHES" in response or "EXTRA_INFORMATION" in response)
254+
else 0
255+
)
256+
257+
if score == 0:
258+
prompt = ERROR_CATEGORIZATION_PROMPT.format(
259+
nl_prompt=nl_prompt,
260+
golden_sql=golden_query,
261+
golden_execution_result=golden_execution_result,
262+
generated_sql=generated_query,
263+
generated_execution_result=generated_execution_result,
264+
)
265+
if self.cache_client:
266+
error_categorization_response = with_cache_execute(
267+
prompt,
268+
self.model_config,
269+
self._inference_without_caching,
270+
self.cache_client,
271+
)
272+
else:
273+
error_categorization_response = self._inference_without_caching(
274+
prompt)
275+
276+
response += "\nError analysis:\n\n" + error_categorization_response
277+
278+
return score, response

evalbench/scorers/score.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from scorers import recallmatcher
77
from scorers import setmatcher
88
from scorers import llmrater
9+
from scorers import llmrater_v2
910
from scorers import returnedsql
1011
from scorers import executablesql
1112
from scorers import trajectorymatcher
@@ -44,6 +45,9 @@ def compare(
4445
if "llmrater" in scorers:
4546
comparators.append(llmrater.LLMRater(
4647
scorers["llmrater"], global_models))
48+
if "llmrater_v2" in scorers:
49+
comparators.append(llmrater_v2.LLMRaterV2(
50+
scorers["llmrater_v2"], global_models))
4751
if "regexp_matcher" in scorers:
4852
comparators.append(
4953
generatedqueryregexpmatcher.GeneratedQueryRegexpMatcher(

0 commit comments

Comments
 (0)