Skip to content
This repository was archived by the owner on Sep 18, 2025. It is now read-only.
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions simuleval/evaluator/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,19 +214,12 @@ def is_finished(self, instance) -> bool:
return instance.finish_prediction

def __call__(self, system):
system.reset()
for instance in self.instance_iterator:
while not self.is_finished(instance):
system.reset()
while not instance.finish_prediction:
input_segment = instance.send_source(self.source_segment_size)
output_segment = system.pushpop(input_segment)
instance.receive_prediction(output_segment)
if instance.finish_prediction:
# if instance.finish_prediction where set by the reader,
# source_finished_reading will be set as well. If it is
# set by any of the intermediate components, then we didn't
# end yet. We are going to clear the state and continue
# processing the rest of the input.
system.reset()

if not self.score_only:
self.write_log(instance)
Expand Down
50 changes: 44 additions & 6 deletions simuleval/evaluator/scorers/quality_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import re
import logging
import sacrebleu
from pathlib import Path
Expand Down Expand Up @@ -47,6 +46,22 @@ def add_sacrebleu_args(parser):
)


class WhisperNormalizer(object):
def __init__(self, en=True, char_level=False):
try:
from whisper.normalizers import EnglishTextNormalizer, BasicTextNormalizer
except ImportError:
raise ImportError("Please install whisper: pip install openai-whisper")
self.normalizer = EnglishTextNormalizer() if en else BasicTextNormalizer()
self.char_level = char_level

def tokenize(self, s):
s = self.normalizer(s)
if self.char_level:
s = " ".join(s.replace(" ", ""))
return s


@register_quality_scorer("WER")
class WERScorer(QualityScorer):
"""
Expand All @@ -64,23 +79,46 @@ def __init__(self, args) -> None:
raise ImportError("Please install editdistance to use WER scorer")
self.logger = logging.getLogger("simuleval.scorer.wer")
self.logger.warning("WER scorer only support language with spaces.")
self.logger.warning(
"Current WER scorer is on raw text (un-tokenized with punctuations)."
)

self.ed = ed
self.wer_whisper_en_norm = args.wer_whisper_en_norm
if self.wer_whisper_en_norm:
self.normalizer = WhisperNormalizer(True, False)
else:
self.logger.warning(
"Current WER scorer is on raw text (un-tokenized with punctuations)."
)

def __call__(self, instances: Dict) -> float:
distance = 0
ref_length = 0
for ins in instances.values():
distance += self.ed.eval(ins.prediction.split(), ins.reference.split())
ref_length += len(ins.reference.split())
if self.wer_whisper_en_norm:
reference = self.normalizer.tokenize(ins.reference)
prediction = self.normalizer.tokenize(ins.prediction)
else:
reference = ins.reference
prediction = ins.prediction
d = self.ed.eval(prediction.split(), reference.split())
distance += d
r = len(reference.split())
ref_length += r
print(f"index {ins.index} : {100*d/r} ")
if ref_length == 0:
self.logger.warning("Reference length is 0. Return WER as 0.")
return 0

return 100.0 * distance / ref_length

@staticmethod
def add_args(parser):
parser.add_argument(
"--wer-whisper-en-norm",
action="store_true",
default=False,
help="Apply Whisper English normalizer",
)

@classmethod
def from_args(cls, args):
return cls(args)
Expand Down