Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 31 additions & 10 deletions backend/app/routers/translation_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def get_memory_by_id(db: Session, memory_id: int):
doc = TranslationMemoryQuery(db).get_memory(memory_id)
if not doc:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Document not found"
status_code=status.HTTP_404_NOT_FOUND, detail="Memory not found"
)
return doc

Expand Down Expand Up @@ -53,20 +53,41 @@ def get_memory_records(
tm_id: int,
db: Annotated[Session, Depends(get_db)],
page: Annotated[int | None, Query(ge=0)] = None,
) -> list[schema.TranslationMemoryRecord]:
query: Annotated[str | None, Query()] = None,
) -> schema.TranslationMemoryListResponse:
page_records: Final = 100
if not page:
page = 0

get_memory_by_id(db, tm_id)
return [
schema.TranslationMemoryRecord(
id=record.id, source=record.source, target=record.target
)
for record in TranslationMemoryQuery(db).get_memory_records_paged(
tm_id, page, page_records
)
]
records, count = TranslationMemoryQuery(db).get_memory_records_paged(
tm_id, page, page_records, query
)
return schema.TranslationMemoryListResponse(
records=records,
page=page,
total_records=count,
)


@router.get("/{tm_id}/records/similar")
def get_memory_records_similar(
tm_id: int,
db: Annotated[Session, Depends(get_db)],
query: Annotated[str, Query()],
) -> schema.TranslationMemoryListSimilarResponse:
page_records: Final = 20

get_memory_by_id(db, tm_id)
records = TranslationMemoryQuery(db).get_memory_records_paged_similar(
tm_id, page_records, query
)
return schema.TranslationMemoryListSimilarResponse(
records=records,
page=0,
# this is incorrect in general case, but for 20 records is fine
total_records=len(records),
)


@router.post("/upload")
Expand Down
82 changes: 70 additions & 12 deletions backend/app/translation_memory/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from sqlalchemy import func, select, text
from sqlalchemy.orm import Session

from app.translation_memory import schema

from .models import TranslationMemory, TranslationMemoryRecord
from .schema import MemorySubstitution


class TranslationMemoryQuery:
Expand Down Expand Up @@ -35,23 +36,80 @@ def get_memory_records_count(self, memory_id: int) -> int:
).scalar_one()

def get_memory_records_paged(
self, memory_id: int, page: int, page_records: int
) -> Iterable[TranslationMemoryRecord]:
return self.__db.execute(
select(TranslationMemoryRecord)
.filter(TranslationMemoryRecord.document_id == memory_id)
.order_by(TranslationMemoryRecord.id)
.offset(page_records * page)
.limit(page_records)
).scalars()
self,
memory_id: int,
page: int,
page_records: int,
query: str | None,
) -> tuple[list[schema.TranslationMemoryRecord], int]:
filters = [TranslationMemoryRecord.document_id == memory_id]
if query:
filters.append(TranslationMemoryRecord.source.ilike(f"%{query}%"))

count = self.__db.execute(
select(
func.count(TranslationMemoryRecord.id),
).filter(*filters)
).scalar_one()

return [
schema.TranslationMemoryRecord(
id=scalar.id, source=scalar.source, target=scalar.target
)
for scalar in self.__db.execute(
select(TranslationMemoryRecord)
.filter(*filters)
.order_by(TranslationMemoryRecord.id)
.offset(page_records * page)
.limit(page_records)
).scalars()
], count

def get_memory_records_paged_similar(
self,
memory_id: int,
page_records: int,
query: str,
) -> list[schema.TranslationMemoryRecordWithSimilarity]:
# Use the same approach as get_substitutions but with different parameters
similarity_func = func.similarity(TranslationMemoryRecord.source, query)

# Set similarity threshold to 0.25 (25%) as required
self.__db.execute(
text("SET pg_trgm.similarity_threshold TO :threshold"),
{"threshold": 0.25},
)

return [
schema.TranslationMemoryRecordWithSimilarity(
id=scalar.id,
source=scalar.source,
target=scalar.target,
similarity=scalar.similarity,
)
for scalar in self.__db.execute(
select(
TranslationMemoryRecord.id,
TranslationMemoryRecord.source,
TranslationMemoryRecord.target,
similarity_func,
)
.filter(
TranslationMemoryRecord.document_id == memory_id,
TranslationMemoryRecord.source.op("%")(query),
)
.order_by(similarity_func.desc())
.limit(page_records)
).all()
]

def get_substitutions(
self,
source: str,
tm_ids: list[int],
threshold: float = 0.75,
count: int = 10,
) -> list[MemorySubstitution]:
) -> list[schema.MemorySubstitution]:
similarity_func = func.similarity(TranslationMemoryRecord.source, source)
self.__db.execute(
text("SET pg_trgm.similarity_threshold TO :threshold"),
Expand All @@ -72,7 +130,7 @@ def get_substitutions(
).all()

return [
MemorySubstitution(
schema.MemorySubstitution(
source=record.source, target=record.target, similarity=record.similarity
)
for record in records
Expand Down
16 changes: 16 additions & 0 deletions backend/app/translation_memory/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,21 @@ class TranslationMemoryRecord(Identified):
target: str


class TranslationMemoryListResponse(BaseModel):
records: list[TranslationMemoryRecord]
page: int
total_records: int


class TranslationMemoryRecordWithSimilarity(TranslationMemoryRecord):
similarity: float


class TranslationMemoryListSimilarResponse(BaseModel):
records: list[TranslationMemoryRecordWithSimilarity]
page: int
total_records: int


class TranslationMemoryCreationSettings(BaseModel):
name: str = Field(min_length=1)
18 changes: 12 additions & 6 deletions backend/app/translators/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

def generate_prompt_prologue() -> str:
if not settings.llm_prompt:
logging.error('No LLM prompt configured')
logging.error("No LLM prompt configured")
return settings.llm_prompt


Expand Down Expand Up @@ -55,22 +55,28 @@ def generate_prompt(
return "\n\n".join(parts), len(task_lines)


SEG_MATCHER = re.compile(r"<seg>(.*)</seg>")


def parse_lines(network_out: str, expected_size: int) -> tuple[list[str], bool]:
output = []
output: list[str] = []

split = network_out.strip().splitlines()
if len(split) != expected_size:
logging.warning("Unexpected LLM output, not enough lines returned %s", split)
return [], False

failed = False
for line in split:
m = re.match(r"<seg>(.*)</seg>", line)
m = re.match(SEG_MATCHER, line)
if not m:
logging.warning("Unexpected LLM output, not match found in %s", line)
return [], False
logging.warning("Unexpected LLM output, no match found in %s", line)
output.append("")
failed = True
continue
output.append(m.group(1))

return output, True
return output, not failed


def translate_lines(
Expand Down
75 changes: 68 additions & 7 deletions backend/tests/routers/test_routes_tms.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,14 @@ def test_can_get_tm_records(user_logged_client: TestClient, session: Session):

response = user_logged_client.get("/translation_memory/1/records")
assert response.status_code == 200
assert response.json() == [
{"id": 1, "source": "Regional Effects", "target": "Translation"},
{"id": 2, "source": "User Interface", "target": "UI"},
]
assert response.json() == {
"records": [
{"id": 1, "source": "Regional Effects", "target": "Translation"},
{"id": 2, "source": "User Interface", "target": "UI"},
],
"page": 0,
"total_records": 2,
}


def test_can_get_tm_records_with_page(user_logged_client: TestClient, session: Session):
Expand All @@ -86,8 +90,9 @@ def test_can_get_tm_records_with_page(user_logged_client: TestClient, session: S
"/translation_memory/1/records", params={"page": "1"}
)
assert response.status_code == 200
assert len(response.json()) == 50
assert response.json()[0] == {"id": 101, "source": "line100", "target": "line100"}
json = response.json()
assert len(json["records"]) == 50
assert json["records"][0] == {"id": 101, "source": "line100", "target": "line100"}


def test_tm_records_are_empty_for_too_large_page(
Expand All @@ -109,7 +114,63 @@ def test_tm_records_are_empty_for_too_large_page(
"/translation_memory/1/records", params={"page": "20"}
)
assert response.status_code == 200
assert response.json() == []
assert response.json()["records"] == []


def test_tm_records_exact_match(user_logged_client: TestClient, session: Session):
tm_records = [
TranslationMemoryRecord(source="Hello world", target="Hola mundo"),
TranslationMemoryRecord(source="Goodbye world", target="Adiós mundo"),
TranslationMemoryRecord(source="Welcome home", target="Bienvenido a casa"),
]
with session as s:
s.add(TranslationMemory(name="test_doc.tmx", records=tm_records, created_by=1))
s.commit()

# Test exact search for "world"
response = user_logged_client.get(
"/translation_memory/1/records",
params={"query": "world", "query_mode": "exact"},
)
assert response.status_code == 200

json = response.json()
assert len(json["records"]) == 2
assert json["total_records"] == 2
# Should return records containing "world" in source
sources = [result["source"] for result in json["records"]]
assert "Hello world" in sources
assert "Goodbye world" in sources
# Similarity should be None for exact search
assert all("similarity" not in result for result in json["records"])


def test_tm_records_exact_match_in_nonexistent_tm(user_logged_client: TestClient):
response = user_logged_client.get(
"/translation_memory/999/records",
params={"query": "test", "query_mode": "exact"},
)
assert response.status_code == 404
assert "Memory not found" in response.json()["detail"]


def test_search_no_results(user_logged_client: TestClient, session: Session):
tm_records = [
TranslationMemoryRecord(source="Hello world", target="Hola mundo"),
]
with session as s:
s.add(TranslationMemory(name="test_doc.tmx", records=tm_records, created_by=1))
s.commit()

# Search for something that doesn't exist
response = user_logged_client.get(
"/translation_memory/1/records",
params={"query": "nonexistent", "query_mode": "exact"},
)
assert response.status_code == 200

results = response.json()["records"]
assert len(results) == 0


def test_tm_records_returns_404_for_nonexistent_document(
Expand Down
2 changes: 1 addition & 1 deletion backend/tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def test_parse_lines_invalid_format():
result, success = llm.parse_lines(network_out, expected_size)

assert not success
assert result == []
assert result == ["", "translation2"]


def test_parse_lines_empty_content():
Expand Down
32 changes: 24 additions & 8 deletions frontend/mocks/tmMocks.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,31 @@
import {http, HttpResponse} from 'msw'
import {AwaitedReturnType} from './utils'
import {getMemories} from '../src/client/services/TmsService'
import {getMemories, getMemory} from '../src/client/services/TmsService'
import {TranslationMemoryWithRecordsCount} from '../src/client/schemas/TranslationMemoryWithRecordsCount'

const tms: TranslationMemoryWithRecordsCount[] = [
{
id: 42,
created_by: 12,
name: 'Some TM',
records_count: 5,
},
]

export const tmMocks = [
http.get('http://localhost:8000/translation_memory/', () =>
HttpResponse.json<AwaitedReturnType<typeof getMemories>>([
{
id: 42,
created_by: 12,
name: 'Some TM',
},
])
HttpResponse.json<AwaitedReturnType<typeof getMemories>>(tms)
),
http.get<{id: string}>(
'http://localhost:8000/translation_memory/:id',
({params}) => {
const id = Number(params.id)
const tm = tms.find((t) => t.id == id)
if (tm) {
return HttpResponse.json<AwaitedReturnType<typeof getMemory>>(tm)
} else {
return new HttpResponse(null, {status: 404})
}
}
),
]
9 changes: 9 additions & 0 deletions frontend/src/client/schemas/TranslationMemoryListResponse.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// This file is autogenerated, do not edit directly.

import {TranslationMemoryRecord} from './TranslationMemoryRecord'

export interface TranslationMemoryListResponse {
records: TranslationMemoryRecord[]
page: number
total_records: number
}
Loading