Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions backend/app/documents/query.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datetime import datetime
from typing import Iterable

from sqlalchemy import case, func, select, Row
from sqlalchemy import Row, case, func, select
from sqlalchemy.orm import Session

from app.base.exceptions import BaseQueryException
Expand Down Expand Up @@ -96,7 +96,7 @@ def get_document_records_paged(
repetitions_subquery = (
select(
DocumentRecord.source,
func.count(DocumentRecord.id).label('repetitions_count')
func.count(DocumentRecord.id).label("repetitions_count"),
)
.filter(DocumentRecord.document_id == doc.id)
.group_by(DocumentRecord.source)
Expand All @@ -106,12 +106,14 @@ def get_document_records_paged(
return self.__db.execute(
select(
DocumentRecord,
func.coalesce(repetitions_subquery.c.repetitions_count, 0).label('repetitions_count')
func.coalesce(repetitions_subquery.c.repetitions_count, 0).label(
"repetitions_count"
),
)
.filter(DocumentRecord.document_id == doc.id)
.outerjoin(
repetitions_subquery,
DocumentRecord.source == repetitions_subquery.c.source
DocumentRecord.source == repetitions_subquery.c.source,
)
.order_by(DocumentRecord.id)
.offset(page_records * page)
Expand Down
16 changes: 12 additions & 4 deletions backend/app/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from enum import Enum
from typing import Literal

from pydantic import BaseModel, ConfigDict, EmailStr

Expand Down Expand Up @@ -27,14 +28,21 @@ def get_values(cls):
return tuple(role.value for role in cls)


class MachineTranslationSettings(BaseModel):
# Yandex only for now
# source_language: str
# target_language: str
class LlmTranslatorSettings(BaseModel):
type: Literal["llm"]
api_key: str
# base_api: str # reserved for universal OpenAI translation


class YandexTranslatorSettings(BaseModel):
type: Literal["yandex"]
folder_id: str
oauth_token: str


MachineTranslationSettings = LlmTranslatorSettings | YandexTranslatorSettings


class StatusMessage(BaseModel):
message: str

Expand Down
11 changes: 11 additions & 0 deletions backend/app/settings.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
from typing import Literal

from pydantic_settings import BaseSettings
Expand All @@ -17,5 +18,15 @@ class Settings(BaseSettings):
"http://localhost:8000",
)

llm_base_api: str | None = None
llm_model: str | None = None
llm_base64_prompt: str | None = None

@property
def llm_prompt(self):
if not self.llm_base64_prompt:
return ""
base64.decodebytes(self.llm_base64_prompt.encode()).decode()


settings = Settings()
11 changes: 11 additions & 0 deletions backend/app/translators/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
class TranslationError(Exception):
"""
An error raised when Machine Translator API returns an error.
"""


Original = str
Translation = str
GlossaryPair = tuple[Original, Translation]
GlossaryPairs = list[GlossaryPair]
LineWithGlossaries = tuple[str, GlossaryPairs]
125 changes: 125 additions & 0 deletions backend/app/translators/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# This is a great idea to make a universal translator using OpenAI API, instead
# of specific translators for every service, but this will require much effort
# to prepare and manage a library of prompts for every context size, so it is
# postponed for the future.

import logging
import re

from openai import OpenAI

from app.settings import settings
from app.translators.common import LineWithGlossaries


def generate_prompt_prologue() -> str:
return settings.llm_prompt or ""


def generate_prompt_ctx(
lines: list[LineWithGlossaries], offset: int, ctx_size: int
) -> str:
start = max(offset - ctx_size, 0)
ctx_lines = [f"<seg>{line[0]}</seg>" for line in lines[start:offset]]
return f"<context>{'\n'.join(ctx_lines)}</context>"


def generate_prompt_glossary(lines: list[LineWithGlossaries]) -> str:
terms = {}
for _, glossary_term in lines:
for original, translation in glossary_term:
terms[original] = (
f"<term><orig>{original}</orig><trans>{translation}</trans></term>"
)

return f"<glossary>{'\n'.join(terms.values())}</glossary>"


def generate_prompt_task(lines: list[LineWithGlossaries]) -> str:
segments = [f"<seg>{line[0]}</seg>" for line in lines]
return f"<task>{'\n'.join(segments)}</task>"


def generate_prompt(
lines: list[LineWithGlossaries], offset: int, ctx_size: int, task_size: int
) -> tuple[str, int]:
task_lines = lines[offset : offset + task_size]
parts = [
generate_prompt_prologue(),
generate_prompt_ctx(lines, offset, ctx_size),
generate_prompt_glossary(task_lines),
generate_prompt_task(task_lines),
]
return "\n\n".join(parts), len(task_lines)


def parse_lines(network_out: str, expected_size: int) -> tuple[list[str], bool]:
output = []

split = network_out.strip().splitlines()
if len(split) != expected_size:
logging.warning("Unexpected LLM output, not enough lines returned %s", split)
return [], False

for line in split:
m = re.match(r"<seg>(.*)</seg>", line)
if not m:
logging.warning("Unexpected LLM output, not match found in %s", line)
return [], False
output.append(m.group(1))

return output, True


def translate_lines(
lines: list[LineWithGlossaries], api_key: str
) -> tuple[list[str], bool]:
"""
Translate lines of text using LLM translation.

Args:
lines: A list of strings to be translated.
settings: An object containing LLM translation settings.

Returns:
A list of translated strings.
"""
# when settings are not set, it fails immediately
if not settings.llm_base_api or not settings.llm_model:
return [], True

client = OpenAI(api_key=api_key, base_url=settings.llm_base_api)

output: list[str] = []

task_size = 40
ctx_size = 40
for offset in range(0, len(lines), task_size):
for attempt in range(3):
prompt, actual_size = generate_prompt(lines, offset, ctx_size, task_size)
completion = client.chat.completions.create(
model=settings.llm_model,
messages=[
{
"role": "system",
# TODO: make it configurable
"content": "You are a smart translator from English to Russian.",
},
{"role": "user", "content": prompt},
],
extra_body={"thinking": {"type": "disabled"}},
)
# parse output of the network
batch_lines, result = parse_lines(
completion.choices[0].message.content or "", actual_size
)
if result:
output += batch_lines
break
logging.warning("Failed to get answer from LLM, attempt %s", attempt + 1)
else:
logging.error("Was unable to get answer from LLM, returning empty list")
for _ in range(task_size):
output.append("")

return output, False
12 changes: 1 addition & 11 deletions backend/app/translators/yandex.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pydantic import BaseModel, PositiveInt, ValidationError

from app.settings import settings
from app.translators.common import LineWithGlossaries, TranslationError


class YandexTranslatorResponse(BaseModel):
Expand All @@ -20,17 +21,6 @@ class YandexTranslatorResponse(BaseModel):
translations: list[dict[str, str]]


class TranslationError(Exception):
"""
An error raised when Yandex Translator API returns an error.
"""


GlossaryPair = tuple[str, str]
GlossaryPairs = list[GlossaryPair]
LineWithGlossaries = tuple[str, GlossaryPairs]


# Currently Yandex rejects requests larger than 10k symbols and more than
# 50 records in glossary.
def iterate_batches(
Expand Down
1 change: 1 addition & 0 deletions backend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ sqlalchemy==2.0.*
psycopg2-binary==2.9.*
requests==2.32.*
itsdangerous==2.2.*
openai==2.6.*
openpyxl==3.1.*
nltk==3.9.*
8 changes: 6 additions & 2 deletions backend/tests/routers/test_routes_documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,7 +917,9 @@ def test_setting_glossaries_returns_404_for_non_existing_glossaries(
assert response.status_code == 404


def test_get_doc_records_with_repetitions(user_logged_client: TestClient, session: Session):
def test_get_doc_records_with_repetitions(
user_logged_client: TestClient, session: Session
):
"""Test that document records endpoint returns repetition counts"""
with session as s:
records = [
Expand Down Expand Up @@ -947,7 +949,9 @@ def test_get_doc_records_with_repetitions(user_logged_client: TestClient, sessio

# Check that repetition counts are correct
# "Hello World" appears 3 times, others appear once
record_counts = {record["source"]: record["repetitions_count"] for record in response_json}
record_counts = {
record["source"]: record["repetitions_count"] for record in response_json
}
assert record_counts["Hello World"] == 3
assert record_counts["Goodbye"] == 1
assert record_counts["Test"] == 1
Loading