From c4b60273302d586465e1cad17e741ff21faa2e9e Mon Sep 17 00:00:00 2001 From: ibanesh <3632454+ibanesh@users.noreply.github.com> Date: Tue, 12 Sep 2023 11:23:58 -0700 Subject: [PATCH 1/2] Whisper normalization for WER scorer --- simuleval/evaluator/scorers/quality_scorer.py | 48 ++++++++++++++++--- 1 file changed, 42 insertions(+), 6 deletions(-) diff --git a/simuleval/evaluator/scorers/quality_scorer.py b/simuleval/evaluator/scorers/quality_scorer.py index a5a42ed7..12fc0eca 100644 --- a/simuleval/evaluator/scorers/quality_scorer.py +++ b/simuleval/evaluator/scorers/quality_scorer.py @@ -4,7 +4,6 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -import re import logging import sacrebleu from pathlib import Path @@ -47,6 +46,22 @@ def add_sacrebleu_args(parser): ) +class WhisperNormalizer(object): + def __init__(self, en=True, char_level=False): + try: + from whisper.normalizers import EnglishTextNormalizer, BasicTextNormalizer + except ImportError: + raise ImportError("Please install whisper: pip install openai-whisper") + self.normalizer = EnglishTextNormalizer() if en else BasicTextNormalizer() + self.char_level = char_level + + def tokenize(self, s): + s = self.normalizer(s) + if self.char_level: + s = " ".join(s.replace(" ", "")) + return s + + @register_quality_scorer("WER") class WERScorer(QualityScorer): """ @@ -64,23 +79,44 @@ def __init__(self, args) -> None: raise ImportError("Please install editdistance to use WER scorer") self.logger = logging.getLogger("simuleval.scorer.wer") self.logger.warning("WER scorer only support language with spaces.") - self.logger.warning( - "Current WER scorer is on raw text (un-tokenized with punctuations)." - ) + self.ed = ed + self.wer_whisper_en_norm = args.wer_whisper_en_norm + if self.wer_whisper_en_norm: + self.normalizer = WhisperNormalizer(True, False) + else: + self.logger.warning("Current WER scorer is on raw text (un-tokenized with punctuations).") def __call__(self, instances: Dict) -> float: distance = 0 ref_length = 0 for ins in instances.values(): - distance += self.ed.eval(ins.prediction.split(), ins.reference.split()) - ref_length += len(ins.reference.split()) + if self.wer_whisper_en_norm: + reference = self.normalizer.tokenize(ins.reference) + prediction = self.normalizer.tokenize(ins.prediction) + else: + reference = ins.reference + prediction = ins.prediction + d = self.ed.eval(prediction.split(), reference.split()) + distance += d + r = len(reference.split()) + ref_length += r + print(f"index {ins.index} : {100*d/r} ") if ref_length == 0: self.logger.warning("Reference length is 0. Return WER as 0.") return 0 return 100.0 * distance / ref_length + @staticmethod + def add_args(parser): + parser.add_argument( + "--wer-whisper-en-norm", + action="store_true", + default=False, + help="Apply Whisper English normalizer", + ) + @classmethod def from_args(cls, args): return cls(args) From 0c2a48e003a9f337a3d7cf00ae390fab78f084a8 Mon Sep 17 00:00:00 2001 From: ibanesh <3632454+ibanesh@users.noreply.github.com> Date: Tue, 12 Sep 2023 11:24:40 -0700 Subject: [PATCH 2/2] Revert evaluator changes --- simuleval/evaluator/evaluator.py | 11 ++--------- simuleval/evaluator/scorers/quality_scorer.py | 4 +++- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/simuleval/evaluator/evaluator.py b/simuleval/evaluator/evaluator.py index 6b766f1c..a577047d 100644 --- a/simuleval/evaluator/evaluator.py +++ b/simuleval/evaluator/evaluator.py @@ -214,19 +214,12 @@ def is_finished(self, instance) -> bool: return instance.finish_prediction def __call__(self, system): - system.reset() for instance in self.instance_iterator: - while not self.is_finished(instance): + system.reset() + while not instance.finish_prediction: input_segment = instance.send_source(self.source_segment_size) output_segment = system.pushpop(input_segment) instance.receive_prediction(output_segment) - if instance.finish_prediction: - # if instance.finish_prediction where set by the reader, - # source_finished_reading will be set as well. If it is - # set by any of the intermediate components, then we didn't - # end yet. We are going to clear the state and continue - # processing the rest of the input. - system.reset() if not self.score_only: self.write_log(instance) diff --git a/simuleval/evaluator/scorers/quality_scorer.py b/simuleval/evaluator/scorers/quality_scorer.py index 12fc0eca..014ba4c0 100644 --- a/simuleval/evaluator/scorers/quality_scorer.py +++ b/simuleval/evaluator/scorers/quality_scorer.py @@ -85,7 +85,9 @@ def __init__(self, args) -> None: if self.wer_whisper_en_norm: self.normalizer = WhisperNormalizer(True, False) else: - self.logger.warning("Current WER scorer is on raw text (un-tokenized with punctuations).") + self.logger.warning( + "Current WER scorer is on raw text (un-tokenized with punctuations)." + ) def __call__(self, instances: Dict) -> float: distance = 0