diff --git a/backend/app/documents/query.py b/backend/app/documents/query.py
index ed8b52a..38ff0b8 100644
--- a/backend/app/documents/query.py
+++ b/backend/app/documents/query.py
@@ -1,7 +1,7 @@
from datetime import datetime
from typing import Iterable
-from sqlalchemy import case, func, select, Row
+from sqlalchemy import Row, case, func, select
from sqlalchemy.orm import Session
from app.base.exceptions import BaseQueryException
@@ -96,7 +96,7 @@ def get_document_records_paged(
repetitions_subquery = (
select(
DocumentRecord.source,
- func.count(DocumentRecord.id).label('repetitions_count')
+ func.count(DocumentRecord.id).label("repetitions_count"),
)
.filter(DocumentRecord.document_id == doc.id)
.group_by(DocumentRecord.source)
@@ -106,12 +106,14 @@ def get_document_records_paged(
return self.__db.execute(
select(
DocumentRecord,
- func.coalesce(repetitions_subquery.c.repetitions_count, 0).label('repetitions_count')
+ func.coalesce(repetitions_subquery.c.repetitions_count, 0).label(
+ "repetitions_count"
+ ),
)
.filter(DocumentRecord.document_id == doc.id)
.outerjoin(
repetitions_subquery,
- DocumentRecord.source == repetitions_subquery.c.source
+ DocumentRecord.source == repetitions_subquery.c.source,
)
.order_by(DocumentRecord.id)
.offset(page_records * page)
diff --git a/backend/app/models.py b/backend/app/models.py
index 800d16f..fd4e2b9 100644
--- a/backend/app/models.py
+++ b/backend/app/models.py
@@ -1,4 +1,5 @@
from enum import Enum
+from typing import Literal
from pydantic import BaseModel, ConfigDict, EmailStr
@@ -27,14 +28,21 @@ def get_values(cls):
return tuple(role.value for role in cls)
-class MachineTranslationSettings(BaseModel):
- # Yandex only for now
- # source_language: str
- # target_language: str
+class LlmTranslatorSettings(BaseModel):
+ type: Literal["llm"]
+ api_key: str
+ # base_api: str # reserved for universal OpenAI translation
+
+
+class YandexTranslatorSettings(BaseModel):
+ type: Literal["yandex"]
folder_id: str
oauth_token: str
+MachineTranslationSettings = LlmTranslatorSettings | YandexTranslatorSettings
+
+
class StatusMessage(BaseModel):
message: str
diff --git a/backend/app/settings.py b/backend/app/settings.py
index 1abf74b..85229d2 100644
--- a/backend/app/settings.py
+++ b/backend/app/settings.py
@@ -1,3 +1,4 @@
+import base64
from typing import Literal
from pydantic_settings import BaseSettings
@@ -17,5 +18,15 @@ class Settings(BaseSettings):
"http://localhost:8000",
)
+ llm_base_api: str | None = None
+ llm_model: str | None = None
+ llm_base64_prompt: str | None = None
+
+ @property
+ def llm_prompt(self):
+ if not self.llm_base64_prompt:
+ return ""
+ base64.decodebytes(self.llm_base64_prompt.encode()).decode()
+
settings = Settings()
diff --git a/backend/app/translators/common.py b/backend/app/translators/common.py
new file mode 100644
index 0000000..3097b85
--- /dev/null
+++ b/backend/app/translators/common.py
@@ -0,0 +1,11 @@
+class TranslationError(Exception):
+ """
+ An error raised when Machine Translator API returns an error.
+ """
+
+
+Original = str
+Translation = str
+GlossaryPair = tuple[Original, Translation]
+GlossaryPairs = list[GlossaryPair]
+LineWithGlossaries = tuple[str, GlossaryPairs]
diff --git a/backend/app/translators/llm.py b/backend/app/translators/llm.py
new file mode 100644
index 0000000..e63e999
--- /dev/null
+++ b/backend/app/translators/llm.py
@@ -0,0 +1,125 @@
+# This is a great idea to make a universal translator using OpenAI API, instead
+# of specific translators for every service, but this will require much effort
+# to prepare and manage a library of prompts for every context size, so it is
+# postponed for the future.
+
+import logging
+import re
+
+from openai import OpenAI
+
+from app.settings import settings
+from app.translators.common import LineWithGlossaries
+
+
+def generate_prompt_prologue() -> str:
+ return settings.llm_prompt or ""
+
+
+def generate_prompt_ctx(
+ lines: list[LineWithGlossaries], offset: int, ctx_size: int
+) -> str:
+ start = max(offset - ctx_size, 0)
+ ctx_lines = [f"{line[0]}" for line in lines[start:offset]]
+ return f"{'\n'.join(ctx_lines)}"
+
+
+def generate_prompt_glossary(lines: list[LineWithGlossaries]) -> str:
+ terms = {}
+ for _, glossary_term in lines:
+ for original, translation in glossary_term:
+ terms[original] = (
+ f"{original}{translation}"
+ )
+
+ return f"{'\n'.join(terms.values())}"
+
+
+def generate_prompt_task(lines: list[LineWithGlossaries]) -> str:
+ segments = [f"{line[0]}" for line in lines]
+ return f"{'\n'.join(segments)}"
+
+
+def generate_prompt(
+ lines: list[LineWithGlossaries], offset: int, ctx_size: int, task_size: int
+) -> tuple[str, int]:
+ task_lines = lines[offset : offset + task_size]
+ parts = [
+ generate_prompt_prologue(),
+ generate_prompt_ctx(lines, offset, ctx_size),
+ generate_prompt_glossary(task_lines),
+ generate_prompt_task(task_lines),
+ ]
+ return "\n\n".join(parts), len(task_lines)
+
+
+def parse_lines(network_out: str, expected_size: int) -> tuple[list[str], bool]:
+ output = []
+
+ split = network_out.strip().splitlines()
+ if len(split) != expected_size:
+ logging.warning("Unexpected LLM output, not enough lines returned %s", split)
+ return [], False
+
+ for line in split:
+ m = re.match(r"(.*)", line)
+ if not m:
+ logging.warning("Unexpected LLM output, not match found in %s", line)
+ return [], False
+ output.append(m.group(1))
+
+ return output, True
+
+
+def translate_lines(
+ lines: list[LineWithGlossaries], api_key: str
+) -> tuple[list[str], bool]:
+ """
+ Translate lines of text using LLM translation.
+
+ Args:
+ lines: A list of strings to be translated.
+ settings: An object containing LLM translation settings.
+
+ Returns:
+ A list of translated strings.
+ """
+ # when settings are not set, it fails immediately
+ if not settings.llm_base_api or not settings.llm_model:
+ return [], True
+
+ client = OpenAI(api_key=api_key, base_url=settings.llm_base_api)
+
+ output: list[str] = []
+
+ task_size = 40
+ ctx_size = 40
+ for offset in range(0, len(lines), task_size):
+ for attempt in range(3):
+ prompt, actual_size = generate_prompt(lines, offset, ctx_size, task_size)
+ completion = client.chat.completions.create(
+ model=settings.llm_model,
+ messages=[
+ {
+ "role": "system",
+ # TODO: make it configurable
+ "content": "You are a smart translator from English to Russian.",
+ },
+ {"role": "user", "content": prompt},
+ ],
+ extra_body={"thinking": {"type": "disabled"}},
+ )
+ # parse output of the network
+ batch_lines, result = parse_lines(
+ completion.choices[0].message.content or "", actual_size
+ )
+ if result:
+ output += batch_lines
+ break
+ logging.warning("Failed to get answer from LLM, attempt %s", attempt + 1)
+ else:
+ logging.error("Was unable to get answer from LLM, returning empty list")
+ for _ in range(task_size):
+ output.append("")
+
+ return output, False
diff --git a/backend/app/translators/yandex.py b/backend/app/translators/yandex.py
index 4801498..9f8e305 100644
--- a/backend/app/translators/yandex.py
+++ b/backend/app/translators/yandex.py
@@ -7,6 +7,7 @@
from pydantic import BaseModel, PositiveInt, ValidationError
from app.settings import settings
+from app.translators.common import LineWithGlossaries, TranslationError
class YandexTranslatorResponse(BaseModel):
@@ -20,17 +21,6 @@ class YandexTranslatorResponse(BaseModel):
translations: list[dict[str, str]]
-class TranslationError(Exception):
- """
- An error raised when Yandex Translator API returns an error.
- """
-
-
-GlossaryPair = tuple[str, str]
-GlossaryPairs = list[GlossaryPair]
-LineWithGlossaries = tuple[str, GlossaryPairs]
-
-
# Currently Yandex rejects requests larger than 10k symbols and more than
# 50 records in glossary.
def iterate_batches(
diff --git a/backend/requirements.txt b/backend/requirements.txt
index 90cbba0..cdac695 100644
--- a/backend/requirements.txt
+++ b/backend/requirements.txt
@@ -9,5 +9,6 @@ sqlalchemy==2.0.*
psycopg2-binary==2.9.*
requests==2.32.*
itsdangerous==2.2.*
+openai==2.6.*
openpyxl==3.1.*
nltk==3.9.*
diff --git a/backend/tests/routers/test_routes_documents.py b/backend/tests/routers/test_routes_documents.py
index 3e423c9..90e52b5 100644
--- a/backend/tests/routers/test_routes_documents.py
+++ b/backend/tests/routers/test_routes_documents.py
@@ -917,7 +917,9 @@ def test_setting_glossaries_returns_404_for_non_existing_glossaries(
assert response.status_code == 404
-def test_get_doc_records_with_repetitions(user_logged_client: TestClient, session: Session):
+def test_get_doc_records_with_repetitions(
+ user_logged_client: TestClient, session: Session
+):
"""Test that document records endpoint returns repetition counts"""
with session as s:
records = [
@@ -947,7 +949,9 @@ def test_get_doc_records_with_repetitions(user_logged_client: TestClient, sessio
# Check that repetition counts are correct
# "Hello World" appears 3 times, others appear once
- record_counts = {record["source"]: record["repetitions_count"] for record in response_json}
+ record_counts = {
+ record["source"]: record["repetitions_count"] for record in response_json
+ }
assert record_counts["Hello World"] == 3
assert record_counts["Goodbye"] == 1
assert record_counts["Test"] == 1
diff --git a/backend/tests/test_llm.py b/backend/tests/test_llm.py
new file mode 100644
index 0000000..4f515a9
--- /dev/null
+++ b/backend/tests/test_llm.py
@@ -0,0 +1,409 @@
+import logging
+from unittest.mock import Mock, patch
+
+import pytest
+
+from app.settings import Settings
+from app.translators import llm
+
+# pylint: disable=C0116
+
+
+@pytest.fixture(autouse=True)
+def mock_llm_settings_autouse(monkeypatch):
+ """Mock the settings to provide test values for LLM functionality."""
+ # Mock the settings object with test values
+ test_settings = Settings(
+ llm_base_api="https://api.test.com/v4",
+ llm_model="test-model",
+ llm_base64_prompt=None,
+ )
+
+ # Mock the llm_prompt property to return the expected test prompt
+ def mock_llm_prompt():
+ return "You need to do the translation task from English to Russian. The prompt includes section, section, section, tags, tags, tags, and tags."
+
+ # Replace the settings object and mock the llm_prompt property on the class
+ monkeypatch.setattr("app.settings.settings", test_settings)
+ monkeypatch.setattr(
+ Settings, "llm_prompt", property(lambda self: mock_llm_prompt())
+ )
+
+ # Also need to mock the settings import in llm module
+ monkeypatch.setattr("app.translators.llm.settings", test_settings)
+
+ yield test_settings
+
+
+def test_generate_prompt_prologue():
+ """Test that prologue generates correct system prompt."""
+ result = llm.generate_prompt_prologue()
+
+ assert "You need to do the translation task from English to Russian." in result
+ assert "" in result
+ assert "" in result
+ assert "" in result
+ assert "" in result
+ assert "" in result
+ assert "" in result
+ assert "" in result
+
+
+def test_generate_prompt_ctx_middle_offset():
+ """Test with offset in the middle"""
+ lines = [(f"line{i}", []) for i in range(10)]
+
+ result = llm.generate_prompt_ctx(lines, 5, 3)
+ expected_lines = ["line2", "line3", "line4"]
+ assert result == "" + "\n".join(expected_lines) + ""
+
+
+def test_generate_prompt_ctx_beginning():
+ """Test with offset at the beginning"""
+ lines = [(f"line{i}", []) for i in range(10)]
+
+ result = llm.generate_prompt_ctx(lines, 1, 3)
+ expected_lines = ["line0"]
+ assert result == "" + "\n".join(expected_lines) + ""
+
+
+def test_generate_prompt_ctx_empty():
+ """Test with no context"""
+ lines = [(f"line{i}", []) for i in range(10)]
+
+ result = llm.generate_prompt_ctx(lines, 0, 3)
+ assert result == ""
+
+
+def test_generate_prompt_glossary():
+ """Test glossary XML generation."""
+ lines = [
+ ("line1", [("hello", "привет"), ("world", "мир")]),
+ ("line2", [("test", "тест"), ("hello", "привет")]), # duplicate term
+ ]
+
+ result = llm.generate_prompt_glossary(lines)
+
+ assert "" in result
+ assert "" in result
+ assert "helloпривет" in result
+ assert "worldмир" in result
+ assert "testтест" in result
+ # Should not have duplicates
+ assert result.count("helloпривет") == 1
+
+
+def test_generate_prompt_glossary_empty():
+ """Test glossary generation with no glossary terms."""
+ lines = [("line1", []), ("line2", [])]
+
+ result = llm.generate_prompt_glossary(lines)
+
+ assert result == ""
+
+
+def test_generate_prompt_task():
+ """Test task XML generation."""
+ lines = [("line1", []), ("line2", [])]
+
+ result = llm.generate_prompt_task(lines)
+
+ assert "" in result
+ assert "" in result
+ assert "line1" in result
+ assert "line2" in result
+
+
+def test_generate_prompt():
+ """Test full prompt generation."""
+ lines = [(f"line{i}", [("test", "тест")]) for i in range(10)]
+
+ prompt, actual_size = llm.generate_prompt(lines, 5, 3, 2)
+
+ assert "You need to do the translation task from English to Russian." in prompt
+ # The prologue contains "" text, so we expect 2 occurrences: one in prologue description, one in actual context
+ assert prompt.count("") == 2
+ # The prologue contains "" text, so we expect 2 occurrences: one in prologue description, one in actual task
+ assert prompt.count("") == 2
+ # The prologue contains "" text, so we expect 2 occurrences: one in prologue description, one in actual glossary
+ assert prompt.count("") == 2
+ assert actual_size == 2
+
+ # Check that task lines are correct
+ assert "line5" in prompt
+ assert "line6" in prompt
+ assert "line7" not in prompt
+
+
+def test_generate_prompt_ctx_negative_offset():
+ """Test context generation with negative offset."""
+ lines = [(f"line{i}", []) for i in range(5)]
+
+ result = llm.generate_prompt_ctx(lines, -1, 3)
+
+ # With negative offset, start becomes 0, so it includes all available lines
+ expected_lines = [
+ "line0",
+ "line1",
+ "line2",
+ "line3",
+ ]
+ assert result == "" + "\n".join(expected_lines) + ""
+
+
+def test_generate_prompt_ctx_large_context():
+ """Test context generation when requested context is larger than available."""
+ lines = [(f"line{i}", []) for i in range(3)]
+
+ result = llm.generate_prompt_ctx(lines, 2, 10)
+
+ # Should only include available lines before offset
+ expected_lines = ["line0", "line1"]
+ assert result == "" + "\n".join(expected_lines) + ""
+
+
+def test_parse_lines_success():
+ """Test successful parsing of LLM output."""
+ network_out = (
+ "translation1\ntranslation2\ntranslation3"
+ )
+ expected_size = 3
+
+ result, success = llm.parse_lines(network_out, expected_size)
+
+ assert success
+ assert result == ["translation1", "translation2", "translation3"]
+
+
+def test_parse_lines_wrong_number():
+ """Test parsing when number of lines doesn't match expected."""
+ network_out = "translation1\ntranslation2"
+ expected_size = 3
+
+ result, success = llm.parse_lines(network_out, expected_size)
+
+ assert not success
+ assert result == []
+
+
+def test_parse_lines_invalid_format():
+ """Test parsing when lines don't match expected format."""
+ network_out = "invalid line\ntranslation2"
+ expected_size = 2
+
+ result, success = llm.parse_lines(network_out, expected_size)
+
+ assert not success
+ assert result == []
+
+
+def test_parse_lines_empty_content():
+ """Test parsing when seg content is empty."""
+ network_out = "\ntranslation2"
+ expected_size = 2
+
+ result, success = llm.parse_lines(network_out, expected_size)
+
+ assert success
+ assert result == ["", "translation2"]
+
+
+@patch("app.translators.llm.OpenAI")
+def test_translate_lines_success(mock_openai):
+ """Test successful translation with mocked API."""
+ # Mock the OpenAI client and response
+ mock_client = Mock()
+ mock_openai.return_value = mock_client
+
+ mock_response = Mock()
+ mock_response.choices = [Mock()]
+ mock_response.choices[
+ 0
+ ].message.content = "translation1\ntranslation2"
+ mock_client.chat.completions.create.return_value = mock_response
+
+ lines = [("line1", []), ("line2", [])]
+
+ result, has_error = llm.translate_lines(lines, "test_api_key")
+
+ assert not has_error
+ assert result == ["translation1", "translation2"]
+ mock_openai.assert_called_once_with(
+ api_key="test_api_key", base_url="https://api.test.com/v4"
+ )
+ assert mock_client.chat.completions.create.call_count == 1
+
+
+@patch("app.translators.llm.OpenAI")
+def test_translate_lines_with_glossaries(mock_openai):
+ """Test translation with glossary terms."""
+ mock_client = Mock()
+ mock_openai.return_value = mock_client
+
+ mock_response = Mock()
+ mock_response.choices = [Mock()]
+ mock_response.choices[0].message.content = "translation1"
+ mock_client.chat.completions.create.return_value = mock_response
+
+ lines = [("hello world", [("hello", "привет")])]
+
+ result, has_error = llm.translate_lines(lines, "test_api_key")
+
+ assert not has_error
+ assert result == ["translation1"]
+
+ # Check that glossary was included in the prompt
+ call_args = mock_client.chat.completions.create.call_args
+ prompt = call_args[1]["messages"][1]["content"]
+ assert "" in prompt
+ assert "helloпривет" in prompt
+
+
+@patch("app.translators.llm.OpenAI")
+def test_translate_lines_api_error_retry(mock_openai, caplog):
+ """Test translation with API error and retry logic."""
+ mock_client = Mock()
+ mock_openai.return_value = mock_client
+
+ # First call fails, second succeeds
+ mock_response_fail = Mock()
+ mock_response_fail.choices = [Mock()]
+ mock_response_fail.choices[0].message.content = "invalid output"
+
+ mock_response_success = Mock()
+ mock_response_success.choices = [Mock()]
+ mock_response_success.choices[0].message.content = "translation1"
+
+ mock_client.chat.completions.create.side_effect = [
+ mock_response_fail,
+ mock_response_success,
+ ]
+
+ lines = [("line1", [])]
+
+ with caplog.at_level(logging.WARNING):
+ result, has_error = llm.translate_lines(lines, "test_api_key")
+
+ assert not has_error
+ assert result == ["translation1"]
+ assert mock_client.chat.completions.create.call_count == 2
+ assert "Failed to get answer from LLM, attempt 1" in caplog.text
+
+
+@patch("app.translators.llm.OpenAI")
+def test_translate_lines_all_attempts_fail(mock_openai, caplog):
+ """Test translation when all API attempts fail."""
+ mock_client = Mock()
+ mock_openai.return_value = mock_client
+
+ # All calls return invalid output
+ mock_response = Mock()
+ mock_response.choices = [Mock()]
+ mock_response.choices[0].message.content = "invalid output"
+ mock_client.chat.completions.create.return_value = mock_response
+
+ lines = [("line1", []), ("line2", [])]
+
+ with caplog.at_level(logging.ERROR):
+ result, has_error = llm.translate_lines(lines, "test_api_key")
+
+ # Should return empty strings for all lines (task_size = 40)
+ assert len(result) == 40 # 40 empty strings from task_size
+ assert all(line == "" for line in result) # All should be empty
+ assert (
+ not has_error
+ ) # Note: llm.translate_lines returns False even when all attempts fail
+ assert mock_client.chat.completions.create.call_count == 3 # 3 attempts
+ assert "Was unable to get answer from LLM, returning empty list" in caplog.text
+
+
+@patch("app.translators.llm.OpenAI")
+def test_translate_lines_large_batch(mock_openai):
+ """Test translation with large batch that gets split."""
+ mock_client = Mock()
+ mock_openai.return_value = mock_client
+
+ # Create simple successful responses for each batch
+ mock_response1 = Mock()
+ mock_response1.choices = [Mock()]
+ mock_response1.choices[0].message.content = "\n".join(
+ [f"translation{i}" for i in range(40)]
+ )
+
+ mock_response2 = Mock()
+ mock_response2.choices = [Mock()]
+ mock_response2.choices[0].message.content = "\n".join(
+ [f"translation{i}" for i in range(40)]
+ )
+
+ mock_response3 = Mock()
+ mock_response3.choices = [Mock()]
+ mock_response3.choices[0].message.content = "translation80"
+
+ mock_client.chat.completions.create.side_effect = [
+ mock_response1,
+ mock_response2,
+ mock_response3,
+ ]
+
+ # Create 81 lines (should be split into 3 batches: 40, 40, 1)
+ lines = [(f"line{i}", []) for i in range(81)]
+
+ result, has_error = llm.translate_lines(lines, "test_api_key")
+
+ assert not has_error
+ assert len(result) == 81
+ assert mock_client.chat.completions.create.call_count == 3
+
+
+@patch("app.translators.llm.OpenAI")
+def test_translate_lines_empty_content(mock_openai):
+ """Test translation when API returns None content."""
+ mock_client = Mock()
+ mock_openai.return_value = mock_client
+
+ mock_response = Mock()
+ mock_response.choices = [Mock()]
+ mock_response.choices[0].message.content = None
+ mock_client.chat.completions.create.return_value = mock_response
+
+ lines = [("line1", [])]
+
+ result, has_error = llm.translate_lines(lines, "test_api_key")
+
+ # Should handle None content gracefully - returns empty strings for all failed attempts
+ assert not has_error
+ assert len(result) == 40 # task_size = 40, all failed so 40 empty strings
+ assert all(line == "" for line in result) # All should be empty
+ assert mock_client.chat.completions.create.call_count == 3 # Should retry 3 times
+
+
+@patch("app.translators.llm.OpenAI")
+def test_translate_lines_context_generation(mock_openai):
+ """Test that context is properly included in the prompt."""
+ mock_client = Mock()
+ mock_openai.return_value = mock_client
+
+ # Create simple successful response for the batch
+ mock_response = Mock()
+ mock_response.choices = [Mock()]
+ mock_response.choices[0].message.content = "\n".join(
+ [f"translation{i}" for i in range(10)]
+ )
+ mock_client.chat.completions.create.return_value = mock_response
+
+ lines = [(f"line{i}", []) for i in range(10)]
+
+ result, has_error = llm.translate_lines(lines, "test_api_key")
+
+ assert not has_error
+ assert len(result) == 10
+ assert "translation0" in result
+
+ # Check that context was included in the prompt
+ call_args = mock_client.chat.completions.create.call_args
+ prompt = call_args[1]["messages"][1]["content"]
+ assert "" in prompt
+ # Should include previous lines as context
+ assert "line0" in prompt
+ assert "line1" in prompt
diff --git a/backend/tests/test_worker.py b/backend/tests/test_worker.py
index 3e1da8b..60d7e86 100644
--- a/backend/tests/test_worker.py
+++ b/backend/tests/test_worker.py
@@ -21,7 +21,10 @@
DocumentTaskDescription,
)
from app.glossary.models import Glossary, GlossaryRecord
-from app.models import DocumentStatus, MachineTranslationSettings
+from app.models import (
+ DocumentStatus,
+ YandexTranslatorSettings,
+)
from app.schema import DocumentTask
from app.translation_memory.models import TranslationMemory, TranslationMemoryRecord
from app.translation_memory.schema import TranslationMemoryUsage
@@ -53,7 +56,7 @@ def create_task(
type_: Literal["xliff", "txt"] = "xliff",
usage: TranslationMemoryUsage = TranslationMemoryUsage.NEWEST,
substitute_numbers: bool = False,
- mt_settings: MachineTranslationSettings | None = None,
+ mt_settings: YandexTranslatorSettings | None = None,
):
return DocumentTask(
data=DocumentTaskDescription(
@@ -273,7 +276,7 @@ def test_process_task_uses_correct_tm_ids(session: Session):
create_doc(name="small.xliff", type_=DocumentType.xliff),
create_xliff_doc(file_data),
create_task(),
- DocMemoryAssociation(doc_id=1, tm_id=2, mode='read')
+ DocMemoryAssociation(doc_id=1, tm_id=2, mode="read"),
]
)
s.commit()
@@ -318,8 +321,8 @@ def test_process_task_uses_tm_mode(mode: str, trans_result: str, session: Sessio
create_doc(name="small.xliff", type_=DocumentType.xliff),
create_xliff_doc(file_data),
create_task(usage=TranslationMemoryUsage(mode)),
- DocMemoryAssociation(doc_id=1, tm_id=1, mode='read'),
- DocMemoryAssociation(doc_id=1, tm_id=2, mode='read')
+ DocMemoryAssociation(doc_id=1, tm_id=1, mode="read"),
+ DocMemoryAssociation(doc_id=1, tm_id=2, mode="read"),
]
)
s.commit()
@@ -425,8 +428,8 @@ def test_process_task_puts_doc_in_error_state(monkeypatch, session: Session):
create_doc(name="small.xliff", type_=DocumentType.xliff),
create_xliff_doc(file_data),
create_task(
- mt_settings=MachineTranslationSettings(
- folder_id="12345", oauth_token="fake"
+ mt_settings=YandexTranslatorSettings(
+ type="yandex", folder_id="12345", oauth_token="fake"
),
),
]
diff --git a/backend/worker.py b/backend/worker.py
index b4950f2..465cf6c 100644
--- a/backend/worker.py
+++ b/backend/worker.py
@@ -29,7 +29,8 @@
from app.translation_memory.models import TranslationMemoryRecord
from app.translation_memory.query import TranslationMemoryQuery
from app.translation_memory.schema import TranslationMemoryUsage
-from app.translators import yandex
+from app.translators import llm, yandex
+from app.translators.common import LineWithGlossaries
def segment_needs_processing(segment: BaseSegment) -> bool:
@@ -106,7 +107,9 @@ def process_document(
)
start_time = time.time()
- translate_indices = substitute_segments(settings, session, segments, tm_ids, glossary_ids)
+ translate_indices = substitute_segments(
+ settings, session, segments, tm_ids, glossary_ids
+ )
logging.info(
"Segments substitution time: %.2f seconds, speed: %.2f segment/second, segments: %d/%d",
time.time() - start_time,
@@ -116,12 +119,16 @@ def process_document(
)
start_time = time.time()
- mt_result = translate_segments(
- segments,
- translate_indices,
- glossary_ids,
- settings.machine_translation_settings,
- session,
+ mt_result = (
+ translate_segments(
+ segments,
+ translate_indices,
+ glossary_ids,
+ settings.machine_translation_settings,
+ session,
+ )
+ if settings.machine_translation_settings is not None
+ else True
)
logging.info(
"Machine translation time: %.2f seconds, speed: %.2f segment/second, segments: %d/%d",
@@ -193,33 +200,45 @@ def translate_segments(
segments: Sequence[BaseSegment],
translate_indices: Sequence[int],
glossary_ids: list[int],
- mt_settings: MachineTranslationSettings | None,
+ mt_settings: MachineTranslationSettings,
session: Session,
) -> bool:
- # TODO: it is better to make solution more translation service agnostic
+ if not translate_indices:
+ return True
+
mt_failed = False
- if mt_settings and translate_indices:
- try:
- original_segments = [segments[idx].original for idx in translate_indices]
- data_to_translate: list[yandex.LineWithGlossaries] = []
- for segment in original_segments:
- glossary_records = GlossaryQuery(
- session
- ).get_glossary_records_for_segment(segment, glossary_ids)
- data_to_translate.append(
- (segment, [(x.source, x.target) for x in glossary_records])
- )
+ try:
+ # TODO: this might be harmful with LLM translation as it is loses
+ # the connectivity of the context
+ segments_to_translate = [segments[idx].original for idx in translate_indices]
+ data_to_translate: list[LineWithGlossaries] = []
+ for segment in segments_to_translate:
+ glossary_records = GlossaryQuery(session).get_glossary_records_for_segment(
+ segment, glossary_ids
+ )
+ data_to_translate.append(
+ (segment, [(x.source, x.target) for x in glossary_records])
+ )
+ if mt_settings.type == "yandex":
translated, mt_failed = yandex.translate_lines(
data_to_translate,
oauth_token=mt_settings.oauth_token,
folder_id=mt_settings.folder_id,
)
- for idx, translated_line in enumerate(translated):
- segments[translate_indices[idx]].translation = translated_line
- # TODO: handle specific exceptions instead of a generic one
- except Exception as e:
- logging.error("Yandex translation error %s", e)
- return False
+ elif mt_settings.type == "llm":
+ translated, mt_failed = llm.translate_lines(
+ data_to_translate, api_key=mt_settings.api_key
+ )
+ else:
+ logging.fatal("Unknown translation API")
+ raise RuntimeError("Unknown translation API")
+ for idx, translated_line in enumerate(translated):
+ segments[translate_indices[idx]].translation = translated_line
+ # TODO: handle specific exceptions instead of a generic one
+ except Exception as e:
+ logging.error("Machine translation error %s", e)
+ return False
+
return not mt_failed
@@ -317,7 +336,11 @@ def process_task(session: Session, task: DocumentTask) -> bool:
def main():
- logging.basicConfig(level=logging.INFO, format="%(message)s")
+ logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s - %(levelname)s - %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ )
logging.info("Starting document processing")
session = next(get_db())