From ff96c4ce45aee1de54fb0fdeabe65a8aee333541 Mon Sep 17 00:00:00 2001 From: Jason Rich Darmawan <63768126+jasonrichdarmawan@users.noreply.github.com> Date: Tue, 26 Aug 2025 07:14:54 +0000 Subject: [PATCH 1/3] feat: add multilanguage support to text2vec and vec2text --- sonar/inference_pipelines/text.py | 125 +++++++++++++++++++++++------- 1 file changed, 99 insertions(+), 26 deletions(-) diff --git a/sonar/inference_pipelines/text.py b/sonar/inference_pipelines/text.py index 34ab79b..5fc8b7a 100644 --- a/sonar/inference_pipelines/text.py +++ b/sonar/inference_pipelines/text.py @@ -174,7 +174,7 @@ def __init__( def predict( self, input: Union[Path, Sequence[str]], - source_lang: str, + source_lang: Union[Path, Sequence[str]], batch_size: Optional[int] = 5, batch_max_tokens: Optional[int] = None, max_seq_len: Optional[int] = None, @@ -196,9 +196,23 @@ def predict( if batch_size is not None and batch_size <= 0: raise ValueError("`batch_size` should be strictly positive") - tokenizer_encoder = self.tokenizer.create_encoder( - lang=source_lang, device=self.device - ) + if isinstance(source_lang, str): + tokenizer_encoder = self.tokenizer.create_encoder( + lang=source_lang, + device=self.device, + ) + def encode_fn(text: str): + return tokenizer_encoder(text) + else: + # Multiple languages + def encode_fn(x: tuple[str, str]): + text, lang = x + tokenizer_encoder = self.tokenizer.create_encoder( + lang=lang, + device=self.device, + ) + return tokenizer_encoder(text) + model_max_len = cast(int | None, self.model.encoder_frontend.pos_encoder.max_seq_len) # type: ignore[union-attr] if max_seq_len is None: max_seq_len = model_max_len @@ -221,15 +235,34 @@ def truncate(x: torch.Tensor) -> torch.Tensor: if isinstance(input, (str, Path)): pipeline_builder = read_text(Path(input)) sorting_index = None + if not isinstance(source_lang, str): + raise ValueError( + "If input is a file, source_lang must be a single string." + ) else: - # so it should a list - sorting_index = torch.argsort(torch.tensor(list(map(len, input)))) - pipeline_builder = read_sequence(list(sorting_index.cpu())).map( - input.__getitem__ + # input is a list + if isinstance(source_lang, str): + items = input + else: + if len(input) != len(source_lang): + raise ValueError( + "Length of input and source_lang must match." + ) + items = list(zip(input, source_lang)) + sorting_index = torch.argsort( + torch.tensor(list(map( + lambda x: len(x[0]) + if not isinstance(source_lang, str) + else len(x), + items + ))) ) + pipeline_builder = read_sequence( + list(sorting_index.cpu()) + ).map(items.__getitem__) pipeline: Iterable = ( - pipeline_builder.map(tokenizer_encoder) + pipeline_builder.map(encode_fn) .map(truncate) .dynamic_bucket( batch_max_tokens or 2**31, @@ -306,7 +339,7 @@ def __init__( def predict( self, inputs: torch.Tensor, - target_lang: str, + target_lang: Sequence[str], batch_size: int = 5, progress_bar: bool = False, sampler: Optional[Sampler] = None, @@ -319,25 +352,65 @@ def predict( else: generator = BeamSearchSeq2SeqGenerator(self.model, **generator_kwargs) - converter = SequenceToTextConverter( - generator, - self.tokenizer, - task="translation", - target_lang=target_lang, - ) + if isinstance(target_lang, str): + converter = SequenceToTextConverter( + generator, + self.tokenizer, + task="translation", + target_lang=target_lang, + ) + + def _do_translate(src_tensors: List[torch.Tensor]) -> List[str]: + texts, _ = converter.batch_convert( + torch.stack(src_tensors).to(self.device), + None + ) + return texts - def _do_translate(src_tensors: List[torch.Tensor]) -> List[str]: - texts, _ = converter.batch_convert( - torch.stack(src_tensors).to(self.device), None + pipeline: Iterable = ( + read_sequence(list(inputs)) + .bucket(batch_size) + .map(_do_translate) + .and_return() + ) + else: + # target_lang is a sequence, + # so decode each sample with its + # language + if len(target_lang) != len(inputs): + raise ValueError( + "input and target_lang must have the same length for multi-language decoding." + ) + + def _do_translate( + x: tuple[torch.Tensor, str] + ): + tensor, lang = x + converter = SequenceToTextConverter( + generator, + self.tokenizer, + task="translation", + target_lang=lang, + ) + texts, _ = converter.batch_convert( + torch.stack([tensor]).to(self.device), + None, + ) + return texts + + pipeline: Iterable = ( + read_sequence( + list( + zip( + list(inputs), + list(target_lang), + ) + ) + ) + .map(_do_translate) + .and_return() ) - return texts - pipeline: Iterable = ( - read_sequence(list(inputs)) - .bucket(batch_size) - .map(_do_translate) - .and_return() - ) if progress_bar: pipeline = add_progress_bar(pipeline, inputs=inputs, batch_size=batch_size) with precision_context(self.model.dtype): From f021dd809a7564e270ebefaa4bcebf1cac66fe34 Mon Sep 17 00:00:00 2001 From: Jason Rich Darmawan <63768126+jasonrichdarmawan@users.noreply.github.com> Date: Tue, 26 Aug 2025 11:25:39 +0000 Subject: [PATCH 2/3] fix: backward compatibility and guard clause --- sonar/inference_pipelines/text.py | 110 +++++++++++++----------------- 1 file changed, 46 insertions(+), 64 deletions(-) diff --git a/sonar/inference_pipelines/text.py b/sonar/inference_pipelines/text.py index 5fc8b7a..0ae48fb 100644 --- a/sonar/inference_pipelines/text.py +++ b/sonar/inference_pipelines/text.py @@ -174,7 +174,7 @@ def __init__( def predict( self, input: Union[Path, Sequence[str]], - source_lang: Union[Path, Sequence[str]], + source_lang: Union[str, Sequence[str]], batch_size: Optional[int] = 5, batch_max_tokens: Optional[int] = None, max_seq_len: Optional[int] = None, @@ -201,15 +201,17 @@ def predict( lang=source_lang, device=self.device, ) + def encode_fn(text: str): return tokenizer_encoder(text) + else: # Multiple languages def encode_fn(x: tuple[str, str]): text, lang = x tokenizer_encoder = self.tokenizer.create_encoder( - lang=lang, - device=self.device, + lang=lang, + device=self.device, ) return tokenizer_encoder(text) @@ -245,21 +247,25 @@ def truncate(x: torch.Tensor) -> torch.Tensor: items = input else: if len(input) != len(source_lang): - raise ValueError( - "Length of input and source_lang must match." - ) + raise ValueError("Length of input and source_lang must match.") items = list(zip(input, source_lang)) sorting_index = torch.argsort( - torch.tensor(list(map( - lambda x: len(x[0]) - if not isinstance(source_lang, str) - else len(x), - items - ))) + torch.tensor( + list( + map( + lambda x: ( + len(x[0]) + if not isinstance(source_lang, str) + else len(x) + ), + items, + ) + ) + ) + ) + pipeline_builder = read_sequence(list(sorting_index.cpu())).map( + items.__getitem__ ) - pipeline_builder = read_sequence( - list(sorting_index.cpu()) - ).map(items.__getitem__) pipeline: Iterable = ( pipeline_builder.map(encode_fn) @@ -339,7 +345,7 @@ def __init__( def predict( self, inputs: torch.Tensor, - target_lang: Sequence[str], + target_lang: Union[str, Sequence[str]], batch_size: int = 5, progress_bar: bool = False, sampler: Optional[Sampler] = None, @@ -353,63 +359,39 @@ def predict( generator = BeamSearchSeq2SeqGenerator(self.model, **generator_kwargs) if isinstance(target_lang, str): + target_lang = [target_lang] * len(inputs) + + if len(target_lang) != len(inputs): + raise ValueError( + "input and target_lang must have the same length for multi-language decoding." + ) + + def _do_translate(x: tuple[torch.Tensor, str]): + tensor, lang = x converter = SequenceToTextConverter( generator, self.tokenizer, task="translation", - target_lang=target_lang, + target_lang=lang, ) - - def _do_translate(src_tensors: List[torch.Tensor]) -> List[str]: - texts, _ = converter.batch_convert( - torch.stack(src_tensors).to(self.device), - None - ) - return texts - - pipeline: Iterable = ( - read_sequence(list(inputs)) - .bucket(batch_size) - .map(_do_translate) - .and_return() + texts, _ = converter.batch_convert( + torch.stack([tensor]).to(self.device), + None, ) - else: - # target_lang is a sequence, - # so decode each sample with its - # language - if len(target_lang) != len(inputs): - raise ValueError( - "input and target_lang must have the same length for multi-language decoding." - ) - - def _do_translate( - x: tuple[torch.Tensor, str] - ): - tensor, lang = x - converter = SequenceToTextConverter( - generator, - self.tokenizer, - task="translation", - target_lang=lang, - ) - texts, _ = converter.batch_convert( - torch.stack([tensor]).to(self.device), - None, - ) - return texts - - pipeline: Iterable = ( - read_sequence( - list( - zip( - list(inputs), - list(target_lang), - ) + return texts + + pipeline: Iterable = ( + read_sequence( + list( + zip( + list(inputs), + list(target_lang), ) ) - .map(_do_translate) - .and_return() ) + .map(_do_translate) + .and_return() + ) if progress_bar: pipeline = add_progress_bar(pipeline, inputs=inputs, batch_size=batch_size) From 5eaf513e4ee76a913eedee92adf0a6eb075d4224 Mon Sep 17 00:00:00 2001 From: Jason Rich Darmawan <63768126+jasonrichdarmawan@users.noreply.github.com> Date: Fri, 5 Sep 2025 12:08:02 +0000 Subject: [PATCH 3/3] chore: fix static type checker fail --- sonar/inference_pipelines/text.py | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/sonar/inference_pipelines/text.py b/sonar/inference_pipelines/text.py index 0ae48fb..0b8ccc0 100644 --- a/sonar/inference_pipelines/text.py +++ b/sonar/inference_pipelines/text.py @@ -196,18 +196,17 @@ def predict( if batch_size is not None and batch_size <= 0: raise ValueError("`batch_size` should be strictly positive") - if isinstance(source_lang, str): - tokenizer_encoder = self.tokenizer.create_encoder( - lang=source_lang, - device=self.device, - ) - - def encode_fn(text: str): - return tokenizer_encoder(text) - - else: - # Multiple languages - def encode_fn(x: tuple[str, str]): + def encode_fn(x: Union[str, tuple[str, str]]) -> torch.Tensor: + if isinstance(source_lang, str): + assert isinstance(x, str) + tokenizer_encoder = self.tokenizer.create_encoder( + lang=source_lang, + device=self.device, + ) + return tokenizer_encoder(x) + else: + # Multiple languages + assert isinstance(x, tuple) text, lang = x tokenizer_encoder = self.tokenizer.create_encoder( lang=lang, @@ -248,7 +247,7 @@ def truncate(x: torch.Tensor) -> torch.Tensor: else: if len(input) != len(source_lang): raise ValueError("Length of input and source_lang must match.") - items = list(zip(input, source_lang)) + items = list(zip(input, source_lang)) # type: ignore[arg-type] sorting_index = torch.argsort( torch.tensor( list( @@ -263,9 +262,8 @@ def truncate(x: torch.Tensor) -> torch.Tensor: ) ) ) - pipeline_builder = read_sequence(list(sorting_index.cpu())).map( - items.__getitem__ - ) + sorted_items = [items[i] for i in sorting_index.tolist()] + pipeline_builder = read_sequence(sorted_items) pipeline: Iterable = ( pipeline_builder.map(encode_fn)