From c47f84583297247e7d09047d50e407fb3ead584f Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Wed, 23 Apr 2025 12:33:11 +0200 Subject: [PATCH 01/18] perf(models): :zap: speed up bruteforce searcher by adding torch compile, amp casting and slight improvements to the structure of the code --- src/models/searchers/brute_force_searcher.py | 64 +++++++++++++++----- tests/models/test_brute_force_searcher.py | 11 ++-- 2 files changed, 53 insertions(+), 22 deletions(-) diff --git a/src/models/searchers/brute_force_searcher.py b/src/models/searchers/brute_force_searcher.py index d4c728a..f36cec8 100644 --- a/src/models/searchers/brute_force_searcher.py +++ b/src/models/searchers/brute_force_searcher.py @@ -25,13 +25,18 @@ def __init__( self.device: torch.device = torch.device("cpu") super().__init__(embs, results, run_build_from_init) - @torch.compile def find(self, batch: np.ndarray, num_neighbors: int) -> np.ndarray: - with torch.no_grad(): + @torch.compile + def _find(batch: np.ndarray) -> np.ndarray: batch_torch: torch.Tensor = torch.from_numpy(batch).to(self.device) # embs after build are (dim, embs_count) dot_product: torch.Tensor = F.linear(batch_torch, self.embs) _, top_indices = dot_product.topk(num_neighbors) + return top_indices + + with torch.inference_mode(): + top_indices: torch.Tensor = _find(batch) + top_indices_np: np.ndarray = top_indices.cpu().numpy() return self.results[top_indices_np] @@ -45,6 +50,7 @@ def __init__(self, kb_embs, num_neighbors): self.kb_embs: torch.Tensor = nn.Parameter(kb_embs) self.num_neighbors: int = num_neighbors + @torch.compile def forward(self, x): dot_product = F.linear(x, self.kb_embs) _, top_indices = dot_product.topk(self.num_neighbors) @@ -66,20 +72,46 @@ def __init__( super().__init__(embs, results, run_build_from_init) def find(self, batch: np.ndarray, num_neighbors: int) -> np.ndarray: - if self.module_searcher is None: - self.module_searcher = nn.DataParallel( - _WrappedSearcher(torch.from_numpy(self.embs), num_neighbors) - ) - self.module_searcher.to(self.device) - self.required_num_neighbors = num_neighbors - if self.required_num_neighbors != num_neighbors: - raise ValueError( - f"num_neighbors was changed from {self.required_num_neighbors} to {num_neighbors} and this is not allowed in DPBruteForceSearcher" - ) - with torch.no_grad(): - top_indices: torch.Tensor = self.module_searcher( - torch.tensor(batch, device=self.device) - ) + """ + Finds the nearest neighbors for a given batch of input data. + CAREFUL: This is an optimized version that comes with potential pitfalls to get better performance. + Read Notes for details! + + Args: + batch (np.ndarray): A batch of input data for which neighbors are to be found. + num_neighbors (int): The number of nearest neighbors to retrieve. + Returns: + np.ndarray: An array containing the results corresponding to the nearest neighbors. + Raises: + TypeError: If `module_searcher` if an unexpected attribute access occurs when using module_searcher. + Notes: + - It is not possible to change num_neighbors after the first call to find. + If you need to do that, you need to reinitialize this object. If you call the find with different + num_neighbors, it will not raise an error and will fail silently. + - The first call to find will be slow, because the module_searcher will be initialized and torch.compile is called. + """ + with torch.inference_mode(), torch.autocast( + device_type=self.device.type, dtype=torch.float16 + ): + # A try except trick to avoid the overhead of checking if the module_searcher is None + # on every call to find. + # This is a bit of a hack, but it should make things faster as we are suggesting that the module_searcher is initialized. + try: + top_indices: torch.Tensor = self.module_searcher( + torch.tensor(batch, device=self.device) + ) + except TypeError as e: + if self.module_searcher is not None: + raise e + self.module_searcher = nn.DataParallel( + _WrappedSearcher(torch.from_numpy(self.embs), num_neighbors) + ) + self.module_searcher.to(self.device) + self.required_num_neighbors = num_neighbors + top_indices: torch.Tensor = self.module_searcher( + torch.tensor(batch, device=self.device) + ) + top_indices_np: np.ndarray = top_indices.cpu().numpy() return self.results[top_indices_np] diff --git a/tests/models/test_brute_force_searcher.py b/tests/models/test_brute_force_searcher.py index 5ae580e..a44b829 100644 --- a/tests/models/test_brute_force_searcher.py +++ b/tests/models/test_brute_force_searcher.py @@ -7,8 +7,8 @@ DPBruteForceSearcher, ) -torch.compiler.disable(BruteForceSearcher.find) -torch.compiler.disable(DPBruteForceSearcher.find) +# torch.compiler.disable(BruteForceSearcher.find) +# torch.compiler.disable(DPBruteForceSearcher.find) def test_search_present(): @@ -39,7 +39,6 @@ def test_search_missing(): ] ) searcher = BruteForceSearcher(embs, np.arange(4)) - res = searcher.find(np.array([[1.0, 0.0, 1.0]]), 2) assert res[0][0] == 0 @@ -63,7 +62,6 @@ def test_search_large(): class TestDPBruteForceSearcher: - @pytest.fixture def small_embs(self): return np.array( @@ -125,8 +123,9 @@ def test_device_selection(self, small_embs): def test_changing_num_neighbors(self, small_embs): searcher = DPBruteForceSearcher(small_embs, np.arange(len(small_embs))) searcher.find(np.random.random((1, 3)), 2) # Initialize with 2 neighbors - with pytest.raises(Exception): - searcher.find(np.random.random((1, 3)), 3) # Try to change to 3 neighbors + # with pytest.raises(Exception): + # Does nothing: + searcher.find(np.random.random((1, 3)), 3) # Try to change to 3 neighbors def test_dataparallel_initialization(self, small_embs): searcher = DPBruteForceSearcher(small_embs, np.arange(len(small_embs))) From eeb0773e2a195fbcabc5a37741c6b9dc0e506013 Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Fri, 25 Apr 2025 11:24:44 +0200 Subject: [PATCH 02/18] feat(tokens): :sparkles: add DamuelPageTypeLoader --- src/tokenization/pipeline/loaders/__init__.py | 8 +- src/tokenization/pipeline/loaders/damuel.py | 33 +++++++ .../pipeline/loaders/test_damuel.py | 86 ++++++++++++++++++- 3 files changed, 125 insertions(+), 2 deletions(-) diff --git a/src/tokenization/pipeline/loaders/__init__.py b/src/tokenization/pipeline/loaders/__init__.py index 6fc25d8..1555a93 100644 --- a/src/tokenization/pipeline/loaders/__init__.py +++ b/src/tokenization/pipeline/loaders/__init__.py @@ -1,4 +1,9 @@ -from .damuel import DaMuELDescriptionLoader, DaMuELLinkLoader, DaMuELStartLoader +from .damuel import ( + DaMuELDescriptionLoader, + DaMuELLinkLoader, + DaMuELStartLoader, + DaMuELPageTypeLoader, +) from .mewsli import MewsliLoader __all__ = [ @@ -6,4 +11,5 @@ "DaMuELLinkLoader", "DaMuELStartLoader", "MewsliLoader", + "DaMuELPageTypeLoader", ] diff --git a/src/tokenization/pipeline/loaders/damuel.py b/src/tokenization/pipeline/loaders/damuel.py index efe8891..1c4ce37 100644 --- a/src/tokenization/pipeline/loaders/damuel.py +++ b/src/tokenization/pipeline/loaders/damuel.py @@ -222,3 +222,36 @@ def __init__( # Here WikiKeyFilter is required because links are only in Wikipages self.add(WikiKeyFilter()) self.add(DaMuELLinkProcessor(use_context, require_link_wiki_origin)) + + +class DaMuELPageTypeProcessor(PipelineStep): + def __init__(self, extract_qid: bool = False): + super().__init__() + self._extract_qid = extract_qid + + def process(self, input_gen=None): + for damuel_entry in input_gen: + page_type = self._get_page_type(damuel_entry) + if self._extract_qid: + qid = parse_qid(damuel_entry["qid"]) + yield page_type, qid + else: + yield page_type, + + def _get_page_type(self, damuel_entry: dict) -> str: + if "page_type" in damuel_entry: + return damuel_entry["page_type"] + return "none" + + +class DaMuELPageTypeLoader(Pipeline): + def __init__( + self, + path: str, + remainder: int = None, + mod: int = None, + extract_qid: bool = False, + ): + super().__init__() + self.add(DaMuELStartLoader(path, remainder, mod)) + self.add(DaMuELPageTypeProcessor(extract_qid)) diff --git a/tests/tokenization/pipeline/loaders/test_damuel.py b/tests/tokenization/pipeline/loaders/test_damuel.py index e614497..10da638 100644 --- a/tests/tokenization/pipeline/loaders/test_damuel.py +++ b/tests/tokenization/pipeline/loaders/test_damuel.py @@ -3,7 +3,11 @@ import pytest from tokenization.pipeline.loaders import DaMuELStartLoader -from tokenization.pipeline.loaders.damuel import DaMuELDescriptionProcessor +from tokenization.pipeline.loaders.damuel import ( + DaMuELDescriptionProcessor, + DaMuELPageTypeProcessor, + DaMuELPageTypeLoader, +) class TestDaMuELStartLoader: @@ -94,3 +98,83 @@ def test_description_title_concatenation_with_original_title( title, description, original_title ) assert text == expected_text + + +class TestDaMuELPageTypeProcessor: + @pytest.fixture + def data(self): + return [ + {"qid": "Q1", "text": "Hello", "page_type": "page"}, + {"qid": "Q2", "text": "World", "page_type": "section"}, + {"qid": "Q3", "text": "Foo", "page_type": "page"}, + {"qid": "Q4", "text": "Bar", "page_type": "section"}, + {"qid": "Q5", "text": "Bar"}, + {"qid": "Q6", "text": "Bar"}, + ] + + def test_damuel_page_type_processor_default(self, data): + processor = DaMuELPageTypeProcessor() + results = list(processor.process(data)) + + expected_results = [ + ("page",), + ("section",), + ("page",), + ("section",), + ("none",), + ("none",), + ] + + assert results == expected_results + + def test_damuel_page_type_processor_with_qid(self, data): + processor = DaMuELPageTypeProcessor(extract_qid=True) + results = list(processor.process(data)) + + expected_results = [ + ("page", 1), + ("section", 2), + ("page", 3), + ("section", 4), + ("none", 5), + ("none", 6), + ] + + assert results == expected_results + + +class TestDaMuELPageTypeLoader: + @pytest.fixture + def damuel_data(self, tmp_path): + data_dir = tmp_path / "damuel_data" + data_dir.mkdir() + + file1 = data_dir / "part-00000" + file1.write_text( + '{"qid": "Q1", "text": "Hello"}\n{"qid": "Q2", "text": "World"}' + ) + + file2 = data_dir / "part-00001" + file2.write_text('{"qid": "Q3", "text": "Foo"}\n{"qid": "Q4", "text": "Bar"}') + + compressed_file = data_dir / "part-00002.xz" + with lzma.open(compressed_file, "wt") as f: + f.write( + '{"qid": "Q5", "text": "Compressed"}\n{"qid": "Q6", "text": "Data"}' + ) + + return str(data_dir) + + @pytest.mark.parametrize( + "extract_qid", + [ + True, + False, + ], + ) + def test(self, damuel_data, extract_qid): + """Very simple test to check that the loader is not crashing.""" + loader = DaMuELPageTypeLoader(damuel_data, extract_qid=extract_qid) + results = list(loader.process()) + + assert len(results) > 0 From 9ad17e2354452f766f0c540ad6fbaef94d240076 Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Tue, 29 Apr 2025 10:26:52 +0200 Subject: [PATCH 03/18] feat(utils): :sparkles: add qids_filter decorator which can be used to remove some qids from the system --- src/utils/loaders.py | 7 ++ src/utils/qid_filter.py | 54 ++++++++++ tests/utils/test_loaders.py | 20 ++++ tests/utils/test_qids_filter.py | 169 ++++++++++++++++++++++++++++++++ 4 files changed, 250 insertions(+) create mode 100644 src/utils/qid_filter.py create mode 100644 tests/utils/test_qids_filter.py diff --git a/src/utils/loaders.py b/src/utils/loaders.py index cf1f52a..31e34ae 100644 --- a/src/utils/loaders.py +++ b/src/utils/loaders.py @@ -92,6 +92,13 @@ def load_qids(file_path: str | Path) -> np.ndarray: return d["qids"] +@remap_qids_decorator(qids_index=None, json_path=gin.REQUIRED) +def load_qids_npy(file_path: str | Path) -> np.ndarray: + if type(file_path) == str: + file_path = Path(file_path) + return np.load(file_path) + + @_sort_by_output(1) @remap_qids_decorator(qids_index=1, json_path=gin.REQUIRED) def load_mentions_from_dir(dir_path: str | Path) -> tuple[np.ndarray, np.ndarray]: diff --git a/src/utils/qid_filter.py b/src/utils/qid_filter.py new file mode 100644 index 0000000..548d973 --- /dev/null +++ b/src/utils/qid_filter.py @@ -0,0 +1,54 @@ +import functools + +import numpy as np + +from .loaders import load_qids_npy + + +@functools.cache +def _load_filter(path: str) -> set: + """Load QIDs to filter from a `.npy` file.""" + try: + arr = load_qids_npy(path) + except Exception: + raise FileNotFoundError(f"Cannot load filter file: {path}") + return set(arr.tolist()) + + +def qid_filter(qids_index: int | None, filter_path: str = None): + """Decorator that filters out QIDs listed in a `.npy` file. + + Args: + qids_index (int | None): Index of the QIDs in the input data. If None assumes the input data to be just qids. + filter_path (str): Path to the `.npy` file containing QIDs to filter out. + Can be empty. If empty, no filtering is applied; the decorator is an identity. + """ + + def decorator(fn): + @functools.wraps(fn) + def wrapper(*args, **kwargs): + result = fn(*args, **kwargs) + if filter_path is None: + return result + + mask_set = _load_filter(filter_path) + is_valid_tuple = isinstance(result, tuple) and 0 <= qids_index < len(result) + if is_valid_tuple: + qids = result[qids_index] + else: + qids = result + keep = ~np.isin(qids, list(mask_set)) + + if is_valid_tuple: + updated_result = tuple(result_array[keep] for result_array in result) + elif qids_index is None and not isinstance(result, tuple): + updated_result = result[keep] + else: + raise ValueError( + f"Invalid qids_index {qids_index} for the returned tuple." + ) + return updated_result + + return wrapper + + return decorator diff --git a/tests/utils/test_loaders.py b/tests/utils/test_loaders.py index fc7933f..09d8d28 100644 --- a/tests/utils/test_loaders.py +++ b/tests/utils/test_loaders.py @@ -12,6 +12,7 @@ load_mentions, load_qids, AliasTableLoader, + load_qids_npy, ) @@ -299,6 +300,25 @@ def test_load_qids(mock_qids_remap, use_string_path: bool) -> None: assert isinstance(loaded_qids, np.ndarray) +@pytest.mark.parametrize("use_string_path", [True, False]) +@patch("utils.qids_remap.qids_remap", side_effect=mock_remap_qids) +def test_load_qids_npy(mock_qids_remap, use_string_path: bool) -> None: + with tempfile.TemporaryDirectory() as temp_dir: + dir_path = Path(temp_dir) + if use_string_path: + dir_path = str(dir_path) + file_path = Path(dir_path) / "qids.npy" + + test_qids = np.array([100, 200, 300]) + + np.save(file_path, test_qids) + + loaded_qids = load_qids_npy(file_path) + + assert np.array_equal(loaded_qids, test_qids) + assert isinstance(loaded_qids, np.ndarray) + + @pytest.mark.parametrize("lowercase", [True, False]) class TestAliasTableLoader: def setup_method(self, lowercase): diff --git a/tests/utils/test_qids_filter.py b/tests/utils/test_qids_filter.py new file mode 100644 index 0000000..c2f4106 --- /dev/null +++ b/tests/utils/test_qids_filter.py @@ -0,0 +1,169 @@ +from unittest.mock import patch +import numpy as np +import pytest + +from utils.qid_filter import qid_filter + + +def mock_remap_qids(qids, _): + return qids + + +@patch("utils.qids_remap.qids_remap", side_effect=mock_remap_qids) +def test_qid_filter_no_filtering_default(mock_qids_remap): + @qid_filter(1) + def loader(data, qids): + return data, qids + + data = np.arange(6).reshape(3, 2) + qids = np.array([1, 2, 3]) + out_data, out_qids = loader(data, qids) + assert out_data is data + assert out_qids is qids + + +@patch("utils.qids_remap.qids_remap", side_effect=mock_remap_qids) +def test_qid_filter_invalid_path_raises(mock_qids_remap): + with pytest.raises(FileNotFoundError): + + @qid_filter(1, filter_path="") + def loader(data, qids): + return data, qids + + data = np.arange(6).reshape(3, 2) + qids = np.array([1, 2, 3]) + _ = loader(data, qids) + + +@patch("utils.qids_remap.qids_remap", side_effect=mock_remap_qids) +def test_qid_filter_basic_filtering(mock_qids_remap, tmp_path): + to_filter = np.array([1, 3]) + filter_file = tmp_path / "filter.npy" + np.save(filter_file, to_filter) + + @qid_filter(1, filter_path=str(filter_file)) + def loader(data, qids): + return data, qids + + data = np.array([[10], [20], [30], [40]]) + qids = np.array([1, 2, 3, 4]) + out_data, out_qids = loader(data, qids) + expected_qids = np.array([2, 4]) + expected_data = np.array([[20], [40]]) + assert np.array_equal(out_qids, expected_qids) + assert np.array_equal(out_data, expected_data) + + +@patch("utils.qids_remap.qids_remap", side_effect=mock_remap_qids) +def test_qid_filter_preserve_dtype(mock_qids_remap, tmp_path): + to_filter = np.array([2], dtype=np.int64) + filter_file = tmp_path / "filter.npy" + np.save(filter_file, to_filter) + + @qid_filter(1, filter_path=str(filter_file)) + def loader(data, qids): + return data, qids + + qids = np.array([1, 2, 3], dtype=np.int64) + data = np.array([[1.0], [2.0], [3.0]], dtype=np.float32) + out_data, out_qids = loader(data, qids) + assert out_qids.dtype == qids.dtype + assert out_data.dtype == data.dtype + + +@patch("utils.qids_remap.qids_remap", side_effect=mock_remap_qids) +def test_qid_filter_preserves_function_metadata(mock_qids_remap, tmp_path): + filter_file = tmp_path / "filter.npy" + np.save(filter_file, np.array([], dtype=int)) + + def dummy_loader(data, qids): + """Loader docstring.""" + return data, qids + + decorated = qid_filter(1, filter_path=str(filter_file))(dummy_loader) + assert decorated.__name__ == dummy_loader.__name__ + assert decorated.__doc__ == dummy_loader.__doc__ + + +@patch("utils.qids_remap.qids_remap", side_effect=mock_remap_qids) +def test_qid_filter_long_tuple(mock_qids_remap, tmp_path): + to_filter = np.array([1, 3]) + filter_file = tmp_path / "filter.npy" + np.save(filter_file, to_filter) + + @qid_filter(2, filter_path=str(filter_file)) + def loader(data, qids): + return data, data, qids, data + + data = np.array([[10], [20], [30], [40]]) + qids = np.array([1, 2, 3, 4]) + out_data, out_data2, out_qids, out_data3 = loader(data, qids) + expected_qids = np.array([2, 4]) + expected_data = np.array([[20], [40]]) + assert np.array_equal(out_qids, expected_qids) + assert np.array_equal(out_data, expected_data) + assert np.array_equal(out_data2, expected_data) + assert np.array_equal(out_data3, expected_data) + + +@patch("utils.qids_remap.qids_remap", side_effect=mock_remap_qids) +def test_qid_filter_no_tuple(mock_qids_remap, tmp_path): + to_filter = np.array([1, 3]) + filter_file = tmp_path / "filter.npy" + np.save(filter_file, to_filter) + + @qid_filter(None, filter_path=str(filter_file)) + def loader(qids): + return qids + + qids = np.array([1, 2, 3, 4]) + out_qids = loader(qids) + expected_qids = np.array([2, 4]) + + assert np.array_equal(out_qids, expected_qids) + + +@patch("utils.qids_remap.qids_remap", side_effect=mock_remap_qids) +def test_qid_filter_no_tuple_identity(mock_qids_remap): + + @qid_filter(None) + def loader(qids): + return qids + + qids = np.array([1, 2, 3, 4]) + out_qids = loader(qids) + expected_qids = qids + + assert np.array_equal(out_qids, expected_qids) + + +@patch("utils.qids_remap.qids_remap", side_effect=mock_remap_qids) +def test_qid_filter_assert_raises_value_error_list(mock_qids_remap, tmp_path): + to_filter = np.array([1, 3]) + filter_file = tmp_path / "filter.npy" + np.save(filter_file, to_filter) + + @qid_filter(1, filter_path=str(filter_file)) + def loader(qids): + return [qids] + + qids = np.array([1, 2, 3, 4]) + with pytest.raises(ValueError): + a = loader(qids) + print(a) + + +@pytest.mark.parametrize("idx", [-1, 2]) +@patch("utils.qids_remap.qids_remap", side_effect=mock_remap_qids) +def test_qid_filter_assert_raises_value_error_idx(mock_qids_remap, idx, tmp_path): + to_filter = np.array([1, 3]) + filter_file = tmp_path / "filter.npy" + np.save(filter_file, to_filter) + + @qid_filter(idx, filter_path=str(filter_file)) + def loader(qids): + return (qids,) + + qids = np.array([1, 2, 3, 4]) + with pytest.raises(ValueError): + _ = loader(qids) From 1943dbe4c69a99027d51a6139c087dd2ed058911 Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Sun, 20 Jul 2025 11:46:30 +0200 Subject: [PATCH 04/18] refactor: :construction: codebase overhaul --- configs/general.gin | 3 + configs/lealla_m.gin | 2 +- configs/multilingual_dataset.gin | 8 + src/finetunings/finetune_model/train_ddp.py | 4 +- src/finetunings/generate_epochs/datasets.py | 3 + src/models/searchers/brute_force_searcher.py | 12 +- src/multilingual_dataset/creator.py | 2 + src/run_action_gin.py | 1 + src/scripts/train/all_langs_no_slurm.sh | 30 ++- src/scripts/train/asi_se_to_rozbilo.sh | 188 ++++++++++++++++++ .../train/asi_se_to_rozbilo_vsechny.sh | 188 ++++++++++++++++++ src/scripts/train/evaluate_no_slurm.sh | 8 +- src/scripts/utils/compare.py | 41 ++++ src/scripts/utils/create_filter_npy.py | 14 ++ src/scripts/utils/create_qid_type_map.py | 28 +++ src/tokenization/pipeline/loaders/damuel.py | 7 + src/utils/loaders.py | 14 +- src/utils/qid_filter.py | 14 +- 18 files changed, 529 insertions(+), 38 deletions(-) create mode 100644 configs/multilingual_dataset.gin create mode 100755 src/scripts/train/asi_se_to_rozbilo.sh create mode 100755 src/scripts/train/asi_se_to_rozbilo_vsechny.sh create mode 100644 src/scripts/utils/compare.py create mode 100644 src/scripts/utils/create_filter_npy.py create mode 100644 src/scripts/utils/create_qid_type_map.py diff --git a/configs/general.gin b/configs/general.gin index b7f7b58..8ab5e01 100644 --- a/configs/general.gin +++ b/configs/general.gin @@ -2,3 +2,6 @@ remap_qids_decorator.json_path=%qids_remap_json qids_remap.old_to_new_qids_path=%qids_remap_json #qids_remap_json="/net/projects/damuel/dev/damuel_1.1-dev_qid_redirects.json" qids_remap_json="/lnet/work/home-students-external/farhan/damuel/dev/damuel_2.0-dev_qid_redirects.json" + +#filter_qids_npy_path="/lnet/work/home-students-external/farhan/troja/filtered_qids2.npy" +#qid_filter.filter_path=%filter_qids_npy_path diff --git a/configs/lealla_m.gin b/configs/lealla_m.gin index 6a03aa2..cdf50ba 100644 --- a/configs/lealla_m.gin +++ b/configs/lealla_m.gin @@ -1,5 +1,5 @@ model_path="/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base" -inference_batch_size=380000 +inference_batch_size=120000 output_type="pooler_output" embs_from_tokens_and_model_name_at.model_name=%model_path diff --git a/configs/multilingual_dataset.gin b/configs/multilingual_dataset.gin new file mode 100644 index 0000000..0f7fe45 --- /dev/null +++ b/configs/multilingual_dataset.gin @@ -0,0 +1,8 @@ +source_dir="/lnet/work/home-students-external/farhan/troja/outputs/finetuning_damuel2" +langs=["af", "be", "ca", "da", "el", "es", "eu", "fi", "ga", "gl", "hi", "hu", "id", "ja", "la", "lv", "mt", "nn", "pt", "ru", "sk", "sr", "ta", "tr", "uk", "vi", "zh", "ar", "bg", "cs", "de", "en", "et", "fa", "fr", "gd", "he", "hr", "hy", "it", "ko", "lt", "mr", "nl", "pl", "ro", "se", "sl", "sv", "te", "ug", "ur", "wo"] +dest_dir="/lnet/work/home-students-external/farhan/troja/outputs/v2_normal/" + + +create_multilingual_dataset.source_dir=%source_dir +create_multilingual_dataset.langs=%langs +create_multilingual_dataset.dest_dir=%dest_dir diff --git a/src/finetunings/finetune_model/train_ddp.py b/src/finetunings/finetune_model/train_ddp.py index 83f86e2..ba3d6e4 100644 --- a/src/finetunings/finetune_model/train_ddp.py +++ b/src/finetunings/finetune_model/train_ddp.py @@ -6,8 +6,6 @@ import torch -from tqdm import tqdm - torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True import gin @@ -143,7 +141,7 @@ def _ddp_train( labels = construct_labels(dataset) labels = torch.from_numpy(labels).to(rank) - for replica_part in tqdm(dataloader, total=len(dataloader)): + for replica_part in dataloader: with torch.autocast(device_type="cuda"): replica_part = forward_to_embeddings(replica_part, model) diff --git a/src/finetunings/generate_epochs/datasets.py b/src/finetunings/generate_epochs/datasets.py index e476b40..a7e653d 100644 --- a/src/finetunings/generate_epochs/datasets.py +++ b/src/finetunings/generate_epochs/datasets.py @@ -45,6 +45,9 @@ def __init__(self, dir_path: Path, known_qids: npt.ArrayLike, batch_size: int): def __iter__(self) -> Iterator[tuple[np.ndarray, np.ndarray, np.ndarray]]: for file_path in self.file_paths: + print(file_path) + # TODO: A lot of time is spent just waiting for this to load, it would be nice to prefetch the data on a different worker + # Otherwise we are waiting a lot for a slow disk to read the gargantual file embs, qids, tokens = load_embs_qids_tokens(file_path) embs, qids, tokens = self._remove_when_qid_missing( diff --git a/src/models/searchers/brute_force_searcher.py b/src/models/searchers/brute_force_searcher.py index f36cec8..a36bf8d 100644 --- a/src/models/searchers/brute_force_searcher.py +++ b/src/models/searchers/brute_force_searcher.py @@ -26,7 +26,7 @@ def __init__( super().__init__(embs, results, run_build_from_init) def find(self, batch: np.ndarray, num_neighbors: int) -> np.ndarray: - @torch.compile + # @torch.compile def _find(batch: np.ndarray) -> np.ndarray: batch_torch: torch.Tensor = torch.from_numpy(batch).to(self.device) # embs after build are (dim, embs_count) @@ -50,7 +50,7 @@ def __init__(self, kb_embs, num_neighbors): self.kb_embs: torch.Tensor = nn.Parameter(kb_embs) self.num_neighbors: int = num_neighbors - @torch.compile + # @torch.compile def forward(self, x): dot_product = F.linear(x, self.kb_embs) _, top_indices = dot_product.topk(self.num_neighbors) @@ -71,6 +71,7 @@ def __init__( self.required_num_neighbors: Optional[int] = None super().__init__(embs, results, run_build_from_init) + @torch.compile def find(self, batch: np.ndarray, num_neighbors: int) -> np.ndarray: """ Finds the nearest neighbors for a given batch of input data. @@ -90,9 +91,10 @@ def find(self, batch: np.ndarray, num_neighbors: int) -> np.ndarray: num_neighbors, it will not raise an error and will fail silently. - The first call to find will be slow, because the module_searcher will be initialized and torch.compile is called. """ - with torch.inference_mode(), torch.autocast( - device_type=self.device.type, dtype=torch.float16 - ): + # with torch.inference_mode(), torch.autocast( + # device_type=self.device.type, dtype=torch.float16 + # ): + with torch.no_grad(): # A try except trick to avoid the overhead of checking if the module_searcher is None # on every call to find. # This is a bit of a hack, but it should make things faster as we are suggesting that the module_searcher is initialized. diff --git a/src/multilingual_dataset/creator.py b/src/multilingual_dataset/creator.py index dba9130..1afad76 100644 --- a/src/multilingual_dataset/creator.py +++ b/src/multilingual_dataset/creator.py @@ -6,6 +6,7 @@ from pathlib import Path from typing import Union +import gin import numpy as np from multilingual_dataset.mixer import Mixer @@ -244,6 +245,7 @@ def run(self) -> None: _logger.info("Finished creating links") +@gin.configurable def create_multilingual_dataset( source_dir: Union[str, Path], langs: list[str], diff --git a/src/run_action_gin.py b/src/run_action_gin.py index 0f838fc..cac9e24 100644 --- a/src/run_action_gin.py +++ b/src/run_action_gin.py @@ -35,6 +35,7 @@ embs_from_tokens_and_model_name, embs_from_tokens_model_name_and_state_dict, ) +from utils.qid_filter import qid_filter from utils.validate_tokens import validate_tokens print("Imports finished") diff --git a/src/scripts/train/all_langs_no_slurm.sh b/src/scripts/train/all_langs_no_slurm.sh index e05a30a..fee80df 100755 --- a/src/scripts/train/all_langs_no_slurm.sh +++ b/src/scripts/train/all_langs_no_slurm.sh @@ -15,11 +15,11 @@ echo "Current directory: $(pwd)" MODEL_CONFIG_PATH="../configs/lealla_m.gin" TRAIN_CONFIG_PATH="../configs/train.gin" -DAMUEL_DESCS_TOKENS_RAW="$OUTPUTS/v2/descs_pages" -DAMUEL_LINKS_TOKENS_RAW="$OUTPUTS/v2/links" +DAMUEL_DESCS_TOKENS_RAW="$OUTPUTS/v2_normal/descs_pages" +DAMUEL_LINKS_TOKENS_RAW="$OUTPUTS/v2_normal/links" MEWSLI_TOKENS_RAW="$OUTPUTS/tokens_mewsli_finetuning" WORKDIR="$OUTPUTS/workdirs/v2_retraining_with_model_from_all" -N_OF_ROUNDS=10 +N_OF_ROUNDS=4 run_ml_finetuning_round() { local DAMUEL_DESCS_TOKENS_RAW=$1 @@ -88,7 +88,7 @@ run_ml_finetuning_round() { mkdir -p "$BATCH_DIR" if [ ! "$(ls -A $BATCH_DIR)" ]; then echo "Running batches generating for damuel" - #../venv/bin/python -m cProfile -o "generate.prof" $ACTION_SCRIPT "generate" \ + # ../venv/bin/python -m cProfile -o "generate.prof" $ACTION_SCRIPT "generate" \ ../venv/bin/python $ACTION_SCRIPT "generate" \ --LINKS_EMBS_DIR="$DAMUEL_LINKS_DIR" \ --INDEX_TOKENS_DIR="$DAMUEL_DESCS_TOKENS_RAW" \ @@ -173,20 +173,16 @@ done STATE_DICT="None" -for ((ROUND_ID=3; ROUND_ID<$N_OF_ROUNDS; ROUND_ID++)) +for ((ROUND_ID=0; ROUND_ID<$N_OF_ROUNDS; ROUND_ID++)) do - run_ml_finetuning_round "$DAMUEL_DESCS_TOKENS_RAW" "$DAMUEL_LINKS_TOKENS_RAW" \ - "$MEWSLI_TOKENS_RAW" \ - "$WORKDIR" "$STATE_DICT" \ - "$ROUND_ID" "$N_OF_ROUNDS" - #if [ ! -e "$WORKDIR/models_$ROUND_ID/final.pth" ]; then - # echo "Running round $ROUND_ID" - - # run_ml_finetuning_round "$DAMUEL_DESCS_TOKENS_RAW" "$DAMUEL_LINKS_TOKENS_RAW" \ - # "$MEWSLI_TOKENS_RAW" \ - # "$WORKDIR" "$STATE_DICT" \ - # "$ROUND_ID" "$N_OF_ROUNDS" - #fi + if [ ! -e "$WORKDIR/models_$ROUND_ID/final.pth" ]; then + echo "Running round $ROUND_ID" + + run_ml_finetuning_round "$DAMUEL_DESCS_TOKENS_RAW" "$DAMUEL_LINKS_TOKENS_RAW" \ + "$MEWSLI_TOKENS_RAW" \ + "$WORKDIR" "$STATE_DICT" \ + "$ROUND_ID" "$N_OF_ROUNDS" + fi STATE_DICT="$WORKDIR/models_$ROUND_ID/final.pth" done diff --git a/src/scripts/train/asi_se_to_rozbilo.sh b/src/scripts/train/asi_se_to_rozbilo.sh new file mode 100755 index 0000000..97e8e54 --- /dev/null +++ b/src/scripts/train/asi_se_to_rozbilo.sh @@ -0,0 +1,188 @@ +#!/bin/bash + +# Runs the complete finetuning process. +# Expects tokens to be in the dirs specified below. +# Additionaly, one can specify additional parameters. +# For running, please also set up/fix the path to venv in run_finetuning_action.sh + +set -ueo pipefail + +cd ../../ + +echo "Running all_langs.sh" +echo "Current directory: $(pwd)" + +MODEL_CONFIG_PATH="../configs/lealla_m.gin" +TRAIN_CONFIG_PATH="../configs/train.gin" + +DAMUEL_DESCS_TOKENS_RAW="$OUTPUTS/v2_normal/descs_pages" +DAMUEL_LINKS_TOKENS_RAW="$OUTPUTS/v2_normal/links" +MEWSLI_TOKENS_RAW="$OUTPUTS/tokens_mewsli_finetuning" +WORKDIR="$OUTPUTS/workdirs/asi_se_to_rozbilo" +N_OF_ROUNDS=15 + +run_ml_finetuning_round() { + local DAMUEL_DESCS_TOKENS_RAW=$1 + local DAMUEL_LINKS_TOKENS_RAW=$2 + local MEWSLI_TOKENS_RAW=$3 + local WORKDIR=$4 + local STATE_DICT=${5:-"None"} + local ROUND_ID=${6:-"0"} + local N_OF_ROUNDS=${7} + + local STEPS_PER_EPOCH=1000 + + # Multiple by 2 to make sure that if a link contained something faulty we can skip it. + local LINKS_PER_ROUND=$(($STEPS_PER_EPOCH * 1000 * 1000)) + echo "LPR $LINKS_PER_ROUND" + + local ACTION_SCRIPT="run_action_gin.py $MODEL_CONFIG_PATH $TRAIN_CONFIG_PATH" + + ENV="../venv/bin/activate" + source $ENV + + # ====================TOKENS COPY==================== + + local DAMUEL_LINKS_TOKENS="$WORKDIR/damuel_links_together_tokens_$ROUND_ID" + if [ ! "$(ls -A $DAMUEL_LINKS_TOKENS)" ]; then + ../venv/bin/python $ACTION_SCRIPT "copy" \ + --source="$DAMUEL_LINKS_TOKENS_RAW" \ + --dest="$DAMUEL_LINKS_TOKENS" \ + --m="$N_OF_ROUNDS" \ + --r="$ROUND_ID" \ + --max_to_copy="$LINKS_PER_ROUND" + fi + + # ====================DAMUEL DESC EMBS==================== + + local DAMUEL_FOR_INDEX_DIR="$WORKDIR/damuel_for_index_$ROUND_ID" + + mkdir -p "$DAMUEL_FOR_INDEX_DIR" + + if [ ! "$(ls -A $DAMUEL_FOR_INDEX_DIR)" ]; then + echo "Running embs generating for damuel" + ../venv/bin/python $ACTION_SCRIPT "embs_from_tokens_model_name_and_state_dict" \ + --source_path="$DAMUEL_DESCS_TOKENS_RAW" \ + --dest_path="$DAMUEL_FOR_INDEX_DIR" \ + --state_dict_path="$STATE_DICT" + fi + + # ====================DAMUEL LINKS EMBEDDING==================== + + local DAMUEL_LINKS_DIR="$WORKDIR/links_embs_$ROUND_ID" + + mkdir -p "$DAMUEL_LINKS_DIR" + + if [ ! "$(ls -A $DAMUEL_LINKS_DIR)" ]; then + echo "Running embs generating for damuel links" + ../venv/bin/python $ACTION_SCRIPT "embed_links_for_generation" \ + --links_tokens_dir_path="$DAMUEL_LINKS_TOKENS" \ + --dest_dir_path="$DAMUEL_LINKS_DIR" \ + --state_dict_path="$STATE_DICT" + fi + + # ====================GENERATING BATCHES==================== + + local BATCH_DIR="$WORKDIR/batches_$ROUND_ID" + + mkdir -p "$BATCH_DIR" + if [ ! "$(ls -A $BATCH_DIR)" ]; then + echo "Running batches generating for damuel" + # ../venv/bin/python -m cProfile -o "generate.prof" $ACTION_SCRIPT "generate" \ + ../venv/bin/python $ACTION_SCRIPT "generate" \ + --LINKS_EMBS_DIR="$DAMUEL_LINKS_DIR" \ + --INDEX_TOKENS_DIR="$DAMUEL_DESCS_TOKENS_RAW" \ + --INDEX_EMBS_QIDS_DIR="$DAMUEL_FOR_INDEX_DIR" \ + --OUTPUT_DIR="$BATCH_DIR" + fi + + # ====================TRAINING MODEL==================== + + local MODELS_DIR="$WORKDIR/models_$ROUND_ID" + + mkdir -p $MODELS_DIR + + if [ ! "$(ls -A $MODELS_DIR)" ]; then + echo "Running training for damuel" + #../venv/bin/python -m cProfile -o "train_ddp.prof" $ACTION_SCRIPT "train_ddp" \ + ../venv/bin/python $ACTION_SCRIPT "train_ddp" \ + --DATASET_DIR="$BATCH_DIR" \ + --MODEL_SAVE_DIR="$MODELS_DIR" \ + --STATE_DICT_PATH="$STATE_DICT" + fi + + # ====================EVALUATION==================== + + local NEXT_INDEX=$(($ROUND_ID + 1)) + local DAMUEL_FOR_INDEX_NEW_DIR="$WORKDIR/damuel_for_index_$NEXT_INDEX" + mkdir -p "$DAMUEL_FOR_INDEX_NEW_DIR" + + if [ ! "$(ls -A $DAMUEL_FOR_INDEX_NEW_DIR)" ]; then + echo "Running embs generating for damuel" + ../venv/bin/python $ACTION_SCRIPT "embs_from_tokens_model_name_and_state_dict" \ + --source_path="$DAMUEL_DESCS_TOKENS_RAW" \ + --dest_path="$DAMUEL_FOR_INDEX_NEW_DIR" \ + --state_dict_path="$MODELS_DIR/final.pth" + fi + + local LANGUAGES=("ar" "de" "en" "es" "ja" "fa" "sr" "ta" "tr") + + for LANG in "${LANGUAGES[@]}"; do + echo "Processing language: $LANG" + + local LANG_TOKEN_DIR="$MEWSLI_TOKENS_RAW/$LANG" + local MEWSLI_EMBS_DIR="$WORKDIR/mewsli_embs_${LANG}_$ROUND_ID" + + mkdir -p "$MEWSLI_EMBS_DIR" + + if [ ! "$(ls -A $MEWSLI_EMBS_DIR)" ]; then + echo "Running embs generating for mewsli - Language: $LANG" + ../venv/bin/python $ACTION_SCRIPT "embs_from_tokens_model_name_and_state_dict" \ + --source_path="$LANG_TOKEN_DIR" \ + --dest_path="$MEWSLI_EMBS_DIR" \ + --state_dict_path="$MODELS_DIR/final.pth" + fi + + ../venv/bin/python $ACTION_SCRIPT "recalls" \ + --damuel_dir="$DAMUEL_FOR_INDEX_NEW_DIR" \ + --mewsli_dir="$MEWSLI_EMBS_DIR" + + echo "Completed processing for language: $LANG" + echo "----------------------------------------" + done + + rm -r $BATCH_DIR $DAMUEL_LINKS_DIR $DAMUEL_FOR_INDEX_DIR $DAMUEL_LINKS_TOKENS +} + +if [ ! -L "$WORKDIR" ]; then + mkdir -p "$WORKDIR" +fi + +DAMUEL_DESCS_TOKENS="$WORKDIR/damuel_descs_together_tokens" +if [ ! -L "$DAMUEL_DESCS_TOKENS" ]; then + mkdir -p "$DAMUEL_DESCS_TOKENS" +fi + +for ((ROUND_ID=0; ROUND_ID<$N_OF_ROUNDS; ROUND_ID++)) +do + DAMUEL_LINKS_TOKENS="$WORKDIR/damuel_links_together_tokens_$ROUND_ID" + if [ ! -L "$DAMUEL_LINKS_TOKENS" ]; then + mkdir -p "$DAMUEL_LINKS_TOKENS" + fi +done + +STATE_DICT="None" + +for ((ROUND_ID=0; ROUND_ID<$N_OF_ROUNDS; ROUND_ID++)) +do + if [ ! -e "$WORKDIR/models_$ROUND_ID/final.pth" ]; then + echo "Running round $ROUND_ID" + + run_ml_finetuning_round "$DAMUEL_DESCS_TOKENS_RAW" "$DAMUEL_LINKS_TOKENS_RAW" \ + "$MEWSLI_TOKENS_RAW" \ + "$WORKDIR" "$STATE_DICT" \ + "$ROUND_ID" "$N_OF_ROUNDS" + fi + + STATE_DICT="$WORKDIR/models_$ROUND_ID/final.pth" +done diff --git a/src/scripts/train/asi_se_to_rozbilo_vsechny.sh b/src/scripts/train/asi_se_to_rozbilo_vsechny.sh new file mode 100755 index 0000000..65bf431 --- /dev/null +++ b/src/scripts/train/asi_se_to_rozbilo_vsechny.sh @@ -0,0 +1,188 @@ +#!/bin/bash + +# Runs the complete finetuning process. +# Expects tokens to be in the dirs specified below. +# Additionaly, one can specify additional parameters. +# For running, please also set up/fix the path to venv in run_finetuning_action.sh + +set -ueo pipefail + +cd ../../ + +echo "Running all_langs.sh" +echo "Current directory: $(pwd)" + +MODEL_CONFIG_PATH="../configs/lealla_m.gin" +TRAIN_CONFIG_PATH="../configs/train.gin" + +DAMUEL_DESCS_TOKENS_RAW="$OUTPUTS/v2/descs_pages" +DAMUEL_LINKS_TOKENS_RAW="$OUTPUTS/v2/links" +MEWSLI_TOKENS_RAW="$OUTPUTS/tokens_mewsli_finetuning" +WORKDIR="$OUTPUTS/workdirs/asi_se_to_rozbilo_vsechny" +N_OF_ROUNDS=15 + +run_ml_finetuning_round() { + local DAMUEL_DESCS_TOKENS_RAW=$1 + local DAMUEL_LINKS_TOKENS_RAW=$2 + local MEWSLI_TOKENS_RAW=$3 + local WORKDIR=$4 + local STATE_DICT=${5:-"None"} + local ROUND_ID=${6:-"0"} + local N_OF_ROUNDS=${7} + + local STEPS_PER_EPOCH=1000 + + # Multiple by 2 to make sure that if a link contained something faulty we can skip it. + local LINKS_PER_ROUND=$(($STEPS_PER_EPOCH * 1000 * 1000)) + echo "LPR $LINKS_PER_ROUND" + + local ACTION_SCRIPT="run_action_gin.py $MODEL_CONFIG_PATH $TRAIN_CONFIG_PATH" + + ENV="../venv/bin/activate" + source $ENV + + # ====================TOKENS COPY==================== + + local DAMUEL_LINKS_TOKENS="$WORKDIR/damuel_links_together_tokens_$ROUND_ID" + if [ ! "$(ls -A $DAMUEL_LINKS_TOKENS)" ]; then + ../venv/bin/python $ACTION_SCRIPT "copy" \ + --source="$DAMUEL_LINKS_TOKENS_RAW" \ + --dest="$DAMUEL_LINKS_TOKENS" \ + --m="$N_OF_ROUNDS" \ + --r="$ROUND_ID" \ + --max_to_copy="$LINKS_PER_ROUND" + fi + + # ====================DAMUEL DESC EMBS==================== + + local DAMUEL_FOR_INDEX_DIR="$WORKDIR/damuel_for_index_$ROUND_ID" + + mkdir -p "$DAMUEL_FOR_INDEX_DIR" + + if [ ! "$(ls -A $DAMUEL_FOR_INDEX_DIR)" ]; then + echo "Running embs generating for damuel" + ../venv/bin/python $ACTION_SCRIPT "embs_from_tokens_model_name_and_state_dict" \ + --source_path="$DAMUEL_DESCS_TOKENS_RAW" \ + --dest_path="$DAMUEL_FOR_INDEX_DIR" \ + --state_dict_path="$STATE_DICT" + fi + + # ====================DAMUEL LINKS EMBEDDING==================== + + local DAMUEL_LINKS_DIR="$WORKDIR/links_embs_$ROUND_ID" + + mkdir -p "$DAMUEL_LINKS_DIR" + + if [ ! "$(ls -A $DAMUEL_LINKS_DIR)" ]; then + echo "Running embs generating for damuel links" + ../venv/bin/python $ACTION_SCRIPT "embed_links_for_generation" \ + --links_tokens_dir_path="$DAMUEL_LINKS_TOKENS" \ + --dest_dir_path="$DAMUEL_LINKS_DIR" \ + --state_dict_path="$STATE_DICT" + fi + + # ====================GENERATING BATCHES==================== + + local BATCH_DIR="$WORKDIR/batches_$ROUND_ID" + + mkdir -p "$BATCH_DIR" + if [ ! "$(ls -A $BATCH_DIR)" ]; then + echo "Running batches generating for damuel" + # ../venv/bin/python -m cProfile -o "generate.prof" $ACTION_SCRIPT "generate" \ + ../venv/bin/python $ACTION_SCRIPT "generate" \ + --LINKS_EMBS_DIR="$DAMUEL_LINKS_DIR" \ + --INDEX_TOKENS_DIR="$DAMUEL_DESCS_TOKENS_RAW" \ + --INDEX_EMBS_QIDS_DIR="$DAMUEL_FOR_INDEX_DIR" \ + --OUTPUT_DIR="$BATCH_DIR" + fi + + # ====================TRAINING MODEL==================== + + local MODELS_DIR="$WORKDIR/models_$ROUND_ID" + + mkdir -p $MODELS_DIR + + if [ ! "$(ls -A $MODELS_DIR)" ]; then + echo "Running training for damuel" + #../venv/bin/python -m cProfile -o "train_ddp.prof" $ACTION_SCRIPT "train_ddp" \ + ../venv/bin/python $ACTION_SCRIPT "train_ddp" \ + --DATASET_DIR="$BATCH_DIR" \ + --MODEL_SAVE_DIR="$MODELS_DIR" \ + --STATE_DICT_PATH="$STATE_DICT" + fi + + # ====================EVALUATION==================== + + local NEXT_INDEX=$(($ROUND_ID + 1)) + local DAMUEL_FOR_INDEX_NEW_DIR="$WORKDIR/damuel_for_index_$NEXT_INDEX" + mkdir -p "$DAMUEL_FOR_INDEX_NEW_DIR" + + if [ ! "$(ls -A $DAMUEL_FOR_INDEX_NEW_DIR)" ]; then + echo "Running embs generating for damuel" + ../venv/bin/python $ACTION_SCRIPT "embs_from_tokens_model_name_and_state_dict" \ + --source_path="$DAMUEL_DESCS_TOKENS_RAW" \ + --dest_path="$DAMUEL_FOR_INDEX_NEW_DIR" \ + --state_dict_path="$MODELS_DIR/final.pth" + fi + + local LANGUAGES=("ar" "de" "en" "es" "ja" "fa" "sr" "ta" "tr") + + for LANG in "${LANGUAGES[@]}"; do + echo "Processing language: $LANG" + + local LANG_TOKEN_DIR="$MEWSLI_TOKENS_RAW/$LANG" + local MEWSLI_EMBS_DIR="$WORKDIR/mewsli_embs_${LANG}_$ROUND_ID" + + mkdir -p "$MEWSLI_EMBS_DIR" + + if [ ! "$(ls -A $MEWSLI_EMBS_DIR)" ]; then + echo "Running embs generating for mewsli - Language: $LANG" + ../venv/bin/python $ACTION_SCRIPT "embs_from_tokens_model_name_and_state_dict" \ + --source_path="$LANG_TOKEN_DIR" \ + --dest_path="$MEWSLI_EMBS_DIR" \ + --state_dict_path="$MODELS_DIR/final.pth" + fi + + ../venv/bin/python $ACTION_SCRIPT "recalls" \ + --damuel_dir="$DAMUEL_FOR_INDEX_NEW_DIR" \ + --mewsli_dir="$MEWSLI_EMBS_DIR" + + echo "Completed processing for language: $LANG" + echo "----------------------------------------" + done + + rm -r $BATCH_DIR $DAMUEL_LINKS_DIR $DAMUEL_FOR_INDEX_DIR $DAMUEL_LINKS_TOKENS +} + +if [ ! -L "$WORKDIR" ]; then + mkdir -p "$WORKDIR" +fi + +DAMUEL_DESCS_TOKENS="$WORKDIR/damuel_descs_together_tokens" +if [ ! -L "$DAMUEL_DESCS_TOKENS" ]; then + mkdir -p "$DAMUEL_DESCS_TOKENS" +fi + +for ((ROUND_ID=0; ROUND_ID<$N_OF_ROUNDS; ROUND_ID++)) +do + DAMUEL_LINKS_TOKENS="$WORKDIR/damuel_links_together_tokens_$ROUND_ID" + if [ ! -L "$DAMUEL_LINKS_TOKENS" ]; then + mkdir -p "$DAMUEL_LINKS_TOKENS" + fi +done + +STATE_DICT="None" + +for ((ROUND_ID=0; ROUND_ID<$N_OF_ROUNDS; ROUND_ID++)) +do + if [ ! -e "$WORKDIR/models_$ROUND_ID/final.pth" ]; then + echo "Running round $ROUND_ID" + + run_ml_finetuning_round "$DAMUEL_DESCS_TOKENS_RAW" "$DAMUEL_LINKS_TOKENS_RAW" \ + "$MEWSLI_TOKENS_RAW" \ + "$WORKDIR" "$STATE_DICT" \ + "$ROUND_ID" "$N_OF_ROUNDS" + fi + + STATE_DICT="$WORKDIR/models_$ROUND_ID/final.pth" +done diff --git a/src/scripts/train/evaluate_no_slurm.sh b/src/scripts/train/evaluate_no_slurm.sh index d94c8db..51bca56 100755 --- a/src/scripts/train/evaluate_no_slurm.sh +++ b/src/scripts/train/evaluate_no_slurm.sh @@ -9,11 +9,11 @@ MODEL_CONFIG_PATH="../configs/lealla_m.gin" TRAIN_CONFIG_PATH="../configs/train.gin" # DAMUEL_FOR_INDEX_NEW_DIR="$OUTPUTS/workdirs/all/damuel_for_index_8" -DAMUEL_FOR_INDEX_NEW_DIR="$OUTPUTS/workdirs/v2tests/embs" -DAMUEL_DESCS_TOKENS_RAW="$OUTPUTS/v2/descs_pages" +DAMUEL_FOR_INDEX_NEW_DIR="/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo/damuel_for_index_2" +DAMUEL_DESCS_TOKENS_RAW="$OUTPUTS/v2_normal/descs_pages" MEWSLI_TOKENS_RAW="$OUTPUTS/tokens_mewsli_finetuning" -WORKDIR="$OUTPUTS/workdirs/v2tests" -ROUND_ID=0 +WORKDIR="$OUTPUTS/workdirs/asi_se_to_rozbilo" +ROUND_ID=1 MODELS_DIR="$WORKDIR/models_$ROUND_ID" ACTION_SCRIPT="run_action_gin.py $MODEL_CONFIG_PATH $TRAIN_CONFIG_PATH" diff --git a/src/scripts/utils/compare.py b/src/scripts/utils/compare.py new file mode 100644 index 0000000..7ba5f90 --- /dev/null +++ b/src/scripts/utils/compare.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +""" +compare_qids.py – Compare “qids” sets contained in .npz files +Usage: python compare_qids.py +""" + +import sys +from pathlib import Path +import numpy as np + + +def collect_qids(dir_path: Path) -> set: + """Return the union of all qids found in .npz files under dir_path.""" + qids: set = set() + for file in dir_path.glob("*.npz"): + with np.load(file, allow_pickle=True) as data: + if "qids" in data: + qids.update(data["qids"].tolist()) + return qids + + +def main(dir_a: Path, dir_b: Path) -> None: + q_a = collect_qids(dir_a) + q_b = collect_qids(dir_b) + + only_a = q_a - q_b + only_b = q_b - q_a + common = q_a & q_b + + print(f"Total unique qids in {dir_a}: {len(q_a)}") + print(f"Total unique qids in {dir_b}: {len(q_b)}") + + print(f"Unique to {dir_a} ({len(only_a)}):") + print(f"Unique to {dir_b} ({len(only_b)}):") + print(f"Common to both ({len(common)})") + + +if __name__ == "__main__": + if len(sys.argv) != 3: + sys.exit("Usage: python compare_qids.py ") + main(Path(sys.argv[1]), Path(sys.argv[2])) diff --git a/src/scripts/utils/create_filter_npy.py b/src/scripts/utils/create_filter_npy.py new file mode 100644 index 0000000..7690424 --- /dev/null +++ b/src/scripts/utils/create_filter_npy.py @@ -0,0 +1,14 @@ +import json +import numpy as np + +IN_PATH = "/lnet/work/home-students-external/farhan/troja/qid_type_agnostic2.json" +OUT_PATH = "/lnet/work/home-students-external/farhan/troja/filtered_qids2.npy" + +with open(IN_PATH, "r") as f: + qid_type_map = json.load(f) + +filtered_qids = [int(qid) for qid, t in qid_type_map.items() if t != "none"] + +np.save(OUT_PATH, np.array(filtered_qids)) + +print(f"Saved {len(filtered_qids)} QIDs to {OUT_PATH}") diff --git a/src/scripts/utils/create_qid_type_map.py b/src/scripts/utils/create_qid_type_map.py new file mode 100644 index 0000000..c1c158d --- /dev/null +++ b/src/scripts/utils/create_qid_type_map.py @@ -0,0 +1,28 @@ +import json +import sys + +sys.path.append("../../") + +from tokenization.pipeline.loaders import DaMuELPageTypeLoader + +OUT = "/lnet/work/home-students-external/farhan/troja/qid_type_agnostic2.json" + +loader = DaMuELPageTypeLoader( + "/lnet/work/home-students-external/farhan/damuel/dev/damuel_2.0-dev_kb_agnostic", + extract_qid=True, +) + +qid_type_map = {} +idx = 0 +for page_type, qid in loader.process(): + if idx % 100000 == 0: + print(f"Processed {idx} entries") + if qid in qid_type_map: + print("Shouldn't happen") + qid_type_map[qid] = page_type + idx += 1 + + +with open(OUT, "w") as f: + json.dump(qid_type_map, f) +print(f"Saved qid_type_map to {OUT}") diff --git a/src/tokenization/pipeline/loaders/damuel.py b/src/tokenization/pipeline/loaders/damuel.py index 1c4ce37..8fc520b 100644 --- a/src/tokenization/pipeline/loaders/damuel.py +++ b/src/tokenization/pipeline/loaders/damuel.py @@ -2,6 +2,7 @@ import lzma import os from collections.abc import Generator +import random import orjson from tqdm.auto import tqdm @@ -48,6 +49,7 @@ def process( position=tqdm_position, ): file_path = os.path.join(self.path, filename) + print(f"Processing file: {file_path}") with self._open_file(file_path) as file: for line in file: yield orjson.loads(line) @@ -231,9 +233,14 @@ def __init__(self, extract_qid: bool = False): def process(self, input_gen=None): for damuel_entry in input_gen: + # print(damuel_entry) page_type = self._get_page_type(damuel_entry) if self._extract_qid: qid = parse_qid(damuel_entry["qid"]) + if random.random() < 0.0001: + print(f"Page type: {page_type}, QID: {qid}") + # if random.random() < 0.1: + # assert False yield page_type, qid else: yield page_type, diff --git a/src/utils/loaders.py b/src/utils/loaders.py index 31e34ae..7bf086a 100644 --- a/src/utils/loaders.py +++ b/src/utils/loaders.py @@ -8,16 +8,19 @@ import numpy as np import pandas as pd -from utils.qids_remap import remap_qids_decorator - # from tokenization.pipeline import DamuelAliasTablePipeline from tokenization.runner import run_alias_table_damuel +from utils.qids_remap import remap_qids_decorator +from utils.qid_filter import qid_filter + current_file_path = os.path.abspath(__file__) project_root = os.path.dirname(os.path.dirname(os.path.dirname(current_file_path))) config_path = os.path.join(project_root, "configs", "general.gin") gin.parse_config_file(config_path) +# config_path = os.path.join(project_root, "configs", "multilingual_dataset.gin") +# gin.parse_config_file(config_path) def _sort_by_output(output_idx: int): @@ -34,6 +37,7 @@ def _wrapper(*args, **kwargs): # @_sort_by_output(1) +@qid_filter(qids_index=1) @remap_qids_decorator(qids_index=1, json_path=gin.REQUIRED) def load_embs_and_qids(dir_path: str | Path) -> tuple[np.ndarray, np.ndarray]: """Loads embeddings and qids from the directory. @@ -76,6 +80,7 @@ def load_embs_qids_tokens(path: str | Path) -> tuple[np.ndarray, np.ndarray]: # @_sort_by_output(1) +@qid_filter(qids_index=1) @remap_qids_decorator(qids_index=1, json_path=gin.REQUIRED) def load_mentions(file_path: str | Path) -> tuple[np.ndarray, np.ndarray]: if type(file_path) == str: @@ -84,6 +89,7 @@ def load_mentions(file_path: str | Path) -> tuple[np.ndarray, np.ndarray]: return d["tokens"], d["qids"] +@qid_filter(qids_index=None) @remap_qids_decorator(qids_index=None, json_path=gin.REQUIRED) def load_qids(file_path: str | Path) -> np.ndarray: if type(file_path) == str: @@ -92,6 +98,7 @@ def load_qids(file_path: str | Path) -> np.ndarray: return d["qids"] +@qid_filter(qids_index=None) @remap_qids_decorator(qids_index=None, json_path=gin.REQUIRED) def load_qids_npy(file_path: str | Path) -> np.ndarray: if type(file_path) == str: @@ -100,6 +107,7 @@ def load_qids_npy(file_path: str | Path) -> np.ndarray: @_sort_by_output(1) +@qid_filter(qids_index=1) @remap_qids_decorator(qids_index=1, json_path=gin.REQUIRED) def load_mentions_from_dir(dir_path: str | Path) -> tuple[np.ndarray, np.ndarray]: tokens, qids = [], [] @@ -111,6 +119,7 @@ def load_mentions_from_dir(dir_path: str | Path) -> tuple[np.ndarray, np.ndarray return np.array(tokens), np.array(qids) +@qid_filter(qids_index=1) @remap_qids_decorator(qids_index=1, json_path=gin.REQUIRED) def load_tokens_and_qids(file_path: str | Path) -> tuple[np.ndarray, np.ndarray]: d = np.load(file_path) @@ -145,6 +154,7 @@ def load_mewsli(self, lang: str) -> tuple[list[str], np.ndarray]: df["mention"] = df["mention"].str.lower() return df["mention"].tolist(), df["qid"].apply(lambda x: int(x[1:])).to_numpy() + @qid_filter(qids_index=1) @remap_qids_decorator(qids_index=1, json_path=gin.REQUIRED) def load_damuel(self, lang) -> tuple[list[str], np.ndarray]: data = run_alias_table_damuel(self._construct_damuel_path(lang)) diff --git a/src/utils/qid_filter.py b/src/utils/qid_filter.py index 548d973..5d4b703 100644 --- a/src/utils/qid_filter.py +++ b/src/utils/qid_filter.py @@ -1,21 +1,23 @@ import functools +import gin import numpy as np -from .loaders import load_qids_npy - @functools.cache def _load_filter(path: str) -> set: """Load QIDs to filter from a `.npy` file.""" try: - arr = load_qids_npy(path) + # TODO: qid remap this loading + print("Loading filter from", path) + arr = np.load(path) except Exception: raise FileNotFoundError(f"Cannot load filter file: {path}") return set(arr.tolist()) -def qid_filter(qids_index: int | None, filter_path: str = None): +@gin.configurable +def qid_filter(qids_index: int | None, filter_path: str | None = None): """Decorator that filters out QIDs listed in a `.npy` file. Args: @@ -31,13 +33,13 @@ def wrapper(*args, **kwargs): if filter_path is None: return result - mask_set = _load_filter(filter_path) + mask_set = set(_load_filter(filter_path)) is_valid_tuple = isinstance(result, tuple) and 0 <= qids_index < len(result) if is_valid_tuple: qids = result[qids_index] else: qids = result - keep = ~np.isin(qids, list(mask_set)) + keep = np.array([q not in mask_set for q in qids]) if is_valid_tuple: updated_result = tuple(result_array[keep] for result_array in result) From f71d26ea54bdfcc5ab7245a55734145bf8a63df6 Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Sun, 20 Jul 2025 14:12:33 +0200 Subject: [PATCH 05/18] fix(utils): :bug: address failing qid filter tests --- src/utils/qid_filter.py | 14 +++++++++++--- tests/utils/test_qids_filter.py | 33 +++++++++++++++++++-------------- 2 files changed, 30 insertions(+), 17 deletions(-) diff --git a/src/utils/qid_filter.py b/src/utils/qid_filter.py index 5d4b703..fabc758 100644 --- a/src/utils/qid_filter.py +++ b/src/utils/qid_filter.py @@ -24,6 +24,10 @@ def qid_filter(qids_index: int | None, filter_path: str | None = None): qids_index (int | None): Index of the QIDs in the input data. If None assumes the input data to be just qids. filter_path (str): Path to the `.npy` file containing QIDs to filter out. Can be empty. If empty, no filtering is applied; the decorator is an identity. + + Raises: + ValueError: If `qids_index` is not valid for the returned tuple. + TypeError: If the index is specified but the result is array of qids but not a tuple. """ def decorator(fn): @@ -34,14 +38,18 @@ def wrapper(*args, **kwargs): return result mask_set = set(_load_filter(filter_path)) - is_valid_tuple = isinstance(result, tuple) and 0 <= qids_index < len(result) - if is_valid_tuple: + is_tuple = isinstance(result, tuple) + if is_tuple: + if qids_index is None: + raise ValueError( + "qids_index cannot be None for a tuple result from the decorated function." + ) qids = result[qids_index] else: qids = result keep = np.array([q not in mask_set for q in qids]) - if is_valid_tuple: + if is_tuple: updated_result = tuple(result_array[keep] for result_array in result) elif qids_index is None and not isinstance(result, tuple): updated_result = result[keep] diff --git a/tests/utils/test_qids_filter.py b/tests/utils/test_qids_filter.py index c2f4106..89f98c8 100644 --- a/tests/utils/test_qids_filter.py +++ b/tests/utils/test_qids_filter.py @@ -137,33 +137,38 @@ def loader(qids): assert np.array_equal(out_qids, expected_qids) +@pytest.mark.parametrize("idx", [1, 2]) @patch("utils.qids_remap.qids_remap", side_effect=mock_remap_qids) -def test_qid_filter_assert_raises_value_error_list(mock_qids_remap, tmp_path): +def test_qid_filter_assert_raises_value_error_idx(mock_qids_remap, idx, tmp_path): to_filter = np.array([1, 3]) filter_file = tmp_path / "filter.npy" np.save(filter_file, to_filter) - @qid_filter(1, filter_path=str(filter_file)) + @qid_filter(idx, filter_path=str(filter_file)) def loader(qids): - return [qids] + return (qids,) qids = np.array([1, 2, 3, 4]) - with pytest.raises(ValueError): - a = loader(qids) - print(a) + with pytest.raises(IndexError): + _ = loader(qids) -@pytest.mark.parametrize("idx", [-1, 2]) @patch("utils.qids_remap.qids_remap", side_effect=mock_remap_qids) -def test_qid_filter_assert_raises_value_error_idx(mock_qids_remap, idx, tmp_path): +def test_qid_filter_raises_value_error_for_none_index_with_tuple( + mock_qids_remap, tmp_path +): to_filter = np.array([1, 3]) filter_file = tmp_path / "filter.npy" np.save(filter_file, to_filter) - @qid_filter(idx, filter_path=str(filter_file)) - def loader(qids): - return (qids,) + @qid_filter(None, filter_path=str(filter_file)) + def loader(data, qids): + return data, qids - qids = np.array([1, 2, 3, 4]) - with pytest.raises(ValueError): - _ = loader(qids) + data = np.array([[1, 2], [3, 4]]) + qids = np.array([1, 2]) + with pytest.raises( + ValueError, + match="qids_index cannot be None for a tuple result from the decorated function.", + ): + loader(data, qids) From a889ef0cd1e441dba7593458533b93c9a56de2c1 Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Thu, 24 Jul 2025 21:13:05 +0200 Subject: [PATCH 06/18] feat(utils): :sparkles: scripts for calculating stats about dataset --- src/scripts/utils/lang_statistics.py | 80 ++++++++++++++++++++++++++ src/scripts/utils/length_statistics.py | 78 +++++++++++++++++++++++++ 2 files changed, 158 insertions(+) create mode 100644 src/scripts/utils/lang_statistics.py create mode 100644 src/scripts/utils/length_statistics.py diff --git a/src/scripts/utils/lang_statistics.py b/src/scripts/utils/lang_statistics.py new file mode 100644 index 0000000..ea2b8e3 --- /dev/null +++ b/src/scripts/utils/lang_statistics.py @@ -0,0 +1,80 @@ +""" +Given a path to a directory of multilingual dataset, calculates how many entities are there per language. +Usage: python lang_statistics.py + +Assumes that the directory contains files with suffix _.npz +""" + +import os +import numpy as np +from collections import defaultdict +import fire +from tqdm import tqdm + + +def count_qids_by_language(directory_path="."): + """ + Count qids by language from .npz files in the given directory. + + Args: + directory_path (str): Path to the directory containing .npz files + """ + # Counter mapping language code to number of qids + lang_counter = defaultdict(int) + + # Get list of .npz files + npz_files = [f for f in os.listdir(directory_path) if f.endswith(".npz")] + + if not npz_files: + print("No .npz files found in the directory.") + return + + # Iterate through all files with progress bar + for filename in tqdm(npz_files, desc="Processing files"): + # Extract language code from filename (format: _.npz) + lang_code = filename.split("_")[-2] + + try: + # Load the .npz file + filepath = os.path.join(directory_path, filename) + data = np.load(filepath) + + # Get qids and count them + if "qids" in data: + qids = data["qids"] + qid_count = len(qids) + lang_counter[lang_code] += qid_count + else: + print(f"Warning: 'qids' key not found in {filename}") + + # Close the file + data.close() + + except Exception as e: + print(f"Error processing {filename}: {e}") + + # Sort languages by count (most to least) + sorted_langs = sorted(lang_counter.items(), key=lambda x: x[1], reverse=True) + + # Display results + print("\n" + "=" * 50) + print("LANGUAGE STATISTICS (Most to Least Used)") + print("=" * 50) + + if not sorted_langs: + print("No .npz files found or no qids data available.") + return + + total_qids = sum(lang_counter.values()) + + for rank, (lang, count) in enumerate(sorted_langs, 1): + percentage = (count / total_qids) * 100 if total_qids > 0 else 0 + print(f"{rank:2d}. {lang:10s}: {count:8,} qids ({percentage:5.1f}%)") + + print("-" * 50) + print(f"Total languages: {len(sorted_langs)}") + print(f"Total qids: {total_qids:,}") + + +if __name__ == "__main__": + fire.Fire(count_qids_by_language) diff --git a/src/scripts/utils/length_statistics.py b/src/scripts/utils/length_statistics.py new file mode 100644 index 0000000..d934902 --- /dev/null +++ b/src/scripts/utils/length_statistics.py @@ -0,0 +1,78 @@ +""" +Given a path to a directory of multilingual dataset, calculates average and std of token lengths. +Usage: python length_statistics.py +Assumes that the directory contains files with suffix _.npz +""" + +import os +import math +import concurrent.futures + +import numpy as np +from tqdm import tqdm +import fire + + +def _file_stats(path): + data = np.load(path) + tokens = data["tokens"] + lengths = np.count_nonzero(tokens, axis=1) + data.close() + count = lengths.size + sum_len = lengths.sum(dtype=np.int64) + sum_sq = np.square(lengths, dtype=np.int64).sum(dtype=np.int64) + return count, sum_len, sum_sq + + +def calculate_length_statistics(directory_path=".", workers: int | None = None): + npz_files = [ + os.path.join(directory_path, f) + for f in os.listdir(directory_path) + if f.endswith(".npz") + ] + if not npz_files: + print("No .npz files found.") + return + if workers is None: + workers = min(32, (os.cpu_count() or 1)) + + total_count = 0 + total_sum = 0 + total_sum_sq = 0 + + with concurrent.futures.ProcessPoolExecutor(max_workers=workers) as pool: + futures = {pool.submit(_file_stats, p): p for p in npz_files} + for fut in tqdm( + concurrent.futures.as_completed(futures), + total=len(futures), + desc="Processing files", + ): + try: + count, s, ss = fut.result() + except Exception as e: + print(f"Error {futures[fut]}: {e}") + continue + total_count += count + total_sum += s + total_sum_sq += ss + + if not total_count: + print("No sequences found.") + return + + mean = total_sum / total_count + variance = (total_sum_sq / total_count) - mean * mean + std = math.sqrt(max(variance, 0.0)) + + print("\n" + "=" * 50) + print("TOKEN LENGTH STATISTICS") + print("=" * 50) + print(f"Total sequences processed: {total_count:,}") + print(f"Average length: {mean:.2f}") + print(f"Standard deviation: {std:.2f}") + print(f"Minimum length: 0 (assumed)") + print("=" * 50) + + +if __name__ == "__main__": + fire.Fire(calculate_length_statistics) From 270e98339dff132576322483b943af36b9cd15d2 Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Thu, 24 Jul 2025 21:13:23 +0200 Subject: [PATCH 07/18] perf(train): :zap: speed up generate with amp --- src/models/searchers/brute_force_searcher.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/models/searchers/brute_force_searcher.py b/src/models/searchers/brute_force_searcher.py index a36bf8d..ad1e7e3 100644 --- a/src/models/searchers/brute_force_searcher.py +++ b/src/models/searchers/brute_force_searcher.py @@ -99,9 +99,10 @@ def find(self, batch: np.ndarray, num_neighbors: int) -> np.ndarray: # on every call to find. # This is a bit of a hack, but it should make things faster as we are suggesting that the module_searcher is initialized. try: - top_indices: torch.Tensor = self.module_searcher( - torch.tensor(batch, device=self.device) - ) + with torch.amp.autocast(device_type="cuda", dtype=torch.float16): + top_indices: torch.Tensor = self.module_searcher( + torch.from_numpy(batch).to(self.device) + ) except TypeError as e: if self.module_searcher is not None: raise e @@ -111,7 +112,7 @@ def find(self, batch: np.ndarray, num_neighbors: int) -> np.ndarray: self.module_searcher.to(self.device) self.required_num_neighbors = num_neighbors top_indices: torch.Tensor = self.module_searcher( - torch.tensor(batch, device=self.device) + torch.from_numpy(batch).to(self.device) ) top_indices_np: np.ndarray = top_indices.cpu().numpy() From ab9e002ac3a3c85881998ab7a7147590740b157d Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Sun, 27 Jul 2025 19:08:39 +0200 Subject: [PATCH 08/18] perf(train): :zap: adding compile, iterable dataset to make prefetch actually work, gradient clipping --- src/finetunings/finetune_model/data.py | 41 +++++++- src/finetunings/finetune_model/train_ddp.py | 104 +++++++++++--------- 2 files changed, 96 insertions(+), 49 deletions(-) diff --git a/src/finetunings/finetune_model/data.py b/src/finetunings/finetune_model/data.py index ebedbff..008ba29 100644 --- a/src/finetunings/finetune_model/data.py +++ b/src/finetunings/finetune_model/data.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +import logging from pathlib import Path from typing import Any @@ -6,7 +7,9 @@ import torch import torch.nn as nn import wandb -from torch.utils.data import Dataset +from torch.utils.data import Dataset, IterableDataset + +_logger = logging.getLogger("finetuning.finetune_model.data") @dataclass @@ -143,3 +146,39 @@ def _set_cnts(self) -> None: def _get_data_obj(self) -> Any: d = np.load(self._dataset_dir / f"epoch_{self._epoch}.npz") return d + + +class LightWeightIterableDataset(IterableDataset): + def __init__( + self, dataset_dir: Path, epoch: int, rank: int = 1, world_size: int = 1 + ) -> None: + super().__init__() + self._world_size = world_size + self._rank = rank + self._dataset_dir = dataset_dir + self._epoch = epoch + self._dataset: LightWeightDataset | None = self._load_next() + + def __iter__(self): + while self._dataset is not None: + for i in range(len(self._dataset)): + yield self._dataset[i] + try: + self._dataset = self._load_next() + except FileNotFoundError: + self._dataset = None + + @property + def links_cnt(self) -> int: + return self._dataset.links_cnt + + @property + def descriptions_cnt(self) -> int: + return self._dataset.descriptions_cnt + + def _load_next(self) -> LightWeightDataset: + dataset = LightWeightDataset( + self._dataset_dir, self._epoch, self._rank, self._world_size + ) + self._epoch += 1 + return dataset diff --git a/src/finetunings/finetune_model/train_ddp.py b/src/finetunings/finetune_model/train_ddp.py index ba3d6e4..b1ed119 100644 --- a/src/finetunings/finetune_model/train_ddp.py +++ b/src/finetunings/finetune_model/train_ddp.py @@ -21,6 +21,7 @@ from finetunings.finetune_model.data import ( LightWeightDataset, + LightWeightIterableDataset, save_model, SaveInformation, ) @@ -97,11 +98,13 @@ def _ddp_train( STATE_DICT_PATH: str | None, TARGET_DIM: int | None, WEIGHT_DECAY: float | None, + GRADIENT_CLIP: float = 1.0, ): setup(rank, world_size) model = load_model(FOUNDATION_MODEL_PATH, STATE_DICT_PATH, TARGET_DIM) model = DDP(model.to(rank), device_ids=[rank]) + model = torch.compile(model) is_the_main_process = rank == 0 @@ -116,6 +119,7 @@ def _ddp_train( "MODEL_SAVE_DIR": MODEL_SAVE_DIR, "STATE_DICT_PATH": STATE_DICT_PATH, "WEIGHT_DECAY": WEIGHT_DECAY, + "GRADIENT_CLIP": GRADIENT_CLIP, }, ) @@ -124,69 +128,70 @@ def _ddp_train( scaler = torch.amp.GradScaler("cuda") + def step(): + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIP) + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + + step = torch.compile(step) + running_averages = None if is_the_main_process: running_averages = RunningAverages(_RUNNING_AVERAGE_SMALL, _RUNNING_AVERAGE_BIG) - for epoch in range(EPOCHS): - model.train() - - train_loss = 0 - - dataset = LightWeightDataset(DATASET_DIR, epoch, rank, world_size) - dataloader = DataLoader( - dataset, batch_size=None, pin_memory=True, num_workers=2, prefetch_factor=2 - ) + dataset = LightWeightIterableDataset(DATASET_DIR, 0, rank, world_size) + dataloader = DataLoader( + dataset, batch_size=None, pin_memory=True, num_workers=2, prefetch_factor=2 + ) - labels = construct_labels(dataset) - labels = torch.from_numpy(labels).to(rank) + labels = construct_labels(dataset) + labels = torch.from_numpy(labels).to(rank) - for replica_part in dataloader: + for replica_part in dataloader: - with torch.autocast(device_type="cuda"): - replica_part = forward_to_embeddings(replica_part, model) + with torch.autocast(device_type="cuda"): + replica_part = forward_to_embeddings(replica_part, model) - with torch.no_grad(): # all_gather cannot propagate gradients so make it explicit - all_replicas = [ - torch.zeros_like(replica_part) for _ in range(world_size) - ] - torch.distributed.all_gather(all_replicas, replica_part) + with torch.no_grad(): # all_gather cannot propagate gradients so make it explicit + all_replicas = [ + torch.zeros_like(replica_part) for _ in range(world_size) + ] + torch.distributed.all_gather(all_replicas, replica_part) - # Allow gradients propagation for the slice owned by the current process - all_replicas[rank] = replica_part + # Allow gradients propagation for the slice owned by the current process + all_replicas[rank] = replica_part - all_replicas = torch.cat(all_replicas, dim=0) + all_replicas = torch.cat(all_replicas, dim=0) - links_embedded, descs_embedded = ( - all_replicas[: dataset.links_cnt], - all_replicas[dataset.links_cnt :], - ) + links_embedded, descs_embedded = ( + all_replicas[: dataset.links_cnt], + all_replicas[dataset.links_cnt :], + ) - loss, outputs = _calculate_loss( - links_embedded, descs_embedded, labels, LOGIT_MULTIPLIER, criterion - ) + loss, outputs = _calculate_loss( + links_embedded, descs_embedded, labels, LOGIT_MULTIPLIER, criterion + ) - scaler.scale(loss).backward() - scaler.step(optimizer) - scaler.update() - norm_for_logs = get_gradient_norm(model.module) - optimizer.zero_grad() + # norm_for_logs = step() + step() + if is_the_main_process: loss_item = loss.item() - train_loss += loss_item - - if is_the_main_process: - process_metrics( - outputs, - labels, - loss_item, - running_averages, - { - "gradient_norm": norm_for_logs, - }, - ) - if is_the_main_process and epoch % 100 == 0: - save_final_model(model.module, MODEL_SAVE_DIR) + + process_metrics( + outputs, + labels, + loss_item, + running_averages, + # { + # "gradient_norm": norm_for_logs, + # }, + ) + if is_the_main_process: + save_final_model(model.module, MODEL_SAVE_DIR) if is_the_main_process: # We only save the model on the main process and only once @@ -212,6 +217,7 @@ def train_ddp( STATE_DICT_PATH=None, TARGET_DIM=None, WEIGHT_DECAY=0.0, + GRADIENT_CLIP=1.0, ): DATASET_DIR = Path(DATASET_DIR) FOUNDATION_MODEL_PATH = str(FOUNDATION_MODEL_PATH) @@ -222,6 +228,7 @@ def train_ddp( STATE_DICT_PATH = str(STATE_DICT_PATH) if STATE_DICT_PATH is not None else None TARGET_DIM = int(TARGET_DIM) if TARGET_DIM is not None else None WEIGHT_DECAY = float(WEIGHT_DECAY) + GRADIENT_CLIP = float(GRADIENT_CLIP) world_size = torch.cuda.device_count() @@ -238,6 +245,7 @@ def train_ddp( STATE_DICT_PATH, TARGET_DIM, WEIGHT_DECAY, + GRADIENT_CLIP, ), nprocs=world_size, ) From f50b087b5514874d2a3f49c7b3c4808e47f61662 Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Sun, 3 Aug 2025 22:01:44 +0200 Subject: [PATCH 09/18] fix(train): :bug: fix double loading of the same data --- src/finetunings/finetune_model/data.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/finetunings/finetune_model/data.py b/src/finetunings/finetune_model/data.py index 008ba29..2d8f812 100644 --- a/src/finetunings/finetune_model/data.py +++ b/src/finetunings/finetune_model/data.py @@ -78,7 +78,7 @@ def __init__( self._rank = rank self._dataset_dir = dataset_dir self._epoch = epoch - self._data = self._load() + # self._data = self._load() self._links_cnt = None self._descriptions_cnt = None self._len = None @@ -99,6 +99,9 @@ def descriptions_cnt(self) -> int: return self._descriptions_cnt def _load(self) -> Any: + _logger.info( + f"Loading dataset from {self._dataset_dir} for epoch {self._epoch}, rank {self._rank}, world size {self._world_size}" + ) self._set_cnts() this_share_start, this_share_end = self._get_share_bounds() if this_share_end <= self.links_cnt: From 9fcb6d551f7a2039ca055e37b7014d372fa410ab Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Sun, 3 Aug 2025 22:02:34 +0200 Subject: [PATCH 10/18] refactor(train): :recycle: return back LightWeightDataset version --- src/finetunings/finetune_model/train_ddp.py | 94 +++++++++++---------- 1 file changed, 48 insertions(+), 46 deletions(-) diff --git a/src/finetunings/finetune_model/train_ddp.py b/src/finetunings/finetune_model/train_ddp.py index b1ed119..83a5817 100644 --- a/src/finetunings/finetune_model/train_ddp.py +++ b/src/finetunings/finetune_model/train_ddp.py @@ -142,56 +142,58 @@ def step(): if is_the_main_process: running_averages = RunningAverages(_RUNNING_AVERAGE_SMALL, _RUNNING_AVERAGE_BIG) - dataset = LightWeightIterableDataset(DATASET_DIR, 0, rank, world_size) - dataloader = DataLoader( - dataset, batch_size=None, pin_memory=True, num_workers=2, prefetch_factor=2 - ) + dataset = LightWeightDataset(DATASET_DIR, 0, rank, world_size) labels = construct_labels(dataset) labels = torch.from_numpy(labels).to(rank) - - for replica_part in dataloader: - - with torch.autocast(device_type="cuda"): - replica_part = forward_to_embeddings(replica_part, model) - - with torch.no_grad(): # all_gather cannot propagate gradients so make it explicit - all_replicas = [ - torch.zeros_like(replica_part) for _ in range(world_size) - ] - torch.distributed.all_gather(all_replicas, replica_part) - - # Allow gradients propagation for the slice owned by the current process - all_replicas[rank] = replica_part - - all_replicas = torch.cat(all_replicas, dim=0) - - links_embedded, descs_embedded = ( - all_replicas[: dataset.links_cnt], - all_replicas[dataset.links_cnt :], - ) - - loss, outputs = _calculate_loss( - links_embedded, descs_embedded, labels, LOGIT_MULTIPLIER, criterion - ) - - # norm_for_logs = step() - step() - + for epoch in range(EPOCHS): if is_the_main_process: - loss_item = loss.item() - - process_metrics( - outputs, - labels, - loss_item, - running_averages, - # { - # "gradient_norm": norm_for_logs, - # }, - ) - if is_the_main_process: - save_final_model(model.module, MODEL_SAVE_DIR) + _logger.info(f"Starting epoch {epoch + 1}/{EPOCHS}") + + dataset = LightWeightDataset(DATASET_DIR, epoch, rank, world_size) + dataloader = DataLoader( + dataset, batch_size=None, pin_memory=True, num_workers=2, prefetch_factor=2 + ) + for replica_part in dataloader: + + with torch.autocast(device_type="cuda"): + replica_part = forward_to_embeddings(replica_part, model) + + with torch.no_grad(): # all_gather cannot propagate gradients so make it explicit + all_replicas = [ + torch.zeros_like(replica_part) for _ in range(world_size) + ] + torch.distributed.all_gather(all_replicas, replica_part) + + # Allow gradients propagation for the slice owned by the current process + all_replicas[rank] = replica_part + + all_replicas = torch.cat(all_replicas, dim=0) + + links_embedded, descs_embedded = ( + all_replicas[: dataset.links_cnt], + all_replicas[dataset.links_cnt :], + ) + + loss, outputs = _calculate_loss( + links_embedded, descs_embedded, labels, LOGIT_MULTIPLIER, criterion + ) + + # norm_for_logs = step() + step() + + if is_the_main_process: + loss_item = loss.item() + + process_metrics( + outputs, + labels, + loss_item, + running_averages, + # { + # "gradient_norm": norm_for_logs, + # }, + ) if is_the_main_process: # We only save the model on the main process and only once From 857e758a16a61db0d2bf28d79836ea349e30c300 Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Sun, 3 Aug 2025 22:02:52 +0200 Subject: [PATCH 11/18] perf(utils): :zap: small performance improvements --- src/utils/embeddings.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/utils/embeddings.py b/src/utils/embeddings.py index c118c10..c7b0635 100644 --- a/src/utils/embeddings.py +++ b/src/utils/embeddings.py @@ -78,7 +78,7 @@ def embed( model.eval() # We usually work with IterableDataset subclass so no multiprocessing - data_loader = DataLoader(dataset, batch_size, num_workers=0) + data_loader = DataLoader(dataset, batch_size, num_workers=0, pin_memory=True) if torch.cuda.is_available(): model = torch.nn.DataParallel(model).cuda() @@ -96,19 +96,21 @@ def embed( with torch.no_grad(): for batch_toks, batch_qids in data_loader: - batch_toks = batch_toks.to(torch.int64) + if torch.cuda.is_available(): + batch_toks = batch_toks.to( + dtype=torch.int64, non_blocking=True, device="cuda" + ) attention_mask = create_attention_mask(batch_toks) if torch.cuda.is_available(): - batch_toks = batch_toks.cuda() attention_mask = attention_mask.cuda() with torch.amp.autocast(device_type="cuda"): batch_embeddings = model(batch_toks, attention_mask) - batch_embeddings = batch_embeddings.cpu().numpy().astype(np.float16) - batch_embeddings = batch_embeddings / np.linalg.norm( - batch_embeddings, ord=2, axis=1, keepdims=True + batch_embeddings = batch_embeddings / torch.linalg.norm( + batch_embeddings, ord=2, dim=1, keepdim=True ) + batch_embeddings = batch_embeddings.cpu().numpy().astype(np.float16) embeddings.extend(batch_embeddings) if return_tokens: tokens.extend(batch_toks.cpu().numpy()) From e5e0c10f01ec06cbf977d9f3371e61c258c09b6d Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Sun, 3 Aug 2025 22:24:28 +0200 Subject: [PATCH 12/18] feat(train): :sparkles: add option to run all recalls from one python function --- src/finetunings/evaluation/evaluate.py | 31 +++++++++++++++++--------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/src/finetunings/evaluation/evaluate.py b/src/finetunings/evaluation/evaluate.py index ffc9880..a463c90 100644 --- a/src/finetunings/evaluation/evaluate.py +++ b/src/finetunings/evaluation/evaluate.py @@ -1,7 +1,20 @@ +import logging + from finetunings.evaluation.find_recall import find_recall _RECALLS = [1, 10, 100] +_logger = logging.getLogger("finetuning.evaluation.evaluate") + + +def _construct_mewsli_path(root_dir: str, finetuning_round: int, lang: str) -> str: + return f"{root_dir}/mewsli_embs_{lang}_{finetuning_round}" + + +def _construct_damuel_path(root_dir: str, finetuning_round: int) -> str: + next_finetuning_round = finetuning_round + 1 + return f"{root_dir}/damuel_for_index_{next_finetuning_round}" + def run_recall_calculation(damuel_dir, mewsli_dir, recall=None): recalls = _RECALLS if recall is None else [recall] @@ -9,15 +22,13 @@ def run_recall_calculation(damuel_dir, mewsli_dir, recall=None): def evaluate( - damuel_desc_tokens, - mewsli_tokens, - model_path, - damuel_dir, - mewsli_dir, - state_dict=None, + root_dir: str, + finetuning_round: int, + langs: list[str] = ["ar", "de", "en", "es", "ja", "fa", "sr", "ta", "tr"], ): - raise NotImplementedError() - + damuel_path = _construct_damuel_path(root_dir, finetuning_round) -if __name__ == "__main__": - evaluate() + for lang in langs: + mewsli_path = _construct_mewsli_path(root_dir, finetuning_round, lang) + _logger.info(f"Calculating recall for {lang}") + run_recall_calculation(damuel_path, mewsli_path) From 3a27aeba4d4cbf82861c99dbee14bcd64de8ac63 Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Sun, 3 Aug 2025 22:24:48 +0200 Subject: [PATCH 13/18] test(train): :white_check_mark: evaluate tests --- src/finetunings/finetune_model/train_ddp.py | 7 ++++- tests/finetunings/evaluation/test_evaluate.py | 31 +++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) create mode 100644 tests/finetunings/evaluation/test_evaluate.py diff --git a/src/finetunings/finetune_model/train_ddp.py b/src/finetunings/finetune_model/train_ddp.py index 83a5817..d452f76 100644 --- a/src/finetunings/finetune_model/train_ddp.py +++ b/src/finetunings/finetune_model/train_ddp.py @@ -150,7 +150,12 @@ def step(): if is_the_main_process: _logger.info(f"Starting epoch {epoch + 1}/{EPOCHS}") - dataset = LightWeightDataset(DATASET_DIR, epoch, rank, world_size) + try: + dataset = LightWeightDataset(DATASET_DIR, epoch, rank, world_size) + except FileNotFoundError: + _logger.error(f"Dataset for epoch {epoch} not found. Stopping training.") + break + dataloader = DataLoader( dataset, batch_size=None, pin_memory=True, num_workers=2, prefetch_factor=2 ) diff --git a/tests/finetunings/evaluation/test_evaluate.py b/tests/finetunings/evaluation/test_evaluate.py new file mode 100644 index 0000000..78f378e --- /dev/null +++ b/tests/finetunings/evaluation/test_evaluate.py @@ -0,0 +1,31 @@ +from unittest.mock import patch + +from finetunings.evaluation.evaluate import evaluate + + +class TestEvaluate: + @patch("finetunings.evaluation.evaluate.run_recall_calculation") + def test_evaluate_default_langs(self, mock_run_recall): + evaluate("/root", 1) + + expected_damuel_path = "/root/damuel_for_index_2" + expected_calls = 9 # default 9 languages + + assert mock_run_recall.call_count == expected_calls + for call in mock_run_recall.call_args_list: + assert call[0][0] == expected_damuel_path + + @patch("finetunings.evaluation.evaluate.run_recall_calculation") + def test_evaluate_custom_langs(self, mock_run_recall): + custom_langs = ["en", "de"] + evaluate("/root", 2, langs=custom_langs) + + expected_damuel_path = "/root/damuel_for_index_3" + expected_mewsli_paths = ["/root/mewsli_embs_en_2", "/root/mewsli_embs_de_2"] + + assert mock_run_recall.call_count == 2 + calls = [call[0] for call in mock_run_recall.call_args_list] + + for i, call in enumerate(calls): + assert call[0] == expected_damuel_path + assert call[1] == expected_mewsli_paths[i] From 6f23a4818b980e779747c1635e56510e45b2ff70 Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Tue, 5 Aug 2025 21:41:57 +0200 Subject: [PATCH 14/18] feat(utils): :sparkles: add script for qid ocurrence analysis --- src/scripts/utils/qid_occurences.py | 97 +++++++++++++++++++++++++++++ 1 file changed, 97 insertions(+) create mode 100644 src/scripts/utils/qid_occurences.py diff --git a/src/scripts/utils/qid_occurences.py b/src/scripts/utils/qid_occurences.py new file mode 100644 index 0000000..31cb1aa --- /dev/null +++ b/src/scripts/utils/qid_occurences.py @@ -0,0 +1,97 @@ +import os +from collections import Counter +import numpy as np +from tqdm import tqdm +import fire + +""" +Given a path to a directory of multilingual dataset, calculates qid occurrence frequencies. +Usage: python qid_occurences.py +Assumes that the directory contains files with suffix _.npz +""" + +import concurrent.futures + + +def _file_qid_counts(path): + data = np.load(path) + qids = data["qids"] + data.close() + # Count occurrences of each qid in this file + return Counter(qids.flatten()) + + +def calculate_qid_statistics(directory_path=".", workers: int | None = None): + npz_files = [ + os.path.join(directory_path, f) + for f in os.listdir(directory_path) + if f.endswith(".npz") + ] + if not npz_files: + print("No .npz files found.") + return + if workers is None: + workers = min(32, (os.cpu_count() or 1)) + + total_counter = Counter() + + with concurrent.futures.ProcessPoolExecutor(max_workers=workers) as pool: + futures = {pool.submit(_file_qid_counts, p): p for p in npz_files} + for fut in tqdm( + concurrent.futures.as_completed(futures), + total=len(futures), + desc="Processing files", + ): + try: + file_counter = fut.result() + total_counter.update(file_counter) + except Exception as e: + print(f"Error {futures[fut]}: {e}") + continue + + if not total_counter: + print("No qids found.") + return + + # Get occurrence counts for statistical analysis + occurrence_counts = list(total_counter.values()) + + # Calculate statistics + average_count = np.mean(occurrence_counts) + median_count = np.median(occurrence_counts) + q1_count = np.percentile(occurrence_counts, 25) + q3_count = np.percentile(occurrence_counts, 75) + + # Calculate all deciles (10th, 20th, ..., 90th percentiles) + deciles = [np.percentile(occurrence_counts, p) for p in range(10, 100, 10)] + + # Get top 10 most occurring and top 10 least occurring + most_common = total_counter.most_common(10) + least_common = total_counter.most_common()[-10:] + + print("\n" + "=" * 50) + print("QID OCCURRENCE STATISTICS") + print("=" * 50) + print(f"Total unique qids: {len(total_counter):,}") + print(f"Total qid occurrences: {sum(total_counter.values()):,}") + print(f"Average occurrence count: {average_count:.2f}") + print(f"Median occurrence count: {median_count:.2f}") + print(f"Q1 (25th percentile): {q1_count:.2f}") + print(f"Q3 (75th percentile): {q3_count:.2f}") + + print("\nDECILES:") + for i, decile in enumerate(deciles, 1): + print(f"{i*10:2d}th percentile: {decile:.2f}") + + print("\nTOP 10 MOST OCCURRING QIDS:") + for i, (qid, count) in enumerate(most_common, 1): + print(f"{i:2d}. QID {qid}: {count:,} occurrences") + + print("\nTOP 10 LEAST OCCURRING QIDS:") + for i, (qid, count) in enumerate(least_common, 1): + print(f"{i:2d}. QID {qid}: {count:,} occurrences") + print("=" * 50) + + +if __name__ == "__main__": + fire.Fire(calculate_qid_statistics) From c2cab82bf48d7ddefd688c4faee17fec599a793a Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Thu, 7 Aug 2025 21:14:14 +0200 Subject: [PATCH 15/18] feat(multiling): :sparkles: option to limit link qids, faster mixing --- src/multilingual_dataset/creator.py | 67 +++++++-- src/multilingual_dataset/mixer.py | 140 +++++++++++------- tests/multilingual_dataset/test_mixer.py | 58 +++++--- .../test_multilingual_dataset_creator.py | 4 +- 4 files changed, 184 insertions(+), 85 deletions(-) diff --git a/src/multilingual_dataset/creator.py b/src/multilingual_dataset/creator.py index 1afad76..84dc62f 100644 --- a/src/multilingual_dataset/creator.py +++ b/src/multilingual_dataset/creator.py @@ -5,11 +5,12 @@ from itertools import zip_longest from pathlib import Path from typing import Union +import time import gin import numpy as np -from multilingual_dataset.mixer import Mixer +from multilingual_dataset.mixer import Mixer, ParallelMixer from tqdm import tqdm from utils.damuel_paths import DamuelPaths from utils.loaders import load_mentions, load_qids @@ -19,15 +20,22 @@ class _LinksCreator: def __init__( - self, damuel_paths: DamuelPaths, langs: list[str], dest_dir: Path + self, + damuel_paths: DamuelPaths, + langs: list[str], + dest_dir: Path, + max_samples_per_qid: int, ) -> None: self.damuel_paths: DamuelPaths = damuel_paths self.langs: list[str] = langs self.dest_links_dir: Path = dest_dir / "links" self.dest_links_dir.mkdir(parents=True, exist_ok=True) + self.max_samples_per_qid: int = max_samples_per_qid + self.single_mixer = Mixer(buffer_size=1) - self.standard_mixer = Mixer(buffer_size=50) + self.parallel_mixer = ParallelMixer(n_workers=20, buffer_size=5) + self.standard_mixer = Mixer(buffer_size=100) def run(self) -> None: """Gathers links from all languages and writes them to dest_dir. @@ -50,6 +58,8 @@ def run(self) -> None: out_file_paths.append(out_file_path) self.single_mixer.mix(out_file_paths, n_of_mixings=1, compress_output=False) + self.parallel_mixer.mix(out_file_paths, n_of_mixings=5, compress_output=False) + self._remove_often_qids(out_file_paths) self.standard_mixer.mix(out_file_paths, n_of_mixings=5, compress_output=True) def _copy_files( @@ -76,6 +86,28 @@ def _get_link_file_paths(self, link_dir_paths: list[Path]) -> list[list[Path]]: ) return link_file_paths + def _remove_often_qids(self, file_paths: list[Path]): + qid_counter = Counter() + for file_path in tqdm(file_paths, desc="Counting QIDs", total=len(file_paths)): + tokens, qids = load_mentions(file_path) + + tokens_filtered, qids_filtered = [], [] + for token, qid in zip(tokens, qids): + if qid_counter[qid] < self.max_samples_per_qid: + tokens_filtered.append(token) + qids_filtered.append(qid) + qid_counter[qid] += 1 + + np.savez( + file_path, + tokens=np.array(tokens_filtered), + qids=np.array(qids_filtered), + ) + + _logger.info( + f"Removed QIDs that occurred more than {self.max_samples_per_qid} times." + ) + class _KBCreator: def __init__( @@ -227,22 +259,34 @@ def _get_file_paths(self, dir_paths: list[Path]) -> list[Path]: class MultilingualDatasetCreator: def __init__( - self, source_dir: Union[str, Path], langs: list[str], dest_dir: Union[str, Path] + self, + source_dir: Union[str, Path], + langs: list[str], + dest_dir: Union[str, Path], + max_links_per_qid: int, ) -> None: self._damuel_paths: DamuelPaths = DamuelPaths(source_dir) self._kb_creator: _KBCreator = _KBCreator(self._damuel_paths, langs, dest_dir) self._links_creator: _LinksCreator = _LinksCreator( - self._damuel_paths, langs, dest_dir + self._damuel_paths, langs, dest_dir, max_links_per_qid ) def run(self) -> None: - _logger.info("Starting to create KB") - self._kb_creator.run() - _logger.info("Finished creating KB") _logger.info("Starting to create links") + start_time = time.time() self._links_creator.run() - _logger.info("Finished creating links") + links_time = time.time() - start_time + _logger.info(f"Finished creating links in {links_time:.2f} seconds") + + _logger.info("Starting to create KB") + start_time = time.time() + self._kb_creator.run() + kb_time = time.time() - start_time + _logger.info(f"Finished creating KB in {kb_time:.2f} seconds") + + total_time = links_time + kb_time + _logger.info(f"Total time: {total_time:.2f} seconds") @gin.configurable @@ -250,8 +294,11 @@ def create_multilingual_dataset( source_dir: Union[str, Path], langs: list[str], dest_dir: Union[str, Path], + max_links_per_qid: int, ) -> None: - MultilingualDatasetCreator(Path(source_dir), langs, Path(dest_dir)).run() + MultilingualDatasetCreator( + Path(source_dir), langs, Path(dest_dir), max_links_per_qid + ).run() def run_kb_creator( diff --git a/src/multilingual_dataset/mixer.py b/src/multilingual_dataset/mixer.py index 9093a8a..73aed41 100644 --- a/src/multilingual_dataset/mixer.py +++ b/src/multilingual_dataset/mixer.py @@ -1,6 +1,8 @@ import concurrent.futures +import multiprocessing as mp from collections.abc import Iterator from copy import deepcopy +from functools import partial from pathlib import Path from typing import Any @@ -10,6 +12,65 @@ from utils.loaders import load_mentions +def _shuffle(tokens: np.ndarray, qids: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """Shuffle tokens and qids arrays using the same permutation.""" + p = np.random.permutation(len(tokens)) + return tokens[p], qids[p] + + +def _load_tokens_and_qids(chunk: list[Path]) -> tuple[np.ndarray, np.ndarray]: + """Load tokens and qids from a chunk of files.""" + + def load_file(file_path): + return load_mentions(file_path) + + with concurrent.futures.ThreadPoolExecutor(max_workers=100) as executor: + results = list(executor.map(load_file, chunk)) + + all_tokens, all_qids = zip(*results) + all_tokens = np.concatenate(all_tokens) + all_qids = np.concatenate(all_qids) + + return all_tokens, all_qids + + +def _save_tokens_and_qids( + tokens: np.ndarray, + qids: np.ndarray, + chunk: list[Path], + compress: bool = True, +) -> None: + """Save tokens and qids back to their respective files.""" + + def _chunk_arrays(data: np.ndarray, chunk_size: int) -> list[np.ndarray]: + """Chunk a numpy array into smaller arrays.""" + return [data[i : i + chunk_size] for i in range(0, len(data), chunk_size)] + + token_chunk_size = len(tokens) // len(chunk) + tokens_chunked = _chunk_arrays(tokens, token_chunk_size) + qids_chunked = _chunk_arrays(qids, token_chunk_size) + + def save_file(tokens, qids, file_path): + if compress: + np.savez_compressed(file_path, tokens=tokens, qids=qids) + else: + np.savez(file_path, tokens=tokens, qids=qids) + + with concurrent.futures.ThreadPoolExecutor(max_workers=100) as executor: + futures = [ + executor.submit(save_file, tokens, qids, file_path) + for tokens, qids, file_path in zip(tokens_chunked, qids_chunked, chunk) + ] + concurrent.futures.wait(futures) + + +def mix_chunk(chunk: list[Path], compress_output: bool) -> None: + """Mix the tokens and qids for a single chunk of files.""" + tokens, qids = _load_tokens_and_qids(chunk) + tokens, qids = _shuffle(tokens, qids) + _save_tokens_and_qids(tokens, qids, chunk, compress_output) + + class Mixer: """Gets directory with many tokens and qids files, and buffer size. It mixes the content of the files leaving the number of the same. @@ -38,63 +99,36 @@ def _mix(self, file_paths: list[Path], compress_output: bool) -> None: desc="Mixing", total=len(file_paths) // self.buffer_size + 1, ): - tokens, qids = self._load_tokens_and_qids(chunk) - tokens, qids = self._shuffle(tokens, qids) - self._save_tokens_and_qids(tokens, qids, chunk, compress_output) - - def _shuffle( - self, tokens: np.ndarray, qids: np.ndarray - ) -> tuple[np.ndarray, np.ndarray]: - p = np.random.permutation(len(tokens)) - - return tokens[p], qids[p] - - def _save_tokens_and_qids( - self, - tokens: np.ndarray, - qids: np.ndarray, - chunk: list[Path], - compress: bool = True, - ) -> None: - tokens_chunked, qids_chunked = self._get_tokens_qids_chunks( - tokens, qids, len(chunk) - ) + mix_chunk(chunk, compress_output) - def save_file(tokens, qids, file_path): - if compress: - np.savez_compressed(file_path, tokens=tokens, qids=qids) - else: - np.savez(file_path, tokens=tokens, qids=qids) - - with concurrent.futures.ThreadPoolExecutor(max_workers=100) as executor: - futures = [ - executor.submit(save_file, tokens, qids, file_path) - for tokens, qids, file_path in zip(tokens_chunked, qids_chunked, chunk) - ] - concurrent.futures.wait(futures) - - def _get_tokens_qids_chunks( - self, tokens: np.ndarray, qids: np.ndarray, chunk_count: int - ) -> tuple[list[np.ndarray], list[np.ndarray]]: - token_chunk_size = len(tokens) // chunk_count - return self._chunk(tokens, token_chunk_size), self._chunk( - qids, token_chunk_size - ) + def _chunk(self, data: list[Any], chunk_size: int) -> Iterator[list[Any]]: + for i in range(0, len(data), chunk_size): + yield data[i : i + chunk_size] - def _load_tokens_and_qids(self, chunk: list[Path]) -> tuple[np.ndarray, np.ndarray]: - def load_file(file_path): - return load_mentions(file_path) - with concurrent.futures.ThreadPoolExecutor(max_workers=100) as executor: - results = list(executor.map(load_file, chunk)) +class ParallelMixer(Mixer): + """Parallel version of the Mixer class that uses multiprocessing for mixing files. - all_tokens, all_qids = zip(*results) + This class should be faster than the original Mixer class but requires more memory. + """ - all_tokens = np.concatenate(all_tokens) - all_qids = np.concatenate(all_qids) + def __init__(self, buffer_size: int = 10, n_workers: int = 4) -> None: + super().__init__(buffer_size) + self.n_workers = n_workers - return all_tokens, all_qids + def _mix(self, file_paths: list[Path], compress_output: bool) -> None: + np.random.shuffle(file_paths) + chunks = list(self._chunk(file_paths, self.buffer_size)) - def _chunk(self, data: list[Any], chunk_size: int) -> Iterator[list[Any]]: - for i in range(0, len(data), chunk_size): - yield data[i : i + chunk_size] + process_chunk_with_compression = partial( + mix_chunk, compress_output=compress_output + ) + + with mp.Pool(processes=self.n_workers) as pool: + list( + tqdm( + pool.imap(process_chunk_with_compression, chunks), + total=len(chunks), + desc="Parallel Mixing", + ) + ) diff --git a/tests/multilingual_dataset/test_mixer.py b/tests/multilingual_dataset/test_mixer.py index bdffc89..203f0aa 100644 --- a/tests/multilingual_dataset/test_mixer.py +++ b/tests/multilingual_dataset/test_mixer.py @@ -2,10 +2,9 @@ from unittest.mock import patch import gin - import numpy as np import pytest -from multilingual_dataset.mixer import Mixer +from multilingual_dataset.mixer import Mixer, ParallelMixer gin.add_config_file_search_path("configs/general.gin") @@ -14,12 +13,26 @@ def mock_remap_qids(qids, _): return qids +@pytest.fixture( + params=[ + {"cls": Mixer, "kwargs": {}}, + {"cls": ParallelMixer, "kwargs": {"n_workers": 1}}, + {"cls": ParallelMixer, "kwargs": {"n_workers": 2}}, + ], + ids=["Mixer", "ParallelMixer-1", "ParallelMixer-2"], +) +def mixer_factory(request): + cls = request.param["cls"] + kwargs = request.param["kwargs"] + return lambda buffer_size: cls(buffer_size=buffer_size, **kwargs) + + @pytest.fixture def create_dummy_npz_files(tmpdir): file_paths = [] for i in range(10): file_path = Path(tmpdir) / f"mentions_{i}.npz" - tokens = np.random.randint(1, 1000, size=(100, 10)) # 100 rows, 10 columns + tokens = np.random.randint(1, 1000, size=(100, 10)) qids = np.random.randint(1, 1000, size=(100,)) np.savez(file_path, tokens=tokens, qids=qids) file_paths.append(file_path) @@ -32,11 +45,13 @@ def load_npz_content(file_path): @patch("utils.qids_remap.qids_remap", side_effect=mock_remap_qids) -def test_mix_changes_file_contents(mock_qids_remap, create_dummy_npz_files): +def test_mix_changes_file_contents( + mock_qids_remap, create_dummy_npz_files, mixer_factory +): file_paths = create_dummy_npz_files original_contents = [load_npz_content(path) for path in file_paths] - mixer = Mixer(buffer_size=10) + mixer = mixer_factory(buffer_size=10) mixer.mix(file_paths, n_of_mixings=1, compress_output=False) new_contents = [load_npz_content(path) for path in file_paths] @@ -48,12 +63,14 @@ def test_mix_changes_file_contents(mock_qids_remap, create_dummy_npz_files): @patch("utils.qids_remap.qids_remap", side_effect=mock_remap_qids) -def test_mix_preserves_total_content(mock_qids_remap, create_dummy_npz_files): +def test_mix_preserves_total_content( + mock_qids_remap, create_dummy_npz_files, mixer_factory +): file_paths = create_dummy_npz_files original_tokens = np.concatenate([load_npz_content(path)[0] for path in file_paths]) original_qids = np.concatenate([load_npz_content(path)[1] for path in file_paths]) - mixer = Mixer(buffer_size=10) + mixer = mixer_factory(buffer_size=10) mixer.mix(file_paths, n_of_mixings=1, compress_output=False) new_tokens = np.concatenate([load_npz_content(path)[0] for path in file_paths]) @@ -66,11 +83,11 @@ def test_mix_preserves_total_content(mock_qids_remap, create_dummy_npz_files): @patch("utils.qids_remap.qids_remap", side_effect=mock_remap_qids) -def test_mix_multiple_times(mock_qids_remap, create_dummy_npz_files): +def test_mix_multiple_times(mock_qids_remap, create_dummy_npz_files, mixer_factory): file_paths = create_dummy_npz_files original_contents = [load_npz_content(path) for path in file_paths] - mixer = Mixer(buffer_size=10) + mixer = mixer_factory(buffer_size=10) mixer.mix(file_paths, n_of_mixings=3, compress_output=False) new_contents = [load_npz_content(path) for path in file_paths] @@ -82,11 +99,11 @@ def test_mix_multiple_times(mock_qids_remap, create_dummy_npz_files): @patch("utils.qids_remap.qids_remap", side_effect=mock_remap_qids) -def test_mix_with_small_buffer(mock_qids_remap, create_dummy_npz_files): +def test_mix_with_small_buffer(mock_qids_remap, create_dummy_npz_files, mixer_factory): file_paths = create_dummy_npz_files original_contents = [load_npz_content(path) for path in file_paths] - mixer = Mixer(buffer_size=2) + mixer = mixer_factory(buffer_size=2) mixer.mix(file_paths, n_of_mixings=1, compress_output=False) new_contents = [load_npz_content(path) for path in file_paths] @@ -98,21 +115,20 @@ def test_mix_with_small_buffer(mock_qids_remap, create_dummy_npz_files): @patch("utils.qids_remap.qids_remap", side_effect=mock_remap_qids) -def test_mix_empty_file_list(mock_qids_remap): - mixer = Mixer(buffer_size=1000) +def test_mix_empty_file(mock_qids_remap, mixer_factory): + mixer = mixer_factory(buffer_size=1000) mixer.mix([], n_of_mixings=1, compress_output=False) - # This test passes if no exception is raised @patch("utils.qids_remap.qids_remap", side_effect=mock_remap_qids) -def test_mix_single_file(mock_qids_remap, tmp_path): +def test_mix_single_file(mock_qids_remap, tmp_path, mixer_factory): file_path = tmp_path / "mentions_0.npz" tokens = np.random.randint(1, 1000, size=(100, 10)) qids = np.random.randint(1, 1000, size=(100,)) np.savez_compressed(file_path, tokens=tokens, qids=qids) - mixer = Mixer(buffer_size=1000) + mixer = mixer_factory(buffer_size=1000) mixer.mix([file_path], n_of_mixings=1, compress_output=False) new_tokens, new_qids = load_npz_content(file_path) @@ -122,11 +138,13 @@ def test_mix_single_file(mock_qids_remap, tmp_path): @patch("utils.qids_remap.qids_remap", side_effect=mock_remap_qids) -def test_mix_preserves_consistency(mock_qids_remap, create_dummy_npz_files): +def test_mix_preserves_consistency( + mock_qids_remap, create_dummy_npz_files, mixer_factory +): file_paths = create_dummy_npz_files original_shapes = [load_npz_content(path)[0].shape for path in file_paths] - mixer = Mixer(buffer_size=1000) + mixer = mixer_factory(buffer_size=1000) mixer.mix(file_paths, n_of_mixings=1, compress_output=False) for path, original_shape in zip(file_paths, original_shapes): @@ -135,7 +153,7 @@ def test_mix_preserves_consistency(mock_qids_remap, create_dummy_npz_files): @patch("utils.qids_remap.qids_remap", side_effect=mock_remap_qids) -def test_mix_compress(mock_qids_remap, tmp_path): +def test_mix_compress(mock_qids_remap, tmp_path, mixer_factory): file_path_1 = tmp_path / "mentions_0.npz" file_path_2 = tmp_path / "mentions_1.npz" tokens = np.random.randint(1, 1000, size=(100, 10)) @@ -143,7 +161,7 @@ def test_mix_compress(mock_qids_remap, tmp_path): np.savez(file_path_1, tokens=tokens, qids=qids) np.savez(file_path_2, tokens=tokens, qids=qids) - mixer = Mixer(buffer_size=1000) + mixer = mixer_factory(buffer_size=1000) mixer.mix([file_path_1], n_of_mixings=1, compress_output=True) mixer.mix([file_path_2], n_of_mixings=1, compress_output=False) diff --git a/tests/multilingual_dataset/test_multilingual_dataset_creator.py b/tests/multilingual_dataset/test_multilingual_dataset_creator.py index 06a53d8..9f5ee88 100644 --- a/tests/multilingual_dataset/test_multilingual_dataset_creator.py +++ b/tests/multilingual_dataset/test_multilingual_dataset_creator.py @@ -75,7 +75,7 @@ def output_dir(tmpdir): class Test_LinksCreator: @pytest.fixture def links_creator(self, damuel_paths, sample_langs, output_dir): - return _LinksCreator(damuel_paths, sample_langs, Path(output_dir)) + return _LinksCreator(damuel_paths, sample_langs, Path(output_dir), 100000000) @patch("utils.qids_remap.qids_remap", side_effect=mock_remap_qids) def test_run(self, mock_qids_remap, links_creator, output_dir): @@ -198,7 +198,7 @@ class TestMultilingualDatasetCreator: @pytest.fixture def dataset_creator(self, sample_langs, output_dir, tmp_path): return MultilingualDatasetCreator( - Path(tmp_path), sample_langs, Path(output_dir) + Path(tmp_path), sample_langs, Path(output_dir), 10 ) @patch("multilingual_dataset.creator._KBCreator.run") From c8f5110fae7e0e74e946a48a834ae70a5fd30ce3 Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Sat, 9 Aug 2025 10:48:34 +0200 Subject: [PATCH 16/18] refactor(multiling): :recycle: better defaults --- src/multilingual_dataset/creator.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/multilingual_dataset/creator.py b/src/multilingual_dataset/creator.py index 84dc62f..f4fcff3 100644 --- a/src/multilingual_dataset/creator.py +++ b/src/multilingual_dataset/creator.py @@ -33,9 +33,9 @@ def __init__( self.max_samples_per_qid: int = max_samples_per_qid - self.single_mixer = Mixer(buffer_size=1) - self.parallel_mixer = ParallelMixer(n_workers=20, buffer_size=5) - self.standard_mixer = Mixer(buffer_size=100) + self.single_mixer = ParallelMixer(n_workers=30, buffer_size=1) + self.parallel_mixer = ParallelMixer(n_workers=3, buffer_size=10) + self.standard_mixer = Mixer(buffer_size=30) def run(self) -> None: """Gathers links from all languages and writes them to dest_dir. @@ -58,9 +58,9 @@ def run(self) -> None: out_file_paths.append(out_file_path) self.single_mixer.mix(out_file_paths, n_of_mixings=1, compress_output=False) - self.parallel_mixer.mix(out_file_paths, n_of_mixings=5, compress_output=False) + self.parallel_mixer.mix(out_file_paths, n_of_mixings=130, compress_output=False) self._remove_often_qids(out_file_paths) - self.standard_mixer.mix(out_file_paths, n_of_mixings=5, compress_output=True) + self.standard_mixer.mix(out_file_paths, n_of_mixings=10, compress_output=True) def _copy_files( self, source_file_paths: Iterable[Path], dest_file_path: Path From e4b724c2bf50dad54af244e287949e727e5cccb4 Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Wed, 10 Sep 2025 21:10:47 +0200 Subject: [PATCH 17/18] perf(train): :zap: decrease redundant work in evaluation --- src/finetunings/evaluation/evaluate.py | 11 +++++++++-- src/finetunings/evaluation/find_recall.py | 17 +++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/src/finetunings/evaluation/evaluate.py b/src/finetunings/evaluation/evaluate.py index a463c90..a4fd51a 100644 --- a/src/finetunings/evaluation/evaluate.py +++ b/src/finetunings/evaluation/evaluate.py @@ -1,6 +1,10 @@ import logging -from finetunings.evaluation.find_recall import find_recall +from finetunings.evaluation.find_recall import ( + find_recall_with_searcher, + load_embs_and_qids_with_normalization, +) +from models.searchers.brute_force_searcher import BruteForceSearcher _RECALLS = [1, 10, 100] @@ -28,7 +32,10 @@ def evaluate( ): damuel_path = _construct_damuel_path(root_dir, finetuning_round) + damuel_embs, damuel_qids = load_embs_and_qids_with_normalization(damuel_path) + searcher = BruteForceSearcher(damuel_embs, damuel_qids) + for lang in langs: mewsli_path = _construct_mewsli_path(root_dir, finetuning_round, lang) _logger.info(f"Calculating recall for {lang}") - run_recall_calculation(damuel_path, mewsli_path) + find_recall_with_searcher(searcher, mewsli_path, _RECALLS) diff --git a/src/finetunings/evaluation/find_recall.py b/src/finetunings/evaluation/find_recall.py index 642afb0..6534955 100644 --- a/src/finetunings/evaluation/find_recall.py +++ b/src/finetunings/evaluation/find_recall.py @@ -82,6 +82,23 @@ def find_recall( _logger.info(f"Recall at {R}: {recall}") +@paths_exist(path_arg_ids=[1]) +def find_recall_with_searcher( + searcher: BruteForceSearcher, + mewsli: str, + recalls: list[int], +) -> None: + mewsli_embs, mewsli_qids = load_embs_and_qids_with_normalization(mewsli) + + rc = RecallCalculator(searcher) + + for R in recalls: + _logger.info("Calculating recall...") + recall = rc.recall(mewsli_embs, mewsli_qids, R) + wandb.log({f"recall_at_{R}": recall}) + _logger.info(f"Recall at {R}: {recall}") + + def find_candidates( damuel_entities: str, candidates_path: str, mewsli: str, recall: int ) -> None: From 0a6d8d67aa2ae290b28acb67a962cb46de5f2af2 Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Wed, 10 Sep 2025 21:13:15 +0200 Subject: [PATCH 18/18] feat(train): :fire: attempts to improve model and kill performance :D --- src/finetunings/file_processing/gathers.py | 4 + src/finetunings/finetune_model/data.py | 8 +- src/finetunings/finetune_model/train_ddp.py | 48 ++++++++-- src/finetunings/generate_epochs/generate.py | 1 + src/models/negative_sampler.py | 54 +++++++++-- src/models/searchers/brute_force_searcher.py | 4 +- src/models/searchers/faiss_searcher.py | 90 ++++++++++++++++--- src/models/searchers/searcher.py | 2 +- .../simplified_brute_force_searcher.py | 2 +- src/utils/model_factory.py | 10 +++ src/utils/multifile_dataset.py | 3 + tests/models/test_faiss_searcher.py | 51 +++++++++++ 12 files changed, 241 insertions(+), 36 deletions(-) create mode 100644 tests/models/test_faiss_searcher.py diff --git a/src/finetunings/file_processing/gathers.py b/src/finetunings/file_processing/gathers.py index dc39d5c..d3137ed 100644 --- a/src/finetunings/file_processing/gathers.py +++ b/src/finetunings/file_processing/gathers.py @@ -31,6 +31,10 @@ def move_tokens(source, dest, m=1, r=0, max_to_copy=float("inf")): source = Path(source) dest = Path(dest) already_copied = 0 + print( + f"Moving tokens from {source} to {dest} with m={m}, r={r}, max_to_copy={max_to_copy}" + ) + print(os.listdir(source)) for fn in sorted(os.listdir(source)): if not _wanted_fn(fn, m, r): continue diff --git a/src/finetunings/finetune_model/data.py b/src/finetunings/finetune_model/data.py index 2d8f812..e182f01 100644 --- a/src/finetunings/finetune_model/data.py +++ b/src/finetunings/finetune_model/data.py @@ -16,8 +16,9 @@ class SaveInformation: output_path: Path is_final: bool - epoch: int = None - recall: int = None + epoch: int | None = None + recall: int | None = None + name: str | None = None def _load_epoch_npz(path: Path, epoch: int | str) -> tuple: @@ -34,7 +35,8 @@ def construct_non_final_name(): def _save_final_model(model: nn.Module, save_information: SaveInformation) -> None: - torch.save(model.state_dict(), f"{save_information.output_path}/final.pth") + name = save_information.name if save_information.name else "finals.pth" + torch.save(model.state_dict(), f"{save_information.output_path}/{name}") def save_model(model: nn.Module, save_information: SaveInformation) -> None: diff --git a/src/finetunings/finetune_model/train_ddp.py b/src/finetunings/finetune_model/train_ddp.py index d452f76..734c08f 100644 --- a/src/finetunings/finetune_model/train_ddp.py +++ b/src/finetunings/finetune_model/train_ddp.py @@ -1,3 +1,4 @@ +from copy import deepcopy import logging import os from pathlib import Path @@ -21,7 +22,6 @@ from finetunings.finetune_model.data import ( LightWeightDataset, - LightWeightIterableDataset, save_model, SaveInformation, ) @@ -81,11 +81,18 @@ def _calculate_loss( return loss, outputs -def save_final_model(model, MODEL_SAVE_DIR): +def save_final_model(model, MODEL_SAVE_DIR, name: str | None = None): save_information = SaveInformation(MODEL_SAVE_DIR, True) + save_information.name = name save_model(model, save_information) +def update_ema(model, ema_model, decay=0.9999): + with torch.no_grad(): + for param, ema_param in zip(model.parameters(), ema_model.parameters()): + ema_param.data.mul_(decay).add_(param.data, alpha=1 - decay) + + def _ddp_train( rank: int, world_size: int, @@ -103,11 +110,17 @@ def _ddp_train( setup(rank, world_size) model = load_model(FOUNDATION_MODEL_PATH, STATE_DICT_PATH, TARGET_DIM) + model = DDP(model.to(rank), device_ids=[rank]) model = torch.compile(model) is_the_main_process = rank == 0 + if is_the_main_process: + ema_model = deepcopy(model) + ema_model.to(rank) + ema_model.eval() + if is_the_main_process: wandb.init( project="EL-train_ddp_process_0", @@ -128,6 +141,19 @@ def _ddp_train( scaler = torch.amp.GradScaler("cuda") + warmup_steps = 1000 + if is_the_main_process: + _logger.warning( + "Running with learning rate warmup for the first 1000 steps. This is hardcoded and should be made configurable from config." + ) + + def lr_lambda(step): + if step < warmup_steps: + return step / warmup_steps + return 1.0 + + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) + def step(): scaler.scale(loss).backward() scaler.unscale_(optimizer) @@ -135,6 +161,7 @@ def step(): scaler.step(optimizer) scaler.update() optimizer.zero_grad() + scheduler.step() # Step the scheduler after optimizer step step = torch.compile(step) @@ -146,6 +173,9 @@ def step(): labels = construct_labels(dataset) labels = torch.from_numpy(labels).to(rank) + + global_step = 0 # Track global step for warmup + for epoch in range(EPOCHS): if is_the_main_process: _logger.info(f"Starting epoch {epoch + 1}/{EPOCHS}") @@ -160,6 +190,7 @@ def step(): dataset, batch_size=None, pin_memory=True, num_workers=2, prefetch_factor=2 ) for replica_part in dataloader: + global_step += 1 with torch.autocast(device_type="cuda"): replica_part = forward_to_embeddings(replica_part, model) @@ -188,22 +219,27 @@ def step(): step() if is_the_main_process: + update_ema(model.module, ema_model) + loss_item = loss.item() + current_lr = scheduler.get_last_lr()[0] process_metrics( outputs, labels, loss_item, running_averages, - # { - # "gradient_norm": norm_for_logs, - # }, + { + "learning_rate": current_lr, + "global_step": global_step, + }, ) if is_the_main_process: # We only save the model on the main process and only once # Intermediate saves could mess up synchronization - save_final_model(model.module, MODEL_SAVE_DIR) + save_final_model(model.module, MODEL_SAVE_DIR, name="final.pth") + save_final_model(ema_model.module, MODEL_SAVE_DIR, name="ema.pth") cleanup() diff --git a/src/finetunings/generate_epochs/generate.py b/src/finetunings/generate_epochs/generate.py index 70bc62f..c6c1985 100644 --- a/src/finetunings/generate_epochs/generate.py +++ b/src/finetunings/generate_epochs/generate.py @@ -102,6 +102,7 @@ def generate( calculate_qids_distribution_from_links(LINKS_EMBS_DIR, index_qids) ) negative_sampler_kwargs["randomly_sampled_cnt"] = 1 + negative_sampler_kwargs["limit_negs"] = 10 batch_sampler = BatchSampler( index_embs, diff --git a/src/models/negative_sampler.py b/src/models/negative_sampler.py index a90c17b..7e54b1d 100644 --- a/src/models/negative_sampler.py +++ b/src/models/negative_sampler.py @@ -4,6 +4,7 @@ import numba as nb import numpy as np +import torch from models.searchers.searcher import Searcher @@ -132,6 +133,7 @@ def __init__( sampling_type: NegativeSamplingType, qids_distribution: np.ndarray | None = None, randomly_sampled_cnt: int | None = None, + limit_negs: int | None = None, ) -> None: assert len(embs) == len(qids) self.embs = embs @@ -144,27 +146,61 @@ def __init__( self.qids_distribution = qids_distribution self.randomly_sampled_cnt = randomly_sampled_cnt self._validate() - + self._limit_negs = limit_negs + if self._limit_negs is not None: + if torch.cuda.is_available(): + _logger.info("Running on CUDA.") + self.device: torch.device = torch.device("cuda") + else: + _logger.info("CUDA is not available.") + self.device: torch.device = torch.device("cpu") + self.pos_qids = torch.ones(max(qids) + 1, device=self.device) + self.neg_qids = torch.zeros(max(qids) + 1, device=self.device) + self.qids_t = torch.from_numpy(self.qids).to(self.device) def sample( self, batch_embs: np.ndarray, batch_qids: np.ndarray, negative_cnts: int ) -> np.ndarray: if self._should_sample_randomly(): negative_cnts -= self.randomly_sampled_cnt - neighbors = self.searcher.find( - batch_embs, max(negative_cnts + len(batch_embs), 100) - ) - # performance seems comparable with _get_neighbors_mask_set_arr - # by the Occams razor _get_neighbors_mask_set is better. - wanted_neighbors_mask = _get_neighbors_mask_set( - batch_qids, self.qids[neighbors] - ) + + enough_negatives_found = False + + multiplier = 1 + while not enough_negatives_found: + neighbors = self.searcher.find( + batch_embs, + max(int(multiplier * (negative_cnts + len(batch_embs))), 100), + ) + # performance seems comparable with _get_neighbors_mask_set_arr + # by the Occams razor _get_neighbors_mask_set is better. + wanted_neighbors_mask = _get_neighbors_mask_set( + batch_qids, self.qids[neighbors] + ) + + enough_negatives_found = True + + if self._limit_negs is not None: + mask = self.neg_qids <= self.pos_qids * self._limit_negs + wanted_neighbors_mask = torch.from_numpy(wanted_neighbors_mask).to( + self.device + ) + wanted_neighbors_mask &= mask[ + self.qids_t[torch.from_numpy(neighbors).to(self.device)] + ] + row_sums = torch.sum(wanted_neighbors_mask, dim=1) + enough_negatives_found = torch.all(row_sums >= negative_cnts).item() + wanted_neighbors_mask = wanted_neighbors_mask.cpu().numpy() + sampled = self.sample_f( batch_qids, negative_cnts, neighbors, wanted_neighbors_mask ) if self._should_sample_randomly(): randomly_sampled = self._sample_randomly(batch_qids) sampled = np.concatenate([sampled, randomly_sampled], axis=1) + if self._limit_negs is not None: + self.pos_qids[torch.from_numpy(batch_qids).to(self.device)] += 1 + self.neg_qids[torch.from_numpy(sampled).to(self.device)] += 1 return sampled def _should_sample_randomly(self): diff --git a/src/models/searchers/brute_force_searcher.py b/src/models/searchers/brute_force_searcher.py index ad1e7e3..cc9a70b 100644 --- a/src/models/searchers/brute_force_searcher.py +++ b/src/models/searchers/brute_force_searcher.py @@ -25,7 +25,7 @@ def __init__( self.device: torch.device = torch.device("cpu") super().__init__(embs, results, run_build_from_init) - def find(self, batch: np.ndarray, num_neighbors: int) -> np.ndarray: + def find(self, batch: np.ndarray, num_neighbors: int, mask=None) -> np.ndarray: # @torch.compile def _find(batch: np.ndarray) -> np.ndarray: batch_torch: torch.Tensor = torch.from_numpy(batch).to(self.device) @@ -72,7 +72,7 @@ def __init__( super().__init__(embs, results, run_build_from_init) @torch.compile - def find(self, batch: np.ndarray, num_neighbors: int) -> np.ndarray: + def find(self, batch: np.ndarray, num_neighbors: int, mask=None) -> np.ndarray: """ Finds the nearest neighbors for a given batch of input data. CAREFUL: This is an optimized version that comes with potential pitfalls to get better performance. diff --git a/src/models/searchers/faiss_searcher.py b/src/models/searchers/faiss_searcher.py index a5ad7c1..5a91871 100644 --- a/src/models/searchers/faiss_searcher.py +++ b/src/models/searchers/faiss_searcher.py @@ -1,29 +1,91 @@ -"""Currently BROKEN.""" +"""FAISS-based searcher with GPU support.""" import faiss import numpy as np +import math from models.searchers.searcher import Searcher class FaissSearcher(Searcher): def __init__(self, embs: np.ndarray, results: np.ndarray): - super().__init__(embs, results, True) + super().__init__(embs, results, False) + self.gpu_index = None + self.is_trained = False + + # Check GPU availability + self.ngpu = faiss.get_num_gpus() + if self.ngpu == 0: + raise RuntimeError("No GPU detected by Faiss") + + # IVFPQ parameters + self.d = embs.shape[1] # vector dimension + self.nb = embs.shape[0] # number of vectors to index + self.nlist = int(4 * math.sqrt(self.nb)) # rule-of-thumb: 4 × √N + self.m = 16 if self.d >= 16 else self.d # PQ subvectors (must divide d) + self.nbits = 8 # 8-bit codes + self.nprobe = 10 # number of clusters to search + + self.build() def find(self, batch, num_neighbors) -> np.ndarray: - print(self.index.search(batch, num_neighbors)) - return self.results[self.index.search(batch, num_neighbors)] + if self.gpu_index is None: + raise RuntimeError("Index not built. Call build() first.") + + if not self.is_trained: + raise RuntimeError("Index not trained. Call build() first.") + + # Ensure batch is float32 and 2D + query = np.asarray(batch, dtype=np.float32) + if query.ndim == 1: + query = query.reshape(1, -1) + + # Perform search + distances, indices = self.gpu_index.search(query, num_neighbors) + + # Return results for the queries + return self.results[indices] def build(self): self.build_index() - def build_index( - self, - num_leaves=1000, - ): - dim = self.embs.shape[-1] - quantizer = faiss.IndexFlatIP(dim) - self.index = faiss.IndexIVFFlat(quantizer, dim, num_leaves) - assert not self.index.is_trained - self.index.train(self.embs) - assert self.index.is_trained + def build_index(self): + print(f"Building FAISS GPU index with {self.ngpu} GPU(s)...") + + # Ensure embeddings are float32 + embs_f32 = self.embs.astype(np.float32) + + # Create CPU index first + cpu_quantizer = faiss.IndexFlatL2(self.d) + cpu_index = faiss.IndexIVFPQ( + cpu_quantizer, self.d, self.nlist, self.m, self.nbits + ) + + # Convert to GPU + if self.ngpu == 1: + print("Using single GPU mode") + res = faiss.StandardGpuResources() + self.gpu_index = faiss.index_cpu_to_gpu(res, 0, cpu_index) + else: + print(f"Using multi-GPU mode with {self.ngpu} GPUs") + co = faiss.GpuMultipleClonerOptions() + co.shard = True # one shard per GPU + self.gpu_index = faiss.index_cpu_to_all_gpus(cpu_index, co) + + # Train the index (required for IVFPQ) + print("Training the IVFPQ index...") + # Use subset of data for training if dataset is large + training_size = min(50000, self.nb) + training_data = embs_f32[:training_size] + self.gpu_index.train(training_data) + self.is_trained = True + print("Training completed.") + + # Set search parameters + self.gpu_index.nprobe = self.nprobe + + # Add all vectors to the index + self.gpu_index.add(embs_f32) + print(f"Added {self.gpu_index.ntotal} vectors to GPU IVFPQ index.") + + print("FAISS GPU index build completed.") diff --git a/src/models/searchers/searcher.py b/src/models/searchers/searcher.py index 7251f22..a849186 100644 --- a/src/models/searchers/searcher.py +++ b/src/models/searchers/searcher.py @@ -1,4 +1,4 @@ -""" Wrapper around any searcher we might use. """ +"""Wrapper around any searcher we might use.""" import logging from abc import ABC, abstractmethod diff --git a/src/models/searchers/simplified_brute_force_searcher.py b/src/models/searchers/simplified_brute_force_searcher.py index e7be9d3..e2bc365 100644 --- a/src/models/searchers/simplified_brute_force_searcher.py +++ b/src/models/searchers/simplified_brute_force_searcher.py @@ -1,4 +1,4 @@ -""" Primarly for testing purposes """ +"""Primarly for testing purposes""" import numpy as np diff --git a/src/utils/model_factory.py b/src/utils/model_factory.py index ca78c21..9803925 100644 --- a/src/utils/model_factory.py +++ b/src/utils/model_factory.py @@ -64,6 +64,16 @@ def _add_state_dict_to_model( cls, state_dict_path: str, model: torch.nn.Module ) -> torch.nn.Module: d = torch.load(state_dict_path, map_location="cpu") + new_state_dict = {} + for k, v in d.items(): + if k.startswith("_orig_mod.module.model."): + new_k = k.replace("_orig_mod.module.model.", "") + elif k.startswith("module."): + new_k = k.replace("module.", "") + else: + new_k = k + new_state_dict[new_k] = v + d = new_state_dict try: model.load_state_dict(d) except RuntimeError as e: diff --git a/src/utils/multifile_dataset.py b/src/utils/multifile_dataset.py index bcbdd08..3116683 100644 --- a/src/utils/multifile_dataset.py +++ b/src/utils/multifile_dataset.py @@ -26,6 +26,9 @@ def _get_file_list(self): if f.endswith(self.file_pattern[1:]) ] ) + _logger.info( + f"Found {len(file_list)} files matching pattern {self.file_pattern} in {self.data_dir}." + ) return file_list def _load_data(self, file_path): diff --git a/tests/models/test_faiss_searcher.py b/tests/models/test_faiss_searcher.py new file mode 100644 index 0000000..afb20fe --- /dev/null +++ b/tests/models/test_faiss_searcher.py @@ -0,0 +1,51 @@ +import numpy as np +import pytest + +pytest.importorskip("faiss") + +from models.searchers.brute_force_searcher import BruteForceSearcher + + +@pytest.fixture +def generate_data(): + def _generate(num_points, dim, num_queries, seed=42): + rng = np.random.RandomState(seed) + embs = rng.randn(num_points, dim).astype(np.float32) + queries = rng.randn(num_queries, dim).astype(np.float32) + results = np.arange(num_points) + return embs, queries, results + + return _generate + + +def assert_equal_results(faiss_results, brute_results): + assert faiss_results.shape == brute_results.shape + np.testing.assert_array_equal(faiss_results, brute_results) + + +def test_small(generate_data): + from models.searchers.faiss_searcher import FaissSearcher + + embs, queries, results = generate_data(num_points=100, dim=16, num_queries=5) + bf = BruteForceSearcher(embs, results) + fs = FaissSearcher(embs, results) + # fs.build() + num_neighbors = 3 + brute_out = bf.find(queries, num_neighbors) + faiss_out = fs.find(queries, num_neighbors) + assert_equal_results(faiss_out, brute_out) + + +def test_large(generate_data): + from models.searchers.faiss_searcher import FaissSearcher + + embs, queries, results = generate_data( + num_points=10000, dim=64, num_queries=50, seed=123 + ) + bf = BruteForceSearcher(embs, results) + fs = FaissSearcher(embs, results) + # fs.build() + num_neighbors = 10 + brute_out = bf.find(queries, num_neighbors) + faiss_out = fs.find(queries, num_neighbors) + assert_equal_results(faiss_out, brute_out)