diff --git a/examples/quick_start/spm_detokenizer_agent.py b/examples/quick_start/spm_detokenizer_agent.py index a5357538..31d64d64 100644 --- a/examples/quick_start/spm_detokenizer_agent.py +++ b/examples/quick_start/spm_detokenizer_agent.py @@ -2,12 +2,14 @@ from fairseq.data.encoders import build_bpe +from simuleval.utils import entrypoint from simuleval.agents import TextToTextAgent from simuleval.agents.actions import ReadAction, WriteAction from simuleval.agents.pipeline import AgentPipeline from simuleval.agents.states import AgentStates +@entrypoint class DummySegmentAgent(TextToTextAgent): """ This agent just splits on space @@ -70,13 +72,20 @@ def policy(self, states: AgentStates): possible_full_words = self.spm_processor.decode( " ".join([x for x in states.source]) ) - + # issue is when the starting word is in the previous segment if self.detokenize_only and len(states.source) > 0: + start_word = "▁" + source_text = states.source[0] states.source = [] if len(possible_full_words) == 0 and not states.source_finished: return ReadAction() else: - return WriteAction(possible_full_words, states.source_finished) + incomplete_word = False + if start_word not in source_text[0]: + incomplete_word = True + return WriteAction( + possible_full_words, states.source_finished, incomplete_word + ) if states.source_finished: return WriteAction(possible_full_words, True) diff --git a/examples/quick_start/spm_source.txt b/examples/quick_start/spm_source.txt index fe099b99..13ccecc6 100644 --- a/examples/quick_start/spm_source.txt +++ b/examples/quick_start/spm_source.txt @@ -1 +1 @@ -▁Let ' s ▁do ▁it ▁with out ▁hesitation . \ No newline at end of file +▁Let ' s ▁do ▁it ▁with out ▁hesitation . diff --git a/examples/quick_start/spm_target.txt b/examples/quick_start/spm_target.txt index e3b2ae21..c228cb4c 100644 --- a/examples/quick_start/spm_target.txt +++ b/examples/quick_start/spm_target.txt @@ -1 +1 @@ -Let's do it without hesitation. \ No newline at end of file +Let's do it without hesitation. diff --git a/examples/quick_start/tokenizer.model b/examples/quick_start/tokenizer.model new file mode 100644 index 00000000..e6149ec4 Binary files /dev/null and b/examples/quick_start/tokenizer.model differ diff --git a/simuleval/agents/actions.py b/simuleval/agents/actions.py index 81e3a4c5..cc251530 100644 --- a/simuleval/agents/actions.py +++ b/simuleval/agents/actions.py @@ -4,7 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from typing import Union, List +from typing import Optional, Union, List from dataclasses import dataclass from simuleval.data.segments import Segment @@ -52,6 +52,7 @@ class WriteAction(Action): content: Union[str, List[float], Segment] finished: bool + incomplete_word: Optional[bool] = None def is_read(self) -> bool: return False diff --git a/simuleval/agents/agent.py b/simuleval/agents/agent.py index c1e7b183..99343a45 100644 --- a/simuleval/agents/agent.py +++ b/simuleval/agents/agent.py @@ -127,6 +127,10 @@ def pop(self, states: Optional[AgentStates] = None) -> Segment: segment = SEGMENT_TYPE_DICT[self.target_type]( index=0, content=action.content, finished=action.finished ) + + if isinstance(segment, TextSegment) and action.incomplete_word is not None: + segment.incomplete_word = action.incomplete_word + states.update_target(segment) return segment diff --git a/simuleval/data/segments.py b/simuleval/data/segments.py index c823695e..8961914c 100644 --- a/simuleval/data/segments.py +++ b/simuleval/data/segments.py @@ -34,6 +34,7 @@ class EmptySegment(Segment): class TextSegment(Segment): content: str = "" data_type: str = "text" + incomplete_word: bool = None @dataclass diff --git a/simuleval/evaluator/instance.py b/simuleval/evaluator/instance.py index c08fd978..2e77ed93 100644 --- a/simuleval/evaluator/instance.py +++ b/simuleval/evaluator/instance.py @@ -201,6 +201,16 @@ def receive_prediction(self, prediction: TextSegment): else: raise NotImplementedError + if prediction.incomplete_word: + first_half = self.prediction_list[-1] + second_half = prediction_list[0] + complete_word = first_half + second_half + self.prediction_list.pop() + self.delays.pop() + self.elapsed.pop() + prediction_list.pop(0) + prediction_list.insert(0, complete_word) + self.prediction_list += prediction_list self.elapsed += [self.step_to_elapsed(self.step, current_time)] * len(