Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions suber/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from suber.metrics.suber_statistics import SubERStatisticsCollector
from suber.metrics.sacrebleu_interface import calculate_sacrebleu_metric
from suber.metrics.jiwer_interface import calculate_word_error_rate
from suber.metrics.pyannote_interface import calculate_time_span_accuracy
from suber.metrics.cer import calculate_character_error_rate
from suber.metrics.length_ratio import calculate_length_ratio

Expand Down Expand Up @@ -69,6 +70,11 @@ def main():
results[metric] = calculate_length_ratio(hypothesis=hypothesis_segments, reference=reference_segments)
continue

if metric.startswith("time_span"):
results[metric] = calculate_time_span_accuracy(
hypothesis=hypothesis_segments, reference=reference_segments, metric=metric)
continue

# When using existing parallel segments there will always be a <eob> word match in the end, don't count it.
# On the other hand, if hypothesis gets aligned to reference a match is not guaranteed, so count it.
score_break_at_segment_end = False
Expand Down Expand Up @@ -158,7 +164,10 @@ def check_metrics(metrics):
"t-WER", "t-CER", "t-BLEU", "t-TER", "t-chrF", "t-WER-cased", "t-CER-cased", "t-WER-seg", "t-BLEU-seg",
"t-TER-seg", "t-TER-br",
# Hypothesis to reference length ratio in terms of number of tokens.
"length_ratio"}
"length_ratio",
# Metrics evaluating how well the hypothesized subtitle timings cover the reference (in terms of duration,
# independent of text).
"time_span_accuracy", "time_span_precision", "time_span_recall", "time_span_f1"}

invalid_metrics = list(sorted(set(metrics) - allowed_metrics))
if invalid_metrics:
Expand All @@ -168,7 +177,7 @@ def check_metrics(metrics):
def check_file_formats(hypothesis_format, reference_format, metrics):
is_plain_input = (hypothesis_format == "plain" or reference_format == "plain")
for metric in metrics:
if ((metric == "SubER" or metric.startswith("t-")) and is_plain_input):
if ((metric == "SubER" or metric.startswith("t-") or metric.startswith("time_span")) and is_plain_input):
raise ValueError(f"Metric '{metric}' requires timing information and can only be computed on SRT "
f"files (both hypothesis and reference).")

Expand Down
34 changes: 34 additions & 0 deletions suber/metrics/pyannote_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import List
from suber.data_types import Subtitle

from pyannote.core import Segment, Annotation
from pyannote.metrics.detection import (
DetectionAccuracy,
DetectionPrecision,
DetectionRecall,
DetectionPrecisionRecallFMeasure,
)


def calculate_time_span_accuracy(hypothesis: List[Subtitle], reference: List[Subtitle], metric="time_span_accuracy"):

reference_timings = Annotation()
for subtitle in reference:
reference_timings[Segment(subtitle.start_time, subtitle.end_time)] = str(subtitle.index)

hypothesis_timings = Annotation()
for subtitle in hypothesis:
hypothesis_timings[Segment(subtitle.start_time, subtitle.end_time)] = str(subtitle.index)

if metric == "time_span_accuracy":
pyannote_metric = DetectionAccuracy()
elif metric == "time_span_precision":
pyannote_metric = DetectionPrecision()
elif metric == "time_span_recall":
pyannote_metric = DetectionRecall()
elif metric == "time_span_f1":
pyannote_metric = DetectionPrecisionRecallFMeasure()

score = pyannote_metric(reference_timings, hypothesis_timings)

return round(score, 3)
83 changes: 83 additions & 0 deletions tests/test_pyannote_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import unittest

from suber.metrics.pyannote_interface import calculate_time_span_accuracy
from .utilities import create_temporary_file_and_read_it


class PyAnnoteInterfaceTest(unittest.TestCase):
def setUp(self):
reference = """
1
0:00:01.000 --> 0:00:02.000
This is a subtitle.

2
0:00:03.000 --> 0:00:04.000
And another one!"""

self._reference = create_temporary_file_and_read_it(reference)

def test_time_span_accuracy_empty(self):
for metric in ("time_span_accuracy", "time_span_precision", "time_span_recall", "time_span_recall"):
score = calculate_time_span_accuracy(hypothesis=[], reference=[], metric=metric)
self.assertAlmostEqual(score, 1.0)

accuracy = calculate_time_span_accuracy(hypothesis=[], reference=self._reference, metric="time_span_accuracy")
self.assertAlmostEqual(accuracy, 0.333) # Total interval is from 1 to 4 seconds, 1 second gap is true negative.
precision = calculate_time_span_accuracy(hypothesis=[], reference=self._reference, metric="time_span_precision")
self.assertAlmostEqual(precision, 1.0)
recall = calculate_time_span_accuracy(hypothesis=[], reference=self._reference, metric="time_span_recall")
self.assertAlmostEqual(recall, 0.0)

accuracy = calculate_time_span_accuracy(hypothesis=self._reference, reference=[], metric="time_span_accuracy")
self.assertAlmostEqual(accuracy, 0.333)
precision = calculate_time_span_accuracy(hypothesis=self._reference, reference=[], metric="time_span_precision")
self.assertAlmostEqual(precision, 0.0)
recall = calculate_time_span_accuracy(hypothesis=self._reference, reference=[], metric="time_span_recall")
self.assertAlmostEqual(recall, 1.0)

def test_time_span_accuracy_perfect(self):
for metric in ("time_span_accuracy", "time_span_precision", "time_span_recall", "time_span_recall"):
score = calculate_time_span_accuracy(hypothesis=self._reference, reference=self._reference, metric=metric)
self.assertAlmostEqual(score, 1.0)

def test_time_span_accuracy_no_overlap(self):
hypothesis = """
1
0:00:02.000 --> 0:00:03.000
This is a subtitle.

2
0:00:04.000 --> 0:00:05.000
And another one!"""
hypothesis = create_temporary_file_and_read_it(hypothesis)

for metric in ("time_span_accuracy", "time_span_precision", "time_span_recall", "time_span_recall"):
score = calculate_time_span_accuracy(hypothesis=hypothesis, reference=self._reference, metric=metric)
self.assertAlmostEqual(score, 0.0)

def test_time_span_accuracy_some_overlap(self):
hypothesis = """
1
0:00:01.500 --> 0:00:02.500
The text doesn't matter.

2
0:00:03.200 --> 0:00:03.800
"""
hypothesis = create_temporary_file_and_read_it(hypothesis)

accuracy = calculate_time_span_accuracy(
hypothesis=hypothesis, reference=self._reference, metric="time_span_accuracy"
)
self.assertAlmostEqual(accuracy, 1.6 / 3, places=3)
precision = calculate_time_span_accuracy(
hypothesis=hypothesis, reference=self._reference, metric="time_span_precision"
)
self.assertAlmostEqual(precision, 1.1 / 1.6, places=3)
recall = calculate_time_span_accuracy(
hypothesis=hypothesis, reference=self._reference, metric="time_span_recall"
)
self.assertAlmostEqual(recall, 1.1 / 2, places=3)
f1 = calculate_time_span_accuracy(hypothesis=hypothesis, reference=self._reference, metric="time_span_f1")
self.assertAlmostEqual(f1, 2 * precision * recall / (precision + recall), places=3)