Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 75 additions & 54 deletions metrics/cer/cer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,39 +29,41 @@
else:
import importlib.metadata as importlib_metadata


SENTENCE_DELIMITER = ""


if version.parse(importlib_metadata.version("jiwer")) < version.parse("2.3.0"):

class SentencesToListOfCharacters(tr.AbstractTransform):
def __init__(self, sentence_delimiter: str = " "):
self.sentence_delimiter = sentence_delimiter

def process_string(self, s: str):
return list(s)

def process_list(self, inp: List[str]):
chars = []
for sent_idx, sentence in enumerate(inp):
chars.extend(self.process_string(sentence))
if self.sentence_delimiter is not None and self.sentence_delimiter != "" and sent_idx < len(inp) - 1:
chars.append(self.sentence_delimiter)
return chars

cer_transform = tr.Compose(
[tr.RemoveMultipleSpaces(), tr.Strip(), SentencesToListOfCharacters(SENTENCE_DELIMITER)]
)
else:
cer_transform = tr.Compose(
[
tr.RemoveMultipleSpaces(),
tr.Strip(),
tr.ReduceToSingleSentence(SENTENCE_DELIMITER),
tr.ReduceToListOfListOfChars(),
]
)
if hasattr(jiwer, "compute_measures"):
SENTENCE_DELIMITER = ""
if version.parse(importlib_metadata.version("jiwer")) < version.parse("2.3.0"):

class SentencesToListOfCharacters(tr.AbstractTransform):
def __init__(self, sentence_delimiter: str = " "):
self.sentence_delimiter = sentence_delimiter

def process_string(self, s: str):
return list(s)

def process_list(self, inp: List[str]):
chars = []
for sent_idx, sentence in enumerate(inp):
chars.extend(self.process_string(sentence))
if (
self.sentence_delimiter is not None
and self.sentence_delimiter != ""
and sent_idx < len(inp) - 1
):
chars.append(self.sentence_delimiter)
return chars

cer_transform = tr.Compose(
[tr.RemoveMultipleSpaces(), tr.Strip(), SentencesToListOfCharacters(SENTENCE_DELIMITER)]
)
else:
cer_transform = tr.Compose(
[
tr.RemoveMultipleSpaces(),
tr.Strip(),
tr.ReduceToSingleSentence(SENTENCE_DELIMITER),
tr.ReduceToListOfListOfChars(),
]
)


_CITATION = """\
Expand Down Expand Up @@ -136,24 +138,43 @@ def _info(self):
)

def _compute(self, predictions, references, concatenate_texts=False):
if concatenate_texts:
return jiwer.compute_measures(
references,
predictions,
truth_transform=cer_transform,
hypothesis_transform=cer_transform,
)["wer"]

incorrect = 0
total = 0
for prediction, reference in zip(predictions, references):
measures = jiwer.compute_measures(
reference,
prediction,
truth_transform=cer_transform,
hypothesis_transform=cer_transform,
)
incorrect += measures["substitutions"] + measures["deletions"] + measures["insertions"]
total += measures["substitutions"] + measures["deletions"] + measures["hits"]

return incorrect / total
if hasattr(jiwer, "compute_measures"):
if concatenate_texts:
return jiwer.compute_measures(
references,
predictions,
truth_transform=cer_transform,
hypothesis_transform=cer_transform,
)["wer"]

incorrect = 0
total = 0
for prediction, reference in zip(predictions, references):
measures = jiwer.compute_measures(
reference,
prediction,
truth_transform=cer_transform,
hypothesis_transform=cer_transform,
)
incorrect += measures["substitutions"] + measures["deletions"] + measures["insertions"]
total += measures["substitutions"] + measures["deletions"] + measures["hits"]

return incorrect / total
else:
if concatenate_texts:
return jiwer.process_characters(
references,
predictions,
).cer

incorrect = 0
total = 0
for prediction, reference in zip(predictions, references):
measures = jiwer.process_characters(
reference,
prediction,
)
incorrect += measures.substitutions + measures.deletions + measures.insertions
total += measures.substitutions + measures.deletions + measures.hits

return incorrect / total
32 changes: 22 additions & 10 deletions metrics/wer/wer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
""" Word Error Ratio (WER) metric. """

import datasets
from jiwer import compute_measures
import jiwer

import evaluate

Expand Down Expand Up @@ -94,13 +94,25 @@ def _info(self):
)

def _compute(self, predictions=None, references=None, concatenate_texts=False):
if concatenate_texts:
return compute_measures(references, predictions)["wer"]
if hasattr(jiwer, "compute_measures"):
if concatenate_texts:
return jiwer.compute_measures(references, predictions)["wer"]
else:
incorrect = 0
total = 0
for prediction, reference in zip(predictions, references):
measures = jiwer.compute_measures(reference, prediction)
incorrect += measures["substitutions"] + measures["deletions"] + measures["insertions"]
total += measures["substitutions"] + measures["deletions"] + measures["hits"]
return incorrect / total
else:
incorrect = 0
total = 0
for prediction, reference in zip(predictions, references):
measures = compute_measures(reference, prediction)
incorrect += measures["substitutions"] + measures["deletions"] + measures["insertions"]
total += measures["substitutions"] + measures["deletions"] + measures["hits"]
return incorrect / total
if concatenate_texts:
return jiwer.process_words(references, predictions).wer
else:
incorrect = 0
total = 0
for prediction, reference in zip(predictions, references):
measures = jiwer.process_words(reference, prediction)
incorrect += measures.substitutions + measures.deletions + measures.insertions
total += measures.substitutions + measures.deletions + measures.hits
return incorrect / total
71 changes: 49 additions & 22 deletions metrics/xtreme_s/xtreme_s.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,14 @@
SENTENCE_DELIMITER = ""

try:
from jiwer import transforms as tr
import jiwer

_jiwer_available = True
except ImportError:
_jiwer_available = False

if _jiwer_available and version.parse(importlib_metadata.version("jiwer")) < version.parse("2.3.0"):
from jiwer import transforms as tr

class SentencesToListOfCharacters(tr.AbstractTransform):
def __init__(self, sentence_delimiter: str = " "):
Expand All @@ -117,7 +118,9 @@ def process_list(self, inp: List[str]):
cer_transform = tr.Compose(
[tr.RemoveMultipleSpaces(), tr.Strip(), SentencesToListOfCharacters(SENTENCE_DELIMITER)]
)
elif _jiwer_available:
elif _jiwer_available and hasattr(jiwer, "compute_measures"):
from jiwer import transforms as tr

cer_transform = tr.Compose(
[
tr.RemoveMultipleSpaces(),
Expand Down Expand Up @@ -187,35 +190,59 @@ def bleu(

def wer_and_cer(preds, labels, concatenate_texts, config_name):
try:
from jiwer import compute_measures
import jiwer
except ImportError:
raise ValueError(
f"jiwer has to be installed in order to apply the wer metric for {config_name}."
"You can install it via `pip install jiwer`."
)

if concatenate_texts:
wer = compute_measures(labels, preds)["wer"]
if hasattr(jiwer, "compute_measures"):
if concatenate_texts:
wer = jiwer.compute_measures(labels, preds)["wer"]

cer = compute_measures(labels, preds, truth_transform=cer_transform, hypothesis_transform=cer_transform)["wer"]
return {"wer": wer, "cer": cer}
cer = jiwer.compute_measures(
labels, preds, truth_transform=cer_transform, hypothesis_transform=cer_transform
)["wer"]
return {"wer": wer, "cer": cer}
else:

def compute_score(preds, labels, score_type="wer"):
incorrect = 0
total = 0
for prediction, reference in zip(preds, labels):
if score_type == "wer":
measures = jiwer.compute_measures(reference, prediction)
elif score_type == "cer":
measures = jiwer.compute_measures(
reference, prediction, truth_transform=cer_transform, hypothesis_transform=cer_transform
)
incorrect += measures["substitutions"] + measures["deletions"] + measures["insertions"]
total += measures["substitutions"] + measures["deletions"] + measures["hits"]
return incorrect / total

return {"wer": compute_score(preds, labels, "wer"), "cer": compute_score(preds, labels, "cer")}
else:
if concatenate_texts:
wer = jiwer.process_words(labels, preds).wer

cer = jiwer.process_characters(labels, preds).cer
return {"wer": wer, "cer": cer}
else:

def compute_score(preds, labels, score_type="wer"):
incorrect = 0
total = 0
for prediction, reference in zip(preds, labels):
if score_type == "wer":
measures = compute_measures(reference, prediction)
elif score_type == "cer":
measures = compute_measures(
reference, prediction, truth_transform=cer_transform, hypothesis_transform=cer_transform
)
incorrect += measures["substitutions"] + measures["deletions"] + measures["insertions"]
total += measures["substitutions"] + measures["deletions"] + measures["hits"]
return incorrect / total

return {"wer": compute_score(preds, labels, "wer"), "cer": compute_score(preds, labels, "cer")}
def compute_score(preds, labels, score_type="wer"):
incorrect = 0
total = 0
for prediction, reference in zip(preds, labels):
if score_type == "wer":
measures = jiwer.process_words(reference, prediction)
elif score_type == "cer":
measures = jiwer.process_characters(reference, prediction)
incorrect += measures.substitutions + measures.deletions + measures.insertions
total += measures.substitutions + measures.deletions + measures.hits
return incorrect / total

return {"wer": compute_score(preds, labels, "wer"), "cer": compute_score(preds, labels, "cer")}


@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
Expand Down
Loading