Skip to content
This repository was archived by the owner on Sep 18, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions simuleval/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def scoring():
options.add_evaluator_args(parser)
options.add_scorer_args(parser)
options.add_dataloader_args(parser)
args = parser.parse_args()
args, _ = parser.parse_known_args()
evaluator = SentenceLevelEvaluator.from_args(args)
print(evaluator.results)

Expand All @@ -98,7 +98,7 @@ def remote_evaluate():
options.add_dataloader_args(parser)
options.add_evaluator_args(parser)
options.add_scorer_args(parser)
args = parser.parse_args()
args, _ = parser.parse_known_args()
evaluator = build_remote_evaluator(args)

# evaluate system
Expand Down
13 changes: 12 additions & 1 deletion simuleval/data/dataloader/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def register_dataloader_class(name, cls):

class GenericDataloader:
"""
Load source and target data
Load source, target and background data

.. argparse::
:ref: simuleval.options.add_data_args
Expand All @@ -41,10 +41,12 @@ def __init__(
self,
source_list: List[str],
target_list: Union[List[str], List[None]],
background_list: Union[List[str], List[None]],
tgt_lang_list: Optional[List[str]] = None,
) -> None:
self.source_list = source_list
self.target_list = target_list
self.background_list = background_list
self.tgt_lang_list = tgt_lang_list
assert len(self.source_list) == len(self.target_list)

Expand All @@ -56,6 +58,9 @@ def get_source(self, index: int) -> Any:

def get_target(self, index: int) -> Any:
return self.preprocess_target(self.target_list[index])

def get_background(self, index: int) -> Any:
return self.background_list[index]

def get_tgt_lang(self, index: int) -> Optional[str]:
if getattr(self, "tgt_lang_list", None) is None or index >= len(
Expand All @@ -69,6 +74,7 @@ def __getitem__(self, index: int) -> Dict[str, Any]:
return {
"source": self.get_source(index),
"target": self.get_target(index),
"background": self.get_background(index),
"tgt_lang": self.get_tgt_lang(index),
}

Expand All @@ -94,6 +100,11 @@ def add_args(parser: ArgumentParser):
type=str,
help="Target file.",
)
parser.add_argument(
"--background",
type=str,
help="Background info file.",
)
parser.add_argument(
"--source-type",
type=str,
Expand Down
20 changes: 13 additions & 7 deletions simuleval/data/dataloader/s2t_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ def __init__(
self,
source_list: List[str],
target_list: List[str],
tgt_lang_list: Optional[List[str]] = None,
bgd_info_list: Union[List[str], List[None]],
) -> None:
super().__init__(source_list, target_list, tgt_lang_list)
super().__init__(source_list, target_list, bgd_info_list)

def preprocess_source(self, source: Union[Path, str]) -> List[float]:
assert IS_IMPORT_SOUNDFILE, "Please make sure soundfile is properly installed."
Expand All @@ -95,21 +95,27 @@ def from_files(
cls,
source: Union[Path, str],
target: Union[Path, str],
tgt_lang: Union[Path, str],
background: Optional[Union[Path, str]]
) -> SpeechToTextDataloader:
source_list = load_list_from_file(source)
target_list = load_list_from_file(target)
tgt_lang_list = []
if tgt_lang is not None:
tgt_lang_list = load_list_from_file(tgt_lang)
dataloader = cls(source_list, target_list, tgt_lang_list)
# if tgt_lang is not None:
# tgt_lang_list = load_list_from_file(tgt_lang)
if background:
with open(background) as f:
background_list = f.readlines()
else:
background_list = [None for _ in source_list]

dataloader = cls(source_list, target_list, background_list)
return dataloader

@classmethod
def from_args(cls, args: Namespace):
args.source_type = "speech"
args.target_type = "text"
return cls.from_files(args.source, args.target, args.tgt_lang)
return cls.from_files(args.source, args.target, args.background)


@register_dataloader("speech-to-speech")
Expand Down
22 changes: 17 additions & 5 deletions simuleval/data/dataloader/t2t_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@
@register_dataloader("text-to-text")
class TextToTextDataloader(GenericDataloader):
def __init__(
self, source_list: List[str], target_list: Union[List[str], List[None]]
self,
source_list: List[str],
target_list: Union[List[str], List[None]],
bgd_info_list: Union[List[str], List[None]],
) -> None:
super().__init__(source_list, target_list)
super().__init__(source_list, target_list, bgd_info_list)
self.source_splitter = lambda x: x.split()
self.target_splitter = lambda x: x
self.bgd_splitter = lambda x: x

def set_source_splitter(self, function: Callable) -> None:
# TODO, make is configurable
Expand All @@ -33,7 +37,10 @@ def preprocess_target(self, target: str) -> List:

@classmethod
def from_files(
cls, source: Union[Path, str], target: Optional[Union[Path, str]]
cls,
source: Union[Path, str],
target: Optional[Union[Path, str]],
background: Optional[Union[Path, str]]
) -> TextToTextDataloader:
assert source
with open(source) as f:
Expand All @@ -43,11 +50,16 @@ def from_files(
target_list = f.readlines()
else:
target_list = [None for _ in source_list]
dataloader = cls(source_list, target_list)
if background:
with open(background) as f:
background_list = f.readlines()
else:
background_list = [None for _ in source_list]
dataloader = cls(source_list, target_list, background_list)
return dataloader

@classmethod
def from_args(cls, args: Namespace):
args.source_type = "text"
args.target_type = "text"
return cls.from_files(args.source, args.target)
return cls.from_files(args.source, args.target, args.background)
141 changes: 108 additions & 33 deletions simuleval/evaluator/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,23 @@
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# test
import contextlib
import json
import logging
import numbers
import os
import os, time
import argparse
from argparse import Namespace
from pathlib import Path
from typing import Dict, Generator, Optional
import editdistance
from pydub import AudioSegment

import pandas
import yaml
from simuleval.data.dataloader import GenericDataloader, build_dataloader
from simuleval.data.dataloader import GenericDataloader
from simuleval.data.dataloader import build_dataloader
from simuleval.data.dataloader.dataloader import IterableDataloader
from tqdm import tqdm

Expand All @@ -35,6 +39,37 @@
logger = logging.getLogger("simuleval.sentence_level_evaluator")


def get_audio_duration(wav_path):
audio = AudioSegment.from_wav(wav_path)
return len(audio) / 1000.0 # pydub provides length in milliseconds


def get_RTF(eval_took_s, audio_file_list, start_index, end_index):
with open(audio_file_list, "r") as f:
fnames = [i[:-1] for i in f.readlines()[start_index:end_index]]
dur = list(map(get_audio_duration, fnames))
return eval_took_s / sum(dur)


def get_real_wer(output_path, source_file, start_index, end_index):
"""Calculate the WER between the ASR output and the source text."""
total_distance = 0
total_ref_words = 0
with open(f"{output_path}/asr.log", "r") as f:
hyp_setences = [i.replace("\n", "") for i in f.readlines()]
with open(f"{source_file}.txt", "r") as f:
ref_sentences = [i.replace("\n", "") for i in f.readlines()][start_index:end_index]

for hyp_sentence, ref_sentence in zip(hyp_setences, ref_sentences):
hyp_words = hyp_sentence.split()
ref_words = ref_sentence.split()

total_distance += editdistance.eval(ref_words, hyp_words)
total_ref_words += len(ref_words)

return round(100.0 * total_distance / total_ref_words, 2)


class SentenceLevelEvaluator(object):
"""
Sentence Level evaluator. It iterates over sentence pairs and run evaluation.
Expand Down Expand Up @@ -88,15 +123,9 @@ def __init__(
if args.eval_latency_unit == "spm":
assert args.eval_latency_spm_model
assert IS_IMPORT_SPM
self.target_spm_model = sentencepiece.SentencePieceProcessor(
model_file=args.eval_latency_spm_model
)
self.target_spm_model = sentencepiece.SentencePieceProcessor(model_file=args.eval_latency_spm_model)

if (
self.source_type is None
and self.target_type is None
and self.output is not None
):
if self.source_type is None and self.target_type is None and self.output is not None:
with open(self.output / "config.yaml") as f:
configs = yaml.safe_load(f)
self.source_type = configs["source_type"]
Expand All @@ -114,18 +143,13 @@ def __init__(
default_flow_style=False,
)

self.instance_class = INSTANCE_TYPE_DICT[
f"{self.source_type}-{self.target_type}"
]
self.instance_class = INSTANCE_TYPE_DICT[f"{self.source_type}-{self.target_type}"]
self.start_index = getattr(args, "start_index", 0)
self.end_index = getattr(args, "end_index", -1)

if not self.score_only:
if self.output:
if (
self.args.continue_unfinished
and (self.output / "instances.log").exists()
):
if self.args.continue_unfinished and (self.output / "instances.log").exists():
with open(self.output / "instances.log", "r") as f:
line = None
for line in f: # noqa
Expand Down Expand Up @@ -153,6 +177,7 @@ def __init__(
)
else:
self.iterator = iterable
self.start_t = time.time()

def write_log(self, instance):
if self.output is not None:
Expand Down Expand Up @@ -198,17 +223,11 @@ def get_indices(self) -> Generator:

@property
def quality(self) -> Dict[str, float]:
return {
name: scorer(self.instances)
for name, scorer in self.quality_scorers.items()
}
return {name: scorer(self.instances) for name, scorer in self.quality_scorers.items()}

@property
def latency(self) -> Dict[str, float]:
return {
name: scorer(self.instances)
for name, scorer in self.latency_scorers.items()
}
return {name: scorer(self.instances) for name, scorer in self.latency_scorers.items()}

@property
def results(self):
Expand All @@ -224,8 +243,65 @@ def results(self):

def dump_results(self) -> None:
results = self.results
finish_t = time.time()
results["TIME"] = finish_t - self.start_t

parser = argparse.ArgumentParser()
parser.add_argument("--source", type=str, default="")
parser.add_argument("--target", type=str, default="")
parser.add_argument("--background", type=str, default=None)
parser.add_argument("--use_api", action="store_true")
parser.add_argument("--k", type=int, default=4)
parser.add_argument("--dir", type=str, default=None)
parser.add_argument("--output", type=str, default=None)
parser.add_argument("--model_id", type=str, default=None)
parser.add_argument("--min_read_time", type=str, default=None)
parser.add_argument("--min_lag_words", type=int, default=None)
parser.add_argument("--start-index", type=int, default=None)
parser.add_argument("--end-index", type=int, default=None)
parser.add_argument("--source-segment-size", type=int, default=None)
parser.add_argument("--use_asr_api", action="store_true")
parser.add_argument("--asr_model_size", type=str, default=None)
parser.add_argument("--prompt_id", type=int, default=0)
parser.add_argument("--func_wrds", type=str, default="[]")
parser.add_argument("--priming", action="store_true")
custom_args, _ = parser.parse_known_args()

if custom_args.asr_model_size is not None:
audio_file_list = custom_args.source.replace(".txt", "")
results["RTF1"] = get_RTF(
results["TIME"],
audio_file_list,
custom_args.start_index,
custom_args.end_index,
)
results["WER"] = get_real_wer(
custom_args.output,
custom_args.source,
custom_args.start_index,
custom_args.end_index,
)
results["min_read_time"] = custom_args.min_read_time
results["min_lag_words"] = custom_args.min_lag_words
results["src_seg_sz"] = custom_args.source_segment_size
results["use_asr_api"] = custom_args.use_asr_api
results["asr_model_size"] = custom_args.asr_model_size
else:
results["WER"] = None
results["k"] = custom_args.k
results["dir"] = custom_args.dir
results["output"] = custom_args.output
results["use_api"] = custom_args.use_api
results["model_id"] = custom_args.model_id
results["end_index"] = custom_args.end_index
results["prompt_id"] = custom_args.prompt_id
results["background"] = custom_args.background
results["func_wrds"] = custom_args.func_wrds
results["priming"] = custom_args.priming

if self.output:
results.to_csv(self.output / "scores.tsv", sep="\t", index=False)
results.to_json(self.output / "scores.json", index=False, orient="records")

logger.info("Results:")
print(results.to_string(index=False))
Expand All @@ -242,18 +318,17 @@ def is_finished(self, instance) -> bool:
return instance.finish_prediction

def __call__(self, system):
with open(
self.output / "instances.log", "a"
) if self.output else contextlib.nullcontext() as file:
with open(self.output / "instances.log", "a") if self.output else contextlib.nullcontext() as file:
system.reset()
for sample in self.iterator:
for sample in self.iterator: # "sample" is an input-output-(background) pair(triplet)
instance = (
self.instance_class(
self.dataloader.cur_index, self.dataloader, self.args
)
self.instance_class(self.dataloader.cur_index, self.dataloader, self.args)
if isinstance(self.dataloader, IterableDataloader)
else sample
)
# update background info for the sentence
if self.args.background is not None:
system._set_background(sample.background)
while not self.is_finished(instance):
input_segment = instance.send_source(self.source_segment_size)
output_segment = system.pushpop(input_segment)
Expand Down
2 changes: 2 additions & 0 deletions simuleval/evaluator/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def __init__(
self.source = self.dataloader[self.index]["source"]
self.reference = self.dataloader[self.index]["target"]
self.tgt_lang = self.dataloader[self.index]["tgt_lang"]
if args.background is not None:
self.background = self.dataloader[self.index]["background"]

self.reset()
if args is not None:
Expand Down
Loading