Skip to content
This repository was archived by the owner on Sep 18, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions simuleval/evaluator/scorers/quality_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from typing import Dict
from sacrebleu.metrics.bleu import BLEU
import subprocess
import editdistance
from simuleval.evaluator.scorers.tokenizer import EvaluationTokenizer

QUALITY_SCORERS_DICT = {}

Expand All @@ -34,6 +36,27 @@ def add_args(parser):
pass


def add_wer_args(parser):
parser.add_argument(
"--lowercase",
action="store_true",
default=False,
help="Lowercasing",
)
parser.add_argument(
"--remove-punct",
action="store_true",
default=False,
help="Remove puncturaiton marks",
)
parser.add_argument(
"--char-level",
action="store_true",
default=False,
help="Character-level evaluation",
)


def add_sacrebleu_args(parser):
parser.add_argument(
"--sacrebleu-tokenizer",
Expand All @@ -44,6 +67,68 @@ def add_sacrebleu_args(parser):
)


@register_quality_scorer("WER")
class WERScorer(QualityScorer):
"""
WER Scorer

Usage:
:code:`--quality-metrics WER`

Additional command line arguments:

.. argparse::
:ref: simuleval.evaluator.scorers.quality_scorer.add_wer_args
:passparser:
:prog:
"""

def __init__(
self,
tokenizer: str = "13a",
lowercase: bool = False,
remove_punct: bool = False,
char_level: bool = False,
) -> None:
super().__init__()
self.logger = logging.getLogger("simuleval.scorer.wer")
self.tokenizer = EvaluationTokenizer(
tokenizer_type=tokenizer,
lowercase=lowercase,
punctuation_removal=remove_punct,
character_tokenization=char_level,
)

def __call__(self, instances: Dict) -> float:
try:
hyps = [
self.tokenizer.tokenize(ins.prediction) for ins in instances.values()
]
refs = [
self.tokenizer.tokenize(ins.reference) for ins in instances.values()
]
err_rates = (
sum([editdistance.eval(hyp, ref) for hyp, ref in zip(hyps, refs)])
* 100
/ sum([len(ref) for ref in refs])
)
return err_rates
except Exception as e:
self.logger.error(str(e))
return 0

@staticmethod
def add_args(parser):
add_wer_args(parser)
add_sacrebleu_args(parser)

@classmethod
def from_args(cls, args):
return cls(
args.sacrebleu_tokenizer, args.lowercase, args.remove_punct, args.char_level
)


@register_quality_scorer("BLEU")
class SacreBLEUScorer(QualityScorer):
"""
Expand Down
80 changes: 80 additions & 0 deletions simuleval/evaluator/scorers/tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import unicodedata

import sacrebleu as sb

from fairseq.dataclass import ChoiceEnum

SACREBLEU_V2_ABOVE = int(sb.__version__[0]) >= 2


class EvaluationTokenizer(object):
"""A generic evaluation-time tokenizer, which leverages built-in tokenizers
in sacreBLEU (https://github.com/mjpost/sacrebleu). It additionally provides
lowercasing, punctuation removal and character tokenization, which are
applied after sacreBLEU tokenization.

Args:
tokenizer_type (str): the type of sacreBLEU tokenizer to apply.
lowercase (bool): lowercase the text.
punctuation_removal (bool): remove punctuation (based on unicode
category) from text.
character_tokenization (bool): tokenize the text to characters.
"""

SPACE = chr(32)
SPACE_ESCAPE = chr(9601)
_ALL_TOKENIZER_TYPES = (
sb.BLEU.TOKENIZERS
if SACREBLEU_V2_ABOVE
else ["none", "13a", "intl", "zh", "ja-mecab"]
)
ALL_TOKENIZER_TYPES = ChoiceEnum(_ALL_TOKENIZER_TYPES)

def __init__(
self,
tokenizer_type: str = "13a",
lowercase: bool = False,
punctuation_removal: bool = False,
character_tokenization: bool = False,
):

assert (
tokenizer_type in self._ALL_TOKENIZER_TYPES
), f"{tokenizer_type}, {self._ALL_TOKENIZER_TYPES}"
self.lowercase = lowercase
self.punctuation_removal = punctuation_removal
self.character_tokenization = character_tokenization
if SACREBLEU_V2_ABOVE:
self.tokenizer = sb.BLEU(tokenize=str(tokenizer_type)).tokenizer
else:
self.tokenizer = sb.tokenizers.TOKENIZERS[tokenizer_type]()

@classmethod
def remove_punctuation(cls, sent: str):
"""Remove punctuation based on Unicode category."""
return cls.SPACE.join(
t
for t in sent.split(cls.SPACE)
if not all(unicodedata.category(c)[0] == "P" for c in t)
)

def tokenize(self, sent: str):
tokenized = self.tokenizer(sent)

if self.punctuation_removal:
tokenized = self.remove_punctuation(tokenized)

if self.character_tokenization:
tokenized = self.SPACE.join(
list(tokenized.replace(self.SPACE, self.SPACE_ESCAPE))
)

if self.lowercase:
tokenized = tokenized.lower()

return tokenized