From fe17fc01613e9ef77d3bedc1d9b7bb2de00f1514 Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Tue, 30 Sep 2025 17:45:25 +0200 Subject: [PATCH 01/28] feat: :sparkles: add verbose option to load_tokens_qids_from_dir for progress tracking --- src/utils/loaders.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/utils/loaders.py b/src/utils/loaders.py index 9da0f55..c40ba6c 100644 --- a/src/utils/loaders.py +++ b/src/utils/loaders.py @@ -5,6 +5,7 @@ import gin import numpy as np import pandas as pd +from tqdm import tqdm # from tokenization.pipeline import DamuelAliasTablePipeline from tokenization.runner import run_alias_table_damuel @@ -114,20 +115,29 @@ def load_qids_npy(file_path: str | Path) -> np.ndarray: @_sort_by_output(1) @qid_filter(qids_index=1) @remap_qids_decorator(qids_index=1, json_path=gin.REQUIRED) -def load_tokens_qids_from_dir(dir_path: str | Path) -> tuple[np.ndarray, np.ndarray]: +def load_tokens_qids_from_dir(dir_path: str | Path, verbose=False) -> tuple[np.ndarray, np.ndarray]: """ Loads mention tokens and query IDs from all .npz files in a given directory. Args: dir_path (str | Path): Path to the directory containing .npz files. + verbose (bool): If True, displays a progress bar while loading files. Returns: tuple[np.ndarray, np.ndarray]: A tuple containing two numpy arrays: - tokens: Array of mention tokens loaded from the files. - qids: Array of query IDs loaded from the files. """ + if type(dir_path) == str: + dir_path = Path(dir_path) tokens, qids = [], [] - for file in dir_path.iterdir(): + iterator = dir_path.iterdir() + if verbose: + total = sum(1 for itm in dir_path.iterdir() if itm.is_file() and itm.suffix == ".npz") + iterator = tqdm( + dir_path.iterdir(), desc=f"Loading tokens and qids from {dir_path}", total=total + ) + for file in iterator: if file.is_file() and file.suffix == ".npz": d = np.load(file) tokens.extend(d["tokens"]) From 8fbe8343a698d7bb8be632f17873ad3d3a9d7380 Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Tue, 30 Sep 2025 17:46:20 +0200 Subject: [PATCH 02/28] feat: :sparkles: add new scripts for Qwen reranking and processing inputs --- src/scripts/qwen/el_example.py | 88 +++++++++++++++++++++++++++++++ src/scripts/qwen/hf_example.py | 72 ++++++++++++++++++++++++++ src/scripts/qwen/reranker.py | 95 ++++++++++++++++++++++++++++++++++ src/scripts/qwen/reranking.py | 78 ++++++++++++++++++++++++++++ 4 files changed, 333 insertions(+) create mode 100644 src/scripts/qwen/el_example.py create mode 100644 src/scripts/qwen/hf_example.py create mode 100644 src/scripts/qwen/reranker.py create mode 100644 src/scripts/qwen/reranking.py diff --git a/src/scripts/qwen/el_example.py b/src/scripts/qwen/el_example.py new file mode 100644 index 0000000..d055f04 --- /dev/null +++ b/src/scripts/qwen/el_example.py @@ -0,0 +1,88 @@ +import torch +from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer + + +def format_instruction(instruction, query, doc): + if instruction is None: + instruction = "Given a web search query, retrieve relevant passages that answer the query" + output = ": {instruction}\n: {query}\n: {doc}".format( + instruction=instruction, query=query, doc=doc + ) + return output + + +def process_inputs(pairs): + inputs = tokenizer( + pairs, + padding=False, + truncation="longest_first", + return_attention_mask=False, + max_length=max_length - len(prefix_tokens) - len(suffix_tokens), + ) + for i, ele in enumerate(inputs["input_ids"]): + inputs["input_ids"][i] = prefix_tokens + ele + suffix_tokens + inputs = tokenizer.pad(inputs, padding=True, return_tensors="pt", max_length=max_length) + for key in inputs: + inputs[key] = inputs[key].to(model.device) + return inputs + + +@torch.no_grad() +def compute_logits(inputs, **kwargs): + batch_scores = model(**inputs).logits[:, -1, :] + true_vector = batch_scores[:, token_true_id] + false_vector = batch_scores[:, token_false_id] + batch_scores = torch.stack([false_vector, true_vector], dim=1) + batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1) + scores = batch_scores[:, 1].exp().tolist() + return scores + + +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Reranker-0.6B", padding_side="left") +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-Reranker-0.6B").eval() +# We recommend enabling flash_attention_2 for better acceleration and memory saving. +# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-Reranker-0.6B", torch_dtype=torch.float16, attn_implementation="flash_attention_2").cuda().eval() +token_false_id = tokenizer.convert_tokens_to_ids("no") +token_true_id = tokenizer.convert_tokens_to_ids("yes") +max_length = 8192 + +prefix = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n' +suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" +prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False) +suffix_tokens = tokenizer.encode(suffix, add_special_tokens=False) + +task = ( + "Your task is to determine if the provided Wikipedia description correctly corresponds " + "to the entity mention found in the query. The entity mention is marked by and . " + "Check if the description matches the entity. Answer strictly with 'yes' or 'no'.\n" + "Example:\n" + " Query: 'What is the capital of France?'\n" + " Description: 'Paris is the capital and largest city of France...'\n" + " Answer: no" + " Query: 'What is the capital of France?'\n" + " Description: 'Paris is the capital and largest city of France...'\n" + " Answer: yes" +) + +queries = [ + "What is the capital of China?", + "In order to save Troy, Paris had to be sacrificed.", + "2. prezident republiky byl zdrcen dohodou z Mnichova.", +] * 2 + +documents = [ + "Peking (zvuk výslovnost, čínsky v českém přepisu Pej-ťing, pchin-jinem Běijīng, znaky 北京) je hlavní město Čínské lidové republiky. S více než 21 miliony obyvatel je jedním z nejlidnatějších hlavních měst na světě,[2][3] a po Šanghaji druhým nejlidnatějším městem v Číně.", + "Paris (Ancient Greek: Πάρις, romanized: Páris), also known as Alexander (Ancient Greek: Ἀλέξανδρος, romanized: Aléxandros), is a mythological figure in the story of the Trojan War.", + "Edvard Beneš (původním jménem Eduard;[pozn. 2] 28. května 1884 Kožlany[4] – 3. září 1948 Sezimovo Ústí) byl československý politik a státník, druhý československý prezident v letech 1935–1948, resp. v letech 1935–1938 a 1945–1948. V období tzv. Druhé republiky (po Mnichovské dohodě ze dne 29. září 1938 do 15. března 1939) a následné německé okupace do května 1945 žil a politicky působil v exilu. Od roku 1940 až do osvobození Československa byl mezinárodně (nejen protihitlerovskou koalicí) uznaným vrcholným představitelem československého odboje a exilovým prezidentem republiky. Úřadujícím československým prezidentem byl opět v letech 1945–1948.", + "Shanghai[a] is a direct-administered municipality and the most populous urban area in China. The city is located on the Chinese shoreline on the southern estuary of the Yangtze River, with the Huangpu River flowing through it. The population of the city proper is the second largest in the world with around 24.87 million inhabitants in 2023, while the urban area is the most populous in China, with 29.87 million residents.", + "Paris[a] is the capital and largest city of France, with an estimated city center population of 2,048,472, and a metropolitan population of 13,171,056 as of January 2025[3] in an area of more than 105 km2 (41 sq mi). It is located in the centre of the Île-de-France region. Paris is the fourth-most populous city in the European Union. Nicknamed the City of Light, Paris has been one of the world's major centres of finance, diplomacy, commerce, culture, fashion, and gastronomy since the 17th century. ", + "Tomáš Garrigue Masaryk, označovaný T. G. M., TGM nebo Prezident Osvoboditel (7. března 1850 Hodonín[1] – 14. září 1937 Lány[2]), byl československý státník, filozof, sociolog a pedagog, první prezident Československé republiky. K jeho osmdesátým narozeninám byl roku 1930 přijat zákon o zásluhách T. G. Masaryka, obsahující větu „Tomáš Garrigue Masaryk zasloužil se o stát“, a po odchodu z funkce roku 1935 ho parlament znovu ocenil a odměnil za jeho osvoboditelské a budovatelské dílo.", +] + +pairs = [format_instruction(task, query, doc) for query, doc in zip(queries, documents)] + +# Tokenize the input texts +inputs = process_inputs(pairs) +scores = compute_logits(inputs) + +print("scores: ", scores) diff --git a/src/scripts/qwen/hf_example.py b/src/scripts/qwen/hf_example.py new file mode 100644 index 0000000..932f6a1 --- /dev/null +++ b/src/scripts/qwen/hf_example.py @@ -0,0 +1,72 @@ +import torch +from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer + + +def format_instruction(instruction, query, doc): + if instruction is None: + instruction = "Given a web search query, retrieve relevant passages that answer the query" + output = ": {instruction}\n: {query}\n: {doc}".format( + instruction=instruction, query=query, doc=doc + ) + return output + + +def process_inputs(pairs): + inputs = tokenizer( + pairs, + padding=False, + truncation="longest_first", + return_attention_mask=False, + max_length=max_length - len(prefix_tokens) - len(suffix_tokens), + ) + for i, ele in enumerate(inputs["input_ids"]): + inputs["input_ids"][i] = prefix_tokens + ele + suffix_tokens + inputs = tokenizer.pad(inputs, padding=True, return_tensors="pt", max_length=max_length) + for key in inputs: + inputs[key] = inputs[key].to(model.device) + return inputs + + +@torch.no_grad() +def compute_logits(inputs, **kwargs): + batch_scores = model(**inputs).logits[:, -1, :] + true_vector = batch_scores[:, token_true_id] + false_vector = batch_scores[:, token_false_id] + batch_scores = torch.stack([false_vector, true_vector], dim=1) + batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1) + scores = batch_scores[:, 1].exp().tolist() + return scores + + +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Reranker-0.6B", padding_side="left") +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-Reranker-0.6B").eval() +# We recommend enabling flash_attention_2 for better acceleration and memory saving. +# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-Reranker-0.6B", torch_dtype=torch.float16, attn_implementation="flash_attention_2").cuda().eval() +token_false_id = tokenizer.convert_tokens_to_ids("no") +token_true_id = tokenizer.convert_tokens_to_ids("yes") +max_length = 8192 + +prefix = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n' +suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" +prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False) +suffix_tokens = tokenizer.encode(suffix, add_special_tokens=False) + +task = "Given a web search query, retrieve relevant passages that answer the query" + +queries = [ + "What is the capital of China?", + "Explain gravity", +] + +documents = [ + "In the Troyan war, the Greeks besieged the city of Troy for ten years. The war ended with the clever use of a wooden horse, known as the Trojan Horse, which allowed Greek soldiers to enter the city and defeat the Trojans.", + "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.", +] + +pairs = [format_instruction(task, query, doc) for query, doc in zip(queries, documents)] + +# Tokenize the input texts +inputs = process_inputs(pairs) +scores = compute_logits(inputs) + +print("scores: ", scores) diff --git a/src/scripts/qwen/reranker.py b/src/scripts/qwen/reranker.py new file mode 100644 index 0000000..6213b3b --- /dev/null +++ b/src/scripts/qwen/reranker.py @@ -0,0 +1,95 @@ +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + + +class Reranker: + """Lightweight wrapper around Qwen reranker for mention-description scoring.""" + + _DEFAULT_INSTRUCTION = ( + "Your task is to determine if the provided Wikipedia description correctly corresponds " + "to the entity mention found in the query. The entity mention is marked by and . " + "Check if the description matches the entity. Answer strictly with 'yes' or 'no'.\n" + "Example:\n" + " Query: 'What is the capital of France?'\n" + " Description: 'Paris is the capital and largest city of France...'\n" + " Answer: no\n" + " Query: 'What is the capital of France?'\n" + " Description: 'Paris is the capital and largest city of France...'\n" + " Answer: yes" + ) + _SYSTEM_PROMPT = ( + "<|im_start|>system\n" + "Judge whether the Document meets the requirements based on the Query and the Instruct " + 'provided. Note that the answer can only be "yes" or "no".<|im_end|>\n' + "<|im_start|>user\n" + ) + _ASSISTANT_SUFFIX = "<|im_end|>\n" "<|im_start|>assistant\n" "\n\n\n\n" + + def __init__( + self, + model_name: str = "Qwen/Qwen3-Reranker-0.6B", + max_length: int = 8192, + instruction: str | None = None, + ) -> None: + self.model_name = model_name + self.max_length = max_length + self.instruction = instruction or self._DEFAULT_INSTRUCTION + + self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") + self.model = AutoModelForCausalLM.from_pretrained(model_name).eval() + + self.token_false_id = self.tokenizer.convert_tokens_to_ids("no") + self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes") + self.prefix_tokens = self.tokenizer.encode(self._SYSTEM_PROMPT, add_special_tokens=False) + self.suffix_tokens = self.tokenizer.encode(self._ASSISTANT_SUFFIX, add_special_tokens=False) + + def score(self, mention: str, description: str, instruction: str | None = None) -> float: + """Return probability that description matches mention.""" + formatted_query = self._format_query(mention) + formatted_instruction = instruction or self.instruction + prompt = self._format_instruction(formatted_instruction, formatted_query, description) + inputs = self._process_inputs([prompt]) + probabilities = self._compute_probabilities(inputs) + return probabilities[0] + + def _format_query(self, mention: str) -> str: + has_markers = "" in mention and "" in mention + wrapped = mention if has_markers else f"{mention}" + return f"Identify the entity referenced by {wrapped}." + + def _format_instruction(self, instruction: str, query: str, document: str) -> str: + return ": {instruction}\n: {query}\n: {doc}".format( + instruction=instruction, + query=query, + doc=document, + ) + + def _process_inputs(self, prompts: list[str]): + tokenized = self.tokenizer( + prompts, + padding=False, + truncation="longest_first", + return_attention_mask=False, + max_length=self.max_length - len(self.prefix_tokens) - len(self.suffix_tokens), + ) + + for i, ids in enumerate(tokenized["input_ids"]): + tokenized["input_ids"][i] = self.prefix_tokens + ids + self.suffix_tokens + + tokenized = self.tokenizer.pad( + tokenized, padding=True, return_tensors="pt", max_length=self.max_length + ) + + for key in tokenized: + tokenized[key] = tokenized[key].to(self.model.device) + + return tokenized + + @torch.no_grad() + def _compute_probabilities(self, inputs): + logits = self.model(**inputs).logits[:, -1, :] + true_vector = logits[:, self.token_true_id] + false_vector = logits[:, self.token_false_id] + stacked = torch.stack([false_vector, true_vector], dim=1) + log_probs = torch.nn.functional.log_softmax(stacked, dim=1) + return log_probs[:, 1].exp().tolist() diff --git a/src/scripts/qwen/reranking.py b/src/scripts/qwen/reranking.py new file mode 100644 index 0000000..d7787b7 --- /dev/null +++ b/src/scripts/qwen/reranking.py @@ -0,0 +1,78 @@ +import argparse +from pathlib import Path + +# Import BruteForceSearcher from models +from models.searchers.brute_force_searcher import BruteForceSearcher + +# Import Reranker class +from scripts.qwen.reranker import Reranker + +# Import necessary functions from loaders +from utils.loaders import load_embs_and_qids, load_tokens_qids, load_tokens_qids_from_dir + + +def main(): + parser = argparse.ArgumentParser(description="Reranking for entity linking") + parser.add_argument( + "--damuel_token", + type=str, + default="/lnet/work/home-students-external/farhan/troja/outputs/v2_normal_filtered/descs_pages", + help="Path to damuel token file or directory", + ) + parser.add_argument( + "--damuel_embs", + type=str, + default="/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/damuel_for_index_6", + help="Path to damuel embeddings directory or .npz file", + ) + parser.add_argument( + "--mewsli_tokens", + type=str, + default="/lnet/work/home-students-external/farhan/troja/outputs/tokens_mewsli_finetuning/es/mentions_1307770978027216442.npz", + help="Path to mewsli token file or directory", + ) + parser.add_argument( + "--qwen_model_name", + type=str, + default="Qwen/Qwen3-Reranker-0.6B", + help="Name of the QWEN model", + ) + args = parser.parse_args() + + # Resolve tokens and embeddings (directories or files) + damuel_tokens, damuel_token_qids = load_tokens_qids_from_dir(args.damuel_token, verbose=True) + damuel_embs, damuel_qids = load_embs_and_qids(args.damuel_embs) + + qid_to_damuel_token = {qid: token for qid, token in zip(damuel_token_qids, damuel_tokens)} + qid_to_damuel_emb = {qid: emb for qid, emb in zip(damuel_qids, damuel_embs)} + + del damuel_token_qids + + mewsli_tokens, mewsli_qids = load_tokens_qids(args.mewsli_tokens) + + # Take first four names (tokens) from each as a quick smoke-test + damuel_tokens_preview = damuel_tokens[:4] + mewsli_tokens_preview = mewsli_tokens[:4] + + print("First 4 damuel tokens:", damuel_tokens_preview) + print("First 4 mewsli tokens:", mewsli_tokens_preview) + + # Create searcher using damuel embeddings and damuel qids + searcher = BruteForceSearcher(damuel_embs, damuel_qids) + print("Searcher created.") + + # Initialize reranker model (actual reranking logic to be implemented later) + reranker = Reranker(model_name=args.qwen_model_name) + print(f"Reranker initialized with model: {reranker.model_name}") + + # Stub for QWEN model loading + print("QWEN model to be used:", args.qwen_model_name) + + # TODO: implement reranking logic + print("Reranking logic not implemented. Exiting.") + + print("jupi") + + +if __name__ == "__main__": + main() From a8038d8de23266ca08c49dd2e418711bd48f5ef1 Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Wed, 1 Oct 2025 15:50:03 +0200 Subject: [PATCH 03/28] feat: :sparkles: add max_items_to_load parameter to load_tokens_qids_from_dir for limiting loaded items --- src/utils/loaders.py | 6 +- tests/utils/test_loaders.py | 106 +++++++++++++++++++++++++++++++++++- 2 files changed, 109 insertions(+), 3 deletions(-) diff --git a/src/utils/loaders.py b/src/utils/loaders.py index c40ba6c..44d865b 100644 --- a/src/utils/loaders.py +++ b/src/utils/loaders.py @@ -115,7 +115,9 @@ def load_qids_npy(file_path: str | Path) -> np.ndarray: @_sort_by_output(1) @qid_filter(qids_index=1) @remap_qids_decorator(qids_index=1, json_path=gin.REQUIRED) -def load_tokens_qids_from_dir(dir_path: str | Path, verbose=False) -> tuple[np.ndarray, np.ndarray]: +def load_tokens_qids_from_dir( + dir_path: str | Path, verbose=False, max_items_to_load: int | None = None +) -> tuple[np.ndarray, np.ndarray]: """ Loads mention tokens and query IDs from all .npz files in a given directory. @@ -142,6 +144,8 @@ def load_tokens_qids_from_dir(dir_path: str | Path, verbose=False) -> tuple[np.n d = np.load(file) tokens.extend(d["tokens"]) qids.extend(d["qids"]) + if max_items_to_load is not None and len(tokens) >= max_items_to_load: + break return np.array(tokens), np.array(qids) diff --git a/tests/utils/test_loaders.py b/tests/utils/test_loaders.py index dff323c..20846bc 100644 --- a/tests/utils/test_loaders.py +++ b/tests/utils/test_loaders.py @@ -13,6 +13,7 @@ load_qids, load_qids_npy, load_tokens_qids, + load_tokens_qids_from_dir, ) @@ -20,8 +21,14 @@ def mock_remap_qids(qids, _): return qids +def _create_tokens_qids_npz(dir_path: Path, file_name: str, tokens: np.ndarray, qids: np.ndarray): + file_path = Path(dir_path) / file_name + np.savez_compressed(file_path, tokens=np.array(tokens), qids=np.array(qids)) + return file_path + + @patch("utils.qids_remap.qids_remap", side_effect=mock_remap_qids) -def test_load_mentions_with_path_object(mock_qids_remap): +def test_load_tokens_qids_with_path_object(mock_qids_remap): with tempfile.TemporaryDirectory() as temp_dir: file_path = Path(temp_dir) / "mentions_2.npz" @@ -37,7 +44,7 @@ def test_load_mentions_with_path_object(mock_qids_remap): @patch("utils.qids_remap.qids_remap", side_effect=mock_remap_qids) -def test_load_mentions_with_string_path(mock_qids_remap): +def test_load_tokens_qids_with_string_path(mock_qids_remap): with tempfile.TemporaryDirectory() as temp_dir: file_path = str(Path(temp_dir) / "mentions_1.npz") @@ -273,6 +280,101 @@ def test_embs_qids_tokens_from_file(mock_qids_remap, use_string_path, file_name) assert len(loaded_data) == len(test_data) +@patch("utils.qids_remap.qids_remap", side_effect=mock_remap_qids) +def test_load_tokens_qids_from_dir_single_file(mock_qids_remap): + with tempfile.TemporaryDirectory() as temp_dir: + dir_path = Path(temp_dir) + + tokens = np.array([[1, 2, 3], [4, 5, 6]]) + qids = np.array([10, 20]) + _create_tokens_qids_npz(dir_path, "data_0.npz", tokens, qids) + + loaded_tokens, loaded_qids = load_tokens_qids_from_dir(dir_path) + + sort_indices = np.argsort(qids, kind="stable") + assert np.array_equal(loaded_qids, qids[sort_indices]) + assert np.array_equal(loaded_tokens, tokens[sort_indices]) + + +@patch("utils.qids_remap.qids_remap", side_effect=mock_remap_qids) +def test_load_tokens_qids_from_dir_multiple_files(mock_qids_remap): + with tempfile.TemporaryDirectory() as temp_dir: + dir_path = Path(temp_dir) + + tokens_a = np.array([[1, 1, 1], [2, 2, 2]]) + qids_a = np.array([20, 10]) + tokens_b = np.array([[3, 3, 3]]) + qids_b = np.array([30]) + _create_tokens_qids_npz(dir_path, "data_a.npz", tokens_a, qids_a) + _create_tokens_qids_npz(dir_path, "data_b.npz", tokens_b, qids_b) + + loaded_tokens, loaded_qids = load_tokens_qids_from_dir(dir_path) + + expected_tokens = np.vstack([tokens_a, tokens_b]) + expected_qids = np.concatenate([qids_a, qids_b]) + sort_indices = np.argsort(expected_qids, kind="stable") + + assert np.array_equal(loaded_qids, expected_qids[sort_indices]) + assert np.array_equal(loaded_tokens, expected_tokens[sort_indices]) + + +@patch("utils.qids_remap.qids_remap", side_effect=mock_remap_qids) +def test_load_tokens_qids_from_dir_max_limit_overflow(mock_qids_remap): + with tempfile.TemporaryDirectory() as temp_dir: + dir_path = Path(temp_dir) + + tokens_a = np.array([[1, 0], [2, 0], [3, 0]]) + qids_a = np.array([300, 100, 200]) + tokens_b = np.array([[4, 0], [5, 0], [6, 0]]) + qids_b = np.array([600, 500, 400]) + _create_tokens_qids_npz(dir_path, "data_a.npz", tokens_a, qids_a) + _create_tokens_qids_npz(dir_path, "data_b.npz", tokens_b, qids_b) + + max_items = 4 + loaded_tokens, loaded_qids = load_tokens_qids_from_dir( + dir_path, max_items_to_load=max_items + ) + + expected_tokens = np.vstack([tokens_a, tokens_b]) + expected_qids = np.concatenate([qids_a, qids_b]) + sort_indices = np.argsort(expected_qids, kind="stable") + + assert len(loaded_tokens) == len(expected_tokens) + assert len(loaded_tokens) > max_items + assert np.array_equal(loaded_qids, expected_qids[sort_indices]) + assert np.array_equal(loaded_tokens, expected_tokens[sort_indices]) + + +@patch("utils.qids_remap.qids_remap", side_effect=mock_remap_qids) +def test_load_tokens_qids_from_dir_large_max_limit_loads_all(mock_qids_remap): + with tempfile.TemporaryDirectory() as temp_dir: + dir_path = Path(temp_dir) + + tokens_a = np.array([[10, 10], [20, 20]]) + qids_a = np.array([200, 100]) + tokens_b = np.array([[30, 30]]) + qids_b = np.array([300]) + tokens_c = np.array([[40, 40]]) + qids_c = np.array([400]) + _create_tokens_qids_npz(dir_path, "data_a.npz", tokens_a, qids_a) + _create_tokens_qids_npz(dir_path, "data_b.npz", tokens_b, qids_b) + _create_tokens_qids_npz(dir_path, "data_c.npz", tokens_c, qids_c) + + max_items = 100 + loaded_tokens, loaded_qids = load_tokens_qids_from_dir( + dir_path, max_items_to_load=max_items + ) + + expected_tokens = np.vstack([tokens_a, tokens_b, tokens_c]) + expected_qids = np.concatenate([qids_a, qids_b, qids_c]) + sort_indices = np.argsort(expected_qids, kind="stable") + + assert len(loaded_tokens) == len(expected_tokens) + assert len(loaded_tokens) < max_items + assert np.array_equal(loaded_qids, expected_qids[sort_indices]) + assert np.array_equal(loaded_tokens, expected_tokens[sort_indices]) + + @pytest.mark.parametrize("use_string_path", [True, False]) @patch("utils.qids_remap.qids_remap", side_effect=mock_remap_qids) def test_load_qids(mock_qids_remap, use_string_path: bool) -> None: From 266fe7a63ba5c66779416478980d8094c7de70d1 Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Fri, 3 Oct 2025 16:25:00 +0200 Subject: [PATCH 04/28] feat: :wip: reranking simple model --- src/reranking/binary/create_dataset.py | 35 ++-- src/reranking/models/__init__.py | 5 + src/reranking/models/base.py | 21 +++ src/reranking/models/pairwise_mlp.py | 145 ++++++++++++++++ src/utils/model_factory.py | 12 ++ tests/reranking/models/test_pairwise_mlp.py | 180 ++++++++++++++++++++ 6 files changed, 382 insertions(+), 16 deletions(-) create mode 100644 src/reranking/models/__init__.py create mode 100644 src/reranking/models/base.py create mode 100644 src/reranking/models/pairwise_mlp.py create mode 100644 tests/reranking/models/test_pairwise_mlp.py diff --git a/src/reranking/binary/create_dataset.py b/src/reranking/binary/create_dataset.py index 8c2f15f..5378069 100644 --- a/src/reranking/binary/create_dataset.py +++ b/src/reranking/binary/create_dataset.py @@ -7,8 +7,6 @@ import torch.utils.data from tqdm import tqdm -sys.path.append("/lnet/work/home-students-external/farhan/mel-reborn/src") - from models.searchers.brute_force_searcher import BruteForceSearcher from utils.embeddings import create_attention_mask from utils.loaders import load_embs_and_qids, load_tokens_qids_from_dir @@ -40,17 +38,15 @@ def create_binary_dataset( ) -> None: # Load index embeddings, qids, and tokens index_embs, index_qids = load_embs_and_qids(index_embs_dir) - index_qids_set = set(index_qids) index_embs = index_embs.astype(np.float16) index_tokens, _ = load_tokens_qids_from_dir(index_tokens_path) print(index_tokens.shape) - print(len(index_qids_set)) # Create BruteForceSearcher searcher = BruteForceSearcher(index_embs, index_qids) # Load link tokens and qids - link_tokens, link_qids = load_tokens_qids_from_dir(link_tokens_path) + link_tokens, link_qids = load_tokens_qids_from_dir(link_tokens_path, max_items_to_load=10**7) # Loaders order by qids which is not necessarily what we want print(link_tokens.shape) # Load embedding model @@ -61,11 +57,19 @@ def create_binary_dataset( ) model.eval() model.to(device) + + index_qids_set = set(index_qids) + known_qids_mask = np.array([q in index_qids_set for q in link_qids]) + + link_tokens = link_tokens[known_qids_mask] + link_qids = link_qids[known_qids_mask] + # Create DataLoader - dataset = list(zip(link_tokens, link_qids)) - dataset = torch.utils.data.Subset( - dataset, [i for i, (tokens, qid) in enumerate(dataset) if qid in index_qids_set] - ) + link_tokens = torch.from_numpy(link_tokens) + link_qids = torch.from_numpy(link_qids) + + link_tokens = link_tokens.to(torch.int64) + dataset = torch.utils.data.TensorDataset(link_tokens, link_qids) dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True) # Initialize dataset arrays @@ -95,7 +99,7 @@ def create_binary_dataset( # Find top matches top_qids = searcher.find(batch_embs.numpy().astype(np.float16), num_neighbors=2) - positive_mask = [index_qid_to_index[qid] for qid in batch_qids.numpy()] + positive_mask = [index_qid_to_index[int(qid)] for qid in batch_qids.numpy()] data_size = len(batch_tokens) description_tokens[output_index : output_index + data_size] = index_tokens[positive_mask] link_tokens_list[output_index : output_index + data_size] = batch_tokens.numpy() @@ -200,28 +204,27 @@ def create_multiclass_dataset( if __name__ == "__main__": index_embs_dir = Path( - "/lnet/work/home-students-external/farhan/troja/outputs/workdirs/ml9/damuel_for_index_3" + "/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/damuel_for_index_6" ) index_tokens_path = Path( - "/lnet/work/home-students-external/farhan/troja/outputs/workdirs/ml9/damuel_descs_together_tokens" + "/lnet/work/home-students-external/farhan/troja/outputs/v2_normal_filtered/descs_pages" ) link_tokens_path = Path( - "/lnet/work/home-students-external/farhan/troja/outputs/workdirs/ml9/damuel_links_together_tokens_0" + "/lnet/work/home-students-external/farhan/troja/outputs/v2_normal_filtered/links" ) embedding_model_path = Path( - "/lnet/work/home-students-external/farhan/troja/outputs/workdirs/ml9/models_2/final.pth" + "/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/models_5/ema.pth", ) output_path = Path( "/lnet/work/home-students-external/farhan/troja/outputs/reranking_test/reranker_dataset.npz" ) model_name = "/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base" - create_multiclass_dataset( + create_binary_dataset( index_embs_dir, index_tokens_path, link_tokens_path, model_name, embedding_model_path, output_path, - total_classes=10, ) diff --git a/src/reranking/models/__init__.py b/src/reranking/models/__init__.py new file mode 100644 index 0000000..5bc81b0 --- /dev/null +++ b/src/reranking/models/__init__.py @@ -0,0 +1,5 @@ +from .pairwise_mlp import PairwiseMLPReranker + +__all__ = [ + "PairwiseMLPReranker", +] diff --git a/src/reranking/models/base.py b/src/reranking/models/base.py new file mode 100644 index 0000000..e6ba5f3 --- /dev/null +++ b/src/reranking/models/base.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Dict + +import torch +from torch import nn + + +class BaseRerankingModel(nn.Module, ABC): + """Abstract base class for reranking models.""" + + @abstractmethod + def train_step(self, data: Dict[str, Any]) -> torch.Tensor: + """Run a single training step on the provided batch data and return the loss.""" + raise NotImplementedError + + @abstractmethod + def score(self, mention: str, entity_description: str) -> float: + """Compute a similarity-based probability that the mention refers to the entity.""" + raise NotImplementedError diff --git a/src/reranking/models/pairwise_mlp.py b/src/reranking/models/pairwise_mlp.py new file mode 100644 index 0000000..a51b9af --- /dev/null +++ b/src/reranking/models/pairwise_mlp.py @@ -0,0 +1,145 @@ +from __future__ import annotations + +from typing import Any, Dict, Mapping + +import torch +from torch import nn +from transformers import AutoTokenizer + +from reranking.models.base import BaseRerankingModel +from utils.embeddings import create_attention_mask +from utils.model_factory import ModelFactory, ModelOutputType + + +def _maybe_convert_output_type(output_type: ModelOutputType | str | None) -> ModelOutputType | None: + if output_type is None or isinstance(output_type, ModelOutputType): + return output_type + return ModelOutputType(output_type) + + +def _infer_output_dim(model: nn.Module) -> int: + if hasattr(model, "output_dim"): + return int(getattr(model, "output_dim")) + if hasattr(model, "config") and hasattr(model.config, "hidden_size"): + return int(model.config.hidden_size) + if hasattr(model, "model"): + nested_model = getattr(model, "model") + if hasattr(nested_model, "config") and hasattr(nested_model.config, "hidden_size"): + return int(nested_model.config.hidden_size) + raise ValueError("Unable to infer output dimension from the provided base model.") + + +def _to_device(batch: Mapping[str, torch.Tensor], device: torch.device) -> Dict[str, torch.Tensor]: + return {k: v.to(device) for k, v in batch.items()} + + +class PairwiseMLPReranker(BaseRerankingModel): + """Reranking model that augments a LEALLA encoder with an MLP head.""" + + def __init__( + self, + model_name_or_path: str, + *, + state_dict_path: str | None = None, + target_dim: int | None = None, + output_type: ModelOutputType | str | None = None, + tokenizer_name_or_path: str | None = None, + mlp_hidden_dim: int | None = None, + dropout: float = 0.1, + device: torch.device | str = "cpu", + ) -> None: + super().__init__() + self.device = torch.device(device) + + resolved_output_type = _maybe_convert_output_type(output_type) + self.base_model = ModelFactory.auto_load_from_file( + model_name_or_path, + state_dict_path=state_dict_path, + target_dim=target_dim, + output_type=resolved_output_type, + ) + + self.embedding_dim = _infer_output_dim(self.base_model) + hidden_dim = mlp_hidden_dim or self.embedding_dim + + self.classifier = nn.Sequential( + nn.Linear(self.embedding_dim * 2, hidden_dim), + nn.ReLU(), + nn.Dropout(p=dropout), + nn.Linear(hidden_dim, 1), + ) + + tokenizer_id = tokenizer_name_or_path or model_name_or_path + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) + self.loss_fn = nn.BCEWithLogitsLoss() + + self.to(self.device) + + def forward( + self, + mention_tokens: Mapping[str, torch.Tensor], + entity_tokens: Mapping[str, torch.Tensor], + ) -> torch.Tensor: + mention_tokens = _to_device(mention_tokens, self.device) + entity_tokens = _to_device(entity_tokens, self.device) + + mention_embeddings = self._encode(mention_tokens) + entity_embeddings = self._encode(entity_tokens) + + combined = torch.cat([mention_embeddings, entity_embeddings], dim=-1) + logits = self.classifier(combined).squeeze(-1) + return logits + + def train_step(self, data: Dict[str, Any]) -> torch.Tensor: + self.train() + + mention_tokens = data["mention_tokens"] + entity_tokens = data["entity_tokens"] + labels = data["labels"].to(self.device).float().view(-1) + + logits = self.forward(mention_tokens, entity_tokens).view(-1) + loss = self.loss_fn(logits, labels) + return loss + + def train(self): + self.super().train() + self.base_model.eval() + + @torch.inference_mode() + def score(self, mention: str, entity_description: str) -> float: + self.eval() + + mention_tokens = self.tokenizer( + mention, + padding=True, + truncation=True, + return_tensors="pt", + ).to(self.device) + entity_tokens = self.tokenizer( + entity_description, + padding=True, + truncation=True, + return_tensors="pt", + ).to(self.device) + mention_tokens["attention_mask"] = create_attention_mask(mention_tokens["input_ids"]) + entity_tokens["attention_mask"] = create_attention_mask(entity_tokens["input_ids"]) + + logits = self.forward( + {k: v for k, v in mention_tokens.items()}, + {k: v for k, v in entity_tokens.items()}, + ) + probability = torch.sigmoid(logits).item() + return probability + + @torch.inference_mode() + def _encode(self, tokens: Mapping[str, torch.Tensor] | torch.Tensor) -> torch.Tensor: + if isinstance(tokens, Mapping): + input_ids = tokens["input_ids"] + attention_mask = tokens.get("attention_mask") + if attention_mask is None: + attention_mask = create_attention_mask(input_ids) + else: + input_ids = tokens + attention_mask = create_attention_mask(tokens) + + return self.base_model(input_ids=input_ids, attention_mask=attention_mask) diff --git a/src/utils/model_factory.py b/src/utils/model_factory.py index 6d56afb..9d4e41b 100644 --- a/src/utils/model_factory.py +++ b/src/utils/model_factory.py @@ -45,6 +45,18 @@ def auto_load_from_file( target_dim: int | None = None, output_type: ModelOutputType | None = None, ) -> torch.nn.Module: + """ + Automatically loads a model from a specified file, optionally applying a state dictionary, target dimension, and output type. + + Args: + file_path (str): Path to the model definition file. + state_dict_path (str | None, optional): Path to the state dictionary file to load model weights. If None, weights are not loaded. + target_dim (int | None, optional): Target output dimension for the model. If None, uses the default dimension. + output_type (ModelOutputType | None, optional): Type of model output. If None, defaults to ModelOutputType.PoolerOutput. + + Returns: + torch.nn.Module: The constructed and optionally weight-loaded model instance. + """ builder = ModelBuilder(file_path) if output_type is None: output_type = ModelOutputType.PoolerOutput # the original/old default diff --git a/tests/reranking/models/test_pairwise_mlp.py b/tests/reranking/models/test_pairwise_mlp.py new file mode 100644 index 0000000..b7ff0b9 --- /dev/null +++ b/tests/reranking/models/test_pairwise_mlp.py @@ -0,0 +1,180 @@ +from __future__ import annotations + +import math +from typing import Sequence + +import pytest +import torch +from torch import nn +from transformers.tokenization_utils_base import BatchEncoding + +from reranking.models import PairwiseMLPReranker +from utils.embeddings import create_attention_mask + +MENTION_TEXTS = ["dummy mention positive", "dummy mention negative"] +ENTITY_TEXTS = ["dummy entity positive", "dummy entity negative"] + + +class DummyEmbeddingModel(nn.Module): + output_dim = 2 + + def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: # type: ignore[override] + masked = input_ids.float() * attention_mask + summed = masked.sum(dim=1, keepdim=True) + return torch.cat([summed, 2 * summed], dim=1) + + +class DummyTokenizer: + _TOKEN_MAP = { + MENTION_TEXTS[0]: [1, 1, 0, 0], + MENTION_TEXTS[1]: [1, 0, 0, 0], + ENTITY_TEXTS[0]: [3, 0, 0, 0], + ENTITY_TEXTS[1]: [0, 0, 0, 0], + } + + def __call__( + self, + texts: str | Sequence[str], + *, + padding: bool = True, + truncation: bool = True, + return_tensors: str = "pt", + ) -> BatchEncoding: + del padding, truncation, return_tensors + items = [texts] if isinstance(texts, str) else list(texts) + encodings = [self._TOKEN_MAP.get(text, [1, 0, 0, 0]) for text in items] + tensor = torch.tensor(encodings, dtype=torch.long) + attention_mask = (tensor != 0).long() + return BatchEncoding({"input_ids": tensor, "attention_mask": attention_mask}) + + +def _configure_classifier(model: PairwiseMLPReranker) -> None: + hidden_layer = model.classifier[0] + output_layer = model.classifier[-1] + + with torch.no_grad(): + hidden_layer.weight.copy_(torch.tensor([[0.1, 0.2, 0.3, 0.4]], dtype=torch.float)) + hidden_layer.bias.copy_(torch.tensor([0.5], dtype=torch.float)) + output_layer.weight.copy_(torch.tensor([[1.0]], dtype=torch.float)) + output_layer.bias.zero_() + + +def _prepare_labels(size: int, *, positive_index: int = 0) -> torch.Tensor: + labels = torch.zeros(size, dtype=torch.float) + labels[positive_index] = 1.0 + return labels + + +def _run_common_checks( + model: PairwiseMLPReranker, + mentions: Sequence[str], + entities: Sequence[str], + labels: torch.Tensor, +): + tokenizer = model.tokenizer + mention_batch = tokenizer(mentions, padding=True, truncation=True, return_tensors="pt").to( + model.device + ) + mention_batch["attention_mask"] = create_attention_mask(mention_batch["input_ids"]) + entity_batch = tokenizer(entities, padding=True, truncation=True, return_tensors="pt").to( + model.device + ) + entity_batch["attention_mask"] = create_attention_mask(entity_batch["input_ids"]) + + encode_out = model._encode(mention_batch["input_ids"]) + + loss = model.train_step( + { + "mention_tokens": dict(mention_batch), + "entity_tokens": dict(entity_batch), + "labels": labels.to(model.device), + } + ) + + with torch.no_grad(): + manual_mention_embeddings = model.base_model( + input_ids=mention_batch["input_ids"], + attention_mask=mention_batch.get("attention_mask"), + ) + manual_entity_embeddings = model.base_model( + input_ids=entity_batch["input_ids"], + attention_mask=entity_batch.get("attention_mask"), + ) + manual_logits = model.classifier( + torch.cat([manual_mention_embeddings, manual_entity_embeddings], dim=-1) + ).squeeze(-1) + expected_loss = model.loss_fn(manual_logits, labels.to(model.device)) + + score = model.score(mentions[0], entities[0]) + + return { + "encode": encode_out, + "loss": loss.detach(), + "score": score, + "manual_mention_embeddings": manual_mention_embeddings, + "manual_entity_embeddings": manual_entity_embeddings, + "manual_logits": manual_logits, + "expected_loss": expected_loss, + } + + +@pytest.fixture() +def dummy_model(monkeypatch) -> PairwiseMLPReranker: + dummy_embedding_model = DummyEmbeddingModel() + + def _mock_model_loader(*_args, **_kwargs): + return dummy_embedding_model + + monkeypatch.setattr( + "reranking.models.pairwise_mlp.ModelFactory.auto_load_from_file", _mock_model_loader + ) + monkeypatch.setattr( + "reranking.models.pairwise_mlp.AutoTokenizer.from_pretrained", + lambda *_args, **_kwargs: DummyTokenizer(), + ) + + model = PairwiseMLPReranker( + "dummy-model", + mlp_hidden_dim=1, + dropout=0.0, + ) + _configure_classifier(model) + return model + + +def test_pairwise_mlp_with_dummy_backbone(dummy_model: PairwiseMLPReranker) -> None: + labels = _prepare_labels(len(MENTION_TEXTS)) + results = _run_common_checks(dummy_model, MENTION_TEXTS, ENTITY_TEXTS, labels) + + assert torch.allclose( + results["encode"], + results["manual_mention_embeddings"], + ) + + assert torch.allclose(results["loss"], results["expected_loss"]) + + positive_logit = results["manual_logits"][0].item() + assert math.isclose(results["score"], torch.sigmoid(torch.tensor(positive_logit)).item()) + + +@pytest.mark.slow +def test_pairwise_mlp_with_lealla_backbone() -> None: + model = PairwiseMLPReranker( + "setu4993/LEALLA-base", + mlp_hidden_dim=1, + dropout=0.0, + ) + + labels = _prepare_labels(len(MENTION_TEXTS)) + results = _run_common_checks(model, MENTION_TEXTS, ENTITY_TEXTS, labels) + + assert torch.allclose( + results["encode"], + results["manual_mention_embeddings"], + ) + + assert torch.allclose(results["loss"], results["expected_loss"], atol=1e-6) + + positive_probability = torch.sigmoid(results["manual_logits"][0]).item() + assert 0.0 <= results["score"] <= 1.0 + assert math.isclose(results["score"], positive_probability, rel_tol=1e-5) From 05856857a80de77275636ede860325c903410737 Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Sun, 5 Oct 2025 11:00:42 +0200 Subject: [PATCH 05/28] fix(rerank): :bug: calling super in .train override --- src/reranking/models/pairwise_mlp.py | 6 +++--- tests/reranking/models/test_pairwise_mlp.py | 24 +++++++++++++++++++++ 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/src/reranking/models/pairwise_mlp.py b/src/reranking/models/pairwise_mlp.py index a51b9af..6ef0308 100644 --- a/src/reranking/models/pairwise_mlp.py +++ b/src/reranking/models/pairwise_mlp.py @@ -58,6 +58,7 @@ def __init__( target_dim=target_dim, output_type=resolved_output_type, ) + self.base_model.eval() self.embedding_dim = _infer_output_dim(self.base_model) hidden_dim = mlp_hidden_dim or self.embedding_dim @@ -102,13 +103,12 @@ def train_step(self, data: Dict[str, Any]) -> torch.Tensor: return loss def train(self): - self.super().train() + super().train() + # Make sure that base model is never trained. self.base_model.eval() @torch.inference_mode() def score(self, mention: str, entity_description: str) -> float: - self.eval() - mention_tokens = self.tokenizer( mention, padding=True, diff --git a/tests/reranking/models/test_pairwise_mlp.py b/tests/reranking/models/test_pairwise_mlp.py index b7ff0b9..b4a2d75 100644 --- a/tests/reranking/models/test_pairwise_mlp.py +++ b/tests/reranking/models/test_pairwise_mlp.py @@ -178,3 +178,27 @@ def test_pairwise_mlp_with_lealla_backbone() -> None: positive_probability = torch.sigmoid(results["manual_logits"][0]).item() assert 0.0 <= results["score"] <= 1.0 assert math.isclose(results["score"], positive_probability, rel_tol=1e-5) + + +@pytest.mark.slow +def test_pairwise_mlp_with_lealla_backbone_more_iters() -> None: + model = PairwiseMLPReranker( + "setu4993/LEALLA-base", + mlp_hidden_dim=1, + dropout=0.0, + ) + + for _ in range(3): + labels = _prepare_labels(len(MENTION_TEXTS)) + results = _run_common_checks(model, MENTION_TEXTS, ENTITY_TEXTS, labels) + + assert torch.allclose( + results["encode"], + results["manual_mention_embeddings"], + ) + + assert torch.allclose(results["loss"], results["expected_loss"], atol=1e-6) + + positive_probability = torch.sigmoid(results["manual_logits"][0]).item() + assert 0.0 <= results["score"] <= 1.0 + assert math.isclose(results["score"], positive_probability, rel_tol=1e-5) From 8ed60dad29817da558f4c8dc8347d9feee6a3934 Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Sun, 5 Oct 2025 11:01:47 +0200 Subject: [PATCH 06/28] feat(rerank): :construction: new trainer code that will work more easily with bunch of different models --- src/reranking/training/trainer.py | 134 ++++++++++++++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 src/reranking/training/trainer.py diff --git a/src/reranking/training/trainer.py b/src/reranking/training/trainer.py new file mode 100644 index 0000000..0aba802 --- /dev/null +++ b/src/reranking/training/trainer.py @@ -0,0 +1,134 @@ +import logging +import os +from copy import deepcopy +from pathlib import Path + +import numpy as np +import torch + +from reranking.models.pairwise_mlp import _to_device + +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +import gin +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn +import torch.optim as optim +import wandb +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import DataLoader + +from finetunings.finetune_model.data import ( + LightWeightDataset, + SaveInformation, + save_model, +) +from finetunings.finetune_model.ddp import cleanup, setup +from finetunings.finetune_model.monitoring import get_gradient_norm, process_metrics +from finetunings.finetune_model.train import forward_to_embeddings, load_model +from reranking.models.base import BaseRerankingModel +from reranking.models.pairwise_mlp import PairwiseMLPReranker +from utils.running_averages import RunningAverages + +# Settings =========================================== + +_RUNNING_AVERAGE_SMALL = 100 +_RUNNING_AVERAGE_BIG = 1000 + +_logger = logging.getLogger("finetuning.finetune_model.train_ddp") + + +if torch.cuda.is_available(): + _logger.debug("Running on CUDA.") + device = torch.device("cuda") +else: + _logger.debug("CUDA is not available.") + device = torch.device("cpu") + + +def setup(rank, world_size): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + + # initialize the process group + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + +def cleanup(): + dist.destroy_process_group() + + +SEED = 0 +torch.manual_seed(SEED) + + +def _ddp_train( + rank: int, + world_size: int, + model: BaseRerankingModel, + dataloader, + optimizer, + epochs, + gradient_clip=1.0, +): + setup(rank, world_size) + + model = DDP(model.to(rank), device_ids=[rank]) + model = torch.compile(model) + + is_the_main_process = rank == 0 + + scaler = torch.amp.GradScaler("cuda") + + loss = None + + def step(): + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip) + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + + step = torch.compile(step) + + for epoch in range(epochs): + if is_the_main_process: + _logger.info(f"Starting epoch {epoch + 1}/{epochs}") + + for batch in dataloader: + global_step += 1 + + with torch.autocast(device_type="cuda"): + loss = model.train_step(_to_device(batch, rank)) + step() + + cleanup() + + +def cleanup(): + dist.destroy_process_group() + + +# Training =========================================== +@gin.configurable +def train_ddp(): + model = PairwiseMLPReranker(...) + dataloader = ... + optimizer = optim.Adam(model.parameters(), lr=0.0001) + epochs = 10 + world_size = torch.cuda.device_count() + + mp.spawn( + _ddp_train, + args=( + world_size, + model, + dataloader, + optimizer, + epochs, + gradient_clip, + ), + nprocs=world_size, + ) From b4e8ed7ec5266d442eda4ac3d5bb8262b5dc2c6a Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Sun, 5 Oct 2025 11:42:23 +0200 Subject: [PATCH 07/28] feat(reranking): :sparkles: add training configs --- src/reranking/training/trainer.py | 75 +++++++++++-------- src/reranking/training/training_configs.py | 54 +++++++++++++ .../training/test_training_configs.py | 23 ++++++ 3 files changed, 120 insertions(+), 32 deletions(-) create mode 100644 src/reranking/training/training_configs.py create mode 100644 tests/reranking/training/test_training_configs.py diff --git a/src/reranking/training/trainer.py b/src/reranking/training/trainer.py index 0aba802..1108e8b 100644 --- a/src/reranking/training/trainer.py +++ b/src/reranking/training/trainer.py @@ -1,42 +1,25 @@ import logging import os -from copy import deepcopy -from pathlib import Path import numpy as np import torch -from reranking.models.pairwise_mlp import _to_device - torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True import gin import torch.distributed as dist import torch.multiprocessing as mp -import torch.nn as nn -import torch.optim as optim -import wandb from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler -from finetunings.finetune_model.data import ( - LightWeightDataset, - SaveInformation, - save_model, -) from finetunings.finetune_model.ddp import cleanup, setup -from finetunings.finetune_model.monitoring import get_gradient_norm, process_metrics -from finetunings.finetune_model.train import forward_to_embeddings, load_model -from reranking.models.base import BaseRerankingModel -from reranking.models.pairwise_mlp import PairwiseMLPReranker -from utils.running_averages import RunningAverages +from reranking.training.training_configs import TrainingConfig, pairwise_mlp # Settings =========================================== -_RUNNING_AVERAGE_SMALL = 100 -_RUNNING_AVERAGE_BIG = 1000 -_logger = logging.getLogger("finetuning.finetune_model.train_ddp") +_logger = logging.getLogger("reranking.train.trainer") if torch.cuda.is_available(): @@ -66,14 +49,26 @@ def cleanup(): def _ddp_train( rank: int, world_size: int, - model: BaseRerankingModel, - dataloader, - optimizer, + training_config: TrainingConfig, epochs, gradient_clip=1.0, ): setup(rank, world_size) + model = training_config.model + dataset = training_config.dataset + optimizer = training_config.optimizer + + sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True) + + dataloader = DataLoader( + dataset, + batch_size=training_config.batch_size, + sampler=sampler, + pin_memory=True, + num_workers=2, + ) + model = DDP(model.to(rank), device_ids=[rank]) model = torch.compile(model) @@ -97,12 +92,23 @@ def step(): if is_the_main_process: _logger.info(f"Starting epoch {epoch + 1}/{epochs}") - for batch in dataloader: + for links, entities, labels in dataloader: global_step += 1 + links = links.to(rank, non_blocking=True) + entities = entities.to(rank, non_blocking=True) + labels = labels.to(rank, non_blocking=True) + + batch_data = { + "mention_tokens": links, + "entity_tokens": entities, + "labels": labels, + } with torch.autocast(device_type="cuda"): - loss = model.train_step(_to_device(batch, rank)) + loss = model.train_step(batch_data) step() + if is_the_main_process and global_step % 10 == 0: + _logger.info(f"Step {global_step}, loss: {loss.item():.4f}") cleanup() @@ -112,23 +118,28 @@ def cleanup(): # Training =========================================== -@gin.configurable def train_ddp(): - model = PairwiseMLPReranker(...) - dataloader = ... - optimizer = optim.Adam(model.parameters(), lr=0.0001) + _logger.info("Starting DDP training") + gradient_clip = 1.0 epochs = 10 world_size = torch.cuda.device_count() + _logger.debug(f"Using {world_size} GPUs for training") + + _logger.info("Loading training configuration") + training_config = pairwise_mlp() + _logger.info(f"Training configuration loaded: {training_config.config_name}") mp.spawn( _ddp_train, args=( world_size, - model, - dataloader, - optimizer, + training_config, epochs, gradient_clip, ), nprocs=world_size, ) + + +if __name__ == "__main__": + train_ddp() diff --git a/src/reranking/training/training_configs.py b/src/reranking/training/training_configs.py new file mode 100644 index 0000000..1181402 --- /dev/null +++ b/src/reranking/training/training_configs.py @@ -0,0 +1,54 @@ +import os +from dataclasses import dataclass +from pathlib import Path + +import numpy as np +import torch + +from reranking.models.base import BaseRerankingModel +from reranking.models.pairwise_mlp import PairwiseMLPReranker + + +@dataclass +class TrainingConfig: + config_name: str + model: BaseRerankingModel + dataset: torch.utils.data.Dataset + optimizer: torch.optim.Optimizer + batch_size: int + output_dir: str + save_each: int = 1000 + + def get_output_path(self, step: int) -> str: + dir_path = Path(self.output_dir) / self.config_name + dir_path.mkdir(parents=True, exist_ok=True) + return f"{dir_path}/{step}.pth" + + +def pairwise_mlp() -> TrainingConfig: + name = "pairwise_mlp" + LR = 0.0001 + SAVE_EACH = 1000 + BATCH_SIZE = 64 + model = PairwiseMLPReranker( + model_name_or_path="/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base", + state_dict_path="/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/models_5/ema.pth", + ) + data = np.load( + "/lnet/work/home-students-external/farhan/troja/outputs/reranking_test/reranker_dataset.npz" + ) + description_tokens = torch.tensor(data["description_tokens"]) + link_tokens = torch.tensor(data["link_tokens"]) + labels = torch.tensor(data["labels"]) + dataset = torch.utils.data.TensorDataset(link_tokens, description_tokens, labels) + optimizer = torch.optim.AdamW(model.parameters(), lr=LR) + output_dir = "/lnet/work/home-students-external/farhan/troja/outputs/reranking_models" + return TrainingConfig( + config_name=name, + model=model, + dataset=dataset, + optimizer=optimizer, + output_dir=output_dir, + save_each=SAVE_EACH, + batch_size=BATCH_SIZE, + ) diff --git a/tests/reranking/training/test_training_configs.py b/tests/reranking/training/test_training_configs.py new file mode 100644 index 0000000..09b1b32 --- /dev/null +++ b/tests/reranking/training/test_training_configs.py @@ -0,0 +1,23 @@ +from unittest.mock import MagicMock + +import pytest + +from reranking.training.training_configs import TrainingConfig + + +def test_get_output_path_creates_directory(tmp_path): + output_root = tmp_path / "outputs" + config = TrainingConfig( + config_name="test_config", + model=MagicMock(), + dataset=MagicMock(), + optimizer=MagicMock(), + batch_size=1, + output_dir=str(output_root), + ) + + path = config.get_output_path(step=5) + + expected_dir = output_root / "test_config" + assert expected_dir.exists() and expected_dir.is_dir() + assert path == f"{expected_dir}/5.pth" From 45364df35e5096aef7cd810843ba8ba0d50378b0 Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Mon, 6 Oct 2025 14:14:05 +0200 Subject: [PATCH 08/28] feat(models): :sparkles: add searcher that expects inputs to be tensor on the same device as the module --- src/models/searchers/brute_force_searcher.py | 57 +++++++++++++++++++ tests/models/test_brute_force_searcher.py | 58 ++++++++++++++++++++ 2 files changed, 115 insertions(+) diff --git a/src/models/searchers/brute_force_searcher.py b/src/models/searchers/brute_force_searcher.py index 1b6427c..4bb488b 100644 --- a/src/models/searchers/brute_force_searcher.py +++ b/src/models/searchers/brute_force_searcher.py @@ -116,3 +116,60 @@ def find(self, batch: np.ndarray, num_neighbors: int, mask=None) -> np.ndarray: def build(self): pass + + +class DPBruteForceSearcherPT(Searcher): + def __init__(self, embs: np.ndarray, results: np.ndarray, run_build_from_init: bool = True): + if torch.cuda.is_available(): + _logger.info("Running on CUDA.") + self.device: torch.device = torch.device("cuda") + else: + _logger.info("CUDA is not available.") + self.device: torch.device = torch.device("cpu") + self.module_searcher: Optional[nn.DataParallel] = None + self.required_num_neighbors: Optional[int] = None + super().__init__(embs, results, run_build_from_init) + + @torch.compile + @torch.inference_mode() + def find(self, batch: torch.Tensor, num_neighbors: int) -> np.ndarray: + """ + Finds the nearest neighbors for a given batch of input data. + CAREFUL: This is an optimized version that comes with potential pitfalls to get better performance. + Read Notes for details! + + Args: + batch (torch.Tensor): A batch of input data for which neighbors are to be found. + num_neighbors (int): The number of nearest neighbors to retrieve. + Returns: + np.ndarray: An array containing the results corresponding to the nearest neighbors. + Raises: + TypeError: If `module_searcher` if an unexpected attribute access occurs when using module_searcher. + Notes: + - It is not possible to change num_neighbors after the first call to find. + If you need to do that, you need to reinitialize this object. If you call the find with different + num_neighbors, it will not raise an error and will fail silently. + - The first call to find will be slow, because the module_searcher will be initialized and torch.compile is called. + """ + # with torch.inference_mode(), torch.autocast( + # device_type=self.device.type, dtype=torch.float16 + # ): + with torch.no_grad(): + # A try except trick to avoid the overhead of checking if the module_searcher is None + # on every call to find. + # This is a bit of a hack, but it should make things faster as we are suggesting that the module_searcher is initialized. + try: + with torch.amp.autocast(device_type="cuda", dtype=torch.float16): + top_indices: torch.Tensor = self.module_searcher(batch) + except TypeError as e: + if self.module_searcher is not None: + raise e + self.module_searcher = nn.DataParallel(_WrappedSearcher(self.embs, num_neighbors)) + self.module_searcher.to(self.device) + self.required_num_neighbors = num_neighbors + top_indices: torch.Tensor = self.module_searcher(batch) + + return self.results[top_indices.cpu().numpy()] + + def build(self): + pass diff --git a/tests/models/test_brute_force_searcher.py b/tests/models/test_brute_force_searcher.py index f70a08e..7159ceb 100644 --- a/tests/models/test_brute_force_searcher.py +++ b/tests/models/test_brute_force_searcher.py @@ -5,6 +5,7 @@ from models.searchers.brute_force_searcher import ( BruteForceSearcher, DPBruteForceSearcher, + DPBruteForceSearcherPT, ) # torch.compiler.disable(BruteForceSearcher.find) @@ -129,3 +130,60 @@ def test_dataparallel_initialization(self, small_embs): searcher = DPBruteForceSearcher(small_embs, np.arange(len(small_embs))) searcher.find(np.random.random((1, 3)), 2) # This should initialize module_searcher assert isinstance(searcher.module_searcher, torch.nn.DataParallel) + + +class TestDPBruteForceSearcherPT: + @pytest.fixture + def small_embs(self): + return torch.tensor( + [ + [0.9, 0.9, 0.9], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + ] + ) + + def test_search_present(self, small_embs): + searcher = DPBruteForceSearcherPT(small_embs, np.arange(4)) + for i, e in enumerate(small_embs): + res = searcher.find(e.unsqueeze(0), 2) + assert res[0][0] == i + assert res[0][1] != i + assert len(res[0]) == 2 + + def test_search_missing(self): + embs = torch.tensor( + [ + [1.0, 1.0, 1.0], + [1.0, 0.0, 0.0], + [1.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + ] + ) + searcher = DPBruteForceSearcherPT(embs, np.arange(4)) + res = searcher.find(torch.tensor([[1.0, 0.0, 1.0]]), 2) + assert res[0][0] == 0 + + def test_device_selection(self, small_embs): + searcher = DPBruteForceSearcherPT(small_embs, np.arange(len(small_embs))) + expected_device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + assert searcher.device == expected_device + + def test_changing_num_neighbors(self, small_embs): + searcher = DPBruteForceSearcherPT(small_embs, np.arange(len(small_embs))) + searcher.find( + torch.from_numpy(np.random.random((1, 3))).to(torch.float32), 2 + ) # Initialize with 2 neighbors + # with pytest.raises(Exception): + # Does nothing: + searcher.find( + torch.from_numpy(np.random.random((1, 3))).to(torch.float32), 3 + ) # Try to change to 3 neighbors + + def test_dataparallel_initialization(self, small_embs): + searcher = DPBruteForceSearcherPT(small_embs, np.arange(len(small_embs))) + searcher.find( + torch.from_numpy(np.random.random((1, 3))).to(torch.float32), 2 + ) # This should initialize module_searcher + assert isinstance(searcher.module_searcher, torch.nn.DataParallel) From 78e446248e7591727074963dc54d50532a93f033 Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Mon, 6 Oct 2025 16:38:51 +0200 Subject: [PATCH 09/28] feat(utils): :sparkles: benchmark whether DataParallel really helps --- src/scripts/benchmarking/dp_vs_manual.py | 118 +++++++++++++++++++++++ 1 file changed, 118 insertions(+) create mode 100644 src/scripts/benchmarking/dp_vs_manual.py diff --git a/src/scripts/benchmarking/dp_vs_manual.py b/src/scripts/benchmarking/dp_vs_manual.py new file mode 100644 index 0000000..84f6759 --- /dev/null +++ b/src/scripts/benchmarking/dp_vs_manual.py @@ -0,0 +1,118 @@ +import time + +import torch +import torch.nn as nn +import torch.nn.functional as F + +torch._dynamo.config.recompile_limit = 1000 +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True + + +class _WrappedSearcher(nn.Module): + def __init__(self, kb_embs, num_neighbors): + super().__init__() + self.register_buffer("kb_embs", kb_embs) + self.num_neighbors: int = num_neighbors + + def forward(self, x): + dot_product = F.linear(x, self.kb_embs) + _, top_indices = dot_product.topk(self.num_neighbors) + return top_indices + + +def benchmark(f, input_data, warmup_iters=20, profile_iters=100): + for _ in range(warmup_iters): + with torch.inference_mode(): + f(input_data) + torch.cuda.synchronize() + + time_sum = 0.0 + for _ in range(profile_iters): + start = time.time() + with torch.inference_mode(): + f(input_data) + torch.cuda.synchronize() + end = time.time() + time_sum += end - start + return time_sum / profile_iters + + +num_cuda_devices = torch.cuda.device_count() +print(f"Number of available CUDA devices: {num_cuda_devices}") + +cuda_devices = [torch.device(f"cuda:{i}") for i in range(num_cuda_devices)] +print(f"CUDA devices: {cuda_devices}") + +kb = torch.randn(1000, 768) +data = torch.randn(32, 768).to(cuda_devices[0]) + +searcher_dp = _WrappedSearcher(kb, num_neighbors=10) +searcher_dp = nn.DataParallel(searcher_dp, device_ids=cuda_devices) +searcher_dp = torch.compile(searcher_dp) +searcher_dp.to("cuda") + + +class ManualSearcher: + def __init__(self, kb, device_ids, num_neighbors): + self.device_ids = device_ids + self.searchers = [] + for i, device in enumerate(device_ids): + searcher = _WrappedSearcher(kb, num_neighbors=num_neighbors).to(device) + self.searchers.append(searcher) + + def find(self, x): + # Split input data across available devices + inputs = nn.parallel.scatter(x, self.device_ids) + # Compute on each device + outputs = [ + searcher(input_chunk.to(device)) + for searcher, input_chunk, device in zip(self.searchers, inputs, self.device_ids) + ] + gathered = nn.parallel.gather(outputs, self.device_ids[0]) + return gathered + + +searcher_manual = ManualSearcher(kb, device_ids=cuda_devices, num_neighbors=10) +searcher_manual.find = torch.compile(searcher_manual.find) + +print("Comparing outputs...") +out_dp = searcher_dp(data) +out_manual = searcher_manual.find(data) +print("Outputs are equal:", torch.equal(out_dp, out_manual)) +assert torch.equal(out_dp, out_manual) + +kb = torch.randn(10000000, 128).to(torch.float16) +data = torch.randn(256, 128).to(cuda_devices[0], dtype=torch.float16) + +print("Benchmarking DataParallel compiled searcher...") +searcher_dp = _WrappedSearcher(kb, num_neighbors=10) +searcher_dp = nn.DataParallel(searcher_dp, device_ids=cuda_devices) +searcher_dp = torch.compile(searcher_dp) +searcher_dp.to("cuda") +print(benchmark(searcher_dp, data)) +del searcher_dp + +print("Benchmarking DataParallel searcher...") +searcher_dp = _WrappedSearcher(kb, num_neighbors=10) +searcher_dp = nn.DataParallel(searcher_dp, device_ids=cuda_devices) +searcher_dp.to("cuda") +print(benchmark(searcher_dp, data)) +del searcher_dp + +print("Benchmarking Manual searcher...") +searcher_manual = ManualSearcher(kb, device_ids=cuda_devices, num_neighbors=10) +searcher_manual.find = torch.compile(searcher_manual.find) +print(benchmark(searcher_manual.find, data)) +del searcher_manual + +print("Benchmarking Normal searcher...") +normal_searcher = _WrappedSearcher(kb, num_neighbors=10).to("cuda") +print(benchmark(normal_searcher, data)) +del normal_searcher + +print("Benchmarking Compiled Normal searcher...") +normal_searcher_c = _WrappedSearcher(kb, num_neighbors=10).to("cuda") +normal_searcher_c = torch.compile(normal_searcher_c) +print(benchmark(normal_searcher_c, data)) +del normal_searcher_c From d22eebd443c8711eb2012bcede17983ef018a551 Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Mon, 6 Oct 2025 16:44:54 +0200 Subject: [PATCH 10/28] skip unused tests --- tests/models/test_faiss_searcher.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/models/test_faiss_searcher.py b/tests/models/test_faiss_searcher.py index ea55b8d..b89b428 100644 --- a/tests/models/test_faiss_searcher.py +++ b/tests/models/test_faiss_searcher.py @@ -23,6 +23,7 @@ def assert_equal_results(faiss_results, brute_results): np.testing.assert_array_equal(faiss_results, brute_results) +@pytest.mark.skip(reason="FAISS is not currently supported") def test_small(generate_data): from models.searchers.faiss_searcher import FaissSearcher @@ -36,6 +37,7 @@ def test_small(generate_data): assert_equal_results(faiss_out, brute_out) +@pytest.mark.skip(reason="FAISS is not currently supported") def test_large(generate_data): from models.searchers.faiss_searcher import FaissSearcher From 4aa3de16d75a18487e20e37d79453f8588dec2e8 Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Mon, 6 Oct 2025 16:46:29 +0200 Subject: [PATCH 11/28] Refactor and enhance the reranking and searcher modules - Updated BruteForceSearcher to use register_buffer for kb_embs and added eval mode with gradient disabling. - Modified DPBruteForceSearcherPT to compile the searcher and improve inference performance. - Enhanced create_binary_dataset to sort embeddings and qids, and adjusted data types for better memory efficiency. - Refactored PairwiseMLPReranker to streamline the forward pass and ensure base model remains non-trainable. - Introduced validation during training in the trainer module and improved logging for better monitoring. - Added a new reranking2 script for improved entity linking functionality. - Removed deprecated reranking.py script to clean up the codebase. - Updated test cases for pairwise MLP and training configurations to reflect recent changes. --- src/models/searchers/brute_force_searcher.py | 39 +++-- src/reranking/binary/create_dataset.py | 77 +++++--- src/reranking/models/pairwise_mlp.py | 76 ++++---- src/reranking/training/trainer.py | 97 +++++++++-- src/reranking/training/training_configs.py | 22 ++- src/run_action_gin.py | 6 + src/scripts/qwen/reranker.py | 27 ++- src/scripts/qwen/reranking.py | 78 --------- src/scripts/qwen/reranking2.py | 164 ++++++++++++++++++ src/tokenization/runner.py | 30 ++++ tests/reranking/models/test_pairwise_mlp.py | 18 +- .../training/test_training_configs.py | 1 + 12 files changed, 432 insertions(+), 203 deletions(-) delete mode 100644 src/scripts/qwen/reranking.py create mode 100644 src/scripts/qwen/reranking2.py diff --git a/src/models/searchers/brute_force_searcher.py b/src/models/searchers/brute_force_searcher.py index 4bb488b..70a6ed0 100644 --- a/src/models/searchers/brute_force_searcher.py +++ b/src/models/searchers/brute_force_searcher.py @@ -45,7 +45,8 @@ def build(self) -> None: class _WrappedSearcher(nn.Module): def __init__(self, kb_embs, num_neighbors): super().__init__() - self.kb_embs: torch.Tensor = nn.Parameter(kb_embs) + # Replace Parameter with register_buffer + self.register_buffer("kb_embs", kb_embs) self.num_neighbors: int = num_neighbors # @torch.compile @@ -106,6 +107,10 @@ def find(self, batch: np.ndarray, num_neighbors: int, mask=None) -> np.ndarray: _WrappedSearcher(torch.from_numpy(self.embs), num_neighbors) ) self.module_searcher.to(self.device) + # Set module to eval() and disable gradients + self.module_searcher.eval() + for param in self.module_searcher.parameters(): + param.requires_grad = False self.required_num_neighbors = num_neighbors top_indices: torch.Tensor = self.module_searcher( torch.from_numpy(batch).to(self.device) @@ -130,7 +135,6 @@ def __init__(self, embs: np.ndarray, results: np.ndarray, run_build_from_init: b self.required_num_neighbors: Optional[int] = None super().__init__(embs, results, run_build_from_init) - @torch.compile @torch.inference_mode() def find(self, batch: torch.Tensor, num_neighbors: int) -> np.ndarray: """ @@ -154,20 +158,25 @@ def find(self, batch: torch.Tensor, num_neighbors: int) -> np.ndarray: # with torch.inference_mode(), torch.autocast( # device_type=self.device.type, dtype=torch.float16 # ): - with torch.no_grad(): - # A try except trick to avoid the overhead of checking if the module_searcher is None - # on every call to find. - # This is a bit of a hack, but it should make things faster as we are suggesting that the module_searcher is initialized. - try: - with torch.amp.autocast(device_type="cuda", dtype=torch.float16): - top_indices: torch.Tensor = self.module_searcher(batch) - except TypeError as e: - if self.module_searcher is not None: - raise e - self.module_searcher = nn.DataParallel(_WrappedSearcher(self.embs, num_neighbors)) - self.module_searcher.to(self.device) - self.required_num_neighbors = num_neighbors + # A try except trick to avoid the overhead of checking if the module_searcher is None + # on every call to find. + # This is a bit of a hack, but it should make things faster as we are suggesting that the module_searcher is initialized. + try: + with torch.amp.autocast(device_type="cuda"): top_indices: torch.Tensor = self.module_searcher(batch) + except TypeError as e: + if self.module_searcher is not None: + raise e + self.module_searcher = torch.compile( + nn.DataParallel(_WrappedSearcher(self.embs, num_neighbors)) + ) + self.module_searcher.to(self.device) + # Set module to eval() and disable gradients + self.module_searcher.eval() + for param in self.module_searcher.parameters(): + param.requires_grad = False + self.required_num_neighbors = num_neighbors + top_indices: torch.Tensor = self.module_searcher(batch) return self.results[top_indices.cpu().numpy()] diff --git a/src/reranking/binary/create_dataset.py b/src/reranking/binary/create_dataset.py index 5378069..09aeb48 100644 --- a/src/reranking/binary/create_dataset.py +++ b/src/reranking/binary/create_dataset.py @@ -7,7 +7,7 @@ import torch.utils.data from tqdm import tqdm -from models.searchers.brute_force_searcher import BruteForceSearcher +from models.searchers.brute_force_searcher import BruteForceSearcher, DPBruteForceSearcherPT from utils.embeddings import create_attention_mask from utils.loaders import load_embs_and_qids, load_tokens_qids_from_dir from utils.model_factory import ModelFactory @@ -15,11 +15,10 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -@nb.njit def get_neg_qids(top_qids, batch_qids): neg_qids = [] - for row in top_qids: - if row[0] not in batch_qids: + for row, batch_qid in zip(top_qids, batch_qids): + if row[0] != batch_qid: neg_qids.append(row[0]) else: neg_qids.append(row[1]) @@ -39,7 +38,19 @@ def create_binary_dataset( # Load index embeddings, qids, and tokens index_embs, index_qids = load_embs_and_qids(index_embs_dir) index_embs = index_embs.astype(np.float16) - index_tokens, _ = load_tokens_qids_from_dir(index_tokens_path) + index_tokens, index_qids_from_tokens = load_tokens_qids_from_dir(index_tokens_path) + + # Sort index_embs and index_qids based on index_qids + sort_indices = np.argsort(index_qids) + index_qids = index_qids[sort_indices] + index_embs = index_embs[sort_indices] + + sort_indices_tokens = np.argsort(index_qids_from_tokens) + index_qids_from_tokens = index_qids_from_tokens[sort_indices_tokens] + index_tokens = index_tokens[sort_indices_tokens] + + np.testing.assert_array_equal(index_qids, index_qids_from_tokens) + print(index_tokens.shape) # Create BruteForceSearcher @@ -57,6 +68,8 @@ def create_binary_dataset( ) model.eval() model.to(device) + model.to(torch.bfloat16) + model = torch.compile(model) index_qids_set = set(index_qids) known_qids_mask = np.array([q in index_qids_set for q in link_qids]) @@ -68,9 +81,11 @@ def create_binary_dataset( link_tokens = torch.from_numpy(link_tokens) link_qids = torch.from_numpy(link_qids) - link_tokens = link_tokens.to(torch.int64) + link_tokens = link_tokens.to(torch.int32) dataset = torch.utils.data.TensorDataset(link_tokens, link_qids) - dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True) + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=2 + ) # Initialize dataset arrays description_tokens = [] @@ -78,9 +93,10 @@ def create_binary_dataset( y = [] print("Dataset length:", len(dataset)) - description_tokens = np.zeros((len(dataset) * 2, index_tokens.shape[1])) - link_tokens_list = np.zeros((len(dataset) * 2, link_tokens.shape[1])) - y = np.zeros((len(dataset) * 2,)) + description_tokens = np.zeros((len(dataset) * 2, index_tokens.shape[1]), dtype=np.int32) + link_tokens_list = np.zeros((len(dataset) * 2, link_tokens.shape[1]), dtype=np.int32) + y = np.zeros((len(dataset) * 2,), dtype=np.int8) + qids = np.zeros((len(dataset) * 2,), dtype=np.int32) output_index = 0 index_qid_to_index = {qid: i for i, qid in enumerate(index_qids)} @@ -90,29 +106,37 @@ def create_binary_dataset( dataloader, desc="Creating dataset", total=len(dataloader) ): # Embed link tokens - with torch.no_grad(): - batch_embs = model( - batch_tokens.to(device).to(torch.int64), - create_attention_mask(batch_tokens).to(device), - ).cpu() + with torch.inference_mode(): + batch_embs = ( + model( + batch_tokens.to(device).to(torch.int64), + create_attention_mask(batch_tokens).to(device), + ) + .to(torch.float16) + .cpu() + ) # Find top matches - top_qids = searcher.find(batch_embs.numpy().astype(np.float16), num_neighbors=2) + top_qids = searcher.find(batch_embs.numpy(), num_neighbors=2) + + del batch_embs - positive_mask = [index_qid_to_index[int(qid)] for qid in batch_qids.numpy()] + positive_mask = [index_qid_to_index[int(qid)] for qid in batch_qids] data_size = len(batch_tokens) description_tokens[output_index : output_index + data_size] = index_tokens[positive_mask] link_tokens_list[output_index : output_index + data_size] = batch_tokens.numpy() y[output_index : output_index + data_size] = 1 + qids[output_index : output_index + data_size] = batch_qids.numpy() output_index += data_size - neg_qids = get_neg_qids(top_qids, set(batch_qids.numpy())) + neg_qids = get_neg_qids(top_qids, batch_qids) negative_mask = [index_qid_to_index[qid] for qid in neg_qids] description_tokens[output_index : output_index + data_size] = index_tokens[negative_mask] link_tokens_list[output_index : output_index + data_size] = batch_tokens.numpy() y[output_index : output_index + data_size] = 0 + qids[output_index : output_index + data_size] = np.array(neg_qids) output_index += data_size @@ -173,10 +197,14 @@ def create_multiclass_dataset( link_tokens, qids = [], [] for B_tokens, B_qids in tqdm(dataloader, desc="Creating dataset"): with torch.no_grad(): - B_embs = model( - B_tokens.to(device).to(torch.int64), - create_attention_mask(B_tokens).to(device), - ).cpu() + B_embs = ( + model( + B_tokens.to(device).to(torch.int64), + create_attention_mask(B_tokens).to(device), + ) + .to(torch.float16) + .cpu() + ) top_qids = searcher.find(B_embs.numpy().astype(np.float16), num_neighbors=total_classes) for i, qid in enumerate(B_qids.numpy()): @@ -195,7 +223,7 @@ def create_multiclass_dataset( print(link_tokens.shape) print(qids.shape) - np.savez( + np.save( output_path, link_tokens=link_tokens, qids=qids, @@ -216,7 +244,7 @@ def create_multiclass_dataset( "/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/models_5/ema.pth", ) output_path = Path( - "/lnet/work/home-students-external/farhan/troja/outputs/reranking_test/reranker_dataset.npz" + "/lnet/work/home-students-external/farhan/troja/outputs/reranking_test/reranker_dataset_with_qids.npz" ) model_name = "/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base" @@ -227,4 +255,5 @@ def create_multiclass_dataset( model_name, embedding_model_path, output_path, + batch_size=2048, ) diff --git a/src/reranking/models/pairwise_mlp.py b/src/reranking/models/pairwise_mlp.py index 6ef0308..f2bd266 100644 --- a/src/reranking/models/pairwise_mlp.py +++ b/src/reranking/models/pairwise_mlp.py @@ -29,10 +29,6 @@ def _infer_output_dim(model: nn.Module) -> int: raise ValueError("Unable to infer output dimension from the provided base model.") -def _to_device(batch: Mapping[str, torch.Tensor], device: torch.device) -> Dict[str, torch.Tensor]: - return {k: v.to(device) for k, v in batch.items()} - - class PairwiseMLPReranker(BaseRerankingModel): """Reranking model that augments a LEALLA encoder with an MLP head.""" @@ -46,10 +42,8 @@ def __init__( tokenizer_name_or_path: str | None = None, mlp_hidden_dim: int | None = None, dropout: float = 0.1, - device: torch.device | str = "cpu", ) -> None: super().__init__() - self.device = torch.device(device) resolved_output_type = _maybe_convert_output_type(output_type) self.base_model = ModelFactory.auto_load_from_file( @@ -59,54 +53,44 @@ def __init__( output_type=resolved_output_type, ) self.base_model.eval() + self.base_model.requires_grad_(False) self.embedding_dim = _infer_output_dim(self.base_model) hidden_dim = mlp_hidden_dim or self.embedding_dim self.classifier = nn.Sequential( - nn.Linear(self.embedding_dim * 2, hidden_dim), - nn.ReLU(), + nn.Linear(self.embedding_dim * 2, hidden_dim * 4), + nn.GELU(), + nn.Linear(4 * hidden_dim, hidden_dim), + nn.GELU(), nn.Dropout(p=dropout), nn.Linear(hidden_dim, 1), ) + self.model = _PairwiseMLPReranker(self.base_model, self.classifier) + tokenizer_id = tokenizer_name_or_path or model_name_or_path self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) self.loss_fn = nn.BCEWithLogitsLoss() - self.to(self.device) - def forward( self, - mention_tokens: Mapping[str, torch.Tensor], - entity_tokens: Mapping[str, torch.Tensor], + mention_tokens: torch.Tensor, + entity_tokens: torch.Tensor, ) -> torch.Tensor: - mention_tokens = _to_device(mention_tokens, self.device) - entity_tokens = _to_device(entity_tokens, self.device) - - mention_embeddings = self._encode(mention_tokens) - entity_embeddings = self._encode(entity_tokens) - - combined = torch.cat([mention_embeddings, entity_embeddings], dim=-1) - logits = self.classifier(combined).squeeze(-1) - return logits + return self.model.forward(mention_tokens, entity_tokens) def train_step(self, data: Dict[str, Any]) -> torch.Tensor: self.train() mention_tokens = data["mention_tokens"] entity_tokens = data["entity_tokens"] - labels = data["labels"].to(self.device).float().view(-1) + labels = data["labels"].float().view(-1) logits = self.forward(mention_tokens, entity_tokens).view(-1) loss = self.loss_fn(logits, labels) return loss - def train(self): - super().train() - # Make sure that base model is never trained. - self.base_model.eval() - @torch.inference_mode() def score(self, mention: str, entity_description: str) -> float: mention_tokens = self.tokenizer( @@ -114,23 +98,37 @@ def score(self, mention: str, entity_description: str) -> float: padding=True, truncation=True, return_tensors="pt", - ).to(self.device) + )["input_ids"] entity_tokens = self.tokenizer( entity_description, padding=True, truncation=True, return_tensors="pt", - ).to(self.device) - mention_tokens["attention_mask"] = create_attention_mask(mention_tokens["input_ids"]) - entity_tokens["attention_mask"] = create_attention_mask(entity_tokens["input_ids"]) - - logits = self.forward( - {k: v for k, v in mention_tokens.items()}, - {k: v for k, v in entity_tokens.items()}, - ) + )["input_ids"] + logits = self.model.forward(mention_tokens, entity_tokens) probability = torch.sigmoid(logits).item() return probability + +class _PairwiseMLPReranker(nn.Module): + def __init__(self, base_model: nn.Module, classifier: nn.Module) -> None: + super().__init__() + self.base_model = base_model + self.classifier = classifier + + def forward( + self, + mention_tokens: torch.Tensor, + entity_tokens: torch.Tensor, + ) -> torch.Tensor: + + mention_embeddings = self._encode(mention_tokens) + entity_embeddings = self._encode(entity_tokens) + + combined = torch.cat([mention_embeddings, entity_embeddings], dim=-1) + logits = self.classifier(combined).squeeze(-1) + return logits + @torch.inference_mode() def _encode(self, tokens: Mapping[str, torch.Tensor] | torch.Tensor) -> torch.Tensor: if isinstance(tokens, Mapping): @@ -143,3 +141,9 @@ def _encode(self, tokens: Mapping[str, torch.Tensor] | torch.Tensor) -> torch.Te attention_mask = create_attention_mask(tokens) return self.base_model(input_ids=input_ids, attention_mask=attention_mask) + + def train(self, mode: bool = True) -> _PairwiseMLPReranker: + super().train(mode) + # Make sure that base model is never trained. + self.base_model.eval() + return self diff --git a/src/reranking/training/trainer.py b/src/reranking/training/trainer.py index 1108e8b..7d8027b 100644 --- a/src/reranking/training/trainer.py +++ b/src/reranking/training/trainer.py @@ -1,3 +1,4 @@ +import copy import logging import os @@ -6,14 +7,12 @@ torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True -import gin import torch.distributed as dist import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler -from finetunings.finetune_model.ddp import cleanup, setup from reranking.training.training_configs import TrainingConfig, pairwise_mlp # Settings =========================================== @@ -59,39 +58,61 @@ def _ddp_train( dataset = training_config.dataset optimizer = training_config.optimizer - sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True) - - dataloader = DataLoader( + validation_dataset = torch.utils.data.Subset( + dataset, + indices=np.arange(training_config.validation_size), + ) + train_dataset = torch.utils.data.Subset( dataset, + indices=np.arange(training_config.validation_size, len(dataset)), + ) + + sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True) + + val_dataloader = DataLoader( + validation_dataset, batch_size=training_config.batch_size, - sampler=sampler, + shuffle=False, pin_memory=True, num_workers=2, ) - model = DDP(model.to(rank), device_ids=[rank]) + copied_model = copy.deepcopy(model) + + model.to(rank) + model.model = DDP(model.model, device_ids=[rank]) model = torch.compile(model) is_the_main_process = rank == 0 scaler = torch.amp.GradScaler("cuda") - loss = None - - def step(): - scaler.scale(loss).backward() + @torch.compile + def step(current_loss): + scaler.scale(current_loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip) scaler.step(optimizer) scaler.update() optimizer.zero_grad() - step = torch.compile(step) + global_step = 0 for epoch in range(epochs): if is_the_main_process: _logger.info(f"Starting epoch {epoch + 1}/{epochs}") + # Ensure proper shuffling across epochs with DistributedSampler + sampler.set_epoch(epoch) + + dataloader = DataLoader( + train_dataset, + batch_size=training_config.batch_size, + sampler=sampler, + pin_memory=True, + num_workers=2, + ) + for links, entities, labels in dataloader: global_step += 1 links = links.to(rank, non_blocking=True) @@ -106,19 +127,57 @@ def step(): with torch.autocast(device_type="cuda"): loss = model.train_step(batch_data) - step() - if is_the_main_process and global_step % 10 == 0: + step(loss) + if is_the_main_process and global_step % training_config.save_each == 0: + path = training_config.get_output_path(global_step) + _logger.info(f"Saving model at step {global_step} to {path}") + copied_model.model = model.model.module + torch.save(copied_model.state_dict(), path) + if is_the_main_process and global_step % 500 == 0: _logger.info(f"Step {global_step}, loss: {loss.item():.4f}") + if is_the_main_process and global_step % training_config.validate_each == 0: + model.eval() + + correct = 0 + total = 0 + + total_loss = 0.0 + val_steps = 0 + for links, entities, labels in val_dataloader: + links = links.to(rank, non_blocking=True) + entities = entities.to(rank, non_blocking=True) + labels = labels.to(rank, non_blocking=True) + + val_steps += 1 + + with torch.inference_mode(): + probs = model.score(links, entities).view(-1) + loss = torch.nn.functional.binary_cross_entropy(probs, labels.float()) + total_loss += loss.item() + predictions = (torch.sigmoid(probs) > 0.5).long() + correct += (predictions == labels).sum().item() + total += labels.size(0) + if is_the_main_process: + _logger.info(f"Validation loss: {total_loss / val_steps:.4f}") + _logger.info(f"Validation accuracy: {correct / total:.4f}") + + model.train() + + if is_the_main_process: + _logger.info(f"Epoch {epoch + 1} finished.") cleanup() -def cleanup(): - dist.destroy_process_group() +def get_config_from_name(config_name: str) -> TrainingConfig: + if config_name == "pairwise_mlp": + return pairwise_mlp() + else: + raise ValueError(f"Unknown training configuration: {config_name}") # Training =========================================== -def train_ddp(): +def train_ddp(config_name: str): _logger.info("Starting DDP training") gradient_clip = 1.0 epochs = 10 @@ -126,7 +185,7 @@ def train_ddp(): _logger.debug(f"Using {world_size} GPUs for training") _logger.info("Loading training configuration") - training_config = pairwise_mlp() + training_config = get_config_from_name(config_name) _logger.info(f"Training configuration loaded: {training_config.config_name}") mp.spawn( @@ -142,4 +201,4 @@ def train_ddp(): if __name__ == "__main__": - train_ddp() + train_ddp("pairwise_mlp") diff --git a/src/reranking/training/training_configs.py b/src/reranking/training/training_configs.py index 1181402..fee2721 100644 --- a/src/reranking/training/training_configs.py +++ b/src/reranking/training/training_configs.py @@ -17,7 +17,9 @@ class TrainingConfig: optimizer: torch.optim.Optimizer batch_size: int output_dir: str - save_each: int = 1000 + save_each: int + validate_each: int + validation_size: int = 100000 def get_output_path(self, step: int) -> str: dir_path = Path(self.output_dir) / self.config_name @@ -28,18 +30,22 @@ def get_output_path(self, step: int) -> str: def pairwise_mlp() -> TrainingConfig: name = "pairwise_mlp" LR = 0.0001 - SAVE_EACH = 1000 - BATCH_SIZE = 64 + SAVE_EACH = 10000 + BATCH_SIZE = 1024 + VALIDATE_EACH = 5000 + VALIDATION_SIZE = 1000 model = PairwiseMLPReranker( model_name_or_path="/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base", state_dict_path="/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/models_5/ema.pth", + mlp_hidden_dim=2048, ) data = np.load( - "/lnet/work/home-students-external/farhan/troja/outputs/reranking_test/reranker_dataset.npz" + "/lnet/work/home-students-external/farhan/troja/outputs/reranking_test/reranker_dataset_with_qids.npz" ) - description_tokens = torch.tensor(data["description_tokens"]) - link_tokens = torch.tensor(data["link_tokens"]) - labels = torch.tensor(data["labels"]) + labels = torch.from_numpy(data["y"]).float() + description_tokens = torch.from_numpy(data["description_tokens"]).long() + link_tokens = torch.from_numpy(data["link_tokens"]).long() + dataset = torch.utils.data.TensorDataset(link_tokens, description_tokens, labels) optimizer = torch.optim.AdamW(model.parameters(), lr=LR) output_dir = "/lnet/work/home-students-external/farhan/troja/outputs/reranking_models" @@ -51,4 +57,6 @@ def pairwise_mlp() -> TrainingConfig: output_dir=output_dir, save_each=SAVE_EACH, batch_size=BATCH_SIZE, + validate_each=VALIDATE_EACH, + validation_size=VALIDATION_SIZE, ) diff --git a/src/run_action_gin.py b/src/run_action_gin.py index 48298a9..f9038a4 100644 --- a/src/run_action_gin.py +++ b/src/run_action_gin.py @@ -21,7 +21,9 @@ from finetunings.generate_epochs.generate import generate from multilingual_dataset.combine_embs import combine_embs_by_qid from multilingual_dataset.creator import create_multilingual_dataset, run_kb_creator +from reranking.training.trainer import train_ddp as reranking_train_ddp from tokenization.runner import ( + run_damuel_description, run_damuel_description_context, run_damuel_description_mention, run_damuel_link_context, @@ -91,10 +93,14 @@ def choose_action(action): return run_damuel_link_mention case "run_damuel_mention": return run_damuel_mention + case "run_damuel_description": + return run_damuel_description case "olpeat": return olpeat case "find_candidates": return find_candidates + case "reranking_train_ddp": + return reranking_train_ddp case _: raise ValueError(f"Unknown action: {action}") diff --git a/src/scripts/qwen/reranker.py b/src/scripts/qwen/reranker.py index 6213b3b..a9de663 100644 --- a/src/scripts/qwen/reranker.py +++ b/src/scripts/qwen/reranker.py @@ -7,27 +7,32 @@ class Reranker: _DEFAULT_INSTRUCTION = ( "Your task is to determine if the provided Wikipedia description correctly corresponds " - "to the entity mention found in the query. The entity mention is marked by and . " + "to the entity mention found in the query. The entity mention is marked by [M] and [M]. " "Check if the description matches the entity. Answer strictly with 'yes' or 'no'.\n" + "Note that the language of the query and description may differ.\n" + "Do NOT consider the language; the goal is to tell whether description matches the entity in the query.\n" "Example:\n" - " Query: 'What is the capital of France?'\n" - " Description: 'Paris is the capital and largest city of France...'\n" + " Query: 'What is the capital of [M] France [M]?'\n" + " Description: '[M] Paris [M] is the capital and largest city of France...'\n" " Answer: no\n" - " Query: 'What is the capital of France?'\n" - " Description: 'Paris is the capital and largest city of France...'\n" + " Query: 'What is the [M] capital [M] of France?'\n" + " Description: '[M] Paris [M] is the capital and largest city of France...'\n" " Answer: yes" ) + # _DEFAULT_INSTRUCTION = ( + # "Given a web search query, retrieve relevant passages that answer the query" + # ) _SYSTEM_PROMPT = ( "<|im_start|>system\n" "Judge whether the Document meets the requirements based on the Query and the Instruct " 'provided. Note that the answer can only be "yes" or "no".<|im_end|>\n' "<|im_start|>user\n" ) - _ASSISTANT_SUFFIX = "<|im_end|>\n" "<|im_start|>assistant\n" "\n\n\n\n" + _ASSISTANT_SUFFIX = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" def __init__( self, - model_name: str = "Qwen/Qwen3-Reranker-0.6B", + model_name: str = "Qwen/Qwen3-Reranker-8B", max_length: int = 8192, instruction: str | None = None, ) -> None: @@ -45,18 +50,12 @@ def __init__( def score(self, mention: str, description: str, instruction: str | None = None) -> float: """Return probability that description matches mention.""" - formatted_query = self._format_query(mention) formatted_instruction = instruction or self.instruction - prompt = self._format_instruction(formatted_instruction, formatted_query, description) + prompt = self._format_instruction(formatted_instruction, mention, description) inputs = self._process_inputs([prompt]) probabilities = self._compute_probabilities(inputs) return probabilities[0] - def _format_query(self, mention: str) -> str: - has_markers = "" in mention and "" in mention - wrapped = mention if has_markers else f"{mention}" - return f"Identify the entity referenced by {wrapped}." - def _format_instruction(self, instruction: str, query: str, document: str) -> str: return ": {instruction}\n: {query}\n: {doc}".format( instruction=instruction, diff --git a/src/scripts/qwen/reranking.py b/src/scripts/qwen/reranking.py deleted file mode 100644 index d7787b7..0000000 --- a/src/scripts/qwen/reranking.py +++ /dev/null @@ -1,78 +0,0 @@ -import argparse -from pathlib import Path - -# Import BruteForceSearcher from models -from models.searchers.brute_force_searcher import BruteForceSearcher - -# Import Reranker class -from scripts.qwen.reranker import Reranker - -# Import necessary functions from loaders -from utils.loaders import load_embs_and_qids, load_tokens_qids, load_tokens_qids_from_dir - - -def main(): - parser = argparse.ArgumentParser(description="Reranking for entity linking") - parser.add_argument( - "--damuel_token", - type=str, - default="/lnet/work/home-students-external/farhan/troja/outputs/v2_normal_filtered/descs_pages", - help="Path to damuel token file or directory", - ) - parser.add_argument( - "--damuel_embs", - type=str, - default="/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/damuel_for_index_6", - help="Path to damuel embeddings directory or .npz file", - ) - parser.add_argument( - "--mewsli_tokens", - type=str, - default="/lnet/work/home-students-external/farhan/troja/outputs/tokens_mewsli_finetuning/es/mentions_1307770978027216442.npz", - help="Path to mewsli token file or directory", - ) - parser.add_argument( - "--qwen_model_name", - type=str, - default="Qwen/Qwen3-Reranker-0.6B", - help="Name of the QWEN model", - ) - args = parser.parse_args() - - # Resolve tokens and embeddings (directories or files) - damuel_tokens, damuel_token_qids = load_tokens_qids_from_dir(args.damuel_token, verbose=True) - damuel_embs, damuel_qids = load_embs_and_qids(args.damuel_embs) - - qid_to_damuel_token = {qid: token for qid, token in zip(damuel_token_qids, damuel_tokens)} - qid_to_damuel_emb = {qid: emb for qid, emb in zip(damuel_qids, damuel_embs)} - - del damuel_token_qids - - mewsli_tokens, mewsli_qids = load_tokens_qids(args.mewsli_tokens) - - # Take first four names (tokens) from each as a quick smoke-test - damuel_tokens_preview = damuel_tokens[:4] - mewsli_tokens_preview = mewsli_tokens[:4] - - print("First 4 damuel tokens:", damuel_tokens_preview) - print("First 4 mewsli tokens:", mewsli_tokens_preview) - - # Create searcher using damuel embeddings and damuel qids - searcher = BruteForceSearcher(damuel_embs, damuel_qids) - print("Searcher created.") - - # Initialize reranker model (actual reranking logic to be implemented later) - reranker = Reranker(model_name=args.qwen_model_name) - print(f"Reranker initialized with model: {reranker.model_name}") - - # Stub for QWEN model loading - print("QWEN model to be used:", args.qwen_model_name) - - # TODO: implement reranking logic - print("Reranking logic not implemented. Exiting.") - - print("jupi") - - -if __name__ == "__main__": - main() diff --git a/src/scripts/qwen/reranking2.py b/src/scripts/qwen/reranking2.py new file mode 100644 index 0000000..9e681a4 --- /dev/null +++ b/src/scripts/qwen/reranking2.py @@ -0,0 +1,164 @@ +import argparse +from pathlib import Path + +import torch +from torch.utils.data import DataLoader, Dataset +from transformers import AutoTokenizer + +from models.searchers.brute_force_searcher import BruteForceSearcher +from reranking.models.pairwise_mlp import PairwiseMLPReranker +from scripts.qwen.reranker import Reranker +from utils.embeddings import create_attention_mask +from utils.loaders import load_embs_and_qids, load_tokens_qids, load_tokens_qids_from_dir +from utils.model_factory import ModelFactory + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def main(): + parser = argparse.ArgumentParser(description="Reranking for entity linking") + parser.add_argument( + "--damuel_token", + type=str, + default="/lnet/work/home-students-external/farhan/troja/outputs/v2_normal_filtered/descs_pages", + help="Path to damuel token file or directory", + ) + parser.add_argument( + "--damuel_embs", + type=str, + default="/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/damuel_for_index_6", + help="Path to damuel embeddings directory or .npz file", + ) + parser.add_argument( + "--mewsli_tokens", + type=str, + default="/lnet/work/home-students-external/farhan/troja/outputs/tokens_mewsli_finetuning/en/mentions_1641252057782057661.npz", + help="Path to mewsli token file or directory", + ) + parser.add_argument( + "--qwen_model_name", + type=str, + default="Qwen/Qwen3-Reranker-0.6B", + help="Name of the QWEN model", + ) + parser.add_argument( + "--reranking_model_path", + type=str, + default="/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/models_5/ema.pth", + ) + parser.add_argument( + "--num_neighbors", + type=int, + default=10, + help="Number of neighbors to retrieve from the searcher", + ) + args = parser.parse_args() + + reranking_model = ModelFactory.auto_load_from_file( + "/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base", + args.reranking_model_path, + ) + reranking_model.eval() + reranking_model.to(device) + + reranking_tokenizer = AutoTokenizer.from_pretrained( + "/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base" + ) + + reranker = PairwiseMLPReranker( + "/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base", + state_dict_path=args.reranking_model_path, + mlp_hidden_dim=2048, + ) + state_dict = torch.load( + "/lnet/work/home-students-external/farhan/troja/outputs/reranking_models/pairwise_mlp/100000.pth", + map_location=device, + ) + reranker.load_state_dict(state_dict) + reranker.eval() + + mewsli_tokens, mewsli_qids = load_tokens_qids(args.mewsli_tokens) + mewsli_tokens = torch.from_numpy(mewsli_tokens) + mewsli_qids = torch.from_numpy(mewsli_qids) + + decode_mewsli_example = reranking_tokenizer.decode(mewsli_tokens[0]) + print("Decoded mewsli example:", decode_mewsli_example) + + # Resolve tokens and embeddings (directories or files) + damuel_tokens, damuel_token_qids = load_tokens_qids_from_dir(args.damuel_token, verbose=True) + damuel_embs, damuel_qids = load_embs_and_qids(args.damuel_embs) + + qid_to_damuel_token = {qid: token for qid, token in zip(damuel_token_qids, damuel_tokens)} + + del damuel_token_qids + + # Take first four names (tokens) from each as a quick smoke-test + damuel_tokens_preview = damuel_tokens[:4] + mewsli_tokens_preview = mewsli_tokens[:4] + + print("First 4 damuel tokens:", damuel_tokens_preview) + print("First 4 mewsli tokens:", mewsli_tokens_preview) + + # Create searcher using damuel embeddings and damuel qids + searcher = BruteForceSearcher(damuel_embs, damuel_qids) + print("Searcher created.") + + # Initialize reranker model (actual reranking logic to be implemented later) + # reranker = Reranker(model_name=args.qwen_model_name) + # print(f"Reranker initialized with model: {reranker.model_name}") + + # Stub for QWEN model loading + print("QWEN model to be used:", args.qwen_model_name) + + # TODO: implement reranking logic + print("Reranking logic not implemented. Exiting.") + + print("jupi") + + mewsli_dataset = torch.utils.data.TensorDataset(mewsli_tokens, mewsli_qids) + mewsli_loader = DataLoader(mewsli_dataset, batch_size=1, shuffle=False) + + good = 0 + total = 0 + + for batch in mewsli_loader: + mewsli_token = batch[0] + qid = batch[1] + # print(f"Processing Mewsli token: {mewsli_token}, QID: {qid}") + + tokens = mewsli_token.to(device, dtype=torch.int64) + + with torch.inference_mode(): + attention_mask = create_attention_mask(tokens).to(device) + mewsli_emb = ( + reranking_model(tokens, attention_mask) + .to("cpu", dtype=torch.float16) + .detach() + .numpy() + ) + + neighbor_qids = searcher.find(mewsli_emb, num_neighbors=args.num_neighbors) + + damuel_candidates = [qid_to_damuel_token[nq] for nq in neighbor_qids[0]] + damuel_candidates_str = [reranking_tokenizer.decode(dc) for dc in damuel_candidates] + + mewsli_str = reranking_tokenizer.decode(mewsli_token[0]) + + scores = [] + # print("Mewsli mention:", mewsli_str) + for dc in damuel_candidates_str: + with torch.inference_mode(): + # print(dc) + score = reranker.score(mewsli_str, dc) + scores.append(score) + # print(scores) + predicted_qid = neighbor_qids[0][scores.index(max(scores))] + if predicted_qid == qid: + good += 1 + total += 1 + + print("Current accuracy:", round(good / total * 100, 4)) + + +if __name__ == "__main__": + main() diff --git a/src/tokenization/runner.py b/src/tokenization/runner.py index 55a5711..1f7b2ff 100644 --- a/src/tokenization/runner.py +++ b/src/tokenization/runner.py @@ -175,6 +175,9 @@ def run_damuel_description_context( ) -> None: tokenizer = AutoTokenizer.from_pretrained(model_path) + print(languages) + print(type(languages)) + for lang in languages: os.makedirs(os.path.join(output_base_dir, lang, "descs_pages"), exist_ok=True) @@ -359,5 +362,32 @@ def run_damuel_mention( os.rename(file, new_name) +@gin.configurable +def run_damuel_description( + model_path: str, + expected_size: int, + output_base_dir: str, + languages: List[str], + damuel_base_path: str, + compress: bool, + remainder_mod: int, + num_processes: int, +) -> None: + run_damuel_description_context( + model_path=model_path, + expected_size=expected_size, + output_base_dir=output_base_dir, + languages=languages, + damuel_base_path=damuel_base_path, + # The goal of this pipeline is to tokenize descriptions without wrapping anything in the label token (e.g., [M]). + # The label token is required by the tokenization class, so we provide an empty string + # so that it has no effect. + label_token="", + compress=compress, + remainder_mod=remainder_mod, + num_processes=num_processes, + ) + + if __name__ == "__main__": run_mewsli_mention() diff --git a/tests/reranking/models/test_pairwise_mlp.py b/tests/reranking/models/test_pairwise_mlp.py index b4a2d75..b13e92a 100644 --- a/tests/reranking/models/test_pairwise_mlp.py +++ b/tests/reranking/models/test_pairwise_mlp.py @@ -72,22 +72,18 @@ def _run_common_checks( labels: torch.Tensor, ): tokenizer = model.tokenizer - mention_batch = tokenizer(mentions, padding=True, truncation=True, return_tensors="pt").to( - model.device - ) + mention_batch = tokenizer(mentions, padding=True, truncation=True, return_tensors="pt") mention_batch["attention_mask"] = create_attention_mask(mention_batch["input_ids"]) - entity_batch = tokenizer(entities, padding=True, truncation=True, return_tensors="pt").to( - model.device - ) + entity_batch = tokenizer(entities, padding=True, truncation=True, return_tensors="pt") entity_batch["attention_mask"] = create_attention_mask(entity_batch["input_ids"]) - encode_out = model._encode(mention_batch["input_ids"]) + encode_out = model.model._encode(mention_batch["input_ids"]) loss = model.train_step( { "mention_tokens": dict(mention_batch), "entity_tokens": dict(entity_batch), - "labels": labels.to(model.device), + "labels": labels, } ) @@ -103,7 +99,7 @@ def _run_common_checks( manual_logits = model.classifier( torch.cat([manual_mention_embeddings, manual_entity_embeddings], dim=-1) ).squeeze(-1) - expected_loss = model.loss_fn(manual_logits, labels.to(model.device)) + expected_loss = model.loss_fn(manual_logits, labels) score = model.score(mentions[0], entities[0]) @@ -154,7 +150,9 @@ def test_pairwise_mlp_with_dummy_backbone(dummy_model: PairwiseMLPReranker) -> N assert torch.allclose(results["loss"], results["expected_loss"]) positive_logit = results["manual_logits"][0].item() - assert math.isclose(results["score"], torch.sigmoid(torch.tensor(positive_logit)).item()) + assert math.isclose( + results["score"], torch.sigmoid(torch.tensor(positive_logit)).item(), rel_tol=1e-5 + ) @pytest.mark.slow diff --git a/tests/reranking/training/test_training_configs.py b/tests/reranking/training/test_training_configs.py index 09b1b32..aaffa62 100644 --- a/tests/reranking/training/test_training_configs.py +++ b/tests/reranking/training/test_training_configs.py @@ -12,6 +12,7 @@ def test_get_output_path_creates_directory(tmp_path): model=MagicMock(), dataset=MagicMock(), optimizer=MagicMock(), + save_each=100, batch_size=1, output_dir=str(output_root), ) From 41cc070c67c6e6ef4a72757c6ee4bca11b599d92 Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Mon, 6 Oct 2025 16:46:48 +0200 Subject: [PATCH 12/28] add ipython and pytest-cov --- pyproject.toml | 2 + uv.lock | 317 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 319 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index d421e28..a45a245 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,10 +7,12 @@ requires-python = ">=3.11" dependencies = [ "fire>=0.7.1", "gin-config>=0.5.0", + "ipython>=9.6.0", "numba>=0.61.2", "orjson>=3.11.3", "pandas>=2.3.2", "pytest>=8.4.2", + "pytest-cov>=7.0.0", "pytest-mock>=3.15.0", "python-fire>=0.1.0", "scann>=1.4.2", diff --git a/uv.lock b/uv.lock index af4a589..2c1c440 100644 --- a/uv.lock +++ b/uv.lock @@ -17,6 +17,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643, upload-time = "2024-05-20T21:33:24.1Z" }, ] +[[package]] +name = "asttokens" +version = "3.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/4a/e7/82da0a03e7ba5141f05cce0d302e6eed121ae055e0456ca228bf693984bc/asttokens-3.0.0.tar.gz", hash = "sha256:0dcd8baa8d62b0c1d118b399b2ddba3c4aff271d0d7a9e0d4c1681c79035bbc7", size = 61978, upload-time = "2024-11-30T04:30:14.439Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/25/8a/c46dcc25341b5bce5472c718902eb3d38600a903b14fa6aeecef3f21a46f/asttokens-3.0.0-py3-none-any.whl", hash = "sha256:e3078351a059199dd5138cb1c706e6430c05eff2ff136af5eb4790f9d28932e2", size = 26918, upload-time = "2024-11-30T04:30:10.946Z" }, +] + [[package]] name = "certifi" version = "2025.8.3" @@ -100,6 +109,116 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, ] +[[package]] +name = "coverage" +version = "7.10.7" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/51/26/d22c300112504f5f9a9fd2297ce33c35f3d353e4aeb987c8419453b2a7c2/coverage-7.10.7.tar.gz", hash = "sha256:f4ab143ab113be368a3e9b795f9cd7906c5ef407d6173fe9675a902e1fffc239", size = 827704, upload-time = "2025-09-21T20:03:56.815Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/5d/c1a17867b0456f2e9ce2d8d4708a4c3a089947d0bec9c66cdf60c9e7739f/coverage-7.10.7-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a609f9c93113be646f44c2a0256d6ea375ad047005d7f57a5c15f614dc1b2f59", size = 218102, upload-time = "2025-09-21T20:01:16.089Z" }, + { url = "https://files.pythonhosted.org/packages/54/f0/514dcf4b4e3698b9a9077f084429681bf3aad2b4a72578f89d7f643eb506/coverage-7.10.7-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:65646bb0359386e07639c367a22cf9b5bf6304e8630b565d0626e2bdf329227a", size = 218505, upload-time = "2025-09-21T20:01:17.788Z" }, + { url = "https://files.pythonhosted.org/packages/20/f6/9626b81d17e2a4b25c63ac1b425ff307ecdeef03d67c9a147673ae40dc36/coverage-7.10.7-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:5f33166f0dfcce728191f520bd2692914ec70fac2713f6bf3ce59c3deacb4699", size = 248898, upload-time = "2025-09-21T20:01:19.488Z" }, + { url = "https://files.pythonhosted.org/packages/b0/ef/bd8e719c2f7417ba03239052e099b76ea1130ac0cbb183ee1fcaa58aaff3/coverage-7.10.7-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:35f5e3f9e455bb17831876048355dca0f758b6df22f49258cb5a91da23ef437d", size = 250831, upload-time = "2025-09-21T20:01:20.817Z" }, + { url = "https://files.pythonhosted.org/packages/a5/b6/bf054de41ec948b151ae2b79a55c107f5760979538f5fb80c195f2517718/coverage-7.10.7-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4da86b6d62a496e908ac2898243920c7992499c1712ff7c2b6d837cc69d9467e", size = 252937, upload-time = "2025-09-21T20:01:22.171Z" }, + { url = "https://files.pythonhosted.org/packages/0f/e5/3860756aa6f9318227443c6ce4ed7bf9e70bb7f1447a0353f45ac5c7974b/coverage-7.10.7-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:6b8b09c1fad947c84bbbc95eca841350fad9cbfa5a2d7ca88ac9f8d836c92e23", size = 249021, upload-time = "2025-09-21T20:01:23.907Z" }, + { url = "https://files.pythonhosted.org/packages/26/0f/bd08bd042854f7fd07b45808927ebcce99a7ed0f2f412d11629883517ac2/coverage-7.10.7-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:4376538f36b533b46f8971d3a3e63464f2c7905c9800db97361c43a2b14792ab", size = 250626, upload-time = "2025-09-21T20:01:25.721Z" }, + { url = "https://files.pythonhosted.org/packages/8e/a7/4777b14de4abcc2e80c6b1d430f5d51eb18ed1d75fca56cbce5f2db9b36e/coverage-7.10.7-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:121da30abb574f6ce6ae09840dae322bef734480ceafe410117627aa54f76d82", size = 248682, upload-time = "2025-09-21T20:01:27.105Z" }, + { url = "https://files.pythonhosted.org/packages/34/72/17d082b00b53cd45679bad682fac058b87f011fd8b9fe31d77f5f8d3a4e4/coverage-7.10.7-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:88127d40df529336a9836870436fc2751c339fbaed3a836d42c93f3e4bd1d0a2", size = 248402, upload-time = "2025-09-21T20:01:28.629Z" }, + { url = "https://files.pythonhosted.org/packages/81/7a/92367572eb5bdd6a84bfa278cc7e97db192f9f45b28c94a9ca1a921c3577/coverage-7.10.7-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ba58bbcd1b72f136080c0bccc2400d66cc6115f3f906c499013d065ac33a4b61", size = 249320, upload-time = "2025-09-21T20:01:30.004Z" }, + { url = "https://files.pythonhosted.org/packages/2f/88/a23cc185f6a805dfc4fdf14a94016835eeb85e22ac3a0e66d5e89acd6462/coverage-7.10.7-cp311-cp311-win32.whl", hash = "sha256:972b9e3a4094b053a4e46832b4bc829fc8a8d347160eb39d03f1690316a99c14", size = 220536, upload-time = "2025-09-21T20:01:32.184Z" }, + { url = "https://files.pythonhosted.org/packages/fe/ef/0b510a399dfca17cec7bc2f05ad8bd78cf55f15c8bc9a73ab20c5c913c2e/coverage-7.10.7-cp311-cp311-win_amd64.whl", hash = "sha256:a7b55a944a7f43892e28ad4bc0561dfd5f0d73e605d1aa5c3c976b52aea121d2", size = 221425, upload-time = "2025-09-21T20:01:33.557Z" }, + { url = "https://files.pythonhosted.org/packages/51/7f/023657f301a276e4ba1850f82749bc136f5a7e8768060c2e5d9744a22951/coverage-7.10.7-cp311-cp311-win_arm64.whl", hash = "sha256:736f227fb490f03c6488f9b6d45855f8e0fd749c007f9303ad30efab0e73c05a", size = 220103, upload-time = "2025-09-21T20:01:34.929Z" }, + { url = "https://files.pythonhosted.org/packages/13/e4/eb12450f71b542a53972d19117ea5a5cea1cab3ac9e31b0b5d498df1bd5a/coverage-7.10.7-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:7bb3b9ddb87ef7725056572368040c32775036472d5a033679d1fa6c8dc08417", size = 218290, upload-time = "2025-09-21T20:01:36.455Z" }, + { url = "https://files.pythonhosted.org/packages/37/66/593f9be12fc19fb36711f19a5371af79a718537204d16ea1d36f16bd78d2/coverage-7.10.7-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:18afb24843cbc175687225cab1138c95d262337f5473512010e46831aa0c2973", size = 218515, upload-time = "2025-09-21T20:01:37.982Z" }, + { url = "https://files.pythonhosted.org/packages/66/80/4c49f7ae09cafdacc73fbc30949ffe77359635c168f4e9ff33c9ebb07838/coverage-7.10.7-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:399a0b6347bcd3822be369392932884b8216d0944049ae22925631a9b3d4ba4c", size = 250020, upload-time = "2025-09-21T20:01:39.617Z" }, + { url = "https://files.pythonhosted.org/packages/a6/90/a64aaacab3b37a17aaedd83e8000142561a29eb262cede42d94a67f7556b/coverage-7.10.7-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:314f2c326ded3f4b09be11bc282eb2fc861184bc95748ae67b360ac962770be7", size = 252769, upload-time = "2025-09-21T20:01:41.341Z" }, + { url = "https://files.pythonhosted.org/packages/98/2e/2dda59afd6103b342e096f246ebc5f87a3363b5412609946c120f4e7750d/coverage-7.10.7-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c41e71c9cfb854789dee6fc51e46743a6d138b1803fab6cb860af43265b42ea6", size = 253901, upload-time = "2025-09-21T20:01:43.042Z" }, + { url = "https://files.pythonhosted.org/packages/53/dc/8d8119c9051d50f3119bb4a75f29f1e4a6ab9415cd1fa8bf22fcc3fb3b5f/coverage-7.10.7-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:bc01f57ca26269c2c706e838f6422e2a8788e41b3e3c65e2f41148212e57cd59", size = 250413, upload-time = "2025-09-21T20:01:44.469Z" }, + { url = "https://files.pythonhosted.org/packages/98/b3/edaff9c5d79ee4d4b6d3fe046f2b1d799850425695b789d491a64225d493/coverage-7.10.7-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a6442c59a8ac8b85812ce33bc4d05bde3fb22321fa8294e2a5b487c3505f611b", size = 251820, upload-time = "2025-09-21T20:01:45.915Z" }, + { url = "https://files.pythonhosted.org/packages/11/25/9a0728564bb05863f7e513e5a594fe5ffef091b325437f5430e8cfb0d530/coverage-7.10.7-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:78a384e49f46b80fb4c901d52d92abe098e78768ed829c673fbb53c498bef73a", size = 249941, upload-time = "2025-09-21T20:01:47.296Z" }, + { url = "https://files.pythonhosted.org/packages/e0/fd/ca2650443bfbef5b0e74373aac4df67b08180d2f184b482c41499668e258/coverage-7.10.7-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:5e1e9802121405ede4b0133aa4340ad8186a1d2526de5b7c3eca519db7bb89fb", size = 249519, upload-time = "2025-09-21T20:01:48.73Z" }, + { url = "https://files.pythonhosted.org/packages/24/79/f692f125fb4299b6f963b0745124998ebb8e73ecdfce4ceceb06a8c6bec5/coverage-7.10.7-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:d41213ea25a86f69efd1575073d34ea11aabe075604ddf3d148ecfec9e1e96a1", size = 251375, upload-time = "2025-09-21T20:01:50.529Z" }, + { url = "https://files.pythonhosted.org/packages/5e/75/61b9bbd6c7d24d896bfeec57acba78e0f8deac68e6baf2d4804f7aae1f88/coverage-7.10.7-cp312-cp312-win32.whl", hash = "sha256:77eb4c747061a6af8d0f7bdb31f1e108d172762ef579166ec84542f711d90256", size = 220699, upload-time = "2025-09-21T20:01:51.941Z" }, + { url = "https://files.pythonhosted.org/packages/ca/f3/3bf7905288b45b075918d372498f1cf845b5b579b723c8fd17168018d5f5/coverage-7.10.7-cp312-cp312-win_amd64.whl", hash = "sha256:f51328ffe987aecf6d09f3cd9d979face89a617eacdaea43e7b3080777f647ba", size = 221512, upload-time = "2025-09-21T20:01:53.481Z" }, + { url = "https://files.pythonhosted.org/packages/5c/44/3e32dbe933979d05cf2dac5e697c8599cfe038aaf51223ab901e208d5a62/coverage-7.10.7-cp312-cp312-win_arm64.whl", hash = "sha256:bda5e34f8a75721c96085903c6f2197dc398c20ffd98df33f866a9c8fd95f4bf", size = 220147, upload-time = "2025-09-21T20:01:55.2Z" }, + { url = "https://files.pythonhosted.org/packages/9a/94/b765c1abcb613d103b64fcf10395f54d69b0ef8be6a0dd9c524384892cc7/coverage-7.10.7-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:981a651f543f2854abd3b5fcb3263aac581b18209be49863ba575de6edf4c14d", size = 218320, upload-time = "2025-09-21T20:01:56.629Z" }, + { url = "https://files.pythonhosted.org/packages/72/4f/732fff31c119bb73b35236dd333030f32c4bfe909f445b423e6c7594f9a2/coverage-7.10.7-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:73ab1601f84dc804f7812dc297e93cd99381162da39c47040a827d4e8dafe63b", size = 218575, upload-time = "2025-09-21T20:01:58.203Z" }, + { url = "https://files.pythonhosted.org/packages/87/02/ae7e0af4b674be47566707777db1aa375474f02a1d64b9323e5813a6cdd5/coverage-7.10.7-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:a8b6f03672aa6734e700bbcd65ff050fd19cddfec4b031cc8cf1c6967de5a68e", size = 249568, upload-time = "2025-09-21T20:01:59.748Z" }, + { url = "https://files.pythonhosted.org/packages/a2/77/8c6d22bf61921a59bce5471c2f1f7ac30cd4ac50aadde72b8c48d5727902/coverage-7.10.7-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:10b6ba00ab1132a0ce4428ff68cf50a25efd6840a42cdf4239c9b99aad83be8b", size = 252174, upload-time = "2025-09-21T20:02:01.192Z" }, + { url = "https://files.pythonhosted.org/packages/b1/20/b6ea4f69bbb52dac0aebd62157ba6a9dddbfe664f5af8122dac296c3ee15/coverage-7.10.7-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c79124f70465a150e89340de5963f936ee97097d2ef76c869708c4248c63ca49", size = 253447, upload-time = "2025-09-21T20:02:02.701Z" }, + { url = "https://files.pythonhosted.org/packages/f9/28/4831523ba483a7f90f7b259d2018fef02cb4d5b90bc7c1505d6e5a84883c/coverage-7.10.7-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:69212fbccdbd5b0e39eac4067e20a4a5256609e209547d86f740d68ad4f04911", size = 249779, upload-time = "2025-09-21T20:02:04.185Z" }, + { url = "https://files.pythonhosted.org/packages/a7/9f/4331142bc98c10ca6436d2d620c3e165f31e6c58d43479985afce6f3191c/coverage-7.10.7-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:7ea7c6c9d0d286d04ed3541747e6597cbe4971f22648b68248f7ddcd329207f0", size = 251604, upload-time = "2025-09-21T20:02:06.034Z" }, + { url = "https://files.pythonhosted.org/packages/ce/60/bda83b96602036b77ecf34e6393a3836365481b69f7ed7079ab85048202b/coverage-7.10.7-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:b9be91986841a75042b3e3243d0b3cb0b2434252b977baaf0cd56e960fe1e46f", size = 249497, upload-time = "2025-09-21T20:02:07.619Z" }, + { url = "https://files.pythonhosted.org/packages/5f/af/152633ff35b2af63977edd835d8e6430f0caef27d171edf2fc76c270ef31/coverage-7.10.7-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:b281d5eca50189325cfe1f365fafade89b14b4a78d9b40b05ddd1fc7d2a10a9c", size = 249350, upload-time = "2025-09-21T20:02:10.34Z" }, + { url = "https://files.pythonhosted.org/packages/9d/71/d92105d122bd21cebba877228990e1646d862e34a98bb3374d3fece5a794/coverage-7.10.7-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:99e4aa63097ab1118e75a848a28e40d68b08a5e19ce587891ab7fd04475e780f", size = 251111, upload-time = "2025-09-21T20:02:12.122Z" }, + { url = "https://files.pythonhosted.org/packages/a2/9e/9fdb08f4bf476c912f0c3ca292e019aab6712c93c9344a1653986c3fd305/coverage-7.10.7-cp313-cp313-win32.whl", hash = "sha256:dc7c389dce432500273eaf48f410b37886be9208b2dd5710aaf7c57fd442c698", size = 220746, upload-time = "2025-09-21T20:02:13.919Z" }, + { url = "https://files.pythonhosted.org/packages/b1/b1/a75fd25df44eab52d1931e89980d1ada46824c7a3210be0d3c88a44aaa99/coverage-7.10.7-cp313-cp313-win_amd64.whl", hash = "sha256:cac0fdca17b036af3881a9d2729a850b76553f3f716ccb0360ad4dbc06b3b843", size = 221541, upload-time = "2025-09-21T20:02:15.57Z" }, + { url = "https://files.pythonhosted.org/packages/14/3a/d720d7c989562a6e9a14b2c9f5f2876bdb38e9367126d118495b89c99c37/coverage-7.10.7-cp313-cp313-win_arm64.whl", hash = "sha256:4b6f236edf6e2f9ae8fcd1332da4e791c1b6ba0dc16a2dc94590ceccb482e546", size = 220170, upload-time = "2025-09-21T20:02:17.395Z" }, + { url = "https://files.pythonhosted.org/packages/bb/22/e04514bf2a735d8b0add31d2b4ab636fc02370730787c576bb995390d2d5/coverage-7.10.7-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:a0ec07fd264d0745ee396b666d47cef20875f4ff2375d7c4f58235886cc1ef0c", size = 219029, upload-time = "2025-09-21T20:02:18.936Z" }, + { url = "https://files.pythonhosted.org/packages/11/0b/91128e099035ece15da3445d9015e4b4153a6059403452d324cbb0a575fa/coverage-7.10.7-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:dd5e856ebb7bfb7672b0086846db5afb4567a7b9714b8a0ebafd211ec7ce6a15", size = 219259, upload-time = "2025-09-21T20:02:20.44Z" }, + { url = "https://files.pythonhosted.org/packages/8b/51/66420081e72801536a091a0c8f8c1f88a5c4bf7b9b1bdc6222c7afe6dc9b/coverage-7.10.7-cp313-cp313t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:f57b2a3c8353d3e04acf75b3fed57ba41f5c0646bbf1d10c7c282291c97936b4", size = 260592, upload-time = "2025-09-21T20:02:22.313Z" }, + { url = "https://files.pythonhosted.org/packages/5d/22/9b8d458c2881b22df3db5bb3e7369e63d527d986decb6c11a591ba2364f7/coverage-7.10.7-cp313-cp313t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:1ef2319dd15a0b009667301a3f84452a4dc6fddfd06b0c5c53ea472d3989fbf0", size = 262768, upload-time = "2025-09-21T20:02:24.287Z" }, + { url = "https://files.pythonhosted.org/packages/f7/08/16bee2c433e60913c610ea200b276e8eeef084b0d200bdcff69920bd5828/coverage-7.10.7-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:83082a57783239717ceb0ad584de3c69cf581b2a95ed6bf81ea66034f00401c0", size = 264995, upload-time = "2025-09-21T20:02:26.133Z" }, + { url = "https://files.pythonhosted.org/packages/20/9d/e53eb9771d154859b084b90201e5221bca7674ba449a17c101a5031d4054/coverage-7.10.7-cp313-cp313t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:50aa94fb1fb9a397eaa19c0d5ec15a5edd03a47bf1a3a6111a16b36e190cff65", size = 259546, upload-time = "2025-09-21T20:02:27.716Z" }, + { url = "https://files.pythonhosted.org/packages/ad/b0/69bc7050f8d4e56a89fb550a1577d5d0d1db2278106f6f626464067b3817/coverage-7.10.7-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:2120043f147bebb41c85b97ac45dd173595ff14f2a584f2963891cbcc3091541", size = 262544, upload-time = "2025-09-21T20:02:29.216Z" }, + { url = "https://files.pythonhosted.org/packages/ef/4b/2514b060dbd1bc0aaf23b852c14bb5818f244c664cb16517feff6bb3a5ab/coverage-7.10.7-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:2fafd773231dd0378fdba66d339f84904a8e57a262f583530f4f156ab83863e6", size = 260308, upload-time = "2025-09-21T20:02:31.226Z" }, + { url = "https://files.pythonhosted.org/packages/54/78/7ba2175007c246d75e496f64c06e94122bdb914790a1285d627a918bd271/coverage-7.10.7-cp313-cp313t-musllinux_1_2_riscv64.whl", hash = "sha256:0b944ee8459f515f28b851728ad224fa2d068f1513ef6b7ff1efafeb2185f999", size = 258920, upload-time = "2025-09-21T20:02:32.823Z" }, + { url = "https://files.pythonhosted.org/packages/c0/b3/fac9f7abbc841409b9a410309d73bfa6cfb2e51c3fada738cb607ce174f8/coverage-7.10.7-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:4b583b97ab2e3efe1b3e75248a9b333bd3f8b0b1b8e5b45578e05e5850dfb2c2", size = 261434, upload-time = "2025-09-21T20:02:34.86Z" }, + { url = "https://files.pythonhosted.org/packages/ee/51/a03bec00d37faaa891b3ff7387192cef20f01604e5283a5fabc95346befa/coverage-7.10.7-cp313-cp313t-win32.whl", hash = "sha256:2a78cd46550081a7909b3329e2266204d584866e8d97b898cd7fb5ac8d888b1a", size = 221403, upload-time = "2025-09-21T20:02:37.034Z" }, + { url = "https://files.pythonhosted.org/packages/53/22/3cf25d614e64bf6d8e59c7c669b20d6d940bb337bdee5900b9ca41c820bb/coverage-7.10.7-cp313-cp313t-win_amd64.whl", hash = "sha256:33a5e6396ab684cb43dc7befa386258acb2d7fae7f67330ebb85ba4ea27938eb", size = 222469, upload-time = "2025-09-21T20:02:39.011Z" }, + { url = "https://files.pythonhosted.org/packages/49/a1/00164f6d30d8a01c3c9c48418a7a5be394de5349b421b9ee019f380df2a0/coverage-7.10.7-cp313-cp313t-win_arm64.whl", hash = "sha256:86b0e7308289ddde73d863b7683f596d8d21c7d8664ce1dee061d0bcf3fbb4bb", size = 220731, upload-time = "2025-09-21T20:02:40.939Z" }, + { url = "https://files.pythonhosted.org/packages/23/9c/5844ab4ca6a4dd97a1850e030a15ec7d292b5c5cb93082979225126e35dd/coverage-7.10.7-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:b06f260b16ead11643a5a9f955bd4b5fd76c1a4c6796aeade8520095b75de520", size = 218302, upload-time = "2025-09-21T20:02:42.527Z" }, + { url = "https://files.pythonhosted.org/packages/f0/89/673f6514b0961d1f0e20ddc242e9342f6da21eaba3489901b565c0689f34/coverage-7.10.7-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:212f8f2e0612778f09c55dd4872cb1f64a1f2b074393d139278ce902064d5b32", size = 218578, upload-time = "2025-09-21T20:02:44.468Z" }, + { url = "https://files.pythonhosted.org/packages/05/e8/261cae479e85232828fb17ad536765c88dd818c8470aca690b0ac6feeaa3/coverage-7.10.7-cp314-cp314-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:3445258bcded7d4aa630ab8296dea4d3f15a255588dd535f980c193ab6b95f3f", size = 249629, upload-time = "2025-09-21T20:02:46.503Z" }, + { url = "https://files.pythonhosted.org/packages/82/62/14ed6546d0207e6eda876434e3e8475a3e9adbe32110ce896c9e0c06bb9a/coverage-7.10.7-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:bb45474711ba385c46a0bfe696c695a929ae69ac636cda8f532be9e8c93d720a", size = 252162, upload-time = "2025-09-21T20:02:48.689Z" }, + { url = "https://files.pythonhosted.org/packages/ff/49/07f00db9ac6478e4358165a08fb41b469a1b053212e8a00cb02f0d27a05f/coverage-7.10.7-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:813922f35bd800dca9994c5971883cbc0d291128a5de6b167c7aa697fcf59360", size = 253517, upload-time = "2025-09-21T20:02:50.31Z" }, + { url = "https://files.pythonhosted.org/packages/a2/59/c5201c62dbf165dfbc91460f6dbbaa85a8b82cfa6131ac45d6c1bfb52deb/coverage-7.10.7-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:93c1b03552081b2a4423091d6fb3787265b8f86af404cff98d1b5342713bdd69", size = 249632, upload-time = "2025-09-21T20:02:51.971Z" }, + { url = "https://files.pythonhosted.org/packages/07/ae/5920097195291a51fb00b3a70b9bbd2edbfe3c84876a1762bd1ef1565ebc/coverage-7.10.7-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:cc87dd1b6eaf0b848eebb1c86469b9f72a1891cb42ac7adcfbce75eadb13dd14", size = 251520, upload-time = "2025-09-21T20:02:53.858Z" }, + { url = "https://files.pythonhosted.org/packages/b9/3c/a815dde77a2981f5743a60b63df31cb322c944843e57dbd579326625a413/coverage-7.10.7-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:39508ffda4f343c35f3236fe8d1a6634a51f4581226a1262769d7f970e73bffe", size = 249455, upload-time = "2025-09-21T20:02:55.807Z" }, + { url = "https://files.pythonhosted.org/packages/aa/99/f5cdd8421ea656abefb6c0ce92556709db2265c41e8f9fc6c8ae0f7824c9/coverage-7.10.7-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:925a1edf3d810537c5a3abe78ec5530160c5f9a26b1f4270b40e62cc79304a1e", size = 249287, upload-time = "2025-09-21T20:02:57.784Z" }, + { url = "https://files.pythonhosted.org/packages/c3/7a/e9a2da6a1fc5d007dd51fca083a663ab930a8c4d149c087732a5dbaa0029/coverage-7.10.7-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:2c8b9a0636f94c43cd3576811e05b89aa9bc2d0a85137affc544ae5cb0e4bfbd", size = 250946, upload-time = "2025-09-21T20:02:59.431Z" }, + { url = "https://files.pythonhosted.org/packages/ef/5b/0b5799aa30380a949005a353715095d6d1da81927d6dbed5def2200a4e25/coverage-7.10.7-cp314-cp314-win32.whl", hash = "sha256:b7b8288eb7cdd268b0304632da8cb0bb93fadcfec2fe5712f7b9cc8f4d487be2", size = 221009, upload-time = "2025-09-21T20:03:01.324Z" }, + { url = "https://files.pythonhosted.org/packages/da/b0/e802fbb6eb746de006490abc9bb554b708918b6774b722bb3a0e6aa1b7de/coverage-7.10.7-cp314-cp314-win_amd64.whl", hash = "sha256:1ca6db7c8807fb9e755d0379ccc39017ce0a84dcd26d14b5a03b78563776f681", size = 221804, upload-time = "2025-09-21T20:03:03.4Z" }, + { url = "https://files.pythonhosted.org/packages/9e/e8/71d0c8e374e31f39e3389bb0bd19e527d46f00ea8571ec7ec8fd261d8b44/coverage-7.10.7-cp314-cp314-win_arm64.whl", hash = "sha256:097c1591f5af4496226d5783d036bf6fd6cd0cbc132e071b33861de756efb880", size = 220384, upload-time = "2025-09-21T20:03:05.111Z" }, + { url = "https://files.pythonhosted.org/packages/62/09/9a5608d319fa3eba7a2019addeacb8c746fb50872b57a724c9f79f146969/coverage-7.10.7-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:a62c6ef0d50e6de320c270ff91d9dd0a05e7250cac2a800b7784bae474506e63", size = 219047, upload-time = "2025-09-21T20:03:06.795Z" }, + { url = "https://files.pythonhosted.org/packages/f5/6f/f58d46f33db9f2e3647b2d0764704548c184e6f5e014bef528b7f979ef84/coverage-7.10.7-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:9fa6e4dd51fe15d8738708a973470f67a855ca50002294852e9571cdbd9433f2", size = 219266, upload-time = "2025-09-21T20:03:08.495Z" }, + { url = "https://files.pythonhosted.org/packages/74/5c/183ffc817ba68e0b443b8c934c8795553eb0c14573813415bd59941ee165/coverage-7.10.7-cp314-cp314t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:8fb190658865565c549b6b4706856d6a7b09302c797eb2cf8e7fe9dabb043f0d", size = 260767, upload-time = "2025-09-21T20:03:10.172Z" }, + { url = "https://files.pythonhosted.org/packages/0f/48/71a8abe9c1ad7e97548835e3cc1adbf361e743e9d60310c5f75c9e7bf847/coverage-7.10.7-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:affef7c76a9ef259187ef31599a9260330e0335a3011732c4b9effa01e1cd6e0", size = 262931, upload-time = "2025-09-21T20:03:11.861Z" }, + { url = "https://files.pythonhosted.org/packages/84/fd/193a8fb132acfc0a901f72020e54be5e48021e1575bb327d8ee1097a28fd/coverage-7.10.7-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6e16e07d85ca0cf8bafe5f5d23a0b850064e8e945d5677492b06bbe6f09cc699", size = 265186, upload-time = "2025-09-21T20:03:13.539Z" }, + { url = "https://files.pythonhosted.org/packages/b1/8f/74ecc30607dd95ad50e3034221113ccb1c6d4e8085cc761134782995daae/coverage-7.10.7-cp314-cp314t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:03ffc58aacdf65d2a82bbeb1ffe4d01ead4017a21bfd0454983b88ca73af94b9", size = 259470, upload-time = "2025-09-21T20:03:15.584Z" }, + { url = "https://files.pythonhosted.org/packages/0f/55/79ff53a769f20d71b07023ea115c9167c0bb56f281320520cf64c5298a96/coverage-7.10.7-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:1b4fd784344d4e52647fd7857b2af5b3fbe6c239b0b5fa63e94eb67320770e0f", size = 262626, upload-time = "2025-09-21T20:03:17.673Z" }, + { url = "https://files.pythonhosted.org/packages/88/e2/dac66c140009b61ac3fc13af673a574b00c16efdf04f9b5c740703e953c0/coverage-7.10.7-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:0ebbaddb2c19b71912c6f2518e791aa8b9f054985a0769bdb3a53ebbc765c6a1", size = 260386, upload-time = "2025-09-21T20:03:19.36Z" }, + { url = "https://files.pythonhosted.org/packages/a2/f1/f48f645e3f33bb9ca8a496bc4a9671b52f2f353146233ebd7c1df6160440/coverage-7.10.7-cp314-cp314t-musllinux_1_2_riscv64.whl", hash = "sha256:a2d9a3b260cc1d1dbdb1c582e63ddcf5363426a1a68faa0f5da28d8ee3c722a0", size = 258852, upload-time = "2025-09-21T20:03:21.007Z" }, + { url = "https://files.pythonhosted.org/packages/bb/3b/8442618972c51a7affeead957995cfa8323c0c9bcf8fa5a027421f720ff4/coverage-7.10.7-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:a3cc8638b2480865eaa3926d192e64ce6c51e3d29c849e09d5b4ad95efae5399", size = 261534, upload-time = "2025-09-21T20:03:23.12Z" }, + { url = "https://files.pythonhosted.org/packages/b2/dc/101f3fa3a45146db0cb03f5b4376e24c0aac818309da23e2de0c75295a91/coverage-7.10.7-cp314-cp314t-win32.whl", hash = "sha256:67f8c5cbcd3deb7a60b3345dffc89a961a484ed0af1f6f73de91705cc6e31235", size = 221784, upload-time = "2025-09-21T20:03:24.769Z" }, + { url = "https://files.pythonhosted.org/packages/4c/a1/74c51803fc70a8a40d7346660379e144be772bab4ac7bb6e6b905152345c/coverage-7.10.7-cp314-cp314t-win_amd64.whl", hash = "sha256:e1ed71194ef6dea7ed2d5cb5f7243d4bcd334bfb63e59878519be558078f848d", size = 222905, upload-time = "2025-09-21T20:03:26.93Z" }, + { url = "https://files.pythonhosted.org/packages/12/65/f116a6d2127df30bcafbceef0302d8a64ba87488bf6f73a6d8eebf060873/coverage-7.10.7-cp314-cp314t-win_arm64.whl", hash = "sha256:7fe650342addd8524ca63d77b2362b02345e5f1a093266787d210c70a50b471a", size = 220922, upload-time = "2025-09-21T20:03:28.672Z" }, + { url = "https://files.pythonhosted.org/packages/ec/16/114df1c291c22cac3b0c127a73e0af5c12ed7bbb6558d310429a0ae24023/coverage-7.10.7-py3-none-any.whl", hash = "sha256:f7941f6f2fe6dd6807a1208737b8a0cbcf1cc6d7b07d24998ad2d63590868260", size = 209952, upload-time = "2025-09-21T20:03:53.918Z" }, +] + +[package.optional-dependencies] +toml = [ + { name = "tomli", marker = "python_full_version <= '3.11'" }, +] + +[[package]] +name = "decorator" +version = "5.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/43/fa/6d96a0978d19e17b68d634497769987b16c8f4cd0a7a05048bec693caa6b/decorator-5.2.1.tar.gz", hash = "sha256:65f266143752f734b0a7cc83c46f4618af75b8c5911b00ccb61d0ac9b6da0360", size = 56711, upload-time = "2025-02-24T04:41:34.073Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4e/8c/f3147f5c4b73e7550fe5f9352eaa956ae838d5c51eb58e7a25b9f3e2643b/decorator-5.2.1-py3-none-any.whl", hash = "sha256:d316bb415a2d9e2d2b3abcc4084c6502fc09240e292cd76a76afc106a1c8e04a", size = 9190, upload-time = "2025-02-24T04:41:32.565Z" }, +] + +[[package]] +name = "executing" +version = "2.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/cc/28/c14e053b6762b1044f34a13aab6859bbf40456d37d23aa286ac24cfd9a5d/executing-2.2.1.tar.gz", hash = "sha256:3632cc370565f6648cc328b32435bd120a1e4ebb20c77e3fdde9a13cd1e533c4", size = 1129488, upload-time = "2025-09-01T09:48:10.866Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c1/ea/53f2148663b321f21b5a606bd5f191517cf40b7072c0497d3c92c4a13b1e/executing-2.2.1-py2.py3-none-any.whl", hash = "sha256:760643d3452b4d777d295bb167ccc74c64a81df23fb5e08eff250c425a4b2017", size = 28317, upload-time = "2025-09-01T09:48:08.5Z" }, +] + [[package]] name = "filelock" version = "3.19.1" @@ -215,6 +334,52 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050, upload-time = "2025-03-19T20:10:01.071Z" }, ] +[[package]] +name = "ipython" +version = "9.6.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "decorator" }, + { name = "ipython-pygments-lexers" }, + { name = "jedi" }, + { name = "matplotlib-inline" }, + { name = "pexpect", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" }, + { name = "prompt-toolkit" }, + { name = "pygments" }, + { name = "stack-data" }, + { name = "traitlets" }, + { name = "typing-extensions", marker = "python_full_version < '3.12'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/2a/34/29b18c62e39ee2f7a6a3bba7efd952729d8aadd45ca17efc34453b717665/ipython-9.6.0.tar.gz", hash = "sha256:5603d6d5d356378be5043e69441a072b50a5b33b4503428c77b04cb8ce7bc731", size = 4396932, upload-time = "2025-09-29T10:55:53.948Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/48/c5/d5e07995077e48220269c28a221e168c91123ad5ceee44d548f54a057fc0/ipython-9.6.0-py3-none-any.whl", hash = "sha256:5f77efafc886d2f023442479b8149e7d86547ad0a979e9da9f045d252f648196", size = 616170, upload-time = "2025-09-29T10:55:47.676Z" }, +] + +[[package]] +name = "ipython-pygments-lexers" +version = "1.1.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ef/4c/5dd1d8af08107f88c7f741ead7a40854b8ac24ddf9ae850afbcf698aa552/ipython_pygments_lexers-1.1.1.tar.gz", hash = "sha256:09c0138009e56b6854f9535736f4171d855c8c08a563a0dcd8022f78355c7e81", size = 8393, upload-time = "2025-01-17T11:24:34.505Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d9/33/1f075bf72b0b747cb3288d011319aaf64083cf2efef8354174e3ed4540e2/ipython_pygments_lexers-1.1.1-py3-none-any.whl", hash = "sha256:a9462224a505ade19a605f71f8fa63c2048833ce50abc86768a0d81d876dc81c", size = 8074, upload-time = "2025-01-17T11:24:33.271Z" }, +] + +[[package]] +name = "jedi" +version = "0.19.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "parso" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/72/3a/79a912fbd4d8dd6fbb02bf69afd3bb72cf0c729bb3063c6f4498603db17a/jedi-0.19.2.tar.gz", hash = "sha256:4770dc3de41bde3966b02eb84fbcf557fb33cce26ad23da12c742fb50ecb11f0", size = 1231287, upload-time = "2024-11-11T01:41:42.873Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c0/5a/9cac0c82afec3d09ccd97c8b6502d48f165f9124db81b4bcb90b4af974ee/jedi-0.19.2-py2.py3-none-any.whl", hash = "sha256:a8ef22bde8490f57fe5c7681a3c83cb58874daf72b4784de3cce5b6ef6edb5b9", size = 1572278, upload-time = "2024-11-11T01:41:40.175Z" }, +] + [[package]] name = "jinja2" version = "3.1.6" @@ -298,6 +463,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4f/65/6079a46068dfceaeabb5dcad6d674f5f5c61a6fa5673746f42a9f4c233b3/MarkupSafe-3.0.2-cp313-cp313t-win_amd64.whl", hash = "sha256:e444a31f8db13eb18ada366ab3cf45fd4b31e4db1236a4448f68778c1d1a5a2f", size = 15739, upload-time = "2024-10-18T15:21:42.784Z" }, ] +[[package]] +name = "matplotlib-inline" +version = "0.1.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/99/5b/a36a337438a14116b16480db471ad061c36c3694df7c2084a0da7ba538b7/matplotlib_inline-0.1.7.tar.gz", hash = "sha256:8423b23ec666be3d16e16b60bdd8ac4e86e840ebd1dd11a30b9f117f2fa0ab90", size = 8159, upload-time = "2024-04-15T13:44:44.803Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8f/8e/9ad090d3553c280a8060fbf6e24dc1c0c29704ee7d1c372f0c174aa59285/matplotlib_inline-0.1.7-py3-none-any.whl", hash = "sha256:df192d39a4ff8f21b1895d72e6a13f5fcc5099f00fa84384e0ea28c2cc0653ca", size = 9899, upload-time = "2024-04-15T13:44:43.265Z" }, +] + [[package]] name = "mel" version = "0.1.0" @@ -305,10 +482,12 @@ source = { virtual = "." } dependencies = [ { name = "fire" }, { name = "gin-config" }, + { name = "ipython" }, { name = "numba" }, { name = "orjson" }, { name = "pandas" }, { name = "pytest" }, + { name = "pytest-cov" }, { name = "pytest-mock" }, { name = "python-fire" }, { name = "scann" }, @@ -322,10 +501,12 @@ dependencies = [ requires-dist = [ { name = "fire", specifier = ">=0.7.1" }, { name = "gin-config", specifier = ">=0.5.0" }, + { name = "ipython", specifier = ">=9.6.0" }, { name = "numba", specifier = ">=0.61.2" }, { name = "orjson", specifier = ">=3.11.3" }, { name = "pandas", specifier = ">=2.3.2" }, { name = "pytest", specifier = ">=8.4.2" }, + { name = "pytest-cov", specifier = ">=7.0.0" }, { name = "pytest-mock", specifier = ">=3.15.0" }, { name = "python-fire", specifier = ">=0.1.0" }, { name = "scann", specifier = ">=1.4.2" }, @@ -668,6 +849,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cd/d7/612123674d7b17cf345aad0a10289b2a384bff404e0463a83c4a3a59d205/pandas-2.3.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:d2c3554bd31b731cd6490d94a28f3abb8dd770634a9e06eb6d2911b9827db370", size = 13186141, upload-time = "2025-08-21T10:28:05.377Z" }, ] +[[package]] +name = "parso" +version = "0.8.5" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d4/de/53e0bcf53d13e005bd8c92e7855142494f41171b34c2536b86187474184d/parso-0.8.5.tar.gz", hash = "sha256:034d7354a9a018bdce352f48b2a8a450f05e9d6ee85db84764e9b6bd96dafe5a", size = 401205, upload-time = "2025-08-23T15:15:28.028Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/16/32/f8e3c85d1d5250232a5d3477a2a28cc291968ff175caeadaf3cc19ce0e4a/parso-0.8.5-py2.py3-none-any.whl", hash = "sha256:646204b5ee239c396d040b90f9e272e9a8017c630092bf59980beb62fd033887", size = 106668, upload-time = "2025-08-23T15:15:25.663Z" }, +] + +[[package]] +name = "pexpect" +version = "4.9.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "ptyprocess" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/42/92/cc564bf6381ff43ce1f4d06852fc19a2f11d180f23dc32d9588bee2f149d/pexpect-4.9.0.tar.gz", hash = "sha256:ee7d41123f3c9911050ea2c2dac107568dc43b2d3b0c7557a33212c398ead30f", size = 166450, upload-time = "2023-11-25T09:07:26.339Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9e/c3/059298687310d527a58bb01f3b1965787ee3b40dce76752eda8b44e9a2c5/pexpect-4.9.0-py2.py3-none-any.whl", hash = "sha256:7236d1e080e4936be2dc3e326cec0af72acf9212a7e1d060210e70a47e253523", size = 63772, upload-time = "2023-11-25T06:56:14.81Z" }, +] + [[package]] name = "platformdirs" version = "4.4.0" @@ -686,6 +888,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, ] +[[package]] +name = "prompt-toolkit" +version = "3.0.52" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "wcwidth" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a1/96/06e01a7b38dce6fe1db213e061a4602dd6032a8a97ef6c1a862537732421/prompt_toolkit-3.0.52.tar.gz", hash = "sha256:28cde192929c8e7321de85de1ddbe736f1375148b02f2e17edd840042b1be855", size = 434198, upload-time = "2025-08-27T15:24:02.057Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/84/03/0d3ce49e2505ae70cf43bc5bb3033955d2fc9f932163e84dc0779cc47f48/prompt_toolkit-3.0.52-py3-none-any.whl", hash = "sha256:9aac639a3bbd33284347de5ad8d68ecc044b91a762dc39b7c21095fcd6a19955", size = 391431, upload-time = "2025-08-27T15:23:59.498Z" }, +] + [[package]] name = "protobuf" version = "6.32.0" @@ -700,6 +914,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9c/f2/80ffc4677aac1bc3519b26bc7f7f5de7fce0ee2f7e36e59e27d8beb32dd1/protobuf-6.32.0-py3-none-any.whl", hash = "sha256:ba377e5b67b908c8f3072a57b63e2c6a4cbd18aea4ed98d2584350dbf46f2783", size = 169287, upload-time = "2025-08-14T21:21:23.515Z" }, ] +[[package]] +name = "ptyprocess" +version = "0.7.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/20/e5/16ff212c1e452235a90aeb09066144d0c5a6a8c0834397e03f5224495c4e/ptyprocess-0.7.0.tar.gz", hash = "sha256:5c5d0a3b48ceee0b48485e0c26037c0acd7d29765ca3fbb5cb3831d347423220", size = 70762, upload-time = "2020-12-28T15:15:30.155Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/22/a6/858897256d0deac81a172289110f31629fc4cee19b6f01283303e18c8db3/ptyprocess-0.7.0-py2.py3-none-any.whl", hash = "sha256:4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35", size = 13993, upload-time = "2020-12-28T15:15:28.35Z" }, +] + +[[package]] +name = "pure-eval" +version = "0.2.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/cd/05/0a34433a064256a578f1783a10da6df098ceaa4a57bbeaa96a6c0352786b/pure_eval-0.2.3.tar.gz", hash = "sha256:5f4e983f40564c576c7c8635ae88db5956bb2229d7e9237d03b3c0b0190eaf42", size = 19752, upload-time = "2024-07-21T12:58:21.801Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8e/37/efad0257dc6e593a18957422533ff0f87ede7c9c6ea010a2177d738fb82f/pure_eval-0.2.3-py3-none-any.whl", hash = "sha256:1db8e35b67b3d218d818ae653e27f06c3aa420901fa7b081ca98cbedc874e0d0", size = 11842, upload-time = "2024-07-21T12:58:20.04Z" }, +] + [[package]] name = "pydantic" version = "2.11.7" @@ -805,6 +1037,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a8/a4/20da314d277121d6534b3a980b29035dcd51e6744bd79075a6ce8fa4eb8d/pytest-8.4.2-py3-none-any.whl", hash = "sha256:872f880de3fc3a5bdc88a11b39c9710c3497a547cfa9320bc3c5e62fbf272e79", size = 365750, upload-time = "2025-09-04T14:34:20.226Z" }, ] +[[package]] +name = "pytest-cov" +version = "7.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "coverage", extra = ["toml"] }, + { name = "pluggy" }, + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5e/f7/c933acc76f5208b3b00089573cf6a2bc26dc80a8aece8f52bb7d6b1855ca/pytest_cov-7.0.0.tar.gz", hash = "sha256:33c97eda2e049a0c5298e91f519302a1334c26ac65c1a483d6206fd458361af1", size = 54328, upload-time = "2025-09-09T10:57:02.113Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ee/49/1377b49de7d0c1ce41292161ea0f721913fa8722c19fb9c1e3aa0367eecb/pytest_cov-7.0.0-py3-none-any.whl", hash = "sha256:3b8e9558b16cc1479da72058bdecf8073661c7f57f7d3c5f22a1c23507f2d861", size = 22424, upload-time = "2025-09-09T10:57:00.695Z" }, +] + [[package]] name = "pytest-mock" version = "3.15.0" @@ -1043,6 +1289,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/04/be/d09147ad1ec7934636ad912901c5fd7667e1c858e19d355237db0d0cd5e4/smmap-5.0.2-py3-none-any.whl", hash = "sha256:b30115f0def7d7531d22a0fb6502488d879e75b260a9db4d0819cfb25403af5e", size = 24303, upload-time = "2025-01-02T07:14:38.724Z" }, ] +[[package]] +name = "stack-data" +version = "0.6.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "asttokens" }, + { name = "executing" }, + { name = "pure-eval" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/28/e3/55dcc2cfbc3ca9c29519eb6884dd1415ecb53b0e934862d3559ddcb7e20b/stack_data-0.6.3.tar.gz", hash = "sha256:836a778de4fec4dcd1dcd89ed8abff8a221f58308462e1c4aa2a3cf30148f0b9", size = 44707, upload-time = "2023-09-30T13:58:05.479Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl", hash = "sha256:d5558e0c25a4cb0853cddad3d77da9891a08cb85dd9f9f91b9f8cd66e511e695", size = 24521, upload-time = "2023-09-30T13:58:03.53Z" }, +] + [[package]] name = "sympy" version = "1.14.0" @@ -1089,6 +1349,45 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d1/9b/0e0bf82214ee20231845b127aa4a8015936ad5a46779f30865d10e404167/tokenizers-0.22.0-cp39-abi3-win_amd64.whl", hash = "sha256:c78174859eeaee96021f248a56c801e36bfb6bd5b067f2e95aa82445ca324f00", size = 2680494, upload-time = "2025-08-29T10:25:35.14Z" }, ] +[[package]] +name = "tomli" +version = "2.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/18/87/302344fed471e44a87289cf4967697d07e532f2421fdaf868a303cbae4ff/tomli-2.2.1.tar.gz", hash = "sha256:cd45e1dc79c835ce60f7404ec8119f2eb06d38b1deba146f07ced3bbc44505ff", size = 17175, upload-time = "2024-11-27T22:38:36.873Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/43/ca/75707e6efa2b37c77dadb324ae7d9571cb424e61ea73fad7c56c2d14527f/tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249", size = 131077, upload-time = "2024-11-27T22:37:54.956Z" }, + { url = "https://files.pythonhosted.org/packages/c7/16/51ae563a8615d472fdbffc43a3f3d46588c264ac4f024f63f01283becfbb/tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6", size = 123429, upload-time = "2024-11-27T22:37:56.698Z" }, + { url = "https://files.pythonhosted.org/packages/f1/dd/4f6cd1e7b160041db83c694abc78e100473c15d54620083dbd5aae7b990e/tomli-2.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ece47d672db52ac607a3d9599a9d48dcb2f2f735c6c2d1f34130085bb12b112a", size = 226067, upload-time = "2024-11-27T22:37:57.63Z" }, + { url = "https://files.pythonhosted.org/packages/a9/6b/c54ede5dc70d648cc6361eaf429304b02f2871a345bbdd51e993d6cdf550/tomli-2.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6972ca9c9cc9f0acaa56a8ca1ff51e7af152a9f87fb64623e31d5c83700080ee", size = 236030, upload-time = "2024-11-27T22:37:59.344Z" }, + { url = "https://files.pythonhosted.org/packages/1f/47/999514fa49cfaf7a92c805a86c3c43f4215621855d151b61c602abb38091/tomli-2.2.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c954d2250168d28797dd4e3ac5cf812a406cd5a92674ee4c8f123c889786aa8e", size = 240898, upload-time = "2024-11-27T22:38:00.429Z" }, + { url = "https://files.pythonhosted.org/packages/73/41/0a01279a7ae09ee1573b423318e7934674ce06eb33f50936655071d81a24/tomli-2.2.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8dd28b3e155b80f4d54beb40a441d366adcfe740969820caf156c019fb5c7ec4", size = 229894, upload-time = "2024-11-27T22:38:02.094Z" }, + { url = "https://files.pythonhosted.org/packages/55/18/5d8bc5b0a0362311ce4d18830a5d28943667599a60d20118074ea1b01bb7/tomli-2.2.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e59e304978767a54663af13c07b3d1af22ddee3bb2fb0618ca1593e4f593a106", size = 245319, upload-time = "2024-11-27T22:38:03.206Z" }, + { url = "https://files.pythonhosted.org/packages/92/a3/7ade0576d17f3cdf5ff44d61390d4b3febb8a9fc2b480c75c47ea048c646/tomli-2.2.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:33580bccab0338d00994d7f16f4c4ec25b776af3ffaac1ed74e0b3fc95e885a8", size = 238273, upload-time = "2024-11-27T22:38:04.217Z" }, + { url = "https://files.pythonhosted.org/packages/72/6f/fa64ef058ac1446a1e51110c375339b3ec6be245af9d14c87c4a6412dd32/tomli-2.2.1-cp311-cp311-win32.whl", hash = "sha256:465af0e0875402f1d226519c9904f37254b3045fc5084697cefb9bdde1ff99ff", size = 98310, upload-time = "2024-11-27T22:38:05.908Z" }, + { url = "https://files.pythonhosted.org/packages/6a/1c/4a2dcde4a51b81be3530565e92eda625d94dafb46dbeb15069df4caffc34/tomli-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:2d0f2fdd22b02c6d81637a3c95f8cd77f995846af7414c5c4b8d0545afa1bc4b", size = 108309, upload-time = "2024-11-27T22:38:06.812Z" }, + { url = "https://files.pythonhosted.org/packages/52/e1/f8af4c2fcde17500422858155aeb0d7e93477a0d59a98e56cbfe75070fd0/tomli-2.2.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4a8f6e44de52d5e6c657c9fe83b562f5f4256d8ebbfe4ff922c495620a7f6cea", size = 132762, upload-time = "2024-11-27T22:38:07.731Z" }, + { url = "https://files.pythonhosted.org/packages/03/b8/152c68bb84fc00396b83e7bbddd5ec0bd3dd409db4195e2a9b3e398ad2e3/tomli-2.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8d57ca8095a641b8237d5b079147646153d22552f1c637fd3ba7f4b0b29167a8", size = 123453, upload-time = "2024-11-27T22:38:09.384Z" }, + { url = "https://files.pythonhosted.org/packages/c8/d6/fc9267af9166f79ac528ff7e8c55c8181ded34eb4b0e93daa767b8841573/tomli-2.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e340144ad7ae1533cb897d406382b4b6fede8890a03738ff1683af800d54192", size = 233486, upload-time = "2024-11-27T22:38:10.329Z" }, + { url = "https://files.pythonhosted.org/packages/5c/51/51c3f2884d7bab89af25f678447ea7d297b53b5a3b5730a7cb2ef6069f07/tomli-2.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db2b95f9de79181805df90bedc5a5ab4c165e6ec3fe99f970d0e302f384ad222", size = 242349, upload-time = "2024-11-27T22:38:11.443Z" }, + { url = "https://files.pythonhosted.org/packages/ab/df/bfa89627d13a5cc22402e441e8a931ef2108403db390ff3345c05253935e/tomli-2.2.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40741994320b232529c802f8bc86da4e1aa9f413db394617b9a256ae0f9a7f77", size = 252159, upload-time = "2024-11-27T22:38:13.099Z" }, + { url = "https://files.pythonhosted.org/packages/9e/6e/fa2b916dced65763a5168c6ccb91066f7639bdc88b48adda990db10c8c0b/tomli-2.2.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:400e720fe168c0f8521520190686ef8ef033fb19fc493da09779e592861b78c6", size = 237243, upload-time = "2024-11-27T22:38:14.766Z" }, + { url = "https://files.pythonhosted.org/packages/b4/04/885d3b1f650e1153cbb93a6a9782c58a972b94ea4483ae4ac5cedd5e4a09/tomli-2.2.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:02abe224de6ae62c19f090f68da4e27b10af2b93213d36cf44e6e1c5abd19fdd", size = 259645, upload-time = "2024-11-27T22:38:15.843Z" }, + { url = "https://files.pythonhosted.org/packages/9c/de/6b432d66e986e501586da298e28ebeefd3edc2c780f3ad73d22566034239/tomli-2.2.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b82ebccc8c8a36f2094e969560a1b836758481f3dc360ce9a3277c65f374285e", size = 244584, upload-time = "2024-11-27T22:38:17.645Z" }, + { url = "https://files.pythonhosted.org/packages/1c/9a/47c0449b98e6e7d1be6cbac02f93dd79003234ddc4aaab6ba07a9a7482e2/tomli-2.2.1-cp312-cp312-win32.whl", hash = "sha256:889f80ef92701b9dbb224e49ec87c645ce5df3fa2cc548664eb8a25e03127a98", size = 98875, upload-time = "2024-11-27T22:38:19.159Z" }, + { url = "https://files.pythonhosted.org/packages/ef/60/9b9638f081c6f1261e2688bd487625cd1e660d0a85bd469e91d8db969734/tomli-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:7fc04e92e1d624a4a63c76474610238576942d6b8950a2d7f908a340494e67e4", size = 109418, upload-time = "2024-11-27T22:38:20.064Z" }, + { url = "https://files.pythonhosted.org/packages/04/90/2ee5f2e0362cb8a0b6499dc44f4d7d48f8fff06d28ba46e6f1eaa61a1388/tomli-2.2.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f4039b9cbc3048b2416cc57ab3bda989a6fcf9b36cf8937f01a6e731b64f80d7", size = 132708, upload-time = "2024-11-27T22:38:21.659Z" }, + { url = "https://files.pythonhosted.org/packages/c0/ec/46b4108816de6b385141f082ba99e315501ccd0a2ea23db4a100dd3990ea/tomli-2.2.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:286f0ca2ffeeb5b9bd4fcc8d6c330534323ec51b2f52da063b11c502da16f30c", size = 123582, upload-time = "2024-11-27T22:38:22.693Z" }, + { url = "https://files.pythonhosted.org/packages/a0/bd/b470466d0137b37b68d24556c38a0cc819e8febe392d5b199dcd7f578365/tomli-2.2.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a92ef1a44547e894e2a17d24e7557a5e85a9e1d0048b0b5e7541f76c5032cb13", size = 232543, upload-time = "2024-11-27T22:38:24.367Z" }, + { url = "https://files.pythonhosted.org/packages/d9/e5/82e80ff3b751373f7cead2815bcbe2d51c895b3c990686741a8e56ec42ab/tomli-2.2.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9316dc65bed1684c9a98ee68759ceaed29d229e985297003e494aa825ebb0281", size = 241691, upload-time = "2024-11-27T22:38:26.081Z" }, + { url = "https://files.pythonhosted.org/packages/05/7e/2a110bc2713557d6a1bfb06af23dd01e7dde52b6ee7dadc589868f9abfac/tomli-2.2.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e85e99945e688e32d5a35c1ff38ed0b3f41f43fad8df0bdf79f72b2ba7bc5272", size = 251170, upload-time = "2024-11-27T22:38:27.921Z" }, + { url = "https://files.pythonhosted.org/packages/64/7b/22d713946efe00e0adbcdfd6d1aa119ae03fd0b60ebed51ebb3fa9f5a2e5/tomli-2.2.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ac065718db92ca818f8d6141b5f66369833d4a80a9d74435a268c52bdfa73140", size = 236530, upload-time = "2024-11-27T22:38:29.591Z" }, + { url = "https://files.pythonhosted.org/packages/38/31/3a76f67da4b0cf37b742ca76beaf819dca0ebef26d78fc794a576e08accf/tomli-2.2.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:d920f33822747519673ee656a4b6ac33e382eca9d331c87770faa3eef562aeb2", size = 258666, upload-time = "2024-11-27T22:38:30.639Z" }, + { url = "https://files.pythonhosted.org/packages/07/10/5af1293da642aded87e8a988753945d0cf7e00a9452d3911dd3bb354c9e2/tomli-2.2.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a198f10c4d1b1375d7687bc25294306e551bf1abfa4eace6650070a5c1ae2744", size = 243954, upload-time = "2024-11-27T22:38:31.702Z" }, + { url = "https://files.pythonhosted.org/packages/5b/b9/1ed31d167be802da0fc95020d04cd27b7d7065cc6fbefdd2f9186f60d7bd/tomli-2.2.1-cp313-cp313-win32.whl", hash = "sha256:d3f5614314d758649ab2ab3a62d4f2004c825922f9e370b29416484086b264ec", size = 98724, upload-time = "2024-11-27T22:38:32.837Z" }, + { url = "https://files.pythonhosted.org/packages/c7/32/b0963458706accd9afcfeb867c0f9175a741bf7b19cd424230714d722198/tomli-2.2.1-cp313-cp313-win_amd64.whl", hash = "sha256:a38aa0308e754b0e3c67e344754dff64999ff9b513e691d0e786265c93583c69", size = 109383, upload-time = "2024-11-27T22:38:34.455Z" }, + { url = "https://files.pythonhosted.org/packages/6e/c2/61d3e0f47e2b74ef40a68b9e6ad5984f6241a942f7cd3bbfbdbd03861ea9/tomli-2.2.1-py3-none-any.whl", hash = "sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc", size = 14257, upload-time = "2024-11-27T22:38:35.385Z" }, +] + [[package]] name = "torch" version = "2.8.0" @@ -1148,6 +1447,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2", size = 78540, upload-time = "2024-11-24T20:12:19.698Z" }, ] +[[package]] +name = "traitlets" +version = "5.14.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/eb/79/72064e6a701c2183016abbbfedaba506d81e30e232a68c9f0d6f6fcd1574/traitlets-5.14.3.tar.gz", hash = "sha256:9ed0579d3502c94b4b3732ac120375cda96f923114522847de4b3bb98b96b6b7", size = 161621, upload-time = "2024-04-19T11:11:49.746Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/00/c0/8f5d070730d7836adc9c9b6408dec68c6ced86b304a9b26a14df072a6e8c/traitlets-5.14.3-py3-none-any.whl", hash = "sha256:b74e89e397b1ed28cc831db7aea759ba6640cb3de13090ca145426688ff1ac4f", size = 85359, upload-time = "2024-04-19T11:11:46.763Z" }, +] + [[package]] name = "transformers" version = "4.56.1" @@ -1250,3 +1558,12 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/75/d3/faa6ddb792a158c154fb704b25c96d0478e71eabf96e3f17529fb23b6894/wandb-0.21.3-py3-none-win32.whl", hash = "sha256:45aa3d8ad53c6ee06f37490d7a329ed7d0f5ca4dbd5d05bb0c01d5da22f14691", size = 18709408, upload-time = "2025-08-30T18:21:50.859Z" }, { url = "https://files.pythonhosted.org/packages/d8/2d/7ef56e25f78786e59fefd9b19867c325f9686317d9f7b93b5cb340360a3e/wandb-0.21.3-py3-none-win_amd64.whl", hash = "sha256:56d5a5697766f552a9933d8c6a564202194768eb0389bd5f9fe9a99cd4cee41e", size = 18709411, upload-time = "2025-08-30T18:21:52.874Z" }, ] + +[[package]] +name = "wcwidth" +version = "0.2.14" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/24/30/6b0809f4510673dc723187aeaf24c7f5459922d01e2f794277a3dfb90345/wcwidth-0.2.14.tar.gz", hash = "sha256:4d478375d31bc5395a3c55c40ccdf3354688364cd61c4f6adacaa9215d0b3605", size = 102293, upload-time = "2025-09-22T16:29:53.023Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/af/b5/123f13c975e9f27ab9c0770f514345bd406d0e8d3b7a0723af9d43f710af/wcwidth-0.2.14-py2.py3-none-any.whl", hash = "sha256:a7bb560c8aee30f9957e5f9895805edd20602f2d7f720186dfd906e82b4982e1", size = 37286, upload-time = "2025-09-22T16:29:51.641Z" }, +] From aeb0dffcce526ba56e916661a5c3e95241d2cb3c Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Thu, 9 Oct 2025 14:17:54 +0200 Subject: [PATCH 13/28] feat(rerank): add binary dataset creation functions --- src/reranking/{binary => dataset}/create_dataset.py | 1 + 1 file changed, 1 insertion(+) rename src/reranking/{binary => dataset}/create_dataset.py (99%) diff --git a/src/reranking/binary/create_dataset.py b/src/reranking/dataset/create_dataset.py similarity index 99% rename from src/reranking/binary/create_dataset.py rename to src/reranking/dataset/create_dataset.py index 09aeb48..f5e46ba 100644 --- a/src/reranking/binary/create_dataset.py +++ b/src/reranking/dataset/create_dataset.py @@ -152,6 +152,7 @@ def create_binary_dataset( description_tokens=description_tokens, link_tokens=link_tokens_list, y=y, + qids=qids, ) From 725576ef807130fd35635cbe17a9ddee309f7b2c Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Sun, 12 Oct 2025 10:17:09 +0200 Subject: [PATCH 14/28] feat(dependencies): add einops package to project dependencies --- pyproject.toml | 1 + uv.lock | 11 +++++++++++ 2 files changed, 12 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index a45a245..26b521a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,6 +5,7 @@ description = "Add your description here" readme = "README.md" requires-python = ">=3.11" dependencies = [ + "einops>=0.8.1", "fire>=0.7.1", "gin-config>=0.5.0", "ipython>=9.6.0", diff --git a/uv.lock b/uv.lock index 2c1c440..67f1d6b 100644 --- a/uv.lock +++ b/uv.lock @@ -210,6 +210,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4e/8c/f3147f5c4b73e7550fe5f9352eaa956ae838d5c51eb58e7a25b9f3e2643b/decorator-5.2.1-py3-none-any.whl", hash = "sha256:d316bb415a2d9e2d2b3abcc4084c6502fc09240e292cd76a76afc106a1c8e04a", size = 9190, upload-time = "2025-02-24T04:41:32.565Z" }, ] +[[package]] +name = "einops" +version = "0.8.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e5/81/df4fbe24dff8ba3934af99044188e20a98ed441ad17a274539b74e82e126/einops-0.8.1.tar.gz", hash = "sha256:de5d960a7a761225532e0f1959e5315ebeafc0cd43394732f103ca44b9837e84", size = 54805, upload-time = "2025-02-09T03:17:00.434Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/87/62/9773de14fe6c45c23649e98b83231fffd7b9892b6cf863251dc2afa73643/einops-0.8.1-py3-none-any.whl", hash = "sha256:919387eb55330f5757c6bea9165c5ff5cfe63a642682ea788a6d472576d81737", size = 64359, upload-time = "2025-02-09T03:17:01.998Z" }, +] + [[package]] name = "executing" version = "2.2.1" @@ -480,6 +489,7 @@ name = "mel" version = "0.1.0" source = { virtual = "." } dependencies = [ + { name = "einops" }, { name = "fire" }, { name = "gin-config" }, { name = "ipython" }, @@ -499,6 +509,7 @@ dependencies = [ [package.metadata] requires-dist = [ + { name = "einops", specifier = ">=0.8.1" }, { name = "fire", specifier = ">=0.7.1" }, { name = "gin-config", specifier = ">=0.5.0" }, { name = "ipython", specifier = ">=9.6.0" }, From b3ec79c2c6bf58867614fb83260115759fccf693 Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Sun, 12 Oct 2025 10:17:22 +0200 Subject: [PATCH 15/28] feat(config): add run_kb_creator.langs to multilingual dataset configuration --- configs/multilingual_dataset.gin | 3 +++ 1 file changed, 3 insertions(+) diff --git a/configs/multilingual_dataset.gin b/configs/multilingual_dataset.gin index 0f7fe45..8092036 100644 --- a/configs/multilingual_dataset.gin +++ b/configs/multilingual_dataset.gin @@ -6,3 +6,6 @@ dest_dir="/lnet/work/home-students-external/farhan/troja/outputs/v2_normal/" create_multilingual_dataset.source_dir=%source_dir create_multilingual_dataset.langs=%langs create_multilingual_dataset.dest_dir=%dest_dir + +run_kb_creator.langs=%langs + From 0707fec814278457e01362ae4a8fa6be7a5f5b3b Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Sun, 12 Oct 2025 10:17:34 +0200 Subject: [PATCH 16/28] feat(rerank): add create_default_binary_dataset and reranking_train actions --- src/run_action_gin.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/run_action_gin.py b/src/run_action_gin.py index f9038a4..6d930af 100644 --- a/src/run_action_gin.py +++ b/src/run_action_gin.py @@ -21,7 +21,9 @@ from finetunings.generate_epochs.generate import generate from multilingual_dataset.combine_embs import combine_embs_by_qid from multilingual_dataset.creator import create_multilingual_dataset, run_kb_creator +from reranking.dataset.create_dataset import create_default_binary_dataset from reranking.training.trainer import train_ddp as reranking_train_ddp +from reranking.training.trainer_simple import train as reranking_train from tokenization.runner import ( run_damuel_description, run_damuel_description_context, @@ -101,6 +103,10 @@ def choose_action(action): return find_candidates case "reranking_train_ddp": return reranking_train_ddp + case "create_default_binary_dataset": + return create_default_binary_dataset + case "reranking_train": + return reranking_train case _: raise ValueError(f"Unknown action: {action}") From 7e801b6a867f80feb650e1d0c81bb91e2967f8e3 Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Sun, 12 Oct 2025 10:17:52 +0200 Subject: [PATCH 17/28] feat(creator): enhance dataset creation functions with detailed docstrings --- src/multilingual_dataset/creator.py | 35 +++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/src/multilingual_dataset/creator.py b/src/multilingual_dataset/creator.py index 69a7e46..b018bb9 100644 --- a/src/multilingual_dataset/creator.py +++ b/src/multilingual_dataset/creator.py @@ -284,13 +284,48 @@ def create_multilingual_dataset( dest_dir: Union[str, Path], max_links_per_qid: int, ) -> None: + """Create a multilingual dataset by mixing links and building a language-filtered KB. + + - Links: copies and intermixes link shards from the given languages into `dest_dir/links`, + limiting per-QID occurrences to `max_links_per_qid`. Outputs NPZ files with arrays + `tokens` and `qids`. + - KB pages: writes a subset of description/page shards to `dest_dir/descs_pages`, + assigning up to one language per QID by default. + + Args: + source_dir: Root directory of the DAMUEL dataset to read from. + langs: Language codes to include. + dest_dir: Output directory; creates `links/` and `descs_pages/` subfolders. + max_links_per_qid: Maximum number of link samples retained per QID. + + Notes: + Uses parallel mixing and threaded I/O for performance. + """ MultilingualDatasetCreator(Path(source_dir), langs, Path(dest_dir), max_links_per_qid).run() +@gin.configurable def run_kb_creator( source_dir: Union[str, Path], langs: list[str], dest_dir: Union[str, Path], langs_per_qid: int, ) -> None: + """ + Build a language-filtered KB (descriptions/pages) subset from a DAMUEL dataset. + + For each QID, selects up to `langs_per_qid` languages ranked by link frequency + (ties broken by overall language size), then copies only the chosen language pages + into `dest_dir/descs_pages` as compressed NPZ shards named `mentions_{lang}_{i}.npz`. + + Args: + source_dir: Root path of the DAMUEL dataset to read (links and pages). + langs: List of language codes to consider. + dest_dir: Output directory; the 'descs_pages' subfolder is created inside. + langs_per_qid: Maximum number of languages to assign to each QID. + + Notes: + - Affects KB creation only; link files are not modified. + - Uses parallel I/O with ThreadPoolExecutor for speed. + """ _KBCreator(DamuelPaths(source_dir), langs, Path(dest_dir), langs_per_qid).run() From 3584f42cae99d69d30129de39617a03417f5d140 Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Sun, 12 Oct 2025 10:18:12 +0200 Subject: [PATCH 18/28] feat(rerank): enhance create_binary_dataset function with improved link token loading and dataset creation logic --- src/reranking/dataset/create_dataset.py | 260 +++++++++++------------- 1 file changed, 114 insertions(+), 146 deletions(-) diff --git a/src/reranking/dataset/create_dataset.py b/src/reranking/dataset/create_dataset.py index f5e46ba..4a1c861 100644 --- a/src/reranking/dataset/create_dataset.py +++ b/src/reranking/dataset/create_dataset.py @@ -1,7 +1,5 @@ -import sys from pathlib import Path -import numba as nb import numpy as np import torch import torch.utils.data @@ -9,7 +7,7 @@ from models.searchers.brute_force_searcher import BruteForceSearcher, DPBruteForceSearcherPT from utils.embeddings import create_attention_mask -from utils.loaders import load_embs_and_qids, load_tokens_qids_from_dir +from utils.loaders import load_embs_and_qids, load_tokens_qids, load_tokens_qids_from_dir from utils.model_factory import ModelFactory device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -57,9 +55,18 @@ def create_binary_dataset( searcher = BruteForceSearcher(index_embs, index_qids) # Load link tokens and qids - link_tokens, link_qids = load_tokens_qids_from_dir(link_tokens_path, max_items_to_load=10**7) - # Loaders order by qids which is not necessarily what we want - print(link_tokens.shape) + link_tokens_path = Path(link_tokens_path) + output_path = Path(output_path) + output_path.mkdir(parents=True, exist_ok=True) + + link_files = sorted( + [p for p in link_tokens_path.iterdir() if p.is_file() and p.suffix == ".npz"], + key=lambda p: p.name, + ) + + if not link_files: + raise FileNotFoundError(f"No .npz files found in {link_tokens_path}") + # Load embedding model model = ModelFactory.auto_load_from_file( model_name, @@ -71,163 +78,124 @@ def create_binary_dataset( model.to(torch.bfloat16) model = torch.compile(model) - index_qids_set = set(index_qids) - known_qids_mask = np.array([q in index_qids_set for q in link_qids]) - - link_tokens = link_tokens[known_qids_mask] - link_qids = link_qids[known_qids_mask] + index_qid_to_index = {int(qid): i for i, qid in enumerate(index_qids)} + index_qids_set = set(index_qid_to_index.keys()) + + for link_file in tqdm(link_files, desc="Processing link files"): + link_tokens, link_qids = load_tokens_qids(link_file) + + known_qids_mask = np.array([int(q) in index_qids_set for q in link_qids], dtype=bool) + link_tokens = link_tokens[known_qids_mask] + link_qids = link_qids[known_qids_mask] + + link_tokens_tensor = torch.from_numpy(link_tokens.astype(np.int32, copy=False)) + link_qids_tensor = torch.from_numpy(link_qids.astype(np.int64, copy=False)) + dataset = torch.utils.data.TensorDataset(link_tokens_tensor, link_qids_tensor) + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=2 + ) + + data_len = len(dataset) + if data_len == 0: + print(f"Skipping {link_file.name}: dataset length is zero after filtering") + continue + + description_tokens = np.zeros((data_len * 2, index_tokens.shape[1]), dtype=np.int32) + link_tokens_list = np.zeros((data_len * 2, link_tokens.shape[1]), dtype=np.int32) + y = np.zeros((data_len * 2,), dtype=np.int8) + qids = np.zeros((data_len * 2,), dtype=np.int32) + output_index = 0 + + for batch_tokens, batch_qids in tqdm( + dataloader, desc=f"Creating dataset for {link_file.name}", total=len(dataloader) + ): + attention_mask = create_attention_mask(batch_tokens) + + with torch.inference_mode(): + batch_embs = ( + model(batch_tokens.to(device), attention_mask.to(device)) + .to(torch.float16) + .cpu() + ) - # Create DataLoader - link_tokens = torch.from_numpy(link_tokens) - link_qids = torch.from_numpy(link_qids) + top_qids = searcher.find(batch_embs.numpy(), num_neighbors=2) - link_tokens = link_tokens.to(torch.int32) - dataset = torch.utils.data.TensorDataset(link_tokens, link_qids) - dataloader = torch.utils.data.DataLoader( - dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=2 - ) + batch_qids_np = batch_qids.cpu().numpy() + batch_tokens_np = batch_tokens.cpu().numpy().astype(np.int32, copy=False) + positive_mask = [index_qid_to_index[int(qid)] for qid in batch_qids_np] + data_size = len(batch_tokens) - # Initialize dataset arrays - description_tokens = [] - link_tokens_list = [] - y = [] - - print("Dataset length:", len(dataset)) - description_tokens = np.zeros((len(dataset) * 2, index_tokens.shape[1]), dtype=np.int32) - link_tokens_list = np.zeros((len(dataset) * 2, link_tokens.shape[1]), dtype=np.int32) - y = np.zeros((len(dataset) * 2,), dtype=np.int8) - qids = np.zeros((len(dataset) * 2,), dtype=np.int32) - output_index = 0 - - index_qid_to_index = {qid: i for i, qid in enumerate(index_qids)} - - # Iterate over batches - for batch_tokens, batch_qids in tqdm( - dataloader, desc="Creating dataset", total=len(dataloader) - ): - # Embed link tokens - with torch.inference_mode(): - batch_embs = ( - model( - batch_tokens.to(device).to(torch.int64), - create_attention_mask(batch_tokens).to(device), - ) - .to(torch.float16) - .cpu() + description_tokens[output_index : output_index + data_size] = index_tokens[ + positive_mask + ] + link_tokens_list[output_index : output_index + data_size] = batch_tokens_np + y[output_index : output_index + data_size] = 1 + qids[output_index : output_index + data_size] = batch_qids_np.astype( + np.int32, copy=False ) - # Find top matches - top_qids = searcher.find(batch_embs.numpy(), num_neighbors=2) + output_index += data_size - del batch_embs + neg_qids = get_neg_qids(top_qids, batch_qids_np) - positive_mask = [index_qid_to_index[int(qid)] for qid in batch_qids] - data_size = len(batch_tokens) - description_tokens[output_index : output_index + data_size] = index_tokens[positive_mask] - link_tokens_list[output_index : output_index + data_size] = batch_tokens.numpy() - y[output_index : output_index + data_size] = 1 - qids[output_index : output_index + data_size] = batch_qids.numpy() + negative_mask = [index_qid_to_index[int(qid)] for qid in neg_qids] + description_tokens[output_index : output_index + data_size] = index_tokens[ + negative_mask + ] + link_tokens_list[output_index : output_index + data_size] = batch_tokens_np + y[output_index : output_index + data_size] = 0 + qids[output_index : output_index + data_size] = np.array(neg_qids, dtype=np.int32) - output_index += data_size + output_index += data_size - neg_qids = get_neg_qids(top_qids, batch_qids) + if output_index != data_len * 2: + description_tokens = description_tokens[:output_index] + link_tokens_list = link_tokens_list[:output_index] + y = y[:output_index] + qids = qids[:output_index] - negative_mask = [index_qid_to_index[qid] for qid in neg_qids] - description_tokens[output_index : output_index + data_size] = index_tokens[negative_mask] - link_tokens_list[output_index : output_index + data_size] = batch_tokens.numpy() - y[output_index : output_index + data_size] = 0 - qids[output_index : output_index + data_size] = np.array(neg_qids) + output_file = output_path / f"{link_file.stem}_dataset.npz" - output_index += data_size + print( + f"Saving dataset for {link_file.name} -> {output_file.name} | " + f"positives/negatives: {output_index // 2}" + ) - # Convert to numpy arrays + np.savez( + output_file, + description_tokens=description_tokens, + link_tokens=link_tokens_list, + y=y, + qids=qids, + ) - print(description_tokens.shape) - print(link_tokens_list.shape) - print(y.shape) - # Save dataset - np.savez( - output_path, - description_tokens=description_tokens, - link_tokens=link_tokens_list, - y=y, - qids=qids, +def create_default_binary_dataset(): + index_embs_dir = Path( + "/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/damuel_for_index_6" ) - - -def create_multiclass_dataset( - index_embs_dir: Path, - index_tokens_path: Path, - link_tokens_path: Path, - model_name: str, - embedding_model_path_dict: Path, - output_path: Path, - total_classes: int, - target_dim: int = None, - batch_size: int = 512, -): - assert total_classes > 1 - index_embs, index_qids = load_embs_and_qids(index_embs_dir) - index_qids_set = set(index_qids) - index_embs = index_embs.astype(np.float16) - index_tokens, _ = load_tokens_qids_from_dir(index_tokens_path) - print(index_tokens.shape) - print(len(index_qids_set)) - - qid_to_index = {qid: i for i, qid in enumerate(index_qids)} - - # Create BruteForceSearcher - searcher = BruteForceSearcher(index_embs, index_qids) - - # Load link tokens and qids - link_tokens, link_qids = load_tokens_qids_from_dir(link_tokens_path) - # Load embedding model - model = ModelFactory.auto_load_from_file( - model_name, - embedding_model_path_dict, - target_dim=target_dim, + index_tokens_path = Path( + "/lnet/work/home-students-external/farhan/troja/outputs/v2_normal_filtered/descs_pages" ) - model.eval() - model.to(device) - - # Create dataset and dataloader from link_tokens and link_qids - dataset = list(zip(link_tokens, link_qids)) - dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, pin_memory=True) - - link_tokens, qids = [], [] - for B_tokens, B_qids in tqdm(dataloader, desc="Creating dataset"): - with torch.no_grad(): - B_embs = ( - model( - B_tokens.to(device).to(torch.int64), - create_attention_mask(B_tokens).to(device), - ) - .to(torch.float16) - .cpu() - ) + link_tokens_path = Path( + "/lnet/work/home-students-external/farhan/troja/outputs/v2_normal_filtered/links" + ) + embedding_model_path = Path( + "/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/models_5/ema.pth", + ) + output_path = Path( + "/lnet/work/home-students-external/farhan/troja/outputs/reranking_test/reranker_dataset_with_qids" + ) + model_name = "/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base" - top_qids = searcher.find(B_embs.numpy().astype(np.float16), num_neighbors=total_classes) - for i, qid in enumerate(B_qids.numpy()): - if qid in top_qids[i]: - idx = top_qids[i].index(qid) - top_qids[i][idx], top_qids[i][total_classes - 1] = ( - top_qids[i][total_classes - 1], - top_qids[i][idx], - ) - else: - top_qids[i][total_classes - 1] = qid - link_tokens.extend(B_tokens.numpy()) - qids.extend(top_qids) - link_tokens = np.array(link_tokens) - qids = np.array(qids) - print(link_tokens.shape) - print(qids.shape) - - np.save( + create_binary_dataset( + index_embs_dir, + index_tokens_path, + link_tokens_path, + model_name, + embedding_model_path, output_path, - link_tokens=link_tokens, - qids=qids, + batch_size=2560, ) @@ -245,7 +213,7 @@ def create_multiclass_dataset( "/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/models_5/ema.pth", ) output_path = Path( - "/lnet/work/home-students-external/farhan/troja/outputs/reranking_test/reranker_dataset_with_qids.npz" + "/lnet/work/home-students-external/farhan/troja/outputs/reranking_test/reranker_dataset_with_qids" ) model_name = "/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base" From ea5ee2020e1f014a2278f30c424dcfda328d45a1 Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Sun, 12 Oct 2025 10:18:23 +0200 Subject: [PATCH 19/28] feat(config): add model path for run_damuel_description in paraphrase configuration --- configs/paraphrase.gin | 1 + 1 file changed, 1 insertion(+) diff --git a/configs/paraphrase.gin b/configs/paraphrase.gin index 22c4e97..118e307 100644 --- a/configs/paraphrase.gin +++ b/configs/paraphrase.gin @@ -23,5 +23,6 @@ train_ddp.FOUNDATION_MODEL_PATH=%model_path run_mewsli_mention.model_path=%model_path run_damuel_mention.model_path=%model_path run_damuel_description_context.model_path=%model_path +run_damuel_description.model_path=%model_path run_damuel_link_context.model_path=%model_path run_mewsli_context.model_path=%model_path \ No newline at end of file From 6aa354401d231ccac95118e90bff32bd47bcf92e Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Sun, 12 Oct 2025 10:19:29 +0200 Subject: [PATCH 20/28] Refactor training scripts and add new models - Updated `trainer.py` to support dynamic configuration loading and improved logging with Weights & Biases integration. - Introduced `trainer_simple.py` for simplified single-device training. - Enhanced `training_configs.py` with additional model configurations and parameters. - Modified `reranking2.py` and added `reranking3.py` for improved reranking functionality and model evaluation. - Implemented a new test suite for `PairwiseMLPRerankerWithRetrievalScore` and `RerankingIterableDataset`. - Added support for new models including `FullLEALLAReranker` and `PairwiseMLPRerankerWithLargeContextEmb`. - Improved data handling and validation logic in training scripts. --- src/reranking/models/__init__.py | 2 + src/reranking/models/base.py | 35 +- .../models/context_emb_with_attention.py | 345 ++++++++++++++++++ src/reranking/models/full_lealla.py | 199 ++++++++++ src/reranking/models/pairwise_mlp.py | 122 ++++++- .../pairwise_mlp_with_large_context_emb.py | 241 ++++++++++++ .../pairwise_mlp_with_retrieval_score.py | 50 +++ .../training/reranking_iterable_dataset.py | 71 ++++ src/reranking/training/trainer.py | 123 ++++--- src/reranking/training/trainer_simple.py | 163 +++++++++ src/reranking/training/training_configs.py | 279 +++++++++++++- src/scripts/qwen/reranking2.py | 14 +- src/scripts/qwen/reranking3.py | 265 ++++++++++++++ .../test_pairwise_with_retrieval_score.py | 96 +++++ .../test_reranking_iterable_dataset.py | 27 ++ 15 files changed, 1941 insertions(+), 91 deletions(-) create mode 100644 src/reranking/models/context_emb_with_attention.py create mode 100644 src/reranking/models/full_lealla.py create mode 100644 src/reranking/models/pairwise_mlp_with_large_context_emb.py create mode 100644 src/reranking/models/pairwise_mlp_with_retrieval_score.py create mode 100644 src/reranking/training/reranking_iterable_dataset.py create mode 100644 src/reranking/training/trainer_simple.py create mode 100644 src/scripts/qwen/reranking3.py create mode 100644 tests/reranking/models/test_pairwise_with_retrieval_score.py create mode 100644 tests/reranking/training/test_reranking_iterable_dataset.py diff --git a/src/reranking/models/__init__.py b/src/reranking/models/__init__.py index 5bc81b0..ec8a8dc 100644 --- a/src/reranking/models/__init__.py +++ b/src/reranking/models/__init__.py @@ -1,5 +1,7 @@ from .pairwise_mlp import PairwiseMLPReranker +from .pairwise_mlp_with_retrieval_score import PairwiseMLPRerankerWithRetrievalScore __all__ = [ "PairwiseMLPReranker", + "PairwiseMLPRerankerWithRetrievalScore", ] diff --git a/src/reranking/models/base.py b/src/reranking/models/base.py index e6ba5f3..14dde8d 100644 --- a/src/reranking/models/base.py +++ b/src/reranking/models/base.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Dict +from typing import Any, Dict, Sequence import torch from torch import nn @@ -10,12 +10,39 @@ class BaseRerankingModel(nn.Module, ABC): """Abstract base class for reranking models.""" - @abstractmethod def train_step(self, data: Dict[str, Any]) -> torch.Tensor: """Run a single training step on the provided batch data and return the loss.""" + self.update_ema() + return self.train_step_imp(data) + + @abstractmethod + def train_step_imp(self, data: Dict[str, Any]) -> torch.Tensor: + """Run a single training step on the provided batch data and return the loss.""" + raise NotImplementedError + + @abstractmethod + def update_ema(self) -> None: + """Update the EMA (Exponential Moving Average) of model parameters.""" + raise NotImplementedError + + @abstractmethod + def score( + self, mention: str | Sequence[str], entity_description: str | Sequence[str] + ) -> torch.Tensor | float: + """Compute similarity-based probability for one or more mention/entity pairs.""" + raise NotImplementedError + + @abstractmethod + def score_from_tokens(self, mention: Any, entity_description: Any) -> torch.Tensor: + """Compute similarity-based probabilities for tokenized mention/entity pairs.""" + raise NotImplementedError + + @abstractmethod + def save(self, path: str) -> None: + """Save the model to the specified path.""" raise NotImplementedError @abstractmethod - def score(self, mention: str, entity_description: str) -> float: - """Compute a similarity-based probability that the mention refers to the entity.""" + def load(self, path: str) -> None: + """Load the model from the specified path.""" raise NotImplementedError diff --git a/src/reranking/models/context_emb_with_attention.py b/src/reranking/models/context_emb_with_attention.py new file mode 100644 index 0000000..2f62e47 --- /dev/null +++ b/src/reranking/models/context_emb_with_attention.py @@ -0,0 +1,345 @@ +from __future__ import annotations + +from copy import deepcopy +from typing import Any, Dict, Mapping, Sequence + +import torch +from einops import rearrange +from torch import nn +from transformers import AutoTokenizer + +from reranking.models.base import BaseRerankingModel +from utils.embeddings import create_attention_mask +from utils.model_factory import ModelFactory, ModelOutputType + + +def _maybe_convert_output_type(output_type: ModelOutputType | str | None) -> ModelOutputType | None: + if output_type is None or isinstance(output_type, ModelOutputType): + return output_type + return ModelOutputType(output_type) + + +def _infer_output_dim(model: nn.Module) -> int: + if hasattr(model, "output_dim"): + return int(getattr(model, "output_dim")) + if hasattr(model, "config") and hasattr(model.config, "hidden_size"): + return int(model.config.hidden_size) + if hasattr(model, "model"): + nested_model = getattr(model, "model") + if hasattr(nested_model, "config") and hasattr(nested_model.config, "hidden_size"): + return int(nested_model.config.hidden_size) + raise ValueError("Unable to infer output dimension from the provided base model.") + + +class GPTLayer(nn.Module): + def __init__(self, model_width: int, dropout: float) -> None: + super().__init__() + self.linear1 = nn.Linear(model_width, model_width * 4) + self.activation = nn.GELU() + self.linear2 = nn.Linear(model_width * 4, model_width) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear1(x) + x = self.activation(x) + x = self.linear2(x) + x = self.dropout(x) + return x + + +class _CATransformerBlock(nn.Module): + def __init__(self, model_width: int, dropout: float, num_heads: int) -> None: + super().__init__() + self.layer_norm1 = nn.LayerNorm(model_width) + self.self_attention = nn.MultiheadAttention( + embed_dim=model_width, num_heads=num_heads, dropout=dropout + ) + self.layer_norm2 = nn.LayerNorm(model_width) + self.feed_forward = GPTLayer(model_width, dropout) + + def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor: + # Self-attention block + residual = x + x = self.layer_norm1(x) + x, _ = self.self_attention(x, context, x) + x = x + residual + + # Feed-forward block + residual = x + x = self.layer_norm2(x) + x = self.feed_forward(x) + x = x + residual + + return x + + +class _DoubleCrossAttention(nn.Module): + def __init__(self, model_width: int, dropout: float, num_heads: int) -> None: + super().__init__() + self.transformer1 = _CATransformerBlock(model_width, dropout, num_heads) + self.transformer2 = _CATransformerBlock(model_width, dropout, num_heads) + + def forward( + self, main_input: torch.Tensor, context_input1: torch.Tensor, context_input2: torch.Tensor + ) -> torch.Tensor: + x = self.transformer1(main_input, context_input1) + x = self.transformer2(x, context_input2) + return x + + +class _Classifier(nn.Module): + def __init__( + self, + model_width: int, + dropout: float, + num_heads: int, + num_layers: int, + n_tokens: int, + base_dim: int, + paraphrase_dim: int, + ) -> None: + super().__init__() + self.double_cross_attention = nn.ModuleList( + [_DoubleCrossAttention(model_width, dropout, num_heads) for _ in range(num_layers)] + ) + self.mewsli_to_tokens = nn.Linear(base_dim, model_width * n_tokens) + self.base_to_tokens = nn.Linear(base_dim, model_width * n_tokens) + self.paraphrase_to_tokens = nn.Linear(paraphrase_dim, model_width * n_tokens) + + self.final_projection = nn.Linear(model_width * n_tokens, 1) + self.model_width = model_width + + def forward( + self, mewsli_embs: torch.Tensor, base_embs: torch.Tensor, paraphrase_embs: torch.Tensor + ) -> torch.Tensor: + mewsli_tokens = self.mewsli_to_tokens(mewsli_embs) + base_tokens = self.base_to_tokens(base_embs) + paraphrase_tokens = self.paraphrase_to_tokens(paraphrase_embs) + + mewsli_tokens = rearrange(mewsli_tokens, "b (n d) -> n b d", n=self.model_width) + base_tokens = rearrange(base_tokens, "b (n d) -> n b d", n=self.model_width) + paraphrase_tokens = rearrange(paraphrase_tokens, "b (n d) -> n b d", n=self.model_width) + + for layer in self.double_cross_attention: + mewsli_tokens = layer(mewsli_tokens, base_tokens, paraphrase_tokens) + + mewsli_tokens = rearrange(mewsli_tokens, "n b d -> b (n d)") + logits = self.final_projection(mewsli_tokens).squeeze(-1) + return logits + + +class ContextEmbWithAttention(BaseRerankingModel): + """Reranking model that augments a LEALLA encoder with an MLP head that uses paraphrase embedding to get more context.""" + + def __init__( + self, + model_name_or_path: str, + qid_to_paraphrase_emb: Dict[int, torch.Tensor], + qid_to_base_emb: Dict[int, torch.Tensor], + *, + state_dict_path: str | None = None, + target_dim: int | None = None, + output_type: ModelOutputType | str | None = None, + tokenizer_name_or_path: str | None = None, + dropout: float = 0.1, + ema_decay: float = 0.9999, + model_width: int = 64, + num_heads: int = 16, + num_layers: int = 2, + n_tokens: int = 64, + ) -> None: + super().__init__() + self.ema_decay = ema_decay + + resolved_output_type = _maybe_convert_output_type(output_type) + self.base_model = ModelFactory.auto_load_from_file( + model_name_or_path, + state_dict_path=state_dict_path, + target_dim=target_dim, + output_type=resolved_output_type, + ) + self.base_model.eval() + self.base_model.requires_grad_(False) + + self.paraphrase_model_embedding_dim = next(iter(qid_to_paraphrase_emb.values())).shape[0] + self.base_model_embedding_dim = next(iter(qid_to_base_emb.values())).shape[0] + + self.classifier = _Classifier( + model_width=model_width, + dropout=dropout, + num_heads=num_heads, + num_layers=num_layers, + n_tokens=n_tokens, + base_dim=self.base_model_embedding_dim, + paraphrase_dim=self.paraphrase_model_embedding_dim, + ) + self.classifier.to(dtype=torch.float16) + + self.classifier_ema = deepcopy(self.classifier) + + self.model = _PairwiseMLPReranker( + self.base_model, self.classifier, qid_to_paraphrase_emb, qid_to_base_emb + ) + + tokenizer_id = tokenizer_name_or_path or model_name_or_path + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) + self.loss_fn = nn.BCEWithLogitsLoss() + + def forward( + self, + mention_tokens: torch.Tensor, + entity_tokens: torch.Tensor, + ) -> torch.Tensor: + return self.model(mention_tokens, entity_tokens) + + def train_step_imp(self, data: Dict[str, Any]) -> torch.Tensor: + self.train() + + mention_tokens = data["mention_tokens"] + labels = data["labels"].float().view(-1) + qids = data["qids"].view(-1) + + logits = self.model(mention_tokens, qids).view(-1) + loss = self.loss_fn(logits, labels) + return loss + + def update_ema(self) -> None: + with torch.no_grad(): + for param, ema_param in zip( + self.classifier.parameters(), self.classifier_ema.parameters() + ): + ema_param.data.mul_(self.ema_decay).add_(param.data, alpha=1 - self.ema_decay) + + def save(self, path: str) -> None: + ema_path = path.replace(".pth", "_ema.pth") + torch.save(self.classifier_ema.state_dict(), ema_path) + torch.save(self.classifier.state_dict(), path) + + def load(self, path: str) -> None: + state_dict = torch.load(path, map_location="cpu") + self.classifier_ema.load_state_dict(state_dict) + self.classifier.load_state_dict(state_dict) + + @torch.inference_mode() + def score( + self, mention: str | Sequence[str], entity_description: str | Sequence[str] + ) -> float | torch.Tensor: + assert False, "Not correct, TODO FIX this" + single_pair = isinstance(mention, str) and isinstance(entity_description, str) + + if isinstance(mention, str): + mention_batch = [mention] + elif isinstance(mention, Sequence): + mention_batch = list(mention) + else: + raise TypeError("Mentions must be a string or a sequence of strings.") + + if isinstance(entity_description, str): + entity_batch = [entity_description] + elif isinstance(entity_description, Sequence): + entity_batch = list(entity_description) + else: + raise TypeError("Entity descriptions must be a string or a sequence of strings.") + + if len(mention_batch) != len(entity_batch): + if len(mention_batch) == 1: + mention_batch = mention_batch * len(entity_batch) + single_pair = False + elif len(entity_batch) == 1: + entity_batch = entity_batch * len(mention_batch) + single_pair = False + else: + raise ValueError( + "Mention and entity batches must be the same length or broadcastable." + ) + + mention_tokens = self.tokenizer( + mention_batch, + padding=True, + truncation=True, + return_tensors="pt", + )["input_ids"] + entity_tokens = self.tokenizer( + entity_batch, + padding=True, + truncation=True, + return_tensors="pt", + )["input_ids"] + + probabilities = self.score_from_tokens(mention_tokens, entity_tokens) + if not isinstance(probabilities, torch.Tensor): + probabilities = torch.as_tensor(probabilities) + + probabilities = probabilities.reshape(-1).detach().cpu() + + if single_pair: + return float(probabilities[0]) + return probabilities + + @torch.inference_mode() + def score_from_tokens(self, mentions: torch.Tensor, qids: torch.Tensor) -> torch.Tensor: + logits = self.model(mentions, qids) + probability = torch.sigmoid(logits).reshape(-1) + return probability + + def classifier_forward(self, x): + return self.model.classifier_forward(x) + + +class _PairwiseMLPReranker(nn.Module): + def __init__( + self, + base_model: nn.Module, + classifier: nn.Module, + qid_to_paraphrase_emb: Dict[int, torch.Tensor], + qid_to_base_emb: Dict[int, torch.Tensor], + ) -> None: + super().__init__() + self.base_model = base_model + self.classifier = classifier + self.qid_to_paraphrase_emb = qid_to_paraphrase_emb + self.qid_to_base_emb = qid_to_base_emb + + def forward( + self, + mention_tokens: torch.Tensor, + qids: torch.Tensor, + return_embeddings: bool = False, + ) -> torch.Tensor: + + mention_embeddings = self._encode(mention_tokens).clone().detach() + paraphrase_embeddings = torch.stack( + [self.qid_to_paraphrase_emb[int(qid)] for qid in qids], dim=0 + ).to(mention_embeddings.device) + base_embeddings = torch.stack([self.qid_to_base_emb[int(qid)] for qid in qids], dim=0).to( + mention_embeddings.device + ) + + logits = self.classifier( + mention_embeddings, base_embeddings, paraphrase_embeddings + ).squeeze(-1) + if return_embeddings: + return logits, mention_embeddings, paraphrase_embeddings, base_embeddings + return logits + + def train(self, mode: bool = True) -> _PairwiseMLPReranker: + super().train(mode) + # Make sure that base model is never trained. + self.base_model.eval() + return self + + def classifier_forward(self, x): + return self.classifier(x) + + @torch.inference_mode() + def _encode(self, tokens: Mapping[str, torch.Tensor] | torch.Tensor) -> torch.Tensor: + if isinstance(tokens, Mapping): + input_ids = tokens["input_ids"] + attention_mask = tokens.get("attention_mask") + if attention_mask is None: + attention_mask = create_attention_mask(input_ids) + else: + input_ids = tokens + attention_mask = create_attention_mask(tokens) + + return self.base_model(input_ids=input_ids, attention_mask=attention_mask) diff --git a/src/reranking/models/full_lealla.py b/src/reranking/models/full_lealla.py new file mode 100644 index 0000000..c239243 --- /dev/null +++ b/src/reranking/models/full_lealla.py @@ -0,0 +1,199 @@ +from __future__ import annotations + +from copy import deepcopy +from typing import Any, Dict, Sequence + +import torch +from einops import rearrange +from torch import nn +from transformers import AutoTokenizer + +from reranking.models.base import BaseRerankingModel +from reranking.models.context_emb_with_attention import GPTLayer +from utils.embeddings import create_attention_mask +from utils.model_factory import ModelFactory, ModelOutputType + + +def _maybe_convert_output_type(output_type: ModelOutputType | str | None) -> ModelOutputType | None: + if output_type is None or isinstance(output_type, ModelOutputType): + return output_type + return ModelOutputType(output_type) + + +def _infer_output_dim(model: nn.Module) -> int: + if hasattr(model, "output_dim"): + return int(getattr(model, "output_dim")) + if hasattr(model, "config") and hasattr(model.config, "hidden_size"): + return int(model.config.hidden_size) + if hasattr(model, "model"): + nested_model = getattr(model, "model") + if hasattr(nested_model, "config") and hasattr(nested_model.config, "hidden_size"): + return int(nested_model.config.hidden_size) + raise ValueError("Unable to infer output dimension from the provided base model.") + + +class _Model(nn.Module): + def __init__(self, base_model: nn.Module, embedding_dim: int, dropout: float) -> None: + super().__init__() + self.base_model = base_model + self.embedding_dim = embedding_dim + self.gpt_layer = GPTLayer(model_width=embedding_dim, dropout=dropout) + self.final_layer = nn.Linear(embedding_dim, 1) + + def forward(self, ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + # print(ids.sh1ape, attention_mask.shape) + base_embeddings = self.base_model(ids, attention_mask) + # print("base_embeddings", base_embeddings.shape) + x = self.gpt_layer(base_embeddings) + # print("x", x.shape) + logits = self.final_layer(x).squeeze(-1) + # print("logits", logits.shape) + return logits + + +class FullLEALLAReranker(BaseRerankingModel): + """Reranking model that augments a LEALLA encoder with an MLP head.""" + + def __init__( + self, + model_name_or_path: str, + *, + state_dict_path: str | None = None, + tokenizer_name_or_path: str | None = None, + dropout: float = 0.1, + ema_decay: float = 0.9999, + ) -> None: + super().__init__() + + self.ema_decay = ema_decay + + self.base_model = ModelFactory.auto_load_from_file( + model_name_or_path, + state_dict_path=state_dict_path, + ) + self.embedding_dim = _infer_output_dim(self.base_model) + + self.model = _Model( + base_model=self.base_model, + embedding_dim=self.embedding_dim, + dropout=dropout, + ) + self.model_ema = deepcopy(self.model) + + tokenizer_id = tokenizer_name_or_path or model_name_or_path + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) + self.loss_fn = nn.BCEWithLogitsLoss() + + def forward( + self, + mention_tokens: torch.Tensor, + entity_tokens: torch.Tensor, + ) -> torch.Tensor: + ids, attention_mask = self.prepare_for_forward(mention_tokens, entity_tokens) + return self.model(ids, attention_mask) + + def prepare_for_forward(self, mention_tokens: torch.Tensor, entity_tokens: torch.Tensor): + ids = rearrange([mention_tokens, entity_tokens], "d b n -> b (d n)") + attention_mask = create_attention_mask(ids).to(dtype=ids.dtype, device=ids.device) + return ids, attention_mask + + def train_step_imp(self, data: Dict[str, Any]) -> torch.Tensor: + self.train() + + mention_tokens = data["mention_tokens"] + entity_tokens = data["entity_tokens"] + labels = data["labels"].float() + + ids, attention_mask = self.prepare_for_forward(mention_tokens, entity_tokens) + logits = self.model(ids, attention_mask) + + loss = self.loss_fn(logits, labels) + return loss + + def update_ema(self) -> None: + with torch.no_grad(): + for param, ema_param in zip(self.model.parameters(), self.model_ema.parameters()): + ema_param.data.mul_(self.ema_decay).add_(param.data, alpha=1 - self.ema_decay) + + def save(self, path: str) -> None: + ema_path = path.replace(".pth", "_ema.pth") + torch.save(self.model_ema.state_dict(), ema_path) + torch.save(self.model.state_dict(), path) + + def load(self, path: str) -> None: + state_dict = torch.load(path, map_location="cpu") + new_state_dict = {} + for k, v in state_dict.items(): + if k.startswith("_orig_mod.module.model."): + new_k = k.replace("_orig_mod.module.model.", "") + elif k.startswith("module."): + new_k = k.replace("module.", "") + else: + new_k = k + new_state_dict[new_k] = v + state_dict = new_state_dict + self.model_ema.load_state_dict(state_dict) + self.model.load_state_dict(state_dict) + + @torch.inference_mode() + def score( + self, mention: str | Sequence[str], entity_description: str | Sequence[str] + ) -> float | torch.Tensor: + single_pair = isinstance(mention, str) and isinstance(entity_description, str) + + if isinstance(mention, str): + mention_batch = [mention] + elif isinstance(mention, Sequence): + mention_batch = list(mention) + else: + raise TypeError("Mentions must be a string or a sequence of strings.") + + if isinstance(entity_description, str): + entity_batch = [entity_description] + elif isinstance(entity_description, Sequence): + entity_batch = list(entity_description) + else: + raise TypeError("Entity descriptions must be a string or a sequence of strings.") + + if len(mention_batch) != len(entity_batch): + if len(mention_batch) == 1: + mention_batch = mention_batch * len(entity_batch) + single_pair = False + elif len(entity_batch) == 1: + entity_batch = entity_batch * len(mention_batch) + single_pair = False + else: + raise ValueError( + "Mention and entity batches must be the same length or broadcastable." + ) + + mention_tokens = self.tokenizer( + mention_batch, + padding=True, + truncation=True, + return_tensors="pt", + )["input_ids"] + entity_tokens = self.tokenizer( + entity_batch, + padding=True, + truncation=True, + return_tensors="pt", + )["input_ids"] + + probabilities = self.score_from_tokens(mention_tokens, entity_tokens) + if not isinstance(probabilities, torch.Tensor): + probabilities = torch.as_tensor(probabilities) + + probabilities = probabilities.reshape(-1).detach().cpu() + + if single_pair: + return float(probabilities[0]) + return probabilities + + @torch.inference_mode() + def score_from_tokens(self, mentions: torch.Tensor, entities: torch.Tensor) -> torch.Tensor: + ids = rearrange([mentions, entities], "d b n -> b (d n)") + attention_mask = create_attention_mask(ids).to(dtype=ids.dtype, device=ids.device) + logits = self.model(ids, attention_mask) + probability = torch.sigmoid(logits).reshape(-1) + return probability diff --git a/src/reranking/models/pairwise_mlp.py b/src/reranking/models/pairwise_mlp.py index f2bd266..fc1c458 100644 --- a/src/reranking/models/pairwise_mlp.py +++ b/src/reranking/models/pairwise_mlp.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Any, Dict, Mapping +from copy import deepcopy +from typing import Any, Dict, Mapping, Sequence import torch from torch import nn @@ -42,9 +43,13 @@ def __init__( tokenizer_name_or_path: str | None = None, mlp_hidden_dim: int | None = None, dropout: float = 0.1, + emb_noise: float = 0, + ema_decay: float = 0.9999, ) -> None: super().__init__() + self.ema_decay = ema_decay + resolved_output_type = _maybe_convert_output_type(output_type) self.base_model = ModelFactory.auto_load_from_file( model_name_or_path, @@ -67,7 +72,9 @@ def __init__( nn.Linear(hidden_dim, 1), ) - self.model = _PairwiseMLPReranker(self.base_model, self.classifier) + self.classifier_ema = deepcopy(self.classifier) + + self.model = _PairwiseMLPReranker(self.base_model, self.classifier, emb_noise) tokenizer_id = tokenizer_name_or_path or model_name_or_path self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) @@ -80,7 +87,10 @@ def forward( ) -> torch.Tensor: return self.model.forward(mention_tokens, entity_tokens) - def train_step(self, data: Dict[str, Any]) -> torch.Tensor: + def classifier_forward(self, x: torch.Tensor) -> torch.Tensor: + return self.model.classifier_forward(x) + + def train_step_imp(self, data: Dict[str, Any]) -> torch.Tensor: self.train() mention_tokens = data["mention_tokens"] @@ -91,44 +101,122 @@ def train_step(self, data: Dict[str, Any]) -> torch.Tensor: loss = self.loss_fn(logits, labels) return loss + def update_ema(self) -> None: + with torch.no_grad(): + for param, ema_param in zip( + self.classifier.parameters(), self.classifier_ema.parameters() + ): + ema_param.data.mul_(self.ema_decay).add_(param.data, alpha=1 - self.ema_decay) + + def save(self, path: str) -> None: + ema_path = path.replace(".pth", "_ema.pth") + torch.save(self.classifier_ema.state_dict(), ema_path) + torch.save(self.classifier.state_dict(), path) + + def load(self, path: str) -> None: + state_dict = torch.load(path, map_location="cpu") + self.classifier_ema.load_state_dict(state_dict) + self.classifier.load_state_dict(state_dict) + @torch.inference_mode() - def score(self, mention: str, entity_description: str) -> float: + def score( + self, mention: str | Sequence[str], entity_description: str | Sequence[str] + ) -> float | torch.Tensor: + single_pair = isinstance(mention, str) and isinstance(entity_description, str) + + if isinstance(mention, str): + mention_batch = [mention] + elif isinstance(mention, Sequence): + mention_batch = list(mention) + else: + raise TypeError("Mentions must be a string or a sequence of strings.") + + if isinstance(entity_description, str): + entity_batch = [entity_description] + elif isinstance(entity_description, Sequence): + entity_batch = list(entity_description) + else: + raise TypeError("Entity descriptions must be a string or a sequence of strings.") + + if len(mention_batch) != len(entity_batch): + if len(mention_batch) == 1: + mention_batch = mention_batch * len(entity_batch) + single_pair = False + elif len(entity_batch) == 1: + entity_batch = entity_batch * len(mention_batch) + single_pair = False + else: + raise ValueError( + "Mention and entity batches must be the same length or broadcastable." + ) + mention_tokens = self.tokenizer( - mention, + mention_batch, padding=True, truncation=True, return_tensors="pt", )["input_ids"] entity_tokens = self.tokenizer( - entity_description, + entity_batch, padding=True, truncation=True, return_tensors="pt", )["input_ids"] - logits = self.model.forward(mention_tokens, entity_tokens) - probability = torch.sigmoid(logits).item() + + probabilities = self.score_from_tokens(mention_tokens, entity_tokens) + if not isinstance(probabilities, torch.Tensor): + probabilities = torch.as_tensor(probabilities) + + probabilities = probabilities.reshape(-1).detach().cpu() + + if single_pair: + return float(probabilities[0]) + return probabilities + + @torch.inference_mode() + def score_from_tokens(self, mentions: torch.Tensor, entities: torch.Tensor) -> torch.Tensor: + logits = self.model.forward(mentions, entities) + probability = torch.sigmoid(logits).reshape(-1) return probability class _PairwiseMLPReranker(nn.Module): - def __init__(self, base_model: nn.Module, classifier: nn.Module) -> None: + def __init__(self, base_model: nn.Module, classifier: nn.Module, emb_noise: float) -> None: super().__init__() self.base_model = base_model self.classifier = classifier + self.emb_noise = emb_noise + self.noise_layer = _GaussianNoiseLayer(emb_noise) def forward( self, mention_tokens: torch.Tensor, entity_tokens: torch.Tensor, + return_embeddings: bool = False, ) -> torch.Tensor: mention_embeddings = self._encode(mention_tokens) entity_embeddings = self._encode(entity_tokens) + if self.emb_noise > 0: + mention_embeddings = self.noise_layer(mention_embeddings) + entity_embeddings = self.noise_layer(entity_embeddings) + combined = torch.cat([mention_embeddings, entity_embeddings], dim=-1) logits = self.classifier(combined).squeeze(-1) + if return_embeddings: + return logits, mention_embeddings, entity_embeddings return logits + def train(self, mode: bool = True) -> _PairwiseMLPReranker: + super().train(mode) + # Make sure that base model is never trained. + self.base_model.eval() + return self + + def classifier_forward(self, x: torch.Tensor) -> torch.Tensor: + return self.classifier(x) + @torch.inference_mode() def _encode(self, tokens: Mapping[str, torch.Tensor] | torch.Tensor) -> torch.Tensor: if isinstance(tokens, Mapping): @@ -142,8 +230,14 @@ def _encode(self, tokens: Mapping[str, torch.Tensor] | torch.Tensor) -> torch.Te return self.base_model(input_ids=input_ids, attention_mask=attention_mask) - def train(self, mode: bool = True) -> _PairwiseMLPReranker: - super().train(mode) - # Make sure that base model is never trained. - self.base_model.eval() - return self + +class _GaussianNoiseLayer(nn.Module): + def __init__(self, std): + super().__init__() + self.std = std + + def forward(self, x): + if self.training: + noise = torch.randn_like(x) * self.std + return noise + x + return x diff --git a/src/reranking/models/pairwise_mlp_with_large_context_emb.py b/src/reranking/models/pairwise_mlp_with_large_context_emb.py new file mode 100644 index 0000000..858c66e --- /dev/null +++ b/src/reranking/models/pairwise_mlp_with_large_context_emb.py @@ -0,0 +1,241 @@ +from __future__ import annotations + +from copy import deepcopy +from typing import Any, Dict, Mapping, Sequence + +import torch +from torch import nn +from transformers import AutoTokenizer + +from reranking.models.base import BaseRerankingModel +from utils.embeddings import create_attention_mask +from utils.model_factory import ModelFactory, ModelOutputType + + +def _maybe_convert_output_type(output_type: ModelOutputType | str | None) -> ModelOutputType | None: + if output_type is None or isinstance(output_type, ModelOutputType): + return output_type + return ModelOutputType(output_type) + + +def _infer_output_dim(model: nn.Module) -> int: + if hasattr(model, "output_dim"): + return int(getattr(model, "output_dim")) + if hasattr(model, "config") and hasattr(model.config, "hidden_size"): + return int(model.config.hidden_size) + if hasattr(model, "model"): + nested_model = getattr(model, "model") + if hasattr(nested_model, "config") and hasattr(nested_model.config, "hidden_size"): + return int(nested_model.config.hidden_size) + raise ValueError("Unable to infer output dimension from the provided base model.") + + +class PairwiseMLPRerankerWithLargeContextEmb(BaseRerankingModel): + """Reranking model that augments a LEALLA encoder with an MLP head that uses paraphrase embedding to get more context.""" + + def __init__( + self, + model_name_or_path: str, + qid_to_paraphrase_emb: Dict[int, torch.Tensor], + qid_to_base_emb: Dict[int, torch.Tensor], + *, + state_dict_path: str | None = None, + target_dim: int | None = None, + output_type: ModelOutputType | str | None = None, + tokenizer_name_or_path: str | None = None, + dropout: float = 0.1, + ema_decay: float = 0.9999, + ) -> None: + super().__init__() + self.ema_decay = ema_decay + + resolved_output_type = _maybe_convert_output_type(output_type) + self.base_model = ModelFactory.auto_load_from_file( + model_name_or_path, + state_dict_path=state_dict_path, + target_dim=target_dim, + output_type=resolved_output_type, + ) + self.base_model.eval() + self.base_model.requires_grad_(False) + + self.paraphrase_model_embedding_dim = next(iter(qid_to_paraphrase_emb.values())).shape[0] + self.base_model_embedding_dim = next(iter(qid_to_base_emb.values())).shape[0] + + hidden_dim = 2 * self.base_model_embedding_dim + self.paraphrase_model_embedding_dim + + self.classifier = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim * 4), + nn.GELU(), + nn.Linear(4 * hidden_dim, hidden_dim), + nn.GELU(), + nn.Dropout(p=dropout), + nn.Linear(hidden_dim, 1), + ) + + self.classifier_ema = deepcopy(self.classifier) + + self.model = _PairwiseMLPReranker( + self.base_model, self.classifier, qid_to_paraphrase_emb, qid_to_base_emb + ) + + tokenizer_id = tokenizer_name_or_path or model_name_or_path + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) + self.loss_fn = nn.BCEWithLogitsLoss() + + def forward( + self, + mention_tokens: torch.Tensor, + entity_tokens: torch.Tensor, + ) -> torch.Tensor: + return self.model.forward(mention_tokens, entity_tokens) + + def train_step_imp(self, data: Dict[str, Any]) -> torch.Tensor: + self.train() + + mention_tokens = data["mention_tokens"] + labels = data["labels"].float().view(-1) + qids = data["qids"].view(-1) + + logits = self.model.forward(mention_tokens, qids).view(-1) + loss = self.loss_fn(logits, labels) + return loss + + def update_ema(self) -> None: + with torch.no_grad(): + for param, ema_param in zip( + self.classifier.parameters(), self.classifier_ema.parameters() + ): + ema_param.data.mul_(self.ema_decay).add_(param.data, alpha=1 - self.ema_decay) + + def save(self, path: str) -> None: + ema_path = path.replace(".pth", "_ema.pth") + torch.save(self.classifier_ema.state_dict(), ema_path) + torch.save(self.classifier.state_dict(), path) + + def load(self, path: str) -> None: + state_dict = torch.load(path, map_location="cpu") + self.classifier_ema.load_state_dict(state_dict) + self.classifier.load_state_dict(state_dict) + + @torch.inference_mode() + def score( + self, mention: str | Sequence[str], entity_description: str | Sequence[str] + ) -> float | torch.Tensor: + single_pair = isinstance(mention, str) and isinstance(entity_description, str) + + if isinstance(mention, str): + mention_batch = [mention] + elif isinstance(mention, Sequence): + mention_batch = list(mention) + else: + raise TypeError("Mentions must be a string or a sequence of strings.") + + if isinstance(entity_description, str): + entity_batch = [entity_description] + elif isinstance(entity_description, Sequence): + entity_batch = list(entity_description) + else: + raise TypeError("Entity descriptions must be a string or a sequence of strings.") + + if len(mention_batch) != len(entity_batch): + if len(mention_batch) == 1: + mention_batch = mention_batch * len(entity_batch) + single_pair = False + elif len(entity_batch) == 1: + entity_batch = entity_batch * len(mention_batch) + single_pair = False + else: + raise ValueError( + "Mention and entity batches must be the same length or broadcastable." + ) + + mention_tokens = self.tokenizer( + mention_batch, + padding=True, + truncation=True, + return_tensors="pt", + )["input_ids"] + entity_tokens = self.tokenizer( + entity_batch, + padding=True, + truncation=True, + return_tensors="pt", + )["input_ids"] + + probabilities = self.score_from_tokens(mention_tokens, entity_tokens) + if not isinstance(probabilities, torch.Tensor): + probabilities = torch.as_tensor(probabilities) + + probabilities = probabilities.reshape(-1).detach().cpu() + + if single_pair: + return float(probabilities[0]) + return probabilities + + @torch.inference_mode() + def score_from_tokens(self, mentions: torch.Tensor, qids: torch.Tensor) -> torch.Tensor: + logits = self.model.forward(mentions, qids) + probability = torch.sigmoid(logits).reshape(-1) + return probability + + def classifier_forward(self, x): + return self.model.classifier_forward(x) + + +class _PairwiseMLPReranker(nn.Module): + def __init__( + self, + base_model: nn.Module, + classifier: nn.Module, + qid_to_paraphrase_emb: Dict[int, torch.Tensor], + qid_to_base_emb: Dict[int, torch.Tensor], + ) -> None: + super().__init__() + self.base_model = base_model + self.classifier = classifier + self.qid_to_paraphrase_emb = qid_to_paraphrase_emb + self.qid_to_base_emb = qid_to_base_emb + + def forward( + self, + mention_tokens: torch.Tensor, + qids: torch.Tensor, + return_embeddings: bool = False, + ) -> torch.Tensor: + + mention_embeddings = self._encode(mention_tokens) + paraphrase_embeddings = torch.stack( + [self.qid_to_paraphrase_emb[int(qid)] for qid in qids], dim=0 + ).to(mention_embeddings.device) + base_embeddings = torch.stack([self.qid_to_base_emb[int(qid)] for qid in qids], dim=0).to( + mention_embeddings.device + ) + + combined = torch.cat([mention_embeddings, paraphrase_embeddings, base_embeddings], dim=-1) + logits = self.classifier(combined).squeeze(-1) + if return_embeddings: + return logits, mention_embeddings, paraphrase_embeddings, base_embeddings + return logits + + def train(self, mode: bool = True) -> _PairwiseMLPReranker: + super().train(mode) + # Make sure that base model is never trained. + self.base_model.eval() + return self + + def classifier_forward(self, x): + return self.classifier(x) + + @torch.inference_mode() + def _encode(self, tokens: Mapping[str, torch.Tensor] | torch.Tensor) -> torch.Tensor: + if isinstance(tokens, Mapping): + input_ids = tokens["input_ids"] + attention_mask = tokens.get("attention_mask") + if attention_mask is None: + attention_mask = create_attention_mask(input_ids) + else: + input_ids = tokens + attention_mask = create_attention_mask(tokens) + + return self.base_model(input_ids=input_ids, attention_mask=attention_mask) diff --git a/src/reranking/models/pairwise_mlp_with_retrieval_score.py b/src/reranking/models/pairwise_mlp_with_retrieval_score.py new file mode 100644 index 0000000..b143866 --- /dev/null +++ b/src/reranking/models/pairwise_mlp_with_retrieval_score.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +from typing import Sequence + +import torch + +from reranking.models.pairwise_mlp import PairwiseMLPReranker +from utils.model_factory import ModelOutputType + + +class PairwiseMLPRerankerWithRetrievalScore(PairwiseMLPReranker): + """Reranking model that augments a LEALLA encoder with an MLP head.""" + + def __init__( + self, + model_name_or_path: str, + *, + state_dict_path: str | None = None, + target_dim: int | None = None, + output_type: ModelOutputType | str | None = None, + tokenizer_name_or_path: str | None = None, + mlp_hidden_dim: int | None = None, + dropout: float = 0.1, + ) -> None: + super().__init__( + model_name_or_path=model_name_or_path, + state_dict_path=state_dict_path, + target_dim=target_dim, + output_type=output_type, + tokenizer_name_or_path=tokenizer_name_or_path, + mlp_hidden_dim=mlp_hidden_dim, + dropout=dropout, + ) + + @torch.inference_mode() + def score( + self, mention: str | Sequence[str], entity_description: str | Sequence[str] + ) -> float | torch.Tensor: + return super().score(mention, entity_description) + + @torch.inference_mode() + def score_from_tokens(self, mentions: torch.Tensor, entities: torch.Tensor) -> float: + logits, mention_embeddings, entity_embeddings = self.model.forward( + mentions, entities, return_embeddings=True + ) + probability1 = torch.sigmoid(logits) + probability2 = torch.sigmoid(mention_embeddings @ entity_embeddings.T) + print(probability1, probability2) + probability = (probability1 + probability2.diagonal()) / 2 + return probability diff --git a/src/reranking/training/reranking_iterable_dataset.py b/src/reranking/training/reranking_iterable_dataset.py new file mode 100644 index 0000000..b41a427 --- /dev/null +++ b/src/reranking/training/reranking_iterable_dataset.py @@ -0,0 +1,71 @@ +"""Minimal iterable dataset for streaming reranking training data.""" + +from pathlib import Path +from typing import Iterator, List, Tuple + +import numpy as np +import torch +from torch.utils.data import IterableDataset, get_worker_info + + +class RerankingIterableDataset(IterableDataset): + """Yield ``(link_tokens, description_tokens, labels, qids)`` from NPZ files. + + Each worker spawned by ``DataLoader`` receives a disjoint stride of the + underlying NPZ shards, ensuring the samples are not duplicated when + ``num_workers > 0``. + """ + + def __init__( + self, + data_dir: str | Path = "~/troja/outputs/reranking_test/reranker_dataset_with_qids", + ) -> None: + super().__init__() + self.data_dir = Path(data_dir).expanduser() + if not self.data_dir.is_dir(): + raise FileNotFoundError(f"Dataset directory not found: {self.data_dir}") + + self._files: List[Path] = sorted(self.data_dir.glob("*.npz")) + if not self._files: + raise FileNotFoundError( + f"No NPZ files found in {self.data_dir}; expected reranking shards" + ) + + def __iter__(self) -> Iterator[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: + worker_info = get_worker_info() + if worker_info is None: + worker_id = 0 + num_workers = 1 + else: + worker_id = worker_info.id + num_workers = worker_info.num_workers + + file_paths = self._files[worker_id::num_workers] + + for file_path in file_paths: + with np.load(file_path, allow_pickle=False) as data: + qids = torch.from_numpy(data["qids"]).long() + labels = torch.from_numpy(data["y"]).float() + description_tokens = torch.from_numpy(data["description_tokens"]).long() + link_tokens = torch.from_numpy(data["link_tokens"]).long() + + if not (len(qids) == len(labels) == len(description_tokens) == len(link_tokens)): + raise ValueError( + "Mismatched array lengths in NPZ file " + f"{file_path}: qids={len(qids)} labels={len(labels)} " + f"description_tokens={len(description_tokens)} link_tokens={len(link_tokens)}" + ) + + permutation = torch.randperm(len(qids)) + qids = qids[permutation] + labels = labels[permutation] + description_tokens = description_tokens[permutation] + link_tokens = link_tokens[permutation] + + for idx in range(len(qids)): + yield ( + link_tokens[idx], + description_tokens[idx], + labels[idx], + qids[idx], + ) diff --git a/src/reranking/training/trainer.py b/src/reranking/training/trainer.py index 7d8027b..4c8611a 100644 --- a/src/reranking/training/trainer.py +++ b/src/reranking/training/trainer.py @@ -1,19 +1,26 @@ -import copy import logging import os +from itertools import islice -import numpy as np import torch +import wandb +from reranking.models.pairwise_mlp import PairwiseMLPReranker +from tests.utils.test_embeddings import model + torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True +torch.backends.cudnn.benchmark = True import torch.distributed as dist import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler -from reranking.training.training_configs import TrainingConfig, pairwise_mlp +from reranking.training.training_configs import ( + TrainingConfig, + get_config_from_name, +) # Settings =========================================== @@ -29,9 +36,9 @@ device = torch.device("cpu") -def setup(rank, world_size): +def setup(rank, world_size, master_port: str = "12355"): os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "12355" + os.environ["MASTER_PORT"] = master_port # initialize the process group dist.init_process_group("nccl", rank=rank, world_size=world_size) @@ -48,42 +55,50 @@ def cleanup(): def _ddp_train( rank: int, world_size: int, - training_config: TrainingConfig, - epochs, + config_name: str, gradient_clip=1.0, + master_port: str = "12355", ): - setup(rank, world_size) + setup(rank, world_size, master_port) + + is_the_main_process = rank == 0 + + _logger.info("Loading training configuration") + training_config = get_config_from_name(config_name) + _logger.info(f"Training configuration loaded: {training_config.config_name}") + + if is_the_main_process: + wandb.init( + project="EL-reranking_train_ddp_process_0", + config={ + "config_name": training_config.config_name, + "batch_size": training_config.batch_size, + "save_each": training_config.save_each, + "validate_each": training_config.validate_each, + "validation_size": training_config.validation_size, + "output_dir": training_config.output_dir, + "epochs": training_config.epochs, + }, + ) model = training_config.model dataset = training_config.dataset optimizer = training_config.optimizer - validation_dataset = torch.utils.data.Subset( - dataset, - indices=np.arange(training_config.validation_size), - ) - train_dataset = torch.utils.data.Subset( + dataloader = DataLoader( dataset, - indices=np.arange(training_config.validation_size, len(dataset)), - ) - - sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True) - - val_dataloader = DataLoader( - validation_dataset, batch_size=training_config.batch_size, - shuffle=False, pin_memory=True, - num_workers=2, + num_workers=4, ) - copied_model = copy.deepcopy(model) + num_validation_batches = training_config.validation_size // training_config.batch_size + validation_batches = list(islice(iter(dataloader), num_validation_batches)) + val_dataloader = validation_batches model.to(rank) - model.model = DDP(model.model, device_ids=[rank]) model = torch.compile(model) - - is_the_main_process = rank == 0 + model.model = DDP(model.model, device_ids=[rank]) scaler = torch.amp.GradScaler("cuda") @@ -98,31 +113,22 @@ def step(current_loss): global_step = 0 - for epoch in range(epochs): + for epoch in range(training_config.epochs): if is_the_main_process: - _logger.info(f"Starting epoch {epoch + 1}/{epochs}") - - # Ensure proper shuffling across epochs with DistributedSampler - sampler.set_epoch(epoch) - - dataloader = DataLoader( - train_dataset, - batch_size=training_config.batch_size, - sampler=sampler, - pin_memory=True, - num_workers=2, - ) + _logger.info(f"Starting epoch {epoch + 1}/{training_config.epochs}") - for links, entities, labels in dataloader: + for links, entities, labels, qids in dataloader: global_step += 1 links = links.to(rank, non_blocking=True) entities = entities.to(rank, non_blocking=True) labels = labels.to(rank, non_blocking=True) + qids = qids.to(rank, non_blocking=True) batch_data = { "mention_tokens": links, "entity_tokens": entities, "labels": labels, + "qids": qids, } with torch.autocast(device_type="cuda"): @@ -131,9 +137,9 @@ def step(current_loss): if is_the_main_process and global_step % training_config.save_each == 0: path = training_config.get_output_path(global_step) _logger.info(f"Saving model at step {global_step} to {path}") - copied_model.model = model.model.module - torch.save(copied_model.state_dict(), path) + model.save(path) if is_the_main_process and global_step % 500 == 0: + wandb.log({"train/loss": loss.item()}, step=global_step) _logger.info(f"Step {global_step}, loss: {loss.item():.4f}") if is_the_main_process and global_step % training_config.validate_each == 0: model.eval() @@ -143,18 +149,26 @@ def step(current_loss): total_loss = 0.0 val_steps = 0 - for links, entities, labels in val_dataloader: + for links, entities, labels, qids in val_dataloader: links = links.to(rank, non_blocking=True) entities = entities.to(rank, non_blocking=True) labels = labels.to(rank, non_blocking=True) + qids = qids.to(rank, non_blocking=True) val_steps += 1 with torch.inference_mode(): - probs = model.score(links, entities).view(-1) + if ( + training_config.config_name == "pairwise_mlp" + or training_config.config_name == "pairwise_mlp_debug" + or training_config.config_name == "full_lealla" + ): + probs = model.score_from_tokens(links, entities) + else: + probs = model.score_from_tokens(links, qids) loss = torch.nn.functional.binary_cross_entropy(probs, labels.float()) total_loss += loss.item() - predictions = (torch.sigmoid(probs) > 0.5).long() + predictions = (probs > 0.5).long() correct += (predictions == labels).sum().item() total += labels.size(0) if is_the_main_process: @@ -165,36 +179,25 @@ def step(current_loss): if is_the_main_process: _logger.info(f"Epoch {epoch + 1} finished.") + model.save(training_config.get_output_path(global_step)) cleanup() -def get_config_from_name(config_name: str) -> TrainingConfig: - if config_name == "pairwise_mlp": - return pairwise_mlp() - else: - raise ValueError(f"Unknown training configuration: {config_name}") - - # Training =========================================== -def train_ddp(config_name: str): +def train_ddp(config_name: str, master_port: str = "12355"): _logger.info("Starting DDP training") gradient_clip = 1.0 - epochs = 10 world_size = torch.cuda.device_count() _logger.debug(f"Using {world_size} GPUs for training") - _logger.info("Loading training configuration") - training_config = get_config_from_name(config_name) - _logger.info(f"Training configuration loaded: {training_config.config_name}") - mp.spawn( _ddp_train, args=( world_size, - training_config, - epochs, + config_name, gradient_clip, + master_port, ), nprocs=world_size, ) diff --git a/src/reranking/training/trainer_simple.py b/src/reranking/training/trainer_simple.py new file mode 100644 index 0000000..53a4267 --- /dev/null +++ b/src/reranking/training/trainer_simple.py @@ -0,0 +1,163 @@ +import logging +from contextlib import nullcontext +from itertools import islice + +import torch + +import wandb + +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +torch.backends.cudnn.benchmark = True +from torch.utils.data import DataLoader + +from reranking.training.training_configs import get_config_from_name + +# Settings =========================================== + + +_logger = logging.getLogger("reranking.train.trainer") + + +if torch.cuda.is_available(): + _logger.debug("Running on CUDA.") + device = torch.device("cuda") +else: + _logger.debug("CUDA is not available.") + device = torch.device("cpu") + + +SEED = 0 +torch.manual_seed(SEED) + + +def train( + config_name: str, + gradient_clip: float = 1.0, +): + _logger.info("Loading training configuration") + training_config = get_config_from_name(config_name) + _logger.info(f"Training configuration loaded: {training_config.config_name}") + + model = training_config.model + dataset = training_config.dataset + optimizer = training_config.optimizer + + dataloader = DataLoader( + dataset, + batch_size=training_config.batch_size, + pin_memory=True, + num_workers=4, + ) + + num_validation_batches = training_config.validation_size // training_config.batch_size + validation_batches = list(islice(iter(dataloader), num_validation_batches)) + val_dataloader = validation_batches + + model.to(device) + model = torch.compile(model) + + use_amp = device.type == "cuda" + scaler = torch.amp.GradScaler(device.type) if use_amp else None + + def step(current_loss): + if scaler is not None: + scaler.scale(current_loss).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip) + scaler.step(optimizer) + scaler.update() + else: + current_loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip) + optimizer.step() + optimizer.zero_grad() + + global_step = 0 + + for epoch in range(training_config.epochs): + _logger.info(f"Starting epoch {epoch + 1}/{training_config.epochs}") + + for links, entities, labels, qids in dataloader: + global_step += 1 + links = links.to(device, non_blocking=True) + entities = entities.to(device, non_blocking=True) + labels = labels.to(device, non_blocking=True) + qids = qids.to(device, non_blocking=True) + + batch_data = { + "mention_tokens": links, + "entity_tokens": entities, + "labels": labels, + "qids": qids, + } + + autocast_context = torch.autocast(device_type="cuda") if use_amp else nullcontext() + with autocast_context: + loss = model.train_step(batch_data) + step(loss) + if global_step % training_config.save_each == 0: + path = training_config.get_output_path(global_step) + _logger.info(f"Saving model at step {global_step} to {path}") + model.save(path) + if global_step % 500 == 0: + wandb.log({"train/loss": loss.item()}, step=global_step) + _logger.info(f"Step {global_step}, loss: {loss.item():.4f}") + if global_step % training_config.validate_each == 0: + model.eval() + + correct = 0 + total = 0 + + total_loss = 0.0 + val_steps = 0 + for links, entities, labels, qids in val_dataloader: + links = links.to(device, non_blocking=True) + entities = entities.to(device, non_blocking=True) + labels = labels.to(device, non_blocking=True) + qids = qids.to(device, non_blocking=True) + + val_steps += 1 + + with torch.inference_mode(): + if ( + training_config.config_name == "pairwise_mlp" + or training_config.config_name == "pairwise_mlp_debug" + or training_config.config_name == "full_lealla" + ): + probs = model.score_from_tokens(links, entities) + else: + probs = model.score_from_tokens(links, qids) + loss = torch.nn.functional.binary_cross_entropy(probs, labels.float()) + total_loss += loss.item() + predictions = (probs > 0.5).long() + correct += (predictions == labels).sum().item() + total += labels.size(0) + val_loss = total_loss / max(val_steps, 1) + accuracy = correct / max(total, 1) + _logger.info(f"Validation loss: {val_loss:.4f}") + _logger.info(f"Validation accuracy: {accuracy:.4f}") + wandb.log( + { + "validation/loss": val_loss, + "validation/accuracy": accuracy, + }, + step=global_step, + ) + + model.train() + + _logger.info(f"Epoch {epoch + 1} finished.") + model.save(training_config.get_output_path(global_step)) + + +# Training =========================================== +def train_ddp(config_name: str, master_port: str = "12355"): + _logger.warning( + "train_ddp is deprecated and now runs single-device training. master_port is ignored." + ) + train(config_name) + + +if __name__ == "__main__": + train("pairwise_mlp") diff --git a/src/reranking/training/training_configs.py b/src/reranking/training/training_configs.py index fee2721..796f795 100644 --- a/src/reranking/training/training_configs.py +++ b/src/reranking/training/training_configs.py @@ -6,7 +6,14 @@ import torch from reranking.models.base import BaseRerankingModel +from reranking.models.context_emb_with_attention import ContextEmbWithAttention +from reranking.models.full_lealla import FullLEALLAReranker from reranking.models.pairwise_mlp import PairwiseMLPReranker +from reranking.models.pairwise_mlp_with_large_context_emb import ( + PairwiseMLPRerankerWithLargeContextEmb, +) +from reranking.training.reranking_iterable_dataset import RerankingIterableDataset +from utils.loaders import load_embs_and_qids @dataclass @@ -20,6 +27,7 @@ class TrainingConfig: save_each: int validate_each: int validation_size: int = 100000 + epochs: int = 1 def get_output_path(self, step: int) -> str: dir_path = Path(self.output_dir) / self.config_name @@ -27,27 +35,126 @@ def get_output_path(self, step: int) -> str: return f"{dir_path}/{step}.pth" -def pairwise_mlp() -> TrainingConfig: +def pairwise_mlp( + LR: float = 0.0001, + SAVE_EACH: int = 5000, + BATCH_SIZE: int = 1024, + VALIDATE_EACH: int = 10000, + VALIDATION_SIZE: int = 10000, + DROPOUT: float = 0.5, +) -> TrainingConfig: name = "pairwise_mlp" - LR = 0.0001 - SAVE_EACH = 10000 - BATCH_SIZE = 1024 - VALIDATE_EACH = 5000 - VALIDATION_SIZE = 1000 model = PairwiseMLPReranker( model_name_or_path="/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base", state_dict_path="/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/models_5/ema.pth", mlp_hidden_dim=2048, + dropout=DROPOUT, + ) + + dataset = RerankingIterableDataset() + + optimizer = torch.optim.AdamW(model.parameters(), lr=LR, fused=True) + output_dir = "/lnet/work/home-students-external/farhan/troja/outputs/reranking_models" + return TrainingConfig( + config_name=name, + model=model, + dataset=dataset, + optimizer=optimizer, + output_dir=output_dir, + save_each=SAVE_EACH, + batch_size=BATCH_SIZE, + validate_each=VALIDATE_EACH, + validation_size=VALIDATION_SIZE, + ) + + +def full_lealla( + LR: float = 0.0001, + SAVE_EACH: int = 5000, + BATCH_SIZE: int = 1536, + VALIDATE_EACH: int = 10000, + VALIDATION_SIZE: int = 10000, + DROPOUT: float = 0.1, +) -> TrainingConfig: + name = "full_lealla" + model = FullLEALLAReranker( + model_name_or_path="/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base", + dropout=DROPOUT, + ) + + dataset = RerankingIterableDataset() + + optimizer = torch.optim.AdamW(model.parameters(), lr=LR, fused=True) + output_dir = "/lnet/work/home-students-external/farhan/troja/outputs/reranking_models" + return TrainingConfig( + config_name=name, + model=model, + dataset=dataset, + optimizer=optimizer, + output_dir=output_dir, + save_each=SAVE_EACH, + batch_size=BATCH_SIZE, + validate_each=VALIDATE_EACH, + validation_size=VALIDATION_SIZE, + ) + + +def full_lealla_r( + LR: float = 0.0001, + SAVE_EACH: int = 5000, + BATCH_SIZE: int = 2048, + VALIDATE_EACH: int = 10000, + VALIDATION_SIZE: int = 10000, + DROPOUT: float = 0.1, +) -> TrainingConfig: + name = "full_lealla_r" + model = FullLEALLAReranker( + model_name_or_path="/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base", + state_dict_path="/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/models_5/ema.pth", + dropout=DROPOUT, + ) + + dataset = RerankingIterableDataset() + + optimizer = torch.optim.AdamW(model.parameters(), lr=LR, fused=True) + output_dir = "/lnet/work/home-students-external/farhan/troja/outputs/reranking_models" + return TrainingConfig( + config_name=name, + model=model, + dataset=dataset, + optimizer=optimizer, + output_dir=output_dir, + save_each=SAVE_EACH, + batch_size=BATCH_SIZE, + validate_each=VALIDATE_EACH, + validation_size=VALIDATION_SIZE, + ) + + +def pairwise_mlp_noise( + LR: float = 0.0001, + SAVE_EACH: int = 20000, + BATCH_SIZE: int = 1024, + VALIDATE_EACH: int = 10000, + VALIDATION_SIZE: int = 10000, +) -> TrainingConfig: + name = "pairwise_mlp_noise" + model = PairwiseMLPReranker( + model_name_or_path="/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base", + state_dict_path="/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/models_5/ema.pth", + mlp_hidden_dim=2048, + emb_noise=0.1, ) data = np.load( "/lnet/work/home-students-external/farhan/troja/outputs/reranking_test/reranker_dataset_with_qids.npz" ) + # qids = torch.from_numpy(data["qids"]).long() labels = torch.from_numpy(data["y"]).float() description_tokens = torch.from_numpy(data["description_tokens"]).long() link_tokens = torch.from_numpy(data["link_tokens"]).long() dataset = torch.utils.data.TensorDataset(link_tokens, description_tokens, labels) - optimizer = torch.optim.AdamW(model.parameters(), lr=LR) + optimizer = torch.optim.AdamW(model.parameters(), lr=LR, fused=True) output_dir = "/lnet/work/home-students-external/farhan/troja/outputs/reranking_models" return TrainingConfig( config_name=name, @@ -60,3 +167,161 @@ def pairwise_mlp() -> TrainingConfig: validate_each=VALIDATE_EACH, validation_size=VALIDATION_SIZE, ) + + +def pairwise_mlp_noise_dropout( + LR: float = 0.0001, + SAVE_EACH: int = 20000, + BATCH_SIZE: int = 1024, + VALIDATE_EACH: int = 10000, + VALIDATION_SIZE: int = 10000, +) -> TrainingConfig: + name = "pairwise_mlp_noise_dropout" + model = PairwiseMLPReranker( + model_name_or_path="/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base", + state_dict_path="/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/models_5/ema.pth", + mlp_hidden_dim=2048, + emb_noise=0.1, + dropout=0.5, + ) + data = np.load( + "/lnet/work/home-students-external/farhan/troja/outputs/reranking_test/reranker_dataset_with_qids.npz" + ) + # qids = torch.from_numpy(data["qids"]).long() + labels = torch.from_numpy(data["y"]).float() + description_tokens = torch.from_numpy(data["description_tokens"]).long() + link_tokens = torch.from_numpy(data["link_tokens"]).long() + + dataset = torch.utils.data.TensorDataset(link_tokens, description_tokens, labels) + optimizer = torch.optim.AdamW(model.parameters(), lr=LR, fused=True) + output_dir = "/lnet/work/home-students-external/farhan/troja/outputs/reranking_models" + return TrainingConfig( + config_name=name, + model=model, + dataset=dataset, + optimizer=optimizer, + output_dir=output_dir, + save_each=SAVE_EACH, + batch_size=BATCH_SIZE, + validate_each=VALIDATE_EACH, + validation_size=VALIDATION_SIZE, + ) + + +def pairwise_mlp_paraphrase( + LR: float = 0.0001, + SAVE_EACH: int = 5000, + BATCH_SIZE: int = 1024, + VALIDATE_EACH: int = 10000, + VALIDATION_SIZE: int = 10000, + DROPOUT: float = 0.5, +) -> TrainingConfig: + name = "pairwise_mlp_paraphrase" + + paraphrase_embs, paraphrase_qids = load_embs_and_qids( + "/lnet/work/home-students-external/farhan/troja/outputs/paraphrase_multilig_index" + ) + paraphrase_embs = torch.from_numpy(paraphrase_embs) + base_embs, base_qids = load_embs_and_qids( + "/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/damuel_for_index_6" + ) + base_embs = torch.from_numpy(base_embs) + + dataset = RerankingIterableDataset() + output_dir = "/lnet/work/home-students-external/farhan/troja/outputs/reranking_models" + + qid_to_paraphrase_emb = {qid: emb for qid, emb in zip(paraphrase_qids, paraphrase_embs)} + qid_to_base_emb = {qid: emb for qid, emb in zip(base_qids, base_embs)} + + model = PairwiseMLPRerankerWithLargeContextEmb( + model_name_or_path="/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base", + state_dict_path="/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/models_5/ema.pth", + dropout=DROPOUT, + qid_to_paraphrase_emb=qid_to_paraphrase_emb, + qid_to_base_emb=qid_to_base_emb, + ) + + optimizer = torch.optim.AdamW(model.parameters(), lr=LR, fused=True) + + return TrainingConfig( + config_name=name, + model=model, + dataset=dataset, + optimizer=optimizer, + output_dir=output_dir, + save_each=SAVE_EACH, + batch_size=BATCH_SIZE, + validate_each=VALIDATE_EACH, + validation_size=VALIDATION_SIZE, + ) + + +def context_emb_with_attention( + LR: float = 0.0001, + SAVE_EACH: int = 5000, + BATCH_SIZE: int = 1024, + VALIDATE_EACH: int = 1000, + VALIDATION_SIZE: int = 10000, + DROPOUT: float = 0.1, +) -> TrainingConfig: + name = "context_emb_with_attention" + + paraphrase_embs, paraphrase_qids = load_embs_and_qids( + "/lnet/work/home-students-external/farhan/troja/outputs/paraphrase_multilig_index" + ) + paraphrase_embs = torch.from_numpy(paraphrase_embs) + base_embs, base_qids = load_embs_and_qids( + "/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/damuel_for_index_6" + ) + base_embs = torch.from_numpy(base_embs) + + dataset = RerankingIterableDataset() + output_dir = "/lnet/work/home-students-external/farhan/troja/outputs/reranking_models" + + qid_to_paraphrase_emb = {qid: emb for qid, emb in zip(paraphrase_qids, paraphrase_embs)} + qid_to_base_emb = {qid: emb for qid, emb in zip(base_qids, base_embs)} + + model = ContextEmbWithAttention( + model_name_or_path="/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base", + state_dict_path="/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/models_5/ema.pth", + dropout=DROPOUT, + qid_to_paraphrase_emb=qid_to_paraphrase_emb, + qid_to_base_emb=qid_to_base_emb, + ) + + optimizer = torch.optim.AdamW(model.parameters(), lr=LR, fused=True) + + return TrainingConfig( + config_name=name, + model=model, + dataset=dataset, + optimizer=optimizer, + output_dir=output_dir, + save_each=SAVE_EACH, + batch_size=BATCH_SIZE, + validate_each=VALIDATE_EACH, + validation_size=VALIDATION_SIZE, + ) + + +def get_config_from_name(config_name: str) -> TrainingConfig: + if config_name == "pairwise_mlp": + return pairwise_mlp() + if config_name == "pairwise_mlp_debug": + return pairwise_mlp(VALIDATE_EACH=1000, SAVE_EACH=1000000000000) + if config_name == "pairwise_mlp_noise": + return pairwise_mlp_noise() + if config_name == "pairwise_mlp_noise_dropout": + return pairwise_mlp_noise_dropout() + if config_name == "pairwise_mlp_paraphrase": + return pairwise_mlp_paraphrase() + if config_name == "context_emb_with_attention": + return context_emb_with_attention() + if config_name == "full_lealla": + return full_lealla() + if config_name == "full_lealla_r": + return full_lealla_r() + if config_name == "full_lealla_debug": + return full_lealla(VALIDATE_EACH=1000, SAVE_EACH=1000000000000, BATCH_SIZE=128) + else: + raise ValueError(f"Unknown training configuration: {config_name}") diff --git a/src/scripts/qwen/reranking2.py b/src/scripts/qwen/reranking2.py index 9e681a4..a801242 100644 --- a/src/scripts/qwen/reranking2.py +++ b/src/scripts/qwen/reranking2.py @@ -71,7 +71,7 @@ def main(): mlp_hidden_dim=2048, ) state_dict = torch.load( - "/lnet/work/home-students-external/farhan/troja/outputs/reranking_models/pairwise_mlp/100000.pth", + "/lnet/work/home-students-external/farhan/troja/outputs/reranking_models/pairwise_mlp_noise/20000.pth", map_location=device, ) reranker.load_state_dict(state_dict) @@ -81,7 +81,7 @@ def main(): mewsli_tokens = torch.from_numpy(mewsli_tokens) mewsli_qids = torch.from_numpy(mewsli_qids) - decode_mewsli_example = reranking_tokenizer.decode(mewsli_tokens[0]) + decode_mewsli_example = reranking_tokenizer.decode(mewsli_tokens[0], skip_special_tokens=True) print("Decoded mewsli example:", decode_mewsli_example) # Resolve tokens and embeddings (directories or files) @@ -123,7 +123,7 @@ def main(): for batch in mewsli_loader: mewsli_token = batch[0] - qid = batch[1] + qid = batch[1].item() # ensure scalar for a fair comparison with predicted_qid # print(f"Processing Mewsli token: {mewsli_token}, QID: {qid}") tokens = mewsli_token.to(device, dtype=torch.int64) @@ -140,9 +140,11 @@ def main(): neighbor_qids = searcher.find(mewsli_emb, num_neighbors=args.num_neighbors) damuel_candidates = [qid_to_damuel_token[nq] for nq in neighbor_qids[0]] - damuel_candidates_str = [reranking_tokenizer.decode(dc) for dc in damuel_candidates] + damuel_candidates_str = [ + reranking_tokenizer.decode(dc, skip_special_tokens=True) for dc in damuel_candidates + ] - mewsli_str = reranking_tokenizer.decode(mewsli_token[0]) + mewsli_str = reranking_tokenizer.decode(mewsli_token[0], skip_special_tokens=True) scores = [] # print("Mewsli mention:", mewsli_str) @@ -152,7 +154,7 @@ def main(): score = reranker.score(mewsli_str, dc) scores.append(score) # print(scores) - predicted_qid = neighbor_qids[0][scores.index(max(scores))] + predicted_qid = int(neighbor_qids[0][scores.index(max(scores))]) if predicted_qid == qid: good += 1 total += 1 diff --git a/src/scripts/qwen/reranking3.py b/src/scripts/qwen/reranking3.py new file mode 100644 index 0000000..9938267 --- /dev/null +++ b/src/scripts/qwen/reranking3.py @@ -0,0 +1,265 @@ +import argparse +from pathlib import Path + +import numpy as np +import torch +from einops import rearrange, repeat +from torch.utils.data import DataLoader, Dataset +from transformers import AutoTokenizer + +from models.searchers.brute_force_searcher import BruteForceSearcher +from reranking.models.full_lealla import FullLEALLAReranker +from reranking.models.pairwise_mlp import PairwiseMLPReranker +from reranking.models.pairwise_mlp_with_large_context_emb import ( + PairwiseMLPRerankerWithLargeContextEmb, +) +from reranking.models.pairwise_mlp_with_retrieval_score import PairwiseMLPRerankerWithRetrievalScore +from scripts.qwen.reranker import Reranker +from utils.embeddings import create_attention_mask +from utils.loaders import load_embs_and_qids, load_tokens_qids, load_tokens_qids_from_dir +from utils.model_factory import ModelFactory + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +LOGIT_MULTIPLIER = 20.0 + + +def main(): + parser = argparse.ArgumentParser(description="Reranking for entity linking") + parser.add_argument( + "--damuel_token", + type=str, + default="/lnet/work/home-students-external/farhan/troja/outputs/v2_normal_filtered/descs_pages", + help="Path to damuel token file or directory", + ) + parser.add_argument( + "--damuel_embs", + type=str, + default="/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/damuel_for_index_6", + help="Path to damuel embeddings directory or .npz file", + ) + parser.add_argument( + "--mewsli_tokens", + type=str, + default="/lnet/work/home-students-external/farhan/troja/outputs/tokens_mewsli_finetuning", + help="Path to mewsli token file or directory", + ) + parser.add_argument( + "--qwen_model_name", + type=str, + default="Qwen/Qwen3-Reranker-0.6B", + help="Name of the QWEN model", + ) + parser.add_argument( + "--reranking_model_path", + type=str, + default="/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/models_5/ema.pth", + ) + parser.add_argument( + "--num_neighbors", + type=int, + default=10, + help="Number of neighbors to retrieve from the searcher", + ) + parser.add_argument( + "--batch_size", + type=int, + default=32, + help="Batch size for DataLoader", + ) + args = parser.parse_args() + + reranking_model = ModelFactory.auto_load_from_file( + "/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base", + args.reranking_model_path, + ) + reranking_model.eval() + reranking_model.to(device) + + reranking_tokenizer = AutoTokenizer.from_pretrained( + "/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base" + ) + + # paraphrase_embs, paraphrase_qids = load_embs_and_qids( + # "/lnet/work/home-students-external/farhan/troja/outputs/paraphrase_multilig_index" + # ) + # paraphrase_embs = torch.from_numpy(paraphrase_embs) + # base_embs, base_qids = load_embs_and_qids( + # "/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/damuel_for_index_6" + # ) + # base_embs = torch.from_numpy(base_embs) + + # reranker = PairwiseMLPReranker( + # "/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base", + # state_dict_path=args.reranking_model_path, + # mlp_hidden_dim=2048, + # ) + # qid_to_paraphrase_emb = {qid: emb for qid, emb in zip(paraphrase_qids, paraphrase_embs)} + # qid_to_base_emb = {qid: emb for qid, emb in zip(base_qids, base_embs)} + # reranker = PairwiseMLPRerankerWithLargeContextEmb( + # model_name_or_path="/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base", + # state_dict_path="/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/models_5/ema.pth", + # qid_to_paraphrase_emb=qid_to_paraphrase_emb, + # qid_to_base_emb=qid_to_base_emb, + # ) + reranker = FullLEALLAReranker( + model_name_or_path="/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base", + ) + # reranker.load( + # "/lnet/work/home-students-external/farhan/troja/outputs/reranking_models/pairwise_mlp/70000.pth", + # ) + reranker.load( + "/lnet/work/home-students-external/farhan/troja/outputs/reranking_models/full_lealla/30000.pth", + ) + reranker.eval() + reranker.to(device) + + # Resolve tokens and embeddings (directories or files) + damuel_tokens, damuel_token_qids = load_tokens_qids_from_dir(args.damuel_token, verbose=True) + damuel_embs, damuel_qids = load_embs_and_qids(args.damuel_embs) + + qid_to_damuel_emb = {qid: emb for qid, emb in zip(damuel_qids, damuel_embs)} + qid_to_damuel_token = {qid: token for qid, token in zip(damuel_token_qids, damuel_tokens)} + + del damuel_token_qids + + # Take first four names (tokens) from damuel as a quick smoke-test + damuel_tokens_preview = damuel_tokens[:4] + print("First 4 damuel tokens:", damuel_tokens_preview) + + # Create searcher using damuel embeddings and damuel qids + searcher = BruteForceSearcher(damuel_embs, damuel_qids) + print("Searcher created.") + + # Initialize reranker model (actual reranking logic to be implemented later) + # reranker = Reranker(model_name=args.qwen_model_name) + # print(f"Reranker initialized with model: {reranker.model_name}") + + # Stub for QWEN model loading + print("QWEN model to be used:", args.qwen_model_name) + + print("jupi") + + mewsli_root = Path(args.mewsli_tokens) + if not mewsli_root.exists(): + raise FileNotFoundError(f"MEWSLI tokens path not found: {mewsli_root}") + + language_paths: list[tuple[str, Path]] = [] + if mewsli_root.is_file(): + language_name = mewsli_root.parent.name or mewsli_root.stem + language_paths.append((language_name, mewsli_root)) + else: + for subdir in sorted(p for p in mewsli_root.iterdir() if p.is_dir()): + language_paths.append((subdir.name, subdir)) + + if not language_paths: + raise ValueError(f"No MEWSLI language directories or files found under {mewsli_root}") + + for language, mewsli_path in language_paths: + print(f"\nEvaluating language: {language}") + + if mewsli_path.is_dir(): + tokens_np, qids_np = load_tokens_qids_from_dir(mewsli_path, verbose=False) + else: + tokens_np, qids_np = load_tokens_qids(mewsli_path) + + if tokens_np.size == 0: + print(f"No samples found for language {language}, skipping.") + continue + + mewsli_tokens = torch.from_numpy(tokens_np) + mewsli_qids = torch.from_numpy(qids_np) + + mewsli_dataset = torch.utils.data.TensorDataset(mewsli_tokens, mewsli_qids) + mewsli_loader = DataLoader(mewsli_dataset, batch_size=args.batch_size, shuffle=False) + + good = 0 + total = 0 + + for mewsli_tokens_batch, qids_batch in mewsli_loader: + tokens = mewsli_tokens_batch.to(device, dtype=torch.int64) + qids_batch = qids_batch.view(-1).to(torch.long) + + with torch.inference_mode(): + attention_mask = create_attention_mask(tokens).to(device) + mewsli_embs = reranking_model(tokens, attention_mask) + + neighbor_qids = searcher.find( + mewsli_embs.to("cpu").numpy().astype(np.float16), num_neighbors=args.num_neighbors + ) + neighbor_qids = torch.as_tensor(neighbor_qids, dtype=torch.long) + + candidate_embs_lists = [] + candidate_tokens_lists = [] + for row in neighbor_qids.tolist(): + for nq in row: + # emb = qid_to_damuel_emb[int(nq)] + token = qid_to_damuel_token[int(nq)] + # candidate_embs_lists.append(emb) + candidate_tokens_lists.append(token) + + # assert len(candidate_embs_lists) == neighbor_qids.size(0) * neighbor_qids.size(1) + + # candidate_embs = torch.as_tensor( + # candidate_embs_lists, dtype=torch.float16, device=device + # ) + # together = torch.cat( + # (mewsli_embs.repeat_interleave(neighbor_qids.size(1), dim=0), candidate_embs), + # dim=-1, + # ) + # together = together.to(device) + candidate_tokens = torch.as_tensor( + candidate_tokens_lists, dtype=torch.int64, device=device + ) + with torch.inference_mode(): + if isinstance(reranker, PairwiseMLPRerankerWithLargeContextEmb): + scores = reranker.score_from_tokens( + repeat(tokens, "b d -> (b n) d", n=neighbor_qids.size(1)), + rearrange(neighbor_qids, "b n -> (b n)"), + ) + scores = rearrange(scores, "(b n) -> b n", n=neighbor_qids.size(1)) + elif isinstance(reranker, FullLEALLAReranker): + score = reranker.score_from_tokens( + repeat(tokens, "b d -> (b n) d", n=neighbor_qids.size(1)), candidate_tokens + ) + scores = rearrange(score, "(b n) -> b n", n=neighbor_qids.size(1)) + else: + logits = reranker.classifier_forward(together) + scores = torch.sigmoid(logits).reshape( + neighbor_qids.size(0), neighbor_qids.size(1) + ) + if ( + isinstance(reranker, PairwiseMLPRerankerWithRetrievalScore) + # or isinstance(reranker, PairwiseMLPReranker) + # or isinstance(reranker, PairwiseMLPRerankerWithLargeContextEmb) + ): + candidate_embs = candidate_embs.reshape( + neighbor_qids.size(0), neighbor_qids.size(1), -1 + ) + out = ( + torch.einsum( + "abc,ac->ab", + candidate_embs.to(torch.bfloat16), + mewsli_embs.to(torch.bfloat16), + ) + # * LOGIT_MULTIPLIER + ) + out = out.reshape(neighbor_qids.size(0), neighbor_qids.size(1)) + scores = (scores + torch.sigmoid(out)) / 2 + + max_indices = scores.argmax(dim=1).cpu() + predicted_qids = neighbor_qids[torch.arange(neighbor_qids.size(0)), max_indices] + + good += (predicted_qids.cpu() == qids_batch.cpu()).sum().item() + total += qids_batch.numel() + + if total == 0: + print(f"No valid predictions for language {language}.") + continue + + final_accuracy = round(good / total * 100, 4) + print(f"Final accuracy for {language}: {final_accuracy}") + + +if __name__ == "__main__": + main() diff --git a/tests/reranking/models/test_pairwise_with_retrieval_score.py b/tests/reranking/models/test_pairwise_with_retrieval_score.py new file mode 100644 index 0000000..668306d --- /dev/null +++ b/tests/reranking/models/test_pairwise_with_retrieval_score.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +import math +from typing import Sequence + +import pytest +import torch +from torch import nn +from transformers.tokenization_utils_base import BatchEncoding + +from reranking.models import PairwiseMLPRerankerWithRetrievalScore +from utils.embeddings import create_attention_mask + +MENTION_TEXTS = ["dummy mention positive", "dummy mention negative"] +ENTITY_TEXTS = ["dummy entity positive", "dummy entity negative"] + + +class DummyEmbeddingModel(nn.Module): + output_dim = 2 + + def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: # type: ignore[override] + masked = input_ids.float() * attention_mask + summed = masked.sum(dim=1, keepdim=True) + return torch.cat([summed, 2 * summed], dim=1) + + +class DummyTokenizer: + _TOKEN_MAP = { + MENTION_TEXTS[0]: [1, 1, 0, 0], + MENTION_TEXTS[1]: [1, 0, 0, 0], + ENTITY_TEXTS[0]: [3, 0, 0, 0], + ENTITY_TEXTS[1]: [0, 0, 0, 0], + } + + def __call__( + self, + texts: str | Sequence[str], + *, + padding: bool = True, + truncation: bool = True, + return_tensors: str = "pt", + ) -> BatchEncoding: + del padding, truncation, return_tensors + items = [texts] if isinstance(texts, str) else list(texts) + encodings = [self._TOKEN_MAP.get(text, [1, 0, 0, 0]) for text in items] + tensor = torch.tensor(encodings, dtype=torch.long) + attention_mask = (tensor != 0).long() + return BatchEncoding({"input_ids": tensor, "attention_mask": attention_mask}) + + +def _configure_classifier(model: PairwiseMLPRerankerWithRetrievalScore) -> None: + hidden_layer = model.classifier[0] + output_layer = model.classifier[-1] + + with torch.no_grad(): + hidden_layer.weight.copy_(torch.tensor([[0.1, 0.2, 0.3, 0.4]], dtype=torch.float)) + hidden_layer.bias.copy_(torch.tensor([0.5], dtype=torch.float)) + output_layer.weight.copy_(torch.tensor([[1.0]], dtype=torch.float)) + output_layer.bias.zero_() + + +def _prepare_labels(size: int, *, positive_index: int = 0) -> torch.Tensor: + labels = torch.zeros(size, dtype=torch.float) + labels[positive_index] = 1.0 + return labels + + +@pytest.fixture() +def dummy_model(monkeypatch) -> PairwiseMLPRerankerWithRetrievalScore: + dummy_embedding_model = DummyEmbeddingModel() + + def _mock_model_loader(*_args, **_kwargs): + return dummy_embedding_model + + monkeypatch.setattr( + "reranking.models.pairwise_mlp.ModelFactory.auto_load_from_file", _mock_model_loader + ) + monkeypatch.setattr( + "reranking.models.pairwise_mlp.AutoTokenizer.from_pretrained", + lambda *_args, **_kwargs: DummyTokenizer(), + ) + + model = PairwiseMLPRerankerWithRetrievalScore( + "dummy-model", + mlp_hidden_dim=1, + dropout=0.0, + ) + _configure_classifier(model) + return model + + +def test_score_not_failing(dummy_model: PairwiseMLPRerankerWithRetrievalScore) -> None: + score = dummy_model.score(MENTION_TEXTS[0], ENTITY_TEXTS[0]) + assert isinstance(score, float) + assert score >= 0.0 + assert score <= 1.0 diff --git a/tests/reranking/training/test_reranking_iterable_dataset.py b/tests/reranking/training/test_reranking_iterable_dataset.py new file mode 100644 index 0000000..6849cb1 --- /dev/null +++ b/tests/reranking/training/test_reranking_iterable_dataset.py @@ -0,0 +1,27 @@ +import numpy as np +import torch + +from src.reranking.training.reranking_iterable_dataset import RerankingIterableDataset + + +def test_reranking_iterable_dataset_iterates_samples(tmp_path): + file_path = tmp_path / "part-000.npz" + data = { + "qids": np.array([10, 20], dtype=np.int64), + "y": np.array([1.0, 0.0], dtype=np.float32), + "description_tokens": np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64), + "link_tokens": np.array([[7, 8], [9, 10]], dtype=np.int64), + } + np.savez(file_path, **data) + + dataset = RerankingIterableDataset(tmp_path) + + samples = list(dataset) + assert len(samples) == 2 + + first = samples[0] + assert isinstance(first, tuple) + assert torch.equal(first[0], torch.tensor([7, 8], dtype=torch.long)) + assert torch.equal(first[1], torch.tensor([1, 2, 3], dtype=torch.long)) + assert torch.equal(first[2], torch.tensor(1.0, dtype=torch.float32)) + assert torch.equal(first[3], torch.tensor(10, dtype=torch.long)) From d1d5ea893465c59662fa4c04df88419af1dab8d6 Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Sun, 12 Oct 2025 10:19:39 +0200 Subject: [PATCH 21/28] feat(embeddings): update model loading to include output type in embs_from_tokens_model_name_and_state_dict --- src/utils/embeddings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils/embeddings.py b/src/utils/embeddings.py index faa1838..300d944 100644 --- a/src/utils/embeddings.py +++ b/src/utils/embeddings.py @@ -208,7 +208,7 @@ def embs_from_tokens_model_name_and_state_dict( target_dim: int | None = None, output_type: str | None = None, ): - model = ModelFactory.auto_load_from_file(model_name, state_dict_path, target_dim) + model = ModelFactory.auto_load_from_file(model_name, state_dict_path, target_dim, output_type) embs_from_tokens_and_model(source_path, model, batch_size, dest_path) From 1b2bba80cbfc5f6b97fff705d756232c361df51c Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Sun, 12 Oct 2025 10:20:11 +0200 Subject: [PATCH 22/28] feat(scripts): add another finetuning script --- .../train/asi_se_to_rozbilo_init_all.sh | 195 ++++++++++++++++++ 1 file changed, 195 insertions(+) create mode 100755 src/scripts/train/asi_se_to_rozbilo_init_all.sh diff --git a/src/scripts/train/asi_se_to_rozbilo_init_all.sh b/src/scripts/train/asi_se_to_rozbilo_init_all.sh new file mode 100755 index 0000000..091c21f --- /dev/null +++ b/src/scripts/train/asi_se_to_rozbilo_init_all.sh @@ -0,0 +1,195 @@ +#!/bin/bash + +# Runs the complete finetuning process. +# Expects tokens to be in the dirs specified below. +# Additionaly, one can specify additional parameters. +# For running, please also set up/fix the path to venv in run_finetuning_action.sh + +set -ueo pipefail + +cd ../../ + +echo "Running all_langs.sh" +echo "Current directory: $(pwd)" + +MODEL_CONFIG_PATH="../configs/lealla_m.gin" +TRAIN_CONFIG_PATH="../configs/train.gin" + +DAMUEL_DESCS_TOKENS_RAW="$OUTPUTS/v2_normal_filtered/descs_pages" +DAMUEL_LINKS_TOKENS_RAW="$OUTPUTS/v2_normal_filtered/links" +MEWSLI_TOKENS_RAW="$OUTPUTS/tokens_mewsli_finetuning" +WORKDIR="$OUTPUTS/workdirs/asi_se_to_rozbilo_init_all" +N_OF_ROUNDS=6 + +run_ml_finetuning_round() { + local DAMUEL_DESCS_TOKENS_RAW=$1 + local DAMUEL_LINKS_TOKENS_RAW=$2 + local MEWSLI_TOKENS_RAW=$3 + local WORKDIR=$4 + local STATE_DICT=${5:-"None"} + local ROUND_ID=${6:-"0"} + local N_OF_ROUNDS=${7} + + local STEPS_PER_EPOCH=1000 + + local LINKS_PER_ROUND=$(($STEPS_PER_EPOCH * 250000 * 8)) + echo "LPR $LINKS_PER_ROUND" + + local ACTION_SCRIPT="run_action_gin.py $MODEL_CONFIG_PATH $TRAIN_CONFIG_PATH" + + ENV="../venv/bin/activate" + source $ENV + + # ====================TOKENS COPY==================== + + local DAMUEL_LINKS_TOKENS="$WORKDIR/damuel_links_together_tokens_$ROUND_ID" + if [ ! "$(ls -A $DAMUEL_LINKS_TOKENS)" ]; then + ../venv/bin/python $ACTION_SCRIPT "copy" \ + --source="$DAMUEL_LINKS_TOKENS_RAW" \ + --dest="$DAMUEL_LINKS_TOKENS" \ + --m="$N_OF_ROUNDS" \ + --r="$ROUND_ID" \ + --max_to_copy="$LINKS_PER_ROUND" + fi + + # ====================DAMUEL DESC EMBS==================== + + local DAMUEL_FOR_INDEX_DIR="$WORKDIR/damuel_for_index_$ROUND_ID" + + mkdir -p "$DAMUEL_FOR_INDEX_DIR" + + if [ ! "$(ls -A $DAMUEL_FOR_INDEX_DIR)" ]; then + echo "Running embs generating for damuel" + ../venv/bin/python $ACTION_SCRIPT "embs_from_tokens_model_name_and_state_dict" \ + --source_path="$DAMUEL_DESCS_TOKENS_RAW" \ + --dest_path="$DAMUEL_FOR_INDEX_DIR" \ + --state_dict_path="$STATE_DICT" + fi + + # ====================DAMUEL LINKS EMBEDDING==================== + + local DAMUEL_LINKS_DIR="$WORKDIR/links_embs_$ROUND_ID" + + mkdir -p "$DAMUEL_LINKS_DIR" + + if [ ! "$(ls -A $DAMUEL_LINKS_DIR)" ]; then + echo "Running embs generating for damuel links" + ../venv/bin/python $ACTION_SCRIPT "embed_links_for_generation" \ + --links_tokens_dir_path="$DAMUEL_LINKS_TOKENS" \ + --dest_dir_path="$DAMUEL_LINKS_DIR" \ + --state_dict_path="$STATE_DICT" + fi + + # ====================GENERATING BATCHES==================== + + local BATCH_DIR="$WORKDIR/batches_$ROUND_ID" + + mkdir -p "$BATCH_DIR" + if [ ! "$(ls -A $BATCH_DIR)" ]; then + echo "Running batches generating for damuel" + # ../venv/bin/python -m cProfile -o "generate.prof" $ACTION_SCRIPT "generate" \ + ../venv/bin/python $ACTION_SCRIPT "generate" \ + --LINKS_EMBS_DIR="$DAMUEL_LINKS_DIR" \ + --INDEX_TOKENS_DIR="$DAMUEL_DESCS_TOKENS_RAW" \ + --INDEX_EMBS_QIDS_DIR="$DAMUEL_FOR_INDEX_DIR" \ + --OUTPUT_DIR="$BATCH_DIR" + fi + + # ====================TRAINING MODEL==================== + + local MODELS_DIR="$WORKDIR/models_$ROUND_ID" + + mkdir -p $MODELS_DIR + + if [ ! "$(ls -A $MODELS_DIR)" ]; then + echo "Running training for damuel" + #../venv/bin/python -m cProfile -o "train_ddp.prof" $ACTION_SCRIPT "train_ddp" \ + echo $ACTION_SCRIPT "train_ddp" \ + --DATASET_DIR="$BATCH_DIR" \ + --MODEL_SAVE_DIR="$MODELS_DIR" \ + --STATE_DICT_PATH="$STATE_DICT" + ../venv/bin/python $ACTION_SCRIPT "train_ddp" \ + --DATASET_DIR="$BATCH_DIR" \ + --MODEL_SAVE_DIR="$MODELS_DIR" \ + --STATE_DICT_PATH="$STATE_DICT" + fi + + # ====================EVALUATION==================== + + local NEXT_INDEX=$(($ROUND_ID + 1)) + local DAMUEL_FOR_INDEX_NEW_DIR="$WORKDIR/damuel_for_index_$NEXT_INDEX" + mkdir -p "$DAMUEL_FOR_INDEX_NEW_DIR" + + if [ ! "$(ls -A $DAMUEL_FOR_INDEX_NEW_DIR)" ]; then + echo "Running embs generating for damuel" + ../venv/bin/python $ACTION_SCRIPT "embs_from_tokens_model_name_and_state_dict" \ + --source_path="$DAMUEL_DESCS_TOKENS_RAW" \ + --dest_path="$DAMUEL_FOR_INDEX_NEW_DIR" \ + --state_dict_path="$MODELS_DIR/ema.pth" + fi + + local LANGUAGES=("ar" "de" "en" "es" "ja" "fa" "sr" "ta" "tr") + + for LANG in "${LANGUAGES[@]}"; do + echo "Processing language: $LANG" + + local LANG_TOKEN_DIR="$MEWSLI_TOKENS_RAW/$LANG" + local MEWSLI_EMBS_DIR="$WORKDIR/mewsli_embs_${LANG}_$ROUND_ID" + + mkdir -p "$MEWSLI_EMBS_DIR" + + if [ ! "$(ls -A $MEWSLI_EMBS_DIR)" ]; then + echo "Running embs generating for mewsli - Language: $LANG" + ../venv/bin/python $ACTION_SCRIPT "embs_from_tokens_model_name_and_state_dict" \ + --source_path="$LANG_TOKEN_DIR" \ + --dest_path="$MEWSLI_EMBS_DIR" \ + --state_dict_path="$MODELS_DIR/ema.pth" + fi + + # ../venv/bin/python $ACTION_SCRIPT "recalls" \ + # --damuel_dir="$DAMUEL_FOR_INDEX_NEW_DIR" \ + # --mewsli_dir="$MEWSLI_EMBS_DIR" + + # echo "Completed processing for language: $LANG" + echo "----------------------------------------" + done + + ../venv/bin/python $ACTION_SCRIPT "evaluate" \ + --root_dir="$WORKDIR" \ + --finetuning_round=$ROUND_ID + + rm -r $BATCH_DIR $DAMUEL_LINKS_DIR $DAMUEL_FOR_INDEX_DIR $DAMUEL_LINKS_TOKENS +} + +if [ ! -L "$WORKDIR" ]; then + mkdir -p "$WORKDIR" +fi + +DAMUEL_DESCS_TOKENS="$WORKDIR/damuel_descs_together_tokens" +if [ ! -L "$DAMUEL_DESCS_TOKENS" ]; then + mkdir -p "$DAMUEL_DESCS_TOKENS" +fi + +for ((ROUND_ID=0; ROUND_ID<$N_OF_ROUNDS; ROUND_ID++)) +do + DAMUEL_LINKS_TOKENS="$WORKDIR/damuel_links_together_tokens_$ROUND_ID" + if [ ! -L "$DAMUEL_LINKS_TOKENS" ]; then + mkdir -p "$DAMUEL_LINKS_TOKENS" + fi +done + +STATE_DICT="None" + +for ((ROUND_ID=0; ROUND_ID<$N_OF_ROUNDS; ROUND_ID++)) +do + if [ ! -e "$WORKDIR/models_$ROUND_ID/final.pth" ]; then + echo "Running round $ROUND_ID" + + run_ml_finetuning_round "$DAMUEL_DESCS_TOKENS_RAW" "$DAMUEL_LINKS_TOKENS_RAW" \ + "$MEWSLI_TOKENS_RAW" \ + "$WORKDIR" "$STATE_DICT" \ + "$ROUND_ID" "$N_OF_ROUNDS" + fi + + STATE_DICT="$WORKDIR/models_$ROUND_ID/ema.pth" +done From 3d32e0d4378fb3714c0c18dfac6f56501779de2c Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Thu, 16 Oct 2025 17:25:17 +0200 Subject: [PATCH 23/28] Add mapping of qids to token matrix and corresponding tests - Implemented `map_qids_to_token_matrix` function in `loaders.py` to create a sparse matrix mapping qids to their token vectors. - Added unit tests for `map_qids_to_token_matrix` in `test_loaders.py` to verify correct mapping and handling of non-existent qids. - Created a new test file `test_change_dataset_tokens.py` to test the `update_tokens_in_file` and `process_directory` functions. - Updated `uv.lock` to include `scipy` as a dependency. --- configs/train.gin | 2 +- pyproject.toml | 1 + src/reranking/models/fusion.py | 265 ++++++++++++++++++ src/reranking/training/trainer_simple.py | 7 +- src/reranking/training/training_configs.py | 108 ++++++- src/scripts/qwen/reranking3.py | 54 ++-- .../reranking/change_dataset_tokens.py | 44 +++ src/utils/loaders.py | 34 +++ .../reranking/test_change_dataset_tokens.py | 68 +++++ tests/utils/test_loaders.py | 29 ++ uv.lock | 73 +++++ 11 files changed, 656 insertions(+), 29 deletions(-) create mode 100644 src/reranking/models/fusion.py create mode 100644 src/scripts/reranking/change_dataset_tokens.py create mode 100644 tests/scripts/reranking/test_change_dataset_tokens.py diff --git a/configs/train.gin b/configs/train.gin index 778c3de..65ac505 100644 --- a/configs/train.gin +++ b/configs/train.gin @@ -1,4 +1,4 @@ -training_batch_size=2688 +training_batch_size=3712 #epochs=300 epochs=100 logit_mutliplier=20 diff --git a/pyproject.toml b/pyproject.toml index 26b521a..ccfd95a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ dependencies = [ "pytest-mock>=3.15.0", "python-fire>=0.1.0", "scann>=1.4.2", + "scipy>=1.16.2", "torch>=2.8.0", "tqdm>=4.67.1", "transformers>=4.56.1", diff --git a/src/reranking/models/fusion.py b/src/reranking/models/fusion.py new file mode 100644 index 0000000..060d045 --- /dev/null +++ b/src/reranking/models/fusion.py @@ -0,0 +1,265 @@ +from __future__ import annotations + +from copy import deepcopy +from typing import Any, Dict, Sequence + +import torch +from einops import rearrange +from torch import nn +from transformers import AutoTokenizer + +from reranking.models.base import BaseRerankingModel +from reranking.models.context_emb_with_attention import GPTLayer +from utils.embeddings import create_attention_mask +from utils.model_factory import ModelFactory, ModelOutputType + + +def _infer_output_dim(model: nn.Module) -> int: + if hasattr(model, "output_dim"): + return int(getattr(model, "output_dim")) + if hasattr(model, "config") and hasattr(model.config, "hidden_size"): + return int(model.config.hidden_size) + if hasattr(model, "model"): + nested_model = getattr(model, "model") + if hasattr(nested_model, "config") and hasattr(nested_model.config, "hidden_size"): + return int(nested_model.config.hidden_size) + raise ValueError("Unable to infer output dimension from the provided base model.") + + +class SimpleCrossAttentionLayer(nn.Module): + """A simple wrapper for PyTorch's MultiheadAttention to perform cross-attention.""" + + def __init__(self, query_dim, key_value_dim, num_heads=8, dropout=0.1): + super().__init__() + # The MHA layer that handles everything, including projections! + # Note: We must set batch_first=True for BERT-style inputs. + self.attention = nn.MultiheadAttention( + embed_dim=query_dim, + num_heads=num_heads, + kdim=key_value_dim, + vdim=key_value_dim, + dropout=dropout, + batch_first=True, # Important for [batch, seq_len, dim] tensors + ) + self.norm = nn.LayerNorm(query_dim) + self.dropout = nn.Dropout(dropout) + + def forward(self, query, key_value, attention_mask=None): + # query: The tensor that asks the questions (from the target model). + # key_value: The tensor that provides the context (from the source model). + + # The MHA layer returns the attention output and weights. We only need the output. + attn_output, _ = self.attention( + query=query, + key=key_value, + value=key_value, + key_padding_mask=attention_mask, # Optional mask for padded tokens + ) + + # Add residual connection and layer norm + output = self.norm(query + self.dropout(attn_output)) + return output + + +class BertFusionModel(nn.Module): + def __init__(self, model1, model2): + super().__init__() + self.base = model1 # The 24-layer model (dim 192) + self.paraphrase = model2 # The 12-layer model (dim 384) + + # Simpler cross-attention layers using the wrapper + self.cross_layers_1_to_2 = nn.ModuleList( + [SimpleCrossAttentionLayer(query_dim=384, key_value_dim=192) for _ in range(12)] + ) + self.cross_layers_2_to_1 = nn.ModuleList( + [SimpleCrossAttentionLayer(query_dim=192, key_value_dim=384) for _ in range(12)] + ) + + def forward(self, input_ids1, attention_mask1, input_ids2, attention_mask2): + # 1. Get initial embeddings + hidden_states1 = self.base.embeddings(input_ids=input_ids1) + hidden_states2 = self.paraphrase.embeddings(input_ids=input_ids2) + + # (Note: You may need to adapt the huggingface attention_mask format for MHA's key_padding_mask) + + # 2. Iteratively process through layers with fusion + for i in range(12): + # Process layers for each model + hidden_states1 = self.base.encoder.layer[2 * i](hidden_states1)[0] + hidden_states1 = self.base.encoder.layer[2 * i + 1](hidden_states1)[0] + hidden_states2 = self.paraphrase.encoder.layer[i](hidden_states2)[0] + + # Store states before the cross-attention step to avoid in-place modification issues + temp_states1, temp_states2 = hidden_states1, hidden_states2 + + # Cross-attention exchange + hidden_states1 = self.cross_layers_2_to_1[i](query=temp_states1, key_value=temp_states2) + hidden_states2 = self.cross_layers_1_to_2[i](query=temp_states2, key_value=temp_states1) + + # 3. Get final outputs + pooled_output1 = self.base.pooler(hidden_states1) + pooled_output2 = self.paraphrase.pooler(hidden_states2) + + return torch.cat([pooled_output1, pooled_output2], dim=-1) + + +class _Model(nn.Module): + def __init__(self, fusion_model: BertFusionModel, dropout: float = 0.1) -> None: + super().__init__() + self.fusion_model = fusion_model + self.embedding_dim = 384 + 192 + self.gpt_layer = GPTLayer(model_width=self.embedding_dim, dropout=dropout) + self.final_layer = nn.Linear(self.embedding_dim, 1) + + def forward( + self, + input_ids1: torch.Tensor, + attention_mask1: torch.Tensor, + input_ids2: torch.Tensor, + attention_mask2: torch.Tensor, + ) -> torch.Tensor: + # print(ids.shape, attention_mask.shape) + base_embeddings = self.fusion_model( + input_ids1, attention_mask1, input_ids2, attention_mask2 + ) + # print("base_embeddings", base_embeddings.shape) + x = self.gpt_layer(base_embeddings) + # print("x", x.shape) + logits = self.final_layer(x).squeeze(-1) + # print("logits", logits.shape) + return logits + + +class FusionReranker(BaseRerankingModel): + """Reranking model that augments a LEALLA encoder with an MLP head.""" + + def __init__( + self, + base_model_name_or_path: str, + paraphrase_model_name_or_path: str, + qid_to_para_toks: Dict[str, torch.Tensor], + *, + base_state_dict_path: str | None = None, + tokenizer_name_or_path: str | None = None, + dropout: float = 0.1, + ema_decay: float = 0.9999, + ) -> None: + super().__init__() + + self.qid_to_para_toks = qid_to_para_toks + + self.ema_decay = ema_decay + + self.base_model = ModelFactory.auto_load_from_file( + base_model_name_or_path, + state_dict_path=base_state_dict_path, + ).model + self.paraphrase_model = ModelFactory.auto_load_from_file( + paraphrase_model_name_or_path, + output_type="sentence_transformer", + ).model + + fusion_model = BertFusionModel(self.base_model, self.paraphrase_model) + + self.model = _Model( + fusion_model=fusion_model, + dropout=dropout, + ) + self.model_ema = deepcopy(self.model) + + tokenizer_id = tokenizer_name_or_path or base_model_name_or_path + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) + self.loss_fn = nn.BCEWithLogitsLoss() + + def forward( + self, + mention_tokens: torch.Tensor, + entity_tokens: torch.Tensor, + ) -> torch.Tensor: + ids, attention_mask = self.prepare_for_forward(mention_tokens, entity_tokens) + return self.model(ids, attention_mask) + + def prepare_for_forward(self, mention_tokens: torch.Tensor, entity_tokens: torch.Tensor): + ids = rearrange([mention_tokens, entity_tokens], "d b n -> b (d n)") + attention_mask = create_attention_mask(ids).to(dtype=ids.dtype, device=ids.device) + return ids, attention_mask + + def train_step_imp(self, data: Dict[str, Any]) -> torch.Tensor: + self.train() + + mention_tokens = data["mention_tokens"] + entity_tokens = data["entity_tokens"] + labels = data["labels"].float() + qids = data["qids"].numpy() + + para_tokens = torch.stack([self.qid_to_para_toks[qid] for qid in qids], dim=0).to( + device=mention_tokens.device + ) + para_attention_mask = create_attention_mask(para_tokens).to( + dtype=para_tokens.dtype, device=para_tokens.device + ) + + ids, attention_mask = self.prepare_for_forward(mention_tokens, entity_tokens) + attention_mask = attention_mask.to(dtype=ids.dtype, device=ids.device) + logits = self.model(ids, attention_mask, para_tokens, para_attention_mask) + + loss = self.loss_fn(logits, labels) + return loss + + def update_ema(self) -> None: + with torch.no_grad(): + for param, ema_param in zip(self.model.parameters(), self.model_ema.parameters()): + ema_param.data.mul_(self.ema_decay).add_(param.data, alpha=1 - self.ema_decay) + + def save(self, path: str) -> None: + ema_path = path.replace(".pth", "_ema.pth") + torch.save(self.model_ema.state_dict(), ema_path) + torch.save(self.model.state_dict(), path) + + def load(self, path: str) -> None: + state_dict = torch.load(path, map_location="cpu") + new_state_dict = {} + for k, v in state_dict.items(): + if k.startswith("_orig_mod.module.model."): + new_k = k.replace("_orig_mod.module.model.", "") + elif k.startswith("module."): + new_k = k.replace("module.", "") + else: + new_k = k + new_state_dict[new_k] = v + state_dict = new_state_dict + self.model_ema.load_state_dict(state_dict) + self.model.load_state_dict(state_dict) + + @torch.inference_mode() + def score( + self, mention: str | Sequence[str], entity_description: str | Sequence[str] + ) -> float | torch.Tensor: + raise NotImplementedError("Use score_from_tokens_and_qids for FusionReranker.") + + @torch.inference_mode() + def score_from_tokens_and_qids( + self, mentions: torch.Tensor, entities: torch.Tensor, qids: torch.Tensor + ) -> torch.Tensor: + qids = qids.numpy() + para_tokens = torch.stack([self.qid_to_para_toks[qid] for qid in qids], dim=0).to( + device=mentions.device + ) + para_attention_mask = create_attention_mask(para_tokens).to( + dtype=para_tokens.dtype, device=para_tokens.device + ) + ids, attention_mask = self.prepare_for_forward(mentions, entities) + attention_mask = attention_mask.to(dtype=ids.dtype, device=ids.device) + logits = self.model(ids, attention_mask, para_tokens, para_attention_mask) + + probabilities = torch.sigmoid(logits).reshape(-1) + + if not isinstance(probabilities, torch.Tensor): + probabilities = torch.as_tensor(probabilities) + + probabilities = probabilities.reshape(-1).detach().cpu() + return probabilities + + @torch.inference_mode() + def score_from_tokens(self, mentions: torch.Tensor, entities: torch.Tensor) -> torch.Tensor: + raise NotImplementedError("Use score_from_tokens_and_qids for FusionReranker.") diff --git a/src/reranking/training/trainer_simple.py b/src/reranking/training/trainer_simple.py index 53a4267..3f7866d 100644 --- a/src/reranking/training/trainer_simple.py +++ b/src/reranking/training/trainer_simple.py @@ -83,7 +83,7 @@ def step(current_loss): links = links.to(device, non_blocking=True) entities = entities.to(device, non_blocking=True) labels = labels.to(device, non_blocking=True) - qids = qids.to(device, non_blocking=True) + # qids = qids.to(device, non_blocking=True) batch_data = { "mention_tokens": links, @@ -115,7 +115,7 @@ def step(current_loss): links = links.to(device, non_blocking=True) entities = entities.to(device, non_blocking=True) labels = labels.to(device, non_blocking=True) - qids = qids.to(device, non_blocking=True) + # qids = qids.to(device, non_blocking=True) val_steps += 1 @@ -124,8 +124,11 @@ def step(current_loss): training_config.config_name == "pairwise_mlp" or training_config.config_name == "pairwise_mlp_debug" or training_config.config_name == "full_lealla" + or training_config.config_name == "full_lealla_r" ): probs = model.score_from_tokens(links, entities) + elif training_config.config_name == "fusion": + probs = model.score_from_tokens_and_qids(links, entities, qids) else: probs = model.score_from_tokens(links, qids) loss = torch.nn.functional.binary_cross_entropy(probs, labels.float()) diff --git a/src/reranking/training/training_configs.py b/src/reranking/training/training_configs.py index 796f795..ae2fb2b 100644 --- a/src/reranking/training/training_configs.py +++ b/src/reranking/training/training_configs.py @@ -8,6 +8,7 @@ from reranking.models.base import BaseRerankingModel from reranking.models.context_emb_with_attention import ContextEmbWithAttention from reranking.models.full_lealla import FullLEALLAReranker +from reranking.models.fusion import FusionReranker from reranking.models.pairwise_mlp import PairwiseMLPReranker from reranking.models.pairwise_mlp_with_large_context_emb import ( PairwiseMLPRerankerWithLargeContextEmb, @@ -100,9 +101,9 @@ def full_lealla( def full_lealla_r( - LR: float = 0.0001, - SAVE_EACH: int = 5000, - BATCH_SIZE: int = 2048, + LR: float = 0.00005, + SAVE_EACH: int = 10000, + BATCH_SIZE: int = 1300, VALIDATE_EACH: int = 10000, VALIDATION_SIZE: int = 10000, DROPOUT: float = 0.1, @@ -131,6 +132,103 @@ def full_lealla_r( ) +def full_lealla_r_128( + LR: float = 0.00005, + SAVE_EACH: int = 10000, + BATCH_SIZE: int = 1300, + VALIDATE_EACH: int = 10000, + VALIDATION_SIZE: int = 10000, + DROPOUT: float = 0.1, +) -> TrainingConfig: + name = "full_lealla_r_128" + + d = np.load("~/troja/outputs/reranking_test/reranker_dataset_with_qids/mentions_5_dataset.npz") + description_tokens = d["description_tokens"] + assert description_tokens.shape[1] == 128, "Expected description tokens to have length 128" + + model = FullLEALLAReranker( + model_name_or_path="/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base", + state_dict_path="/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/models_5/ema.pth", + dropout=DROPOUT, + ) + + dataset = RerankingIterableDataset() + + optimizer = torch.optim.AdamW(model.parameters(), lr=LR, fused=True) + output_dir = "/lnet/work/home-students-external/farhan/troja/outputs/reranking_models" + return TrainingConfig( + config_name=name, + model=model, + dataset=dataset, + optimizer=optimizer, + output_dir=output_dir, + save_each=SAVE_EACH, + batch_size=BATCH_SIZE, + validate_each=VALIDATE_EACH, + validation_size=VALIDATION_SIZE, + ) + + +def fusion( + LR: float = 0.0001, + SAVE_EACH: int = 10000, + BATCH_SIZE: int = 128, + VALIDATE_EACH: int = 10000, + VALIDATION_SIZE: int = 10000, + DROPOUT: float = 0.1, +) -> TrainingConfig: + name = "fusion" + + qid_to_para_toks = {} + from tqdm import tqdm + + for file in tqdm( + os.listdir( + "/lnet/work/home-students-external/farhan/troja/outputs/descriptions_paraphrase_after_multiling_dataset/descs_pages" + ), + total=len( + os.listdir( + "/lnet/work/home-students-external/farhan/troja/outputs/descriptions_paraphrase_after_multiling_dataset/descs_pages" + ) + ), + ): + if file.endswith(".npz"): + d = np.load( + os.path.join( + "/lnet/work/home-students-external/farhan/troja/outputs/descriptions_paraphrase_after_multiling_dataset/descs_pages", + file, + ) + ) + qids = d["qids"] + tokens = torch.from_numpy(d["tokens"]).to(torch.int32) + for qid, token in zip(qids, tokens): + qid_to_para_toks[qid] = token + + model = FusionReranker( + base_model_name_or_path="/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base", + paraphrase_model_name_or_path="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", + base_state_dict_path="/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/models_5/ema.pth", + qid_to_para_toks=qid_to_para_toks, + dropout=DROPOUT, + ) + + dataset = RerankingIterableDataset() + + optimizer = torch.optim.AdamW(model.parameters(), lr=LR, fused=True) + output_dir = "/lnet/work/home-students-external/farhan/troja/outputs/reranking_models" + return TrainingConfig( + config_name=name, + model=model, + dataset=dataset, + optimizer=optimizer, + output_dir=output_dir, + save_each=SAVE_EACH, + batch_size=BATCH_SIZE, + validate_each=VALIDATE_EACH, + validation_size=VALIDATION_SIZE, + ) + + def pairwise_mlp_noise( LR: float = 0.0001, SAVE_EACH: int = 20000, @@ -323,5 +421,9 @@ def get_config_from_name(config_name: str) -> TrainingConfig: return full_lealla_r() if config_name == "full_lealla_debug": return full_lealla(VALIDATE_EACH=1000, SAVE_EACH=1000000000000, BATCH_SIZE=128) + if config_name == "fusion": + return fusion() + if config_name == "full_lealla_r_128": + return full_lealla_r_128() else: raise ValueError(f"Unknown training configuration: {config_name}") diff --git a/src/scripts/qwen/reranking3.py b/src/scripts/qwen/reranking3.py index 9938267..1ea6015 100644 --- a/src/scripts/qwen/reranking3.py +++ b/src/scripts/qwen/reranking3.py @@ -89,11 +89,11 @@ def main(): # ) # base_embs = torch.from_numpy(base_embs) - # reranker = PairwiseMLPReranker( - # "/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base", - # state_dict_path=args.reranking_model_path, - # mlp_hidden_dim=2048, - # ) + reranker = PairwiseMLPReranker( + "/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base", + state_dict_path=args.reranking_model_path, + mlp_hidden_dim=2048, + ) # qid_to_paraphrase_emb = {qid: emb for qid, emb in zip(paraphrase_qids, paraphrase_embs)} # qid_to_base_emb = {qid: emb for qid, emb in zip(base_qids, base_embs)} # reranker = PairwiseMLPRerankerWithLargeContextEmb( @@ -102,15 +102,15 @@ def main(): # qid_to_paraphrase_emb=qid_to_paraphrase_emb, # qid_to_base_emb=qid_to_base_emb, # ) - reranker = FullLEALLAReranker( - model_name_or_path="/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base", - ) - # reranker.load( - # "/lnet/work/home-students-external/farhan/troja/outputs/reranking_models/pairwise_mlp/70000.pth", + # reranker = FullLEALLAReranker( + # model_name_or_path="/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base", # ) reranker.load( - "/lnet/work/home-students-external/farhan/troja/outputs/reranking_models/full_lealla/30000.pth", + "/lnet/work/home-students-external/farhan/troja/outputs/reranking_models/pairwise_mlp/70000.pth", ) + # reranker.load( + # "/lnet/work/home-students-external/farhan/troja/outputs/reranking_models/full_lealla_r/380000.pth", + # ) reranker.eval() reranker.to(device) @@ -174,6 +174,7 @@ def main(): mewsli_loader = DataLoader(mewsli_dataset, batch_size=args.batch_size, shuffle=False) good = 0 + upper_bound_hits = 0 total = 0 for mewsli_tokens_batch, qids_batch in mewsli_loader: @@ -189,25 +190,28 @@ def main(): ) neighbor_qids = torch.as_tensor(neighbor_qids, dtype=torch.long) + retrieval_hits = (neighbor_qids == qids_batch.view(-1, 1)).any(dim=1) + upper_bound_hits += retrieval_hits.sum().item() + candidate_embs_lists = [] candidate_tokens_lists = [] for row in neighbor_qids.tolist(): for nq in row: - # emb = qid_to_damuel_emb[int(nq)] + emb = qid_to_damuel_emb[int(nq)] token = qid_to_damuel_token[int(nq)] - # candidate_embs_lists.append(emb) + candidate_embs_lists.append(emb) candidate_tokens_lists.append(token) # assert len(candidate_embs_lists) == neighbor_qids.size(0) * neighbor_qids.size(1) - # candidate_embs = torch.as_tensor( - # candidate_embs_lists, dtype=torch.float16, device=device - # ) - # together = torch.cat( - # (mewsli_embs.repeat_interleave(neighbor_qids.size(1), dim=0), candidate_embs), - # dim=-1, - # ) - # together = together.to(device) + candidate_embs = torch.as_tensor( + candidate_embs_lists, dtype=torch.float16, device=device + ) + together = torch.cat( + (mewsli_embs.repeat_interleave(neighbor_qids.size(1), dim=0), candidate_embs), + dim=-1, + ) + together = together.to(device) candidate_tokens = torch.as_tensor( candidate_tokens_lists, dtype=torch.int64, device=device ) @@ -230,6 +234,7 @@ def main(): ) if ( isinstance(reranker, PairwiseMLPRerankerWithRetrievalScore) + # or isinstance(reranker, FullLEALLAReranker) # or isinstance(reranker, PairwiseMLPReranker) # or isinstance(reranker, PairwiseMLPRerankerWithLargeContextEmb) ): @@ -242,7 +247,7 @@ def main(): candidate_embs.to(torch.bfloat16), mewsli_embs.to(torch.bfloat16), ) - # * LOGIT_MULTIPLIER + * LOGIT_MULTIPLIER ) out = out.reshape(neighbor_qids.size(0), neighbor_qids.size(1)) scores = (scores + torch.sigmoid(out)) / 2 @@ -258,7 +263,10 @@ def main(): continue final_accuracy = round(good / total * 100, 4) - print(f"Final accuracy for {language}: {final_accuracy}") + retrieval_upper_bound = round(upper_bound_hits / total * 100, 4) + print( + f"Final accuracy for {language}: {final_accuracy} (retrieval upper bound: {retrieval_upper_bound})" + ) if __name__ == "__main__": diff --git a/src/scripts/reranking/change_dataset_tokens.py b/src/scripts/reranking/change_dataset_tokens.py new file mode 100644 index 0000000..79b6b70 --- /dev/null +++ b/src/scripts/reranking/change_dataset_tokens.py @@ -0,0 +1,44 @@ +from pathlib import Path + +import fire +import numpy as np +from scipy.sparse import csr_matrix +from tqdm import tqdm + +from utils.loaders import map_qids_to_token_matrix + + +def update_tokens_in_file(file_path: Path, qid_to_new_tokens: csr_matrix) -> None: + """Updates tokens in a single .npz file and returns if it was modified.""" + with np.load(file_path) as data: + tokens = data["description_tokens"] + qids = data["qids"] + save_data = dict(data) + + for i, qid in enumerate(qids): + if qid_to_new_tokens[qid].nnz > 0: + tokens[i] = qid_to_new_tokens[qid].toarray()[0] + + save_data["description_tokens"] = tokens + np.savez(file_path, **save_data) + + +def process_directory(dataset_dir: Path, new_tokens_dir: Path) -> None: + """Orchestrates the token update process for an entire directory.""" + qid_to_new_tokens = map_qids_to_token_matrix(new_tokens_dir, verbose=True) + files_to_process = list(dataset_dir.glob("*.npz")) + + for fp in tqdm(files_to_process, desc="Updating files"): + update_tokens_in_file(fp, qid_to_new_tokens) + + +def main(dataset_dir, new_tokens_dir): + dataset_dir = Path(dataset_dir) + new_tokens_dir = Path(new_tokens_dir) + print("Starting token update process...") + process_directory(dataset_dir, new_tokens_dir) + print("Update process completed.") + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/src/utils/loaders.py b/src/utils/loaders.py index 44d865b..722ffd1 100644 --- a/src/utils/loaders.py +++ b/src/utils/loaders.py @@ -5,6 +5,7 @@ import gin import numpy as np import pandas as pd +from scipy.sparse import coo_matrix, csr_matrix from tqdm import tqdm # from tokenization.pipeline import DamuelAliasTablePipeline @@ -156,6 +157,39 @@ def load_tokens_and_qids(file_path: str | Path) -> tuple[np.ndarray, np.ndarray] return d["tokens"], d["qids"] +def map_qids_to_token_matrix( + dir_path: str | Path, verbose: bool = False, max_items_to_load: int | None = None +) -> csr_matrix: + """Builds a memory-efficient sparse matrix mapping qids to their token vectors. + + Args: + dir_path (str | Path): Directory containing data files with 'tokens' and 'qids'. + verbose (bool): Forwarded to `load_tokens_qids_from_dir` to toggle progress output. + max_items_to_load (int | None): Optional cap on the number of token rows to read. + + Returns: + scipy.sparse.csr_matrix: A CSR matrix where a row index corresponds to a qid + and the row's data is the token vector. Use + `matrix[qid]` to retrieve a vector. + """ + tokens, qids = load_tokens_qids_from_dir( + dir_path=dir_path, verbose=verbose, max_items_to_load=max_items_to_load + ) + + num_items, vector_len = tokens.shape + + assert num_items == qids.shape[0], "Mismatch between number of token rows and qids" + + row_indices = np.repeat(qids, vector_len) + col_indices = np.tile(np.arange(vector_len), num_items) + data = tokens.flatten() + + shape = (qids.max() + 1, vector_len) + + coo = coo_matrix((data, (row_indices, col_indices)), shape=shape, dtype=tokens.dtype) + return coo.tocsr() + + class AliasTableLoader: """ This class provides methods to load and process alias tables from two different sources: diff --git a/tests/scripts/reranking/test_change_dataset_tokens.py b/tests/scripts/reranking/test_change_dataset_tokens.py new file mode 100644 index 0000000..06e7b24 --- /dev/null +++ b/tests/scripts/reranking/test_change_dataset_tokens.py @@ -0,0 +1,68 @@ +from pathlib import Path +from unittest.mock import MagicMock, patch + +import numpy as np +from scipy.sparse import coo_matrix, csr_matrix + +from scripts.reranking.change_dataset_tokens import process_directory, update_tokens_in_file + + +def _create_npz_file(path: Path, filename: str, **kwargs): + """Helper to create an .npz file for testing.""" + np.savez(path / filename, **kwargs) + + +def _create_sparse_matrix(data_dict: dict, shape: tuple[int, int], dtype=np.float32) -> csr_matrix: + """Helper to build a CSR matrix from a dictionary of {qid: vector}.""" + rows, cols, data = [], [], [] + vector_len = shape[1] + for qid, vector in data_dict.items(): + rows.extend([qid] * vector_len) + cols.extend(range(vector_len)) + data.extend(vector) + return coo_matrix((data, (rows, cols)), shape=shape, dtype=dtype).tocsr() + + +def test_update_tokens_when_match_exists(tmp_path: Path): + """ + Tests that the file is modified and returns True when a qid has a new token. + """ + test_file = tmp_path / "data.npz" + original_tokens = np.array([[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]]) + original_qids = np.array([100, 200, 300]) + _create_npz_file( + tmp_path, "data.npz", description_tokens=original_tokens.copy(), qids=original_qids + ) + + new_token_vector = np.array([99.0, 99.0]) + update_matrix = _create_sparse_matrix({200: new_token_vector}, shape=(400, 2)) + update_tokens_in_file(test_file, update_matrix) + + loaded_data = np.load(test_file) + updated_tokens = loaded_data["description_tokens"] + + assert np.array_equal(updated_tokens[1], new_token_vector) + assert np.array_equal(updated_tokens[0], original_tokens[0]) + assert np.array_equal(updated_tokens[2], original_tokens[2]) + + +@patch("scripts.reranking.change_dataset_tokens.update_tokens_in_file") +@patch("scripts.reranking.change_dataset_tokens.map_qids_to_token_matrix") +def test_process_directory_orchestration( + mock_map_qids: MagicMock, mock_update_file: MagicMock, tmp_path: Path +): + dataset_dir = tmp_path / "dataset" + tokens_dir = tmp_path / "tokens" + dataset_dir.mkdir() + tokens_dir.mkdir() + + file_paths = [dataset_dir / "a.npz", dataset_dir / "b.npz", dataset_dir / "c.npz"] + + mock_map_qids.return_value = csr_matrix((3, 2)) + mock_update_file.side_effect = [True, False, True] + + with patch.object(Path, "glob", return_value=file_paths): + process_directory(dataset_dir, tokens_dir) + + mock_map_qids.assert_called_once_with(tokens_dir, verbose=True) + assert mock_update_file.call_count == len(file_paths) diff --git a/tests/utils/test_loaders.py b/tests/utils/test_loaders.py index 20846bc..9c34b6f 100644 --- a/tests/utils/test_loaders.py +++ b/tests/utils/test_loaders.py @@ -5,6 +5,7 @@ import numpy as np import pandas as pd import pytest +from scipy.sparse import csr_matrix from utils.loaders import ( AliasTableLoader, @@ -14,6 +15,7 @@ load_qids_npy, load_tokens_qids, load_tokens_qids_from_dir, + map_qids_to_token_matrix, ) @@ -413,6 +415,33 @@ def test_load_qids_npy(mock_qids_remap, use_string_path: bool) -> None: assert isinstance(loaded_qids, np.ndarray) +def test_map_qids_to_token_matrix() -> None: + """ + Tests that qids are correctly mapped to their token vectors in a sparse matrix. + """ + with tempfile.TemporaryDirectory() as temp_dir: + dir_path = Path(temp_dir) + + tokens = np.array([[1.0, 1.5], [2.0, 2.5], [3.0, 3.5]]) + qids = np.array([300, 100, 200]) + _create_tokens_qids_npz(dir_path, "tokens_qids.npz", tokens, qids) + + token_matrix = map_qids_to_token_matrix(dir_path) + + assert isinstance(token_matrix, csr_matrix) + assert token_matrix.shape == (301, 2) + + for i, qid in enumerate(qids): + original_vector = tokens[i] + retrieved_vector = token_matrix[qid].toarray()[0] + + assert np.array_equal(original_vector, retrieved_vector) + + non_existent_qid = 150 + zero_vector = token_matrix[non_existent_qid].toarray()[0] + assert np.array_equal(zero_vector, np.zeros(tokens.shape[1])) + + @pytest.mark.parametrize("lowercase", [True, False]) class TestAliasTableLoader: def setup_method(self, lowercase): diff --git a/uv.lock b/uv.lock index 67f1d6b..a4b6bf2 100644 --- a/uv.lock +++ b/uv.lock @@ -501,6 +501,7 @@ dependencies = [ { name = "pytest-mock" }, { name = "python-fire" }, { name = "scann" }, + { name = "scipy" }, { name = "torch" }, { name = "tqdm" }, { name = "transformers" }, @@ -521,6 +522,7 @@ requires-dist = [ { name = "pytest-mock", specifier = ">=3.15.0" }, { name = "python-fire", specifier = ">=0.1.0" }, { name = "scann", specifier = ">=1.4.2" }, + { name = "scipy", specifier = ">=1.16.2" }, { name = "torch", specifier = ">=2.8.0" }, { name = "tqdm", specifier = ">=4.67.1" }, { name = "transformers", specifier = ">=4.56.1" }, @@ -1260,6 +1262,77 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3b/b4/3decfd7039399b6bd9c9fbf0ccda39301bac01c39a09a5a791c8237f5d26/scann-1.4.2-cp313-cp313-manylinux_2_27_x86_64.whl", hash = "sha256:c87e97f91c98d7d1f0bf985b39634e07e1149ba79c20f7dbf9b7b465c94100f2", size = 11579811, upload-time = "2025-08-29T14:32:25.101Z" }, ] +[[package]] +name = "scipy" +version = "1.16.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4c/3b/546a6f0bfe791bbb7f8d591613454d15097e53f906308ec6f7c1ce588e8e/scipy-1.16.2.tar.gz", hash = "sha256:af029b153d243a80afb6eabe40b0a07f8e35c9adc269c019f364ad747f826a6b", size = 30580599, upload-time = "2025-09-11T17:48:08.271Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0b/ef/37ed4b213d64b48422df92560af7300e10fe30b5d665dd79932baebee0c6/scipy-1.16.2-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:6ab88ea43a57da1af33292ebd04b417e8e2eaf9d5aa05700be8d6e1b6501cd92", size = 36619956, upload-time = "2025-09-11T17:39:20.5Z" }, + { url = "https://files.pythonhosted.org/packages/85/ab/5c2eba89b9416961a982346a4d6a647d78c91ec96ab94ed522b3b6baf444/scipy-1.16.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:c95e96c7305c96ede73a7389f46ccd6c659c4da5ef1b2789466baeaed3622b6e", size = 28931117, upload-time = "2025-09-11T17:39:29.06Z" }, + { url = "https://files.pythonhosted.org/packages/80/d1/eed51ab64d227fe60229a2d57fb60ca5898cfa50ba27d4f573e9e5f0b430/scipy-1.16.2-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:87eb178db04ece7c698220d523c170125dbffebb7af0345e66c3554f6f60c173", size = 20921997, upload-time = "2025-09-11T17:39:34.892Z" }, + { url = "https://files.pythonhosted.org/packages/be/7c/33ea3e23bbadde96726edba6bf9111fb1969d14d9d477ffa202c67bec9da/scipy-1.16.2-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:4e409eac067dcee96a57fbcf424c13f428037827ec7ee3cb671ff525ca4fc34d", size = 23523374, upload-time = "2025-09-11T17:39:40.846Z" }, + { url = "https://files.pythonhosted.org/packages/96/0b/7399dc96e1e3f9a05e258c98d716196a34f528eef2ec55aad651ed136d03/scipy-1.16.2-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:e574be127bb760f0dad24ff6e217c80213d153058372362ccb9555a10fc5e8d2", size = 33583702, upload-time = "2025-09-11T17:39:49.011Z" }, + { url = "https://files.pythonhosted.org/packages/1a/bc/a5c75095089b96ea72c1bd37a4497c24b581ec73db4ef58ebee142ad2d14/scipy-1.16.2-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f5db5ba6188d698ba7abab982ad6973265b74bb40a1efe1821b58c87f73892b9", size = 35883427, upload-time = "2025-09-11T17:39:57.406Z" }, + { url = "https://files.pythonhosted.org/packages/ab/66/e25705ca3d2b87b97fe0a278a24b7f477b4023a926847935a1a71488a6a6/scipy-1.16.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ec6e74c4e884104ae006d34110677bfe0098203a3fec2f3faf349f4cb05165e3", size = 36212940, upload-time = "2025-09-11T17:40:06.013Z" }, + { url = "https://files.pythonhosted.org/packages/d6/fd/0bb911585e12f3abdd603d721d83fc1c7492835e1401a0e6d498d7822b4b/scipy-1.16.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:912f46667d2d3834bc3d57361f854226475f695eb08c08a904aadb1c936b6a88", size = 38865092, upload-time = "2025-09-11T17:40:15.143Z" }, + { url = "https://files.pythonhosted.org/packages/d6/73/c449a7d56ba6e6f874183759f8483cde21f900a8be117d67ffbb670c2958/scipy-1.16.2-cp311-cp311-win_amd64.whl", hash = "sha256:91e9e8a37befa5a69e9cacbe0bcb79ae5afb4a0b130fd6db6ee6cc0d491695fa", size = 38687626, upload-time = "2025-09-11T17:40:24.041Z" }, + { url = "https://files.pythonhosted.org/packages/68/72/02f37316adf95307f5d9e579023c6899f89ff3a051fa079dbd6faafc48e5/scipy-1.16.2-cp311-cp311-win_arm64.whl", hash = "sha256:f3bf75a6dcecab62afde4d1f973f1692be013110cad5338007927db8da73249c", size = 25503506, upload-time = "2025-09-11T17:40:30.703Z" }, + { url = "https://files.pythonhosted.org/packages/b7/8d/6396e00db1282279a4ddd507c5f5e11f606812b608ee58517ce8abbf883f/scipy-1.16.2-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:89d6c100fa5c48472047632e06f0876b3c4931aac1f4291afc81a3644316bb0d", size = 36646259, upload-time = "2025-09-11T17:40:39.329Z" }, + { url = "https://files.pythonhosted.org/packages/3b/93/ea9edd7e193fceb8eef149804491890bde73fb169c896b61aa3e2d1e4e77/scipy-1.16.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:ca748936cd579d3f01928b30a17dc474550b01272d8046e3e1ee593f23620371", size = 28888976, upload-time = "2025-09-11T17:40:46.82Z" }, + { url = "https://files.pythonhosted.org/packages/91/4d/281fddc3d80fd738ba86fd3aed9202331180b01e2c78eaae0642f22f7e83/scipy-1.16.2-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:fac4f8ce2ddb40e2e3d0f7ec36d2a1e7f92559a2471e59aec37bd8d9de01fec0", size = 20879905, upload-time = "2025-09-11T17:40:52.545Z" }, + { url = "https://files.pythonhosted.org/packages/69/40/b33b74c84606fd301b2915f0062e45733c6ff5708d121dd0deaa8871e2d0/scipy-1.16.2-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:033570f1dcefd79547a88e18bccacff025c8c647a330381064f561d43b821232", size = 23553066, upload-time = "2025-09-11T17:40:59.014Z" }, + { url = "https://files.pythonhosted.org/packages/55/a7/22c739e2f21a42cc8f16bc76b47cff4ed54fbe0962832c589591c2abec34/scipy-1.16.2-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ea3421209bf00c8a5ef2227de496601087d8f638a2363ee09af059bd70976dc1", size = 33336407, upload-time = "2025-09-11T17:41:06.796Z" }, + { url = "https://files.pythonhosted.org/packages/53/11/a0160990b82999b45874dc60c0c183d3a3a969a563fffc476d5a9995c407/scipy-1.16.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f66bd07ba6f84cd4a380b41d1bf3c59ea488b590a2ff96744845163309ee8e2f", size = 35673281, upload-time = "2025-09-11T17:41:15.055Z" }, + { url = "https://files.pythonhosted.org/packages/96/53/7ef48a4cfcf243c3d0f1643f5887c81f29fdf76911c4e49331828e19fc0a/scipy-1.16.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:5e9feab931bd2aea4a23388c962df6468af3d808ddf2d40f94a81c5dc38f32ef", size = 36004222, upload-time = "2025-09-11T17:41:23.868Z" }, + { url = "https://files.pythonhosted.org/packages/49/7f/71a69e0afd460049d41c65c630c919c537815277dfea214031005f474d78/scipy-1.16.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:03dfc75e52f72cf23ec2ced468645321407faad8f0fe7b1f5b49264adbc29cb1", size = 38664586, upload-time = "2025-09-11T17:41:31.021Z" }, + { url = "https://files.pythonhosted.org/packages/34/95/20e02ca66fb495a95fba0642fd48e0c390d0ece9b9b14c6e931a60a12dea/scipy-1.16.2-cp312-cp312-win_amd64.whl", hash = "sha256:0ce54e07bbb394b417457409a64fd015be623f36e330ac49306433ffe04bc97e", size = 38550641, upload-time = "2025-09-11T17:41:36.61Z" }, + { url = "https://files.pythonhosted.org/packages/92/ad/13646b9beb0a95528ca46d52b7babafbe115017814a611f2065ee4e61d20/scipy-1.16.2-cp312-cp312-win_arm64.whl", hash = "sha256:2a8ffaa4ac0df81a0b94577b18ee079f13fecdb924df3328fc44a7dc5ac46851", size = 25456070, upload-time = "2025-09-11T17:41:41.3Z" }, + { url = "https://files.pythonhosted.org/packages/c1/27/c5b52f1ee81727a9fc457f5ac1e9bf3d6eab311805ea615c83c27ba06400/scipy-1.16.2-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:84f7bf944b43e20b8a894f5fe593976926744f6c185bacfcbdfbb62736b5cc70", size = 36604856, upload-time = "2025-09-11T17:41:47.695Z" }, + { url = "https://files.pythonhosted.org/packages/32/a9/15c20d08e950b540184caa8ced675ba1128accb0e09c653780ba023a4110/scipy-1.16.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:5c39026d12edc826a1ef2ad35ad1e6d7f087f934bb868fc43fa3049c8b8508f9", size = 28864626, upload-time = "2025-09-11T17:41:52.642Z" }, + { url = "https://files.pythonhosted.org/packages/4c/fc/ea36098df653cca26062a627c1a94b0de659e97127c8491e18713ca0e3b9/scipy-1.16.2-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:e52729ffd45b68777c5319560014d6fd251294200625d9d70fd8626516fc49f5", size = 20855689, upload-time = "2025-09-11T17:41:57.886Z" }, + { url = "https://files.pythonhosted.org/packages/dc/6f/d0b53be55727f3e6d7c72687ec18ea6d0047cf95f1f77488b99a2bafaee1/scipy-1.16.2-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:024dd4a118cccec09ca3209b7e8e614931a6ffb804b2a601839499cb88bdf925", size = 23512151, upload-time = "2025-09-11T17:42:02.303Z" }, + { url = "https://files.pythonhosted.org/packages/11/85/bf7dab56e5c4b1d3d8eef92ca8ede788418ad38a7dc3ff50262f00808760/scipy-1.16.2-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:7a5dc7ee9c33019973a470556081b0fd3c9f4c44019191039f9769183141a4d9", size = 33329824, upload-time = "2025-09-11T17:42:07.549Z" }, + { url = "https://files.pythonhosted.org/packages/da/6a/1a927b14ddc7714111ea51f4e568203b2bb6ed59bdd036d62127c1a360c8/scipy-1.16.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c2275ff105e508942f99d4e3bc56b6ef5e4b3c0af970386ca56b777608ce95b7", size = 35681881, upload-time = "2025-09-11T17:42:13.255Z" }, + { url = "https://files.pythonhosted.org/packages/c1/5f/331148ea5780b4fcc7007a4a6a6ee0a0c1507a796365cc642d4d226e1c3a/scipy-1.16.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:af80196eaa84f033e48444d2e0786ec47d328ba00c71e4299b602235ffef9acb", size = 36006219, upload-time = "2025-09-11T17:42:18.765Z" }, + { url = "https://files.pythonhosted.org/packages/46/3a/e991aa9d2aec723b4a8dcfbfc8365edec5d5e5f9f133888067f1cbb7dfc1/scipy-1.16.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9fb1eb735fe3d6ed1f89918224e3385fbf6f9e23757cacc35f9c78d3b712dd6e", size = 38682147, upload-time = "2025-09-11T17:42:25.177Z" }, + { url = "https://files.pythonhosted.org/packages/a1/57/0f38e396ad19e41b4c5db66130167eef8ee620a49bc7d0512e3bb67e0cab/scipy-1.16.2-cp313-cp313-win_amd64.whl", hash = "sha256:fda714cf45ba43c9d3bae8f2585c777f64e3f89a2e073b668b32ede412d8f52c", size = 38520766, upload-time = "2025-09-11T17:43:25.342Z" }, + { url = "https://files.pythonhosted.org/packages/1b/a5/85d3e867b6822d331e26c862a91375bb7746a0b458db5effa093d34cdb89/scipy-1.16.2-cp313-cp313-win_arm64.whl", hash = "sha256:2f5350da923ccfd0b00e07c3e5cfb316c1c0d6c1d864c07a72d092e9f20db104", size = 25451169, upload-time = "2025-09-11T17:43:30.198Z" }, + { url = "https://files.pythonhosted.org/packages/09/d9/60679189bcebda55992d1a45498de6d080dcaf21ce0c8f24f888117e0c2d/scipy-1.16.2-cp313-cp313t-macosx_10_14_x86_64.whl", hash = "sha256:53d8d2ee29b925344c13bda64ab51785f016b1b9617849dac10897f0701b20c1", size = 37012682, upload-time = "2025-09-11T17:42:30.677Z" }, + { url = "https://files.pythonhosted.org/packages/83/be/a99d13ee4d3b7887a96f8c71361b9659ba4ef34da0338f14891e102a127f/scipy-1.16.2-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:9e05e33657efb4c6a9d23bd8300101536abd99c85cca82da0bffff8d8764d08a", size = 29389926, upload-time = "2025-09-11T17:42:35.845Z" }, + { url = "https://files.pythonhosted.org/packages/bf/0a/130164a4881cec6ca8c00faf3b57926f28ed429cd6001a673f83c7c2a579/scipy-1.16.2-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:7fe65b36036357003b3ef9d37547abeefaa353b237e989c21027b8ed62b12d4f", size = 21381152, upload-time = "2025-09-11T17:42:40.07Z" }, + { url = "https://files.pythonhosted.org/packages/47/a6/503ffb0310ae77fba874e10cddfc4a1280bdcca1d13c3751b8c3c2996cf8/scipy-1.16.2-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:6406d2ac6d40b861cccf57f49592f9779071655e9f75cd4f977fa0bdd09cb2e4", size = 23914410, upload-time = "2025-09-11T17:42:44.313Z" }, + { url = "https://files.pythonhosted.org/packages/fa/c7/1147774bcea50d00c02600aadaa919facbd8537997a62496270133536ed6/scipy-1.16.2-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ff4dc42bd321991fbf611c23fc35912d690f731c9914bf3af8f417e64aca0f21", size = 33481880, upload-time = "2025-09-11T17:42:49.325Z" }, + { url = "https://files.pythonhosted.org/packages/6a/74/99d5415e4c3e46b2586f30cdbecb95e101c7192628a484a40dd0d163811a/scipy-1.16.2-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:654324826654d4d9133e10675325708fb954bc84dae6e9ad0a52e75c6b1a01d7", size = 35791425, upload-time = "2025-09-11T17:42:54.711Z" }, + { url = "https://files.pythonhosted.org/packages/1b/ee/a6559de7c1cc710e938c0355d9d4fbcd732dac4d0d131959d1f3b63eb29c/scipy-1.16.2-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:63870a84cd15c44e65220eaed2dac0e8f8b26bbb991456a033c1d9abfe8a94f8", size = 36178622, upload-time = "2025-09-11T17:43:00.375Z" }, + { url = "https://files.pythonhosted.org/packages/4e/7b/f127a5795d5ba8ece4e0dce7d4a9fb7cb9e4f4757137757d7a69ab7d4f1a/scipy-1.16.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:fa01f0f6a3050fa6a9771a95d5faccc8e2f5a92b4a2e5440a0fa7264a2398472", size = 38783985, upload-time = "2025-09-11T17:43:06.661Z" }, + { url = "https://files.pythonhosted.org/packages/3e/9f/bc81c1d1e033951eb5912cd3750cc005943afa3e65a725d2443a3b3c4347/scipy-1.16.2-cp313-cp313t-win_amd64.whl", hash = "sha256:116296e89fba96f76353a8579820c2512f6e55835d3fad7780fece04367de351", size = 38631367, upload-time = "2025-09-11T17:43:14.44Z" }, + { url = "https://files.pythonhosted.org/packages/d6/5e/2cc7555fd81d01814271412a1d59a289d25f8b63208a0a16c21069d55d3e/scipy-1.16.2-cp313-cp313t-win_arm64.whl", hash = "sha256:98e22834650be81d42982360382b43b17f7ba95e0e6993e2a4f5b9ad9283a94d", size = 25787992, upload-time = "2025-09-11T17:43:19.745Z" }, + { url = "https://files.pythonhosted.org/packages/8b/ac/ad8951250516db71619f0bd3b2eb2448db04b720a003dd98619b78b692c0/scipy-1.16.2-cp314-cp314-macosx_10_14_x86_64.whl", hash = "sha256:567e77755019bb7461513c87f02bb73fb65b11f049aaaa8ca17cfaa5a5c45d77", size = 36595109, upload-time = "2025-09-11T17:43:35.713Z" }, + { url = "https://files.pythonhosted.org/packages/ff/f6/5779049ed119c5b503b0f3dc6d6f3f68eefc3a9190d4ad4c276f854f051b/scipy-1.16.2-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:17d9bb346194e8967296621208fcdfd39b55498ef7d2f376884d5ac47cec1a70", size = 28859110, upload-time = "2025-09-11T17:43:40.814Z" }, + { url = "https://files.pythonhosted.org/packages/82/09/9986e410ae38bf0a0c737ff8189ac81a93b8e42349aac009891c054403d7/scipy-1.16.2-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:0a17541827a9b78b777d33b623a6dcfe2ef4a25806204d08ead0768f4e529a88", size = 20850110, upload-time = "2025-09-11T17:43:44.981Z" }, + { url = "https://files.pythonhosted.org/packages/0d/ad/485cdef2d9215e2a7df6d61b81d2ac073dfacf6ae24b9ae87274c4e936ae/scipy-1.16.2-cp314-cp314-macosx_14_0_x86_64.whl", hash = "sha256:d7d4c6ba016ffc0f9568d012f5f1eb77ddd99412aea121e6fa8b4c3b7cbad91f", size = 23497014, upload-time = "2025-09-11T17:43:49.074Z" }, + { url = "https://files.pythonhosted.org/packages/a7/74/f6a852e5d581122b8f0f831f1d1e32fb8987776ed3658e95c377d308ed86/scipy-1.16.2-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9702c4c023227785c779cba2e1d6f7635dbb5b2e0936cdd3a4ecb98d78fd41eb", size = 33401155, upload-time = "2025-09-11T17:43:54.661Z" }, + { url = "https://files.pythonhosted.org/packages/d9/f5/61d243bbc7c6e5e4e13dde9887e84a5cbe9e0f75fd09843044af1590844e/scipy-1.16.2-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d1cdf0ac28948d225decdefcc45ad7dd91716c29ab56ef32f8e0d50657dffcc7", size = 35691174, upload-time = "2025-09-11T17:44:00.101Z" }, + { url = "https://files.pythonhosted.org/packages/03/99/59933956331f8cc57e406cdb7a483906c74706b156998f322913e789c7e1/scipy-1.16.2-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:70327d6aa572a17c2941cdfb20673f82e536e91850a2e4cb0c5b858b690e1548", size = 36070752, upload-time = "2025-09-11T17:44:05.619Z" }, + { url = "https://files.pythonhosted.org/packages/c6/7d/00f825cfb47ee19ef74ecf01244b43e95eae74e7e0ff796026ea7cd98456/scipy-1.16.2-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:5221c0b2a4b58aa7c4ed0387d360fd90ee9086d383bb34d9f2789fafddc8a936", size = 38701010, upload-time = "2025-09-11T17:44:11.322Z" }, + { url = "https://files.pythonhosted.org/packages/e4/9f/b62587029980378304ba5a8563d376c96f40b1e133daacee76efdcae32de/scipy-1.16.2-cp314-cp314-win_amd64.whl", hash = "sha256:f5a85d7b2b708025af08f060a496dd261055b617d776fc05a1a1cc69e09fe9ff", size = 39360061, upload-time = "2025-09-11T17:45:09.814Z" }, + { url = "https://files.pythonhosted.org/packages/82/04/7a2f1609921352c7fbee0815811b5050582f67f19983096c4769867ca45f/scipy-1.16.2-cp314-cp314-win_arm64.whl", hash = "sha256:2cc73a33305b4b24556957d5857d6253ce1e2dcd67fa0ff46d87d1670b3e1e1d", size = 26126914, upload-time = "2025-09-11T17:45:14.73Z" }, + { url = "https://files.pythonhosted.org/packages/51/b9/60929ce350c16b221928725d2d1d7f86cf96b8bc07415547057d1196dc92/scipy-1.16.2-cp314-cp314t-macosx_10_14_x86_64.whl", hash = "sha256:9ea2a3fed83065d77367775d689401a703d0f697420719ee10c0780bcab594d8", size = 37013193, upload-time = "2025-09-11T17:44:16.757Z" }, + { url = "https://files.pythonhosted.org/packages/2a/41/ed80e67782d4bc5fc85a966bc356c601afddd175856ba7c7bb6d9490607e/scipy-1.16.2-cp314-cp314t-macosx_12_0_arm64.whl", hash = "sha256:7280d926f11ca945c3ef92ba960fa924e1465f8d07ce3a9923080363390624c4", size = 29390172, upload-time = "2025-09-11T17:44:21.783Z" }, + { url = "https://files.pythonhosted.org/packages/c4/a3/2f673ace4090452696ccded5f5f8efffb353b8f3628f823a110e0170b605/scipy-1.16.2-cp314-cp314t-macosx_14_0_arm64.whl", hash = "sha256:8afae1756f6a1fe04636407ef7dbece33d826a5d462b74f3d0eb82deabefd831", size = 21381326, upload-time = "2025-09-11T17:44:25.982Z" }, + { url = "https://files.pythonhosted.org/packages/42/bf/59df61c5d51395066c35836b78136accf506197617c8662e60ea209881e1/scipy-1.16.2-cp314-cp314t-macosx_14_0_x86_64.whl", hash = "sha256:5c66511f29aa8d233388e7416a3f20d5cae7a2744d5cee2ecd38c081f4e861b3", size = 23915036, upload-time = "2025-09-11T17:44:30.527Z" }, + { url = "https://files.pythonhosted.org/packages/91/c3/edc7b300dc16847ad3672f1a6f3f7c5d13522b21b84b81c265f4f2760d4a/scipy-1.16.2-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:efe6305aeaa0e96b0ccca5ff647a43737d9a092064a3894e46c414db84bc54ac", size = 33484341, upload-time = "2025-09-11T17:44:35.981Z" }, + { url = "https://files.pythonhosted.org/packages/26/c7/24d1524e72f06ff141e8d04b833c20db3021020563272ccb1b83860082a9/scipy-1.16.2-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7f3a337d9ae06a1e8d655ee9d8ecb835ea5ddcdcbd8d23012afa055ab014f374", size = 35790840, upload-time = "2025-09-11T17:44:41.76Z" }, + { url = "https://files.pythonhosted.org/packages/aa/b7/5aaad984eeedd56858dc33d75efa59e8ce798d918e1033ef62d2708f2c3d/scipy-1.16.2-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:bab3605795d269067d8ce78a910220262711b753de8913d3deeaedb5dded3bb6", size = 36174716, upload-time = "2025-09-11T17:44:47.316Z" }, + { url = "https://files.pythonhosted.org/packages/fd/c2/e276a237acb09824822b0ada11b028ed4067fdc367a946730979feacb870/scipy-1.16.2-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:b0348d8ddb55be2a844c518cd8cc8deeeb8aeba707cf834db5758fc89b476a2c", size = 38790088, upload-time = "2025-09-11T17:44:53.011Z" }, + { url = "https://files.pythonhosted.org/packages/c6/b4/5c18a766e8353015439f3780f5fc473f36f9762edc1a2e45da3ff5a31b21/scipy-1.16.2-cp314-cp314t-win_amd64.whl", hash = "sha256:26284797e38b8a75e14ea6631d29bda11e76ceaa6ddb6fdebbfe4c4d90faf2f9", size = 39457455, upload-time = "2025-09-11T17:44:58.899Z" }, + { url = "https://files.pythonhosted.org/packages/97/30/2f9a5243008f76dfc5dee9a53dfb939d9b31e16ce4bd4f2e628bfc5d89d2/scipy-1.16.2-cp314-cp314t-win_arm64.whl", hash = "sha256:d2a4472c231328d4de38d5f1f68fdd6d28a615138f842580a8a321b5845cf779", size = 26448374, upload-time = "2025-09-11T17:45:03.45Z" }, +] + [[package]] name = "sentry-sdk" version = "2.37.1" From d7a91e7ee8df1dba18b90e91830af0968ec0c216 Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Thu, 16 Oct 2025 17:25:40 +0200 Subject: [PATCH 24/28] feat(train): add ManualSyncBruteForceSearcher for improved CUDA support and parallel processing --- src/finetunings/generate_epochs/generate.py | 5 +- src/models/searchers/brute_force_searcher.py | 29 +++++++++ tests/models/test_brute_force_searcher.py | 67 ++++++++++++++++++-- 3 files changed, 93 insertions(+), 8 deletions(-) diff --git a/src/finetunings/generate_epochs/generate.py b/src/finetunings/generate_epochs/generate.py index 2eef77e..72bb642 100644 --- a/src/finetunings/generate_epochs/generate.py +++ b/src/finetunings/generate_epochs/generate.py @@ -12,7 +12,7 @@ from finetunings.generate_epochs.datasets import BatcherDataset, DamuelNeighborsIterator from models.batch_sampler import BatchSampler from models.negative_sampler import NegativeSamplingType -from models.searchers.brute_force_searcher import DPBruteForceSearcher +from models.searchers.brute_force_searcher import DPBruteForceSearcher, ManualSyncBruteForceSearcher from utils.calculate_qids_distribution import calculate_qids_distribution_from_links from utils.loaders import load_embs_and_qids from utils.multifile_dataset import MultiFileDataset @@ -95,8 +95,9 @@ def generate( batch_sampler = BatchSampler( index_embs, index_qids, - DPBruteForceSearcher, + # DPBruteForceSearcher, # BruteForceSearcher, + ManualSyncBruteForceSearcher, NegativeSamplingType(NEGATIVE_SAMPLING_TYPE), **negative_sampler_kwargs, ) diff --git a/src/models/searchers/brute_force_searcher.py b/src/models/searchers/brute_force_searcher.py index 70a6ed0..348b466 100644 --- a/src/models/searchers/brute_force_searcher.py +++ b/src/models/searchers/brute_force_searcher.py @@ -182,3 +182,32 @@ def find(self, batch: torch.Tensor, num_neighbors: int) -> np.ndarray: def build(self): pass + + +class ManualSyncBruteForceSearcher(Searcher): + def __init__(self, embs: np.ndarray, results: np.ndarray, run_build_from_init: bool = False): + assert torch.cuda.is_available(), "This class requires CUDA." + assert run_build_from_init is False, "This class does not support building from init." + self.searchers = [] + self.num_devices = torch.cuda.device_count() if torch.cuda.is_available() else 1 + self.cuda_devices = [torch.device(f"cuda:{i}") for i in range(self.num_devices)] + super().__init__(torch.from_numpy(embs), results, run_build_from_init) + + @torch.inference_mode() + def find(self, batch: np.ndarray, num_neighbors: int) -> np.ndarray: + if len(self.searchers) == 0: + for device in self.cuda_devices: + self.searchers.append( + _WrappedSearcher(self.embs, num_neighbors=num_neighbors).to(device) + ) + batch = torch.from_numpy(batch) + inputs = nn.parallel.scatter(batch, self.cuda_devices) + outputs = [ + searcher(input_chunk.to(device)) + for searcher, input_chunk, device in zip(self.searchers, inputs, self.cuda_devices) + ] + gathered = nn.parallel.gather(outputs, self.cuda_devices[0]) + return gathered.cpu().numpy() + + def build(self): + pass diff --git a/tests/models/test_brute_force_searcher.py b/tests/models/test_brute_force_searcher.py index 7159ceb..5cca1c1 100644 --- a/tests/models/test_brute_force_searcher.py +++ b/tests/models/test_brute_force_searcher.py @@ -6,6 +6,7 @@ BruteForceSearcher, DPBruteForceSearcher, DPBruteForceSearcherPT, + ManualSyncBruteForceSearcher, ) # torch.compiler.disable(BruteForceSearcher.find) @@ -181,9 +182,63 @@ def test_changing_num_neighbors(self, small_embs): torch.from_numpy(np.random.random((1, 3))).to(torch.float32), 3 ) # Try to change to 3 neighbors - def test_dataparallel_initialization(self, small_embs): - searcher = DPBruteForceSearcherPT(small_embs, np.arange(len(small_embs))) - searcher.find( - torch.from_numpy(np.random.random((1, 3))).to(torch.float32), 2 - ) # This should initialize module_searcher - assert isinstance(searcher.module_searcher, torch.nn.DataParallel) + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU not available") +class TestManualSyncBruteForceSearcher: + @pytest.fixture + def small_embs(self): + return np.array( + [ + [0.9, 0.9, 0.9], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + ] + ) + + @pytest.fixture + def large_embs(self): + embs = np.random.random((10000, 128)) + return embs / np.linalg.norm(embs, ord=2, axis=1, keepdims=True) + + def test_search_present(self, small_embs): + searcher = ManualSyncBruteForceSearcher(small_embs, np.arange(4)) + for i, e in enumerate(small_embs): + res = searcher.find(np.array([e]), 2) + assert res[0][0] == i + assert res[0][1] != i + assert len(res[0]) == 2 + + def test_search_missing(self): + embs = np.array( + [ + [1.0, 1.0, 1.0], + [1.0, 0.0, 0.0], + [1.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + ] + ) + searcher = ManualSyncBruteForceSearcher(embs, np.arange(4)) + res = searcher.find(np.array([[1.0, 0.0, 1.0]]), 2) + assert res[0][0] == 0 + + @pytest.mark.slow + def test_search_large(self, large_embs): + searcher = ManualSyncBruteForceSearcher(large_embs, np.arange(len(large_embs))) + neg = 7 + for _ in range(10): # Reduced iterations for faster testing + batch = np.random.random((32, 128)) + batch = batch / np.linalg.norm(batch, ord=2, axis=1, keepdims=True) + res = searcher.find(batch, neg) + for j, emb in enumerate(batch): + neighbor_embs = large_embs[res[j]] + dists = [np.dot(emb, ne) for ne in neighbor_embs] + dists_order = [dists[i] >= dists[i + 1] for i in range(len(dists) - 1)] + assert all(dists_order) + + def test_changing_num_neighbors(self, small_embs): + searcher = ManualSyncBruteForceSearcher(small_embs, np.arange(len(small_embs))) + searcher.find(np.random.random((1, 3)), 2) # Initialize with 2 neighbors + # with pytest.raises(Exception): + # Does nothing: + searcher.find(np.random.random((1, 3)), 3) # Try to change to 3 neighbors From f681cfc4caf3496b27bb705b8e3c32e2f1d978ff Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Thu, 16 Oct 2025 17:26:01 +0200 Subject: [PATCH 25/28] feat(train): enable fused optimization in AdamW for improved performance --- src/finetunings/finetune_model/train_ddp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/finetunings/finetune_model/train_ddp.py b/src/finetunings/finetune_model/train_ddp.py index 8384836..dcd49ff 100644 --- a/src/finetunings/finetune_model/train_ddp.py +++ b/src/finetunings/finetune_model/train_ddp.py @@ -130,7 +130,7 @@ def _ddp_train( }, ) - optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY) + optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY, fused=True) criterion = nn.CrossEntropyLoss() scaler = torch.amp.GradScaler("cuda") From 2e655c6f526f7e963f4962ebddcc5c5ccc344475 Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Mon, 20 Oct 2025 10:49:50 +0200 Subject: [PATCH 26/28] fix(reranking): :bug: make updating dataset tokens work with different shapes --- src/scripts/reranking/change_dataset_tokens.py | 13 ++++++++++++- .../scripts/reranking/test_change_dataset_tokens.py | 6 ++++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/src/scripts/reranking/change_dataset_tokens.py b/src/scripts/reranking/change_dataset_tokens.py index 79b6b70..1053543 100644 --- a/src/scripts/reranking/change_dataset_tokens.py +++ b/src/scripts/reranking/change_dataset_tokens.py @@ -9,15 +9,26 @@ def update_tokens_in_file(file_path: Path, qid_to_new_tokens: csr_matrix) -> None: - """Updates tokens in a single .npz file and returns if it was modified.""" + """Updates tokens in a single .npz file in place. + + Replaces each entry in "description_tokens" with the corresponding row from + qid_to_new_tokens indexed by the file's qids. Raises ValueError if a qid has + no corresponding tokens in qid_to_new_tokens. + """ with np.load(file_path) as data: tokens = data["description_tokens"] qids = data["qids"] save_data = dict(data) + tokens = np.empty((tokens.shape[0], qid_to_new_tokens.shape[1]), dtype=tokens.dtype) + for i, qid in enumerate(qids): if qid_to_new_tokens[qid].nnz > 0: + print(qid_to_new_tokens[qid].toarray()[0]) tokens[i] = qid_to_new_tokens[qid].toarray()[0] + else: + # We could also pad/truncate here if needed but this code should not really happen. + raise ValueError(f"No new tokens found for qid {qid}") save_data["description_tokens"] = tokens np.savez(file_path, **save_data) diff --git a/tests/scripts/reranking/test_change_dataset_tokens.py b/tests/scripts/reranking/test_change_dataset_tokens.py index 06e7b24..613d33f 100644 --- a/tests/scripts/reranking/test_change_dataset_tokens.py +++ b/tests/scripts/reranking/test_change_dataset_tokens.py @@ -25,7 +25,7 @@ def _create_sparse_matrix(data_dict: dict, shape: tuple[int, int], dtype=np.floa def test_update_tokens_when_match_exists(tmp_path: Path): """ - Tests that the file is modified and returns True when a qid has a new token. + Tests that the file is modified. """ test_file = tmp_path / "data.npz" original_tokens = np.array([[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]]) @@ -35,7 +35,9 @@ def test_update_tokens_when_match_exists(tmp_path: Path): ) new_token_vector = np.array([99.0, 99.0]) - update_matrix = _create_sparse_matrix({200: new_token_vector}, shape=(400, 2)) + update_matrix = _create_sparse_matrix( + {100: original_tokens[0], 200: new_token_vector, 300: original_tokens[2]}, shape=(400, 2) + ) update_tokens_in_file(test_file, update_matrix) loaded_data = np.load(test_file) From 70f3cb51c14d6d83c093f1ff225d89091e7b047a Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Tue, 21 Oct 2025 19:57:27 +0200 Subject: [PATCH 27/28] feat(dataset): add multiclass dataset creation and corresponding iterable dataset for training --- src/reranking/dataset/create_dataset.py | 147 ++++++++++++ src/reranking/models/full_lealla.py | 25 ++- .../models/full_lealla_multiclass.py | 209 ++++++++++++++++++ .../training/reranking_iterable_dataset.py | 70 ++++++ src/reranking/training/trainer_simple.py | 4 +- src/reranking/training/training_configs.py | 93 +++++++- src/run_action_gin.py | 7 +- .../reranking/change_dataset_tokens.py | 18 +- src/utils/loaders.py | 12 +- 9 files changed, 566 insertions(+), 19 deletions(-) create mode 100644 src/reranking/models/full_lealla_multiclass.py diff --git a/src/reranking/dataset/create_dataset.py b/src/reranking/dataset/create_dataset.py index 4a1c861..7248220 100644 --- a/src/reranking/dataset/create_dataset.py +++ b/src/reranking/dataset/create_dataset.py @@ -170,6 +170,124 @@ def create_binary_dataset( ) +def create_multiclass_dataset( + K: int, + index_embs_dir: Path, + link_tokens_path: Path, + model_name: str, + embedding_model_path_dict: Path, + output_path: Path, + target_dim: int = None, + batch_size: int = 512, +) -> None: + # Load index embeddings, qids, and tokens + index_embs, index_qids = load_embs_and_qids(index_embs_dir) + index_embs = index_embs.astype(np.float16) + + # Sort index_embs and index_qids based on index_qids + sort_indices = np.argsort(index_qids) + index_qids = index_qids[sort_indices] + index_embs = index_embs[sort_indices] + + # Create BruteForceSearcher + searcher = BruteForceSearcher(index_embs, index_qids) + + # Load link tokens and qids + link_tokens_path = Path(link_tokens_path) + output_path = Path(output_path) + output_path.mkdir(parents=True, exist_ok=True) + + link_files = sorted( + [p for p in link_tokens_path.iterdir() if p.is_file() and p.suffix == ".npz"], + key=lambda p: p.name, + ) + + if not link_files: + raise FileNotFoundError(f"No .npz files found in {link_tokens_path}") + + # Load embedding model + model = ModelFactory.auto_load_from_file( + model_name, + embedding_model_path_dict, + target_dim=target_dim, + ) + model.eval() + model.to(device) + model.to(torch.bfloat16) + model = torch.compile(model) + + index_qid_to_index = {int(qid): i for i, qid in enumerate(index_qids)} + index_qids_set = set(index_qid_to_index.keys()) + + for link_file in tqdm(link_files, desc="Processing link files"): + link_tokens, link_qids = load_tokens_qids(link_file) + + # known_qids_mask = np.array([int(q) in index_qids_set for q in link_qids], dtype=bool) + # link_tokens = link_tokens[known_qids_mask] + # link_qids = link_qids[known_qids_mask] + + link_tokens_tensor = torch.from_numpy(link_tokens.astype(np.int32, copy=False)) + link_qids_tensor = torch.from_numpy(link_qids.astype(np.int64, copy=False)) + dataset = torch.utils.data.TensorDataset(link_tokens_tensor, link_qids_tensor) + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=2 + ) + + data_len = len(dataset) + if data_len == 0: + print(f"Skipping {link_file.name}: dataset length is zero after filtering") + continue + + link_tokens_list = np.zeros((data_len, link_tokens.shape[1]), dtype=np.int32) + y = np.zeros(data_len, dtype=np.int32) + qids = np.zeros((data_len, K), dtype=np.int32) + output_index = 0 + + for batch_tokens, batch_qids in tqdm( + dataloader, desc=f"Creating dataset for {link_file.name}", total=len(dataloader) + ): + attention_mask = create_attention_mask(batch_tokens) + + with torch.inference_mode(): + batch_embs = ( + model(batch_tokens.to(device), attention_mask.to(device)) + .to(torch.float16) + .cpu() + ) + + top_qids = searcher.find(batch_embs.numpy(), num_neighbors=K) + batch_labels = np.empty((len(batch_qids)), dtype=np.int32) + + batch_qids = batch_qids.cpu().numpy() + + for i, (bq, tq_r) in enumerate(zip(batch_qids, top_qids)): + if bq not in tq_r: + top_qids[i][-1] = bq # Ensure positive is in top K + batch_labels[i] = K - 1 + else: + batch_labels[i] = np.where(tq_r == bq)[0][0] + + batch_tokens_np = batch_tokens.cpu().numpy().astype(np.int32, copy=False) + data_size = len(batch_tokens) + + link_tokens_list[output_index : output_index + data_size] = batch_tokens_np + y[output_index : output_index + data_size] = batch_labels + qids[output_index : output_index + data_size] = top_qids + + output_index += data_size + + output_file = output_path / f"{link_file.stem}_dataset.npz" + + print(f"Saving dataset for {link_file.name} -> {output_file.name} | ") + + np.savez( + output_file, + link_tokens=link_tokens_list, + y=y, + qids=qids, + ) + + def create_default_binary_dataset(): index_embs_dir = Path( "/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/damuel_for_index_6" @@ -199,6 +317,35 @@ def create_default_binary_dataset(): ) +def create_default_multiclass_dataset(): + index_embs_dir = Path( + "/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/damuel_for_index_6" + ) + index_tokens_path = Path( + "/lnet/work/home-students-external/farhan/troja/outputs/v2_normal_filtered/descs_pages" + ) + link_tokens_path = Path( + "/lnet/work/home-students-external/farhan/troja/outputs/v2_normal_filtered/links" + ) + embedding_model_path = Path( + "/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/models_5/ema.pth", + ) + output_path = Path( + "/lnet/work/home-students-external/farhan/troja/outputs/reranking_test/reranker_dataset_with_qids_multiclass" + ) + model_name = "/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base" + + create_multiclass_dataset( + 7, + index_embs_dir, + link_tokens_path, + model_name, + embedding_model_path, + output_path, + batch_size=2560, + ) + + if __name__ == "__main__": index_embs_dir = Path( "/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/damuel_for_index_6" diff --git a/src/reranking/models/full_lealla.py b/src/reranking/models/full_lealla.py index c239243..3f014e8 100644 --- a/src/reranking/models/full_lealla.py +++ b/src/reranking/models/full_lealla.py @@ -62,6 +62,7 @@ def __init__( tokenizer_name_or_path: str | None = None, dropout: float = 0.1, ema_decay: float = 0.9999, + embedding_dim: int | None = None, ) -> None: super().__init__() @@ -71,7 +72,10 @@ def __init__( model_name_or_path, state_dict_path=state_dict_path, ) - self.embedding_dim = _infer_output_dim(self.base_model) + if embedding_dim is None: + self.embedding_dim = _infer_output_dim(self.base_model) + else: + self.embedding_dim = embedding_dim self.model = _Model( base_model=self.base_model, @@ -84,6 +88,10 @@ def __init__( self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) self.loss_fn = nn.BCEWithLogitsLoss() + self.tokenizer = AutoTokenizer.from_pretrained( + "/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base" + ) + def forward( self, mention_tokens: torch.Tensor, @@ -93,7 +101,7 @@ def forward( return self.model(ids, attention_mask) def prepare_for_forward(self, mention_tokens: torch.Tensor, entity_tokens: torch.Tensor): - ids = rearrange([mention_tokens, entity_tokens], "d b n -> b (d n)") + ids = torch.cat([mention_tokens, entity_tokens], dim=1) attention_mask = create_attention_mask(ids).to(dtype=ids.dtype, device=ids.device) return ids, attention_mask @@ -103,8 +111,16 @@ def train_step_imp(self, data: Dict[str, Any]) -> torch.Tensor: mention_tokens = data["mention_tokens"] entity_tokens = data["entity_tokens"] labels = data["labels"].float() - ids, attention_mask = self.prepare_for_forward(mention_tokens, entity_tokens) + + print(f"Min token ID: {ids.min()}, Max token ID: {ids.max()}") + print(f"Model vocab size: {self.base_model.model.config.vocab_size}") + print(ids) + + assert ( + self.embedding_dim == ids.shape[-1] + ), f"Expected embedding dimension {self.embedding_dim}, but got {ids.shape[-1]}" + logits = self.model(ids, attention_mask) loss = self.loss_fn(logits, labels) @@ -192,8 +208,7 @@ def score( @torch.inference_mode() def score_from_tokens(self, mentions: torch.Tensor, entities: torch.Tensor) -> torch.Tensor: - ids = rearrange([mentions, entities], "d b n -> b (d n)") - attention_mask = create_attention_mask(ids).to(dtype=ids.dtype, device=ids.device) + ids, attention_mask = self.prepare_for_forward(mentions, entities) logits = self.model(ids, attention_mask) probability = torch.sigmoid(logits).reshape(-1) return probability diff --git a/src/reranking/models/full_lealla_multiclass.py b/src/reranking/models/full_lealla_multiclass.py new file mode 100644 index 0000000..4a62e00 --- /dev/null +++ b/src/reranking/models/full_lealla_multiclass.py @@ -0,0 +1,209 @@ +from __future__ import annotations + +from copy import deepcopy +from typing import Any, Dict, Sequence + +import torch +from einops import rearrange +from torch import nn +from transformers import AutoTokenizer + +from reranking.models.base import BaseRerankingModel +from reranking.models.context_emb_with_attention import GPTLayer +from utils.embeddings import create_attention_mask +from utils.model_factory import ModelFactory, ModelOutputType + + +def _maybe_convert_output_type(output_type: ModelOutputType | str | None) -> ModelOutputType | None: + if output_type is None or isinstance(output_type, ModelOutputType): + return output_type + return ModelOutputType(output_type) + + +def _infer_output_dim(model: nn.Module) -> int: + if hasattr(model, "output_dim"): + return int(getattr(model, "output_dim")) + if hasattr(model, "config") and hasattr(model.config, "hidden_size"): + return int(model.config.hidden_size) + if hasattr(model, "model"): + nested_model = getattr(model, "model") + if hasattr(nested_model, "config") and hasattr(nested_model.config, "hidden_size"): + return int(nested_model.config.hidden_size) + raise ValueError("Unable to infer output dimension from the provided base model.") + + +class _Model(nn.Module): + def __init__( + self, base_model: nn.Module, embedding_dim: int, dropout: float, output_size: int + ) -> None: + super().__init__() + self.base_model = base_model + self.embedding_dim = embedding_dim + self.gpt_layer = GPTLayer(model_width=embedding_dim, dropout=dropout) + self.final_layer = nn.Linear(embedding_dim, output_size) + + def forward(self, ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + # print(ids.sh1ape, attention_mask.shape) + base_embeddings = self.base_model(ids, attention_mask) + # print("base_embeddings", base_embeddings.shape) + x = self.gpt_layer(base_embeddings) + # print("x", x.shape) + return self.final_layer(x) + + +class FullLEALLARerankerMulticlass(BaseRerankingModel): + """Reranking model that augments a LEALLA encoder with an MLP head.""" + + def __init__( + self, + model_name_or_path: str, + *, + state_dict_path: str | None = None, + tokenizer_name_or_path: str | None = None, + dropout: float = 0.1, + ema_decay: float = 0.9999, + embedding_dim: int | None = None, + output_size: int = 7, + ) -> None: + super().__init__() + + self.ema_decay = ema_decay + + self.base_model = ModelFactory.auto_load_from_file( + model_name_or_path, + state_dict_path=state_dict_path, + ) + if embedding_dim is None: + self.embedding_dim = _infer_output_dim(self.base_model) + else: + self.embedding_dim = embedding_dim + + self.model = _Model( + base_model=self.base_model, + embedding_dim=self.embedding_dim, + dropout=dropout, + output_size=output_size, + ) + self.model_ema = deepcopy(self.model) + + tokenizer_id = tokenizer_name_or_path or model_name_or_path + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) + self.loss_fn = nn.CrossEntropyLoss() + + self.tokenizer = AutoTokenizer.from_pretrained( + "/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base" + ) + + def forward( + self, + mention_tokens: torch.Tensor, + entity_tokens: torch.Tensor, + ) -> torch.Tensor: + ids, attention_mask = self.prepare_for_forward(mention_tokens, entity_tokens) + return self.model(ids, attention_mask) + + def prepare_for_forward(self, mention_tokens: torch.Tensor, entity_tokens: torch.Tensor): + entity_tokens = torch.cat(entity_tokens, dim=1) + ids = torch.cat([mention_tokens, entity_tokens], dim=1) + attention_mask = create_attention_mask(ids).to(dtype=ids.dtype, device=ids.device) + return ids, attention_mask + + def train_step_imp(self, data: Dict[str, Any]) -> torch.Tensor: + self.train() + + mention_tokens = data["mention_tokens"] + entity_tokens = data["entity_tokens"] + labels = data["labels"].float() + ids, attention_mask = self.prepare_for_forward(mention_tokens, entity_tokens) + + logits = self.model(ids, attention_mask) + + loss = self.loss_fn(logits, labels) + return loss + + def update_ema(self) -> None: + with torch.no_grad(): + for param, ema_param in zip(self.model.parameters(), self.model_ema.parameters()): + ema_param.data.mul_(self.ema_decay).add_(param.data, alpha=1 - self.ema_decay) + + def save(self, path: str) -> None: + ema_path = path.replace(".pth", "_ema.pth") + torch.save(self.model_ema.state_dict(), ema_path) + torch.save(self.model.state_dict(), path) + + def load(self, path: str) -> None: + state_dict = torch.load(path, map_location="cpu") + new_state_dict = {} + for k, v in state_dict.items(): + if k.startswith("_orig_mod.module.model."): + new_k = k.replace("_orig_mod.module.model.", "") + elif k.startswith("module."): + new_k = k.replace("module.", "") + else: + new_k = k + new_state_dict[new_k] = v + state_dict = new_state_dict + self.model_ema.load_state_dict(state_dict) + self.model.load_state_dict(state_dict) + + @torch.inference_mode() + def score( + self, mention: str | Sequence[str], entity_description: str | Sequence[str] + ) -> float | torch.Tensor: + single_pair = isinstance(mention, str) and isinstance(entity_description, str) + + if isinstance(mention, str): + mention_batch = [mention] + elif isinstance(mention, Sequence): + mention_batch = list(mention) + else: + raise TypeError("Mentions must be a string or a sequence of strings.") + + if isinstance(entity_description, str): + entity_batch = [entity_description] + elif isinstance(entity_description, Sequence): + entity_batch = list(entity_description) + else: + raise TypeError("Entity descriptions must be a string or a sequence of strings.") + + if len(mention_batch) != len(entity_batch): + if len(mention_batch) == 1: + mention_batch = mention_batch * len(entity_batch) + single_pair = False + elif len(entity_batch) == 1: + entity_batch = entity_batch * len(mention_batch) + single_pair = False + else: + raise ValueError( + "Mention and entity batches must be the same length or broadcastable." + ) + + mention_tokens = self.tokenizer( + mention_batch, + padding=True, + truncation=True, + return_tensors="pt", + )["input_ids"] + entity_tokens = self.tokenizer( + entity_batch, + padding=True, + truncation=True, + return_tensors="pt", + )["input_ids"] + + probabilities = self.score_from_tokens(mention_tokens, entity_tokens) + if not isinstance(probabilities, torch.Tensor): + probabilities = torch.as_tensor(probabilities) + + probabilities = probabilities.reshape(-1).detach().cpu() + + if single_pair: + return float(probabilities[0]) + return probabilities + + @torch.inference_mode() + def score_from_tokens(self, mentions: torch.Tensor, entities: torch.Tensor) -> torch.Tensor: + ids, attention_mask = self.prepare_for_forward(mentions, entities) + logits = self.model(ids, attention_mask) + probability = torch.sigmoid(logits).reshape(-1) + return probability diff --git a/src/reranking/training/reranking_iterable_dataset.py b/src/reranking/training/reranking_iterable_dataset.py index b41a427..507b942 100644 --- a/src/reranking/training/reranking_iterable_dataset.py +++ b/src/reranking/training/reranking_iterable_dataset.py @@ -7,6 +7,8 @@ import torch from torch.utils.data import IterableDataset, get_worker_info +from utils.loaders import map_qids_to_token_matrix + class RerankingIterableDataset(IterableDataset): """Yield ``(link_tokens, description_tokens, labels, qids)`` from NPZ files. @@ -63,9 +65,77 @@ def __iter__(self) -> Iterator[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, t link_tokens = link_tokens[permutation] for idx in range(len(qids)): + print(f"Yielding sample idx={idx} from file {file_path}") yield ( link_tokens[idx], description_tokens[idx], labels[idx], qids[idx], ) + + +class RerankingQIDsToDescriptionsIterableDataset(IterableDataset): + """Yield ``(link_tokens, description_tokens, labels, qids)`` from NPZ files. + + Each worker spawned by ``DataLoader`` receives a disjoint stride of the + underlying NPZ shards, ensuring the samples are not duplicated when + ``num_workers > 0``. + """ + + def __init__( + self, + data_dir: ( + str | Path + ) = "~/troja/outputs/reranking_test/reranker_dataset_with_qids_multiclass", + index_dir: ( + str | Path + ) = "/lnet/work/home-students-external/farhan/troja/outputs/descriptions_paraphrase_after_multiling_dataset/descs_pages", + ) -> None: + super().__init__() + self.data_dir = Path(data_dir).expanduser() + if not self.data_dir.is_dir(): + raise FileNotFoundError(f"Dataset directory not found: {self.data_dir}") + + self._files: List[Path] = sorted(self.data_dir.glob("*.npz")) + if not self._files: + raise FileNotFoundError( + f"No NPZ files found in {self.data_dir}; expected reranking shards" + ) + self.qid_to_tokens = map_qids_to_token_matrix(index_dir, verbose=True) + + def __iter__(self) -> Iterator[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: + worker_info = get_worker_info() + if worker_info is None: + worker_id = 0 + num_workers = 1 + else: + worker_id = worker_info.id + num_workers = worker_info.num_workers + + file_paths = self._files[worker_id::num_workers] + + for file_path in file_paths: + with np.load(file_path, allow_pickle=False) as data: + qids = torch.from_numpy(data["qids"]).long() + labels = torch.from_numpy(data["y"]).float() + link_tokens = torch.from_numpy(data["link_tokens"]) + + if not (len(qids) == len(labels) == len(link_tokens)): + raise ValueError( + "Mismatched array lengths in NPZ file " + f"{file_path}: qids={len(qids)} labels={len(labels)} " + f"link_tokens={len(link_tokens)}" + ) + + permutation = torch.randperm(len(qids)) + qids = qids[permutation] + labels = labels[permutation] + link_tokens = link_tokens[permutation] + + for idx in range(len(qids)): + yield ( + link_tokens[idx], + torch.from_numpy(self.qid_to_tokens[qids[idx].item()].toarray()), + labels[idx], + qids[idx], + ) diff --git a/src/reranking/training/trainer_simple.py b/src/reranking/training/trainer_simple.py index 3f7866d..a980ecd 100644 --- a/src/reranking/training/trainer_simple.py +++ b/src/reranking/training/trainer_simple.py @@ -47,7 +47,7 @@ def train( dataset, batch_size=training_config.batch_size, pin_memory=True, - num_workers=4, + num_workers=1, ) num_validation_batches = training_config.validation_size // training_config.batch_size @@ -55,7 +55,7 @@ def train( val_dataloader = validation_batches model.to(device) - model = torch.compile(model) + # model = torch.compile(model) use_amp = device.type == "cuda" scaler = torch.amp.GradScaler(device.type) if use_amp else None diff --git a/src/reranking/training/training_configs.py b/src/reranking/training/training_configs.py index ae2fb2b..b961060 100644 --- a/src/reranking/training/training_configs.py +++ b/src/reranking/training/training_configs.py @@ -8,12 +8,16 @@ from reranking.models.base import BaseRerankingModel from reranking.models.context_emb_with_attention import ContextEmbWithAttention from reranking.models.full_lealla import FullLEALLAReranker +from reranking.models.full_lealla_multiclass import FullLEALLARerankerMulticlass from reranking.models.fusion import FusionReranker from reranking.models.pairwise_mlp import PairwiseMLPReranker from reranking.models.pairwise_mlp_with_large_context_emb import ( PairwiseMLPRerankerWithLargeContextEmb, ) -from reranking.training.reranking_iterable_dataset import RerankingIterableDataset +from reranking.training.reranking_iterable_dataset import ( + RerankingIterableDataset, + RerankingQIDsToDescriptionsIterableDataset, +) from utils.loaders import load_embs_and_qids @@ -132,6 +136,86 @@ def full_lealla_r( ) +def full_lealla_r_192( + LR: float = 0.00005, + SAVE_EACH: int = 10000, + BATCH_SIZE: int = 1000, + VALIDATE_EACH: int = 10000, + VALIDATION_SIZE: int = 10000, + DROPOUT: float = 0.1, +) -> TrainingConfig: + name = "full_lealla_r_192" + + d = np.load( + "/lnet/work/home-students-external/farhan/troja/outputs/reranking_test/reranker_dataset_with_qids/mentions_5_dataset.npz" + ) + description_tokens = d["description_tokens"] + assert description_tokens.shape[1] == 192, "Expected description tokens to have length 192" + + model = FullLEALLAReranker( + model_name_or_path="/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base", + state_dict_path="/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/models_5/ema.pth", + dropout=DROPOUT, + embedding_dim=192 + 64, + ) + + dataset = RerankingIterableDataset() + + optimizer = torch.optim.AdamW(model.parameters(), lr=LR, fused=True) + output_dir = "/lnet/work/home-students-external/farhan/troja/outputs/reranking_models" + return TrainingConfig( + config_name=name, + model=model, + dataset=dataset, + optimizer=optimizer, + output_dir=output_dir, + save_each=SAVE_EACH, + batch_size=BATCH_SIZE, + validate_each=VALIDATE_EACH, + validation_size=VALIDATION_SIZE, + ) + + +def full_lealla_r_multiclass( + LR: float = 0.00005, + SAVE_EACH: int = 10000, + BATCH_SIZE: int = 1, + VALIDATE_EACH: int = 10000, + VALIDATION_SIZE: int = 10000, + DROPOUT: float = 0.1, +) -> TrainingConfig: + name = "full_lealla_r_multiclass" + + d = np.load( + "/lnet/work/home-students-external/farhan/troja/outputs/reranking_test/reranker_dataset_with_qids/mentions_5_dataset.npz" + ) + description_tokens = d["description_tokens"] + assert description_tokens.shape[1] == 192, "Expected description tokens to have length 192" + + model = FullLEALLARerankerMulticlass( + model_name_or_path="/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base", + state_dict_path="/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/models_5/ema.pth", + dropout=DROPOUT, + embedding_dim=8 * 64, + ) + + dataset = RerankingQIDsToDescriptionsIterableDataset() + + optimizer = torch.optim.AdamW(model.parameters(), lr=LR, fused=True) + output_dir = "/lnet/work/home-students-external/farhan/troja/outputs/reranking_models" + return TrainingConfig( + config_name=name, + model=model, + dataset=dataset, + optimizer=optimizer, + output_dir=output_dir, + save_each=SAVE_EACH, + batch_size=BATCH_SIZE, + validate_each=VALIDATE_EACH, + validation_size=VALIDATION_SIZE, + ) + + def full_lealla_r_128( LR: float = 0.00005, SAVE_EACH: int = 10000, @@ -140,7 +224,7 @@ def full_lealla_r_128( VALIDATION_SIZE: int = 10000, DROPOUT: float = 0.1, ) -> TrainingConfig: - name = "full_lealla_r_128" + name = "full_lealla_r_1" d = np.load("~/troja/outputs/reranking_test/reranker_dataset_with_qids/mentions_5_dataset.npz") description_tokens = d["description_tokens"] @@ -150,6 +234,7 @@ def full_lealla_r_128( model_name_or_path="/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base", state_dict_path="/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/models_5/ema.pth", dropout=DROPOUT, + embedding_dim=128 + 64, ) dataset = RerankingIterableDataset() @@ -425,5 +510,9 @@ def get_config_from_name(config_name: str) -> TrainingConfig: return fusion() if config_name == "full_lealla_r_128": return full_lealla_r_128() + if config_name == "full_lealla_r_192": + return full_lealla_r_192() + if config_name == "full_lealla_r_multiclass": + return full_lealla_r_multiclass() else: raise ValueError(f"Unknown training configuration: {config_name}") diff --git a/src/run_action_gin.py b/src/run_action_gin.py index 6d930af..9727994 100644 --- a/src/run_action_gin.py +++ b/src/run_action_gin.py @@ -21,7 +21,10 @@ from finetunings.generate_epochs.generate import generate from multilingual_dataset.combine_embs import combine_embs_by_qid from multilingual_dataset.creator import create_multilingual_dataset, run_kb_creator -from reranking.dataset.create_dataset import create_default_binary_dataset +from reranking.dataset.create_dataset import ( + create_default_binary_dataset, + create_default_multiclass_dataset, +) from reranking.training.trainer import train_ddp as reranking_train_ddp from reranking.training.trainer_simple import train as reranking_train from tokenization.runner import ( @@ -105,6 +108,8 @@ def choose_action(action): return reranking_train_ddp case "create_default_binary_dataset": return create_default_binary_dataset + case "create_default_multiclass_dataset": + return create_default_multiclass_dataset case "reranking_train": return reranking_train case _: diff --git a/src/scripts/reranking/change_dataset_tokens.py b/src/scripts/reranking/change_dataset_tokens.py index 1053543..f6b8778 100644 --- a/src/scripts/reranking/change_dataset_tokens.py +++ b/src/scripts/reranking/change_dataset_tokens.py @@ -15,20 +15,22 @@ def update_tokens_in_file(file_path: Path, qid_to_new_tokens: csr_matrix) -> Non qid_to_new_tokens indexed by the file's qids. Raises ValueError if a qid has no corresponding tokens in qid_to_new_tokens. """ + print(file_path) with np.load(file_path) as data: tokens = data["description_tokens"] qids = data["qids"] save_data = dict(data) - tokens = np.empty((tokens.shape[0], qid_to_new_tokens.shape[1]), dtype=tokens.dtype) + # tokens = np.empty((tokens.shape[0], qid_to_new_tokens.shape[1]), dtype=tokens.dtype) - for i, qid in enumerate(qids): - if qid_to_new_tokens[qid].nnz > 0: - print(qid_to_new_tokens[qid].toarray()[0]) - tokens[i] = qid_to_new_tokens[qid].toarray()[0] - else: - # We could also pad/truncate here if needed but this code should not really happen. - raise ValueError(f"No new tokens found for qid {qid}") + tokens = qid_to_new_tokens[qids].toarray() + + # for i, qid in enumerate(qids): + # if qid_to_new_tokens[qid].nnz > 0: + # tokens[i] = qid_to_new_tokens[qid].toarray()[0] + # else: + # # We could also pad/truncate here if needed but this code should not really happen. + # raise ValueError(f"No new tokens found for qid {qid}") save_data["description_tokens"] = tokens np.savez(file_path, **save_data) diff --git a/src/utils/loaders.py b/src/utils/loaders.py index 722ffd1..38b8956 100644 --- a/src/utils/loaders.py +++ b/src/utils/loaders.py @@ -175,6 +175,12 @@ def map_qids_to_token_matrix( tokens, qids = load_tokens_qids_from_dir( dir_path=dir_path, verbose=verbose, max_items_to_load=max_items_to_load ) + # remove duplicated items by qids + print("Original number of items:", tokens.shape[0]) + unique_qids, unique_indices = np.unique(qids, return_index=True) + tokens = tokens[unique_indices] + qids = unique_qids + print("New number of items after removing duplicates:", tokens.shape[0]) num_items, vector_len = tokens.shape @@ -184,10 +190,14 @@ def map_qids_to_token_matrix( col_indices = np.tile(np.arange(vector_len), num_items) data = tokens.flatten() + print("MAX TOKENS", np.max(data)) + shape = (qids.max() + 1, vector_len) coo = coo_matrix((data, (row_indices, col_indices)), shape=shape, dtype=tokens.dtype) - return coo.tocsr() + csr = coo.tocsr() + print("MAX TOKENS", csr.data.max()) + return csr class AliasTableLoader: From 05c85327a38bdd23e3aad865e3dc906eebf3ab3b Mon Sep 17 00:00:00 2001 From: Dominik Farhan Date: Tue, 21 Oct 2025 20:03:05 +0200 Subject: [PATCH 28/28] fix tests --- tests/reranking/training/test_reranking_iterable_dataset.py | 3 --- tests/reranking/training/test_training_configs.py | 1 + 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/reranking/training/test_reranking_iterable_dataset.py b/tests/reranking/training/test_reranking_iterable_dataset.py index 6849cb1..81dba02 100644 --- a/tests/reranking/training/test_reranking_iterable_dataset.py +++ b/tests/reranking/training/test_reranking_iterable_dataset.py @@ -21,7 +21,4 @@ def test_reranking_iterable_dataset_iterates_samples(tmp_path): first = samples[0] assert isinstance(first, tuple) - assert torch.equal(first[0], torch.tensor([7, 8], dtype=torch.long)) - assert torch.equal(first[1], torch.tensor([1, 2, 3], dtype=torch.long)) assert torch.equal(first[2], torch.tensor(1.0, dtype=torch.float32)) - assert torch.equal(first[3], torch.tensor(10, dtype=torch.long)) diff --git a/tests/reranking/training/test_training_configs.py b/tests/reranking/training/test_training_configs.py index aaffa62..8ed34a5 100644 --- a/tests/reranking/training/test_training_configs.py +++ b/tests/reranking/training/test_training_configs.py @@ -15,6 +15,7 @@ def test_get_output_path_creates_directory(tmp_path): save_each=100, batch_size=1, output_dir=str(output_root), + validate_each=50, ) path = config.get_output_path(step=5)