diff --git a/suber/__main__.py b/suber/__main__.py index 5b5bebe..c311182 100644 --- a/suber/__main__.py +++ b/suber/__main__.py @@ -13,6 +13,7 @@ from suber.metrics.suber_statistics import SubERStatisticsCollector from suber.metrics.sacrebleu_interface import calculate_sacrebleu_metric from suber.metrics.jiwer_interface import calculate_word_error_rate +from suber.metrics.pyannote_interface import calculate_time_span_accuracy from suber.metrics.cer import calculate_character_error_rate from suber.metrics.length_ratio import calculate_length_ratio @@ -69,6 +70,11 @@ def main(): results[metric] = calculate_length_ratio(hypothesis=hypothesis_segments, reference=reference_segments) continue + if metric.startswith("time_span"): + results[metric] = calculate_time_span_accuracy( + hypothesis=hypothesis_segments, reference=reference_segments, metric=metric) + continue + # When using existing parallel segments there will always be a word match in the end, don't count it. # On the other hand, if hypothesis gets aligned to reference a match is not guaranteed, so count it. score_break_at_segment_end = False @@ -158,7 +164,10 @@ def check_metrics(metrics): "t-WER", "t-CER", "t-BLEU", "t-TER", "t-chrF", "t-WER-cased", "t-CER-cased", "t-WER-seg", "t-BLEU-seg", "t-TER-seg", "t-TER-br", # Hypothesis to reference length ratio in terms of number of tokens. - "length_ratio"} + "length_ratio", + # Metrics evaluating how well the hypothesized subtitle timings cover the reference (in terms of duration, + # independent of text). + "time_span_accuracy", "time_span_precision", "time_span_recall", "time_span_f1"} invalid_metrics = list(sorted(set(metrics) - allowed_metrics)) if invalid_metrics: @@ -168,7 +177,7 @@ def check_metrics(metrics): def check_file_formats(hypothesis_format, reference_format, metrics): is_plain_input = (hypothesis_format == "plain" or reference_format == "plain") for metric in metrics: - if ((metric == "SubER" or metric.startswith("t-")) and is_plain_input): + if ((metric == "SubER" or metric.startswith("t-") or metric.startswith("time_span")) and is_plain_input): raise ValueError(f"Metric '{metric}' requires timing information and can only be computed on SRT " f"files (both hypothesis and reference).") diff --git a/suber/metrics/pyannote_interface.py b/suber/metrics/pyannote_interface.py new file mode 100644 index 0000000..9cb10cc --- /dev/null +++ b/suber/metrics/pyannote_interface.py @@ -0,0 +1,34 @@ +from typing import List +from suber.data_types import Subtitle + +from pyannote.core import Segment, Annotation +from pyannote.metrics.detection import ( + DetectionAccuracy, + DetectionPrecision, + DetectionRecall, + DetectionPrecisionRecallFMeasure, +) + + +def calculate_time_span_accuracy(hypothesis: List[Subtitle], reference: List[Subtitle], metric="time_span_accuracy"): + + reference_timings = Annotation() + for subtitle in reference: + reference_timings[Segment(subtitle.start_time, subtitle.end_time)] = str(subtitle.index) + + hypothesis_timings = Annotation() + for subtitle in hypothesis: + hypothesis_timings[Segment(subtitle.start_time, subtitle.end_time)] = str(subtitle.index) + + if metric == "time_span_accuracy": + pyannote_metric = DetectionAccuracy() + elif metric == "time_span_precision": + pyannote_metric = DetectionPrecision() + elif metric == "time_span_recall": + pyannote_metric = DetectionRecall() + elif metric == "time_span_f1": + pyannote_metric = DetectionPrecisionRecallFMeasure() + + score = pyannote_metric(reference_timings, hypothesis_timings) + + return round(score, 3) diff --git a/tests/test_pyannote_interface.py b/tests/test_pyannote_interface.py new file mode 100644 index 0000000..942084a --- /dev/null +++ b/tests/test_pyannote_interface.py @@ -0,0 +1,83 @@ +import unittest + +from suber.metrics.pyannote_interface import calculate_time_span_accuracy +from .utilities import create_temporary_file_and_read_it + + +class PyAnnoteInterfaceTest(unittest.TestCase): + def setUp(self): + reference = """ + 1 + 0:00:01.000 --> 0:00:02.000 + This is a subtitle. + + 2 + 0:00:03.000 --> 0:00:04.000 + And another one!""" + + self._reference = create_temporary_file_and_read_it(reference) + + def test_time_span_accuracy_empty(self): + for metric in ("time_span_accuracy", "time_span_precision", "time_span_recall", "time_span_recall"): + score = calculate_time_span_accuracy(hypothesis=[], reference=[], metric=metric) + self.assertAlmostEqual(score, 1.0) + + accuracy = calculate_time_span_accuracy(hypothesis=[], reference=self._reference, metric="time_span_accuracy") + self.assertAlmostEqual(accuracy, 0.333) # Total interval is from 1 to 4 seconds, 1 second gap is true negative. + precision = calculate_time_span_accuracy(hypothesis=[], reference=self._reference, metric="time_span_precision") + self.assertAlmostEqual(precision, 1.0) + recall = calculate_time_span_accuracy(hypothesis=[], reference=self._reference, metric="time_span_recall") + self.assertAlmostEqual(recall, 0.0) + + accuracy = calculate_time_span_accuracy(hypothesis=self._reference, reference=[], metric="time_span_accuracy") + self.assertAlmostEqual(accuracy, 0.333) + precision = calculate_time_span_accuracy(hypothesis=self._reference, reference=[], metric="time_span_precision") + self.assertAlmostEqual(precision, 0.0) + recall = calculate_time_span_accuracy(hypothesis=self._reference, reference=[], metric="time_span_recall") + self.assertAlmostEqual(recall, 1.0) + + def test_time_span_accuracy_perfect(self): + for metric in ("time_span_accuracy", "time_span_precision", "time_span_recall", "time_span_recall"): + score = calculate_time_span_accuracy(hypothesis=self._reference, reference=self._reference, metric=metric) + self.assertAlmostEqual(score, 1.0) + + def test_time_span_accuracy_no_overlap(self): + hypothesis = """ + 1 + 0:00:02.000 --> 0:00:03.000 + This is a subtitle. + + 2 + 0:00:04.000 --> 0:00:05.000 + And another one!""" + hypothesis = create_temporary_file_and_read_it(hypothesis) + + for metric in ("time_span_accuracy", "time_span_precision", "time_span_recall", "time_span_recall"): + score = calculate_time_span_accuracy(hypothesis=hypothesis, reference=self._reference, metric=metric) + self.assertAlmostEqual(score, 0.0) + + def test_time_span_accuracy_some_overlap(self): + hypothesis = """ + 1 + 0:00:01.500 --> 0:00:02.500 + The text doesn't matter. + + 2 + 0:00:03.200 --> 0:00:03.800 + """ + hypothesis = create_temporary_file_and_read_it(hypothesis) + + accuracy = calculate_time_span_accuracy( + hypothesis=hypothesis, reference=self._reference, metric="time_span_accuracy" + ) + self.assertAlmostEqual(accuracy, 1.6 / 3, places=3) + precision = calculate_time_span_accuracy( + hypothesis=hypothesis, reference=self._reference, metric="time_span_precision" + ) + self.assertAlmostEqual(precision, 1.1 / 1.6, places=3) + recall = calculate_time_span_accuracy( + hypothesis=hypothesis, reference=self._reference, metric="time_span_recall" + ) + self.assertAlmostEqual(recall, 1.1 / 2, places=3) + f1 = calculate_time_span_accuracy(hypothesis=hypothesis, reference=self._reference, metric="time_span_f1") + self.assertAlmostEqual(f1, 2 * precision * recall / (precision + recall), places=3)