diff --git a/imobench/__init__.py b/imobench/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/imobench/eval/README.md b/imobench/eval/README.md new file mode 100644 index 0000000..11181d9 --- /dev/null +++ b/imobench/eval/README.md @@ -0,0 +1,104 @@ +# IMO-AnswerBench Evaluation Harness + +A Python tool for evaluating model outputs against the IMO-AnswerBench ground +truth answers using mathematical equivalence checking. + +## Features + +- **Math equivalence checking** via SymPy: correctly identifies equivalent + mathematical expressions (e.g., `\frac{1}{2}` == `0.5`). +- **Multiple checking strategies**: exact match, numeric comparison, SymPy + symbolic equivalence, multi-answer set matching, and normalized string + comparison. +- **Detailed metrics**: accuracy breakdown by category, subcategory, and source. +- **CLI interface**: score predictions from the command line. +- **Flexible input**: accepts predictions as CSV or JSONL. + +## Installation + +```bash +pip install -r imobench/eval/requirements.txt + +# For running tests: +pip install -r imobench/eval/requirements-dev.txt +``` + +## Usage + +### Prepare predictions + +Create a CSV file with your model's predictions: + +```csv +Problem ID,Model Answer +imo-bench-algebra-001,3 +imo-bench-algebra-002,$\lfloor \log_2 a \rfloor + 1$ +``` + +Or a JSONL file: + +```jsonl +{"problem_id": "imo-bench-algebra-001", "answer": "3"} +{"problem_id": "imo-bench-algebra-002", "answer": "$\\lfloor \\log_2 a \\rfloor + 1$"} +``` + +### Run evaluation + +```bash +# Text report (default) +python -m imobench.eval.cli predictions.csv + +# JSON output +python -m imobench.eval.cli predictions.csv --format json + +# Save detailed results +python -m imobench.eval.cli predictions.csv --output results.json + +# Use a custom answerbench path +python -m imobench.eval.cli predictions.csv --answerbench path/to/answerbench_v2.csv +``` + +### Python API + +```python +from imobench.eval import check_answer, evaluate_predictions, compute_metrics, format_report +from imobench.eval.evaluate import load_predictions + +# Check a single answer +result = check_answer(r"\frac{1}{2}", "0.5") +print(result) # {'correct': True, 'method': 'sympy', 'details': ''} + +# Evaluate a batch of predictions +predictions = load_predictions("predictions.csv") +results = evaluate_predictions(predictions) +metrics = compute_metrics(results) +print(format_report(metrics)) +``` + +## Answer Checking Strategies + +The checker tries strategies in this order and returns the first definitive +result: + +| Strategy | Handles | Example | +|----------|---------|---------| +| Exact match | Identical normalized strings | `42` == `42` | +| Numeric | Plain numbers | `3.0` == `3` | +| Multi-answer | Comma-separated sets | `3, 1, 2` == `1, 2, 3` | +| String normalized | Case/whitespace differences | `Algebra` == `algebra` | +| SymPy | LaTeX math expressions | `\frac{1}{2}` == `0.5` | + +## Running Tests + +```bash +pytest imobench/eval/tests/ -v +``` + +## Limitations + +- SymPy's LaTeX parser does not handle all mathematical notation (e.g., some + piecewise functions, complex set-builder notation). +- Answers involving free variables or functions (e.g., `f(x) = 2x + c`) require + structural matching that may not always succeed. +- For proof-based problems (IMO-ProofBench), use LLM-based grading with the + autograder prompts instead. diff --git a/imobench/eval/__init__.py b/imobench/eval/__init__.py new file mode 100644 index 0000000..ed30384 --- /dev/null +++ b/imobench/eval/__init__.py @@ -0,0 +1,37 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""IMO-AnswerBench evaluation harness. + +Provides tools to evaluate model outputs against IMO-AnswerBench ground truth +answers using mathematical equivalence checking. + +Usage: + from imobench.eval import check_answer, evaluate_predictions, compute_metrics + + result = check_answer("\\frac{1}{2}", "0.5") + assert result["correct"] is True +""" + +from imobench.eval.answer_checker import check_answer +from imobench.eval.evaluate import evaluate_predictions, load_predictions +from imobench.eval.metrics import compute_metrics, format_report + +__all__ = [ + "check_answer", + "evaluate_predictions", + "load_predictions", + "compute_metrics", + "format_report", +] diff --git a/imobench/eval/answer_checker.py b/imobench/eval/answer_checker.py new file mode 100644 index 0000000..9256ee9 --- /dev/null +++ b/imobench/eval/answer_checker.py @@ -0,0 +1,234 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Mathematical answer equivalence checking for IMO-AnswerBench. + +Supports numeric comparison, LaTeX expression parsing via SymPy, and +multi-answer (comma-separated) set matching. +""" + +import math +import re +import threading +from typing import Any + +from sympy import simplify +from sympy.parsing.latex import parse_latex + +_SYMPY_TIMEOUT_SEC = 5 + + +def normalize_latex(text: str) -> str: + """Strip LaTeX delimiters, whitespace, and trailing punctuation.""" + text = text.strip() + # Strip trailing punctuation before removing delimiters, + # since periods often appear outside: "$x+1$." + text = text.rstrip(".") + text = text.strip() + # Remove $ delimiters + text = re.sub(r"^\$+|\$+$", "", text) + # Remove \( \) and \[ \] delimiters (common in LLM outputs) + text = re.sub(r"^\\\(|\\\)$", "", text) + text = re.sub(r"^\\\[|\\\]$", "", text) + text = text.strip() + return text + + +def _looks_like_math(text: str) -> bool: + """Heuristic check for LaTeX math notation or math-like content.""" + # LaTeX commands, superscripts, subscripts, braces + if re.search(r"[\\^_{}]", text): + return True + # Common math function names + if re.search(r"\b(sqrt|log|ln|sin|cos|tan|exp|lim|sum|prod)\b", text): + return True + # Expressions with operators between digits/variables + if re.search(r"\d\s*[+\-*/]\s*\d", text): + return True + return False + + +def _split_multi_answer(text: str) -> list[str]: + """Split comma-separated answers respecting nested braces and parens.""" + parts: list[str] = [] + depth = 0 + current: list[str] = [] + for ch in text: + if ch in "({[": + depth += 1 + elif ch in ")}]": + depth = max(0, depth - 1) + elif ch == "," and depth == 0: + parts.append("".join(current).strip()) + current = [] + continue + current.append(ch) + parts.append("".join(current).strip()) + return [p for p in parts if p] + + +def _try_parse_number(text: str) -> float | None: + """Try to parse a normalized string as a finite number.""" + try: + val = float(text) + if math.isfinite(val): + return val + return None + except ValueError: + return None + + +def _run_with_timeout(fn, timeout_sec: int = _SYMPY_TIMEOUT_SEC) -> Any: + """Run a callable with a timeout. Returns None on timeout or error.""" + result: list[Any] = [None] + error: list[bool] = [False] + + def target(): + try: + result[0] = fn() + except Exception: + error[0] = True + + thread = threading.Thread(target=target, daemon=True) + thread.start() + thread.join(timeout=timeout_sec) + if thread.is_alive() or error[0]: + return None + return result[0] + + +def _try_parse_sympy(text: str) -> Any: + """Try to parse a LaTeX string into a SymPy expression with timeout.""" + return _run_with_timeout(lambda: parse_latex(text)) + + +def _expressions_equivalent(expr_a: Any, expr_b: Any) -> bool: + """Check if two SymPy expressions are mathematically equivalent.""" + def _check(): + diff = simplify(expr_a - expr_b) + if diff == 0: + return True + if hasattr(diff, "is_zero") and diff.is_zero: + return True + return False + + result = _run_with_timeout(_check) + if result is True: + return True + + def _check_equals(): + return bool(expr_a.equals(expr_b)) + + result = _run_with_timeout(_check_equals) + return result is True + + +def check_answer( + model_answer: str, ground_truth: str, _depth: int = 0 +) -> dict[str, Any]: + """Check if a model answer matches the ground truth. + + Tries multiple strategies in order: + 1. Exact string match (after normalization) + 2. Numeric comparison + 3. Multi-answer set matching (comma-separated) + 4. Normalized string comparison (case/whitespace insensitive) + 5. SymPy mathematical equivalence (only for math-like expressions) + + Args: + model_answer: The model's predicted answer (plain text or LaTeX). + ground_truth: The ground truth answer from the benchmark. + + Returns: + A dict with keys: + correct (bool): Whether the answer is correct. + method (str): Which strategy determined the result. + details (str): Additional information about the comparison. + """ + model_norm = normalize_latex(model_answer) + truth_norm = normalize_latex(ground_truth) + + # 1. Exact string match + if model_norm == truth_norm: + return {"correct": True, "method": "exact_match", "details": ""} + + # 2. Numeric comparison + model_num = _try_parse_number(model_norm) + truth_num = _try_parse_number(truth_norm) + if model_num is not None and truth_num is not None: + if abs(model_num - truth_num) < 1e-9: + return {"correct": True, "method": "numeric", "details": ""} + return { + "correct": False, + "method": "numeric", + "details": f"Expected {truth_num}, got {model_num}", + } + + # 3. Multi-answer (comma-separated sets) + truth_parts = _split_multi_answer(truth_norm) + if len(truth_parts) > 1 and _depth < 2: + model_parts = _split_multi_answer(model_norm) + if len(model_parts) != len(truth_parts): + return { + "correct": False, + "method": "multi_answer", + "details": ( + f"Expected {len(truth_parts)} answers, " + f"got {len(model_parts)}" + ), + } + matched: set[int] = set() + for mp in model_parts: + for i, tp in enumerate(truth_parts): + if i not in matched: + sub_result = check_answer(mp, tp, _depth=_depth + 1) + if sub_result["correct"]: + matched.add(i) + break + if len(matched) == len(truth_parts): + return { + "correct": True, + "method": "multi_answer", + "details": "All parts matched", + } + return { + "correct": False, + "method": "multi_answer", + "details": f"Matched {len(matched)}/{len(truth_parts)} parts", + } + + # 4. Normalized string comparison + model_clean = re.sub(r"\s+", " ", model_norm.lower()) + truth_clean = re.sub(r"\s+", " ", truth_norm.lower()) + if model_clean == truth_clean: + return {"correct": True, "method": "string_normalized", "details": ""} + + # 5. SymPy expression comparison (only for math-like expressions) + if _looks_like_math(model_norm) or _looks_like_math(truth_norm): + model_expr = _try_parse_sympy(model_norm) + truth_expr = _try_parse_sympy(truth_norm) + if model_expr is not None and truth_expr is not None: + if _expressions_equivalent(model_expr, truth_expr): + return {"correct": True, "method": "sympy", "details": ""} + return { + "correct": False, + "method": "sympy", + "details": "Expressions not equivalent", + } + + return { + "correct": False, + "method": "no_match", + "details": "Could not determine equivalence", + } diff --git a/imobench/eval/cli.py b/imobench/eval/cli.py new file mode 100644 index 0000000..37db4a7 --- /dev/null +++ b/imobench/eval/cli.py @@ -0,0 +1,95 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Command-line interface for IMO-AnswerBench evaluation. + +Usage: + python -m imobench.eval.cli predictions.csv + python -m imobench.eval.cli predictions.jsonl --output results.json + python -m imobench.eval.cli predictions.csv --answerbench path/to/answerbench_v2.csv +""" + +import argparse +import json +import sys + +from imobench.eval.evaluate import evaluate_predictions, load_predictions +from imobench.eval.metrics import compute_metrics, format_report + + +def main(argv: list[str] | None = None) -> None: + parser = argparse.ArgumentParser( + description="Evaluate model predictions against IMO-AnswerBench.", + ) + parser.add_argument( + "predictions", + help="Path to predictions file (CSV or JSONL).", + ) + parser.add_argument( + "--answerbench", + default=None, + help="Path to answerbench CSV (defaults to bundled answerbench_v2.csv).", + ) + parser.add_argument( + "--output", + default=None, + help="Path to write detailed results as JSON.", + ) + parser.add_argument( + "--format", + choices=["text", "json"], + default="text", + help="Output format for the report (default: text).", + ) + + args = parser.parse_args(argv) + + try: + predictions = load_predictions(args.predictions) + except FileNotFoundError: + print(f"Error: File not found: {args.predictions}", file=sys.stderr) + sys.exit(1) + except ValueError as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) + + if not predictions: + print("Error: No predictions loaded.", file=sys.stderr) + sys.exit(1) + + try: + results = evaluate_predictions(predictions, args.answerbench) + except FileNotFoundError: + print( + f"Error: Answerbench not found: {args.answerbench}", + file=sys.stderr, + ) + sys.exit(1) + metrics = compute_metrics(results) + + if args.format == "json": + output = {"metrics": metrics, "results": results} + print(json.dumps(output, indent=2, ensure_ascii=False)) + else: + print(format_report(metrics)) + + if args.output: + output = {"metrics": metrics, "results": results} + with open(args.output, "w", encoding="utf-8") as f: + json.dump(output, f, indent=2, ensure_ascii=False) + print(f"\nDetailed results written to {args.output}") + + +if __name__ == "__main__": + main() diff --git a/imobench/eval/evaluate.py b/imobench/eval/evaluate.py new file mode 100644 index 0000000..fea7889 --- /dev/null +++ b/imobench/eval/evaluate.py @@ -0,0 +1,153 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Benchmark evaluation runner for IMO-AnswerBench. + +Loads model predictions and scores them against the ground truth answers. +""" + +import csv +import json +from pathlib import Path +from typing import Any + +from imobench.eval.answer_checker import check_answer + +_BENCHMARKS_DIR = Path(__file__).resolve().parent.parent +_ANSWERBENCH_PATH = _BENCHMARKS_DIR / "answerbench_v2.csv" + + +def load_answerbench( + path: str | Path | None = None, +) -> dict[str, dict[str, str]]: + """Load IMO-AnswerBench ground truth. + + Args: + path: Path to answerbench CSV. Defaults to answerbench_v2.csv. + + Returns: + Dict mapping Problem ID to row data. + """ + if path is None: + path = _ANSWERBENCH_PATH + path = Path(path) + + problems: dict[str, dict[str, str]] = {} + with open(path, newline="", encoding="utf-8") as f: + reader = csv.DictReader(f) + for row in reader: + problems[row["Problem ID"]] = dict(row) + return problems + + +def load_predictions(path: str | Path) -> dict[str, str]: + """Load model predictions from CSV or JSONL. + + Expected CSV format: + Problem ID,Model Answer + imo-bench-algebra-001,3 + imo-bench-algebra-002,$\\log_2 a + 1$ + + Expected JSONL format: + {"problem_id": "imo-bench-algebra-001", "answer": "3"} + + Args: + path: Path to predictions file. + + Returns: + Dict mapping Problem ID to model answer string. + """ + path = Path(path) + predictions: dict[str, str] = {} + + if path.suffix == ".jsonl": + with open(path, encoding="utf-8") as f: + for line_num, line in enumerate(f, 1): + line = line.strip() + if not line: + continue + obj = json.loads(line) + pid = obj.get("problem_id", obj.get("Problem ID", "")) + if not pid: + raise ValueError( + f"Missing problem ID in {path} line {line_num}" + ) + answer = obj.get("answer", obj.get("Model Answer", "")) + predictions[pid] = str(answer) + else: + with open(path, newline="", encoding="utf-8") as f: + reader = csv.DictReader(f) + for row_num, row in enumerate(reader, 2): + pid = row.get("Problem ID", row.get("problem_id", "")) + if not pid: + raise ValueError( + f"Missing problem ID in {path} row {row_num}" + ) + answer = row.get("Model Answer", row.get("answer", "")) + predictions[pid] = str(answer) + + return predictions + + +def evaluate_predictions( + predictions: dict[str, str], + answerbench_path: str | Path | None = None, +) -> list[dict[str, Any]]: + """Evaluate model predictions against IMO-AnswerBench. + + Args: + predictions: Dict mapping Problem ID to model answer. + answerbench_path: Path to answerbench CSV. Defaults to bundled v2. + + Returns: + List of result dicts, one per problem, each containing: + problem_id (str) + category (str) + subcategory (str) + source (str) + ground_truth (str) + model_answer (str) + correct (bool) + method (str) + details (str) + """ + benchmark = load_answerbench(answerbench_path) + results: list[dict[str, Any]] = [] + + for pid, problem in benchmark.items(): + model_answer = predictions.get(pid, "") + ground_truth = problem["Short Answer"] + + if not model_answer: + result = { + "correct": False, + "method": "missing", + "details": "No prediction provided", + } + else: + result = check_answer(model_answer, ground_truth) + + results.append( + { + "problem_id": pid, + "category": problem.get("Category", ""), + "subcategory": problem.get("Subcategory", ""), + "source": problem.get("Source", ""), + "ground_truth": ground_truth, + "model_answer": model_answer, + **result, + } + ) + + return results diff --git a/imobench/eval/metrics.py b/imobench/eval/metrics.py new file mode 100644 index 0000000..176ac6c --- /dev/null +++ b/imobench/eval/metrics.py @@ -0,0 +1,145 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Metrics computation and reporting for IMO-AnswerBench evaluation.""" + +from collections import defaultdict +from typing import Any + + +def compute_metrics(results: list[dict[str, Any]]) -> dict[str, Any]: + """Compute accuracy metrics from evaluation results. + + Args: + results: List of result dicts from evaluate_predictions. + + Returns: + Dict containing: + overall: Overall accuracy and counts. + by_category: Accuracy broken down by Category. + by_subcategory: Accuracy broken down by Category/Subcategory. + by_source: Accuracy broken down by Source. + by_method: Count of answers resolved by each checking method. + """ + total = len(results) + correct = sum(1 for r in results if r["correct"]) + + # By category + cat_counts: dict[str, dict[str, int]] = defaultdict( + lambda: {"total": 0, "correct": 0} + ) + for r in results: + cat = r.get("category", "Unknown") + cat_counts[cat]["total"] += 1 + if r["correct"]: + cat_counts[cat]["correct"] += 1 + + # By subcategory + subcat_counts: dict[str, dict[str, int]] = defaultdict( + lambda: {"total": 0, "correct": 0} + ) + for r in results: + key = f"{r.get('category', 'Unknown')}/{r.get('subcategory', 'Unknown')}" + subcat_counts[key]["total"] += 1 + if r["correct"]: + subcat_counts[key]["correct"] += 1 + + # By source + source_counts: dict[str, dict[str, int]] = defaultdict( + lambda: {"total": 0, "correct": 0} + ) + for r in results: + source = r.get("source", "Unknown") + source_counts[source]["total"] += 1 + if r["correct"]: + source_counts[source]["correct"] += 1 + + # By method + method_counts: dict[str, int] = defaultdict(int) + for r in results: + method_counts[r.get("method", "unknown")] += 1 + + def _accuracy(counts: dict[str, int]) -> float: + if counts["total"] == 0: + return 0.0 + return counts["correct"] / counts["total"] + + return { + "overall": { + "total": total, + "correct": correct, + "accuracy": correct / total if total > 0 else 0.0, + }, + "by_category": { + cat: {**counts, "accuracy": _accuracy(counts)} + for cat, counts in sorted(cat_counts.items()) + }, + "by_subcategory": { + key: {**counts, "accuracy": _accuracy(counts)} + for key, counts in sorted(subcat_counts.items()) + }, + "by_source": { + src: {**counts, "accuracy": _accuracy(counts)} + for src, counts in sorted(source_counts.items()) + }, + "by_method": dict(sorted(method_counts.items())), + } + + +def format_report(metrics: dict[str, Any]) -> str: + """Format metrics into a human-readable report. + + Args: + metrics: Output from compute_metrics. + + Returns: + Formatted string report. + """ + lines: list[str] = [] + + overall = metrics["overall"] + lines.append("=" * 60) + lines.append("IMO-AnswerBench Evaluation Report") + lines.append("=" * 60) + lines.append( + f"Overall Accuracy: {overall['correct']}/{overall['total']} " + f"({overall['accuracy']:.1%})" + ) + lines.append("") + + lines.append("Accuracy by Category:") + lines.append("-" * 40) + for cat, data in metrics["by_category"].items(): + lines.append( + f" {cat:<25} {data['correct']:>3}/{data['total']:<3} " + f"({data['accuracy']:.1%})" + ) + lines.append("") + + lines.append("Accuracy by Subcategory:") + lines.append("-" * 40) + for key, data in metrics["by_subcategory"].items(): + lines.append( + f" {key:<35} {data['correct']:>3}/{data['total']:<3} " + f"({data['accuracy']:.1%})" + ) + lines.append("") + + lines.append("Checking Methods Used:") + lines.append("-" * 40) + for method, count in metrics["by_method"].items(): + lines.append(f" {method:<25} {count:>4}") + lines.append("") + + return "\n".join(lines) diff --git a/imobench/eval/requirements-dev.txt b/imobench/eval/requirements-dev.txt new file mode 100644 index 0000000..9573d48 --- /dev/null +++ b/imobench/eval/requirements-dev.txt @@ -0,0 +1,2 @@ +-r requirements.txt +pytest>=7.0 diff --git a/imobench/eval/requirements.txt b/imobench/eval/requirements.txt new file mode 100644 index 0000000..1991c1f --- /dev/null +++ b/imobench/eval/requirements.txt @@ -0,0 +1,4 @@ +sympy>=1.12 +# Pinned for compatibility with sympy.parsing.latex (LaTeX → SymPy parser). +# Different SymPy versions may bundle different ANTLR grammars. +antlr4-python3-runtime==4.11.1 diff --git a/imobench/eval/tests/__init__.py b/imobench/eval/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/imobench/eval/tests/test_answer_checker.py b/imobench/eval/tests/test_answer_checker.py new file mode 100644 index 0000000..36ace42 --- /dev/null +++ b/imobench/eval/tests/test_answer_checker.py @@ -0,0 +1,217 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the answer equivalence checker.""" + +import pytest + +from imobench.eval.answer_checker import ( + _expressions_equivalent, + _looks_like_math, + _split_multi_answer, + _try_parse_number, + _try_parse_sympy, + check_answer, + normalize_latex, +) + + +class TestNormalizeLatex: + def test_strips_dollar_signs(self): + assert normalize_latex("$x+1$") == "x+1" + + def test_strips_double_dollar_signs(self): + assert normalize_latex("$$x+1$$") == "x+1" + + def test_strips_trailing_period(self): + assert normalize_latex("$x+1$.") == "x+1" + + def test_strips_whitespace(self): + assert normalize_latex(" 42 ") == "42" + + def test_combined(self): + assert normalize_latex(" $\\frac{1}{2}$. ") == "\\frac{1}{2}" + + +class TestSplitMultiAnswer: + def test_single_answer(self): + assert _split_multi_answer("42") == ["42"] + + def test_comma_separated(self): + assert _split_multi_answer("1, 2, 3") == ["1", "2", "3"] + + def test_nested_braces(self): + result = _split_multi_answer("f(x,y), g(x)") + assert result == ["f(x,y)", "g(x)"] + + def test_empty_parts_filtered(self): + result = _split_multi_answer("1,,2") + assert result == ["1", "2"] + + +class TestCheckAnswer: + # --- Exact match --- + def test_exact_match(self): + result = check_answer("42", "42") + assert result["correct"] is True + assert result["method"] == "exact_match" + + def test_exact_match_with_latex(self): + result = check_answer("$\\frac{1}{2}$", "$\\frac{1}{2}$.") + assert result["correct"] is True + assert result["method"] == "exact_match" + + # --- Numeric --- + def test_numeric_match(self): + result = check_answer("3.0", "3") + assert result["correct"] is True + assert result["method"] == "numeric" + + def test_numeric_mismatch(self): + result = check_answer("4", "3") + assert result["correct"] is False + assert result["method"] == "numeric" + + def test_negative_numeric(self): + result = check_answer("-768", "-768.0") + assert result["correct"] is True + + # --- Multi-answer --- + def test_multi_answer_match(self): + result = check_answer("1, 2, 3", "1, 2, 3") + assert result["correct"] is True + + def test_multi_answer_reordered(self): + result = check_answer("3, 1, 2", "1, 2, 3") + assert result["correct"] is True + assert result["method"] == "multi_answer" + + def test_multi_answer_wrong_count(self): + result = check_answer("1, 2", "1, 2, 3") + assert result["correct"] is False + + # --- SymPy equivalence --- + def test_sympy_fraction(self): + result = check_answer("\\frac{1}{2}", "0.5") + assert result["correct"] is True + assert result["method"] == "sympy" + + def test_sympy_equivalent_expression(self): + result = check_answer("2^{3}", "8") + assert result["correct"] is True + + def test_sympy_mismatch(self): + result = check_answer("\\frac{1}{3}", "\\frac{1}{2}") + assert result["correct"] is False + + # --- String normalization --- + def test_string_case_insensitive(self): + result = check_answer("Algebra", "algebra") + assert result["correct"] is True + assert result["method"] == "string_normalized" + + # --- NaN/inf rejected --- + def test_nan_rejected(self): + result = check_answer("nan", "42") + assert result["correct"] is False + assert result["method"] != "numeric" + + def test_inf_rejected(self): + result = check_answer("inf", "inf") + assert result["method"] != "numeric" + + # --- Unbalanced brackets --- + def test_unbalanced_brackets(self): + parts = _split_multi_answer("a), b)") + assert len(parts) == 2 + + # --- No match --- + def test_no_match(self): + result = check_answer("foo", "bar") + assert result["correct"] is False + assert result["method"] == "no_match" + + +class TestNormalizeLatexDelimiters: + def test_backslash_parens(self): + assert normalize_latex(r"\(x+1\)") == "x+1" + + def test_backslash_brackets(self): + assert normalize_latex(r"\[x+1\]") == "x+1" + + def test_backslash_parens_with_period(self): + assert normalize_latex(r"\(x+1\).") == "x+1" + + +class TestTryParseNumber: + def test_integer(self): + assert _try_parse_number("42") == 42.0 + + def test_negative(self): + assert _try_parse_number("-3") == -3.0 + + def test_float(self): + assert _try_parse_number("3.14") == 3.14 + + def test_nan_rejected(self): + assert _try_parse_number("nan") is None + + def test_inf_rejected(self): + assert _try_parse_number("inf") is None + + def test_negative_inf_rejected(self): + assert _try_parse_number("-inf") is None + + def test_not_a_number(self): + assert _try_parse_number("abc") is None + + +class TestTryParseSympy: + def test_simple_expression(self): + expr = _try_parse_sympy(r"\frac{1}{2}") + assert expr is not None + + def test_empty_string(self): + expr = _try_parse_sympy("") + assert expr is None + + +class TestExpressionsEquivalent: + def test_equal(self): + from sympy import Rational + assert _expressions_equivalent(Rational(1, 2), Rational(2, 4)) is True + + def test_not_equal(self): + from sympy import Rational + assert _expressions_equivalent(Rational(1, 2), Rational(1, 3)) is False + + +class TestLooksLikeMath: + def test_latex_backslash(self): + assert _looks_like_math(r"\frac{1}{2}") is True + + def test_caret(self): + assert _looks_like_math("2^3") is True + + def test_plain_text(self): + assert _looks_like_math("hello") is False + + def test_math_function(self): + assert _looks_like_math("sqrt(2)") is True + + def test_log(self): + assert _looks_like_math("log(x)") is True + + def test_digit_operator(self): + assert _looks_like_math("2+3") is True diff --git a/imobench/eval/tests/test_cli.py b/imobench/eval/tests/test_cli.py new file mode 100644 index 0000000..26d463c --- /dev/null +++ b/imobench/eval/tests/test_cli.py @@ -0,0 +1,92 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the CLI module.""" + +import csv +import json +from pathlib import Path + +import pytest + +from imobench.eval.cli import main + + +@pytest.fixture +def sample_answerbench(tmp_path: Path) -> Path: + path = tmp_path / "answerbench.csv" + with open(path, "w", newline="", encoding="utf-8") as f: + writer = csv.writer(f) + writer.writerow( + ["Problem ID", "Problem", "Short Answer", "Category", "Subcategory", "Source"] + ) + writer.writerow(["p-001", "Q1", "2", "Algebra", "Op", "Test"]) + writer.writerow(["p-002", "Q2", "5", "Algebra", "Eq", "Test"]) + return path + + +@pytest.fixture +def sample_predictions(tmp_path: Path) -> Path: + path = tmp_path / "preds.csv" + with open(path, "w", newline="", encoding="utf-8") as f: + writer = csv.writer(f) + writer.writerow(["Problem ID", "Model Answer"]) + writer.writerow(["p-001", "2"]) + writer.writerow(["p-002", "3"]) + return path + + +class TestCli: + def test_text_output(self, sample_predictions, sample_answerbench, capsys): + main([str(sample_predictions), "--answerbench", str(sample_answerbench)]) + captured = capsys.readouterr() + assert "Accuracy" in captured.out + assert "1/2" in captured.out + + def test_json_output(self, sample_predictions, sample_answerbench, capsys): + main([ + str(sample_predictions), + "--answerbench", str(sample_answerbench), + "--format", "json", + ]) + captured = capsys.readouterr() + data = json.loads(captured.out) + assert "metrics" in data + assert "results" in data + assert data["metrics"]["overall"]["total"] == 2 + + def test_output_file(self, sample_predictions, sample_answerbench, tmp_path): + out_path = tmp_path / "results.json" + main([ + str(sample_predictions), + "--answerbench", str(sample_answerbench), + "--output", str(out_path), + ]) + assert out_path.exists() + data = json.loads(out_path.read_text(encoding="utf-8")) + assert data["metrics"]["overall"]["total"] == 2 + + def test_missing_file_exits(self): + with pytest.raises(SystemExit) as exc_info: + main(["nonexistent_file.csv"]) + assert exc_info.value.code == 1 + + def test_empty_predictions_exits(self, tmp_path): + path = tmp_path / "empty.csv" + with open(path, "w", newline="", encoding="utf-8") as f: + writer = csv.writer(f) + writer.writerow(["Problem ID", "Model Answer"]) + with pytest.raises(SystemExit) as exc_info: + main([str(path)]) + assert exc_info.value.code == 1 diff --git a/imobench/eval/tests/test_evaluate.py b/imobench/eval/tests/test_evaluate.py new file mode 100644 index 0000000..6eeae02 --- /dev/null +++ b/imobench/eval/tests/test_evaluate.py @@ -0,0 +1,123 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the benchmark evaluation runner.""" + +import csv +import json +from pathlib import Path + +import pytest + +from imobench.eval.evaluate import ( + evaluate_predictions, + load_predictions, +) + + +@pytest.fixture +def sample_answerbench(tmp_path: Path) -> Path: + """Create a minimal answerbench CSV for testing.""" + path = tmp_path / "answerbench.csv" + with open(path, "w", newline="", encoding="utf-8") as f: + writer = csv.writer(f) + writer.writerow( + ["Problem ID", "Problem", "Short Answer", "Category", "Subcategory", "Source"] + ) + writer.writerow(["p-001", "What is 1+1?", "2", "Algebra", "Operation", "Test"]) + writer.writerow( + ["p-002", "Simplify.", "$\\frac{1}{2}$", "Algebra", "Equation", "Test"] + ) + writer.writerow( + ["p-003", "Find all x.", "1, 2, 3", "Combinatorics", "Other", "Test"] + ) + return path + + +@pytest.fixture +def sample_predictions_csv(tmp_path: Path) -> Path: + """Create a sample predictions CSV.""" + path = tmp_path / "predictions.csv" + with open(path, "w", newline="", encoding="utf-8") as f: + writer = csv.writer(f) + writer.writerow(["Problem ID", "Model Answer"]) + writer.writerow(["p-001", "2"]) + writer.writerow(["p-002", "0.5"]) + writer.writerow(["p-003", "2, 1, 3"]) + return path + + +@pytest.fixture +def sample_predictions_jsonl(tmp_path: Path) -> Path: + """Create a sample predictions JSONL.""" + path = tmp_path / "predictions.jsonl" + with open(path, "w", encoding="utf-8") as f: + for pid, answer in [("p-001", "2"), ("p-002", "0.5"), ("p-003", "2, 1, 3")]: + f.write(json.dumps({"problem_id": pid, "answer": answer}) + "\n") + return path + + +class TestLoadPredictions: + def test_load_csv(self, sample_predictions_csv: Path): + preds = load_predictions(sample_predictions_csv) + assert preds["p-001"] == "2" + assert preds["p-002"] == "0.5" + assert len(preds) == 3 + + def test_load_jsonl(self, sample_predictions_jsonl: Path): + preds = load_predictions(sample_predictions_jsonl) + assert preds["p-001"] == "2" + assert preds["p-002"] == "0.5" + assert len(preds) == 3 + + +class TestEvaluatePredictions: + def test_all_correct( + self, sample_predictions_csv: Path, sample_answerbench: Path + ): + preds = load_predictions(sample_predictions_csv) + results = evaluate_predictions(preds, sample_answerbench) + assert len(results) == 3 + assert all(r["correct"] for r in results) + + def test_missing_prediction(self, sample_answerbench: Path): + preds = {"p-001": "2"} # Missing p-002, p-003 + results = evaluate_predictions(preds, sample_answerbench) + correct_count = sum(1 for r in results if r["correct"]) + assert correct_count == 1 + missing = [r for r in results if r["method"] == "missing"] + assert len(missing) == 2 + + def test_empty_problem_id_raises(self, tmp_path: Path): + path = tmp_path / "bad.csv" + with open(path, "w", newline="", encoding="utf-8") as f: + writer = csv.writer(f) + writer.writerow(["Problem ID", "Model Answer"]) + writer.writerow(["", "42"]) + with pytest.raises(ValueError, match="Missing problem ID"): + load_predictions(path) + + def test_result_structure( + self, sample_predictions_csv: Path, sample_answerbench: Path + ): + preds = load_predictions(sample_predictions_csv) + results = evaluate_predictions(preds, sample_answerbench) + for r in results: + assert "problem_id" in r + assert "category" in r + assert "subcategory" in r + assert "ground_truth" in r + assert "model_answer" in r + assert "correct" in r + assert "method" in r diff --git a/imobench/eval/tests/test_metrics.py b/imobench/eval/tests/test_metrics.py new file mode 100644 index 0000000..63939ec --- /dev/null +++ b/imobench/eval/tests/test_metrics.py @@ -0,0 +1,100 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the metrics computation module.""" + +from imobench.eval.metrics import compute_metrics, format_report + + +SAMPLE_RESULTS = [ + { + "problem_id": "p-001", + "category": "Algebra", + "subcategory": "Operation", + "source": "Test", + "ground_truth": "2", + "model_answer": "2", + "correct": True, + "method": "exact_match", + "details": "", + }, + { + "problem_id": "p-002", + "category": "Algebra", + "subcategory": "Equation", + "source": "Test", + "ground_truth": "$\\frac{1}{2}$", + "model_answer": "0.5", + "correct": True, + "method": "sympy", + "details": "", + }, + { + "problem_id": "p-003", + "category": "Combinatorics", + "subcategory": "Other", + "source": "Test", + "ground_truth": "5", + "model_answer": "3", + "correct": False, + "method": "numeric", + "details": "Expected 5, got 3", + }, +] + + +class TestComputeMetrics: + def test_overall_accuracy(self): + metrics = compute_metrics(SAMPLE_RESULTS) + assert metrics["overall"]["total"] == 3 + assert metrics["overall"]["correct"] == 2 + assert abs(metrics["overall"]["accuracy"] - 2 / 3) < 1e-9 + + def test_by_category(self): + metrics = compute_metrics(SAMPLE_RESULTS) + assert "Algebra" in metrics["by_category"] + assert metrics["by_category"]["Algebra"]["total"] == 2 + assert metrics["by_category"]["Algebra"]["correct"] == 2 + assert "Combinatorics" in metrics["by_category"] + assert metrics["by_category"]["Combinatorics"]["correct"] == 0 + + def test_by_subcategory(self): + metrics = compute_metrics(SAMPLE_RESULTS) + assert "Algebra/Operation" in metrics["by_subcategory"] + assert "Algebra/Equation" in metrics["by_subcategory"] + + def test_by_method(self): + metrics = compute_metrics(SAMPLE_RESULTS) + assert metrics["by_method"]["exact_match"] == 1 + assert metrics["by_method"]["sympy"] == 1 + assert metrics["by_method"]["numeric"] == 1 + + def test_empty_results(self): + metrics = compute_metrics([]) + assert metrics["overall"]["total"] == 0 + assert metrics["overall"]["accuracy"] == 0.0 + + +class TestFormatReport: + def test_report_contains_accuracy(self): + metrics = compute_metrics(SAMPLE_RESULTS) + report = format_report(metrics) + assert "2/3" in report + assert "66.7%" in report + + def test_report_contains_categories(self): + metrics = compute_metrics(SAMPLE_RESULTS) + report = format_report(metrics) + assert "Algebra" in report + assert "Combinatorics" in report