diff --git a/simuleval/evaluator/scorers/quality_scorer.py b/simuleval/evaluator/scorers/quality_scorer.py index 7611b18b..d36167a6 100644 --- a/simuleval/evaluator/scorers/quality_scorer.py +++ b/simuleval/evaluator/scorers/quality_scorer.py @@ -10,6 +10,8 @@ from typing import Dict from sacrebleu.metrics.bleu import BLEU import subprocess +import editdistance +from simuleval.evaluator.scorers.tokenizer import EvaluationTokenizer QUALITY_SCORERS_DICT = {} @@ -34,6 +36,27 @@ def add_args(parser): pass +def add_wer_args(parser): + parser.add_argument( + "--lowercase", + action="store_true", + default=False, + help="Lowercasing", + ) + parser.add_argument( + "--remove-punct", + action="store_true", + default=False, + help="Remove puncturaiton marks", + ) + parser.add_argument( + "--char-level", + action="store_true", + default=False, + help="Character-level evaluation", + ) + + def add_sacrebleu_args(parser): parser.add_argument( "--sacrebleu-tokenizer", @@ -44,6 +67,68 @@ def add_sacrebleu_args(parser): ) +@register_quality_scorer("WER") +class WERScorer(QualityScorer): + """ + WER Scorer + + Usage: + :code:`--quality-metrics WER` + + Additional command line arguments: + + .. argparse:: + :ref: simuleval.evaluator.scorers.quality_scorer.add_wer_args + :passparser: + :prog: + """ + + def __init__( + self, + tokenizer: str = "13a", + lowercase: bool = False, + remove_punct: bool = False, + char_level: bool = False, + ) -> None: + super().__init__() + self.logger = logging.getLogger("simuleval.scorer.wer") + self.tokenizer = EvaluationTokenizer( + tokenizer_type=tokenizer, + lowercase=lowercase, + punctuation_removal=remove_punct, + character_tokenization=char_level, + ) + + def __call__(self, instances: Dict) -> float: + try: + hyps = [ + self.tokenizer.tokenize(ins.prediction) for ins in instances.values() + ] + refs = [ + self.tokenizer.tokenize(ins.reference) for ins in instances.values() + ] + err_rates = ( + sum([editdistance.eval(hyp, ref) for hyp, ref in zip(hyps, refs)]) + * 100 + / sum([len(ref) for ref in refs]) + ) + return err_rates + except Exception as e: + self.logger.error(str(e)) + return 0 + + @staticmethod + def add_args(parser): + add_wer_args(parser) + add_sacrebleu_args(parser) + + @classmethod + def from_args(cls, args): + return cls( + args.sacrebleu_tokenizer, args.lowercase, args.remove_punct, args.char_level + ) + + @register_quality_scorer("BLEU") class SacreBLEUScorer(QualityScorer): """ diff --git a/simuleval/evaluator/scorers/tokenizer.py b/simuleval/evaluator/scorers/tokenizer.py new file mode 100644 index 00000000..b0cedd50 --- /dev/null +++ b/simuleval/evaluator/scorers/tokenizer.py @@ -0,0 +1,80 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unicodedata + +import sacrebleu as sb + +from fairseq.dataclass import ChoiceEnum + +SACREBLEU_V2_ABOVE = int(sb.__version__[0]) >= 2 + + +class EvaluationTokenizer(object): + """A generic evaluation-time tokenizer, which leverages built-in tokenizers + in sacreBLEU (https://github.com/mjpost/sacrebleu). It additionally provides + lowercasing, punctuation removal and character tokenization, which are + applied after sacreBLEU tokenization. + + Args: + tokenizer_type (str): the type of sacreBLEU tokenizer to apply. + lowercase (bool): lowercase the text. + punctuation_removal (bool): remove punctuation (based on unicode + category) from text. + character_tokenization (bool): tokenize the text to characters. + """ + + SPACE = chr(32) + SPACE_ESCAPE = chr(9601) + _ALL_TOKENIZER_TYPES = ( + sb.BLEU.TOKENIZERS + if SACREBLEU_V2_ABOVE + else ["none", "13a", "intl", "zh", "ja-mecab"] + ) + ALL_TOKENIZER_TYPES = ChoiceEnum(_ALL_TOKENIZER_TYPES) + + def __init__( + self, + tokenizer_type: str = "13a", + lowercase: bool = False, + punctuation_removal: bool = False, + character_tokenization: bool = False, + ): + + assert ( + tokenizer_type in self._ALL_TOKENIZER_TYPES + ), f"{tokenizer_type}, {self._ALL_TOKENIZER_TYPES}" + self.lowercase = lowercase + self.punctuation_removal = punctuation_removal + self.character_tokenization = character_tokenization + if SACREBLEU_V2_ABOVE: + self.tokenizer = sb.BLEU(tokenize=str(tokenizer_type)).tokenizer + else: + self.tokenizer = sb.tokenizers.TOKENIZERS[tokenizer_type]() + + @classmethod + def remove_punctuation(cls, sent: str): + """Remove punctuation based on Unicode category.""" + return cls.SPACE.join( + t + for t in sent.split(cls.SPACE) + if not all(unicodedata.category(c)[0] == "P" for c in t) + ) + + def tokenize(self, sent: str): + tokenized = self.tokenizer(sent) + + if self.punctuation_removal: + tokenized = self.remove_punctuation(tokenized) + + if self.character_tokenization: + tokenized = self.SPACE.join( + list(tokenized.replace(self.SPACE, self.SPACE_ESCAPE)) + ) + + if self.lowercase: + tokenized = tokenized.lower() + + return tokenized