diff --git a/README.md b/README.md index 72cf044..f3796b4 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,11 @@ pip install subtitle-edit-rate will install the `suber` command line tool. Alternatively, check out this git repository and run the contained `suber` module with `python -m suber`. +For Japanese and/or Korean support (via `-l`, see below), specify `ja` and/or `ko` as optional dependency: +```console +pip install subtitle-edit-rate[ja,ko] +``` + ## Basic Usage Currently, we expect subtitle files to come in [SubRip text (SRT)](https://en.wikipedia.org/wiki/SubRip) format. Given a human reference subtitle file `reference.srt` and a hypothesis file `hypothesis.srt` (typically the output of an automatic subtitling system) the SubER score can be calculated by running: @@ -28,6 +33,9 @@ Also, note that ``, `` and `` formatting tags are ignored if present in #### Punctuation and Case-Sensitivity The main SubER metric is computed on normalized text, which means case-insensitive and without taking punctuation into account, as we observe higher correlation with human judgements and post-edit effort in this setting. We provide an implementation of a case-sensitive variant which also uses a tokenizer to take punctuation into account as separate tokens which you can use "at your own risk" or to reassess our findings. For this, add `--metrics SubER-cased` to the command above. Please do not report results using this variant as "SubER" unless explicitly mentioning the punctuation-/case-sensitivity. +#### Language support +SubER is expected to give meaningful scores for all languages that use space-separation of words similar to English. In addition, versions `>=0.4.0` explicitly support __Chinese__, __Japanese__ and __Korean__. (Korean does use spaces, but we follow [SacreBLEU](https://github.com/mjpost/sacrebleu) by using [mecab-ko](https://github.com/NoUnique/pymecab-ko) tokenization.) For these particular languages it is __required__ to set the `-l`/`--language` option to the corresponding two-letter language code, for example for Japanese files `suber -H hypothesis.srt -R reference.srt -l ja`. An example of a currently not supported scriptio continua language is Thai. As a workaround, it is however possible to run your own tokenization / word segmentation on the SRT files before calling `suber`. + ## Other Metrics The SubER tool supports computing the following other metrics directly on subtitle files: @@ -52,6 +60,8 @@ $ suber -H hypothesis.srt -R reference.srt --metrics WER BLEU TER chrF CER ``` In this mode, the text from each parallel subtitle pair is considered to be a sentence pair. +For __Chinese__, __Japanese__ and __Korean__ files, also here it is required to specify the language code via `-l`/`--language` option for correct BLEU, TER and WER scores. (This sets the `asian_support` option of TER, and for BLEU and WER enables tokenization via SacreBleu's dedicated tokenizers `TokenizerZh`, `TokenizerJaMecab`, and `TokenizerKoMecab`, respectively.) + ### Scoring Non-Parallel Subtitle Files In the general case, subtitle files for the same video can have different numbers of subtitles with different time stamps. All metrics - except SubER - usually require to be calculated on parallel segments. To apply these metrics to general subtitle files, the hypothesis file has to be re-segmented to correspond to the reference subtitles. The SubER tool implements two options: diff --git a/pyproject.toml b/pyproject.toml index 653fe40..3110d6c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ dependencies = [ "sacrebleu==2.5.1", "jiwer==4.0.0", "numpy", + "regex", "dataclasses;python_version<'3.7'", ] requires-python = ">= 3.6" @@ -31,6 +32,11 @@ classifiers = [ "Topic :: Scientific/Engineering :: Artificial Intelligence", ] +[project.optional-dependencies] +# Installs MeCab for Japanese/Korean word segmentation. +ja = ["sacrebleu[ja]==2.5.1"] +ko = ["sacrebleu[ko]==2.5.1"] + [project.urls] Homepage = "https://github.com/apptek/SubER" Issues = "https://github.com/apptek/SubER/issues" diff --git a/suber/__main__.py b/suber/__main__.py index 5b5bebe..b3e7f70 100644 --- a/suber/__main__.py +++ b/suber/__main__.py @@ -30,8 +30,17 @@ def parse_arguments(): help="The reference files. Usually just one file, but we support test sets consisting of " "multiple files.") parser.add_argument("-m", "--metrics", nargs="+", default=["SubER"], help="The metrics to compute.") - parser.add_argument("-f", "--hypothesis-format", default="SRT", help="Hypothesis file format, 'SRT' or 'plain'.") - parser.add_argument("-F", "--reference-format", default="SRT", help="Reference file format, 'SRT' or 'plain'.") + parser.add_argument("-f", "--hypothesis-format", default="SRT", choices=["SRT", "plain"], + help="Hypothesis file format, 'SRT' or 'plain'.") + parser.add_argument("-F", "--reference-format", default="SRT", choices=["SRT", "plain"], + help="Reference file format, 'SRT' or 'plain'.") + parser.add_argument("-l", "--language", choices=["zh", "ja", "ko"], + help='Set to "zh", "ja" or "ko" to enable correct tokenization of Chinese, Japanese or Korean ' + "text, respectively. We follow sacrebleu and use its BLEU tokenizers 'zh', 'ja-mecab' and " + "'ko-mecab' for these three languages, respectively. We employ those tokenizers for SubER " + "and WER computation too, in favor of TercomTokenizer. That's because TercomTokenizer's " + '"asian_support" is questionable, it does not split Japanese Hiragana/Katakana at all. ' + 'Only for TER itself the original TercomTokenizer with "asian_support" is used.') parser.add_argument("--suber-statistics", action="store_true", help="If set, will create an '#info' field in the output containing statistics about the " "different edit operations used to calculate the SubER score.") @@ -66,7 +75,8 @@ def main(): continue # specified multiple times by the user if metric == "length_ratio": - results[metric] = calculate_length_ratio(hypothesis=hypothesis_segments, reference=reference_segments) + results[metric] = calculate_length_ratio( + hypothesis=hypothesis_segments, reference=reference_segments, language=args.language) continue # When using existing parallel segments there will always be a word match in the end, don't count it. @@ -82,7 +92,7 @@ def main(): # AS-WER and AS-BLEU were introduced by Matusov et al. https://aclanthology.org/2005.iwslt-1.19.pdf if levenshtein_aligned_hypothesis_segments is None: levenshtein_aligned_hypothesis_segments = levenshtein_align_hypothesis_to_reference( - hypothesis=hypothesis_segments, reference=reference_segments) + hypothesis=hypothesis_segments, reference=reference_segments, language=args.language) hypothesis_segments_to_use = levenshtein_aligned_hypothesis_segments metric = metric[len("AS-"):] @@ -94,7 +104,7 @@ def main(): # https://www.isca-archive.org/interspeech_2021/cherry21_interspeech.pdf if time_aligned_hypothesis_segments is None: time_aligned_hypothesis_segments = time_align_hypothesis_to_reference( - hypothesis=hypothesis_segments, reference=reference_segments) + hypothesis=hypothesis_segments, reference=reference_segments, language=args.language) hypothesis_segments_to_use = time_aligned_hypothesis_segments metric = metric[len("t-"):] @@ -110,7 +120,7 @@ def main(): metric_score = calculate_SubER( hypothesis=hypothesis_segments_to_use, reference=reference_segments, metric=metric, - statistics_collector=statistics_collector) + statistics_collector=statistics_collector, language=args.language) if statistics_collector: additional_outputs[full_metric_name] = statistics_collector.get_statistics() @@ -118,7 +128,7 @@ def main(): elif metric.startswith("WER"): metric_score = calculate_word_error_rate( hypothesis=hypothesis_segments_to_use, reference=reference_segments, metric=metric, - score_break_at_segment_end=score_break_at_segment_end) + score_break_at_segment_end=score_break_at_segment_end, language=args.language) elif metric.startswith("CER"): metric_score = calculate_character_error_rate( @@ -127,7 +137,7 @@ def main(): else: metric_score = calculate_sacrebleu_metric( hypothesis=hypothesis_segments_to_use, reference=reference_segments, metric=metric, - score_break_at_segment_end=score_break_at_segment_end) + score_break_at_segment_end=score_break_at_segment_end, language=args.language) results[full_metric_name] = metric_score diff --git a/suber/constants.py b/suber/constants.py index 5ab9707..3845208 100644 --- a/suber/constants.py +++ b/suber/constants.py @@ -3,3 +3,12 @@ END_OF_BLOCK_SYMBOL = "" MASK_SYMBOL = "" + +# These are the languages for which we enable "asian_support" for TER computation. +# TODO: Korean included as a precaution, does it make sense? "asian_support=True" should only have an effect in very +# rare cases for Korean text? +# For SubER and WER we actually use sacrebleu's TokenizerZh, TokenizerJaMecab, and TokenizerKoMecab instead of +# TercomTokenizer with "asian_support". +EAST_ASIAN_LANGUAGE_CODES = ["zh", "ja", "ko"] + +SPACE_ESCAPE = "▁" diff --git a/suber/file_readers/srt_file_reader.py b/suber/file_readers/srt_file_reader.py index 70a7649..7a0479c 100644 --- a/suber/file_readers/srt_file_reader.py +++ b/suber/file_readers/srt_file_reader.py @@ -1,10 +1,9 @@ import re import datetime -import numpy - from suber.file_readers.file_reader_base import FileReaderBase from suber.data_types import LineBreak, TimedWord, Subtitle +from suber.utilities import set_approximate_word_times class SRTFormatError(Exception): @@ -81,7 +80,7 @@ def _parse_lines(self, file_object): if word_list: # might be an empty subtitle word_list[-1].line_break = LineBreak.END_OF_BLOCK - self._set_approximate_word_times(word_list, start_time, end_time) + set_approximate_word_times(word_list, start_time, end_time) subtitles.append( Subtitle(word_list=word_list, index=subtitle_index, start_time=start_time, end_time=end_time)) @@ -98,32 +97,13 @@ def _parse_lines(self, file_object): if word_list: # might be an empty subtitle word_list[-1].line_break = LineBreak.END_OF_BLOCK - self._set_approximate_word_times(word_list, start_time, end_time) + set_approximate_word_times(word_list, start_time, end_time) subtitles.append( Subtitle(word_list=word_list, index=subtitle_index, start_time=start_time, end_time=end_time)) return subtitles - @classmethod - def _set_approximate_word_times(cls, word_list, start_time, end_time): - """ - Linearly interpolates word times from the subtitle start and end time as described in - https://www.isca-archive.org/interspeech_2021/cherry21_interspeech.pdf - """ - # Remove small margin to guarantee the first and last word will always be counted as within the subtitle. - epsilon = 1e-8 - start_time = start_time + epsilon - end_time = end_time - epsilon - - num_words = len(word_list) - duration = end_time - start_time - assert duration >= 0 - - approximate_word_times = numpy.linspace(start=start_time, stop=end_time, num=num_words) - for word_time, word in zip(approximate_word_times, word_list): - word.approximate_word_time = word_time - @classmethod def _parse_time_stamp(cls, time_stamp): time_stamp_tokens = time_stamp.split() diff --git a/suber/hyp_to_ref_alignment/levenshtein_alignment.py b/suber/hyp_to_ref_alignment/levenshtein_alignment.py index dbfa12b..984aac7 100644 --- a/suber/hyp_to_ref_alignment/levenshtein_alignment.py +++ b/suber/hyp_to_ref_alignment/levenshtein_alignment.py @@ -1,19 +1,29 @@ import numpy +import regex import string from itertools import zip_longest -from typing import List, Tuple +from typing import List, Optional, Tuple from suber import lib_levenshtein +from suber.constants import EAST_ASIAN_LANGUAGE_CODES, SPACE_ESCAPE from suber.data_types import Segment +from suber.tokenizers import reversibly_tokenize_segments, detokenize_segments -def levenshtein_align_hypothesis_to_reference(hypothesis: List[Segment], reference: List[Segment]) -> List[Segment]: +def levenshtein_align_hypothesis_to_reference( + hypothesis: List[Segment], reference: List[Segment], language: Optional[str] = None) -> List[Segment]: """ Runs the Levenshtein algorithm to get the minimal set of edit operations to convert the full list of hypothesis words into the full list of reference words. The edit operations implicitly define an alignment between hypothesis and reference words. Using this alignment, the hypotheses are re-segmented to match the reference segmentation. """ + if language in EAST_ASIAN_LANGUAGE_CODES: + # Punctuation kept attached because we want to remove it below to normalize the tokens before alignment, but + # there we cannot change the number of tokens (and must not create empty tokens). + hypothesis = reversibly_tokenize_segments(hypothesis, language, keep_punctuation_attached=True) + reference = reversibly_tokenize_segments(reference, language, keep_punctuation_attached=True) + remove_punctuation_table = str.maketrans('', '', string.punctuation) def normalize_word(word): @@ -21,7 +31,17 @@ def normalize_word(word): Lower-cases and removes punctuation as this increases the alignment accuracy. """ word = word.lower() - word_without_punctuation = word.translate(remove_punctuation_table) + + if language in EAST_ASIAN_LANGUAGE_CODES: + # Space escape needed for detokenization, but we don't want it to influence the alignment. + if word.startswith(SPACE_ESCAPE): + word = word[1:] + assert word, "Word should not be only space escape character." + word_without_punctuation = regex.sub(r"\p{P}", "", word) + else: + # Backwards compatibility: keep old behavior for other languages, even though removing non-ASCII punctuation + # would also make sense here. + word_without_punctuation = word.translate(remove_punctuation_table) if not word_without_punctuation: return word # keep tokens that are purely punctuation @@ -85,6 +105,9 @@ def normalize_word(word): aligned_hypothesis = [Segment(word_list=word_list) for word_list in aligned_hypothesis_word_lists] + if language in EAST_ASIAN_LANGUAGE_CODES: + aligned_hypothesis = detokenize_segments(aligned_hypothesis) + return aligned_hypothesis diff --git a/suber/hyp_to_ref_alignment/time_alignment.py b/suber/hyp_to_ref_alignment/time_alignment.py index 38144a2..2ec138f 100644 --- a/suber/hyp_to_ref_alignment/time_alignment.py +++ b/suber/hyp_to_ref_alignment/time_alignment.py @@ -1,16 +1,23 @@ import numpy +from typing import List, Optional -from typing import List +from suber.constants import EAST_ASIAN_LANGUAGE_CODES from suber.data_types import Segment, Subtitle +from suber.tokenizers import reversibly_tokenize_segments, detokenize_segments -def time_align_hypothesis_to_reference(hypothesis: List[Segment], reference: List[Subtitle]) -> List[Subtitle]: +def time_align_hypothesis_to_reference( + hypothesis: List[Segment], reference: List[Subtitle], language: Optional[str] = None) -> List[Subtitle]: """ Re-segments the hypothesis segments according to the reference subtitle timings. The output hypothesis subtitles will have the same time stamps as the reference, and each will contain the words whose approximate times falls into these intervals, i.e. reference_subtitle.start_time < word.approximate_word_time < reference_subtitle.end_time. Hypothesis words that do not fall into any subtitle will be dropped. """ + + if language in EAST_ASIAN_LANGUAGE_CODES: + hypothesis = reversibly_tokenize_segments(hypothesis, language) + aligned_hypothesis_word_lists = [[] for _ in reference] reference_start_times = numpy.array([subtitle.start_time for subtitle in reference]) @@ -40,4 +47,7 @@ def time_align_hypothesis_to_reference(hypothesis: List[Segment], reference: Lis aligned_hypothesis.append(subtitle) + if language in EAST_ASIAN_LANGUAGE_CODES: + aligned_hypothesis = detokenize_segments(aligned_hypothesis) + return aligned_hypothesis diff --git a/suber/metrics/cer.py b/suber/metrics/cer.py index f4ac6a6..aef8cfb 100644 --- a/suber/metrics/cer.py +++ b/suber/metrics/cer.py @@ -1,6 +1,7 @@ -import string from typing import List +import regex + from suber import lib_levenshtein from suber.data_types import Segment from suber.utilities import segment_to_string @@ -14,12 +15,8 @@ def calculate_character_error_rate(hypothesis: List[Segment], reference: List[Se reference_strings = [segment_to_string(segment) for segment in reference] if metric != "CER-cased": - remove_punctuation_table = str.maketrans('', '', string.punctuation) - def normalize_string(string): - string = string.translate(remove_punctuation_table) - # Ellipsis is a common character in subtitles which is not included in string.punctuation. - string = string.replace('…', '') + string = regex.sub(r"\p{P}", "", string) string = string.lower() return string diff --git a/suber/metrics/jiwer_interface.py b/suber/metrics/jiwer_interface.py index e336fa6..fcdc4e5 100644 --- a/suber/metrics/jiwer_interface.py +++ b/suber/metrics/jiwer_interface.py @@ -2,14 +2,14 @@ import functools from typing import List -from sacrebleu.tokenizers.tokenizer_ter import TercomTokenizer - from suber.data_types import Segment +from suber.constants import EAST_ASIAN_LANGUAGE_CODES +from suber.tokenizers import get_sacrebleu_tokenizer from suber.utilities import segment_to_string, get_segment_to_string_opts_from_metric def calculate_word_error_rate(hypothesis: List[Segment], reference: List[Segment], metric="WER", - score_break_at_segment_end=True) -> float: + score_break_at_segment_end=True, language: str = None) -> float: assert len(hypothesis) == len(reference), ( "Number of hypothesis segments does not match reference, alignment step missing?") @@ -18,19 +18,25 @@ def calculate_word_error_rate(hypothesis: List[Segment], reference: List[Segment transformations = jiwer.Compose([ # Note: the original release used no tokenization here. We find this change to have a minor positive effect # on correlation with post-edit effort (-0.657 vs. -0.650 in Table 1, row 2, "Combined" in our paper.) - TercomTokenize(), + Tokenize(language), jiwer.ReduceToListOfListOfWords(), ]) metric = "WER" else: - transformations = jiwer.Compose([ + transformations = [ jiwer.ToLowerCase(), jiwer.RemovePunctuation(), # Ellipsis is a common character in subtitles that older jiwer versions would not remove by default. jiwer.RemoveSpecificWords(['…']), jiwer.ReduceToListOfListOfWords(), - ]) + ] + # For most languages no tokenizer needed when punctuation is removed. Not true though for languages that do not + # use spaces to separate words. + if language in EAST_ASIAN_LANGUAGE_CODES: + transformations.insert(3, Tokenize(language)) + + transformations = jiwer.Compose(transformations) include_breaks, mask_words, metric = get_segment_to_string_opts_from_metric(metric) assert metric == "WER" @@ -51,9 +57,12 @@ def calculate_word_error_rate(hypothesis: List[Segment], reference: List[Segment return round(wer_score * 100, 3) -class TercomTokenize(jiwer.AbstractTransform): - def __init__(self): - self.tokenizer = TercomTokenizer(normalized=True, no_punct=False, case_sensitive=True) +class Tokenize(jiwer.AbstractTransform): + def __init__(self, language: str): + # For backwards-compatibility, TercomTokenizer is used for all languages except "ja", "ko", and "zh". + self.tokenizer = get_sacrebleu_tokenizer(language, default_to_tercom=True) def process_string(self, s: str): + # TercomTokenizer would split "" into "< eol >" + s = s.replace("", "eol").replace("", "eob") return self.tokenizer(s) diff --git a/suber/metrics/length_ratio.py b/suber/metrics/length_ratio.py index b310bc1..9e2e2a6 100644 --- a/suber/metrics/length_ratio.py +++ b/suber/metrics/length_ratio.py @@ -1,19 +1,19 @@ from typing import List -from suber.data_types import Segment -from sacrebleu.tokenizers.tokenizer_13a import Tokenizer13a +from suber.data_types import Segment +from suber.tokenizers import get_sacrebleu_tokenizer -def calculate_length_ratio(hypothesis: List[Segment], reference: List[Segment]) -> float: +def calculate_length_ratio(hypothesis: List[Segment], reference: List[Segment], language: str = None) -> float: all_hypothesis_words = [word.string for segment in hypothesis for word in segment.word_list] all_reference_words = [word.string for segment in reference for word in segment.word_list] full_hypothesis_string = " ".join(all_hypothesis_words) full_reference_string = " ".join(all_reference_words) - # Default tokenizer used for BLEU calculation in SacreBLEU, so length ratio we calculate here should correspond - # to the "ratio" printed by SacreBLEU. - tokenizer = Tokenizer13a() + # Same tokenizer as used by default for BLEU calculation in SacreBLEU depending on the language, so length ratio we + # calculate here should correspond to the "ratio" printed by SacreBLEU. + tokenizer = get_sacrebleu_tokenizer(language) num_tokens_hypothesis = len(tokenizer(full_hypothesis_string).split()) num_tokens_reference = len(tokenizer(full_reference_string).split()) diff --git a/suber/metrics/sacrebleu_interface.py b/suber/metrics/sacrebleu_interface.py index e130f6f..877bba0 100644 --- a/suber/metrics/sacrebleu_interface.py +++ b/suber/metrics/sacrebleu_interface.py @@ -4,11 +4,12 @@ from sacrebleu.metrics import BLEU, TER, CHRF from suber.data_types import Segment +from suber.constants import EAST_ASIAN_LANGUAGE_CODES from suber.utilities import segment_to_string, get_segment_to_string_opts_from_metric def calculate_sacrebleu_metric(hypothesis: List[Segment], reference: List[Segment], - metric="BLEU", score_break_at_segment_end=True) -> float: + metric="BLEU", score_break_at_segment_end=True, language: str = None) -> float: assert len(hypothesis) == len(reference), ( "Number of hypothesis segments does not match reference, alignment step missing?") @@ -16,9 +17,22 @@ def calculate_sacrebleu_metric(hypothesis: List[Segment], reference: List[Segmen include_breaks, mask_words, metric = get_segment_to_string_opts_from_metric(metric) if metric == "BLEU": - sacrebleu_metric = BLEU() + sacrebleu_metric = BLEU(trg_lang=language or "") elif metric == "TER": - sacrebleu_metric = TER() + # Setting 'asian_support' only has an effect if 'normalized' is set as well. + # TODO: using TER with default options was probably a bad idea in the first place, 'normalized' should always be + # set (unless input would already be tokenized). Probably also 'case_sensitive'. The original TER paper mentions + # case sensitivity and punctuation as separate tokens already. But until someone really cares let's not break + # current behavior or add new command line options. For languages that use spaces, the default behavior is not + # completely unreasonable. + asian_support = language in EAST_ASIAN_LANGUAGE_CODES + sacrebleu_metric = TER(asian_support=asian_support, normalized=asian_support) + + if asian_support and mask_words: + raise NotImplementedError( + f"TER-br not implemented for language '{language}'. Would require doing the TER tokenization " + "separately before replacing with mask tokens and then calling sacrebleu's TER.") + elif metric == "chrF": sacrebleu_metric = CHRF() else: diff --git a/suber/metrics/suber.py b/suber/metrics/suber.py index a2fecfd..2815e2c 100644 --- a/suber/metrics/suber.py +++ b/suber/metrics/suber.py @@ -1,18 +1,19 @@ import string from typing import List +import regex + from suber.data_types import Subtitle, TimedWord, LineBreak -from suber.constants import END_OF_BLOCK_SYMBOL, END_OF_LINE_SYMBOL +from suber.constants import END_OF_BLOCK_SYMBOL, END_OF_LINE_SYMBOL, EAST_ASIAN_LANGUAGE_CODES from suber.metrics import lib_ter from suber.metrics.suber_statistics import SubERStatisticsCollector - -from sacrebleu.tokenizers.tokenizer_ter import TercomTokenizer # only used for "SubER-cased" +from suber.tokenizers import get_sacrebleu_tokenizer def calculate_SubER(hypothesis: List[Subtitle], reference: List[Subtitle], metric="SubER", - statistics_collector: SubERStatisticsCollector = None) -> float: + statistics_collector: SubERStatisticsCollector = None, language: str = None) -> float: """ - Main function to caculate the SubER score. It is computed on normalized text, which means case-insensitive and + Main function to calculate the SubER score. It is computed on normalized text, which means case-insensitive and without taking punctuation into account, as we observed higher correlation with human judgements and post-edit effort in this setting. You can set the 'metric' parameter to "SubER-cased" to calculate a score on cased and punctuated text nevertheless. In this case punctuation will be treated as separate words by using a tokenizer. @@ -30,7 +31,8 @@ def calculate_SubER(hypothesis: List[Subtitle], reference: List[Subtitle], metri hypothesis_part, reference_part = part num_edits, reference_length = _calculate_num_edits_for_part( - hypothesis_part, reference_part, normalize=normalize, statistics_collector=statistics_collector) + hypothesis_part, reference_part, normalize=normalize, statistics_collector=statistics_collector, + language=language) total_num_edits += num_edits total_reference_length += reference_length @@ -47,7 +49,7 @@ def calculate_SubER(hypothesis: List[Subtitle], reference: List[Subtitle], metri def _calculate_num_edits_for_part(hypothesis_part: List[Subtitle], reference_part: List[Subtitle], normalize=True, - statistics_collector: SubERStatisticsCollector = None): + statistics_collector: SubERStatisticsCollector = None, language: str = None): """ Returns number of edits (word or break edits and shifts) and the total number of reference tokens (words + breaks) for the current part. @@ -58,13 +60,14 @@ def _calculate_num_edits_for_part(hypothesis_part: List[Subtitle], reference_par if normalize: # Although casing and punctuation are important aspects of subtitle quality, we observe higher correlation with # human post edit effort when normalizing the words. - all_hypothesis_words = _normalize_words(all_hypothesis_words) - all_reference_words = _normalize_words(all_reference_words) - else: + all_hypothesis_words = _normalize_words(all_hypothesis_words, language=language) + all_reference_words = _normalize_words(all_reference_words, language=language) + + if not normalize or language in EAST_ASIAN_LANGUAGE_CODES: # When not normalizing punctuation symbols are kept. We treat them as separate tokens by splitting them off # the words using sacrebleu's TercomTokenizer. - all_hypothesis_words = _tokenize_words(all_hypothesis_words) - all_reference_words = _tokenize_words(all_reference_words) + all_hypothesis_words = _tokenize_words(all_hypothesis_words, language=language) + all_reference_words = _tokenize_words(all_reference_words, language=language) all_hypothesis_words = _add_breaks_as_words(all_hypothesis_words) all_reference_words = _add_breaks_as_words(all_reference_words) @@ -107,17 +110,26 @@ def _add_breaks_as_words(words: List[TimedWord]) -> List[TimedWord]: remove_punctuation_table = str.maketrans('', '', string.punctuation) -def _normalize_words(words: List[TimedWord]) -> List[TimedWord]: +def _normalize_words(words: List[TimedWord], language: str = None) -> List[TimedWord]: """ Lower-cases Words and removes punctuation. """ output_words = [] for word in words: normalized_string = word.string.lower() - normalized_string_without_punctuation = normalized_string.translate(remove_punctuation_table) - normalized_string_without_punctuation = normalized_string_without_punctuation.replace('…', '') - - if normalized_string_without_punctuation: # keep tokens that are purely punctuation + if language in EAST_ASIAN_LANGUAGE_CODES: + normalized_string_without_punctuation = regex.sub(r"\p{P}", "", normalized_string) + else: + # Backwards compatibility: keep old behavior for other languages, even though removing non-ASCII punctuation + # would also make sense here. + normalized_string_without_punctuation = normalized_string.translate(remove_punctuation_table) + normalized_string_without_punctuation = normalized_string_without_punctuation.replace('…', '') + + # Keep tokens that are purely punctuation. + # TODO: this rule is questionable, for example in French '?', '!', etc. are not attached and thus taken into + # account. Also leading dialogue dashes are often followed by a space. But also here, better to not change + # original behavior for now. + if normalized_string_without_punctuation: normalized_string = normalized_string_without_punctuation output_words.append( @@ -131,21 +143,26 @@ def _normalize_words(words: List[TimedWord]) -> List[TimedWord]: return output_words -_tokenizer = None # created if needed in _tokenize_words(), has to be cached... +_tokenizers = {} # language -> callable -def _tokenize_words(words: List[TimedWord]) -> List[TimedWord]: +def _tokenize_words(words: List[TimedWord], language: str = None) -> List[TimedWord]: """ Not used for the main SubER metric, only for the "SubER-cased" variant. Applies sacrebleu's TercomTokenizer to all words in the input, which will create a new list of words containing punctuation symbols as separate elements. """ - global _tokenizer - if not _tokenizer: - _tokenizer = TercomTokenizer(normalized=True, no_punct=False, case_sensitive=True) + global _tokenizers + + if language not in _tokenizers: + # For all languages except "ja", "ko", "zh" we use TercomTokenizer to stay close to the reference TER + # implementation. + _tokenizers[language] = get_sacrebleu_tokenizer(language, default_to_tercom=True) + + tokenizer = _tokenizers[language] output_words = [] for word in words: - tokenized_word_string = _tokenizer(word.string) + tokenized_word_string = tokenizer(word.string) tokens = tokenized_word_string.split() if len(tokens) == 1: diff --git a/suber/tokenizers.py b/suber/tokenizers.py new file mode 100644 index 0000000..b87fa05 --- /dev/null +++ b/suber/tokenizers.py @@ -0,0 +1,189 @@ +import regex +from typing import Callable, List, Optional + +from sacrebleu.tokenizers.tokenizer_13a import Tokenizer13a +from sacrebleu.tokenizers.tokenizer_ja_mecab import TokenizerJaMecab +from sacrebleu.tokenizers.tokenizer_ter import TercomTokenizer +from sacrebleu.tokenizers.tokenizer_zh import TokenizerZh + +from suber.constants import SPACE_ESCAPE +from suber.data_types import LineBreak, Segment, Subtitle, Word, TimedWord +from suber.utilities import set_approximate_word_times + + +def get_sacrebleu_tokenizer(language: str, default_to_tercom: bool = False) -> Callable[[str], str]: + """ + Returns the default tokenizer as used by sacrebleu for BLEU calculation. If 'default_to_tercom' is set, will return + (case-sensitive) TercomTokenizer instead, if language is not "ja", "ko", "zh". The reasoning is that for TER-based + metrics we want to stay close to the original implementation, also we want to keep the behavior of our original + implementation which always used TercomTokenizer. But the "asian_support" of TercomTokenizer is questionable, + especially that for Japanese sequences of Hiragana and Katakana characters are never split. So for those languages + we switch to the dedicated default BLEU tokenizers. + """ + if language == "ja": + tokenizer = TokenizerJaMecab() + elif language == "ko": + # Import only here to keep compatible with sacrebleu versions < 2.2 for all other languages. + from sacrebleu.tokenizers.tokenizer_ko_mecab import TokenizerKoMecab + + tokenizer = TokenizerKoMecab() + elif language == "zh": + tokenizer = TokenizerZh() + elif not default_to_tercom: + tokenizer = Tokenizer13a() + else: + tokenizer = TercomTokenizer(normalized=True, no_punct=False, case_sensitive=True) + + return tokenizer + + +def reversibly_tokenize_segments( + segments: List[Segment], language: str, keep_punctuation_attached: bool = False) -> List[Segment]: + """ + For each segment, splits words by applying the tokenizer function to the Word.string attributes. Uses a "▁" prefix + to represent the original word boundaries which are positions of spaces (similar to SentencePiece). If the input + Segments are Subtitles, the output will also be Subtitles. If the input contains TimedWords, the output will too. + For that, subtitle timings are carried over and approximate word times are recomputed. + If 'keep_punctuation_attached' is set, do not split off tokens from (space-separated) input words which would + consist of only punctuation. Most useful for Japanese / Chinese to run word segmentation without creating extra + punctuation tokens. + """ + + if keep_punctuation_attached: + tokenizer = lambda string: _reattach_punctuation(get_sacrebleu_tokenizer(language)(string)) + else: + tokenizer = get_sacrebleu_tokenizer(language) + + tokenized_segments = [] + words_are_timed = None + + for segment in segments: + tokenized_word_list = [] + + for word in segment.word_list: + assert word, "Words must not be empty." + + tokens = tokenizer(word.string).split() + assert tokens, "Tokenizer deleted word." + + for token_index, token in enumerate(tokens): + # Prefix the first token to mark original space. + if token_index == 0: + token = SPACE_ESCAPE + token + + # Only the last token inherits the original line break. + line_break = word.line_break if token_index == len(tokens) - 1 else LineBreak.NONE + + if isinstance(word, TimedWord): + assert words_are_timed is None or words_are_timed, "Either all or no words must be timed." + words_are_timed = True + + tokenized_word_list.append( + TimedWord(string=token, line_break=line_break, + subtitle_start_time=segment.start_time, subtitle_end_time=segment.end_time)) + else: + assert not words_are_timed, "Either all or no words must be timed." + words_are_timed = False + tokenized_word_list.append(Word(string=token, line_break=line_break)) + + if isinstance(segment, Subtitle): + if words_are_timed: + set_approximate_word_times(tokenized_word_list, segment.start_time, segment.end_time) + tokenized_segment = Subtitle(word_list=tokenized_word_list, index=segment.index, + start_time=segment.start_time, end_time=segment.end_time) + else: + tokenized_segment = Segment(word_list=tokenized_word_list) + + tokenized_segments.append(tokenized_segment) + + return tokenized_segments + + +def detokenize_segments(segments: List[Segment]) -> List[Segment]: + """ + Inverse of 'reversibly_tokenize_segments()'. + """ + + def add_word(current_tokens: List[str], detokenized_word_list: List[Word], words_are_timed: bool, + current_line_break: LineBreak, current_subtitle_start_time: Optional[float], + current_subtitle_end_time: Optional[float]): + """ + Helper function. Joins 'current_tokens' into a word and appends to 'detokenized_word_list'. Clears + 'current_tokens' afterwards. + """ + if not current_tokens: + return + + detokenized_word_string = "".join(current_tokens) + if words_are_timed: + detokenized_word = TimedWord( + string=detokenized_word_string, line_break=current_line_break, + subtitle_start_time=current_subtitle_start_time, subtitle_end_time=current_subtitle_end_time) + else: + detokenized_word = Word(string=detokenized_word_string, line_break=current_line_break) + detokenized_word_list.append(detokenized_word) + current_tokens.clear() + + detokenized_segments = [] + words_are_timed = None + + for segment in segments: + detokenized_word_list = [] + + current_tokens = [] + current_line_break = LineBreak.NONE + current_subtitle_start_time = None # Taken from TimedWord.subtitle_start/end_time. All tokens originating from + current_subtitle_end_time = None # a given word should have identical timings, we don't check this here. + + for token in segment.word_list: + if isinstance(token, TimedWord): + assert words_are_timed is None or words_are_timed, "Either all or no words must be timed." + words_are_timed = True + else: + assert not words_are_timed, "Either all or no words must be timed." + words_are_timed = False + + token_string = token.string + + if token_string.startswith(SPACE_ESCAPE): + assert len(token_string) > 1, "Space escape character should not appear as separate word." + token_string = token_string[1:] # strip space escape character + + # Flush the previous word if there is one. + add_word(current_tokens, detokenized_word_list, words_are_timed, current_line_break, + current_subtitle_start_time, current_subtitle_end_time) + + current_tokens.append(token_string) + current_line_break = token.line_break + if words_are_timed: + current_subtitle_start_time = token.subtitle_start_time + current_subtitle_end_time = token.subtitle_end_time + + # Flush remaining tokens. + add_word(current_tokens, detokenized_word_list, words_are_timed, current_line_break, + current_subtitle_start_time, current_subtitle_end_time) + + if isinstance(segment, Subtitle): + if words_are_timed: + set_approximate_word_times(detokenized_word_list, segment.start_time, segment.end_time) + detokenized_segment = Subtitle(word_list=detokenized_word_list, index=segment.index, + start_time=segment.start_time, end_time=segment.end_time) + else: + detokenized_segment = Segment(word_list=detokenized_word_list) + + detokenized_segments.append(detokenized_segment) + + return detokenized_segments + + +def _reattach_punctuation(word: str) -> str: + """ + To be used on a 'word' string that is the result of applying a tokenizer to a single (space-separated) word. + 'word' is therefore expected to contain spaces that split the word into tokens. This function removes spaces such + that tokens consisting of punctuation characters only get attached to the token to their left, except for leading + punctuation which gets attached right. + """ + word = regex.sub(r" (\p{P}+)(?= |$)", r"\1", word) + # Now the only possible punctuation token remaining should be at start of the string, remove space after it. + word = regex.sub(r"^(\p{P}+) ", r"\1", word) + return word diff --git a/suber/tools/align_hyp_to_ref.py b/suber/tools/align_hyp_to_ref.py index a3b73a2..dba1a9b 100644 --- a/suber/tools/align_hyp_to_ref.py +++ b/suber/tools/align_hyp_to_ref.py @@ -16,9 +16,14 @@ def parse_arguments(): parser.add_argument("-R", "--reference", required=True, help="The reference file.") parser.add_argument("-o", "--aligned-hypothesis", required=True, help="The aligned hypothesis output file in plain format.") - parser.add_argument("-f", "--hypothesis-format", default="SRT", help="Hypothesis file format, 'SRT' or 'plain'.") - parser.add_argument("-F", "--reference-format", default="SRT", help="Reference file format, 'SRT' or 'plain'.") - parser.add_argument("-m", "--method", default="levenshtein", + parser.add_argument("-f", "--hypothesis-format", default="SRT", choices=["SRT", "plain"], + help="Hypothesis file format, 'SRT' or 'plain'.") + parser.add_argument("-F", "--reference-format", default="SRT", choices=["SRT", "plain"], + help="Reference file format, 'SRT' or 'plain'.") + parser.add_argument("-l", "--language", choices=["zh", "ja", "ko"], + help='Set to "zh", "ja" or "ko" to enable correct tokenization of Chinese, Japanese or Korean ' + "text, respectively.") + parser.add_argument("-m", "--method", default="levenshtein", choices=["levenshtein", "time"], help="The alignment method, either 'levenshtein' or 'time'. See the " "'suber.hyp_to_ref_alignment' module. 'time' only supported if both hypothesis and " "reference are given in SRT format.") @@ -29,7 +34,7 @@ def parse_arguments(): def main(): args = parse_arguments() - if args.method == "time" and not args.hypothesis_format == "SRT" and args.reference_format == "SRT": + if args.method == "time" and not (args.hypothesis_format == "SRT" and args.reference_format == "SRT"): raise ValueError("For time alignment, both hypothesis and reference have to be given in SRT format.") hypothesis_segments = read_input_file(args.hypothesis, file_format=args.hypothesis_format) @@ -37,10 +42,10 @@ def main(): if args.method == "levenshtein": aligned_hypothesis_segments = levenshtein_align_hypothesis_to_reference( - hypothesis=hypothesis_segments, reference=reference_segments) + hypothesis=hypothesis_segments, reference=reference_segments, language=args.language) elif args.method == "time": aligned_hypothesis_segments = time_align_hypothesis_to_reference( - hypothesis=hypothesis_segments, reference=reference_segments) + hypothesis=hypothesis_segments, reference=reference_segments, language=args.language) with open(args.aligned_hypothesis, "w", encoding="utf-8") as output_file_object: for segment in aligned_hypothesis_segments: diff --git a/suber/utilities.py b/suber/utilities.py index 7ec971d..aee5757 100644 --- a/suber/utilities.py +++ b/suber/utilities.py @@ -1,5 +1,8 @@ -from suber.data_types import LineBreak, Segment +import numpy +from typing import List + from suber.constants import END_OF_LINE_SYMBOL, END_OF_BLOCK_SYMBOL, MASK_SYMBOL +from suber.data_types import LineBreak, Segment, TimedWord def segment_to_string(segment: Segment, include_line_breaks=False, include_last_break=True, @@ -36,3 +39,22 @@ def get_segment_to_string_opts_from_metric(metric: str): metric = metric[:-len("-seg")] return include_breaks, mask_words, metric + + +def set_approximate_word_times(word_list: List[TimedWord], subtitle_start_time: float, subtitle_end_time: float): + """ + Linearly interpolates word times from the subtitle start and end time as described in + https://www.isca-archive.org/interspeech_2021/cherry21_interspeech.pdf + """ + # Remove small margin to guarantee the first and last word will always be counted as within the subtitle. + epsilon = 1e-8 + subtitle_start_time = subtitle_start_time + epsilon + subtitle_end_time = subtitle_end_time - epsilon + + num_words = len(word_list) + duration = subtitle_end_time - subtitle_start_time + assert duration >= 0 + + approximate_word_times = numpy.linspace(start=subtitle_start_time, stop=subtitle_end_time, num=num_words) + for word_time, word in zip(approximate_word_times, word_list): + word.approximate_word_time = word_time diff --git a/tests/test_cer.py b/tests/test_cer.py index 9e557ee..dafe6ae 100644 --- a/tests/test_cer.py +++ b/tests/test_cer.py @@ -42,6 +42,44 @@ def test_cer(self): # 2 edits / 68 characters self.assertAlmostEqual(cer_cased_score, 2.941) + def test_cer_japanese(self): + reference_file_content = """ + 1 + 00:00:00,000 --> 00:00:01,000 + これは簡単な最初のブロックです + + 2 + 00:00:01,000 --> 00:00:02,000 + これは二つの行を持つ + 別のブロックです。""" + + hypothesis_file_content = """ + 1 + 00:00:00,000 --> 00:00:01,000 + 「これは簡単な最初のブロックです」 + + 2 + 00:00:01,000 --> 00:00:02,000 + これは二つの行を + 持つ別のブロックです""" + + reference_subtitles = create_temporary_file_and_read_it(reference_file_content) + hypothesis_subtitles = create_temporary_file_and_read_it(hypothesis_file_content) + + cer_score = calculate_character_error_rate( + hypothesis=hypothesis_subtitles, reference=reference_subtitles, metric="CER") + + # TODO: this is not 0 because a space is added at line breaks and thus the second blocks differ in the position + # of this space. We might want to not add such a space for Japanese? + # 1 space insertion, 1 space deletion, 34 reference characters (including the spaces). + self.assertAlmostEqual(cer_score, 5.882) + + cer_cased_score = calculate_character_error_rate( + hypothesis=hypothesis_subtitles, reference=reference_subtitles, metric="CER-cased") + + # 3 punctuation character errors, 2 space edits as above, now 35 reference characters. + self.assertAlmostEqual(cer_cased_score, 14.286) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_jiwer_interface.py b/tests/test_jiwer_interface.py index 700cd68..c6c658f 100644 --- a/tests/test_jiwer_interface.py +++ b/tests/test_jiwer_interface.py @@ -54,6 +54,159 @@ def test_wer(self): # (1 break deletion + 1 break insertion) / (13 words + 1 breaks) self.assertAlmostEqual(wer_seg_score, 14.286) + def test_wer_chinese(self): + reference_file_content = """ + 1 + 00:00:00,000 --> 00:00:01,000 + 这是一个简单的第一帧 + + 2 + 00:00:01,000 --> 00:00:02,000 + 这是另一个有 + 两条线的帧""" + + hypothesis_file_content = """ + 1 + 00:00:00,000 --> 00:00:01,000 + 这是一个简单的第一帧。 + + 2 + 00:00:01,000 --> 00:00:02,000 + 这是另一个 + 有两条线的帧。""" + + reference_subtitles = create_temporary_file_and_read_it(reference_file_content) + hypothesis_subtitles = create_temporary_file_and_read_it(hypothesis_file_content) + + wer_score = calculate_word_error_rate( + hypothesis=hypothesis_subtitles, reference=reference_subtitles, metric="WER", language="zh") + + self.assertAlmostEqual(wer_score, 0.0) + + wer_cased_score = calculate_word_error_rate( + hypothesis=hypothesis_subtitles, reference=reference_subtitles, metric="WER-cased", language="zh") + + # TokenizerZh expected to split all characters. + # 2 punctuation errors / 21 tokenized characters + self.assertAlmostEqual(wer_cased_score, 9.524) + + wer_seg_score = calculate_word_error_rate( + hypothesis=hypothesis_subtitles, reference=reference_subtitles, metric="WER-seg", language="zh") + + # (1 break deletion + 1 break insertion) / (21 tokenized characters + 3 breaks) + self.assertAlmostEqual(wer_seg_score, 8.333) + + wer_seg_score = calculate_word_error_rate( + hypothesis=hypothesis_subtitles, reference=reference_subtitles, metric="WER-seg", + score_break_at_segment_end=False, language="zh") + + # (1 break deletion + 1 break insertion) / (21 tokenized characters + 1 breaks) + self.assertAlmostEqual(wer_seg_score, 9.091) + + def test_wer_japanese(self): + reference_file_content = """ + 1 + 00:00:00,000 --> 00:00:01,000 + これは簡単な最初のブロックです + + 2 + 00:00:01,000 --> 00:00:02,000 + これは二つの行を持つ + 別のブロックです""" + + hypothesis_file_content = """ + 1 + 00:00:00,000 --> 00:00:01,000 + これは簡単な最初のブロックです。 + + 2 + 00:00:01,000 --> 00:00:02,000 + これは二つの行を + 持つ別のブロックです。""" + + # TokenizerJaMecab expected to tokenize into this: + # "これ は 簡単 な 最初 の ブロック です" + # "これ は 二つ の 行 を 持つ 別 の ブロック です" + + reference_subtitles = create_temporary_file_and_read_it(reference_file_content) + hypothesis_subtitles = create_temporary_file_and_read_it(hypothesis_file_content) + + wer_score = calculate_word_error_rate( + hypothesis=hypothesis_subtitles, reference=reference_subtitles, metric="WER", language="ja") + + self.assertAlmostEqual(wer_score, 0.0) + + wer_cased_score = calculate_word_error_rate( + hypothesis=hypothesis_subtitles, reference=reference_subtitles, metric="WER-cased", language="ja") + + # 2 punctuation errors / 19 tokenized words + self.assertAlmostEqual(wer_cased_score, 10.526) + + wer_seg_score = calculate_word_error_rate( + hypothesis=hypothesis_subtitles, reference=reference_subtitles, metric="WER-seg", language="ja") + + # (1 break deletion + 1 break insertion) / (19 tokenized words + 3 breaks) + self.assertAlmostEqual(wer_seg_score, 9.091) + + wer_seg_score = calculate_word_error_rate( + hypothesis=hypothesis_subtitles, reference=reference_subtitles, metric="WER-seg", + score_break_at_segment_end=False, language="ja") + + # (1 break deletion + 1 break insertion) / (19 tokenized words + 1 breaks) + self.assertAlmostEqual(wer_seg_score, 10.0) + + def test_wer_korean(self): + reference_file_content = """ + 1 + 00:00:00,000 --> 00:00:01,000 + 이것은 간단한 첫 번째 프레임입니다 + + 2 + 00:00:01,000 --> 00:00:02,000 + 이것은 두 줄로 이루어진 또 + 다른 프레임입니다""" + + hypothesis_file_content = """ + 1 + 00:00:00,000 --> 00:00:01,000 + 이것은 간단한 첫 번째 프레임입니다. + + 2 + 00:00:01,000 --> 00:00:02,000 + 이것은 두 줄로 이루어진 + 또 다른 프레임입니다.""" + + # TokenizerKoMecab expected to tokenize into this: + # "이것 은 간단 한 첫 번 째 프레임 입니다" + # "이것 은 두 줄 로 이루어진 또 다른 프레임 입니다" + + reference_subtitles = create_temporary_file_and_read_it(reference_file_content) + hypothesis_subtitles = create_temporary_file_and_read_it(hypothesis_file_content) + + wer_score = calculate_word_error_rate( + hypothesis=hypothesis_subtitles, reference=reference_subtitles, metric="WER", language="ko") + + self.assertAlmostEqual(wer_score, 0.0) + + wer_cased_score = calculate_word_error_rate( + hypothesis=hypothesis_subtitles, reference=reference_subtitles, metric="WER-cased", language="ko") + + # 2 punctuation errors / 19 tokenized words + self.assertAlmostEqual(wer_cased_score, 10.526) + + wer_seg_score = calculate_word_error_rate( + hypothesis=hypothesis_subtitles, reference=reference_subtitles, metric="WER-seg", language="ko") + + # (1 break deletion + 1 break insertion) / (19 tokenized characters + 3 breaks) + self.assertAlmostEqual(wer_seg_score, 9.091) + + wer_seg_score = calculate_word_error_rate( + hypothesis=hypothesis_subtitles, reference=reference_subtitles, metric="WER-seg", + score_break_at_segment_end=False, language="ko") + + # (1 break deletion + 1 break insertion) / (19 tokenized characters + 1 breaks) + self.assertAlmostEqual(wer_seg_score, 10.0) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_length_ratio.py b/tests/test_length_ratio.py index 1ecc195..1c7cded 100644 --- a/tests/test_length_ratio.py +++ b/tests/test_length_ratio.py @@ -5,7 +5,7 @@ class LengthRatioTest(unittest.TestCase): - def setUp(self): + def test_length_ratio(self): # Punctuation marks should count as separate tokens. reference_file_content = """ 1 @@ -30,11 +30,118 @@ def setUp(self): 00:00:01,500 --> 00:00:02,000 six?""" - self._reference_subtitles = create_temporary_file_and_read_it(reference_file_content) - self._hypothesis_subtitles = create_temporary_file_and_read_it(hypothesis_file_content) + reference_subtitles = create_temporary_file_and_read_it(reference_file_content) + hypothesis_subtitles = create_temporary_file_and_read_it(hypothesis_file_content) + + length_ratio = calculate_length_ratio(hypothesis=hypothesis_subtitles, reference=reference_subtitles) + + self.assertAlmostEqual(length_ratio, 7 / 9 * 100, places=3) + + def test_length_ratio_chinese(self): + # Should be split into characters, including punctuation, except for the English words, which are handled + # separately by the tokenizer. + reference_file_content = """ + 1 + 00:00:00,000 --> 00:00:01,000 + 一二三。 + + 2 + 00:00:01,000 --> 00:00:02,000 + 五六 + 七八?Plus three!""" + + hypothesis_file_content = """ + 1 + 00:00:00,000 --> 00:00:01,000 + 一二。 + + 2 + 00:00:01,000 --> 00:00:02,000 + 四五 + + 3 + 00:00:01,500 --> 00:00:02,000 + 六? + Plus three!""" + + reference_subtitles = create_temporary_file_and_read_it(reference_file_content) + hypothesis_subtitles = create_temporary_file_and_read_it(hypothesis_file_content) - def test_length_ratio(self): length_ratio = calculate_length_ratio( - hypothesis=self._hypothesis_subtitles, reference=self._reference_subtitles) + hypothesis=hypothesis_subtitles, reference=reference_subtitles, language="zh") + + # As in English test_length_ratio(), but both reference and hypothesis "plus three" tokens. + self.assertAlmostEqual(length_ratio, 10 / 12 * 100, places=3) + def test_length_ratio_japanese(self): + # TODO: Not sure what to expect here, some numbers are split into characters, others not. (Without commas it + # looks even less consistent to me.) Need language expertise. :D + # Could also test kanji, but then it would be the same as Chinese, i.e. characters? + reference_file_content = """ + 1 + 00:00:00,000 --> 00:00:01,000 + いち、に、さん。 + + 2 + 00:00:01,000 --> 00:00:02,000 + ご、ろく + しち、はち?""" + + hypothesis_file_content = """ + 1 + 00:00:00,000 --> 00:00:01,000 + いち、に。 + + 2 + 00:00:01,000 --> 00:00:02,000 + し、ご + + 3 + 00:00:01,500 --> 00:00:02,000 + ろく?""" + + reference_subtitles = create_temporary_file_and_read_it(reference_file_content) + hypothesis_subtitles = create_temporary_file_and_read_it(hypothesis_file_content) + + length_ratio = calculate_length_ratio( + hypothesis=hypothesis_subtitles, reference=reference_subtitles, language="ja") + + self.assertAlmostEqual(length_ratio, 10 / 16 * 100, places=3) + + def test_length_ratio_korean(self): + # Tokenizer expected to split into numbers. + reference_file_content = """ + 1 + 00:00:00,000 --> 00:00:01,000 + 하나둘셋. + + 2 + 00:00:01,000 --> 00:00:02,000 + 다섯여섯 + 일곱여덟?""" + + hypothesis_file_content = """ + 1 + 00:00:00,000 --> 00:00:01,000 + 하나둘. + + 2 + 00:00:01,000 --> 00:00:02,000 + 넷다섯 + + 3 + 00:00:01,500 --> 00:00:02,000 + 여섯?""" + + reference_subtitles = create_temporary_file_and_read_it(reference_file_content) + hypothesis_subtitles = create_temporary_file_and_read_it(hypothesis_file_content) + + length_ratio = calculate_length_ratio( + hypothesis=hypothesis_subtitles, reference=reference_subtitles, language="ko") + + # As in English test_length_ratio(). self.assertAlmostEqual(length_ratio, 7 / 9 * 100, places=3) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_sacrebleu_interface.py b/tests/test_sacrebleu_interface.py index 88b7451..573eb3a 100644 --- a/tests/test_sacrebleu_interface.py +++ b/tests/test_sacrebleu_interface.py @@ -87,5 +87,88 @@ def test_chrF(self): self.assertAlmostEqual(chrF_score, 100.0) +class SacreBleuInterfaceTestJapanese(unittest.TestCase): + def setUp(self): + reference_file_content = """ + 1 + 00:00:00,000 --> 00:00:01,000 + これは簡単な最初のブロックです + + 2 + 00:00:01,000 --> 00:00:02,000 + これは二つの行を持つ + 別のブロックです""" + + hypothesis_file_content = """ + 1 + 00:00:00,000 --> 00:00:01,000 + これは簡単な最初のブロックです + + 2 + 00:00:01,000 --> 00:00:02,000 + これは二つの行を + 持つ別のブロックです""" + + # TokenizerJaMecab used for BLEU expected to tokenize into this: + # "これ は 簡単 な 最初 の ブロック です" + # "これ は 二つ の 行 を 持つ 別 の ブロック です" + + # TercomTokenizer(normalized=True, asian_support=True) used for TER expected to tokenize into this: + # "これは 簡 単 な 最 初 のブロックです" + # "これは 二 つの 行 を 持 つ 別 のブロックです" + + self._reference_subtitles = create_temporary_file_and_read_it(reference_file_content) + self._hypothesis_subtitles = create_temporary_file_and_read_it(hypothesis_file_content) + + def test_bleu(self): + bleu_score = calculate_sacrebleu_metric( + hypothesis=self._hypothesis_subtitles, reference=self._reference_subtitles, metric="BLEU", language="ja") + + self.assertAlmostEqual(bleu_score, 100.0) + + bleu_seg_score = calculate_sacrebleu_metric( + hypothesis=self._hypothesis_subtitles, reference=self._reference_subtitles, metric="BLEU-seg", + language="ja") + + # Manually checked that internal sacrebleu tokenization is as expected with these results, including break + # tokens. + self.assertAlmostEqual(bleu_seg_score, 82.108) + + bleu_seg_score = calculate_sacrebleu_metric( + hypothesis=self._hypothesis_subtitles, reference=self._reference_subtitles, metric="BLEU-seg", + score_break_at_segment_end=False, language="ja") + + self.assertAlmostEqual(bleu_seg_score, 79.616) + + def test_TER(self): + ter_score = calculate_sacrebleu_metric( + hypothesis=self._hypothesis_subtitles, reference=self._reference_subtitles, metric="TER", language="ja") + + self.assertAlmostEqual(ter_score, 0.0) + + ter_seg_score = calculate_sacrebleu_metric( + hypothesis=self._hypothesis_subtitles, reference=self._reference_subtitles, metric="TER-seg", language="ja") + + # 1 break shift / (16 words + 3 breaks) + self.assertAlmostEqual(ter_seg_score, 5.263) + + ter_seg_score = calculate_sacrebleu_metric( + hypothesis=self._hypothesis_subtitles, reference=self._reference_subtitles, metric="TER-seg", + score_break_at_segment_end=False, language="ja") + + # 1 break shift / (16 words + 1 breaks) + self.assertAlmostEqual(ter_seg_score, 5.882) + + def test_chrF(self): + chrF_score = calculate_sacrebleu_metric( + hypothesis=self._hypothesis_subtitles, reference=self._reference_subtitles, metric="chrF") + + self.assertAlmostEqual(chrF_score, 100.0) + + +# TODO: we should probably add tests for Chinese and Korean here. But only affects the BLEU metric, and we anyways only +# pass the language code (same way as for Japanese) and otherwise rely on sacrebleu to do the right thing... + + if __name__ == '__main__': unittest.main() diff --git a/tests/test_suber_metric.py b/tests/test_suber_metric.py index 650d5dc..b6c0a96 100644 --- a/tests/test_suber_metric.py +++ b/tests/test_suber_metric.py @@ -21,11 +21,11 @@ def setUp(self): 0:00:03.000 --> 0:00:04.000 And another one!""" - def _run_test(self, hypothesis, reference, expected_score): + def _run_test(self, hypothesis, reference, expected_score, language=None): hypothesis_subtitles = create_temporary_file_and_read_it(hypothesis) reference_subtitles = create_temporary_file_and_read_it(reference) - SubER_score = calculate_SubER(hypothesis_subtitles, reference_subtitles) + SubER_score = calculate_SubER(hypothesis_subtitles, reference_subtitles, language=language) self.assertAlmostEqual(SubER_score, expected_score) @@ -161,6 +161,96 @@ def test_split_into_three_with_one_shift(self): # (1 shift + 2 break insertions) / (7 words + 2 breaks) self._run_test(hypothesis, self._reference2, expected_score=33.333) + def test_split_into_three_with_one_shift_chinese(self): + reference = """ + 1 + 0:00:01.000 --> 0:00:02.000 + 这是字幕。 + + 2 + 0:00:03.000 --> 0:00:04.000 + 还有一个!""" + + hypothesis = """ + 1 + 0:00:01.000 --> 0:00:01.500 + 这是 + + 2 + 0:00:01.500 --> 0:00:03.500 + 还有 + 字幕。 + + 2 + 0:00:03.500 --> 0:00:04.000 + 一个!""" + + # TokenizerZh expected to split all characters. + # (1 shift + 2 break insertions) / (8 characters + 2 breaks) + self._run_test(hypothesis, reference, expected_score=30.0, language="zh") + + def test_split_into_three_with_one_shift_japanese(self): + reference = """ + 1 + 0:00:01.000 --> 0:00:02.000 + これは字幕です。 + + 2 + 0:00:03.000 --> 0:00:04.000 + そしてもう一つ。""" + + hypothesis = """ + 1 + 0:00:01.000 --> 0:00:01.500 + これは + + 2 + 0:00:01.500 --> 0:00:03.500 + そして + 字幕です。 + + 2 + 0:00:03.500 --> 0:00:04.000 + もう一つ!""" + + # TercomTokenizer splits into this: + # "これ は 字幕 です そして もう 一つ" + # "これ は そして 字幕 です もう 一つ" + + # (1 shift + 2 break insertions) / (7 words + 2 breaks) + self._run_test(hypothesis, reference, expected_score=33.333, language="ja") + + def test_split_into_three_with_one_shift_korean(self): + reference = """ + 1 + 0:00:01.000 --> 0:00:02.000 + 이것은 자막입니다. + + 2 + 0:00:03.000 --> 0:00:04.000 + 또 하나 나왔네요.""" + + hypothesis = """ + 1 + 0:00:01.000 --> 0:00:01.500 + 이것은 + + 2 + 0:00:01.500 --> 0:00:03.500 + 또 하나 + 자막입니다. + + 2 + 0:00:03.500 --> 0:00:04.000 + 나왔네요!""" + + # TercomTokenizer splits into this: + # "이것 은 자막 입니다" + # "또 하나 나왔 네요" + + # (1 shift + 2 break insertions) / (8 words + 2 breaks) + self._run_test(hypothesis, reference, expected_score=30.0, language="ko") + class SubERCasedMetricTests(unittest.TestCase): def test_SubER_cased(self): @@ -193,9 +283,117 @@ def test_SubER_cased(self): SubER_score = calculate_SubER(hypothesis_subtitles, reference_subtitles, metric="SubER-cased") # After tokenization there should be 9 reference words + 2 reference break tokens. - # 1 shift and 2 break deletions as above for SubER, plus 2 substitutions: "," -> "."; "and" -> "And" + # 1 shift and 2 break insertions as above for SubER, plus 2 substitutions: "," -> "."; "and" -> "And" self.assertAlmostEqual(SubER_score, 45.455) + def test_SubER_cased_chinese(self): + reference = """ + 1 + 0:00:01.000 --> 0:00:02.000 + 这是字幕。 + + 2 + 0:00:03.000 --> 0:00:04.000 + 还有一个。""" + + hypothesis = """ + 1 + 0:00:01.000 --> 0:00:01.500 + 这是 + + 2 + 0:00:01.500 --> 0:00:03.500 + 还有 + 字幕。 + + 2 + 0:00:03.500 --> 0:00:04.000 + 一个!""" + + hypothesis_subtitles = create_temporary_file_and_read_it(hypothesis) + reference_subtitles = create_temporary_file_and_read_it(reference) + + SubER_score = calculate_SubER(hypothesis_subtitles, reference_subtitles, metric="SubER-cased", language="zh") + + # TokenizerZh expected to split all characters. + # 10 words after tokenization + 2 reference break tokens. + # 1 shift and 2 break insertions as above for SubER, plus 1 substitution '!' -> '。' + self.assertAlmostEqual(SubER_score, 33.333) + + def test_SubER_cased_japanese(self): + reference = """ + 1 + 0:00:01.000 --> 0:00:02.000 + これは字幕です。 + + 2 + 0:00:03.000 --> 0:00:04.000 + そしてもう一つ。""" + + hypothesis = """ + 1 + 0:00:01.000 --> 0:00:01.500 + これは + + 2 + 0:00:01.500 --> 0:00:03.500 + そして + 字幕です。 + + 2 + 0:00:03.500 --> 0:00:04.000 + もう一つ!""" + + # TokenizerJaMecab splits into this: + # "これ は 字幕 です 。 そして もう 一つ 。" + # "これ は そして 字幕 です 。 もう 一つ !" + + hypothesis_subtitles = create_temporary_file_and_read_it(hypothesis) + reference_subtitles = create_temporary_file_and_read_it(reference) + + SubER_score = calculate_SubER(hypothesis_subtitles, reference_subtitles, metric="SubER-cased", language="ja") + + # 9 words after tokenization + 2 reference break tokens. + # 1 shift and 2 break insertions as above for SubER, plus 1 substitution '!' -> '。' + self.assertAlmostEqual(SubER_score, 36.364) + + def test_SubER_cased_korean(self): + reference = """ + 1 + 0:00:01.000 --> 0:00:02.000 + 이것은 자막입니다. + + 2 + 0:00:03.000 --> 0:00:04.000 + 또 하나 나왔네요.""" + + hypothesis = """ + 1 + 0:00:01.000 --> 0:00:01.500 + 이것은 + + 2 + 0:00:01.500 --> 0:00:03.500 + 또 하나 + 자막입니다. + + 2 + 0:00:03.500 --> 0:00:04.000 + 나왔네요!""" + + # TercomTokenizer splits into this: + # "이것 은 자막 입니다 ." + # "또 하나 나왔 네요 ." + + hypothesis_subtitles = create_temporary_file_and_read_it(hypothesis) + reference_subtitles = create_temporary_file_and_read_it(reference) + + SubER_score = calculate_SubER(hypothesis_subtitles, reference_subtitles, metric="SubER-cased", language="ko") + + # 10 words after tokenization + 2 reference break tokens. + # 1 shift and 2 break insertions as above for SubER, plus 1 substitution '!' -> '.' + self.assertAlmostEqual(SubER_score, 33.333) + class SubERHelperFunctionTests(unittest.TestCase): diff --git a/tests/test_tokenizers.py b/tests/test_tokenizers.py new file mode 100644 index 0000000..fed6413 --- /dev/null +++ b/tests/test_tokenizers.py @@ -0,0 +1,145 @@ +import unittest + +from suber.tokenizers import _reattach_punctuation, detokenize_segments, reversibly_tokenize_segments + +from .utilities import create_temporary_file_and_read_it + + +class ReversibleTokenizationTests(unittest.TestCase): + def test_reversible_tokenization(self): + example_srt = """ +1 +00:00:00,000 --> 00:00:01,380 +入れ墨入れたの? + +2 +00:00:01,380 --> 00:00:03,000 +入れ墨入れたよ + +3 +00:00:03,020 --> 00:00:04,000 +新品同様か? + +4 +00:00:04,000 --> 00:00:05,240 +うん + +5 +00:00:05,240 --> 00:00:06,700 +なんでそんな顔してるの? + +6 +00:00:06,700 --> 00:00:07,960 +ユーチューブの皆さん +どうも! + +7 +00:00:07,960 --> 00:00:09,180 +ぼくはジェイソン + +8 +00:00:09,180 --> 00:00:10,640 +二年目の医学生です + +9 +00:00:10,640 --> 00:00:11,680 +ようこそ ぼくのチャンネル + +10 +00:00:11,680 --> 00:00:12,740 +「信じて見据えよう」へ + +11 +00:00:12,740 --> 00:00:14,500 +簡単に背景を説明すると + +12 +00:00:14,500 --> 00:00:17,640 +ずっと胸の入れ墨を入れたい +と思っていて + +13 +00:00:17,640 --> 00:00:20,020 +下調べとかもしてきたんだ + +14 +00:00:20,020 --> 00:00:24,240 +でも頭の中では +入れ墨が全くない状態から + +15 +00:00:24,240 --> 00:00:26,040 +いきなり胸に入れ墨を入れる +つもりはなかったんだ + +16 +00:00:26,040 --> 00:00:27,240 +それでぼくがしたことは + +17 +00:00:27,240 --> 00:00:29,960 +胸に貼るステッカー式の +入れ墨を買ったんだ + +18 +00:00:29,960 --> 00:00:32,960 +胸の入れ墨がどう見えるか +確かめるために + +19 +00:00:32,960 --> 00:00:35,300 +問題は +ぼくが両親に電話して + +20 +00:00:35,300 --> 00:00:37,160 +そのステッカーの入れ墨が +""" + subtitles = create_temporary_file_and_read_it(example_srt) + + tokenized_subtitles = reversibly_tokenize_segments(subtitles, language="ja", keep_punctuation_attached=False) + tokenized_subtitles_punct_attached = reversibly_tokenize_segments( + subtitles, language="ja", keep_punctuation_attached=True) + + num_words = sum(len(subtitle.word_list) for subtitle in subtitles) + num_tokens = sum(len(subtitle.word_list) for subtitle in tokenized_subtitles) + num_tokens_punct_attached = sum(len(subtitle.word_list) for subtitle in tokenized_subtitles_punct_attached) + + self.assertTrue(num_words < num_tokens_punct_attached < num_tokens) + + all_characters_from_words = "".join(word.string for subtitle in subtitles for word in subtitle.word_list) + all_characters_from_tokens = "".join( + word.string for subtitle in tokenized_subtitles for word in subtitle.word_list) + all_characters_from_tokens_punct_attached = "".join( + word.string for subtitle in tokenized_subtitles_punct_attached for word in subtitle.word_list) + + self.assertEqual(all_characters_from_words, all_characters_from_tokens.replace("▁", "")) + self.assertEqual(all_characters_from_words, all_characters_from_tokens_punct_attached.replace("▁", "")) + + previous_word_time = None + for subtitle in tokenized_subtitles: + for word in subtitle.word_list: + self.assertEqual(word.subtitle_start_time, subtitle.start_time) + self.assertEqual(word.subtitle_end_time, subtitle.end_time) + self.assertTrue(subtitle.start_time <= word.approximate_word_time <= subtitle.end_time) + if previous_word_time is not None: + self.assertTrue(previous_word_time < word.approximate_word_time) + previous_word_time = word.approximate_word_time + + detokenize_subtitles = detokenize_segments(tokenized_subtitles) + self.assertEqual(subtitles, detokenize_subtitles) + detokenize_subtitles = detokenize_segments(tokenized_subtitles_punct_attached) + self.assertEqual(subtitles, detokenize_subtitles) + + def test_reattach_punctuation(self): + self.assertEqual(_reattach_punctuation("No punctuation"), "No punctuation") + self.assertEqual(_reattach_punctuation("シンプル な 句読点 。"), "シンプル な 句読点。") + self.assertEqual(_reattach_punctuation("¿ Esto funciona ?"), "¿Esto funciona?") + self.assertEqual(_reattach_punctuation("アルバート ・ アインシュタイン"), "アルバート・ アインシュタイン") + self.assertEqual(_reattach_punctuation("Multiple . .. ... tokens"), "Multiple...... tokens") + self.assertEqual(_reattach_punctuation(". .. ... Multiple tokens"), "......Multiple tokens") + self.assertEqual(_reattach_punctuation("Multiple tokens . .. ..."), "Multiple tokens......") + + +if __name__ == '__main__': + unittest.main()