From 0a0fb0f69a360c7e7e4ae155bd24bce05b3b5426 Mon Sep 17 00:00:00 2001 From: Xutai Ma Date: Mon, 23 Oct 2023 21:00:48 +0000 Subject: [PATCH 1/8] Add SPM and slurm array --- simuleval/cli.py | 9 +++- simuleval/data/dataloader/dataloader.py | 2 +- simuleval/evaluator/evaluator.py | 15 +++++++ simuleval/evaluator/instance.py | 19 ++++++-- simuleval/options.py | 12 +++-- simuleval/utils/slurm.py | 59 +++++++++++++++++++++---- 6 files changed, 99 insertions(+), 17 deletions(-) diff --git a/simuleval/cli.py b/simuleval/cli.py index 5191ce56..b54fe756 100644 --- a/simuleval/cli.py +++ b/simuleval/cli.py @@ -58,8 +58,13 @@ def main(): def evaluate(system_class: GenericAgent, config_dict: dict = {}): EVALUATION_SYSTEM_LIST.append(system_class) - - if check_argument("slurm", config_dict): + just_for_arg_check = {} + for key, value in config_dict.items(): + if isinstance(value, list): + just_for_arg_check[key] = value[0] + else: + just_for_arg_check[key] = value + if check_argument("slurm", just_for_arg_check): submit_slurm_job(config_dict) return diff --git a/simuleval/data/dataloader/dataloader.py b/simuleval/data/dataloader/dataloader.py index 7ec1a575..fb33c587 100644 --- a/simuleval/data/dataloader/dataloader.py +++ b/simuleval/data/dataloader/dataloader.py @@ -57,7 +57,7 @@ def get_target(self, index: int) -> Any: return self.preprocess_target(self.target_list[index]) def get_tgt_lang(self, index: int) -> Optional[str]: - if self.tgt_lang_list is None or index >= len(self.tgt_lang_list): + if getattr(self, "tgt_lang_list", None) is None or index >= len(self.tgt_lang_list): return None else: return self.tgt_lang_list[index] diff --git a/simuleval/evaluator/evaluator.py b/simuleval/evaluator/evaluator.py index 6b766f1c..5a9f50c2 100644 --- a/simuleval/evaluator/evaluator.py +++ b/simuleval/evaluator/evaluator.py @@ -21,6 +21,13 @@ from pathlib import Path from simuleval.data.dataloader import GenericDataloader, build_dataloader +try: + import sentencepiece + + IS_IMPORT_SPM = True +except Exception: + IS_IMPORT_SPM = False + logger = logging.getLogger("simuleval.sentence_level_evaluator") @@ -73,6 +80,12 @@ def __init__( self.source_type = getattr(args, "source_type", None) self.target_type = getattr(args, "target_type", None) + self.target_spm_model = None + if args.eval_latency_unit == "spm": + assert args.eval_latency_spm_model + assert IS_IMPORT_SPM + self.target_spm_model = sentencepiece.SentencePieceProcessor(model_file=args.eval_latency_spm_model) + if ( self.source_type is None and self.target_type is None @@ -150,10 +163,12 @@ def build_instances_from_log(self): for line in f: instance = LogInstance(line.strip()) self.instances[instance.index] = instance + self.instances[instance.index].set_target_spm_model(self.target_spm_model) def build_instances_from_dataloader(self): for i in self.get_indices(): self.instances[i] = self.instance_class(i, self.dataloader, self.args) + self.instances[i].set_target_spm_model(self.target_spm_model) def __len__(self) -> int: return self.end_index - self.start_index diff --git a/simuleval/evaluator/instance.py b/simuleval/evaluator/instance.py index e6cc8766..b41088f7 100644 --- a/simuleval/evaluator/instance.py +++ b/simuleval/evaluator/instance.py @@ -52,6 +52,11 @@ def __init__( if args is not None: self.args = args self.latency_unit = args.eval_latency_unit + + self.target_spm_model = None + + def set_target_spm_model(self, spm_model): + self.target_spm_model = spm_model def reset(self): self.step = 0 @@ -119,11 +124,14 @@ def reference_length(self) -> int: return len(self.reference.split(" ")) elif self.latency_unit == "char": return len(self.reference.strip()) + elif self.latency_unit == "spm": + assert self.target_spm_model is not None + return len(self.target_spm_model.encode(self.reference, out_type=str)) else: raise NotImplementedError def summarize(self): - return { + return_dict = { "index": self.index, "prediction": self.prediction, "delays": self.delays, @@ -133,6 +141,9 @@ def summarize(self): "source": self.source_info, "source_length": self.source_length, } + if self.latency_unit == "spm": + return_dict["prediction_spm"] = self.prediction_list + return return_dict @classmethod def from_json(cls, json_string): @@ -196,7 +207,7 @@ def receive_prediction(self, prediction: TextSegment): current_time = time.time() - if self.latency_unit == "word": + if self.latency_unit in ["word", "spm"]: prediction_list = prediction.content.strip().split() elif self.latency_unit == "char": prediction_list = list(prediction.content.replace(" ", "")) @@ -212,7 +223,7 @@ def receive_prediction(self, prediction: TextSegment): @property def target_length_latency(self): - if self.latency_unit == "word": + if self.latency_unit in ["word", "spm"]: return len(self.reference.split(" ")) elif self.latency_unit == "char": return len(self.reference) @@ -225,6 +236,8 @@ def prediction(self) -> str: return " ".join(list(self.prediction_list)) elif self.latency_unit == "char": return "".join(list(self.prediction_list)) + elif self.latency_unit == "spm": + return "".join(list(self.prediction_list)).replace("▁", " ").strip() else: raise NotImplementedError diff --git a/simuleval/options.py b/simuleval/options.py index 2659785f..ae6c16af 100644 --- a/simuleval/options.py +++ b/simuleval/options.py @@ -62,10 +62,16 @@ def add_evaluator_args(parser: argparse.ArgumentParser): "--eval-latency-unit", type=str, default="word", - choices=["word", "char"], + choices=["word", "char", "spm"], help="Basic unit used for latency calculation, choose from " "words (detokenized) and characters.", ) + parser.add_argument( + "--eval-latency-spm-model", + type=str, + default=None, + help="Pass the spm model path if the eval_latency_unit is spm." + ) parser.add_argument( "--remote-address", default="localhost", @@ -183,7 +189,7 @@ def general_parser(): def add_slurm_args(parser): parser.add_argument( - "--slurm-partition", default="learnaccel,ust", help="Slurm partition." + "--slurm-partition", default="learnaccel,learnfair,seamless", help="Slurm partition." ) parser.add_argument("--slurm-job-name", default="simuleval", help="Slurm job name.") - parser.add_argument("--slurm-time", default="10:00:00", help="Slurm partition.") + parser.add_argument("--slurm-time", default="2:00:00", help="Slurm partition.") diff --git a/simuleval/utils/slurm.py b/simuleval/utils/slurm.py index 6031dd8f..f0944340 100644 --- a/simuleval/utils/slurm.py +++ b/simuleval/utils/slurm.py @@ -9,10 +9,11 @@ import sys import logging import subprocess -from typing import Optional, Dict +from typing import Optional, Dict, List from simuleval import options from simuleval.utils.arguments import cli_argument_list from simuleval.utils.agent import get_agent_class +import itertools logger = logging.getLogger("simuleval.slurm") @@ -31,8 +32,22 @@ def mkdir_output_dir(path: str) -> bool: def submit_slurm_job(config_dict: Optional[Dict] = None) -> None: if config_dict is not None and "slurm" in config_dict: raise RuntimeError("--slurm is only available as a CLI argument") - parser = options.general_parser() + + sweep_options = [ + [[key, v] for v in value] + for key, value in config_dict.items() if isinstance(value, list) + ] + sweep_config_dict_list = [] + if len(sweep_options) > 0: + for option_list in itertools.product(*sweep_options): + sweep_config_dict_list.append({k: v for k, v in option_list}) + + for x in sweep_options: + if x[0][0] in config_dict: + del config_dict[x[0][0]] + cli_arguments = cli_argument_list(config_dict) + parser = options.general_parser() options.add_evaluator_args(parser) options.add_scorer_args(parser, cli_arguments) options.add_slurm_args(parser) @@ -65,10 +80,33 @@ def submit_slurm_job(config_dict: Optional[Dict] = None) -> None: r"[^\"'\s]+\.py", f"{os.path.abspath(args.output)}/agent.py", command ).strip() + sweep_command = "" + sbatch_job_array_head = "" + + if len(sweep_config_dict_list) > 0: + job_array_configs="declare -A JobArrayConfigs\n" + for i, sub_config_dict in enumerate(sweep_config_dict_list): + sub_config_string = " ".join([f"--{k.replace('_', '-')} {v}" for k, v in sub_config_dict.items()]) + job_array_configs += f'JobArrayConfigs[{i}]="{sub_config_string}"\n' + + job_array_configs += "\ndeclare -A JobArrayString\n" + for i, sub_config_dict in enumerate(sweep_config_dict_list): + sub_config_string = ".".join([str(v) for k, v in sub_config_dict.items()]) + job_array_configs += f'JobArrayString[{i}]="{sub_config_string}"\n' + + sweep_command = "${JobArrayConfigs[$SLURM_ARRAY_TASK_ID]}" + sbatch_job_array_head = f"#SBATCH --array=0-{len(sweep_config_dict_list) - 1}" + output_dir = f"{args.output}" + "/results/${JobArrayString[$SLURM_ARRAY_TASK_ID]}" + log_path = f"{output_dir}/slurm-%A_%a.log" + + else: + output_dir = args.output + log_path = f"{args.output}/slurm-%j.log" + if "--output" in command: - command = re.sub(r"--output\s+\S+", f"--output {args.output}", command).strip() + command = re.sub(r"--output\s+\S+", f"--output {output_dir}", command).strip() else: - command += f" --output {args.output}" + command += f" --output {output_dir}" command = command.replace("--", "\\\n\t--") script = f"""#!/bin/bash @@ -77,9 +115,13 @@ def submit_slurm_job(config_dict: Optional[Dict] = None) -> None: #SBATCH --nodes=1 #SBATCH --gpus-per-node=1 #SBATCH --ntasks-per-node=8 -#SBATCH --output="{args.output}/slurm-%j.log" +#SBATCH --output="{args.output}/logs/slurm-%j.log" #SBATCH --job-name="{args.slurm_job_name}" +{sbatch_job_array_head} +{job_array_configs} + +mkdir -p {args.output}/logs cd {os.path.abspath(args.output)} GPU_ID=$SLURM_LOCALID @@ -87,8 +129,9 @@ def submit_slurm_job(config_dict: Optional[Dict] = None) -> None: # Change to local a gpu id for debugging, e.g. # GPU_ID=0 -CUDA_VISIBLE_DEVICES=$GPU_ID {command} - """ + +CUDA_VISIBLE_DEVICES=$GPU_ID {command} {sweep_command} +""" script_file = os.path.join(args.output, "script.sh") with open(script_file, "w") as f: f.writelines(script) @@ -103,4 +146,4 @@ def submit_slurm_job(config_dict: Optional[Dict] = None) -> None: logger.info(f"sbatch stdout: {stdout.decode('utf-8').strip()}") stderr = stderr.decode("utf-8").strip() if len(stderr) > 0: - logger.info(f"sbatch stderr: {stderr.decode('utf-8').strip()}") + logger.info(f"sbatch stderr: {stderr.strip()}") From 9c93ce0d845b22bf7faba38c3d173d11ae211c36 Mon Sep 17 00:00:00 2001 From: Xutai Ma Date: Mon, 23 Oct 2023 21:41:55 +0000 Subject: [PATCH 2/8] Update slurm log path --- simuleval/utils/slurm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/simuleval/utils/slurm.py b/simuleval/utils/slurm.py index f0944340..1b6b7d66 100644 --- a/simuleval/utils/slurm.py +++ b/simuleval/utils/slurm.py @@ -97,7 +97,7 @@ def submit_slurm_job(config_dict: Optional[Dict] = None) -> None: sweep_command = "${JobArrayConfigs[$SLURM_ARRAY_TASK_ID]}" sbatch_job_array_head = f"#SBATCH --array=0-{len(sweep_config_dict_list) - 1}" output_dir = f"{args.output}" + "/results/${JobArrayString[$SLURM_ARRAY_TASK_ID]}" - log_path = f"{output_dir}/slurm-%A_%a.log" + log_path = f"{args.output}/logs/slurm-%A_%a.log" else: output_dir = args.output @@ -115,7 +115,7 @@ def submit_slurm_job(config_dict: Optional[Dict] = None) -> None: #SBATCH --nodes=1 #SBATCH --gpus-per-node=1 #SBATCH --ntasks-per-node=8 -#SBATCH --output="{args.output}/logs/slurm-%j.log" +#SBATCH --output="{log_path}" #SBATCH --job-name="{args.slurm_job_name}" {sbatch_job_array_head} From 71d121568c0ce01947fd1e682059759ab27390d0 Mon Sep 17 00:00:00 2001 From: Xutai Ma Date: Thu, 26 Oct 2023 18:01:23 +0000 Subject: [PATCH 3/8] Address comments --- simuleval/options.py | 2 +- simuleval/utils/slurm.py | 13 +++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/simuleval/options.py b/simuleval/options.py index ae6c16af..417376ca 100644 --- a/simuleval/options.py +++ b/simuleval/options.py @@ -189,7 +189,7 @@ def general_parser(): def add_slurm_args(parser): parser.add_argument( - "--slurm-partition", default="learnaccel,learnfair,seamless", help="Slurm partition." + "--slurm-partition", default="", help="Slurm partition." ) parser.add_argument("--slurm-job-name", default="simuleval", help="Slurm job name.") parser.add_argument("--slurm-time", default="2:00:00", help="Slurm partition.") diff --git a/simuleval/utils/slurm.py b/simuleval/utils/slurm.py index 1b6b7d66..dd0bb693 100644 --- a/simuleval/utils/slurm.py +++ b/simuleval/utils/slurm.py @@ -34,18 +34,18 @@ def submit_slurm_job(config_dict: Optional[Dict] = None) -> None: raise RuntimeError("--slurm is only available as a CLI argument") sweep_options = [ - [[key, v] for v in value] + [[key, v] for v in value] for key, value in config_dict.items() if isinstance(value, list) - ] + ] sweep_config_dict_list = [] if len(sweep_options) > 0: for option_list in itertools.product(*sweep_options): sweep_config_dict_list.append({k: v for k, v in option_list}) - + for x in sweep_options: if x[0][0] in config_dict: del config_dict[x[0][0]] - + cli_arguments = cli_argument_list(config_dict) parser = options.general_parser() options.add_evaluator_args(parser) @@ -82,18 +82,19 @@ def submit_slurm_job(config_dict: Optional[Dict] = None) -> None: sweep_command = "" sbatch_job_array_head = "" + job_array_configs = "" if len(sweep_config_dict_list) > 0: job_array_configs="declare -A JobArrayConfigs\n" for i, sub_config_dict in enumerate(sweep_config_dict_list): sub_config_string = " ".join([f"--{k.replace('_', '-')} {v}" for k, v in sub_config_dict.items()]) job_array_configs += f'JobArrayConfigs[{i}]="{sub_config_string}"\n' - + job_array_configs += "\ndeclare -A JobArrayString\n" for i, sub_config_dict in enumerate(sweep_config_dict_list): sub_config_string = ".".join([str(v) for k, v in sub_config_dict.items()]) job_array_configs += f'JobArrayString[{i}]="{sub_config_string}"\n' - + sweep_command = "${JobArrayConfigs[$SLURM_ARRAY_TASK_ID]}" sbatch_job_array_head = f"#SBATCH --array=0-{len(sweep_config_dict_list) - 1}" output_dir = f"{args.output}" + "/results/${JobArrayString[$SLURM_ARRAY_TASK_ID]}" From 9601ba3560c39bbde29b91be71b0081277a90188 Mon Sep 17 00:00:00 2001 From: Xutai Ma Date: Thu, 26 Oct 2023 19:32:15 +0000 Subject: [PATCH 4/8] black --- simuleval/data/dataloader/dataloader.py | 4 +++- simuleval/evaluator/evaluator.py | 8 ++++++-- simuleval/evaluator/instance.py | 8 ++++---- simuleval/options.py | 6 ++---- simuleval/utils/slurm.py | 13 +++++++++---- 5 files changed, 24 insertions(+), 15 deletions(-) diff --git a/simuleval/data/dataloader/dataloader.py b/simuleval/data/dataloader/dataloader.py index fb33c587..a58feceb 100644 --- a/simuleval/data/dataloader/dataloader.py +++ b/simuleval/data/dataloader/dataloader.py @@ -57,7 +57,9 @@ def get_target(self, index: int) -> Any: return self.preprocess_target(self.target_list[index]) def get_tgt_lang(self, index: int) -> Optional[str]: - if getattr(self, "tgt_lang_list", None) is None or index >= len(self.tgt_lang_list): + if getattr(self, "tgt_lang_list", None) is None or index >= len( + self.tgt_lang_list + ): return None else: return self.tgt_lang_list[index] diff --git a/simuleval/evaluator/evaluator.py b/simuleval/evaluator/evaluator.py index 5a9f50c2..a0e7e598 100644 --- a/simuleval/evaluator/evaluator.py +++ b/simuleval/evaluator/evaluator.py @@ -84,7 +84,9 @@ def __init__( if args.eval_latency_unit == "spm": assert args.eval_latency_spm_model assert IS_IMPORT_SPM - self.target_spm_model = sentencepiece.SentencePieceProcessor(model_file=args.eval_latency_spm_model) + self.target_spm_model = sentencepiece.SentencePieceProcessor( + model_file=args.eval_latency_spm_model + ) if ( self.source_type is None @@ -163,7 +165,9 @@ def build_instances_from_log(self): for line in f: instance = LogInstance(line.strip()) self.instances[instance.index] = instance - self.instances[instance.index].set_target_spm_model(self.target_spm_model) + self.instances[instance.index].set_target_spm_model( + self.target_spm_model + ) def build_instances_from_dataloader(self): for i in self.get_indices(): diff --git a/simuleval/evaluator/instance.py b/simuleval/evaluator/instance.py index b41088f7..89c674d5 100644 --- a/simuleval/evaluator/instance.py +++ b/simuleval/evaluator/instance.py @@ -52,9 +52,9 @@ def __init__( if args is not None: self.args = args self.latency_unit = args.eval_latency_unit - + self.target_spm_model = None - + def set_target_spm_model(self, spm_model): self.target_spm_model = spm_model @@ -131,7 +131,7 @@ def reference_length(self) -> int: raise NotImplementedError def summarize(self): - return_dict = { + return_dict = { "index": self.index, "prediction": self.prediction, "delays": self.delays, @@ -142,7 +142,7 @@ def summarize(self): "source_length": self.source_length, } if self.latency_unit == "spm": - return_dict["prediction_spm"] = self.prediction_list + return_dict["prediction_spm"] = self.prediction_list return return_dict @classmethod diff --git a/simuleval/options.py b/simuleval/options.py index 417376ca..9236a100 100644 --- a/simuleval/options.py +++ b/simuleval/options.py @@ -70,7 +70,7 @@ def add_evaluator_args(parser: argparse.ArgumentParser): "--eval-latency-spm-model", type=str, default=None, - help="Pass the spm model path if the eval_latency_unit is spm." + help="Pass the spm model path if the eval_latency_unit is spm.", ) parser.add_argument( "--remote-address", @@ -188,8 +188,6 @@ def general_parser(): def add_slurm_args(parser): - parser.add_argument( - "--slurm-partition", default="", help="Slurm partition." - ) + parser.add_argument("--slurm-partition", default="", help="Slurm partition.") parser.add_argument("--slurm-job-name", default="simuleval", help="Slurm job name.") parser.add_argument("--slurm-time", default="2:00:00", help="Slurm partition.") diff --git a/simuleval/utils/slurm.py b/simuleval/utils/slurm.py index dd0bb693..99ff6f28 100644 --- a/simuleval/utils/slurm.py +++ b/simuleval/utils/slurm.py @@ -35,7 +35,8 @@ def submit_slurm_job(config_dict: Optional[Dict] = None) -> None: sweep_options = [ [[key, v] for v in value] - for key, value in config_dict.items() if isinstance(value, list) + for key, value in config_dict.items() + if isinstance(value, list) ] sweep_config_dict_list = [] if len(sweep_options) > 0: @@ -85,9 +86,11 @@ def submit_slurm_job(config_dict: Optional[Dict] = None) -> None: job_array_configs = "" if len(sweep_config_dict_list) > 0: - job_array_configs="declare -A JobArrayConfigs\n" + job_array_configs = "declare -A JobArrayConfigs\n" for i, sub_config_dict in enumerate(sweep_config_dict_list): - sub_config_string = " ".join([f"--{k.replace('_', '-')} {v}" for k, v in sub_config_dict.items()]) + sub_config_string = " ".join( + [f"--{k.replace('_', '-')} {v}" for k, v in sub_config_dict.items()] + ) job_array_configs += f'JobArrayConfigs[{i}]="{sub_config_string}"\n' job_array_configs += "\ndeclare -A JobArrayString\n" @@ -97,7 +100,9 @@ def submit_slurm_job(config_dict: Optional[Dict] = None) -> None: sweep_command = "${JobArrayConfigs[$SLURM_ARRAY_TASK_ID]}" sbatch_job_array_head = f"#SBATCH --array=0-{len(sweep_config_dict_list) - 1}" - output_dir = f"{args.output}" + "/results/${JobArrayString[$SLURM_ARRAY_TASK_ID]}" + output_dir = ( + f"{args.output}" + "/results/${JobArrayString[$SLURM_ARRAY_TASK_ID]}" + ) log_path = f"{args.output}/logs/slurm-%A_%a.log" else: From 700e2f7a01cf36660d512445a1780dead9471455 Mon Sep 17 00:00:00 2001 From: Xutai Ma Date: Thu, 26 Oct 2023 23:25:41 +0000 Subject: [PATCH 5/8] black --- simuleval/evaluator/scorers/quality_scorer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/simuleval/evaluator/scorers/quality_scorer.py b/simuleval/evaluator/scorers/quality_scorer.py index cdb0ed42..5afb3d75 100644 --- a/simuleval/evaluator/scorers/quality_scorer.py +++ b/simuleval/evaluator/scorers/quality_scorer.py @@ -316,7 +316,9 @@ def asr_transcribe(self, instances): wav_path = wav_dir / f"{index}_pred.wav" if wav_path.exists(): result = model.transcribe( - wav_path.as_posix(), language=self.target_lang, temperature=self.temperature + wav_path.as_posix(), + language=self.target_lang, + temperature=self.temperature, ) text = result["text"] assert type(text) == str From cd67038c3109c956dc882abdb4a49dcf7c8d3ccc Mon Sep 17 00:00:00 2001 From: Anna Sun <13106449+annasun28@users.noreply.github.com> Date: Fri, 3 Nov 2023 12:09:41 -0700 Subject: [PATCH 6/8] Add segment config, pass to all agents in pipeline --- simuleval/agents/agent.py | 1 + simuleval/agents/pipeline.py | 16 ++++++++++++++-- simuleval/agents/states.py | 7 +++++++ simuleval/data/segments.py | 1 + simuleval/evaluator/evaluator.py | 8 +++++++- 5 files changed, 30 insertions(+), 3 deletions(-) diff --git a/simuleval/agents/agent.py b/simuleval/agents/agent.py index a86a0af3..93d18507 100644 --- a/simuleval/agents/agent.py +++ b/simuleval/agents/agent.py @@ -100,6 +100,7 @@ def push( states.upstream_states = upstream_states + states.update_config(source_segment.config) states.update_source(source_segment) def pop(self, states: Optional[AgentStates] = None) -> Segment: diff --git a/simuleval/agents/pipeline.py b/simuleval/agents/pipeline.py index 1631c3f8..7f848c83 100644 --- a/simuleval/agents/pipeline.py +++ b/simuleval/agents/pipeline.py @@ -8,6 +8,10 @@ from typing import List, Optional, Dict, Set, Type, Union from simuleval.data.segments import Segment from .agent import GenericAgent, AgentStates +import time +import logging + +logger = logging.getLogger(__name__) class AgentPipeline(GenericAgent): @@ -69,11 +73,13 @@ def push( upstream_states = [] for index, module in enumerate(self.module_list[:-1]): + config = segment.config segment = module.pushpop( segment, states[index], upstream_states=upstream_states + states_list[:index], ) + segment.config = config self.module_list[-1].push( segment, states[-1], upstream_states=upstream_states + states_list[:index] ) @@ -274,12 +280,18 @@ def push_impl( # DFS over the tree children = self.module_dict[module] if len(children) == 0: # leaf node - module.push(segment, states[module]) + module.push(segment, states[module]) # , upstream_states) + # upstream_states[len(upstream_states)] = states[module] + # TODO: ? return [] + config = segment.config segment = module.pushpop(segment, states[module], upstream_states) + segment.config = config assert len(upstream_states) not in upstream_states - upstream_states[len(upstream_states)] = states[module] + upstream_states[len(upstream_states)] = ( + states[module] if states[module] is not None else module.states + ) for child in children: self.push_impl(child, segment, states, upstream_states) diff --git a/simuleval/agents/states.py b/simuleval/agents/states.py index dd71a61b..df605f53 100644 --- a/simuleval/agents/states.py +++ b/simuleval/agents/states.py @@ -31,6 +31,13 @@ def reset(self) -> None: self.target_sample_rate = 0 self.tgt_lang = None self.upstream_states = [] + self.config = {} + + def update_config(self, config: dict): + for k in config.keys(): + if k not in self.config: + # only update with new keys within each utterance + self.config[k] = config[k] def update_source(self, segment: Segment): """ diff --git a/simuleval/data/segments.py b/simuleval/data/segments.py index 29bc4347..a07f3376 100644 --- a/simuleval/data/segments.py +++ b/simuleval/data/segments.py @@ -17,6 +17,7 @@ class Segment: is_empty: bool = False data_type: str = None tgt_lang: str = None + config: dict = field(default_factory=dict) def json(self) -> str: info_dict = {attribute: value for attribute, value in self.__dict__.items()} diff --git a/simuleval/evaluator/evaluator.py b/simuleval/evaluator/evaluator.py index a0e7e598..8656a5cd 100644 --- a/simuleval/evaluator/evaluator.py +++ b/simuleval/evaluator/evaluator.py @@ -12,6 +12,7 @@ from .scorers import get_scorer_class from .scorers.latency_scorer import LatencyScorer from .scorers.quality_scorer import QualityScorer +from simuleval.data.segments import Segment from .instance import INSTANCE_TYPE_DICT, LogInstance import yaml @@ -236,7 +237,12 @@ def __call__(self, system): system.reset() for instance in self.instance_iterator: while not self.is_finished(instance): - input_segment = instance.send_source(self.source_segment_size) + input_segment: Segment = instance.send_source(self.source_segment_size) + # TODO: cleanup after testing + # TODO: test behavior of changing from non-expr --> expr mid-utterance + # it should be a no-op, and not take effect until the next utterance + # @xutaima: uncomment this to test dual agent w/expr + # input_segment.config["use_expr"] = True output_segment = system.pushpop(input_segment) instance.receive_prediction(output_segment) if instance.finish_prediction: From 7b6142caf309be99911b2a58976f1d5f2522f824 Mon Sep 17 00:00:00 2001 From: Xutai Ma Date: Mon, 6 Nov 2023 13:58:04 -0800 Subject: [PATCH 7/8] add branched agent pipeline --- simuleval/agents/branch.py | 157 +++++++++++++++++++++++++++++++ simuleval/agents/pipeline.py | 38 ++++++-- simuleval/evaluator/evaluator.py | 2 +- 3 files changed, 186 insertions(+), 11 deletions(-) create mode 100644 simuleval/agents/branch.py diff --git a/simuleval/agents/branch.py b/simuleval/agents/branch.py new file mode 100644 index 00000000..ee8892a1 --- /dev/null +++ b/simuleval/agents/branch.py @@ -0,0 +1,157 @@ +from typing import Dict, List, Optional +from argparse import Namespace +from simuleval.agents.pipeline import AgentPipeline +from simuleval.data.segments import Segment +from simuleval.agents.agent import GenericAgent, AgentStates + + +class BranchedAgentPipelineStates(AgentStates): + def __init__(self, states_dict: Dict[str, AgentStates]): + self.states_dict = states_dict + super().__init__() + + def reset(self) -> None: + super().reset() + for states in self.states_dict.values(): + for s in states: + s.reset() + + def update_source(self, segment: Segment): + pass + + def update_target(self, segment: Segment): + pass + + +class BranchedAgentPipeline(AgentPipeline): + """ + Select different agent branch to use + + Args: + pipeline_dict (dict): dictionary of agents can be select from different branch + """ + + branches = {} + name = "branch" + + def __init__( + self, + pipeline_dict: Dict[str, GenericAgent], + ): + self.pipeline_dict = pipeline_dict + for pipeline in self.pipeline_dict.values(): + assert isinstance(pipeline, AgentPipeline) + # the default branch model is the first one + self.default_branch_name = list(self.pipeline_dict.keys())[0] + self.states = self.build_states() + # Don't check the type for Now + + @property + def source_type(self) -> Optional[str]: + source_type = list( + set(pipeline.source_type for pipeline in self.pipeline_dict.values()) + ) + assert len(source_type) == 1, "source type should be the same for all branches" + return source_type[0] + + @property + def target_type(self) -> Optional[str]: + target_type = list( + set(pipeline.target_type for pipeline in self.pipeline_dict.values()) + ) + assert len(target_type) == 1, "target type should be the same for all branches" + return target_type[0] + + def push( + self, + segment: Segment, + states: BranchedAgentPipelineStates | None = None, + upstream_states: List[AgentStates | None] | None = None, + ) -> None: + is_stateless = True + if states is None: + states = self.states + is_stateless = False + + states.update_config(segment.config) + branch_name = self.get_branch_from_states(states) + branch_states = states.states_dict[branch_name] + + return super().push( + segment, + branch_states if states == is_stateless else None, + upstream_states, + module_list=self.pipeline_dict[branch_name].module_list, + ) + + def pop(self, states: BranchedAgentPipelineStates | None = None): + is_stateless = True + if states is None: + states = self.states + is_stateless = False + + branch_name = self.get_branch_from_states(states) + branch_states = states.states_dict[branch_name] + + return super().pop( + branch_states if states == is_stateless else None, + module_list=self.pipeline_dict[branch_name].module_list, + ) + + def build_states(self) -> BranchedAgentPipelineStates: + return BranchedAgentPipelineStates( + { + key: pipeline.build_states() + for key, pipeline in self.pipeline_dict.items() + } + ) + + def reset(self) -> None: + for agent in self.pipeline_dict.values(): + agent.reset() + + def get_branch_from_states(self, states): + if states is None: + # stateful agent + states = self.states + + branch_name = states.config.get(self.name, self.default_branch_name) + assert branch_name in self.pipeline_dict + return branch_name + + @classmethod + def from_args(cls, arg: Namespace): + pipeline_dict = {} + for branch_name, pipeline_class_or_list in cls.branches.items(): + if isinstance(pipeline_class_or_list, list): + pipeline_dict[branch_name] = AgentPipeline.from_pipeline_args( + pipeline_class_or_list, arg + ) + elif isinstance(pipeline_class_or_list, AgentPipeline): + pipeline_dict[branch_name] = pipeline_class_or_list.from_args(arg) + else: + raise NotImplementedError + + return cls(pipeline_dict) + + @classmethod + def add_args(cls, parser) -> None: + for pipeline_class_or_list in cls.branches.values(): + if isinstance(pipeline_class_or_list, list): + for agent in pipeline_class_or_list: + agent.add_args(parser) + else: + pipeline_class_or_list.add_args(parser) + + def __repr__(self) -> str: + # TODO, Display here is not correct + string_list = [] + for branch_name, pipeline in self.pipeline_dict.items(): + string_list.append(f"{branch_name}:\n\t\t{pipeline}") + string = ",\n".join( + [ + f"\t{branch_name}:{pipeline}" + for branch_name, pipeline in self.pipeline_dict.items() + ] + ) + return f"{self.__class__.__name__}(\n{string}\n)" diff --git a/simuleval/agents/pipeline.py b/simuleval/agents/pipeline.py index 7f848c83..64ac1970 100644 --- a/simuleval/agents/pipeline.py +++ b/simuleval/agents/pipeline.py @@ -59,20 +59,26 @@ def push( segment: Segment, states: Optional[List[Optional[AgentStates]]] = None, upstream_states: Optional[List[Optional[AgentStates]]] = None, + module_list: Optional[List[GenericAgent]] = None, ) -> None: + if module_list is None: + module_list = self.module_list + if states is None: # stateful agent - states = [None for _ in self.module_list] - states_list = [module.states for module in self.module_list] + states = [None for _ in module_list] + states_list = [module.states for module in module_list] else: # stateless agent - assert len(states) == len(self.module_list) + assert len(states) == len(module_list) states_list = states if upstream_states is None: upstream_states = [] - for index, module in enumerate(self.module_list[:-1]): + index = 0 + + for index, module in enumerate(module_list[:-1]): config = segment.config segment = module.pushpop( segment, @@ -80,18 +86,25 @@ def push( upstream_states=upstream_states + states_list[:index], ) segment.config = config - self.module_list[-1].push( + + module_list[-1].push( segment, states[-1], upstream_states=upstream_states + states_list[:index] ) - def pop(self, states: Optional[List[Optional[AgentStates]]] = None) -> Segment: + def pop( + self, + states: Optional[List[Optional[AgentStates]]] = None, + module_list: Optional[List[GenericAgent]] = None, + ) -> Segment: + if module_list is None: + module_list = self.module_list if states is None: last_states = None else: - assert len(states) == len(self.module_list) + assert len(states) == len(module_list) last_states = states[-1] - return self.module_list[-1].pop(last_states) + return module_list[-1].pop(last_states) @classmethod def add_args(cls, parser) -> None: @@ -103,11 +116,16 @@ def from_args(cls, args): assert len(cls.pipeline) > 0 return cls([module_class.from_args(args) for module_class in cls.pipeline]) + @classmethod + def from_pipeline_args(cls, pipeline, args): + assert len(pipeline) > 0 + return cls([module_class.from_args(args) for module_class in pipeline]) + def __repr__(self) -> str: - pipline_str = "\n\t".join( + pipeline_str = "\n\t".join( "\t".join(str(module).splitlines(True)) for module in self.module_list ) - return f"{self.__class__.__name__}(\n\t{pipline_str}\n)" + return f"{self.__class__.__name__}(\n\t\t{pipeline_str})" def __str__(self) -> str: return self.__repr__() diff --git a/simuleval/evaluator/evaluator.py b/simuleval/evaluator/evaluator.py index 8656a5cd..5284be58 100644 --- a/simuleval/evaluator/evaluator.py +++ b/simuleval/evaluator/evaluator.py @@ -242,7 +242,7 @@ def __call__(self, system): # TODO: test behavior of changing from non-expr --> expr mid-utterance # it should be a no-op, and not take effect until the next utterance # @xutaima: uncomment this to test dual agent w/expr - # input_segment.config["use_expr"] = True + # input_segment.config["synthesizer_mode"] = "expressive" output_segment = system.pushpop(input_segment) instance.receive_prediction(output_segment) if instance.finish_prediction: From a0a22276ff125910555dfd1f578522a7bfd4d135 Mon Sep 17 00:00:00 2001 From: Xutai Ma Date: Thu, 9 Nov 2023 13:39:03 -0800 Subject: [PATCH 8/8] Update segment finish flag --- simuleval/agents/branch.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/simuleval/agents/branch.py b/simuleval/agents/branch.py index ee8892a1..8b4a8013 100644 --- a/simuleval/agents/branch.py +++ b/simuleval/agents/branch.py @@ -17,10 +17,14 @@ def reset(self) -> None: s.reset() def update_source(self, segment: Segment): - pass + self.source_finished = segment.finished + for states in self.states_dict.values(): + states[0] = segment.finished def update_target(self, segment: Segment): - pass + self.target_finished = segment.finished + for states in self.states_dict.values(): + states[1] = segment.finished class BranchedAgentPipeline(AgentPipeline): @@ -144,7 +148,7 @@ def add_args(cls, parser) -> None: pipeline_class_or_list.add_args(parser) def __repr__(self) -> str: - # TODO, Display here is not correct + # TODO, indent here is not correct string_list = [] for branch_name, pipeline in self.pipeline_dict.items(): string_list.append(f"{branch_name}:\n\t\t{pipeline}")