Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions autorag_research/evaluation/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,22 @@
from autorag_research.evaluation.metrics.generation import (
BertScoreConfig,
BleuConfig,
ExactMatchConfig,
MeteorConfig,
ResponseRelevancyConfig,
RougeConfig,
SemScoreConfig,
TokenF1Config,
UniEvalConfig,
bert_score,
bleu,
exact_match,
huggingface_evaluate,
meteor,
response_relevancy,
rouge,
sem_score,
token_f1,
unieval,
)
from autorag_research.evaluation.metrics.retrieval import (
Expand Down Expand Up @@ -47,6 +51,7 @@
__all__ = [
"BertScoreConfig",
"BleuConfig",
"ExactMatchConfig",
"F1Config",
"FullRecallConfig",
"MAPConfig",
Expand All @@ -58,12 +63,14 @@
"ResponseRelevancyConfig",
"RougeConfig",
"SemScoreConfig",
"TokenF1Config",
"UniEvalConfig",
"bert_score",
"bleu",
"calculate_cosine_similarity",
"calculate_inner_product",
"calculate_l2_distance",
"exact_match",
"huggingface_evaluate",
"meteor",
"metric",
Expand All @@ -78,5 +85,6 @@
"retrieval_recall",
"rouge",
"sem_score",
"token_f1",
"unieval",
]
97 changes: 92 additions & 5 deletions autorag_research/evaluation/metrics/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import os
import re
from collections import Counter
from collections.abc import Callable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, cast
Expand All @@ -24,7 +25,7 @@
from autorag_research.exceptions import EmbeddingError
from autorag_research.injection import with_embedding, with_llm
from autorag_research.schema import MetricInput
from autorag_research.util import convert_inputs_to_list, truncate_texts, unpack_and_run
from autorag_research.util import convert_inputs_to_list, normalize_string, truncate_texts, unpack_and_run

logger = logging.getLogger("AutoRAG-Research")

Expand Down Expand Up @@ -60,6 +61,55 @@
UNIEVAL_DIMENSIONS = ("coherence", "consistency", "fluency", "relevance")


def _get_normalized_tokens(text: str) -> list[str]:
"""Normalize text with SQuAD rules and split into tokens."""
return normalize_string(text).split()


def _score_generation_against_references(
prediction: str,
references: list[str],
scorer: Callable[[str, str], float],
) -> float:
"""Return the best score for a prediction across all references."""
return max(scorer(prediction, reference) for reference in references)


def _compute_generation_reference_scores(
metric_inputs: list[MetricInput],
scorer: Callable[[str, str], float],
) -> list[float]:
"""Compute best-reference scores for generation metric inputs."""
scores = []
for metric_input in metric_inputs:
generated_text = cast(str, metric_input.generated_texts)
generation_gt = cast(list[str], metric_input.generation_gt)
scores.append(_score_generation_against_references(generated_text, generation_gt, scorer))
return scores


def _exact_match_score(prediction: str, reference: str) -> float:
"""Return binary exact match after SQuAD-style normalization."""
return float(normalize_string(prediction) == normalize_string(reference))


def _token_f1_score(prediction: str, reference: str) -> float:
"""Compute SQuAD-style token F1 for one prediction/reference pair."""
prediction_tokens = _get_normalized_tokens(prediction)
reference_tokens = _get_normalized_tokens(reference)

if not prediction_tokens or not reference_tokens:
return float(prediction_tokens == reference_tokens)

overlap = sum((Counter(prediction_tokens) & Counter(reference_tokens)).values())
if overlap == 0:
return 0.0

precision = overlap / len(prediction_tokens)
recall = overlap / len(reference_tokens)
return float((2 * precision * recall) / (precision + recall))


def _extract_llm_text(response: Any) -> str:
"""Extract text from LLM response object."""
if hasattr(response, "content"):
Expand Down Expand Up @@ -242,11 +292,10 @@ def huggingface_evaluate(instance: Any, key: str, metric_inputs: list[MetricInpu
The list of scores.
"""

def compute_score(gt: list[str], pred: str) -> float:
return max([instance.compute(predictions=[pred], references=[x], **kwargs)[key] for x in gt])
def compute_score(prediction: str, reference: str) -> float:
return instance.compute(predictions=[prediction], references=[reference], **kwargs)[key]

result = [compute_score(x.generation_gt, x.generated_texts) for x in metric_inputs] # ty: ignore
return result
return _compute_generation_reference_scores(metric_inputs, compute_score)


@metric_loop(fields_to_check=["generation_gt", "generated_texts"])
Expand Down Expand Up @@ -365,6 +414,18 @@ def rouge(
return result


@metric_loop(fields_to_check=["generation_gt", "generated_texts"])
def exact_match(metric_inputs: list[MetricInput]) -> list[float]:
"""Compute SQuAD-style exact match for generation."""
return _compute_generation_reference_scores(metric_inputs, _exact_match_score)


@metric_loop(fields_to_check=["generation_gt", "generated_texts"])
def token_f1(metric_inputs: list[MetricInput]) -> list[float]:
"""Compute SQuAD-style token F1 for generation."""
return _compute_generation_reference_scores(metric_inputs, _token_f1_score)


@metric_loop(fields_to_check=["generation_gt", "generated_texts"])
@with_embedding()
def sem_score(
Expand Down Expand Up @@ -674,6 +735,32 @@ def get_metric_kwargs(self) -> dict[str, Any]:
}


@dataclass
class ExactMatchConfig(BaseGenerationMetricConfig):
"""Configuration for SQuAD-style exact match metric."""

def get_metric_func(self) -> Callable:
"""Return the metric function."""
return exact_match

def get_metric_kwargs(self) -> dict[str, Any]:
"""Return kwargs for the metric function."""
return {}


@dataclass
class TokenF1Config(BaseGenerationMetricConfig):
"""Configuration for SQuAD-style token F1 metric."""

def get_metric_func(self) -> Callable:
"""Return the metric function."""
return token_f1

def get_metric_kwargs(self) -> dict[str, Any]:
"""Return kwargs for the metric function."""
return {}


@dataclass
class SemScoreConfig(BaseGenerationMetricConfig):
"""Configuration for SemScore (semantic similarity) metric.
Expand Down
2 changes: 2 additions & 0 deletions configs/metrics/generation/exact_match.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
_target_: autorag_research.evaluation.metrics.generation.ExactMatchConfig
description: "Exact Match (EM) - SQuAD-style normalized exact answer match"
2 changes: 2 additions & 0 deletions configs/metrics/generation/token_f1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
_target_: autorag_research.evaluation.metrics.generation.TokenF1Config
description: "Token F1 - SQuAD-style bag-of-tokens answer overlap"
2 changes: 2 additions & 0 deletions tests/autorag_research/cli/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,9 @@ def test_discover_metrics_finds_real_generation_configs(self, real_config_path:
"""discover_metrics finds rouge in real configs/metrics/generation/."""
result = discover_metrics("generation")

assert "exact_match" in result
assert "rouge" in result
assert "token_f1" in result
assert "unieval_coherence" in result
assert "unieval_consistency" in result
assert "unieval_fluency" in result
Expand Down
87 changes: 87 additions & 0 deletions tests/autorag_research/evaluation/metrics/test_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,16 @@

from autorag_research import cli
from autorag_research.evaluation.metrics.generation import (
ExactMatchConfig,
TokenF1Config,
bert_score,
bleu,
exact_match,
meteor,
response_relevancy,
rouge,
sem_score,
token_f1,
)
from autorag_research.schema import MetricInput
from tests.mock import mock_embed_documents
Expand Down Expand Up @@ -351,3 +355,86 @@ def test_response_relevancy_invalid_json_returns_nan():
scores = response_relevancy(metric_inputs, llm=llm, embedding_model=KeywordEmbeddings(), strictness=3)

assert scores[0] != scores[0] # NaN check


def test_exact_match_uses_squad_normalization():
metric_inputs = [
MetricInput(generated_texts="The Eiffel Tower!", generation_gt=["eiffel tower"]),
MetricInput(generated_texts="Paris, France", generation_gt=["France Paris"]),
]

scores = exact_match(metric_inputs)

assert scores == [1.0, 0.0]


def test_exact_match_returns_best_score_across_references():
metric_inputs = [
MetricInput(
generated_texts="Pacific Ocean",
generation_gt=["atlantic ocean", "the pacific ocean"],
)
]

assert exact_match(metric_inputs) == [1.0]


def test_exact_match_handles_empty_normalized_answers():
metric_inputs = [
MetricInput(generated_texts="the", generation_gt=["an"]),
MetricInput(generated_texts="the", generation_gt=["cat"]),
]

scores = exact_match(metric_inputs)

assert scores == [1.0, 0.0]


def test_token_f1_counts_repeated_tokens_with_bag_of_words_overlap():
metric_inputs = [
MetricInput(
generated_texts="red red blue",
generation_gt=["red blue blue"],
)
]

scores = token_f1(metric_inputs)

assert scores == [pytest.approx(2 / 3)]


def test_token_f1_uses_best_reference_overlap():
metric_inputs = [
MetricInput(
generated_texts="red blue",
generation_gt=["red green", "red blue yellow"],
)
]

scores = token_f1(metric_inputs)

assert scores == [pytest.approx(0.8)]


def test_token_f1_handles_empty_normalized_answers():
metric_inputs = [
MetricInput(generated_texts="the", generation_gt=["an"]),
MetricInput(generated_texts="the", generation_gt=["cat"]),
]

scores = token_f1(metric_inputs)

assert scores == [1.0, 0.0]


def test_new_metric_configs_expose_metric_functions_and_names():
exact_match_config = ExactMatchConfig()
token_f1_config = TokenF1Config()

assert exact_match_config.get_metric_name() == "exact_match"
assert exact_match_config.get_metric_func() is exact_match
assert exact_match_config.get_metric_kwargs() == {}

assert token_f1_config.get_metric_name() == "token_f1"
assert token_f1_config.get_metric_func() is token_f1
assert token_f1_config.get_metric_kwargs() == {}
Loading