From 2fffca8f8cd955757bf9458dd0341adfbfbbfcf9 Mon Sep 17 00:00:00 2001 From: Denis Bezykornov Date: Mon, 3 Nov 2025 00:16:54 +0300 Subject: [PATCH] Add LLM translator support alongside Yandex --- backend/app/documents/query.py | 10 +- backend/app/models.py | 16 +- backend/app/settings.py | 11 + backend/app/translators/common.py | 11 + backend/app/translators/llm.py | 125 ++++++ backend/app/translators/yandex.py | 12 +- backend/requirements.txt | 1 + .../tests/routers/test_routes_documents.py | 8 +- backend/tests/test_llm.py | 409 ++++++++++++++++++ backend/tests/test_worker.py | 17 +- backend/worker.py | 79 ++-- 11 files changed, 643 insertions(+), 56 deletions(-) create mode 100644 backend/app/translators/common.py create mode 100644 backend/app/translators/llm.py create mode 100644 backend/tests/test_llm.py diff --git a/backend/app/documents/query.py b/backend/app/documents/query.py index ed8b52a..38ff0b8 100644 --- a/backend/app/documents/query.py +++ b/backend/app/documents/query.py @@ -1,7 +1,7 @@ from datetime import datetime from typing import Iterable -from sqlalchemy import case, func, select, Row +from sqlalchemy import Row, case, func, select from sqlalchemy.orm import Session from app.base.exceptions import BaseQueryException @@ -96,7 +96,7 @@ def get_document_records_paged( repetitions_subquery = ( select( DocumentRecord.source, - func.count(DocumentRecord.id).label('repetitions_count') + func.count(DocumentRecord.id).label("repetitions_count"), ) .filter(DocumentRecord.document_id == doc.id) .group_by(DocumentRecord.source) @@ -106,12 +106,14 @@ def get_document_records_paged( return self.__db.execute( select( DocumentRecord, - func.coalesce(repetitions_subquery.c.repetitions_count, 0).label('repetitions_count') + func.coalesce(repetitions_subquery.c.repetitions_count, 0).label( + "repetitions_count" + ), ) .filter(DocumentRecord.document_id == doc.id) .outerjoin( repetitions_subquery, - DocumentRecord.source == repetitions_subquery.c.source + DocumentRecord.source == repetitions_subquery.c.source, ) .order_by(DocumentRecord.id) .offset(page_records * page) diff --git a/backend/app/models.py b/backend/app/models.py index 800d16f..fd4e2b9 100644 --- a/backend/app/models.py +++ b/backend/app/models.py @@ -1,4 +1,5 @@ from enum import Enum +from typing import Literal from pydantic import BaseModel, ConfigDict, EmailStr @@ -27,14 +28,21 @@ def get_values(cls): return tuple(role.value for role in cls) -class MachineTranslationSettings(BaseModel): - # Yandex only for now - # source_language: str - # target_language: str +class LlmTranslatorSettings(BaseModel): + type: Literal["llm"] + api_key: str + # base_api: str # reserved for universal OpenAI translation + + +class YandexTranslatorSettings(BaseModel): + type: Literal["yandex"] folder_id: str oauth_token: str +MachineTranslationSettings = LlmTranslatorSettings | YandexTranslatorSettings + + class StatusMessage(BaseModel): message: str diff --git a/backend/app/settings.py b/backend/app/settings.py index 1abf74b..85229d2 100644 --- a/backend/app/settings.py +++ b/backend/app/settings.py @@ -1,3 +1,4 @@ +import base64 from typing import Literal from pydantic_settings import BaseSettings @@ -17,5 +18,15 @@ class Settings(BaseSettings): "http://localhost:8000", ) + llm_base_api: str | None = None + llm_model: str | None = None + llm_base64_prompt: str | None = None + + @property + def llm_prompt(self): + if not self.llm_base64_prompt: + return "" + base64.decodebytes(self.llm_base64_prompt.encode()).decode() + settings = Settings() diff --git a/backend/app/translators/common.py b/backend/app/translators/common.py new file mode 100644 index 0000000..3097b85 --- /dev/null +++ b/backend/app/translators/common.py @@ -0,0 +1,11 @@ +class TranslationError(Exception): + """ + An error raised when Machine Translator API returns an error. + """ + + +Original = str +Translation = str +GlossaryPair = tuple[Original, Translation] +GlossaryPairs = list[GlossaryPair] +LineWithGlossaries = tuple[str, GlossaryPairs] diff --git a/backend/app/translators/llm.py b/backend/app/translators/llm.py new file mode 100644 index 0000000..e63e999 --- /dev/null +++ b/backend/app/translators/llm.py @@ -0,0 +1,125 @@ +# This is a great idea to make a universal translator using OpenAI API, instead +# of specific translators for every service, but this will require much effort +# to prepare and manage a library of prompts for every context size, so it is +# postponed for the future. + +import logging +import re + +from openai import OpenAI + +from app.settings import settings +from app.translators.common import LineWithGlossaries + + +def generate_prompt_prologue() -> str: + return settings.llm_prompt or "" + + +def generate_prompt_ctx( + lines: list[LineWithGlossaries], offset: int, ctx_size: int +) -> str: + start = max(offset - ctx_size, 0) + ctx_lines = [f"{line[0]}" for line in lines[start:offset]] + return f"{'\n'.join(ctx_lines)}" + + +def generate_prompt_glossary(lines: list[LineWithGlossaries]) -> str: + terms = {} + for _, glossary_term in lines: + for original, translation in glossary_term: + terms[original] = ( + f"{original}{translation}" + ) + + return f"{'\n'.join(terms.values())}" + + +def generate_prompt_task(lines: list[LineWithGlossaries]) -> str: + segments = [f"{line[0]}" for line in lines] + return f"{'\n'.join(segments)}" + + +def generate_prompt( + lines: list[LineWithGlossaries], offset: int, ctx_size: int, task_size: int +) -> tuple[str, int]: + task_lines = lines[offset : offset + task_size] + parts = [ + generate_prompt_prologue(), + generate_prompt_ctx(lines, offset, ctx_size), + generate_prompt_glossary(task_lines), + generate_prompt_task(task_lines), + ] + return "\n\n".join(parts), len(task_lines) + + +def parse_lines(network_out: str, expected_size: int) -> tuple[list[str], bool]: + output = [] + + split = network_out.strip().splitlines() + if len(split) != expected_size: + logging.warning("Unexpected LLM output, not enough lines returned %s", split) + return [], False + + for line in split: + m = re.match(r"(.*)", line) + if not m: + logging.warning("Unexpected LLM output, not match found in %s", line) + return [], False + output.append(m.group(1)) + + return output, True + + +def translate_lines( + lines: list[LineWithGlossaries], api_key: str +) -> tuple[list[str], bool]: + """ + Translate lines of text using LLM translation. + + Args: + lines: A list of strings to be translated. + settings: An object containing LLM translation settings. + + Returns: + A list of translated strings. + """ + # when settings are not set, it fails immediately + if not settings.llm_base_api or not settings.llm_model: + return [], True + + client = OpenAI(api_key=api_key, base_url=settings.llm_base_api) + + output: list[str] = [] + + task_size = 40 + ctx_size = 40 + for offset in range(0, len(lines), task_size): + for attempt in range(3): + prompt, actual_size = generate_prompt(lines, offset, ctx_size, task_size) + completion = client.chat.completions.create( + model=settings.llm_model, + messages=[ + { + "role": "system", + # TODO: make it configurable + "content": "You are a smart translator from English to Russian.", + }, + {"role": "user", "content": prompt}, + ], + extra_body={"thinking": {"type": "disabled"}}, + ) + # parse output of the network + batch_lines, result = parse_lines( + completion.choices[0].message.content or "", actual_size + ) + if result: + output += batch_lines + break + logging.warning("Failed to get answer from LLM, attempt %s", attempt + 1) + else: + logging.error("Was unable to get answer from LLM, returning empty list") + for _ in range(task_size): + output.append("") + + return output, False diff --git a/backend/app/translators/yandex.py b/backend/app/translators/yandex.py index 4801498..9f8e305 100644 --- a/backend/app/translators/yandex.py +++ b/backend/app/translators/yandex.py @@ -7,6 +7,7 @@ from pydantic import BaseModel, PositiveInt, ValidationError from app.settings import settings +from app.translators.common import LineWithGlossaries, TranslationError class YandexTranslatorResponse(BaseModel): @@ -20,17 +21,6 @@ class YandexTranslatorResponse(BaseModel): translations: list[dict[str, str]] -class TranslationError(Exception): - """ - An error raised when Yandex Translator API returns an error. - """ - - -GlossaryPair = tuple[str, str] -GlossaryPairs = list[GlossaryPair] -LineWithGlossaries = tuple[str, GlossaryPairs] - - # Currently Yandex rejects requests larger than 10k symbols and more than # 50 records in glossary. def iterate_batches( diff --git a/backend/requirements.txt b/backend/requirements.txt index 90cbba0..cdac695 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -9,5 +9,6 @@ sqlalchemy==2.0.* psycopg2-binary==2.9.* requests==2.32.* itsdangerous==2.2.* +openai==2.6.* openpyxl==3.1.* nltk==3.9.* diff --git a/backend/tests/routers/test_routes_documents.py b/backend/tests/routers/test_routes_documents.py index 3e423c9..90e52b5 100644 --- a/backend/tests/routers/test_routes_documents.py +++ b/backend/tests/routers/test_routes_documents.py @@ -917,7 +917,9 @@ def test_setting_glossaries_returns_404_for_non_existing_glossaries( assert response.status_code == 404 -def test_get_doc_records_with_repetitions(user_logged_client: TestClient, session: Session): +def test_get_doc_records_with_repetitions( + user_logged_client: TestClient, session: Session +): """Test that document records endpoint returns repetition counts""" with session as s: records = [ @@ -947,7 +949,9 @@ def test_get_doc_records_with_repetitions(user_logged_client: TestClient, sessio # Check that repetition counts are correct # "Hello World" appears 3 times, others appear once - record_counts = {record["source"]: record["repetitions_count"] for record in response_json} + record_counts = { + record["source"]: record["repetitions_count"] for record in response_json + } assert record_counts["Hello World"] == 3 assert record_counts["Goodbye"] == 1 assert record_counts["Test"] == 1 diff --git a/backend/tests/test_llm.py b/backend/tests/test_llm.py new file mode 100644 index 0000000..4f515a9 --- /dev/null +++ b/backend/tests/test_llm.py @@ -0,0 +1,409 @@ +import logging +from unittest.mock import Mock, patch + +import pytest + +from app.settings import Settings +from app.translators import llm + +# pylint: disable=C0116 + + +@pytest.fixture(autouse=True) +def mock_llm_settings_autouse(monkeypatch): + """Mock the settings to provide test values for LLM functionality.""" + # Mock the settings object with test values + test_settings = Settings( + llm_base_api="https://api.test.com/v4", + llm_model="test-model", + llm_base64_prompt=None, + ) + + # Mock the llm_prompt property to return the expected test prompt + def mock_llm_prompt(): + return "You need to do the translation task from English to Russian. The prompt includes section, section, section, tags, tags, tags, and tags." + + # Replace the settings object and mock the llm_prompt property on the class + monkeypatch.setattr("app.settings.settings", test_settings) + monkeypatch.setattr( + Settings, "llm_prompt", property(lambda self: mock_llm_prompt()) + ) + + # Also need to mock the settings import in llm module + monkeypatch.setattr("app.translators.llm.settings", test_settings) + + yield test_settings + + +def test_generate_prompt_prologue(): + """Test that prologue generates correct system prompt.""" + result = llm.generate_prompt_prologue() + + assert "You need to do the translation task from English to Russian." in result + assert "" in result + assert "" in result + assert "" in result + assert "" in result + assert "" in result + assert "" in result + assert "" in result + + +def test_generate_prompt_ctx_middle_offset(): + """Test with offset in the middle""" + lines = [(f"line{i}", []) for i in range(10)] + + result = llm.generate_prompt_ctx(lines, 5, 3) + expected_lines = ["line2", "line3", "line4"] + assert result == "" + "\n".join(expected_lines) + "" + + +def test_generate_prompt_ctx_beginning(): + """Test with offset at the beginning""" + lines = [(f"line{i}", []) for i in range(10)] + + result = llm.generate_prompt_ctx(lines, 1, 3) + expected_lines = ["line0"] + assert result == "" + "\n".join(expected_lines) + "" + + +def test_generate_prompt_ctx_empty(): + """Test with no context""" + lines = [(f"line{i}", []) for i in range(10)] + + result = llm.generate_prompt_ctx(lines, 0, 3) + assert result == "" + + +def test_generate_prompt_glossary(): + """Test glossary XML generation.""" + lines = [ + ("line1", [("hello", "привет"), ("world", "мир")]), + ("line2", [("test", "тест"), ("hello", "привет")]), # duplicate term + ] + + result = llm.generate_prompt_glossary(lines) + + assert "" in result + assert "" in result + assert "helloпривет" in result + assert "worldмир" in result + assert "testтест" in result + # Should not have duplicates + assert result.count("helloпривет") == 1 + + +def test_generate_prompt_glossary_empty(): + """Test glossary generation with no glossary terms.""" + lines = [("line1", []), ("line2", [])] + + result = llm.generate_prompt_glossary(lines) + + assert result == "" + + +def test_generate_prompt_task(): + """Test task XML generation.""" + lines = [("line1", []), ("line2", [])] + + result = llm.generate_prompt_task(lines) + + assert "" in result + assert "" in result + assert "line1" in result + assert "line2" in result + + +def test_generate_prompt(): + """Test full prompt generation.""" + lines = [(f"line{i}", [("test", "тест")]) for i in range(10)] + + prompt, actual_size = llm.generate_prompt(lines, 5, 3, 2) + + assert "You need to do the translation task from English to Russian." in prompt + # The prologue contains "" text, so we expect 2 occurrences: one in prologue description, one in actual context + assert prompt.count("") == 2 + # The prologue contains "" text, so we expect 2 occurrences: one in prologue description, one in actual task + assert prompt.count("") == 2 + # The prologue contains "" text, so we expect 2 occurrences: one in prologue description, one in actual glossary + assert prompt.count("") == 2 + assert actual_size == 2 + + # Check that task lines are correct + assert "line5" in prompt + assert "line6" in prompt + assert "line7" not in prompt + + +def test_generate_prompt_ctx_negative_offset(): + """Test context generation with negative offset.""" + lines = [(f"line{i}", []) for i in range(5)] + + result = llm.generate_prompt_ctx(lines, -1, 3) + + # With negative offset, start becomes 0, so it includes all available lines + expected_lines = [ + "line0", + "line1", + "line2", + "line3", + ] + assert result == "" + "\n".join(expected_lines) + "" + + +def test_generate_prompt_ctx_large_context(): + """Test context generation when requested context is larger than available.""" + lines = [(f"line{i}", []) for i in range(3)] + + result = llm.generate_prompt_ctx(lines, 2, 10) + + # Should only include available lines before offset + expected_lines = ["line0", "line1"] + assert result == "" + "\n".join(expected_lines) + "" + + +def test_parse_lines_success(): + """Test successful parsing of LLM output.""" + network_out = ( + "translation1\ntranslation2\ntranslation3" + ) + expected_size = 3 + + result, success = llm.parse_lines(network_out, expected_size) + + assert success + assert result == ["translation1", "translation2", "translation3"] + + +def test_parse_lines_wrong_number(): + """Test parsing when number of lines doesn't match expected.""" + network_out = "translation1\ntranslation2" + expected_size = 3 + + result, success = llm.parse_lines(network_out, expected_size) + + assert not success + assert result == [] + + +def test_parse_lines_invalid_format(): + """Test parsing when lines don't match expected format.""" + network_out = "invalid line\ntranslation2" + expected_size = 2 + + result, success = llm.parse_lines(network_out, expected_size) + + assert not success + assert result == [] + + +def test_parse_lines_empty_content(): + """Test parsing when seg content is empty.""" + network_out = "\ntranslation2" + expected_size = 2 + + result, success = llm.parse_lines(network_out, expected_size) + + assert success + assert result == ["", "translation2"] + + +@patch("app.translators.llm.OpenAI") +def test_translate_lines_success(mock_openai): + """Test successful translation with mocked API.""" + # Mock the OpenAI client and response + mock_client = Mock() + mock_openai.return_value = mock_client + + mock_response = Mock() + mock_response.choices = [Mock()] + mock_response.choices[ + 0 + ].message.content = "translation1\ntranslation2" + mock_client.chat.completions.create.return_value = mock_response + + lines = [("line1", []), ("line2", [])] + + result, has_error = llm.translate_lines(lines, "test_api_key") + + assert not has_error + assert result == ["translation1", "translation2"] + mock_openai.assert_called_once_with( + api_key="test_api_key", base_url="https://api.test.com/v4" + ) + assert mock_client.chat.completions.create.call_count == 1 + + +@patch("app.translators.llm.OpenAI") +def test_translate_lines_with_glossaries(mock_openai): + """Test translation with glossary terms.""" + mock_client = Mock() + mock_openai.return_value = mock_client + + mock_response = Mock() + mock_response.choices = [Mock()] + mock_response.choices[0].message.content = "translation1" + mock_client.chat.completions.create.return_value = mock_response + + lines = [("hello world", [("hello", "привет")])] + + result, has_error = llm.translate_lines(lines, "test_api_key") + + assert not has_error + assert result == ["translation1"] + + # Check that glossary was included in the prompt + call_args = mock_client.chat.completions.create.call_args + prompt = call_args[1]["messages"][1]["content"] + assert "" in prompt + assert "helloпривет" in prompt + + +@patch("app.translators.llm.OpenAI") +def test_translate_lines_api_error_retry(mock_openai, caplog): + """Test translation with API error and retry logic.""" + mock_client = Mock() + mock_openai.return_value = mock_client + + # First call fails, second succeeds + mock_response_fail = Mock() + mock_response_fail.choices = [Mock()] + mock_response_fail.choices[0].message.content = "invalid output" + + mock_response_success = Mock() + mock_response_success.choices = [Mock()] + mock_response_success.choices[0].message.content = "translation1" + + mock_client.chat.completions.create.side_effect = [ + mock_response_fail, + mock_response_success, + ] + + lines = [("line1", [])] + + with caplog.at_level(logging.WARNING): + result, has_error = llm.translate_lines(lines, "test_api_key") + + assert not has_error + assert result == ["translation1"] + assert mock_client.chat.completions.create.call_count == 2 + assert "Failed to get answer from LLM, attempt 1" in caplog.text + + +@patch("app.translators.llm.OpenAI") +def test_translate_lines_all_attempts_fail(mock_openai, caplog): + """Test translation when all API attempts fail.""" + mock_client = Mock() + mock_openai.return_value = mock_client + + # All calls return invalid output + mock_response = Mock() + mock_response.choices = [Mock()] + mock_response.choices[0].message.content = "invalid output" + mock_client.chat.completions.create.return_value = mock_response + + lines = [("line1", []), ("line2", [])] + + with caplog.at_level(logging.ERROR): + result, has_error = llm.translate_lines(lines, "test_api_key") + + # Should return empty strings for all lines (task_size = 40) + assert len(result) == 40 # 40 empty strings from task_size + assert all(line == "" for line in result) # All should be empty + assert ( + not has_error + ) # Note: llm.translate_lines returns False even when all attempts fail + assert mock_client.chat.completions.create.call_count == 3 # 3 attempts + assert "Was unable to get answer from LLM, returning empty list" in caplog.text + + +@patch("app.translators.llm.OpenAI") +def test_translate_lines_large_batch(mock_openai): + """Test translation with large batch that gets split.""" + mock_client = Mock() + mock_openai.return_value = mock_client + + # Create simple successful responses for each batch + mock_response1 = Mock() + mock_response1.choices = [Mock()] + mock_response1.choices[0].message.content = "\n".join( + [f"translation{i}" for i in range(40)] + ) + + mock_response2 = Mock() + mock_response2.choices = [Mock()] + mock_response2.choices[0].message.content = "\n".join( + [f"translation{i}" for i in range(40)] + ) + + mock_response3 = Mock() + mock_response3.choices = [Mock()] + mock_response3.choices[0].message.content = "translation80" + + mock_client.chat.completions.create.side_effect = [ + mock_response1, + mock_response2, + mock_response3, + ] + + # Create 81 lines (should be split into 3 batches: 40, 40, 1) + lines = [(f"line{i}", []) for i in range(81)] + + result, has_error = llm.translate_lines(lines, "test_api_key") + + assert not has_error + assert len(result) == 81 + assert mock_client.chat.completions.create.call_count == 3 + + +@patch("app.translators.llm.OpenAI") +def test_translate_lines_empty_content(mock_openai): + """Test translation when API returns None content.""" + mock_client = Mock() + mock_openai.return_value = mock_client + + mock_response = Mock() + mock_response.choices = [Mock()] + mock_response.choices[0].message.content = None + mock_client.chat.completions.create.return_value = mock_response + + lines = [("line1", [])] + + result, has_error = llm.translate_lines(lines, "test_api_key") + + # Should handle None content gracefully - returns empty strings for all failed attempts + assert not has_error + assert len(result) == 40 # task_size = 40, all failed so 40 empty strings + assert all(line == "" for line in result) # All should be empty + assert mock_client.chat.completions.create.call_count == 3 # Should retry 3 times + + +@patch("app.translators.llm.OpenAI") +def test_translate_lines_context_generation(mock_openai): + """Test that context is properly included in the prompt.""" + mock_client = Mock() + mock_openai.return_value = mock_client + + # Create simple successful response for the batch + mock_response = Mock() + mock_response.choices = [Mock()] + mock_response.choices[0].message.content = "\n".join( + [f"translation{i}" for i in range(10)] + ) + mock_client.chat.completions.create.return_value = mock_response + + lines = [(f"line{i}", []) for i in range(10)] + + result, has_error = llm.translate_lines(lines, "test_api_key") + + assert not has_error + assert len(result) == 10 + assert "translation0" in result + + # Check that context was included in the prompt + call_args = mock_client.chat.completions.create.call_args + prompt = call_args[1]["messages"][1]["content"] + assert "" in prompt + # Should include previous lines as context + assert "line0" in prompt + assert "line1" in prompt diff --git a/backend/tests/test_worker.py b/backend/tests/test_worker.py index 3e1da8b..60d7e86 100644 --- a/backend/tests/test_worker.py +++ b/backend/tests/test_worker.py @@ -21,7 +21,10 @@ DocumentTaskDescription, ) from app.glossary.models import Glossary, GlossaryRecord -from app.models import DocumentStatus, MachineTranslationSettings +from app.models import ( + DocumentStatus, + YandexTranslatorSettings, +) from app.schema import DocumentTask from app.translation_memory.models import TranslationMemory, TranslationMemoryRecord from app.translation_memory.schema import TranslationMemoryUsage @@ -53,7 +56,7 @@ def create_task( type_: Literal["xliff", "txt"] = "xliff", usage: TranslationMemoryUsage = TranslationMemoryUsage.NEWEST, substitute_numbers: bool = False, - mt_settings: MachineTranslationSettings | None = None, + mt_settings: YandexTranslatorSettings | None = None, ): return DocumentTask( data=DocumentTaskDescription( @@ -273,7 +276,7 @@ def test_process_task_uses_correct_tm_ids(session: Session): create_doc(name="small.xliff", type_=DocumentType.xliff), create_xliff_doc(file_data), create_task(), - DocMemoryAssociation(doc_id=1, tm_id=2, mode='read') + DocMemoryAssociation(doc_id=1, tm_id=2, mode="read"), ] ) s.commit() @@ -318,8 +321,8 @@ def test_process_task_uses_tm_mode(mode: str, trans_result: str, session: Sessio create_doc(name="small.xliff", type_=DocumentType.xliff), create_xliff_doc(file_data), create_task(usage=TranslationMemoryUsage(mode)), - DocMemoryAssociation(doc_id=1, tm_id=1, mode='read'), - DocMemoryAssociation(doc_id=1, tm_id=2, mode='read') + DocMemoryAssociation(doc_id=1, tm_id=1, mode="read"), + DocMemoryAssociation(doc_id=1, tm_id=2, mode="read"), ] ) s.commit() @@ -425,8 +428,8 @@ def test_process_task_puts_doc_in_error_state(monkeypatch, session: Session): create_doc(name="small.xliff", type_=DocumentType.xliff), create_xliff_doc(file_data), create_task( - mt_settings=MachineTranslationSettings( - folder_id="12345", oauth_token="fake" + mt_settings=YandexTranslatorSettings( + type="yandex", folder_id="12345", oauth_token="fake" ), ), ] diff --git a/backend/worker.py b/backend/worker.py index b4950f2..465cf6c 100644 --- a/backend/worker.py +++ b/backend/worker.py @@ -29,7 +29,8 @@ from app.translation_memory.models import TranslationMemoryRecord from app.translation_memory.query import TranslationMemoryQuery from app.translation_memory.schema import TranslationMemoryUsage -from app.translators import yandex +from app.translators import llm, yandex +from app.translators.common import LineWithGlossaries def segment_needs_processing(segment: BaseSegment) -> bool: @@ -106,7 +107,9 @@ def process_document( ) start_time = time.time() - translate_indices = substitute_segments(settings, session, segments, tm_ids, glossary_ids) + translate_indices = substitute_segments( + settings, session, segments, tm_ids, glossary_ids + ) logging.info( "Segments substitution time: %.2f seconds, speed: %.2f segment/second, segments: %d/%d", time.time() - start_time, @@ -116,12 +119,16 @@ def process_document( ) start_time = time.time() - mt_result = translate_segments( - segments, - translate_indices, - glossary_ids, - settings.machine_translation_settings, - session, + mt_result = ( + translate_segments( + segments, + translate_indices, + glossary_ids, + settings.machine_translation_settings, + session, + ) + if settings.machine_translation_settings is not None + else True ) logging.info( "Machine translation time: %.2f seconds, speed: %.2f segment/second, segments: %d/%d", @@ -193,33 +200,45 @@ def translate_segments( segments: Sequence[BaseSegment], translate_indices: Sequence[int], glossary_ids: list[int], - mt_settings: MachineTranslationSettings | None, + mt_settings: MachineTranslationSettings, session: Session, ) -> bool: - # TODO: it is better to make solution more translation service agnostic + if not translate_indices: + return True + mt_failed = False - if mt_settings and translate_indices: - try: - original_segments = [segments[idx].original for idx in translate_indices] - data_to_translate: list[yandex.LineWithGlossaries] = [] - for segment in original_segments: - glossary_records = GlossaryQuery( - session - ).get_glossary_records_for_segment(segment, glossary_ids) - data_to_translate.append( - (segment, [(x.source, x.target) for x in glossary_records]) - ) + try: + # TODO: this might be harmful with LLM translation as it is loses + # the connectivity of the context + segments_to_translate = [segments[idx].original for idx in translate_indices] + data_to_translate: list[LineWithGlossaries] = [] + for segment in segments_to_translate: + glossary_records = GlossaryQuery(session).get_glossary_records_for_segment( + segment, glossary_ids + ) + data_to_translate.append( + (segment, [(x.source, x.target) for x in glossary_records]) + ) + if mt_settings.type == "yandex": translated, mt_failed = yandex.translate_lines( data_to_translate, oauth_token=mt_settings.oauth_token, folder_id=mt_settings.folder_id, ) - for idx, translated_line in enumerate(translated): - segments[translate_indices[idx]].translation = translated_line - # TODO: handle specific exceptions instead of a generic one - except Exception as e: - logging.error("Yandex translation error %s", e) - return False + elif mt_settings.type == "llm": + translated, mt_failed = llm.translate_lines( + data_to_translate, api_key=mt_settings.api_key + ) + else: + logging.fatal("Unknown translation API") + raise RuntimeError("Unknown translation API") + for idx, translated_line in enumerate(translated): + segments[translate_indices[idx]].translation = translated_line + # TODO: handle specific exceptions instead of a generic one + except Exception as e: + logging.error("Machine translation error %s", e) + return False + return not mt_failed @@ -317,7 +336,11 @@ def process_task(session: Session, task: DocumentTask) -> bool: def main(): - logging.basicConfig(level=logging.INFO, format="%(message)s") + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) logging.info("Starting document processing") session = next(get_db())