Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
c47f845
perf(models): :zap: speed up bruteforce searcher by adding torch comp…
Yokto13 Apr 23, 2025
eeb0773
feat(tokens): :sparkles: add DamuelPageTypeLoader
Yokto13 Apr 25, 2025
9ad17e2
feat(utils): :sparkles: add qids_filter decorator which can be used t…
Yokto13 Apr 29, 2025
1943dbe
refactor: :construction: codebase overhaul
Yokto13 Jul 20, 2025
f71d26e
fix(utils): :bug: address failing qid filter tests
Yokto13 Jul 20, 2025
a889ef0
feat(utils): :sparkles: scripts for calculating stats about dataset
Yokto13 Jul 24, 2025
270e983
perf(train): :zap: speed up generate with amp
Yokto13 Jul 24, 2025
ab9e002
perf(train): :zap: adding compile, iterable dataset to make prefetch …
Yokto13 Jul 27, 2025
f50b087
fix(train): :bug: fix double loading of the same data
Yokto13 Aug 3, 2025
9fcb6d5
refactor(train): :recycle: return back LightWeightDataset version
Yokto13 Aug 3, 2025
857e758
perf(utils): :zap: small performance improvements
Yokto13 Aug 3, 2025
e5e0c10
feat(train): :sparkles: add option to run all recalls from one python…
Yokto13 Aug 3, 2025
3a27aeb
test(train): :white_check_mark: evaluate tests
Yokto13 Aug 3, 2025
6f23a48
feat(utils): :sparkles: add script for qid ocurrence analysis
Yokto13 Aug 5, 2025
c2cab82
feat(multiling): :sparkles: option to limit link qids, faster mixing
Yokto13 Aug 7, 2025
c8f5110
refactor(multiling): :recycle: better defaults
Yokto13 Aug 9, 2025
e4b724c
perf(train): :zap: decrease redundant work in evaluation
Yokto13 Sep 10, 2025
0a6d8d6
feat(train): :fire: attempts to improve model and kill performance :D
Yokto13 Sep 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions configs/general.gin
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@ remap_qids_decorator.json_path=%qids_remap_json
qids_remap.old_to_new_qids_path=%qids_remap_json
#qids_remap_json="/net/projects/damuel/dev/damuel_1.1-dev_qid_redirects.json"
qids_remap_json="/lnet/work/home-students-external/farhan/damuel/dev/damuel_2.0-dev_qid_redirects.json"

#filter_qids_npy_path="/lnet/work/home-students-external/farhan/troja/filtered_qids2.npy"
#qid_filter.filter_path=%filter_qids_npy_path
2 changes: 1 addition & 1 deletion configs/lealla_m.gin
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
model_path="/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base"
inference_batch_size=380000
inference_batch_size=120000
output_type="pooler_output"

embs_from_tokens_and_model_name_at.model_name=%model_path
Expand Down
8 changes: 8 additions & 0 deletions configs/multilingual_dataset.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
source_dir="/lnet/work/home-students-external/farhan/troja/outputs/finetuning_damuel2"
langs=["af", "be", "ca", "da", "el", "es", "eu", "fi", "ga", "gl", "hi", "hu", "id", "ja", "la", "lv", "mt", "nn", "pt", "ru", "sk", "sr", "ta", "tr", "uk", "vi", "zh", "ar", "bg", "cs", "de", "en", "et", "fa", "fr", "gd", "he", "hr", "hy", "it", "ko", "lt", "mr", "nl", "pl", "ro", "se", "sl", "sv", "te", "ug", "ur", "wo"]
dest_dir="/lnet/work/home-students-external/farhan/troja/outputs/v2_normal/"


create_multilingual_dataset.source_dir=%source_dir
create_multilingual_dataset.langs=%langs
create_multilingual_dataset.dest_dir=%dest_dir
38 changes: 28 additions & 10 deletions src/finetunings/evaluation/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,41 @@
from finetunings.evaluation.find_recall import find_recall
import logging

from finetunings.evaluation.find_recall import (
find_recall_with_searcher,
load_embs_and_qids_with_normalization,
)
from models.searchers.brute_force_searcher import BruteForceSearcher

_RECALLS = [1, 10, 100]

_logger = logging.getLogger("finetuning.evaluation.evaluate")


def _construct_mewsli_path(root_dir: str, finetuning_round: int, lang: str) -> str:
return f"{root_dir}/mewsli_embs_{lang}_{finetuning_round}"


def _construct_damuel_path(root_dir: str, finetuning_round: int) -> str:
next_finetuning_round = finetuning_round + 1
return f"{root_dir}/damuel_for_index_{next_finetuning_round}"


def run_recall_calculation(damuel_dir, mewsli_dir, recall=None):
recalls = _RECALLS if recall is None else [recall]
find_recall(damuel_dir, mewsli_dir, recalls)


def evaluate(
damuel_desc_tokens,
mewsli_tokens,
model_path,
damuel_dir,
mewsli_dir,
state_dict=None,
root_dir: str,
finetuning_round: int,
langs: list[str] = ["ar", "de", "en", "es", "ja", "fa", "sr", "ta", "tr"],
):
raise NotImplementedError()
damuel_path = _construct_damuel_path(root_dir, finetuning_round)

damuel_embs, damuel_qids = load_embs_and_qids_with_normalization(damuel_path)
searcher = BruteForceSearcher(damuel_embs, damuel_qids)

if __name__ == "__main__":
evaluate()
for lang in langs:
mewsli_path = _construct_mewsli_path(root_dir, finetuning_round, lang)
_logger.info(f"Calculating recall for {lang}")
find_recall_with_searcher(searcher, mewsli_path, _RECALLS)
17 changes: 17 additions & 0 deletions src/finetunings/evaluation/find_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,23 @@ def find_recall(
_logger.info(f"Recall at {R}: {recall}")


@paths_exist(path_arg_ids=[1])
def find_recall_with_searcher(
searcher: BruteForceSearcher,
mewsli: str,
recalls: list[int],
) -> None:
mewsli_embs, mewsli_qids = load_embs_and_qids_with_normalization(mewsli)

rc = RecallCalculator(searcher)

for R in recalls:
_logger.info("Calculating recall...")
recall = rc.recall(mewsli_embs, mewsli_qids, R)
wandb.log({f"recall_at_{R}": recall})
_logger.info(f"Recall at {R}: {recall}")


def find_candidates(
damuel_entities: str, candidates_path: str, mewsli: str, recall: int
) -> None:
Expand Down
4 changes: 4 additions & 0 deletions src/finetunings/file_processing/gathers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ def move_tokens(source, dest, m=1, r=0, max_to_copy=float("inf")):
source = Path(source)
dest = Path(dest)
already_copied = 0
print(
f"Moving tokens from {source} to {dest} with m={m}, r={r}, max_to_copy={max_to_copy}"
)
print(os.listdir(source))
for fn in sorted(os.listdir(source)):
if not _wanted_fn(fn, m, r):
continue
Expand Down
54 changes: 49 additions & 5 deletions src/finetunings/finetune_model/data.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
from dataclasses import dataclass
import logging
from pathlib import Path
from typing import Any

import numpy as np
import torch
import torch.nn as nn
import wandb
from torch.utils.data import Dataset
from torch.utils.data import Dataset, IterableDataset

_logger = logging.getLogger("finetuning.finetune_model.data")


@dataclass
class SaveInformation:
output_path: Path
is_final: bool
epoch: int = None
recall: int = None
epoch: int | None = None
recall: int | None = None
name: str | None = None


def _load_epoch_npz(path: Path, epoch: int | str) -> tuple:
Expand All @@ -31,7 +35,8 @@ def construct_non_final_name():


def _save_final_model(model: nn.Module, save_information: SaveInformation) -> None:
torch.save(model.state_dict(), f"{save_information.output_path}/final.pth")
name = save_information.name if save_information.name else "finals.pth"
torch.save(model.state_dict(), f"{save_information.output_path}/{name}")


def save_model(model: nn.Module, save_information: SaveInformation) -> None:
Expand Down Expand Up @@ -75,7 +80,7 @@ def __init__(
self._rank = rank
self._dataset_dir = dataset_dir
self._epoch = epoch
self._data = self._load()
# self._data = self._load()
self._links_cnt = None
self._descriptions_cnt = None
self._len = None
Expand All @@ -96,6 +101,9 @@ def descriptions_cnt(self) -> int:
return self._descriptions_cnt

def _load(self) -> Any:
_logger.info(
f"Loading dataset from {self._dataset_dir} for epoch {self._epoch}, rank {self._rank}, world size {self._world_size}"
)
self._set_cnts()
this_share_start, this_share_end = self._get_share_bounds()
if this_share_end <= self.links_cnt:
Expand Down Expand Up @@ -143,3 +151,39 @@ def _set_cnts(self) -> None:
def _get_data_obj(self) -> Any:
d = np.load(self._dataset_dir / f"epoch_{self._epoch}.npz")
return d


class LightWeightIterableDataset(IterableDataset):
def __init__(
self, dataset_dir: Path, epoch: int, rank: int = 1, world_size: int = 1
) -> None:
super().__init__()
self._world_size = world_size
self._rank = rank
self._dataset_dir = dataset_dir
self._epoch = epoch
self._dataset: LightWeightDataset | None = self._load_next()

def __iter__(self):
while self._dataset is not None:
for i in range(len(self._dataset)):
yield self._dataset[i]
try:
self._dataset = self._load_next()
except FileNotFoundError:
self._dataset = None

@property
def links_cnt(self) -> int:
return self._dataset.links_cnt

@property
def descriptions_cnt(self) -> int:
return self._dataset.descriptions_cnt

def _load_next(self) -> LightWeightDataset:
dataset = LightWeightDataset(
self._dataset_dir, self._epoch, self._rank, self._world_size
)
self._epoch += 1
return dataset
Loading
Loading