Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
2e99d35
Added Granular Tagger template
gsarti Apr 13, 2023
6a9a04a
feat: tags from edits /w tests
k4black Apr 27, 2023
9b49b46
style: update tags_from_edits docs and some style fixes for typings
k4black May 1, 2023
9732403
feat: tags to source and attention sim scores
k4black May 4, 2023
f9e375c
fix: filter None similarities when check deletions
k4black May 4, 2023
340bb8c
feat: function cache
k4black May 5, 2023
3926a2c
feat: function cache
k4black May 5, 2023
4e90560
chore: add new comment for added material
k4black May 5, 2023
bce4ac1
chore: cmd line argument to run augmentation
k4black May 9, 2023
1d69f9a
style: apply black and ruff
k4black May 9, 2023
fa7045d
chore: gitignore .ruff_cache/
k4black May 9, 2023
7b44a98
refactor: move taggers to separate module
k4black May 9, 2023
c13798e
fix: remove debug print in cache
k4black May 9, 2023
4233114
style: add cache for wmt22 and fix style for tests
k4black May 11, 2023
2fdba17
feat: add deletions (None, j) to _fill_deleted_inserted_tokens
k4black May 12, 2023
270491a
fix: make _fill_deleted_inserted_tokens use lists as input
k4black May 12, 2023
1734e31
fix: _fill_deleted_inserted_tokens insertions error
k4black May 12, 2023
5aedb48
style: apply black
k4black May 12, 2023
e577f92
feat: optimize simalign to load models faster (much faster)
k4black Aug 27, 2023
01092bb
feat: save alignments
k4black Aug 27, 2023
b53d28c
fix: add some typings, fix cache for ald tags
k4black Aug 27, 2023
a90764c
feat: add analyze notebooks
k4black Aug 27, 2023
6448e1f
fix: refactor analysis notebook with faster data loading and rate plots
k4black Sep 17, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ data/raw/vie/*/*

tmp/*
outputs/*
cache/
.cache/

.idea/

Expand Down Expand Up @@ -64,6 +66,7 @@ coverage.xml
*.py,cover
.hypothesis/
.pytest_cache/
.ruff_cache/

# Translations
*.mo
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ python scripts/preprocess.py \
--add_extra \
--add_annotations \
--add_wmt22_quality_tags \
--add_name_tbd_quality_tags \
--output_single \
--output_merged_subjects \
--output_merged_languages
Expand Down
142 changes: 142 additions & 0 deletions divemt/cache_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
"""
The hashing idea adapted from https://death.andgravity.com/stable-hashing
https://github.com/lemon24/reader/blob/1efcd38c78f70dcc4e0d279e0fa2a0276749111e/src/reader/_hash_utils.py
"""
import dataclasses
import datetime
import functools
import hashlib
import inspect
import json
import pickle
from collections.abc import Collection
from pathlib import Path
from typing import Any, Callable, Dict, Optional

import pandas as pd

_VERSION = 0
_EXCLUDE = "_hash_exclude_"


def _json_dumps(thing: object) -> str:
return json.dumps(
thing,
default=_json_default, # force formatting-related options to known values
ensure_ascii=False,
sort_keys=True,
indent=None,
separators=(",", ":"),
)


def _json_default(thing: object) -> Any:
try:
return _dataclass_dict(thing)
except TypeError:
pass
if isinstance(thing, datetime.datetime):
return thing.isoformat(timespec="microseconds")
raise TypeError(f"Object of type {type(thing).__name__} is not JSON serializable")


def _dataclass_dict(thing: object) -> Dict[str, Any]:
# we could have used dataclasses.asdict()
# with a dict_factory that drops empty values,
# but asdict() is recursive and we need to intercept and check
# the _hash_exclude_ of nested dataclasses;
# this way, json.dumps() does the recursion instead of asdict()

# raises TypeError for non-dataclasses
fields = dataclasses.fields(thing)
# ... but doesn't for dataclass *types*
if isinstance(thing, type):
raise TypeError("got type, expected instance")

exclude = getattr(thing, _EXCLUDE, ())

rv = {}
for field in fields:
if field.name in exclude:
continue

value = getattr(thing, field.name)
if value is None or not value and isinstance(value, Collection):
continue

rv[field.name] = value

return rv


def calc_obj_hash(obj: object) -> bytes:
"""Calculate hash of a single object"""
prefix = _VERSION.to_bytes(1, "big")
hash_object = hashlib.sha256()
hash_object.update(_json_dumps(obj).encode("utf-8"))
return prefix + hash_object.digest()


def calc_args_hash(*args: Any, **kwargs: any) -> bytes:
"""Calculate hash of arguments to function"""
prefix = _VERSION.to_bytes(1, "big")
hash_object = hashlib.sha256()
for arg in args:
if isinstance(arg, (pd.DataFrame, pd.Series)):
hash_object.update(str(pd.util.hash_pandas_object(arg).sum()).encode("utf-8"))
else:
hash_object.update(_json_dumps(arg).encode("utf-8"))
for key, value in kwargs.items():
if isinstance(value, (pd.DataFrame, pd.Series)):
hash_object.update(key.encode("utf-8") + str(pd.util.hash_pandas_object(value).sum()).encode("utf-8"))
else:
hash_object.update(_json_dumps([key, value]).encode("utf-8"))
return prefix + hash_object.digest()


class CacheDecorator:
def __init__(self, cache_dir: Optional[Path] = None, version: int = 0, name: Optional[str] = None):
self.version = version
self.name = name or ''
self.cache_dir = cache_dir or Path(".cache")

@staticmethod
def _is_bound_method(function: Callable, arg: Any):
return inspect.ismethod(function) or (hasattr(arg, "__class__") and function.__name__ in dir(arg.__class__))

def __call__(self, function: Callable) -> Any:
cached_function_name = function.__qualname__.replace(".", "_")
if self.name:
cached_function_name = cached_function_name + "_" + self.name
@functools.wraps(function)
def wrapper(*args: Any, **kwargs: Any) -> Any:
cache_key_args = args[1:] if self._is_bound_method(function, args[0]) else args
hash_val = calc_args_hash(*cache_key_args, **kwargs)
cache_file = self.cache_dir / f"{cached_function_name}_v{self.version}_{hash_val.hex()}.pkl"

# TODO: add logging, not printing

if cache_file.exists():
print(f"LOADING CACHE: {cache_file}")
with open(cache_file, "rb") as f:
return pickle.load(f)
else:
result = function(*args, **kwargs)
print(f"CREATE CACHE: {cache_file}")
cache_file.parent.mkdir(parents=True, exist_ok=True)
with open(cache_file, "wb") as f:
pickle.dump(result, f)
return result

return wrapper

def __get__(self, instance, owner):
"""note: adapted from chat-gpt-4 =)"""
# Support method decorators for class instances
if instance is None:
return self

# Bind the decorated method to the instance
bound_method = functools.partial(self, instance)

return bound_method
4 changes: 2 additions & 2 deletions divemt/cer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import ctypes
import itertools

import Levenshtein
import Levenshtein as levenshtein


class EditDistance:
Expand Down Expand Up @@ -86,7 +86,7 @@ def cer(hyp_words, ref_words, ed_wrapper):
if len(shifted_chars) == 0:
return 1.0

edit_cost = Levenshtein.distance(shifted_chars, ref_chars) + shift_cost
edit_cost = levenshtein.distance(shifted_chars, ref_chars) + shift_cost
return min(1.0, edit_cost / len(shifted_chars))


Expand Down
19 changes: 13 additions & 6 deletions divemt/parse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .cer import cer
from .tag_utils import clear_nlp_cache, texts2annotations, tokenize

from .qe_taggers import QETagger, WMT22QETagger # isort: skip <- due to circular import with tag_utils
from .qe_taggers import QETagger, WMT22QETagger, NameTBDTagger # isort: skip <- due to circular import with tag_utils

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -351,16 +351,20 @@ def texts2qe(
) -> pd.DataFrame:
"""Add quality tags to a dataframe."""
pe_texts = data.copy()[data.mt_text.notnull()]
src_tags, mt_tags = tagger.generate_tags(
src_tags, mt_tags, src_mt_alignments, mt_pe_alignments = tagger.generate_tags(
pe_texts["src_text"].tolist(),
pe_texts["mt_text"].tolist(),
pe_texts["tgt_text"].tolist(),
"eng",
pe_texts.unit_id.str.split("-").map(lambda x: x[2]),
pe_texts.unit_id.str.split("-").map(lambda x: x[2]).tolist(),
)
pe_texts[f"src_{tagger.ID}"] = src_tags
pe_texts[f"mt_{tagger.ID}"] = mt_tags
pe_texts = pe_texts[["unit_id", f"src_{tagger.ID}", f"mt_{tagger.ID}"]]
if src_mt_alignments:
pe_texts[f"src_mt_{tagger.ID}_alignments"] = src_mt_alignments
if mt_pe_alignments:
pe_texts[f"mt_pe_{tagger.ID}_alignments"] = mt_pe_alignments
data = data.join(pe_texts.set_index("unit_id"), on="unit_id")
return data

Expand All @@ -377,12 +381,12 @@ def parse_from_folder(
add_extra_information: bool = False,
add_annotations_information: bool = False,
add_wmt22_quality_tags: bool = False,
add_name_tbd_quality_tags: bool = False,
rounding: Optional[int] = None,
) -> Union[pd.DataFrame, Tuple[pd.DataFrame, pd.DataFrame]]:
"""Parse all .per XML files in a folder and return a single dataframe containing all units."""
metrics_list_dfs = [per2metrics(os.path.join(path, f)) for f in os.listdir(path) if f.endswith(".per")]
metrics_df = pd.concat([df for df in metrics_list_dfs if df is not None], ignore_index=True)

if (
output_texts
or add_edit_information
Expand All @@ -407,9 +411,12 @@ def parse_from_folder(
if add_extra_information:
metrics_df = metrics2extra(metrics_df)
if add_annotations_information:
texts_df = texts2annotations(texts_df)
texts_df = texts2annotations(texts_df) # TODO: make cache optional
if add_wmt22_quality_tags:
tagger = WMT22QETagger()
tagger = WMT22QETagger() # TODO: make cache optional
texts_df = texts2qe(texts_df, tagger)
if add_name_tbd_quality_tags:
tagger = NameTBDTagger() # TODO: make cache optional
texts_df = texts2qe(texts_df, tagger)

if time_ordered:
Expand Down
13 changes: 13 additions & 0 deletions divemt/qe_taggers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from .base import QETagger, TAlignment, TTag
from .name_tbd_tagger import NameTBDGeneralTags, NameTBDTagger
from .wmt22_tagger import WMT22QETagger, WMT22QETags

__all__ = [
"QETagger",
"TTag",
"TAlignment",
"NameTBDGeneralTags",
"NameTBDTagger",
"WMT22QETags",
"WMT22QETagger",
]
104 changes: 104 additions & 0 deletions divemt/qe_taggers/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from abc import ABC, abstractmethod
from typing import Any, List, Optional, Set, Tuple, Union

from ..parse_utils import tokenize

TTag = Union[str, Set[str]]
TAlignment = Union[Tuple[Optional[int], Optional[int]], Tuple[Optional[int], Optional[int], Optional[float]]]


class QETagger(ABC):
"""An abstract class to produce quality estimation tags from src-mt-pe triplets."""

ID = "qe"

def align_source_mt(
self,
src_tokens: List[List[str]],
mt_tokens: List[List[str]],
**align_source_mt_kwargs: Any,
) -> List[List[TAlignment]]:
"""Align source and machine translation tokens."""
raise NotImplementedError(f"{self.__class__.__name__} does not implement align_source_mt()")

def align_source_pe(
self,
src_tokens: List[List[str]],
pe_tokens: List[List[str]],
**align_source_pe_kwargs: Any,
) -> List[List[TAlignment]]:
"""Align source and post-edited tokens."""
raise NotImplementedError(f"{self.__class__.__name__} does not implement align_source_pe()")

@abstractmethod
def align_mt_pe(
self,
mt_tokens: List[List[str]],
pe_tokens: List[List[str]],
**align_mt_pe_kwargs: Any,
) -> List[List[TAlignment]]:
"""Align machine translation and post-editing tokens."""
pass

@staticmethod
@abstractmethod
def tags_from_edits(
mt_tokens: List[List[str]],
pe_tokens: List[List[str]],
alignments: List[List[TAlignment]],
**mt_tagging_kwargs: Any,
) -> List[List[TTag]]:
"""Produce tags on MT tokens from edits found in the PE tokens."""
pass

@staticmethod
@abstractmethod
def tags_to_source(
src_tokens: List[List[str]],
tgt_tokens: List[List[str]],
**src_tagging_kwargs: Any,
) -> List[List[TTag]]:
"""Propagate tags from MT to source."""
pass

@staticmethod
def get_tokenized(
sents: List[str], lang: Union[str, List[str]]
) -> Tuple[List[List[str]], Union[List[str], List[List[str]]]]:
"""Tokenize sentences."""
if isinstance(lang, str):
lang = [lang] * len(sents)
tok: List[List[str]] = [tokenize(sent, curr_lang, keep_tokens=True) for sent, curr_lang in zip(sents, lang)]
assert len(tok) == len(lang)
return tok, lang

@abstractmethod
def generate_tags(
self,
srcs: List[str],
mts: List[str],
pes: List[str],
src_langs: Union[str, List[str]],
tgt_langs: Union[str, List[str]],
) -> Tuple[List[TTag], List[TTag], List[TAlignment], List[TAlignment]]:
"""Generate word-level quality estimation tags from source-mt-pe triplets.

Args:
srcs (`List[str]`):
List of untokenized source sentences.
mts (`List[str]`):
List of untokenized machine translated sentences.
pes (`List[str]`):
List of untokenized post-edited sentences.
src_langs (`Union[str, List[str]]`):
Either a single language code for all source sentences or a list of language codes
(one per source sentence).
tgt_langs (`Union[str, List[str]]`):
Either a single language code for all target sentences or a list of language codes
(one per machine translation).

Returns:
`Tuple[List[TTag], List[TTag]]`: A tuple containing the lists of quality tags for all source and the
machine translation sentence, respectively.
"""
pass
Loading