Skip to content
This repository was archived by the owner on Sep 18, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions simuleval/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def push(

states.upstream_states = upstream_states

states.update_config(source_segment.config)
states.update_source(source_segment)

def pop(self, states: Optional[AgentStates] = None) -> Segment:
Expand Down
161 changes: 161 additions & 0 deletions simuleval/agents/branch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
from typing import Dict, List, Optional
from argparse import Namespace
from simuleval.agents.pipeline import AgentPipeline
from simuleval.data.segments import Segment
from simuleval.agents.agent import GenericAgent, AgentStates


class BranchedAgentPipelineStates(AgentStates):
def __init__(self, states_dict: Dict[str, AgentStates]):
self.states_dict = states_dict
super().__init__()

def reset(self) -> None:
super().reset()
for states in self.states_dict.values():
for s in states:
s.reset()

def update_source(self, segment: Segment):
self.source_finished = segment.finished
for states in self.states_dict.values():
states[0] = segment.finished

def update_target(self, segment: Segment):
self.target_finished = segment.finished
for states in self.states_dict.values():
states[1] = segment.finished


class BranchedAgentPipeline(AgentPipeline):
"""
Select different agent branch to use

Args:
pipeline_dict (dict): dictionary of agents can be select from different branch
"""

branches = {}
name = "branch"

def __init__(
self,
pipeline_dict: Dict[str, GenericAgent],
):
self.pipeline_dict = pipeline_dict
for pipeline in self.pipeline_dict.values():
assert isinstance(pipeline, AgentPipeline)
# the default branch model is the first one
self.default_branch_name = list(self.pipeline_dict.keys())[0]
self.states = self.build_states()
# Don't check the type for Now

@property
def source_type(self) -> Optional[str]:
source_type = list(
set(pipeline.source_type for pipeline in self.pipeline_dict.values())
)
assert len(source_type) == 1, "source type should be the same for all branches"
return source_type[0]

@property
def target_type(self) -> Optional[str]:
target_type = list(
set(pipeline.target_type for pipeline in self.pipeline_dict.values())
)
assert len(target_type) == 1, "target type should be the same for all branches"
return target_type[0]

def push(
self,
segment: Segment,
states: BranchedAgentPipelineStates | None = None,
upstream_states: List[AgentStates | None] | None = None,
) -> None:
is_stateless = True
if states is None:
states = self.states
is_stateless = False

states.update_config(segment.config)
branch_name = self.get_branch_from_states(states)
branch_states = states.states_dict[branch_name]

return super().push(
segment,
branch_states if states == is_stateless else None,
upstream_states,
module_list=self.pipeline_dict[branch_name].module_list,
)

def pop(self, states: BranchedAgentPipelineStates | None = None):
is_stateless = True
if states is None:
states = self.states
is_stateless = False

branch_name = self.get_branch_from_states(states)
branch_states = states.states_dict[branch_name]

return super().pop(
branch_states if states == is_stateless else None,
module_list=self.pipeline_dict[branch_name].module_list,
)

def build_states(self) -> BranchedAgentPipelineStates:
return BranchedAgentPipelineStates(
{
key: pipeline.build_states()
for key, pipeline in self.pipeline_dict.items()
}
)

def reset(self) -> None:
for agent in self.pipeline_dict.values():
agent.reset()

def get_branch_from_states(self, states):
if states is None:
# stateful agent
states = self.states

branch_name = states.config.get(self.name, self.default_branch_name)
assert branch_name in self.pipeline_dict
return branch_name

@classmethod
def from_args(cls, arg: Namespace):
pipeline_dict = {}
for branch_name, pipeline_class_or_list in cls.branches.items():
if isinstance(pipeline_class_or_list, list):
pipeline_dict[branch_name] = AgentPipeline.from_pipeline_args(
pipeline_class_or_list, arg
)
elif isinstance(pipeline_class_or_list, AgentPipeline):
pipeline_dict[branch_name] = pipeline_class_or_list.from_args(arg)
else:
raise NotImplementedError

return cls(pipeline_dict)

@classmethod
def add_args(cls, parser) -> None:
for pipeline_class_or_list in cls.branches.values():
if isinstance(pipeline_class_or_list, list):
for agent in pipeline_class_or_list:
agent.add_args(parser)
else:
pipeline_class_or_list.add_args(parser)

def __repr__(self) -> str:
# TODO, indent here is not correct
string_list = []
for branch_name, pipeline in self.pipeline_dict.items():
string_list.append(f"{branch_name}:\n\t\t{pipeline}")
string = ",\n".join(
[
f"\t{branch_name}:{pipeline}"
for branch_name, pipeline in self.pipeline_dict.items()
]
)
return f"{self.__class__.__name__}(\n{string}\n)"
54 changes: 42 additions & 12 deletions simuleval/agents/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
from typing import List, Optional, Dict, Set, Type, Union
from simuleval.data.segments import Segment
from .agent import GenericAgent, AgentStates
import time
import logging

logger = logging.getLogger(__name__)


class AgentPipeline(GenericAgent):
Expand Down Expand Up @@ -55,37 +59,52 @@ def push(
segment: Segment,
states: Optional[List[Optional[AgentStates]]] = None,
upstream_states: Optional[List[Optional[AgentStates]]] = None,
module_list: Optional[List[GenericAgent]] = None,
) -> None:
if module_list is None:
module_list = self.module_list

if states is None:
# stateful agent
states = [None for _ in self.module_list]
states_list = [module.states for module in self.module_list]
states = [None for _ in module_list]
states_list = [module.states for module in module_list]
else:
# stateless agent
assert len(states) == len(self.module_list)
assert len(states) == len(module_list)
states_list = states

if upstream_states is None:
upstream_states = []

for index, module in enumerate(self.module_list[:-1]):
index = 0

for index, module in enumerate(module_list[:-1]):
config = segment.config
segment = module.pushpop(
segment,
states[index],
upstream_states=upstream_states + states_list[:index],
)
self.module_list[-1].push(
segment.config = config

module_list[-1].push(
segment, states[-1], upstream_states=upstream_states + states_list[:index]
)

def pop(self, states: Optional[List[Optional[AgentStates]]] = None) -> Segment:
def pop(
self,
states: Optional[List[Optional[AgentStates]]] = None,
module_list: Optional[List[GenericAgent]] = None,
) -> Segment:
if module_list is None:
module_list = self.module_list
if states is None:
last_states = None
else:
assert len(states) == len(self.module_list)
assert len(states) == len(module_list)
last_states = states[-1]

return self.module_list[-1].pop(last_states)
return module_list[-1].pop(last_states)

@classmethod
def add_args(cls, parser) -> None:
Expand All @@ -97,11 +116,16 @@ def from_args(cls, args):
assert len(cls.pipeline) > 0
return cls([module_class.from_args(args) for module_class in cls.pipeline])

@classmethod
def from_pipeline_args(cls, pipeline, args):
assert len(pipeline) > 0
return cls([module_class.from_args(args) for module_class in pipeline])

def __repr__(self) -> str:
pipline_str = "\n\t".join(
pipeline_str = "\n\t".join(
"\t".join(str(module).splitlines(True)) for module in self.module_list
)
return f"{self.__class__.__name__}(\n\t{pipline_str}\n)"
return f"{self.__class__.__name__}(\n\t\t{pipeline_str})"

def __str__(self) -> str:
return self.__repr__()
Expand Down Expand Up @@ -274,12 +298,18 @@ def push_impl(
# DFS over the tree
children = self.module_dict[module]
if len(children) == 0: # leaf node
module.push(segment, states[module])
module.push(segment, states[module]) # , upstream_states)
# upstream_states[len(upstream_states)] = states[module]
# TODO: ?
return []

config = segment.config
segment = module.pushpop(segment, states[module], upstream_states)
segment.config = config
assert len(upstream_states) not in upstream_states
upstream_states[len(upstream_states)] = states[module]
upstream_states[len(upstream_states)] = (
states[module] if states[module] is not None else module.states
)

for child in children:
self.push_impl(child, segment, states, upstream_states)
Expand Down
7 changes: 7 additions & 0 deletions simuleval/agents/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ def reset(self) -> None:
self.target_sample_rate = 0
self.tgt_lang = None
self.upstream_states = []
self.config = {}

def update_config(self, config: dict):
for k in config.keys():
if k not in self.config:
# only update with new keys within each utterance
self.config[k] = config[k]

def update_source(self, segment: Segment):
"""
Expand Down
1 change: 1 addition & 0 deletions simuleval/data/segments.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class Segment:
is_empty: bool = False
data_type: str = None
tgt_lang: str = None
config: dict = field(default_factory=dict)

def json(self) -> str:
info_dict = {attribute: value for attribute, value in self.__dict__.items()}
Expand Down
8 changes: 7 additions & 1 deletion simuleval/evaluator/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .scorers import get_scorer_class
from .scorers.latency_scorer import LatencyScorer
from .scorers.quality_scorer import QualityScorer
from simuleval.data.segments import Segment

from .instance import INSTANCE_TYPE_DICT, LogInstance
import yaml
Expand Down Expand Up @@ -236,7 +237,12 @@ def __call__(self, system):
system.reset()
for instance in self.instance_iterator:
while not self.is_finished(instance):
input_segment = instance.send_source(self.source_segment_size)
input_segment: Segment = instance.send_source(self.source_segment_size)
# TODO: cleanup after testing
# TODO: test behavior of changing from non-expr --> expr mid-utterance
# it should be a no-op, and not take effect until the next utterance
# @xutaima: uncomment this to test dual agent w/expr
# input_segment.config["synthesizer_mode"] = "expressive"
output_segment = system.pushpop(input_segment)
instance.receive_prediction(output_segment)
if instance.finish_prediction:
Expand Down