diff --git a/sonar/inference_pipelines/text.py b/sonar/inference_pipelines/text.py index 34ab79b..0b8ccc0 100644 --- a/sonar/inference_pipelines/text.py +++ b/sonar/inference_pipelines/text.py @@ -174,7 +174,7 @@ def __init__( def predict( self, input: Union[Path, Sequence[str]], - source_lang: str, + source_lang: Union[str, Sequence[str]], batch_size: Optional[int] = 5, batch_max_tokens: Optional[int] = None, max_seq_len: Optional[int] = None, @@ -196,9 +196,24 @@ def predict( if batch_size is not None and batch_size <= 0: raise ValueError("`batch_size` should be strictly positive") - tokenizer_encoder = self.tokenizer.create_encoder( - lang=source_lang, device=self.device - ) + def encode_fn(x: Union[str, tuple[str, str]]) -> torch.Tensor: + if isinstance(source_lang, str): + assert isinstance(x, str) + tokenizer_encoder = self.tokenizer.create_encoder( + lang=source_lang, + device=self.device, + ) + return tokenizer_encoder(x) + else: + # Multiple languages + assert isinstance(x, tuple) + text, lang = x + tokenizer_encoder = self.tokenizer.create_encoder( + lang=lang, + device=self.device, + ) + return tokenizer_encoder(text) + model_max_len = cast(int | None, self.model.encoder_frontend.pos_encoder.max_seq_len) # type: ignore[union-attr] if max_seq_len is None: max_seq_len = model_max_len @@ -221,15 +236,37 @@ def truncate(x: torch.Tensor) -> torch.Tensor: if isinstance(input, (str, Path)): pipeline_builder = read_text(Path(input)) sorting_index = None + if not isinstance(source_lang, str): + raise ValueError( + "If input is a file, source_lang must be a single string." + ) else: - # so it should a list - sorting_index = torch.argsort(torch.tensor(list(map(len, input)))) - pipeline_builder = read_sequence(list(sorting_index.cpu())).map( - input.__getitem__ + # input is a list + if isinstance(source_lang, str): + items = input + else: + if len(input) != len(source_lang): + raise ValueError("Length of input and source_lang must match.") + items = list(zip(input, source_lang)) # type: ignore[arg-type] + sorting_index = torch.argsort( + torch.tensor( + list( + map( + lambda x: ( + len(x[0]) + if not isinstance(source_lang, str) + else len(x) + ), + items, + ) + ) + ) ) + sorted_items = [items[i] for i in sorting_index.tolist()] + pipeline_builder = read_sequence(sorted_items) pipeline: Iterable = ( - pipeline_builder.map(tokenizer_encoder) + pipeline_builder.map(encode_fn) .map(truncate) .dynamic_bucket( batch_max_tokens or 2**31, @@ -306,7 +343,7 @@ def __init__( def predict( self, inputs: torch.Tensor, - target_lang: str, + target_lang: Union[str, Sequence[str]], batch_size: int = 5, progress_bar: bool = False, sampler: Optional[Sampler] = None, @@ -319,25 +356,41 @@ def predict( else: generator = BeamSearchSeq2SeqGenerator(self.model, **generator_kwargs) - converter = SequenceToTextConverter( - generator, - self.tokenizer, - task="translation", - target_lang=target_lang, - ) + if isinstance(target_lang, str): + target_lang = [target_lang] * len(inputs) + + if len(target_lang) != len(inputs): + raise ValueError( + "input and target_lang must have the same length for multi-language decoding." + ) - def _do_translate(src_tensors: List[torch.Tensor]) -> List[str]: + def _do_translate(x: tuple[torch.Tensor, str]): + tensor, lang = x + converter = SequenceToTextConverter( + generator, + self.tokenizer, + task="translation", + target_lang=lang, + ) texts, _ = converter.batch_convert( - torch.stack(src_tensors).to(self.device), None + torch.stack([tensor]).to(self.device), + None, ) return texts pipeline: Iterable = ( - read_sequence(list(inputs)) - .bucket(batch_size) + read_sequence( + list( + zip( + list(inputs), + list(target_lang), + ) + ) + ) .map(_do_translate) .and_return() ) + if progress_bar: pipeline = add_progress_bar(pipeline, inputs=inputs, batch_size=batch_size) with precision_context(self.model.dtype):