Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions configs/tokenization_context.gin
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
expected_size=128
output_base_dir="/lnet/work/home-students-external/farhan/troja/outputs/128_finetuning_damuel/"
output_base_dir_mewsli="/lnet/work/home-students-external/farhan/troja/outputs/128_finetuning_mewsli/"
expected_size=64
output_base_dir="/lnet/work/home-students-external/farhan/troja/outputs/finetuning_damuel_2"
output_base_dir_mewsli="/lnet/work/home-students-external/farhan/troja/outputs/finetuning_mewsli_2/"
languages = [
"af", "be", "ca", "da", "el", "es", "eu", "fi", "ga", "gl", "hi", "hu", "id",
"ja", "la", "lv", "mt", "nn", "pt", "ru", "sk", "sr", "ta", "tr", "uk", "vi",
"zh", "ar", "bg", "cs", "de", "en", "et", "fa", "fr", "gd", "he", "hr", "hy",
"it", "ko", "lt", "mr", "nl", "pl", "ro", "se", "sl", "sv", "te", "ug", "ur",
"wo"
]
# languages = ["af"]
damuel_base_path="/lnet/work/home-students-external/farhan/damuel/1.0-xz"
#damuel_base_path="/lnet/work/home-students-external/farhan/troja/damuel/1.0"
label_token="[M]"
compress=True
remainder_mod=128
num_processes=90
compress=False
remainder_mod=90
num_processes=60

run_damuel_description_context.expected_size=%expected_size
run_damuel_description_context.output_base_dir=%output_base_dir
Expand Down
20 changes: 20 additions & 0 deletions configs/train_NO_distribution.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
training_batch_size=2688
#epochs=300
epochs=100
logit_mutliplier=20
negative_sampling_type="top"
learning_rate=0.0001
negatives_per_link=3
context_size=64
STEPS_PER_EPOCH=1000

generate.BATCH_SIZE=%training_batch_size
generate.EPOCHS=%epochs
generate.STEPS_PER_EPOCH=%STEPS_PER_EPOCH
generate.NEG=%negatives_per_link
generate.CONTEXT_SIZE=%context_size
generate.NEGATIVE_SAMPLING_TYPE=%negative_sampling_type

train_ddp.EPOCHS=%epochs
train_ddp.LR=%learning_rate
train_ddp.LOGIT_MULTIPLIER=%logit_mutliplier
12 changes: 8 additions & 4 deletions src/data_processors/tokens/tokens_cutter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@


def fast_token_mention_span(all_tokens, label_token_id):
@nb.njit
# Premature optimalization is cause of all evil and here it is so so true
# The numba njit version crashes the code at OOM, maybe not working good with multiprocessing?
# TODO: We should investigate why is it broken.
# @nb.njit
def _fast_token_mention_span(all_tokens, label_token_id):
mention_start_idx, mention_end_idx = None, None
for i, token in enumerate(all_tokens):
Expand Down Expand Up @@ -157,6 +160,7 @@ def _more_on_left_cut(self):
)

def _warn_about_padding(self):
_logger.warning(
"Padding tokens are present in the input text. This means that input text is shorter than expected."
)
pass
# _logger.warning(
# "Padding tokens are present in the input text. This means that input text is shorter than expected."
# )
5 changes: 4 additions & 1 deletion src/finetunings/generate_epochs/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,11 @@ def generate(

negative_sampler_kwargs = {}
if "distribution" in NEGATIVE_SAMPLING_TYPE:
# using sqrt seems like a sane default
negative_sampler_kwargs["qids_distribution"] = (
calculate_qids_distribution_from_links(LINKS_EMBS_DIR, index_qids)
calculate_qids_distribution_from_links(
LINKS_EMBS_DIR, index_qids, lambda x: np.sqrt(x)
)
)
negative_sampler_kwargs["randomly_sampled_cnt"] = 1

Expand Down
13 changes: 9 additions & 4 deletions src/models/negative_sampler.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from collections import Counter
from enum import Enum
import logging

import numba as nb
import numpy as np
from scipy.stats.sampling import DiscreteGuideTable

from models.searchers.searcher import Searcher

Expand Down Expand Up @@ -143,8 +143,14 @@ def __init__(
self.sample_f = _get_sampler(sampling_type)
self.qids_distribution = qids_distribution
self.randomly_sampled_cnt = randomly_sampled_cnt

self._validate()

if self._should_sample_randomly():
# self.urng = np.random.default_rng()
# DescreteAliasUrn seems slightly faster but there is a UNU.RAN error which I don't want to debug and
# DiscreteGuideTable looks okish.
self.rng = DiscreteGuideTable(self.qids_distribution)

def sample(
self, batch_embs: np.ndarray, batch_qids: np.ndarray, negative_cnts: int
Expand Down Expand Up @@ -178,13 +184,12 @@ def _sample_randomly(self, batch_qids):
batch_qid = batch_qids[0]
batch_qids = set(batch_qids)
result = np.empty((batch_size, self.randomly_sampled_cnt), dtype=np.int32)
# TODO: Is sampling of size 1 really needed? These nested loops seem incredibly slow.
for i in range(batch_size):
for j in range(self.randomly_sampled_cnt):
qid_to_add = batch_qid
while qid_to_add in batch_qids:
qid_idx = np.random.choice(
self.returned_indices, size=1, p=self.qids_distribution
)[0]
qid_idx = self.returned_indices[self.rng.rvs(1)[0]]
qid_to_add = self.qids[qid_idx]
result[i][j] = qid_idx
return result
Expand Down
31 changes: 24 additions & 7 deletions src/multilingual_dataset/creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(
self.dest_links_dir.mkdir(parents=True, exist_ok=True)

self.single_mixer = Mixer(buffer_size=1)
self.standard_mixer = Mixer(buffer_size=200)
self.standard_mixer = Mixer(buffer_size=10)

def run(self) -> None:
"""Gathers links from all languages and writes them to dest_dir.
Expand All @@ -49,7 +49,7 @@ def run(self) -> None:
out_file_paths.append(out_file_path)

self.single_mixer.mix(out_file_paths, n_of_mixings=1, compress_output=False)
self.standard_mixer.mix(out_file_paths, n_of_mixings=5, compress_output=True)
self.standard_mixer.mix(out_file_paths, n_of_mixings=20, compress_output=True)

def _copy_files(
self, source_file_paths: Iterable[Path], dest_file_path: Path
Expand Down Expand Up @@ -93,6 +93,7 @@ def __init__(

def run(self) -> None:
qid_lang_mapping = self._get_qid_lang_mapping()
assert all(len(v) <= self.langs_per_qid for v in qid_lang_mapping.values())
lang_qid_lists = self._group_qids_by_lang(qid_lang_mapping)
self._copy_chosen_pages(lang_qid_lists)

Expand All @@ -115,10 +116,20 @@ def _copy_chosen_pages(self, lang_qid_lists: dict[str, list[int]]) -> None:
def _copy_chosen_pages_from_lang(
self, wanted_qids: list[int], filepaths: list[Path], lang: str
) -> None:
# wanted_qids = list(wanted_qids)
processed_qids = set()
for i, descs_file_path in enumerate(filepaths):
tokens, qids = load_mentions(descs_file_path)
_, unique_qids_index = np.unique(qids, return_index=True)
qids = qids[unique_qids_index]
tokens = tokens[unique_qids_index]

index = np.isin(qids, wanted_qids)
# This is likely better than the isin from numpy which needs linear scan.
for i, flag in enumerate(index):
if flag and qids[i] in processed_qids:
index[i] = False

index = np.isin(qids, list(wanted_qids))
chosen_tokens = tokens[index]
chosen_qids = qids[index]

Expand All @@ -130,6 +141,7 @@ def _copy_chosen_pages_from_lang(
tokens=chosen_tokens,
qids=chosen_qids,
)
processed_qids.update(chosen_qids)

def _group_qids_by_lang(
self, qid_lang_mapping: dict[int, str]
Expand Down Expand Up @@ -194,6 +206,11 @@ def _get_mapping_from_counts_and_lang_sizes(
desc="Mapping QIDs to languages",
total=len(qid_lang_counts),
):
if qid in qid_lang_mapping:
_logger.warning(
f"QID {qid} already has a language mapping, not adding it again. This might happen when using qids remap dict. If that is the case, ignore this warning."
)
continue
items_by_importance = sorted(
lang_counts.items(), key=lambda x: (-x[1], -lang_sizes[x[0]])
)
Expand Down Expand Up @@ -235,14 +252,14 @@ def __init__(
)

def run(self) -> None:
_logger.info("Starting to create KB")
self._kb_creator.run()
_logger.info("Finished creating KB")

_logger.info("Starting to create links")
self._links_creator.run()
_logger.info("Finished creating links")

_logger.info("Starting to create KB")
self._kb_creator.run()
_logger.info("Finished creating KB")


def create_multilingual_dataset(
source_dir: Union[str, Path],
Expand Down
2 changes: 1 addition & 1 deletion src/run_action_gin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

logging.basicConfig(level=logging.INFO)

from fire import Fire
import gin
import wandb

Expand All @@ -17,7 +18,6 @@
embed_links_for_generation,
)
from finetunings.generate_epochs.generate import generate
from fire import Fire
from multilingual_dataset.combine_embs import combine_embs_by_qid
from multilingual_dataset.creator import create_multilingual_dataset, run_kb_creator
from tokenization.runner import (
Expand Down
5 changes: 5 additions & 0 deletions src/scripts/multilingual/multilingual.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/bin/bash

# An example on how to build the multilingual dataset

python run_action_gin.py ../configs/general.gin create_multilingual_dataset --source_dir=/lnet/work/home-students-external/farhan/troja/outputs/finetuning_damuel_2 --langs=[af,be,ca,da,el,es,eu,fi,ga,gl,hi,hu,id,ja,la,lv,mt,nn,pt,ru,sk,sr,ta,tr,uk,vi,zh,ar,bg,cs,de,en,et,fa,fr,gd,he,hr,hy,it,ko,lt,mr,nl,pl,ro,se,sl,sv,te,ug,ur,wo] --dest_dir=/lnet/work/home-students-external/farhan/troja/outputs/all2/
21 changes: 15 additions & 6 deletions src/scripts/train/evaluate_no_slurm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,26 @@ echo "Current directory: $(pwd)"
MODEL_CONFIG_PATH="../configs/lealla_m.gin"
TRAIN_CONFIG_PATH="../configs/train.gin"

# DAMUEL_FOR_INDEX_NEW_DIR="$OUTPUTS/workdirs/all/damuel_for_index_8"
DAMUEL_FOR_INDEX_NEW_DIR="$OUTPUTS/triplets/combined"
MEWSLI_TOKENS_RAW="$OUTPUTS/tokens_mewsli_finetuning"
WORKDIR="$OUTPUTS/workdirs/all"
DAMUEL_FOR_INDEX_NEW_DIR="$OUTPUTS/all2/index"
# DAMUEL_FOR_INDEX_NEW_DIR="$OUTPUTS/triplets/combined"
MEWSLI_TOKENS_RAW="$OUTPUTS/finetuning_mewsli_2"
DAMUEL_DESCS_TOKENS_RAW="$OUTPUTS/all2/descs_pages"
WORKDIR="$OUTPUTS/all2"
ROUND_ID=7
MODELS_DIR="$WORKDIR/models_$ROUND_ID"
MODELS_DIR="$OUTPUTS/workdirs/all/models_$ROUND_ID"

ACTION_SCRIPT="run_action_gin.py $MODEL_CONFIG_PATH $TRAIN_CONFIG_PATH"

LANGUAGES=("ar" "de" "en" "es" "ja" "fa" "sr" "ta" "tr")

if [ ! "$(ls -A $DAMUEL_FOR_INDEX_NEW_DIR)" ]; then
echo "Running embs generating for damuel"
../venv/bin/python $ACTION_SCRIPT "embs_from_tokens_model_name_and_state_dict" \
--source_path="$DAMUEL_DESCS_TOKENS_RAW" \
--dest_path="$DAMUEL_FOR_INDEX_NEW_DIR" \
--state_dict_path="$MODELS_DIR/final.pth"
fi

for LANG in "${LANGUAGES[@]}"; do
echo "Processing language: $LANG"

Expand All @@ -39,4 +48,4 @@ for LANG in "${LANGUAGES[@]}"; do

echo "Completed processing for language: $LANG"
echo "----------------------------------------"
done
done
8 changes: 5 additions & 3 deletions src/tokenization/pipeline/loaders/damuel.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,14 @@ def _process_with_context(
title = self._extract_title(damuel_entry)
if title is None:
continue
original_titled = title
original_title = title
title_wrapped = self._wrap_title(title, self.label_token)

description = self._extract_description(damuel_entry)
if description is None:
description = ""
text = self.construct_text_from_title_and_description(
title_wrapped, description, original_titled
title_wrapped, description, original_title
)

qid = parse_qid(damuel_entry["qid"])
Expand Down Expand Up @@ -175,7 +175,9 @@ def _wrap_title(self, title: str, label_token: str) -> str:
def construct_text_from_title_and_description(
cls, title: str, description: str, original_title: str | None = None
) -> str:
if original_title is not None and description.startswith(original_title):
if original_title is not None and description.strip().startswith(
original_title.strip()
):
return f"{title}\n{description[len(original_title):]}"
return f"{title}\n{description}"

Expand Down
4 changes: 3 additions & 1 deletion src/tokenization/pipeline/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from .loaders import DaMuELDescriptionLoader, DaMuELLinkLoader, MewsliLoader
from .loggers import LoggerStep, StatisticsLogger
from .savers import NPZSaver
from .savers import NPZSaver, NPZSaverIncremental
from .tokenizers import CuttingTokenizer, SimpleTokenizer


Expand Down Expand Up @@ -102,6 +102,7 @@ def __init__(
remainder: int = None,
mod: int = None,
require_link_wiki_origin: bool = True,
save_every: int = 1000000,
):
super().__init__()
self.add(
Expand All @@ -115,6 +116,7 @@ def __init__(
)
self.add(CuttingTokenizer(tokenizer, expected_size, label_token))
self.add(NPZSaver(output_filename, compress))
# self.add(NPZSaverIncremental(output_filename, compress, save_every))


class DamuelLinkMentionPipeline(Pipeline):
Expand Down
4 changes: 2 additions & 2 deletions src/tokenization/pipeline/savers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .npz import NPZSaver
from .npz import NPZSaver, NPZSaverIncremental

__all__ = ["NPZSaver"]
__all__ = ["NPZSaver", "NPZSaverIncremental"]
16 changes: 13 additions & 3 deletions src/utils/calculate_qids_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,29 @@

import numpy as np

from .loaders import load_qids


def calculate_qids_distribution_from_links(
links_dir: Path, index_qids: np.ndarray, transform_fn: callable = lambda x: x
) -> np.ndarray:
qid_to_cnt: dict[int, int] = defaultdict(int)
index_qids = set(index_qids)
for file in links_dir.iterdir():
if not file.suffix == ".npz":
continue
d = np.load(file)
qids = d["qids"]
qids = load_qids(file)
for qid in qids:
if qid not in index_qids:
continue
qid_to_cnt[qid] += 1
for key in qid_to_cnt:
qid_to_cnt[key] = transform_fn(qid_to_cnt[key])
print(len(list(qid_to_cnt.keys())))
print(len(set(qid_to_cnt.keys()) & set(index_qids)))
qids_observed_cnt = sum(qid_to_cnt.values())
return np.array([qid_to_cnt[qid] / qids_observed_cnt for qid in index_qids])
res = np.array(
[qid_to_cnt[qid] / qids_observed_cnt for qid in index_qids], dtype=np.float64
)
print(sum(res))
return res
Loading