diff --git a/metrics/cer/cer.py b/metrics/cer/cer.py index c5f4a9072..a9ff290e4 100644 --- a/metrics/cer/cer.py +++ b/metrics/cer/cer.py @@ -29,39 +29,41 @@ else: import importlib.metadata as importlib_metadata - -SENTENCE_DELIMITER = "" - - -if version.parse(importlib_metadata.version("jiwer")) < version.parse("2.3.0"): - - class SentencesToListOfCharacters(tr.AbstractTransform): - def __init__(self, sentence_delimiter: str = " "): - self.sentence_delimiter = sentence_delimiter - - def process_string(self, s: str): - return list(s) - - def process_list(self, inp: List[str]): - chars = [] - for sent_idx, sentence in enumerate(inp): - chars.extend(self.process_string(sentence)) - if self.sentence_delimiter is not None and self.sentence_delimiter != "" and sent_idx < len(inp) - 1: - chars.append(self.sentence_delimiter) - return chars - - cer_transform = tr.Compose( - [tr.RemoveMultipleSpaces(), tr.Strip(), SentencesToListOfCharacters(SENTENCE_DELIMITER)] - ) -else: - cer_transform = tr.Compose( - [ - tr.RemoveMultipleSpaces(), - tr.Strip(), - tr.ReduceToSingleSentence(SENTENCE_DELIMITER), - tr.ReduceToListOfListOfChars(), - ] - ) +if hasattr(jiwer, "compute_measures"): + SENTENCE_DELIMITER = "" + if version.parse(importlib_metadata.version("jiwer")) < version.parse("2.3.0"): + + class SentencesToListOfCharacters(tr.AbstractTransform): + def __init__(self, sentence_delimiter: str = " "): + self.sentence_delimiter = sentence_delimiter + + def process_string(self, s: str): + return list(s) + + def process_list(self, inp: List[str]): + chars = [] + for sent_idx, sentence in enumerate(inp): + chars.extend(self.process_string(sentence)) + if ( + self.sentence_delimiter is not None + and self.sentence_delimiter != "" + and sent_idx < len(inp) - 1 + ): + chars.append(self.sentence_delimiter) + return chars + + cer_transform = tr.Compose( + [tr.RemoveMultipleSpaces(), tr.Strip(), SentencesToListOfCharacters(SENTENCE_DELIMITER)] + ) + else: + cer_transform = tr.Compose( + [ + tr.RemoveMultipleSpaces(), + tr.Strip(), + tr.ReduceToSingleSentence(SENTENCE_DELIMITER), + tr.ReduceToListOfListOfChars(), + ] + ) _CITATION = """\ @@ -136,24 +138,43 @@ def _info(self): ) def _compute(self, predictions, references, concatenate_texts=False): - if concatenate_texts: - return jiwer.compute_measures( - references, - predictions, - truth_transform=cer_transform, - hypothesis_transform=cer_transform, - )["wer"] - - incorrect = 0 - total = 0 - for prediction, reference in zip(predictions, references): - measures = jiwer.compute_measures( - reference, - prediction, - truth_transform=cer_transform, - hypothesis_transform=cer_transform, - ) - incorrect += measures["substitutions"] + measures["deletions"] + measures["insertions"] - total += measures["substitutions"] + measures["deletions"] + measures["hits"] - - return incorrect / total + if hasattr(jiwer, "compute_measures"): + if concatenate_texts: + return jiwer.compute_measures( + references, + predictions, + truth_transform=cer_transform, + hypothesis_transform=cer_transform, + )["wer"] + + incorrect = 0 + total = 0 + for prediction, reference in zip(predictions, references): + measures = jiwer.compute_measures( + reference, + prediction, + truth_transform=cer_transform, + hypothesis_transform=cer_transform, + ) + incorrect += measures["substitutions"] + measures["deletions"] + measures["insertions"] + total += measures["substitutions"] + measures["deletions"] + measures["hits"] + + return incorrect / total + else: + if concatenate_texts: + return jiwer.process_characters( + references, + predictions, + ).cer + + incorrect = 0 + total = 0 + for prediction, reference in zip(predictions, references): + measures = jiwer.process_characters( + reference, + prediction, + ) + incorrect += measures.substitutions + measures.deletions + measures.insertions + total += measures.substitutions + measures.deletions + measures.hits + + return incorrect / total diff --git a/metrics/wer/wer.py b/metrics/wer/wer.py index 214d5b22e..f63aad021 100644 --- a/metrics/wer/wer.py +++ b/metrics/wer/wer.py @@ -14,7 +14,7 @@ """ Word Error Ratio (WER) metric. """ import datasets -from jiwer import compute_measures +import jiwer import evaluate @@ -94,13 +94,25 @@ def _info(self): ) def _compute(self, predictions=None, references=None, concatenate_texts=False): - if concatenate_texts: - return compute_measures(references, predictions)["wer"] + if hasattr(jiwer, "compute_measures"): + if concatenate_texts: + return jiwer.compute_measures(references, predictions)["wer"] + else: + incorrect = 0 + total = 0 + for prediction, reference in zip(predictions, references): + measures = jiwer.compute_measures(reference, prediction) + incorrect += measures["substitutions"] + measures["deletions"] + measures["insertions"] + total += measures["substitutions"] + measures["deletions"] + measures["hits"] + return incorrect / total else: - incorrect = 0 - total = 0 - for prediction, reference in zip(predictions, references): - measures = compute_measures(reference, prediction) - incorrect += measures["substitutions"] + measures["deletions"] + measures["insertions"] - total += measures["substitutions"] + measures["deletions"] + measures["hits"] - return incorrect / total + if concatenate_texts: + return jiwer.process_words(references, predictions).wer + else: + incorrect = 0 + total = 0 + for prediction, reference in zip(predictions, references): + measures = jiwer.process_words(reference, prediction) + incorrect += measures.substitutions + measures.deletions + measures.insertions + total += measures.substitutions + measures.deletions + measures.hits + return incorrect / total diff --git a/metrics/xtreme_s/xtreme_s.py b/metrics/xtreme_s/xtreme_s.py index b4c052fc5..9c9b5b610 100644 --- a/metrics/xtreme_s/xtreme_s.py +++ b/metrics/xtreme_s/xtreme_s.py @@ -91,13 +91,14 @@ SENTENCE_DELIMITER = "" try: - from jiwer import transforms as tr + import jiwer _jiwer_available = True except ImportError: _jiwer_available = False if _jiwer_available and version.parse(importlib_metadata.version("jiwer")) < version.parse("2.3.0"): + from jiwer import transforms as tr class SentencesToListOfCharacters(tr.AbstractTransform): def __init__(self, sentence_delimiter: str = " "): @@ -117,7 +118,9 @@ def process_list(self, inp: List[str]): cer_transform = tr.Compose( [tr.RemoveMultipleSpaces(), tr.Strip(), SentencesToListOfCharacters(SENTENCE_DELIMITER)] ) -elif _jiwer_available: +elif _jiwer_available and hasattr(jiwer, "compute_measures"): + from jiwer import transforms as tr + cer_transform = tr.Compose( [ tr.RemoveMultipleSpaces(), @@ -187,35 +190,59 @@ def bleu( def wer_and_cer(preds, labels, concatenate_texts, config_name): try: - from jiwer import compute_measures + import jiwer except ImportError: raise ValueError( f"jiwer has to be installed in order to apply the wer metric for {config_name}." "You can install it via `pip install jiwer`." ) - if concatenate_texts: - wer = compute_measures(labels, preds)["wer"] + if hasattr(jiwer, "compute_measures"): + if concatenate_texts: + wer = jiwer.compute_measures(labels, preds)["wer"] - cer = compute_measures(labels, preds, truth_transform=cer_transform, hypothesis_transform=cer_transform)["wer"] - return {"wer": wer, "cer": cer} + cer = jiwer.compute_measures( + labels, preds, truth_transform=cer_transform, hypothesis_transform=cer_transform + )["wer"] + return {"wer": wer, "cer": cer} + else: + + def compute_score(preds, labels, score_type="wer"): + incorrect = 0 + total = 0 + for prediction, reference in zip(preds, labels): + if score_type == "wer": + measures = jiwer.compute_measures(reference, prediction) + elif score_type == "cer": + measures = jiwer.compute_measures( + reference, prediction, truth_transform=cer_transform, hypothesis_transform=cer_transform + ) + incorrect += measures["substitutions"] + measures["deletions"] + measures["insertions"] + total += measures["substitutions"] + measures["deletions"] + measures["hits"] + return incorrect / total + + return {"wer": compute_score(preds, labels, "wer"), "cer": compute_score(preds, labels, "cer")} else: + if concatenate_texts: + wer = jiwer.process_words(labels, preds).wer + + cer = jiwer.process_characters(labels, preds).cer + return {"wer": wer, "cer": cer} + else: - def compute_score(preds, labels, score_type="wer"): - incorrect = 0 - total = 0 - for prediction, reference in zip(preds, labels): - if score_type == "wer": - measures = compute_measures(reference, prediction) - elif score_type == "cer": - measures = compute_measures( - reference, prediction, truth_transform=cer_transform, hypothesis_transform=cer_transform - ) - incorrect += measures["substitutions"] + measures["deletions"] + measures["insertions"] - total += measures["substitutions"] + measures["deletions"] + measures["hits"] - return incorrect / total - - return {"wer": compute_score(preds, labels, "wer"), "cer": compute_score(preds, labels, "cer")} + def compute_score(preds, labels, score_type="wer"): + incorrect = 0 + total = 0 + for prediction, reference in zip(preds, labels): + if score_type == "wer": + measures = jiwer.process_words(reference, prediction) + elif score_type == "cer": + measures = jiwer.process_characters(reference, prediction) + incorrect += measures.substitutions + measures.deletions + measures.insertions + total += measures.substitutions + measures.deletions + measures.hits + return incorrect / total + + return {"wer": compute_score(preds, labels, "wer"), "cer": compute_score(preds, labels, "cer")} @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)