From ee217c471d6a14d8b3289a8782f7b02b45542305 Mon Sep 17 00:00:00 2001 From: RomanKoshkin Date: Sun, 21 Apr 2024 11:35:05 +0900 Subject: [PATCH 1/5] FIX: user couldn't add custom cli arguments to be passed to the policy --- simuleval/cli.py | 4 ++-- simuleval/utils/agent.py | 2 +- simuleval/utils/slurm.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/simuleval/cli.py b/simuleval/cli.py index 03ae24b..593f32d 100644 --- a/simuleval/cli.py +++ b/simuleval/cli.py @@ -87,7 +87,7 @@ def scoring(): options.add_evaluator_args(parser) options.add_scorer_args(parser) options.add_dataloader_args(parser) - args = parser.parse_args() + args, _ = parser.parse_known_args() evaluator = SentenceLevelEvaluator.from_args(args) print(evaluator.results) @@ -98,7 +98,7 @@ def remote_evaluate(): options.add_dataloader_args(parser) options.add_evaluator_args(parser) options.add_scorer_args(parser) - args = parser.parse_args() + args, _ = parser.parse_known_args() evaluator = build_remote_evaluator(args) # evaluate system diff --git a/simuleval/utils/agent.py b/simuleval/utils/agent.py index 06c13bd..f2b05e5 100644 --- a/simuleval/utils/agent.py +++ b/simuleval/utils/agent.py @@ -142,7 +142,7 @@ def build_system_args( args, _ = parser.parse_known_args(cli_argument_list(config_dict)) system = system_class.from_args(args) - args = parser.parse_args(cli_argument_list(config_dict)) + args, _ = parser.parse_known_args(cli_argument_list(config_dict)) dtype = args.dtype if args.dtype else "fp16" if args.fp16 else "fp32" logger.info(f"System will run on device: {args.device}. dtype: {dtype}") diff --git a/simuleval/utils/slurm.py b/simuleval/utils/slurm.py index 88c3417..e0e27cc 100644 --- a/simuleval/utils/slurm.py +++ b/simuleval/utils/slurm.py @@ -59,7 +59,7 @@ def submit_slurm_job( options.add_dataloader_args(parser, cli_arguments) system_class = get_agent_class(config_dict) system_class.add_args(parser) - args = parser.parse_args(cli_argument_list(config_dict)) + args, _ = parser.parse_known_args(cli_argument_list(config_dict)) args.output = os.path.abspath(args.output) assert mkdir_output_dir(args.output) From e1a9efc6564c938067800e1364ad032804986404 Mon Sep 17 00:00:00 2001 From: RomanKoshkin Date: Wed, 8 May 2024 14:06:34 +0900 Subject: [PATCH 2/5] chk --- simuleval/evaluator/evaluator.py | 25 +++++++++++++++-- simuleval/evaluator/scorers/latency_scorer.py | 17 ++++++++++++ simuleval/evaluator/scorers/quality_scorer.py | 27 ++++++++++++++++++- 3 files changed, 66 insertions(+), 3 deletions(-) diff --git a/simuleval/evaluator/evaluator.py b/simuleval/evaluator/evaluator.py index b417f13..8b213fb 100644 --- a/simuleval/evaluator/evaluator.py +++ b/simuleval/evaluator/evaluator.py @@ -3,12 +3,13 @@ # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. - +# test import contextlib import json import logging import numbers -import os +import os, time +import argparse from argparse import Namespace from pathlib import Path from typing import Dict, Generator, Optional @@ -153,6 +154,7 @@ def __init__( ) else: self.iterator = iterable + self.start_t = time.time() def write_log(self, instance): if self.output is not None: @@ -224,6 +226,25 @@ def results(self): def dump_results(self) -> None: results = self.results + finish_t = time.time() + results["TIME"] = finish_t - self.start_t + + parser = argparse.ArgumentParser() + parser.add_argument("--use_api", action='store_true') + parser.add_argument("--k", type=int, default=4) + parser.add_argument("--dir", type=str, default=None) + parser.add_argument("--output", type=str, default=None) + parser.add_argument("--model_id", type=str, default=None) + parser.add_argument("--start-index", type=int, default=None) + parser.add_argument("--end-index", type=int, default=None) + custom_args, _ = parser.parse_known_args() + results["k"] = custom_args.k + results["dir"] = custom_args.dir + results["output"] = custom_args.output + results["use_api"] = custom_args.use_api + results["model_id"] = custom_args.model_id + results["start_index"] = custom_args.start_index + results["end_index"] = custom_args.end_index if self.output: results.to_csv(self.output / "scores.tsv", sep="\t", index=False) diff --git a/simuleval/evaluator/scorers/latency_scorer.py b/simuleval/evaluator/scorers/latency_scorer.py index 8cebf01..fa24622 100644 --- a/simuleval/evaluator/scorers/latency_scorer.py +++ b/simuleval/evaluator/scorers/latency_scorer.py @@ -22,6 +22,7 @@ ) from argparse import ArgumentParser, Namespace from subprocess import Popen, PIPE +import numpy as np logger = logging.getLogger("simuleval.latency_scorer") @@ -111,6 +112,22 @@ def from_args(cls, args: Namespace): ) +@register_latency_scorer("RTF") +class RTFScorer(LatencyScorer): + """Real time factor + + Usage: + --latency-metrics RTF + """ + + def __call__(self, instances) -> float: + scores = [] + for ins in instances.values(): + scores.append(ins.delays[-1] / ins.source_length) + + return np.mean(scores) + + @register_latency_scorer("AL") class ALScorer(LatencyScorer): r""" diff --git a/simuleval/evaluator/scorers/quality_scorer.py b/simuleval/evaluator/scorers/quality_scorer.py index acf867b..f6af654 100644 --- a/simuleval/evaluator/scorers/quality_scorer.py +++ b/simuleval/evaluator/scorers/quality_scorer.py @@ -9,17 +9,22 @@ import subprocess from pathlib import Path from typing import Dict +import numpy as np import sacrebleu import tqdm from sacrebleu.metrics.bleu import BLEU +from argparse import ArgumentParser, Namespace + QUALITY_SCORERS_DICT = {} +QUALITY_SCORERS_NAME_DICT = {} def register_quality_scorer(name): def register(cls): QUALITY_SCORERS_DICT[name] = cls + QUALITY_SCORERS_NAME_DICT[cls.__name__] = name return cls return register @@ -33,9 +38,12 @@ def __call__(self, instances: Dict) -> float: raise NotImplementedError @staticmethod - def add_args(parser): + def add_args(parser: ArgumentParser): pass + @classmethod + def from_args(cls, args: Namespace): + return cls() def add_sacrebleu_args(parser): parser.add_argument( @@ -46,7 +54,24 @@ def add_sacrebleu_args(parser): help="Tokenizer in sacrebleu", ) +@register_quality_scorer("CHRF") +class CHRFScorer(QualityScorer): + """ChrF1 + + Usage: + --quality-metrics CHRF + """ + + def __call__(self, instances) -> float: + scores = [] + for ins in instances.values(): + ref = ins.reference + hyp = ins.prediction + chrf = sacrebleu.corpus_chrf([hyp], [ref]) + scores.append(chrf.score) + return np.mean(scores) + @register_quality_scorer("WER") class WERScorer(QualityScorer): """ From e002e7e63da29e846476c2c4fcdc015b88d7bc69 Mon Sep 17 00:00:00 2001 From: RomanKoshkin Date: Sat, 11 May 2024 00:20:09 +0900 Subject: [PATCH 3/5] major extensions: eval with backgroud info --- simuleval/data/dataloader/dataloader.py | 13 +++- simuleval/data/dataloader/s2t_dataloader.py | 20 +++--- simuleval/data/dataloader/t2t_dataloader.py | 22 +++++-- simuleval/evaluator/evaluator.py | 67 +++++++++++++++++++-- simuleval/evaluator/instance.py | 2 + simuleval/options.py | 3 + 6 files changed, 109 insertions(+), 18 deletions(-) diff --git a/simuleval/data/dataloader/dataloader.py b/simuleval/data/dataloader/dataloader.py index 206aa9e..84c828a 100644 --- a/simuleval/data/dataloader/dataloader.py +++ b/simuleval/data/dataloader/dataloader.py @@ -28,7 +28,7 @@ def register_dataloader_class(name, cls): class GenericDataloader: """ - Load source and target data + Load source, target and background data .. argparse:: :ref: simuleval.options.add_data_args @@ -41,10 +41,12 @@ def __init__( self, source_list: List[str], target_list: Union[List[str], List[None]], + background_list: Union[List[str], List[None]], tgt_lang_list: Optional[List[str]] = None, ) -> None: self.source_list = source_list self.target_list = target_list + self.background_list = background_list self.tgt_lang_list = tgt_lang_list assert len(self.source_list) == len(self.target_list) @@ -56,6 +58,9 @@ def get_source(self, index: int) -> Any: def get_target(self, index: int) -> Any: return self.preprocess_target(self.target_list[index]) + + def get_background(self, index: int) -> Any: + return self.background_list[index] def get_tgt_lang(self, index: int) -> Optional[str]: if getattr(self, "tgt_lang_list", None) is None or index >= len( @@ -69,6 +74,7 @@ def __getitem__(self, index: int) -> Dict[str, Any]: return { "source": self.get_source(index), "target": self.get_target(index), + "background": self.get_background(index), "tgt_lang": self.get_tgt_lang(index), } @@ -94,6 +100,11 @@ def add_args(parser: ArgumentParser): type=str, help="Target file.", ) + parser.add_argument( + "--background", + type=str, + help="Background info file.", + ) parser.add_argument( "--source-type", type=str, diff --git a/simuleval/data/dataloader/s2t_dataloader.py b/simuleval/data/dataloader/s2t_dataloader.py index 537176d..875d3eb 100644 --- a/simuleval/data/dataloader/s2t_dataloader.py +++ b/simuleval/data/dataloader/s2t_dataloader.py @@ -71,9 +71,9 @@ def __init__( self, source_list: List[str], target_list: List[str], - tgt_lang_list: Optional[List[str]] = None, + bgd_info_list: Union[List[str], List[None]], ) -> None: - super().__init__(source_list, target_list, tgt_lang_list) + super().__init__(source_list, target_list, bgd_info_list) def preprocess_source(self, source: Union[Path, str]) -> List[float]: assert IS_IMPORT_SOUNDFILE, "Please make sure soundfile is properly installed." @@ -95,21 +95,27 @@ def from_files( cls, source: Union[Path, str], target: Union[Path, str], - tgt_lang: Union[Path, str], + background: Optional[Union[Path, str]] ) -> SpeechToTextDataloader: source_list = load_list_from_file(source) target_list = load_list_from_file(target) tgt_lang_list = [] - if tgt_lang is not None: - tgt_lang_list = load_list_from_file(tgt_lang) - dataloader = cls(source_list, target_list, tgt_lang_list) + # if tgt_lang is not None: + # tgt_lang_list = load_list_from_file(tgt_lang) + if background: + with open(background) as f: + background_list = f.readlines() + else: + background_list = [None for _ in source_list] + + dataloader = cls(source_list, target_list, background_list) return dataloader @classmethod def from_args(cls, args: Namespace): args.source_type = "speech" args.target_type = "text" - return cls.from_files(args.source, args.target, args.tgt_lang) + return cls.from_files(args.source, args.target, args.background) @register_dataloader("speech-to-speech") diff --git a/simuleval/data/dataloader/t2t_dataloader.py b/simuleval/data/dataloader/t2t_dataloader.py index 16d5d41..96813b3 100644 --- a/simuleval/data/dataloader/t2t_dataloader.py +++ b/simuleval/data/dataloader/t2t_dataloader.py @@ -15,11 +15,15 @@ @register_dataloader("text-to-text") class TextToTextDataloader(GenericDataloader): def __init__( - self, source_list: List[str], target_list: Union[List[str], List[None]] + self, + source_list: List[str], + target_list: Union[List[str], List[None]], + bgd_info_list: Union[List[str], List[None]], ) -> None: - super().__init__(source_list, target_list) + super().__init__(source_list, target_list, bgd_info_list) self.source_splitter = lambda x: x.split() self.target_splitter = lambda x: x + self.bgd_splitter = lambda x: x def set_source_splitter(self, function: Callable) -> None: # TODO, make is configurable @@ -33,7 +37,10 @@ def preprocess_target(self, target: str) -> List: @classmethod def from_files( - cls, source: Union[Path, str], target: Optional[Union[Path, str]] + cls, + source: Union[Path, str], + target: Optional[Union[Path, str]], + background: Optional[Union[Path, str]] ) -> TextToTextDataloader: assert source with open(source) as f: @@ -43,11 +50,16 @@ def from_files( target_list = f.readlines() else: target_list = [None for _ in source_list] - dataloader = cls(source_list, target_list) + if background: + with open(background) as f: + background_list = f.readlines() + else: + background_list = [None for _ in source_list] + dataloader = cls(source_list, target_list, background_list) return dataloader @classmethod def from_args(cls, args: Namespace): args.source_type = "text" args.target_type = "text" - return cls.from_files(args.source, args.target) + return cls.from_files(args.source, args.target, args.background) diff --git a/simuleval/evaluator/evaluator.py b/simuleval/evaluator/evaluator.py index 8b213fb..771cded 100644 --- a/simuleval/evaluator/evaluator.py +++ b/simuleval/evaluator/evaluator.py @@ -13,10 +13,13 @@ from argparse import Namespace from pathlib import Path from typing import Dict, Generator, Optional +import editdistance +from pydub import AudioSegment import pandas import yaml -from simuleval.data.dataloader import GenericDataloader, build_dataloader +from simuleval.data.dataloader import GenericDataloader +from simuleval.data.dataloader import build_dataloader from simuleval.data.dataloader.dataloader import IterableDataloader from tqdm import tqdm @@ -36,6 +39,36 @@ logger = logging.getLogger("simuleval.sentence_level_evaluator") +def get_audio_duration(wav_path): + audio = AudioSegment.from_wav(wav_path) + return len(audio) / 1000.0 # pydub provides length in milliseconds + +def get_RTF(eval_took_s, audio_file_list, start_index, end_index): + with open(audio_file_list, 'r') as f: + fnames = [i[:-1] for i in f.readlines()[start_index:end_index]] + dur = list(map(get_audio_duration, fnames)) + return eval_took_s/sum(dur) + +def get_real_wer(output_path, source_file, start_index, end_index): + + """ Calculate the WER between the ASR output and the source text.""" + total_distance = 0 + total_ref_words = 0 + with open(f"{output_path}/asr", "r") as f: + hyp_setences = [i.replace("\n", "") for i in f.readlines()] + with open(f"{source_file}.txt", "r") as f: + ref_sentences = [i.replace("\n", "") for i in f.readlines()][start_index:end_index] + + for hyp_sentence, ref_sentence in zip(hyp_setences, ref_sentences): + hyp_words = hyp_sentence.split() + ref_words = ref_sentence.split() + + total_distance += editdistance.eval(ref_words, hyp_words) + total_ref_words += len(ref_words) + + return round(100.0 * total_distance / total_ref_words, 2) + + class SentenceLevelEvaluator(object): """ Sentence Level evaluator. It iterates over sentence pairs and run evaluation. @@ -228,8 +261,11 @@ def dump_results(self) -> None: results = self.results finish_t = time.time() results["TIME"] = finish_t - self.start_t - + parser = argparse.ArgumentParser() + parser.add_argument("--source", type=str, default="") + parser.add_argument("--target", type=str, default="") + parser.add_argument("--background", type=str, default=None) parser.add_argument("--use_api", action='store_true') parser.add_argument("--k", type=int, default=4) parser.add_argument("--dir", type=str, default=None) @@ -237,14 +273,32 @@ def dump_results(self) -> None: parser.add_argument("--model_id", type=str, default=None) parser.add_argument("--start-index", type=int, default=None) parser.add_argument("--end-index", type=int, default=None) + parser.add_argument("--source-segment-size", type=int, default=None) + parser.add_argument("--use_asr_api", action='store_true') + parser.add_argument("--asr_model_size", type=str, default=None) + parser.add_argument("--prompt_id", type=int, default=0) custom_args, _ = parser.parse_known_args() + + audio_file_list = custom_args.source.replace(".txt", "") + results["RTF1"] = get_RTF(results["TIME"], audio_file_list, custom_args.start_index, custom_args.end_index) + + if custom_args.asr_model_size is not None: + results["WER"] = get_real_wer(custom_args.output, custom_args.source, custom_args.start_index, custom_args.end_index) + else: + results["WER"] = None results["k"] = custom_args.k results["dir"] = custom_args.dir results["output"] = custom_args.output results["use_api"] = custom_args.use_api results["model_id"] = custom_args.model_id - results["start_index"] = custom_args.start_index results["end_index"] = custom_args.end_index + results["source_segment_size"] = custom_args.source_segment_size + results["use_asr_api"] = custom_args.use_asr_api + results["asr_model_size"] = custom_args.asr_model_size + results["prompt_id"] = custom_args.prompt_id + results["background"] = custom_args.background + + if self.output: results.to_csv(self.output / "scores.tsv", sep="\t", index=False) @@ -267,7 +321,7 @@ def __call__(self, system): self.output / "instances.log", "a" ) if self.output else contextlib.nullcontext() as file: system.reset() - for sample in self.iterator: + for sample in self.iterator: # "sample" is an input-output-(background) pair(triplet) instance = ( self.instance_class( self.dataloader.cur_index, self.dataloader, self.args @@ -275,10 +329,13 @@ def __call__(self, system): if isinstance(self.dataloader, IterableDataloader) else sample ) + # update background info for the sentence + if self.args.background is not None: + system._set_background(sample.background) while not self.is_finished(instance): input_segment = instance.send_source(self.source_segment_size) output_segment = system.pushpop(input_segment) - instance.receive_prediction(output_segment) + instance.receive_prediction(output_segment) if instance.finish_prediction: # if instance.finish_prediction where set by the reader, # source_finished_reading will be set as well. If it is diff --git a/simuleval/evaluator/instance.py b/simuleval/evaluator/instance.py index 9f33466..4cb24e3 100644 --- a/simuleval/evaluator/instance.py +++ b/simuleval/evaluator/instance.py @@ -46,6 +46,8 @@ def __init__( self.source = self.dataloader[self.index]["source"] self.reference = self.dataloader[self.index]["target"] self.tgt_lang = self.dataloader[self.index]["tgt_lang"] + if args.background is not None: + self.background = self.dataloader[self.index]["background"] self.reset() if args is not None: diff --git a/simuleval/options.py b/simuleval/options.py index 1c1b40c..072e145 100644 --- a/simuleval/options.py +++ b/simuleval/options.py @@ -222,6 +222,9 @@ def general_parser( parser.add_argument( "--device", type=str, default="cpu", help="Device to run the model." ) + parser.add_argument( + "--background", type=str, help="Path to background info.", default=None + ) dtype_arg_group = parser.add_mutually_exclusive_group() dtype_arg_group.add_argument( "--dtype", From c2aac63e300d65319320cec87814a85131ee33df Mon Sep 17 00:00:00 2001 From: RomanKoshkin Date: Sun, 19 May 2024 10:31:39 +0900 Subject: [PATCH 4/5] minor fixes --- simuleval/evaluator/evaluator.py | 82 ++++++++----------- simuleval/evaluator/scorers/quality_scorer.py | 70 +++++++++++++--- 2 files changed, 92 insertions(+), 60 deletions(-) diff --git a/simuleval/evaluator/evaluator.py b/simuleval/evaluator/evaluator.py index 771cded..81def2c 100644 --- a/simuleval/evaluator/evaluator.py +++ b/simuleval/evaluator/evaluator.py @@ -43,22 +43,23 @@ def get_audio_duration(wav_path): audio = AudioSegment.from_wav(wav_path) return len(audio) / 1000.0 # pydub provides length in milliseconds + def get_RTF(eval_took_s, audio_file_list, start_index, end_index): - with open(audio_file_list, 'r') as f: + with open(audio_file_list, "r") as f: fnames = [i[:-1] for i in f.readlines()[start_index:end_index]] dur = list(map(get_audio_duration, fnames)) - return eval_took_s/sum(dur) + return eval_took_s / sum(dur) -def get_real_wer(output_path, source_file, start_index, end_index): - """ Calculate the WER between the ASR output and the source text.""" +def get_real_wer(output_path, source_file, start_index, end_index): + """Calculate the WER between the ASR output and the source text.""" total_distance = 0 total_ref_words = 0 - with open(f"{output_path}/asr", "r") as f: + with open(f"{output_path}/asr.log", "r") as f: hyp_setences = [i.replace("\n", "") for i in f.readlines()] with open(f"{source_file}.txt", "r") as f: ref_sentences = [i.replace("\n", "") for i in f.readlines()][start_index:end_index] - + for hyp_sentence, ref_sentence in zip(hyp_setences, ref_sentences): hyp_words = hyp_sentence.split() ref_words = ref_sentence.split() @@ -67,7 +68,7 @@ def get_real_wer(output_path, source_file, start_index, end_index): total_ref_words += len(ref_words) return round(100.0 * total_distance / total_ref_words, 2) - + class SentenceLevelEvaluator(object): """ @@ -122,15 +123,9 @@ def __init__( if args.eval_latency_unit == "spm": assert args.eval_latency_spm_model assert IS_IMPORT_SPM - self.target_spm_model = sentencepiece.SentencePieceProcessor( - model_file=args.eval_latency_spm_model - ) + self.target_spm_model = sentencepiece.SentencePieceProcessor(model_file=args.eval_latency_spm_model) - if ( - self.source_type is None - and self.target_type is None - and self.output is not None - ): + if self.source_type is None and self.target_type is None and self.output is not None: with open(self.output / "config.yaml") as f: configs = yaml.safe_load(f) self.source_type = configs["source_type"] @@ -148,18 +143,13 @@ def __init__( default_flow_style=False, ) - self.instance_class = INSTANCE_TYPE_DICT[ - f"{self.source_type}-{self.target_type}" - ] + self.instance_class = INSTANCE_TYPE_DICT[f"{self.source_type}-{self.target_type}"] self.start_index = getattr(args, "start_index", 0) self.end_index = getattr(args, "end_index", -1) if not self.score_only: if self.output: - if ( - self.args.continue_unfinished - and (self.output / "instances.log").exists() - ): + if self.args.continue_unfinished and (self.output / "instances.log").exists(): with open(self.output / "instances.log", "r") as f: line = None for line in f: # noqa @@ -233,17 +223,11 @@ def get_indices(self) -> Generator: @property def quality(self) -> Dict[str, float]: - return { - name: scorer(self.instances) - for name, scorer in self.quality_scorers.items() - } + return {name: scorer(self.instances) for name, scorer in self.quality_scorers.items()} @property def latency(self) -> Dict[str, float]: - return { - name: scorer(self.instances) - for name, scorer in self.latency_scorers.items() - } + return {name: scorer(self.instances) for name, scorer in self.latency_scorers.items()} @property def results(self): @@ -261,12 +245,12 @@ def dump_results(self) -> None: results = self.results finish_t = time.time() results["TIME"] = finish_t - self.start_t - + parser = argparse.ArgumentParser() parser.add_argument("--source", type=str, default="") parser.add_argument("--target", type=str, default="") parser.add_argument("--background", type=str, default=None) - parser.add_argument("--use_api", action='store_true') + parser.add_argument("--use_api", action="store_true") parser.add_argument("--k", type=int, default=4) parser.add_argument("--dir", type=str, default=None) parser.add_argument("--output", type=str, default=None) @@ -274,16 +258,25 @@ def dump_results(self) -> None: parser.add_argument("--start-index", type=int, default=None) parser.add_argument("--end-index", type=int, default=None) parser.add_argument("--source-segment-size", type=int, default=None) - parser.add_argument("--use_asr_api", action='store_true') + parser.add_argument("--use_asr_api", action="store_true") parser.add_argument("--asr_model_size", type=str, default=None) parser.add_argument("--prompt_id", type=int, default=0) custom_args, _ = parser.parse_known_args() - audio_file_list = custom_args.source.replace(".txt", "") - results["RTF1"] = get_RTF(results["TIME"], audio_file_list, custom_args.start_index, custom_args.end_index) - if custom_args.asr_model_size is not None: - results["WER"] = get_real_wer(custom_args.output, custom_args.source, custom_args.start_index, custom_args.end_index) + audio_file_list = custom_args.source.replace(".txt", "") + results["RTF1"] = get_RTF( + results["TIME"], + audio_file_list, + custom_args.start_index, + custom_args.end_index, + ) + results["WER"] = get_real_wer( + custom_args.output, + custom_args.source, + custom_args.start_index, + custom_args.end_index, + ) else: results["WER"] = None results["k"] = custom_args.k @@ -297,8 +290,7 @@ def dump_results(self) -> None: results["asr_model_size"] = custom_args.asr_model_size results["prompt_id"] = custom_args.prompt_id results["background"] = custom_args.background - - + if self.output: results.to_csv(self.output / "scores.tsv", sep="\t", index=False) @@ -317,15 +309,11 @@ def is_finished(self, instance) -> bool: return instance.finish_prediction def __call__(self, system): - with open( - self.output / "instances.log", "a" - ) if self.output else contextlib.nullcontext() as file: + with open(self.output / "instances.log", "a") if self.output else contextlib.nullcontext() as file: system.reset() - for sample in self.iterator: # "sample" is an input-output-(background) pair(triplet) + for sample in self.iterator: # "sample" is an input-output-(background) pair(triplet) instance = ( - self.instance_class( - self.dataloader.cur_index, self.dataloader, self.args - ) + self.instance_class(self.dataloader.cur_index, self.dataloader, self.args) if isinstance(self.dataloader, IterableDataloader) else sample ) @@ -335,7 +323,7 @@ def __call__(self, system): while not self.is_finished(instance): input_segment = instance.send_source(self.source_segment_size) output_segment = system.pushpop(input_segment) - instance.receive_prediction(output_segment) + instance.receive_prediction(output_segment) if instance.finish_prediction: # if instance.finish_prediction where set by the reader, # source_finished_reading will be set as well. If it is diff --git a/simuleval/evaluator/scorers/quality_scorer.py b/simuleval/evaluator/scorers/quality_scorer.py index f6af654..68d0f0a 100644 --- a/simuleval/evaluator/scorers/quality_scorer.py +++ b/simuleval/evaluator/scorers/quality_scorer.py @@ -4,6 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import re import logging import string import subprocess @@ -45,6 +46,7 @@ def add_args(parser: ArgumentParser): def from_args(cls, args: Namespace): return cls() + def add_sacrebleu_args(parser): parser.add_argument( "--sacrebleu-tokenizer", @@ -54,6 +56,7 @@ def add_sacrebleu_args(parser): help="Tokenizer in sacrebleu", ) + @register_quality_scorer("CHRF") class CHRFScorer(QualityScorer): """ChrF1 @@ -71,7 +74,8 @@ def __call__(self, instances) -> float: scores.append(chrf.score) return np.mean(scores) - + + @register_quality_scorer("WER") class WERScorer(QualityScorer): """ @@ -89,9 +93,7 @@ def __init__(self, args) -> None: raise ImportError("Please install editdistance to use WER scorer") self.logger = logging.getLogger("simuleval.scorer.wer") self.logger.warning("WER scorer only support language with spaces.") - self.logger.warning( - "Current WER scorer is on raw text (un-tokenized with punctuations)." - ) + self.logger.warning("Current WER scorer is on raw text (un-tokenized with punctuations).") self.ed = ed def __call__(self, instances: Dict) -> float: @@ -155,6 +157,54 @@ def from_args(cls, args): return cls(args.sacrebleu_tokenizer) +@register_quality_scorer("BLEUp") +class SacreBLEUScorerPunkt(QualityScorer): + """ + SacreBLEU Scorer + + Usage: + :code:`--quality-metrics BLEU` + + Additional command line arguments: + + .. argparse:: + :ref: simuleval.evaluator.scorers.quality_scorer.add_sacrebleu_args + :passparser: + :prog: + """ + + def __init__(self, tokenizer: str = "13a") -> None: + super().__init__() + self.logger = logging.getLogger("simuleval.scorer.bleu") + self.tokenizer = tokenizer + + def _remove_punctuation(self, string): + return re.sub(r"[^\w\s]", "", string) + + def __call__(self, instances: Dict) -> float: + + try: + return ( + BLEU(tokenize=self.tokenizer) + .corpus_score( + [self._remove_punctuation(ins.prediction) for ins in instances.values()], + [[self._remove_punctuation(ins.reference) for ins in instances.values()]], + ) + .score + ) + except Exception as e: + self.logger.error(str(e)) + return 0 + + @staticmethod + def add_args(parser): + add_sacrebleu_args(parser) + + @classmethod + def from_args(cls, args): + return cls(args.sacrebleu_tokenizer) + + @register_quality_scorer("ASR_BLEU") class ASRSacreBLEUScorer(QualityScorer): """ @@ -257,16 +307,12 @@ def from_args(cls, args): return cls(args.sacrebleu_tokenizer, args.target_speech_lang) -PUNCTUATIONS_EXCLUDE_APOSTROPHE = ( - string.punctuation.replace("'", "") + "¡¨«°³º»¿‘“”…♪♫ˆᵉ™,ʾ˚" -) +PUNCTUATIONS_EXCLUDE_APOSTROPHE = string.punctuation.replace("'", "") + "¡¨«°³º»¿‘“”…♪♫ˆᵉ™,ʾ˚" PUNCTUATIONS_TO_SPACE = "-/–·—•" def remove_punctuations(text, punctuations=string.punctuation): - text = text.translate( - str.maketrans(PUNCTUATIONS_TO_SPACE, " " * len(PUNCTUATIONS_TO_SPACE)) - ) + text = text.translate(str.maketrans(PUNCTUATIONS_TO_SPACE, " " * len(PUNCTUATIONS_TO_SPACE))) return text.translate(str.maketrans("", "", punctuations)) @@ -317,9 +363,7 @@ def __call__(self, instances: Dict) -> float: return score def asr_transcribe(self, instances): - self.logger.info( - "Evaluating speech output by ASR BLEU. whisper and sacrebleu are required." - ) + self.logger.info("Evaluating speech output by ASR BLEU. whisper and sacrebleu are required.") self.logger.info("Configs:") self.logger.info(f"tokenizer = {self.tokenizer}") self.logger.info(f"target_lang = {self.target_lang}") From 62034df70b21f2605ca2fd20cab11c8f3c519e89 Mon Sep 17 00:00:00 2001 From: RomanKoshkin Date: Wed, 19 Jun 2024 15:30:46 +0900 Subject: [PATCH 5/5] added min_read_time and min_lag_words as options --- simuleval/evaluator/evaluator.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/simuleval/evaluator/evaluator.py b/simuleval/evaluator/evaluator.py index 81def2c..cb9d7af 100644 --- a/simuleval/evaluator/evaluator.py +++ b/simuleval/evaluator/evaluator.py @@ -255,12 +255,16 @@ def dump_results(self) -> None: parser.add_argument("--dir", type=str, default=None) parser.add_argument("--output", type=str, default=None) parser.add_argument("--model_id", type=str, default=None) + parser.add_argument("--min_read_time", type=str, default=None) + parser.add_argument("--min_lag_words", type=int, default=None) parser.add_argument("--start-index", type=int, default=None) parser.add_argument("--end-index", type=int, default=None) parser.add_argument("--source-segment-size", type=int, default=None) parser.add_argument("--use_asr_api", action="store_true") parser.add_argument("--asr_model_size", type=str, default=None) parser.add_argument("--prompt_id", type=int, default=0) + parser.add_argument("--func_wrds", type=str, default="[]") + parser.add_argument("--priming", action="store_true") custom_args, _ = parser.parse_known_args() if custom_args.asr_model_size is not None: @@ -277,6 +281,11 @@ def dump_results(self) -> None: custom_args.start_index, custom_args.end_index, ) + results["min_read_time"] = custom_args.min_read_time + results["min_lag_words"] = custom_args.min_lag_words + results["src_seg_sz"] = custom_args.source_segment_size + results["use_asr_api"] = custom_args.use_asr_api + results["asr_model_size"] = custom_args.asr_model_size else: results["WER"] = None results["k"] = custom_args.k @@ -285,14 +294,14 @@ def dump_results(self) -> None: results["use_api"] = custom_args.use_api results["model_id"] = custom_args.model_id results["end_index"] = custom_args.end_index - results["source_segment_size"] = custom_args.source_segment_size - results["use_asr_api"] = custom_args.use_asr_api - results["asr_model_size"] = custom_args.asr_model_size results["prompt_id"] = custom_args.prompt_id results["background"] = custom_args.background + results["func_wrds"] = custom_args.func_wrds + results["priming"] = custom_args.priming if self.output: results.to_csv(self.output / "scores.tsv", sep="\t", index=False) + results.to_json(self.output / "scores.json", index=False, orient="records") logger.info("Results:") print(results.to_string(index=False))