diff --git a/simuleval/agents/agent.py b/simuleval/agents/agent.py index a86a0af3..93d18507 100644 --- a/simuleval/agents/agent.py +++ b/simuleval/agents/agent.py @@ -100,6 +100,7 @@ def push( states.upstream_states = upstream_states + states.update_config(source_segment.config) states.update_source(source_segment) def pop(self, states: Optional[AgentStates] = None) -> Segment: diff --git a/simuleval/agents/branch.py b/simuleval/agents/branch.py new file mode 100644 index 00000000..8b4a8013 --- /dev/null +++ b/simuleval/agents/branch.py @@ -0,0 +1,161 @@ +from typing import Dict, List, Optional +from argparse import Namespace +from simuleval.agents.pipeline import AgentPipeline +from simuleval.data.segments import Segment +from simuleval.agents.agent import GenericAgent, AgentStates + + +class BranchedAgentPipelineStates(AgentStates): + def __init__(self, states_dict: Dict[str, AgentStates]): + self.states_dict = states_dict + super().__init__() + + def reset(self) -> None: + super().reset() + for states in self.states_dict.values(): + for s in states: + s.reset() + + def update_source(self, segment: Segment): + self.source_finished = segment.finished + for states in self.states_dict.values(): + states[0] = segment.finished + + def update_target(self, segment: Segment): + self.target_finished = segment.finished + for states in self.states_dict.values(): + states[1] = segment.finished + + +class BranchedAgentPipeline(AgentPipeline): + """ + Select different agent branch to use + + Args: + pipeline_dict (dict): dictionary of agents can be select from different branch + """ + + branches = {} + name = "branch" + + def __init__( + self, + pipeline_dict: Dict[str, GenericAgent], + ): + self.pipeline_dict = pipeline_dict + for pipeline in self.pipeline_dict.values(): + assert isinstance(pipeline, AgentPipeline) + # the default branch model is the first one + self.default_branch_name = list(self.pipeline_dict.keys())[0] + self.states = self.build_states() + # Don't check the type for Now + + @property + def source_type(self) -> Optional[str]: + source_type = list( + set(pipeline.source_type for pipeline in self.pipeline_dict.values()) + ) + assert len(source_type) == 1, "source type should be the same for all branches" + return source_type[0] + + @property + def target_type(self) -> Optional[str]: + target_type = list( + set(pipeline.target_type for pipeline in self.pipeline_dict.values()) + ) + assert len(target_type) == 1, "target type should be the same for all branches" + return target_type[0] + + def push( + self, + segment: Segment, + states: BranchedAgentPipelineStates | None = None, + upstream_states: List[AgentStates | None] | None = None, + ) -> None: + is_stateless = True + if states is None: + states = self.states + is_stateless = False + + states.update_config(segment.config) + branch_name = self.get_branch_from_states(states) + branch_states = states.states_dict[branch_name] + + return super().push( + segment, + branch_states if states == is_stateless else None, + upstream_states, + module_list=self.pipeline_dict[branch_name].module_list, + ) + + def pop(self, states: BranchedAgentPipelineStates | None = None): + is_stateless = True + if states is None: + states = self.states + is_stateless = False + + branch_name = self.get_branch_from_states(states) + branch_states = states.states_dict[branch_name] + + return super().pop( + branch_states if states == is_stateless else None, + module_list=self.pipeline_dict[branch_name].module_list, + ) + + def build_states(self) -> BranchedAgentPipelineStates: + return BranchedAgentPipelineStates( + { + key: pipeline.build_states() + for key, pipeline in self.pipeline_dict.items() + } + ) + + def reset(self) -> None: + for agent in self.pipeline_dict.values(): + agent.reset() + + def get_branch_from_states(self, states): + if states is None: + # stateful agent + states = self.states + + branch_name = states.config.get(self.name, self.default_branch_name) + assert branch_name in self.pipeline_dict + return branch_name + + @classmethod + def from_args(cls, arg: Namespace): + pipeline_dict = {} + for branch_name, pipeline_class_or_list in cls.branches.items(): + if isinstance(pipeline_class_or_list, list): + pipeline_dict[branch_name] = AgentPipeline.from_pipeline_args( + pipeline_class_or_list, arg + ) + elif isinstance(pipeline_class_or_list, AgentPipeline): + pipeline_dict[branch_name] = pipeline_class_or_list.from_args(arg) + else: + raise NotImplementedError + + return cls(pipeline_dict) + + @classmethod + def add_args(cls, parser) -> None: + for pipeline_class_or_list in cls.branches.values(): + if isinstance(pipeline_class_or_list, list): + for agent in pipeline_class_or_list: + agent.add_args(parser) + else: + pipeline_class_or_list.add_args(parser) + + def __repr__(self) -> str: + # TODO, indent here is not correct + string_list = [] + for branch_name, pipeline in self.pipeline_dict.items(): + string_list.append(f"{branch_name}:\n\t\t{pipeline}") + string = ",\n".join( + [ + f"\t{branch_name}:{pipeline}" + for branch_name, pipeline in self.pipeline_dict.items() + ] + ) + return f"{self.__class__.__name__}(\n{string}\n)" diff --git a/simuleval/agents/pipeline.py b/simuleval/agents/pipeline.py index 1631c3f8..64ac1970 100644 --- a/simuleval/agents/pipeline.py +++ b/simuleval/agents/pipeline.py @@ -8,6 +8,10 @@ from typing import List, Optional, Dict, Set, Type, Union from simuleval.data.segments import Segment from .agent import GenericAgent, AgentStates +import time +import logging + +logger = logging.getLogger(__name__) class AgentPipeline(GenericAgent): @@ -55,37 +59,52 @@ def push( segment: Segment, states: Optional[List[Optional[AgentStates]]] = None, upstream_states: Optional[List[Optional[AgentStates]]] = None, + module_list: Optional[List[GenericAgent]] = None, ) -> None: + if module_list is None: + module_list = self.module_list + if states is None: # stateful agent - states = [None for _ in self.module_list] - states_list = [module.states for module in self.module_list] + states = [None for _ in module_list] + states_list = [module.states for module in module_list] else: # stateless agent - assert len(states) == len(self.module_list) + assert len(states) == len(module_list) states_list = states if upstream_states is None: upstream_states = [] - for index, module in enumerate(self.module_list[:-1]): + index = 0 + + for index, module in enumerate(module_list[:-1]): + config = segment.config segment = module.pushpop( segment, states[index], upstream_states=upstream_states + states_list[:index], ) - self.module_list[-1].push( + segment.config = config + + module_list[-1].push( segment, states[-1], upstream_states=upstream_states + states_list[:index] ) - def pop(self, states: Optional[List[Optional[AgentStates]]] = None) -> Segment: + def pop( + self, + states: Optional[List[Optional[AgentStates]]] = None, + module_list: Optional[List[GenericAgent]] = None, + ) -> Segment: + if module_list is None: + module_list = self.module_list if states is None: last_states = None else: - assert len(states) == len(self.module_list) + assert len(states) == len(module_list) last_states = states[-1] - return self.module_list[-1].pop(last_states) + return module_list[-1].pop(last_states) @classmethod def add_args(cls, parser) -> None: @@ -97,11 +116,16 @@ def from_args(cls, args): assert len(cls.pipeline) > 0 return cls([module_class.from_args(args) for module_class in cls.pipeline]) + @classmethod + def from_pipeline_args(cls, pipeline, args): + assert len(pipeline) > 0 + return cls([module_class.from_args(args) for module_class in pipeline]) + def __repr__(self) -> str: - pipline_str = "\n\t".join( + pipeline_str = "\n\t".join( "\t".join(str(module).splitlines(True)) for module in self.module_list ) - return f"{self.__class__.__name__}(\n\t{pipline_str}\n)" + return f"{self.__class__.__name__}(\n\t\t{pipeline_str})" def __str__(self) -> str: return self.__repr__() @@ -274,12 +298,18 @@ def push_impl( # DFS over the tree children = self.module_dict[module] if len(children) == 0: # leaf node - module.push(segment, states[module]) + module.push(segment, states[module]) # , upstream_states) + # upstream_states[len(upstream_states)] = states[module] + # TODO: ? return [] + config = segment.config segment = module.pushpop(segment, states[module], upstream_states) + segment.config = config assert len(upstream_states) not in upstream_states - upstream_states[len(upstream_states)] = states[module] + upstream_states[len(upstream_states)] = ( + states[module] if states[module] is not None else module.states + ) for child in children: self.push_impl(child, segment, states, upstream_states) diff --git a/simuleval/agents/states.py b/simuleval/agents/states.py index dd71a61b..df605f53 100644 --- a/simuleval/agents/states.py +++ b/simuleval/agents/states.py @@ -31,6 +31,13 @@ def reset(self) -> None: self.target_sample_rate = 0 self.tgt_lang = None self.upstream_states = [] + self.config = {} + + def update_config(self, config: dict): + for k in config.keys(): + if k not in self.config: + # only update with new keys within each utterance + self.config[k] = config[k] def update_source(self, segment: Segment): """ diff --git a/simuleval/data/segments.py b/simuleval/data/segments.py index 29bc4347..a07f3376 100644 --- a/simuleval/data/segments.py +++ b/simuleval/data/segments.py @@ -17,6 +17,7 @@ class Segment: is_empty: bool = False data_type: str = None tgt_lang: str = None + config: dict = field(default_factory=dict) def json(self) -> str: info_dict = {attribute: value for attribute, value in self.__dict__.items()} diff --git a/simuleval/evaluator/evaluator.py b/simuleval/evaluator/evaluator.py index a0e7e598..5284be58 100644 --- a/simuleval/evaluator/evaluator.py +++ b/simuleval/evaluator/evaluator.py @@ -12,6 +12,7 @@ from .scorers import get_scorer_class from .scorers.latency_scorer import LatencyScorer from .scorers.quality_scorer import QualityScorer +from simuleval.data.segments import Segment from .instance import INSTANCE_TYPE_DICT, LogInstance import yaml @@ -236,7 +237,12 @@ def __call__(self, system): system.reset() for instance in self.instance_iterator: while not self.is_finished(instance): - input_segment = instance.send_source(self.source_segment_size) + input_segment: Segment = instance.send_source(self.source_segment_size) + # TODO: cleanup after testing + # TODO: test behavior of changing from non-expr --> expr mid-utterance + # it should be a no-op, and not take effect until the next utterance + # @xutaima: uncomment this to test dual agent w/expr + # input_segment.config["synthesizer_mode"] = "expressive" output_segment = system.pushpop(input_segment) instance.receive_prediction(output_segment) if instance.finish_prediction: