diff --git a/configs/multilingual_dataset.gin b/configs/multilingual_dataset.gin index 0f7fe45..8092036 100644 --- a/configs/multilingual_dataset.gin +++ b/configs/multilingual_dataset.gin @@ -6,3 +6,6 @@ dest_dir="/lnet/work/home-students-external/farhan/troja/outputs/v2_normal/" create_multilingual_dataset.source_dir=%source_dir create_multilingual_dataset.langs=%langs create_multilingual_dataset.dest_dir=%dest_dir + +run_kb_creator.langs=%langs + diff --git a/configs/paraphrase.gin b/configs/paraphrase.gin index 22c4e97..118e307 100644 --- a/configs/paraphrase.gin +++ b/configs/paraphrase.gin @@ -23,5 +23,6 @@ train_ddp.FOUNDATION_MODEL_PATH=%model_path run_mewsli_mention.model_path=%model_path run_damuel_mention.model_path=%model_path run_damuel_description_context.model_path=%model_path +run_damuel_description.model_path=%model_path run_damuel_link_context.model_path=%model_path run_mewsli_context.model_path=%model_path \ No newline at end of file diff --git a/configs/train.gin b/configs/train.gin index 778c3de..65ac505 100644 --- a/configs/train.gin +++ b/configs/train.gin @@ -1,4 +1,4 @@ -training_batch_size=2688 +training_batch_size=3712 #epochs=300 epochs=100 logit_mutliplier=20 diff --git a/pyproject.toml b/pyproject.toml index d421e28..ccfd95a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,15 +5,19 @@ description = "Add your description here" readme = "README.md" requires-python = ">=3.11" dependencies = [ + "einops>=0.8.1", "fire>=0.7.1", "gin-config>=0.5.0", + "ipython>=9.6.0", "numba>=0.61.2", "orjson>=3.11.3", "pandas>=2.3.2", "pytest>=8.4.2", + "pytest-cov>=7.0.0", "pytest-mock>=3.15.0", "python-fire>=0.1.0", "scann>=1.4.2", + "scipy>=1.16.2", "torch>=2.8.0", "tqdm>=4.67.1", "transformers>=4.56.1", diff --git a/src/finetunings/finetune_model/train_ddp.py b/src/finetunings/finetune_model/train_ddp.py index 8384836..dcd49ff 100644 --- a/src/finetunings/finetune_model/train_ddp.py +++ b/src/finetunings/finetune_model/train_ddp.py @@ -130,7 +130,7 @@ def _ddp_train( }, ) - optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY) + optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY, fused=True) criterion = nn.CrossEntropyLoss() scaler = torch.amp.GradScaler("cuda") diff --git a/src/finetunings/generate_epochs/generate.py b/src/finetunings/generate_epochs/generate.py index 2eef77e..72bb642 100644 --- a/src/finetunings/generate_epochs/generate.py +++ b/src/finetunings/generate_epochs/generate.py @@ -12,7 +12,7 @@ from finetunings.generate_epochs.datasets import BatcherDataset, DamuelNeighborsIterator from models.batch_sampler import BatchSampler from models.negative_sampler import NegativeSamplingType -from models.searchers.brute_force_searcher import DPBruteForceSearcher +from models.searchers.brute_force_searcher import DPBruteForceSearcher, ManualSyncBruteForceSearcher from utils.calculate_qids_distribution import calculate_qids_distribution_from_links from utils.loaders import load_embs_and_qids from utils.multifile_dataset import MultiFileDataset @@ -95,8 +95,9 @@ def generate( batch_sampler = BatchSampler( index_embs, index_qids, - DPBruteForceSearcher, + # DPBruteForceSearcher, # BruteForceSearcher, + ManualSyncBruteForceSearcher, NegativeSamplingType(NEGATIVE_SAMPLING_TYPE), **negative_sampler_kwargs, ) diff --git a/src/models/searchers/brute_force_searcher.py b/src/models/searchers/brute_force_searcher.py index 1b6427c..348b466 100644 --- a/src/models/searchers/brute_force_searcher.py +++ b/src/models/searchers/brute_force_searcher.py @@ -45,7 +45,8 @@ def build(self) -> None: class _WrappedSearcher(nn.Module): def __init__(self, kb_embs, num_neighbors): super().__init__() - self.kb_embs: torch.Tensor = nn.Parameter(kb_embs) + # Replace Parameter with register_buffer + self.register_buffer("kb_embs", kb_embs) self.num_neighbors: int = num_neighbors # @torch.compile @@ -106,6 +107,10 @@ def find(self, batch: np.ndarray, num_neighbors: int, mask=None) -> np.ndarray: _WrappedSearcher(torch.from_numpy(self.embs), num_neighbors) ) self.module_searcher.to(self.device) + # Set module to eval() and disable gradients + self.module_searcher.eval() + for param in self.module_searcher.parameters(): + param.requires_grad = False self.required_num_neighbors = num_neighbors top_indices: torch.Tensor = self.module_searcher( torch.from_numpy(batch).to(self.device) @@ -116,3 +121,93 @@ def find(self, batch: np.ndarray, num_neighbors: int, mask=None) -> np.ndarray: def build(self): pass + + +class DPBruteForceSearcherPT(Searcher): + def __init__(self, embs: np.ndarray, results: np.ndarray, run_build_from_init: bool = True): + if torch.cuda.is_available(): + _logger.info("Running on CUDA.") + self.device: torch.device = torch.device("cuda") + else: + _logger.info("CUDA is not available.") + self.device: torch.device = torch.device("cpu") + self.module_searcher: Optional[nn.DataParallel] = None + self.required_num_neighbors: Optional[int] = None + super().__init__(embs, results, run_build_from_init) + + @torch.inference_mode() + def find(self, batch: torch.Tensor, num_neighbors: int) -> np.ndarray: + """ + Finds the nearest neighbors for a given batch of input data. + CAREFUL: This is an optimized version that comes with potential pitfalls to get better performance. + Read Notes for details! + + Args: + batch (torch.Tensor): A batch of input data for which neighbors are to be found. + num_neighbors (int): The number of nearest neighbors to retrieve. + Returns: + np.ndarray: An array containing the results corresponding to the nearest neighbors. + Raises: + TypeError: If `module_searcher` if an unexpected attribute access occurs when using module_searcher. + Notes: + - It is not possible to change num_neighbors after the first call to find. + If you need to do that, you need to reinitialize this object. If you call the find with different + num_neighbors, it will not raise an error and will fail silently. + - The first call to find will be slow, because the module_searcher will be initialized and torch.compile is called. + """ + # with torch.inference_mode(), torch.autocast( + # device_type=self.device.type, dtype=torch.float16 + # ): + # A try except trick to avoid the overhead of checking if the module_searcher is None + # on every call to find. + # This is a bit of a hack, but it should make things faster as we are suggesting that the module_searcher is initialized. + try: + with torch.amp.autocast(device_type="cuda"): + top_indices: torch.Tensor = self.module_searcher(batch) + except TypeError as e: + if self.module_searcher is not None: + raise e + self.module_searcher = torch.compile( + nn.DataParallel(_WrappedSearcher(self.embs, num_neighbors)) + ) + self.module_searcher.to(self.device) + # Set module to eval() and disable gradients + self.module_searcher.eval() + for param in self.module_searcher.parameters(): + param.requires_grad = False + self.required_num_neighbors = num_neighbors + top_indices: torch.Tensor = self.module_searcher(batch) + + return self.results[top_indices.cpu().numpy()] + + def build(self): + pass + + +class ManualSyncBruteForceSearcher(Searcher): + def __init__(self, embs: np.ndarray, results: np.ndarray, run_build_from_init: bool = False): + assert torch.cuda.is_available(), "This class requires CUDA." + assert run_build_from_init is False, "This class does not support building from init." + self.searchers = [] + self.num_devices = torch.cuda.device_count() if torch.cuda.is_available() else 1 + self.cuda_devices = [torch.device(f"cuda:{i}") for i in range(self.num_devices)] + super().__init__(torch.from_numpy(embs), results, run_build_from_init) + + @torch.inference_mode() + def find(self, batch: np.ndarray, num_neighbors: int) -> np.ndarray: + if len(self.searchers) == 0: + for device in self.cuda_devices: + self.searchers.append( + _WrappedSearcher(self.embs, num_neighbors=num_neighbors).to(device) + ) + batch = torch.from_numpy(batch) + inputs = nn.parallel.scatter(batch, self.cuda_devices) + outputs = [ + searcher(input_chunk.to(device)) + for searcher, input_chunk, device in zip(self.searchers, inputs, self.cuda_devices) + ] + gathered = nn.parallel.gather(outputs, self.cuda_devices[0]) + return gathered.cpu().numpy() + + def build(self): + pass diff --git a/src/multilingual_dataset/creator.py b/src/multilingual_dataset/creator.py index 69a7e46..b018bb9 100644 --- a/src/multilingual_dataset/creator.py +++ b/src/multilingual_dataset/creator.py @@ -284,13 +284,48 @@ def create_multilingual_dataset( dest_dir: Union[str, Path], max_links_per_qid: int, ) -> None: + """Create a multilingual dataset by mixing links and building a language-filtered KB. + + - Links: copies and intermixes link shards from the given languages into `dest_dir/links`, + limiting per-QID occurrences to `max_links_per_qid`. Outputs NPZ files with arrays + `tokens` and `qids`. + - KB pages: writes a subset of description/page shards to `dest_dir/descs_pages`, + assigning up to one language per QID by default. + + Args: + source_dir: Root directory of the DAMUEL dataset to read from. + langs: Language codes to include. + dest_dir: Output directory; creates `links/` and `descs_pages/` subfolders. + max_links_per_qid: Maximum number of link samples retained per QID. + + Notes: + Uses parallel mixing and threaded I/O for performance. + """ MultilingualDatasetCreator(Path(source_dir), langs, Path(dest_dir), max_links_per_qid).run() +@gin.configurable def run_kb_creator( source_dir: Union[str, Path], langs: list[str], dest_dir: Union[str, Path], langs_per_qid: int, ) -> None: + """ + Build a language-filtered KB (descriptions/pages) subset from a DAMUEL dataset. + + For each QID, selects up to `langs_per_qid` languages ranked by link frequency + (ties broken by overall language size), then copies only the chosen language pages + into `dest_dir/descs_pages` as compressed NPZ shards named `mentions_{lang}_{i}.npz`. + + Args: + source_dir: Root path of the DAMUEL dataset to read (links and pages). + langs: List of language codes to consider. + dest_dir: Output directory; the 'descs_pages' subfolder is created inside. + langs_per_qid: Maximum number of languages to assign to each QID. + + Notes: + - Affects KB creation only; link files are not modified. + - Uses parallel I/O with ThreadPoolExecutor for speed. + """ _KBCreator(DamuelPaths(source_dir), langs, Path(dest_dir), langs_per_qid).run() diff --git a/src/reranking/binary/create_dataset.py b/src/reranking/binary/create_dataset.py deleted file mode 100644 index 8c2f15f..0000000 --- a/src/reranking/binary/create_dataset.py +++ /dev/null @@ -1,227 +0,0 @@ -import sys -from pathlib import Path - -import numba as nb -import numpy as np -import torch -import torch.utils.data -from tqdm import tqdm - -sys.path.append("/lnet/work/home-students-external/farhan/mel-reborn/src") - -from models.searchers.brute_force_searcher import BruteForceSearcher -from utils.embeddings import create_attention_mask -from utils.loaders import load_embs_and_qids, load_tokens_qids_from_dir -from utils.model_factory import ModelFactory - -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - -@nb.njit -def get_neg_qids(top_qids, batch_qids): - neg_qids = [] - for row in top_qids: - if row[0] not in batch_qids: - neg_qids.append(row[0]) - else: - neg_qids.append(row[1]) - return neg_qids - - -def create_binary_dataset( - index_embs_dir: Path, - index_tokens_path: Path, - link_tokens_path: Path, - model_name: str, - embedding_model_path_dict: Path, - output_path: Path, - target_dim: int = None, - batch_size: int = 512, -) -> None: - # Load index embeddings, qids, and tokens - index_embs, index_qids = load_embs_and_qids(index_embs_dir) - index_qids_set = set(index_qids) - index_embs = index_embs.astype(np.float16) - index_tokens, _ = load_tokens_qids_from_dir(index_tokens_path) - print(index_tokens.shape) - print(len(index_qids_set)) - - # Create BruteForceSearcher - searcher = BruteForceSearcher(index_embs, index_qids) - - # Load link tokens and qids - link_tokens, link_qids = load_tokens_qids_from_dir(link_tokens_path) - # Loaders order by qids which is not necessarily what we want - print(link_tokens.shape) - # Load embedding model - model = ModelFactory.auto_load_from_file( - model_name, - embedding_model_path_dict, - target_dim=target_dim, - ) - model.eval() - model.to(device) - # Create DataLoader - dataset = list(zip(link_tokens, link_qids)) - dataset = torch.utils.data.Subset( - dataset, [i for i, (tokens, qid) in enumerate(dataset) if qid in index_qids_set] - ) - dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True) - - # Initialize dataset arrays - description_tokens = [] - link_tokens_list = [] - y = [] - - print("Dataset length:", len(dataset)) - description_tokens = np.zeros((len(dataset) * 2, index_tokens.shape[1])) - link_tokens_list = np.zeros((len(dataset) * 2, link_tokens.shape[1])) - y = np.zeros((len(dataset) * 2,)) - output_index = 0 - - index_qid_to_index = {qid: i for i, qid in enumerate(index_qids)} - - # Iterate over batches - for batch_tokens, batch_qids in tqdm( - dataloader, desc="Creating dataset", total=len(dataloader) - ): - # Embed link tokens - with torch.no_grad(): - batch_embs = model( - batch_tokens.to(device).to(torch.int64), - create_attention_mask(batch_tokens).to(device), - ).cpu() - - # Find top matches - top_qids = searcher.find(batch_embs.numpy().astype(np.float16), num_neighbors=2) - - positive_mask = [index_qid_to_index[qid] for qid in batch_qids.numpy()] - data_size = len(batch_tokens) - description_tokens[output_index : output_index + data_size] = index_tokens[positive_mask] - link_tokens_list[output_index : output_index + data_size] = batch_tokens.numpy() - y[output_index : output_index + data_size] = 1 - - output_index += data_size - - neg_qids = get_neg_qids(top_qids, set(batch_qids.numpy())) - - negative_mask = [index_qid_to_index[qid] for qid in neg_qids] - description_tokens[output_index : output_index + data_size] = index_tokens[negative_mask] - link_tokens_list[output_index : output_index + data_size] = batch_tokens.numpy() - y[output_index : output_index + data_size] = 0 - - output_index += data_size - - # Convert to numpy arrays - - print(description_tokens.shape) - print(link_tokens_list.shape) - print(y.shape) - - # Save dataset - np.savez( - output_path, - description_tokens=description_tokens, - link_tokens=link_tokens_list, - y=y, - ) - - -def create_multiclass_dataset( - index_embs_dir: Path, - index_tokens_path: Path, - link_tokens_path: Path, - model_name: str, - embedding_model_path_dict: Path, - output_path: Path, - total_classes: int, - target_dim: int = None, - batch_size: int = 512, -): - assert total_classes > 1 - index_embs, index_qids = load_embs_and_qids(index_embs_dir) - index_qids_set = set(index_qids) - index_embs = index_embs.astype(np.float16) - index_tokens, _ = load_tokens_qids_from_dir(index_tokens_path) - print(index_tokens.shape) - print(len(index_qids_set)) - - qid_to_index = {qid: i for i, qid in enumerate(index_qids)} - - # Create BruteForceSearcher - searcher = BruteForceSearcher(index_embs, index_qids) - - # Load link tokens and qids - link_tokens, link_qids = load_tokens_qids_from_dir(link_tokens_path) - # Load embedding model - model = ModelFactory.auto_load_from_file( - model_name, - embedding_model_path_dict, - target_dim=target_dim, - ) - model.eval() - model.to(device) - - # Create dataset and dataloader from link_tokens and link_qids - dataset = list(zip(link_tokens, link_qids)) - dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, pin_memory=True) - - link_tokens, qids = [], [] - for B_tokens, B_qids in tqdm(dataloader, desc="Creating dataset"): - with torch.no_grad(): - B_embs = model( - B_tokens.to(device).to(torch.int64), - create_attention_mask(B_tokens).to(device), - ).cpu() - - top_qids = searcher.find(B_embs.numpy().astype(np.float16), num_neighbors=total_classes) - for i, qid in enumerate(B_qids.numpy()): - if qid in top_qids[i]: - idx = top_qids[i].index(qid) - top_qids[i][idx], top_qids[i][total_classes - 1] = ( - top_qids[i][total_classes - 1], - top_qids[i][idx], - ) - else: - top_qids[i][total_classes - 1] = qid - link_tokens.extend(B_tokens.numpy()) - qids.extend(top_qids) - link_tokens = np.array(link_tokens) - qids = np.array(qids) - print(link_tokens.shape) - print(qids.shape) - - np.savez( - output_path, - link_tokens=link_tokens, - qids=qids, - ) - - -if __name__ == "__main__": - index_embs_dir = Path( - "/lnet/work/home-students-external/farhan/troja/outputs/workdirs/ml9/damuel_for_index_3" - ) - index_tokens_path = Path( - "/lnet/work/home-students-external/farhan/troja/outputs/workdirs/ml9/damuel_descs_together_tokens" - ) - link_tokens_path = Path( - "/lnet/work/home-students-external/farhan/troja/outputs/workdirs/ml9/damuel_links_together_tokens_0" - ) - embedding_model_path = Path( - "/lnet/work/home-students-external/farhan/troja/outputs/workdirs/ml9/models_2/final.pth" - ) - output_path = Path( - "/lnet/work/home-students-external/farhan/troja/outputs/reranking_test/reranker_dataset.npz" - ) - model_name = "/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base" - - create_multiclass_dataset( - index_embs_dir, - index_tokens_path, - link_tokens_path, - model_name, - embedding_model_path, - output_path, - total_classes=10, - ) diff --git a/src/reranking/dataset/create_dataset.py b/src/reranking/dataset/create_dataset.py new file mode 100644 index 0000000..7248220 --- /dev/null +++ b/src/reranking/dataset/create_dataset.py @@ -0,0 +1,375 @@ +from pathlib import Path + +import numpy as np +import torch +import torch.utils.data +from tqdm import tqdm + +from models.searchers.brute_force_searcher import BruteForceSearcher, DPBruteForceSearcherPT +from utils.embeddings import create_attention_mask +from utils.loaders import load_embs_and_qids, load_tokens_qids, load_tokens_qids_from_dir +from utils.model_factory import ModelFactory + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def get_neg_qids(top_qids, batch_qids): + neg_qids = [] + for row, batch_qid in zip(top_qids, batch_qids): + if row[0] != batch_qid: + neg_qids.append(row[0]) + else: + neg_qids.append(row[1]) + return neg_qids + + +def create_binary_dataset( + index_embs_dir: Path, + index_tokens_path: Path, + link_tokens_path: Path, + model_name: str, + embedding_model_path_dict: Path, + output_path: Path, + target_dim: int = None, + batch_size: int = 512, +) -> None: + # Load index embeddings, qids, and tokens + index_embs, index_qids = load_embs_and_qids(index_embs_dir) + index_embs = index_embs.astype(np.float16) + index_tokens, index_qids_from_tokens = load_tokens_qids_from_dir(index_tokens_path) + + # Sort index_embs and index_qids based on index_qids + sort_indices = np.argsort(index_qids) + index_qids = index_qids[sort_indices] + index_embs = index_embs[sort_indices] + + sort_indices_tokens = np.argsort(index_qids_from_tokens) + index_qids_from_tokens = index_qids_from_tokens[sort_indices_tokens] + index_tokens = index_tokens[sort_indices_tokens] + + np.testing.assert_array_equal(index_qids, index_qids_from_tokens) + + print(index_tokens.shape) + + # Create BruteForceSearcher + searcher = BruteForceSearcher(index_embs, index_qids) + + # Load link tokens and qids + link_tokens_path = Path(link_tokens_path) + output_path = Path(output_path) + output_path.mkdir(parents=True, exist_ok=True) + + link_files = sorted( + [p for p in link_tokens_path.iterdir() if p.is_file() and p.suffix == ".npz"], + key=lambda p: p.name, + ) + + if not link_files: + raise FileNotFoundError(f"No .npz files found in {link_tokens_path}") + + # Load embedding model + model = ModelFactory.auto_load_from_file( + model_name, + embedding_model_path_dict, + target_dim=target_dim, + ) + model.eval() + model.to(device) + model.to(torch.bfloat16) + model = torch.compile(model) + + index_qid_to_index = {int(qid): i for i, qid in enumerate(index_qids)} + index_qids_set = set(index_qid_to_index.keys()) + + for link_file in tqdm(link_files, desc="Processing link files"): + link_tokens, link_qids = load_tokens_qids(link_file) + + known_qids_mask = np.array([int(q) in index_qids_set for q in link_qids], dtype=bool) + link_tokens = link_tokens[known_qids_mask] + link_qids = link_qids[known_qids_mask] + + link_tokens_tensor = torch.from_numpy(link_tokens.astype(np.int32, copy=False)) + link_qids_tensor = torch.from_numpy(link_qids.astype(np.int64, copy=False)) + dataset = torch.utils.data.TensorDataset(link_tokens_tensor, link_qids_tensor) + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=2 + ) + + data_len = len(dataset) + if data_len == 0: + print(f"Skipping {link_file.name}: dataset length is zero after filtering") + continue + + description_tokens = np.zeros((data_len * 2, index_tokens.shape[1]), dtype=np.int32) + link_tokens_list = np.zeros((data_len * 2, link_tokens.shape[1]), dtype=np.int32) + y = np.zeros((data_len * 2,), dtype=np.int8) + qids = np.zeros((data_len * 2,), dtype=np.int32) + output_index = 0 + + for batch_tokens, batch_qids in tqdm( + dataloader, desc=f"Creating dataset for {link_file.name}", total=len(dataloader) + ): + attention_mask = create_attention_mask(batch_tokens) + + with torch.inference_mode(): + batch_embs = ( + model(batch_tokens.to(device), attention_mask.to(device)) + .to(torch.float16) + .cpu() + ) + + top_qids = searcher.find(batch_embs.numpy(), num_neighbors=2) + + batch_qids_np = batch_qids.cpu().numpy() + batch_tokens_np = batch_tokens.cpu().numpy().astype(np.int32, copy=False) + positive_mask = [index_qid_to_index[int(qid)] for qid in batch_qids_np] + data_size = len(batch_tokens) + + description_tokens[output_index : output_index + data_size] = index_tokens[ + positive_mask + ] + link_tokens_list[output_index : output_index + data_size] = batch_tokens_np + y[output_index : output_index + data_size] = 1 + qids[output_index : output_index + data_size] = batch_qids_np.astype( + np.int32, copy=False + ) + + output_index += data_size + + neg_qids = get_neg_qids(top_qids, batch_qids_np) + + negative_mask = [index_qid_to_index[int(qid)] for qid in neg_qids] + description_tokens[output_index : output_index + data_size] = index_tokens[ + negative_mask + ] + link_tokens_list[output_index : output_index + data_size] = batch_tokens_np + y[output_index : output_index + data_size] = 0 + qids[output_index : output_index + data_size] = np.array(neg_qids, dtype=np.int32) + + output_index += data_size + + if output_index != data_len * 2: + description_tokens = description_tokens[:output_index] + link_tokens_list = link_tokens_list[:output_index] + y = y[:output_index] + qids = qids[:output_index] + + output_file = output_path / f"{link_file.stem}_dataset.npz" + + print( + f"Saving dataset for {link_file.name} -> {output_file.name} | " + f"positives/negatives: {output_index // 2}" + ) + + np.savez( + output_file, + description_tokens=description_tokens, + link_tokens=link_tokens_list, + y=y, + qids=qids, + ) + + +def create_multiclass_dataset( + K: int, + index_embs_dir: Path, + link_tokens_path: Path, + model_name: str, + embedding_model_path_dict: Path, + output_path: Path, + target_dim: int = None, + batch_size: int = 512, +) -> None: + # Load index embeddings, qids, and tokens + index_embs, index_qids = load_embs_and_qids(index_embs_dir) + index_embs = index_embs.astype(np.float16) + + # Sort index_embs and index_qids based on index_qids + sort_indices = np.argsort(index_qids) + index_qids = index_qids[sort_indices] + index_embs = index_embs[sort_indices] + + # Create BruteForceSearcher + searcher = BruteForceSearcher(index_embs, index_qids) + + # Load link tokens and qids + link_tokens_path = Path(link_tokens_path) + output_path = Path(output_path) + output_path.mkdir(parents=True, exist_ok=True) + + link_files = sorted( + [p for p in link_tokens_path.iterdir() if p.is_file() and p.suffix == ".npz"], + key=lambda p: p.name, + ) + + if not link_files: + raise FileNotFoundError(f"No .npz files found in {link_tokens_path}") + + # Load embedding model + model = ModelFactory.auto_load_from_file( + model_name, + embedding_model_path_dict, + target_dim=target_dim, + ) + model.eval() + model.to(device) + model.to(torch.bfloat16) + model = torch.compile(model) + + index_qid_to_index = {int(qid): i for i, qid in enumerate(index_qids)} + index_qids_set = set(index_qid_to_index.keys()) + + for link_file in tqdm(link_files, desc="Processing link files"): + link_tokens, link_qids = load_tokens_qids(link_file) + + # known_qids_mask = np.array([int(q) in index_qids_set for q in link_qids], dtype=bool) + # link_tokens = link_tokens[known_qids_mask] + # link_qids = link_qids[known_qids_mask] + + link_tokens_tensor = torch.from_numpy(link_tokens.astype(np.int32, copy=False)) + link_qids_tensor = torch.from_numpy(link_qids.astype(np.int64, copy=False)) + dataset = torch.utils.data.TensorDataset(link_tokens_tensor, link_qids_tensor) + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=2 + ) + + data_len = len(dataset) + if data_len == 0: + print(f"Skipping {link_file.name}: dataset length is zero after filtering") + continue + + link_tokens_list = np.zeros((data_len, link_tokens.shape[1]), dtype=np.int32) + y = np.zeros(data_len, dtype=np.int32) + qids = np.zeros((data_len, K), dtype=np.int32) + output_index = 0 + + for batch_tokens, batch_qids in tqdm( + dataloader, desc=f"Creating dataset for {link_file.name}", total=len(dataloader) + ): + attention_mask = create_attention_mask(batch_tokens) + + with torch.inference_mode(): + batch_embs = ( + model(batch_tokens.to(device), attention_mask.to(device)) + .to(torch.float16) + .cpu() + ) + + top_qids = searcher.find(batch_embs.numpy(), num_neighbors=K) + batch_labels = np.empty((len(batch_qids)), dtype=np.int32) + + batch_qids = batch_qids.cpu().numpy() + + for i, (bq, tq_r) in enumerate(zip(batch_qids, top_qids)): + if bq not in tq_r: + top_qids[i][-1] = bq # Ensure positive is in top K + batch_labels[i] = K - 1 + else: + batch_labels[i] = np.where(tq_r == bq)[0][0] + + batch_tokens_np = batch_tokens.cpu().numpy().astype(np.int32, copy=False) + data_size = len(batch_tokens) + + link_tokens_list[output_index : output_index + data_size] = batch_tokens_np + y[output_index : output_index + data_size] = batch_labels + qids[output_index : output_index + data_size] = top_qids + + output_index += data_size + + output_file = output_path / f"{link_file.stem}_dataset.npz" + + print(f"Saving dataset for {link_file.name} -> {output_file.name} | ") + + np.savez( + output_file, + link_tokens=link_tokens_list, + y=y, + qids=qids, + ) + + +def create_default_binary_dataset(): + index_embs_dir = Path( + "/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/damuel_for_index_6" + ) + index_tokens_path = Path( + "/lnet/work/home-students-external/farhan/troja/outputs/v2_normal_filtered/descs_pages" + ) + link_tokens_path = Path( + "/lnet/work/home-students-external/farhan/troja/outputs/v2_normal_filtered/links" + ) + embedding_model_path = Path( + "/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/models_5/ema.pth", + ) + output_path = Path( + "/lnet/work/home-students-external/farhan/troja/outputs/reranking_test/reranker_dataset_with_qids" + ) + model_name = "/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base" + + create_binary_dataset( + index_embs_dir, + index_tokens_path, + link_tokens_path, + model_name, + embedding_model_path, + output_path, + batch_size=2560, + ) + + +def create_default_multiclass_dataset(): + index_embs_dir = Path( + "/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/damuel_for_index_6" + ) + index_tokens_path = Path( + "/lnet/work/home-students-external/farhan/troja/outputs/v2_normal_filtered/descs_pages" + ) + link_tokens_path = Path( + "/lnet/work/home-students-external/farhan/troja/outputs/v2_normal_filtered/links" + ) + embedding_model_path = Path( + "/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/models_5/ema.pth", + ) + output_path = Path( + "/lnet/work/home-students-external/farhan/troja/outputs/reranking_test/reranker_dataset_with_qids_multiclass" + ) + model_name = "/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base" + + create_multiclass_dataset( + 7, + index_embs_dir, + link_tokens_path, + model_name, + embedding_model_path, + output_path, + batch_size=2560, + ) + + +if __name__ == "__main__": + index_embs_dir = Path( + "/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/damuel_for_index_6" + ) + index_tokens_path = Path( + "/lnet/work/home-students-external/farhan/troja/outputs/v2_normal_filtered/descs_pages" + ) + link_tokens_path = Path( + "/lnet/work/home-students-external/farhan/troja/outputs/v2_normal_filtered/links" + ) + embedding_model_path = Path( + "/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/models_5/ema.pth", + ) + output_path = Path( + "/lnet/work/home-students-external/farhan/troja/outputs/reranking_test/reranker_dataset_with_qids" + ) + model_name = "/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base" + + create_binary_dataset( + index_embs_dir, + index_tokens_path, + link_tokens_path, + model_name, + embedding_model_path, + output_path, + batch_size=2048, + ) diff --git a/src/reranking/models/__init__.py b/src/reranking/models/__init__.py new file mode 100644 index 0000000..ec8a8dc --- /dev/null +++ b/src/reranking/models/__init__.py @@ -0,0 +1,7 @@ +from .pairwise_mlp import PairwiseMLPReranker +from .pairwise_mlp_with_retrieval_score import PairwiseMLPRerankerWithRetrievalScore + +__all__ = [ + "PairwiseMLPReranker", + "PairwiseMLPRerankerWithRetrievalScore", +] diff --git a/src/reranking/models/base.py b/src/reranking/models/base.py new file mode 100644 index 0000000..14dde8d --- /dev/null +++ b/src/reranking/models/base.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Dict, Sequence + +import torch +from torch import nn + + +class BaseRerankingModel(nn.Module, ABC): + """Abstract base class for reranking models.""" + + def train_step(self, data: Dict[str, Any]) -> torch.Tensor: + """Run a single training step on the provided batch data and return the loss.""" + self.update_ema() + return self.train_step_imp(data) + + @abstractmethod + def train_step_imp(self, data: Dict[str, Any]) -> torch.Tensor: + """Run a single training step on the provided batch data and return the loss.""" + raise NotImplementedError + + @abstractmethod + def update_ema(self) -> None: + """Update the EMA (Exponential Moving Average) of model parameters.""" + raise NotImplementedError + + @abstractmethod + def score( + self, mention: str | Sequence[str], entity_description: str | Sequence[str] + ) -> torch.Tensor | float: + """Compute similarity-based probability for one or more mention/entity pairs.""" + raise NotImplementedError + + @abstractmethod + def score_from_tokens(self, mention: Any, entity_description: Any) -> torch.Tensor: + """Compute similarity-based probabilities for tokenized mention/entity pairs.""" + raise NotImplementedError + + @abstractmethod + def save(self, path: str) -> None: + """Save the model to the specified path.""" + raise NotImplementedError + + @abstractmethod + def load(self, path: str) -> None: + """Load the model from the specified path.""" + raise NotImplementedError diff --git a/src/reranking/models/context_emb_with_attention.py b/src/reranking/models/context_emb_with_attention.py new file mode 100644 index 0000000..2f62e47 --- /dev/null +++ b/src/reranking/models/context_emb_with_attention.py @@ -0,0 +1,345 @@ +from __future__ import annotations + +from copy import deepcopy +from typing import Any, Dict, Mapping, Sequence + +import torch +from einops import rearrange +from torch import nn +from transformers import AutoTokenizer + +from reranking.models.base import BaseRerankingModel +from utils.embeddings import create_attention_mask +from utils.model_factory import ModelFactory, ModelOutputType + + +def _maybe_convert_output_type(output_type: ModelOutputType | str | None) -> ModelOutputType | None: + if output_type is None or isinstance(output_type, ModelOutputType): + return output_type + return ModelOutputType(output_type) + + +def _infer_output_dim(model: nn.Module) -> int: + if hasattr(model, "output_dim"): + return int(getattr(model, "output_dim")) + if hasattr(model, "config") and hasattr(model.config, "hidden_size"): + return int(model.config.hidden_size) + if hasattr(model, "model"): + nested_model = getattr(model, "model") + if hasattr(nested_model, "config") and hasattr(nested_model.config, "hidden_size"): + return int(nested_model.config.hidden_size) + raise ValueError("Unable to infer output dimension from the provided base model.") + + +class GPTLayer(nn.Module): + def __init__(self, model_width: int, dropout: float) -> None: + super().__init__() + self.linear1 = nn.Linear(model_width, model_width * 4) + self.activation = nn.GELU() + self.linear2 = nn.Linear(model_width * 4, model_width) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear1(x) + x = self.activation(x) + x = self.linear2(x) + x = self.dropout(x) + return x + + +class _CATransformerBlock(nn.Module): + def __init__(self, model_width: int, dropout: float, num_heads: int) -> None: + super().__init__() + self.layer_norm1 = nn.LayerNorm(model_width) + self.self_attention = nn.MultiheadAttention( + embed_dim=model_width, num_heads=num_heads, dropout=dropout + ) + self.layer_norm2 = nn.LayerNorm(model_width) + self.feed_forward = GPTLayer(model_width, dropout) + + def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor: + # Self-attention block + residual = x + x = self.layer_norm1(x) + x, _ = self.self_attention(x, context, x) + x = x + residual + + # Feed-forward block + residual = x + x = self.layer_norm2(x) + x = self.feed_forward(x) + x = x + residual + + return x + + +class _DoubleCrossAttention(nn.Module): + def __init__(self, model_width: int, dropout: float, num_heads: int) -> None: + super().__init__() + self.transformer1 = _CATransformerBlock(model_width, dropout, num_heads) + self.transformer2 = _CATransformerBlock(model_width, dropout, num_heads) + + def forward( + self, main_input: torch.Tensor, context_input1: torch.Tensor, context_input2: torch.Tensor + ) -> torch.Tensor: + x = self.transformer1(main_input, context_input1) + x = self.transformer2(x, context_input2) + return x + + +class _Classifier(nn.Module): + def __init__( + self, + model_width: int, + dropout: float, + num_heads: int, + num_layers: int, + n_tokens: int, + base_dim: int, + paraphrase_dim: int, + ) -> None: + super().__init__() + self.double_cross_attention = nn.ModuleList( + [_DoubleCrossAttention(model_width, dropout, num_heads) for _ in range(num_layers)] + ) + self.mewsli_to_tokens = nn.Linear(base_dim, model_width * n_tokens) + self.base_to_tokens = nn.Linear(base_dim, model_width * n_tokens) + self.paraphrase_to_tokens = nn.Linear(paraphrase_dim, model_width * n_tokens) + + self.final_projection = nn.Linear(model_width * n_tokens, 1) + self.model_width = model_width + + def forward( + self, mewsli_embs: torch.Tensor, base_embs: torch.Tensor, paraphrase_embs: torch.Tensor + ) -> torch.Tensor: + mewsli_tokens = self.mewsli_to_tokens(mewsli_embs) + base_tokens = self.base_to_tokens(base_embs) + paraphrase_tokens = self.paraphrase_to_tokens(paraphrase_embs) + + mewsli_tokens = rearrange(mewsli_tokens, "b (n d) -> n b d", n=self.model_width) + base_tokens = rearrange(base_tokens, "b (n d) -> n b d", n=self.model_width) + paraphrase_tokens = rearrange(paraphrase_tokens, "b (n d) -> n b d", n=self.model_width) + + for layer in self.double_cross_attention: + mewsli_tokens = layer(mewsli_tokens, base_tokens, paraphrase_tokens) + + mewsli_tokens = rearrange(mewsli_tokens, "n b d -> b (n d)") + logits = self.final_projection(mewsli_tokens).squeeze(-1) + return logits + + +class ContextEmbWithAttention(BaseRerankingModel): + """Reranking model that augments a LEALLA encoder with an MLP head that uses paraphrase embedding to get more context.""" + + def __init__( + self, + model_name_or_path: str, + qid_to_paraphrase_emb: Dict[int, torch.Tensor], + qid_to_base_emb: Dict[int, torch.Tensor], + *, + state_dict_path: str | None = None, + target_dim: int | None = None, + output_type: ModelOutputType | str | None = None, + tokenizer_name_or_path: str | None = None, + dropout: float = 0.1, + ema_decay: float = 0.9999, + model_width: int = 64, + num_heads: int = 16, + num_layers: int = 2, + n_tokens: int = 64, + ) -> None: + super().__init__() + self.ema_decay = ema_decay + + resolved_output_type = _maybe_convert_output_type(output_type) + self.base_model = ModelFactory.auto_load_from_file( + model_name_or_path, + state_dict_path=state_dict_path, + target_dim=target_dim, + output_type=resolved_output_type, + ) + self.base_model.eval() + self.base_model.requires_grad_(False) + + self.paraphrase_model_embedding_dim = next(iter(qid_to_paraphrase_emb.values())).shape[0] + self.base_model_embedding_dim = next(iter(qid_to_base_emb.values())).shape[0] + + self.classifier = _Classifier( + model_width=model_width, + dropout=dropout, + num_heads=num_heads, + num_layers=num_layers, + n_tokens=n_tokens, + base_dim=self.base_model_embedding_dim, + paraphrase_dim=self.paraphrase_model_embedding_dim, + ) + self.classifier.to(dtype=torch.float16) + + self.classifier_ema = deepcopy(self.classifier) + + self.model = _PairwiseMLPReranker( + self.base_model, self.classifier, qid_to_paraphrase_emb, qid_to_base_emb + ) + + tokenizer_id = tokenizer_name_or_path or model_name_or_path + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) + self.loss_fn = nn.BCEWithLogitsLoss() + + def forward( + self, + mention_tokens: torch.Tensor, + entity_tokens: torch.Tensor, + ) -> torch.Tensor: + return self.model(mention_tokens, entity_tokens) + + def train_step_imp(self, data: Dict[str, Any]) -> torch.Tensor: + self.train() + + mention_tokens = data["mention_tokens"] + labels = data["labels"].float().view(-1) + qids = data["qids"].view(-1) + + logits = self.model(mention_tokens, qids).view(-1) + loss = self.loss_fn(logits, labels) + return loss + + def update_ema(self) -> None: + with torch.no_grad(): + for param, ema_param in zip( + self.classifier.parameters(), self.classifier_ema.parameters() + ): + ema_param.data.mul_(self.ema_decay).add_(param.data, alpha=1 - self.ema_decay) + + def save(self, path: str) -> None: + ema_path = path.replace(".pth", "_ema.pth") + torch.save(self.classifier_ema.state_dict(), ema_path) + torch.save(self.classifier.state_dict(), path) + + def load(self, path: str) -> None: + state_dict = torch.load(path, map_location="cpu") + self.classifier_ema.load_state_dict(state_dict) + self.classifier.load_state_dict(state_dict) + + @torch.inference_mode() + def score( + self, mention: str | Sequence[str], entity_description: str | Sequence[str] + ) -> float | torch.Tensor: + assert False, "Not correct, TODO FIX this" + single_pair = isinstance(mention, str) and isinstance(entity_description, str) + + if isinstance(mention, str): + mention_batch = [mention] + elif isinstance(mention, Sequence): + mention_batch = list(mention) + else: + raise TypeError("Mentions must be a string or a sequence of strings.") + + if isinstance(entity_description, str): + entity_batch = [entity_description] + elif isinstance(entity_description, Sequence): + entity_batch = list(entity_description) + else: + raise TypeError("Entity descriptions must be a string or a sequence of strings.") + + if len(mention_batch) != len(entity_batch): + if len(mention_batch) == 1: + mention_batch = mention_batch * len(entity_batch) + single_pair = False + elif len(entity_batch) == 1: + entity_batch = entity_batch * len(mention_batch) + single_pair = False + else: + raise ValueError( + "Mention and entity batches must be the same length or broadcastable." + ) + + mention_tokens = self.tokenizer( + mention_batch, + padding=True, + truncation=True, + return_tensors="pt", + )["input_ids"] + entity_tokens = self.tokenizer( + entity_batch, + padding=True, + truncation=True, + return_tensors="pt", + )["input_ids"] + + probabilities = self.score_from_tokens(mention_tokens, entity_tokens) + if not isinstance(probabilities, torch.Tensor): + probabilities = torch.as_tensor(probabilities) + + probabilities = probabilities.reshape(-1).detach().cpu() + + if single_pair: + return float(probabilities[0]) + return probabilities + + @torch.inference_mode() + def score_from_tokens(self, mentions: torch.Tensor, qids: torch.Tensor) -> torch.Tensor: + logits = self.model(mentions, qids) + probability = torch.sigmoid(logits).reshape(-1) + return probability + + def classifier_forward(self, x): + return self.model.classifier_forward(x) + + +class _PairwiseMLPReranker(nn.Module): + def __init__( + self, + base_model: nn.Module, + classifier: nn.Module, + qid_to_paraphrase_emb: Dict[int, torch.Tensor], + qid_to_base_emb: Dict[int, torch.Tensor], + ) -> None: + super().__init__() + self.base_model = base_model + self.classifier = classifier + self.qid_to_paraphrase_emb = qid_to_paraphrase_emb + self.qid_to_base_emb = qid_to_base_emb + + def forward( + self, + mention_tokens: torch.Tensor, + qids: torch.Tensor, + return_embeddings: bool = False, + ) -> torch.Tensor: + + mention_embeddings = self._encode(mention_tokens).clone().detach() + paraphrase_embeddings = torch.stack( + [self.qid_to_paraphrase_emb[int(qid)] for qid in qids], dim=0 + ).to(mention_embeddings.device) + base_embeddings = torch.stack([self.qid_to_base_emb[int(qid)] for qid in qids], dim=0).to( + mention_embeddings.device + ) + + logits = self.classifier( + mention_embeddings, base_embeddings, paraphrase_embeddings + ).squeeze(-1) + if return_embeddings: + return logits, mention_embeddings, paraphrase_embeddings, base_embeddings + return logits + + def train(self, mode: bool = True) -> _PairwiseMLPReranker: + super().train(mode) + # Make sure that base model is never trained. + self.base_model.eval() + return self + + def classifier_forward(self, x): + return self.classifier(x) + + @torch.inference_mode() + def _encode(self, tokens: Mapping[str, torch.Tensor] | torch.Tensor) -> torch.Tensor: + if isinstance(tokens, Mapping): + input_ids = tokens["input_ids"] + attention_mask = tokens.get("attention_mask") + if attention_mask is None: + attention_mask = create_attention_mask(input_ids) + else: + input_ids = tokens + attention_mask = create_attention_mask(tokens) + + return self.base_model(input_ids=input_ids, attention_mask=attention_mask) diff --git a/src/reranking/models/full_lealla.py b/src/reranking/models/full_lealla.py new file mode 100644 index 0000000..3f014e8 --- /dev/null +++ b/src/reranking/models/full_lealla.py @@ -0,0 +1,214 @@ +from __future__ import annotations + +from copy import deepcopy +from typing import Any, Dict, Sequence + +import torch +from einops import rearrange +from torch import nn +from transformers import AutoTokenizer + +from reranking.models.base import BaseRerankingModel +from reranking.models.context_emb_with_attention import GPTLayer +from utils.embeddings import create_attention_mask +from utils.model_factory import ModelFactory, ModelOutputType + + +def _maybe_convert_output_type(output_type: ModelOutputType | str | None) -> ModelOutputType | None: + if output_type is None or isinstance(output_type, ModelOutputType): + return output_type + return ModelOutputType(output_type) + + +def _infer_output_dim(model: nn.Module) -> int: + if hasattr(model, "output_dim"): + return int(getattr(model, "output_dim")) + if hasattr(model, "config") and hasattr(model.config, "hidden_size"): + return int(model.config.hidden_size) + if hasattr(model, "model"): + nested_model = getattr(model, "model") + if hasattr(nested_model, "config") and hasattr(nested_model.config, "hidden_size"): + return int(nested_model.config.hidden_size) + raise ValueError("Unable to infer output dimension from the provided base model.") + + +class _Model(nn.Module): + def __init__(self, base_model: nn.Module, embedding_dim: int, dropout: float) -> None: + super().__init__() + self.base_model = base_model + self.embedding_dim = embedding_dim + self.gpt_layer = GPTLayer(model_width=embedding_dim, dropout=dropout) + self.final_layer = nn.Linear(embedding_dim, 1) + + def forward(self, ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + # print(ids.sh1ape, attention_mask.shape) + base_embeddings = self.base_model(ids, attention_mask) + # print("base_embeddings", base_embeddings.shape) + x = self.gpt_layer(base_embeddings) + # print("x", x.shape) + logits = self.final_layer(x).squeeze(-1) + # print("logits", logits.shape) + return logits + + +class FullLEALLAReranker(BaseRerankingModel): + """Reranking model that augments a LEALLA encoder with an MLP head.""" + + def __init__( + self, + model_name_or_path: str, + *, + state_dict_path: str | None = None, + tokenizer_name_or_path: str | None = None, + dropout: float = 0.1, + ema_decay: float = 0.9999, + embedding_dim: int | None = None, + ) -> None: + super().__init__() + + self.ema_decay = ema_decay + + self.base_model = ModelFactory.auto_load_from_file( + model_name_or_path, + state_dict_path=state_dict_path, + ) + if embedding_dim is None: + self.embedding_dim = _infer_output_dim(self.base_model) + else: + self.embedding_dim = embedding_dim + + self.model = _Model( + base_model=self.base_model, + embedding_dim=self.embedding_dim, + dropout=dropout, + ) + self.model_ema = deepcopy(self.model) + + tokenizer_id = tokenizer_name_or_path or model_name_or_path + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) + self.loss_fn = nn.BCEWithLogitsLoss() + + self.tokenizer = AutoTokenizer.from_pretrained( + "/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base" + ) + + def forward( + self, + mention_tokens: torch.Tensor, + entity_tokens: torch.Tensor, + ) -> torch.Tensor: + ids, attention_mask = self.prepare_for_forward(mention_tokens, entity_tokens) + return self.model(ids, attention_mask) + + def prepare_for_forward(self, mention_tokens: torch.Tensor, entity_tokens: torch.Tensor): + ids = torch.cat([mention_tokens, entity_tokens], dim=1) + attention_mask = create_attention_mask(ids).to(dtype=ids.dtype, device=ids.device) + return ids, attention_mask + + def train_step_imp(self, data: Dict[str, Any]) -> torch.Tensor: + self.train() + + mention_tokens = data["mention_tokens"] + entity_tokens = data["entity_tokens"] + labels = data["labels"].float() + ids, attention_mask = self.prepare_for_forward(mention_tokens, entity_tokens) + + print(f"Min token ID: {ids.min()}, Max token ID: {ids.max()}") + print(f"Model vocab size: {self.base_model.model.config.vocab_size}") + print(ids) + + assert ( + self.embedding_dim == ids.shape[-1] + ), f"Expected embedding dimension {self.embedding_dim}, but got {ids.shape[-1]}" + + logits = self.model(ids, attention_mask) + + loss = self.loss_fn(logits, labels) + return loss + + def update_ema(self) -> None: + with torch.no_grad(): + for param, ema_param in zip(self.model.parameters(), self.model_ema.parameters()): + ema_param.data.mul_(self.ema_decay).add_(param.data, alpha=1 - self.ema_decay) + + def save(self, path: str) -> None: + ema_path = path.replace(".pth", "_ema.pth") + torch.save(self.model_ema.state_dict(), ema_path) + torch.save(self.model.state_dict(), path) + + def load(self, path: str) -> None: + state_dict = torch.load(path, map_location="cpu") + new_state_dict = {} + for k, v in state_dict.items(): + if k.startswith("_orig_mod.module.model."): + new_k = k.replace("_orig_mod.module.model.", "") + elif k.startswith("module."): + new_k = k.replace("module.", "") + else: + new_k = k + new_state_dict[new_k] = v + state_dict = new_state_dict + self.model_ema.load_state_dict(state_dict) + self.model.load_state_dict(state_dict) + + @torch.inference_mode() + def score( + self, mention: str | Sequence[str], entity_description: str | Sequence[str] + ) -> float | torch.Tensor: + single_pair = isinstance(mention, str) and isinstance(entity_description, str) + + if isinstance(mention, str): + mention_batch = [mention] + elif isinstance(mention, Sequence): + mention_batch = list(mention) + else: + raise TypeError("Mentions must be a string or a sequence of strings.") + + if isinstance(entity_description, str): + entity_batch = [entity_description] + elif isinstance(entity_description, Sequence): + entity_batch = list(entity_description) + else: + raise TypeError("Entity descriptions must be a string or a sequence of strings.") + + if len(mention_batch) != len(entity_batch): + if len(mention_batch) == 1: + mention_batch = mention_batch * len(entity_batch) + single_pair = False + elif len(entity_batch) == 1: + entity_batch = entity_batch * len(mention_batch) + single_pair = False + else: + raise ValueError( + "Mention and entity batches must be the same length or broadcastable." + ) + + mention_tokens = self.tokenizer( + mention_batch, + padding=True, + truncation=True, + return_tensors="pt", + )["input_ids"] + entity_tokens = self.tokenizer( + entity_batch, + padding=True, + truncation=True, + return_tensors="pt", + )["input_ids"] + + probabilities = self.score_from_tokens(mention_tokens, entity_tokens) + if not isinstance(probabilities, torch.Tensor): + probabilities = torch.as_tensor(probabilities) + + probabilities = probabilities.reshape(-1).detach().cpu() + + if single_pair: + return float(probabilities[0]) + return probabilities + + @torch.inference_mode() + def score_from_tokens(self, mentions: torch.Tensor, entities: torch.Tensor) -> torch.Tensor: + ids, attention_mask = self.prepare_for_forward(mentions, entities) + logits = self.model(ids, attention_mask) + probability = torch.sigmoid(logits).reshape(-1) + return probability diff --git a/src/reranking/models/full_lealla_multiclass.py b/src/reranking/models/full_lealla_multiclass.py new file mode 100644 index 0000000..4a62e00 --- /dev/null +++ b/src/reranking/models/full_lealla_multiclass.py @@ -0,0 +1,209 @@ +from __future__ import annotations + +from copy import deepcopy +from typing import Any, Dict, Sequence + +import torch +from einops import rearrange +from torch import nn +from transformers import AutoTokenizer + +from reranking.models.base import BaseRerankingModel +from reranking.models.context_emb_with_attention import GPTLayer +from utils.embeddings import create_attention_mask +from utils.model_factory import ModelFactory, ModelOutputType + + +def _maybe_convert_output_type(output_type: ModelOutputType | str | None) -> ModelOutputType | None: + if output_type is None or isinstance(output_type, ModelOutputType): + return output_type + return ModelOutputType(output_type) + + +def _infer_output_dim(model: nn.Module) -> int: + if hasattr(model, "output_dim"): + return int(getattr(model, "output_dim")) + if hasattr(model, "config") and hasattr(model.config, "hidden_size"): + return int(model.config.hidden_size) + if hasattr(model, "model"): + nested_model = getattr(model, "model") + if hasattr(nested_model, "config") and hasattr(nested_model.config, "hidden_size"): + return int(nested_model.config.hidden_size) + raise ValueError("Unable to infer output dimension from the provided base model.") + + +class _Model(nn.Module): + def __init__( + self, base_model: nn.Module, embedding_dim: int, dropout: float, output_size: int + ) -> None: + super().__init__() + self.base_model = base_model + self.embedding_dim = embedding_dim + self.gpt_layer = GPTLayer(model_width=embedding_dim, dropout=dropout) + self.final_layer = nn.Linear(embedding_dim, output_size) + + def forward(self, ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + # print(ids.sh1ape, attention_mask.shape) + base_embeddings = self.base_model(ids, attention_mask) + # print("base_embeddings", base_embeddings.shape) + x = self.gpt_layer(base_embeddings) + # print("x", x.shape) + return self.final_layer(x) + + +class FullLEALLARerankerMulticlass(BaseRerankingModel): + """Reranking model that augments a LEALLA encoder with an MLP head.""" + + def __init__( + self, + model_name_or_path: str, + *, + state_dict_path: str | None = None, + tokenizer_name_or_path: str | None = None, + dropout: float = 0.1, + ema_decay: float = 0.9999, + embedding_dim: int | None = None, + output_size: int = 7, + ) -> None: + super().__init__() + + self.ema_decay = ema_decay + + self.base_model = ModelFactory.auto_load_from_file( + model_name_or_path, + state_dict_path=state_dict_path, + ) + if embedding_dim is None: + self.embedding_dim = _infer_output_dim(self.base_model) + else: + self.embedding_dim = embedding_dim + + self.model = _Model( + base_model=self.base_model, + embedding_dim=self.embedding_dim, + dropout=dropout, + output_size=output_size, + ) + self.model_ema = deepcopy(self.model) + + tokenizer_id = tokenizer_name_or_path or model_name_or_path + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) + self.loss_fn = nn.CrossEntropyLoss() + + self.tokenizer = AutoTokenizer.from_pretrained( + "/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base" + ) + + def forward( + self, + mention_tokens: torch.Tensor, + entity_tokens: torch.Tensor, + ) -> torch.Tensor: + ids, attention_mask = self.prepare_for_forward(mention_tokens, entity_tokens) + return self.model(ids, attention_mask) + + def prepare_for_forward(self, mention_tokens: torch.Tensor, entity_tokens: torch.Tensor): + entity_tokens = torch.cat(entity_tokens, dim=1) + ids = torch.cat([mention_tokens, entity_tokens], dim=1) + attention_mask = create_attention_mask(ids).to(dtype=ids.dtype, device=ids.device) + return ids, attention_mask + + def train_step_imp(self, data: Dict[str, Any]) -> torch.Tensor: + self.train() + + mention_tokens = data["mention_tokens"] + entity_tokens = data["entity_tokens"] + labels = data["labels"].float() + ids, attention_mask = self.prepare_for_forward(mention_tokens, entity_tokens) + + logits = self.model(ids, attention_mask) + + loss = self.loss_fn(logits, labels) + return loss + + def update_ema(self) -> None: + with torch.no_grad(): + for param, ema_param in zip(self.model.parameters(), self.model_ema.parameters()): + ema_param.data.mul_(self.ema_decay).add_(param.data, alpha=1 - self.ema_decay) + + def save(self, path: str) -> None: + ema_path = path.replace(".pth", "_ema.pth") + torch.save(self.model_ema.state_dict(), ema_path) + torch.save(self.model.state_dict(), path) + + def load(self, path: str) -> None: + state_dict = torch.load(path, map_location="cpu") + new_state_dict = {} + for k, v in state_dict.items(): + if k.startswith("_orig_mod.module.model."): + new_k = k.replace("_orig_mod.module.model.", "") + elif k.startswith("module."): + new_k = k.replace("module.", "") + else: + new_k = k + new_state_dict[new_k] = v + state_dict = new_state_dict + self.model_ema.load_state_dict(state_dict) + self.model.load_state_dict(state_dict) + + @torch.inference_mode() + def score( + self, mention: str | Sequence[str], entity_description: str | Sequence[str] + ) -> float | torch.Tensor: + single_pair = isinstance(mention, str) and isinstance(entity_description, str) + + if isinstance(mention, str): + mention_batch = [mention] + elif isinstance(mention, Sequence): + mention_batch = list(mention) + else: + raise TypeError("Mentions must be a string or a sequence of strings.") + + if isinstance(entity_description, str): + entity_batch = [entity_description] + elif isinstance(entity_description, Sequence): + entity_batch = list(entity_description) + else: + raise TypeError("Entity descriptions must be a string or a sequence of strings.") + + if len(mention_batch) != len(entity_batch): + if len(mention_batch) == 1: + mention_batch = mention_batch * len(entity_batch) + single_pair = False + elif len(entity_batch) == 1: + entity_batch = entity_batch * len(mention_batch) + single_pair = False + else: + raise ValueError( + "Mention and entity batches must be the same length or broadcastable." + ) + + mention_tokens = self.tokenizer( + mention_batch, + padding=True, + truncation=True, + return_tensors="pt", + )["input_ids"] + entity_tokens = self.tokenizer( + entity_batch, + padding=True, + truncation=True, + return_tensors="pt", + )["input_ids"] + + probabilities = self.score_from_tokens(mention_tokens, entity_tokens) + if not isinstance(probabilities, torch.Tensor): + probabilities = torch.as_tensor(probabilities) + + probabilities = probabilities.reshape(-1).detach().cpu() + + if single_pair: + return float(probabilities[0]) + return probabilities + + @torch.inference_mode() + def score_from_tokens(self, mentions: torch.Tensor, entities: torch.Tensor) -> torch.Tensor: + ids, attention_mask = self.prepare_for_forward(mentions, entities) + logits = self.model(ids, attention_mask) + probability = torch.sigmoid(logits).reshape(-1) + return probability diff --git a/src/reranking/models/fusion.py b/src/reranking/models/fusion.py new file mode 100644 index 0000000..060d045 --- /dev/null +++ b/src/reranking/models/fusion.py @@ -0,0 +1,265 @@ +from __future__ import annotations + +from copy import deepcopy +from typing import Any, Dict, Sequence + +import torch +from einops import rearrange +from torch import nn +from transformers import AutoTokenizer + +from reranking.models.base import BaseRerankingModel +from reranking.models.context_emb_with_attention import GPTLayer +from utils.embeddings import create_attention_mask +from utils.model_factory import ModelFactory, ModelOutputType + + +def _infer_output_dim(model: nn.Module) -> int: + if hasattr(model, "output_dim"): + return int(getattr(model, "output_dim")) + if hasattr(model, "config") and hasattr(model.config, "hidden_size"): + return int(model.config.hidden_size) + if hasattr(model, "model"): + nested_model = getattr(model, "model") + if hasattr(nested_model, "config") and hasattr(nested_model.config, "hidden_size"): + return int(nested_model.config.hidden_size) + raise ValueError("Unable to infer output dimension from the provided base model.") + + +class SimpleCrossAttentionLayer(nn.Module): + """A simple wrapper for PyTorch's MultiheadAttention to perform cross-attention.""" + + def __init__(self, query_dim, key_value_dim, num_heads=8, dropout=0.1): + super().__init__() + # The MHA layer that handles everything, including projections! + # Note: We must set batch_first=True for BERT-style inputs. + self.attention = nn.MultiheadAttention( + embed_dim=query_dim, + num_heads=num_heads, + kdim=key_value_dim, + vdim=key_value_dim, + dropout=dropout, + batch_first=True, # Important for [batch, seq_len, dim] tensors + ) + self.norm = nn.LayerNorm(query_dim) + self.dropout = nn.Dropout(dropout) + + def forward(self, query, key_value, attention_mask=None): + # query: The tensor that asks the questions (from the target model). + # key_value: The tensor that provides the context (from the source model). + + # The MHA layer returns the attention output and weights. We only need the output. + attn_output, _ = self.attention( + query=query, + key=key_value, + value=key_value, + key_padding_mask=attention_mask, # Optional mask for padded tokens + ) + + # Add residual connection and layer norm + output = self.norm(query + self.dropout(attn_output)) + return output + + +class BertFusionModel(nn.Module): + def __init__(self, model1, model2): + super().__init__() + self.base = model1 # The 24-layer model (dim 192) + self.paraphrase = model2 # The 12-layer model (dim 384) + + # Simpler cross-attention layers using the wrapper + self.cross_layers_1_to_2 = nn.ModuleList( + [SimpleCrossAttentionLayer(query_dim=384, key_value_dim=192) for _ in range(12)] + ) + self.cross_layers_2_to_1 = nn.ModuleList( + [SimpleCrossAttentionLayer(query_dim=192, key_value_dim=384) for _ in range(12)] + ) + + def forward(self, input_ids1, attention_mask1, input_ids2, attention_mask2): + # 1. Get initial embeddings + hidden_states1 = self.base.embeddings(input_ids=input_ids1) + hidden_states2 = self.paraphrase.embeddings(input_ids=input_ids2) + + # (Note: You may need to adapt the huggingface attention_mask format for MHA's key_padding_mask) + + # 2. Iteratively process through layers with fusion + for i in range(12): + # Process layers for each model + hidden_states1 = self.base.encoder.layer[2 * i](hidden_states1)[0] + hidden_states1 = self.base.encoder.layer[2 * i + 1](hidden_states1)[0] + hidden_states2 = self.paraphrase.encoder.layer[i](hidden_states2)[0] + + # Store states before the cross-attention step to avoid in-place modification issues + temp_states1, temp_states2 = hidden_states1, hidden_states2 + + # Cross-attention exchange + hidden_states1 = self.cross_layers_2_to_1[i](query=temp_states1, key_value=temp_states2) + hidden_states2 = self.cross_layers_1_to_2[i](query=temp_states2, key_value=temp_states1) + + # 3. Get final outputs + pooled_output1 = self.base.pooler(hidden_states1) + pooled_output2 = self.paraphrase.pooler(hidden_states2) + + return torch.cat([pooled_output1, pooled_output2], dim=-1) + + +class _Model(nn.Module): + def __init__(self, fusion_model: BertFusionModel, dropout: float = 0.1) -> None: + super().__init__() + self.fusion_model = fusion_model + self.embedding_dim = 384 + 192 + self.gpt_layer = GPTLayer(model_width=self.embedding_dim, dropout=dropout) + self.final_layer = nn.Linear(self.embedding_dim, 1) + + def forward( + self, + input_ids1: torch.Tensor, + attention_mask1: torch.Tensor, + input_ids2: torch.Tensor, + attention_mask2: torch.Tensor, + ) -> torch.Tensor: + # print(ids.shape, attention_mask.shape) + base_embeddings = self.fusion_model( + input_ids1, attention_mask1, input_ids2, attention_mask2 + ) + # print("base_embeddings", base_embeddings.shape) + x = self.gpt_layer(base_embeddings) + # print("x", x.shape) + logits = self.final_layer(x).squeeze(-1) + # print("logits", logits.shape) + return logits + + +class FusionReranker(BaseRerankingModel): + """Reranking model that augments a LEALLA encoder with an MLP head.""" + + def __init__( + self, + base_model_name_or_path: str, + paraphrase_model_name_or_path: str, + qid_to_para_toks: Dict[str, torch.Tensor], + *, + base_state_dict_path: str | None = None, + tokenizer_name_or_path: str | None = None, + dropout: float = 0.1, + ema_decay: float = 0.9999, + ) -> None: + super().__init__() + + self.qid_to_para_toks = qid_to_para_toks + + self.ema_decay = ema_decay + + self.base_model = ModelFactory.auto_load_from_file( + base_model_name_or_path, + state_dict_path=base_state_dict_path, + ).model + self.paraphrase_model = ModelFactory.auto_load_from_file( + paraphrase_model_name_or_path, + output_type="sentence_transformer", + ).model + + fusion_model = BertFusionModel(self.base_model, self.paraphrase_model) + + self.model = _Model( + fusion_model=fusion_model, + dropout=dropout, + ) + self.model_ema = deepcopy(self.model) + + tokenizer_id = tokenizer_name_or_path or base_model_name_or_path + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) + self.loss_fn = nn.BCEWithLogitsLoss() + + def forward( + self, + mention_tokens: torch.Tensor, + entity_tokens: torch.Tensor, + ) -> torch.Tensor: + ids, attention_mask = self.prepare_for_forward(mention_tokens, entity_tokens) + return self.model(ids, attention_mask) + + def prepare_for_forward(self, mention_tokens: torch.Tensor, entity_tokens: torch.Tensor): + ids = rearrange([mention_tokens, entity_tokens], "d b n -> b (d n)") + attention_mask = create_attention_mask(ids).to(dtype=ids.dtype, device=ids.device) + return ids, attention_mask + + def train_step_imp(self, data: Dict[str, Any]) -> torch.Tensor: + self.train() + + mention_tokens = data["mention_tokens"] + entity_tokens = data["entity_tokens"] + labels = data["labels"].float() + qids = data["qids"].numpy() + + para_tokens = torch.stack([self.qid_to_para_toks[qid] for qid in qids], dim=0).to( + device=mention_tokens.device + ) + para_attention_mask = create_attention_mask(para_tokens).to( + dtype=para_tokens.dtype, device=para_tokens.device + ) + + ids, attention_mask = self.prepare_for_forward(mention_tokens, entity_tokens) + attention_mask = attention_mask.to(dtype=ids.dtype, device=ids.device) + logits = self.model(ids, attention_mask, para_tokens, para_attention_mask) + + loss = self.loss_fn(logits, labels) + return loss + + def update_ema(self) -> None: + with torch.no_grad(): + for param, ema_param in zip(self.model.parameters(), self.model_ema.parameters()): + ema_param.data.mul_(self.ema_decay).add_(param.data, alpha=1 - self.ema_decay) + + def save(self, path: str) -> None: + ema_path = path.replace(".pth", "_ema.pth") + torch.save(self.model_ema.state_dict(), ema_path) + torch.save(self.model.state_dict(), path) + + def load(self, path: str) -> None: + state_dict = torch.load(path, map_location="cpu") + new_state_dict = {} + for k, v in state_dict.items(): + if k.startswith("_orig_mod.module.model."): + new_k = k.replace("_orig_mod.module.model.", "") + elif k.startswith("module."): + new_k = k.replace("module.", "") + else: + new_k = k + new_state_dict[new_k] = v + state_dict = new_state_dict + self.model_ema.load_state_dict(state_dict) + self.model.load_state_dict(state_dict) + + @torch.inference_mode() + def score( + self, mention: str | Sequence[str], entity_description: str | Sequence[str] + ) -> float | torch.Tensor: + raise NotImplementedError("Use score_from_tokens_and_qids for FusionReranker.") + + @torch.inference_mode() + def score_from_tokens_and_qids( + self, mentions: torch.Tensor, entities: torch.Tensor, qids: torch.Tensor + ) -> torch.Tensor: + qids = qids.numpy() + para_tokens = torch.stack([self.qid_to_para_toks[qid] for qid in qids], dim=0).to( + device=mentions.device + ) + para_attention_mask = create_attention_mask(para_tokens).to( + dtype=para_tokens.dtype, device=para_tokens.device + ) + ids, attention_mask = self.prepare_for_forward(mentions, entities) + attention_mask = attention_mask.to(dtype=ids.dtype, device=ids.device) + logits = self.model(ids, attention_mask, para_tokens, para_attention_mask) + + probabilities = torch.sigmoid(logits).reshape(-1) + + if not isinstance(probabilities, torch.Tensor): + probabilities = torch.as_tensor(probabilities) + + probabilities = probabilities.reshape(-1).detach().cpu() + return probabilities + + @torch.inference_mode() + def score_from_tokens(self, mentions: torch.Tensor, entities: torch.Tensor) -> torch.Tensor: + raise NotImplementedError("Use score_from_tokens_and_qids for FusionReranker.") diff --git a/src/reranking/models/pairwise_mlp.py b/src/reranking/models/pairwise_mlp.py new file mode 100644 index 0000000..fc1c458 --- /dev/null +++ b/src/reranking/models/pairwise_mlp.py @@ -0,0 +1,243 @@ +from __future__ import annotations + +from copy import deepcopy +from typing import Any, Dict, Mapping, Sequence + +import torch +from torch import nn +from transformers import AutoTokenizer + +from reranking.models.base import BaseRerankingModel +from utils.embeddings import create_attention_mask +from utils.model_factory import ModelFactory, ModelOutputType + + +def _maybe_convert_output_type(output_type: ModelOutputType | str | None) -> ModelOutputType | None: + if output_type is None or isinstance(output_type, ModelOutputType): + return output_type + return ModelOutputType(output_type) + + +def _infer_output_dim(model: nn.Module) -> int: + if hasattr(model, "output_dim"): + return int(getattr(model, "output_dim")) + if hasattr(model, "config") and hasattr(model.config, "hidden_size"): + return int(model.config.hidden_size) + if hasattr(model, "model"): + nested_model = getattr(model, "model") + if hasattr(nested_model, "config") and hasattr(nested_model.config, "hidden_size"): + return int(nested_model.config.hidden_size) + raise ValueError("Unable to infer output dimension from the provided base model.") + + +class PairwiseMLPReranker(BaseRerankingModel): + """Reranking model that augments a LEALLA encoder with an MLP head.""" + + def __init__( + self, + model_name_or_path: str, + *, + state_dict_path: str | None = None, + target_dim: int | None = None, + output_type: ModelOutputType | str | None = None, + tokenizer_name_or_path: str | None = None, + mlp_hidden_dim: int | None = None, + dropout: float = 0.1, + emb_noise: float = 0, + ema_decay: float = 0.9999, + ) -> None: + super().__init__() + + self.ema_decay = ema_decay + + resolved_output_type = _maybe_convert_output_type(output_type) + self.base_model = ModelFactory.auto_load_from_file( + model_name_or_path, + state_dict_path=state_dict_path, + target_dim=target_dim, + output_type=resolved_output_type, + ) + self.base_model.eval() + self.base_model.requires_grad_(False) + + self.embedding_dim = _infer_output_dim(self.base_model) + hidden_dim = mlp_hidden_dim or self.embedding_dim + + self.classifier = nn.Sequential( + nn.Linear(self.embedding_dim * 2, hidden_dim * 4), + nn.GELU(), + nn.Linear(4 * hidden_dim, hidden_dim), + nn.GELU(), + nn.Dropout(p=dropout), + nn.Linear(hidden_dim, 1), + ) + + self.classifier_ema = deepcopy(self.classifier) + + self.model = _PairwiseMLPReranker(self.base_model, self.classifier, emb_noise) + + tokenizer_id = tokenizer_name_or_path or model_name_or_path + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) + self.loss_fn = nn.BCEWithLogitsLoss() + + def forward( + self, + mention_tokens: torch.Tensor, + entity_tokens: torch.Tensor, + ) -> torch.Tensor: + return self.model.forward(mention_tokens, entity_tokens) + + def classifier_forward(self, x: torch.Tensor) -> torch.Tensor: + return self.model.classifier_forward(x) + + def train_step_imp(self, data: Dict[str, Any]) -> torch.Tensor: + self.train() + + mention_tokens = data["mention_tokens"] + entity_tokens = data["entity_tokens"] + labels = data["labels"].float().view(-1) + + logits = self.forward(mention_tokens, entity_tokens).view(-1) + loss = self.loss_fn(logits, labels) + return loss + + def update_ema(self) -> None: + with torch.no_grad(): + for param, ema_param in zip( + self.classifier.parameters(), self.classifier_ema.parameters() + ): + ema_param.data.mul_(self.ema_decay).add_(param.data, alpha=1 - self.ema_decay) + + def save(self, path: str) -> None: + ema_path = path.replace(".pth", "_ema.pth") + torch.save(self.classifier_ema.state_dict(), ema_path) + torch.save(self.classifier.state_dict(), path) + + def load(self, path: str) -> None: + state_dict = torch.load(path, map_location="cpu") + self.classifier_ema.load_state_dict(state_dict) + self.classifier.load_state_dict(state_dict) + + @torch.inference_mode() + def score( + self, mention: str | Sequence[str], entity_description: str | Sequence[str] + ) -> float | torch.Tensor: + single_pair = isinstance(mention, str) and isinstance(entity_description, str) + + if isinstance(mention, str): + mention_batch = [mention] + elif isinstance(mention, Sequence): + mention_batch = list(mention) + else: + raise TypeError("Mentions must be a string or a sequence of strings.") + + if isinstance(entity_description, str): + entity_batch = [entity_description] + elif isinstance(entity_description, Sequence): + entity_batch = list(entity_description) + else: + raise TypeError("Entity descriptions must be a string or a sequence of strings.") + + if len(mention_batch) != len(entity_batch): + if len(mention_batch) == 1: + mention_batch = mention_batch * len(entity_batch) + single_pair = False + elif len(entity_batch) == 1: + entity_batch = entity_batch * len(mention_batch) + single_pair = False + else: + raise ValueError( + "Mention and entity batches must be the same length or broadcastable." + ) + + mention_tokens = self.tokenizer( + mention_batch, + padding=True, + truncation=True, + return_tensors="pt", + )["input_ids"] + entity_tokens = self.tokenizer( + entity_batch, + padding=True, + truncation=True, + return_tensors="pt", + )["input_ids"] + + probabilities = self.score_from_tokens(mention_tokens, entity_tokens) + if not isinstance(probabilities, torch.Tensor): + probabilities = torch.as_tensor(probabilities) + + probabilities = probabilities.reshape(-1).detach().cpu() + + if single_pair: + return float(probabilities[0]) + return probabilities + + @torch.inference_mode() + def score_from_tokens(self, mentions: torch.Tensor, entities: torch.Tensor) -> torch.Tensor: + logits = self.model.forward(mentions, entities) + probability = torch.sigmoid(logits).reshape(-1) + return probability + + +class _PairwiseMLPReranker(nn.Module): + def __init__(self, base_model: nn.Module, classifier: nn.Module, emb_noise: float) -> None: + super().__init__() + self.base_model = base_model + self.classifier = classifier + self.emb_noise = emb_noise + self.noise_layer = _GaussianNoiseLayer(emb_noise) + + def forward( + self, + mention_tokens: torch.Tensor, + entity_tokens: torch.Tensor, + return_embeddings: bool = False, + ) -> torch.Tensor: + + mention_embeddings = self._encode(mention_tokens) + entity_embeddings = self._encode(entity_tokens) + + if self.emb_noise > 0: + mention_embeddings = self.noise_layer(mention_embeddings) + entity_embeddings = self.noise_layer(entity_embeddings) + + combined = torch.cat([mention_embeddings, entity_embeddings], dim=-1) + logits = self.classifier(combined).squeeze(-1) + if return_embeddings: + return logits, mention_embeddings, entity_embeddings + return logits + + def train(self, mode: bool = True) -> _PairwiseMLPReranker: + super().train(mode) + # Make sure that base model is never trained. + self.base_model.eval() + return self + + def classifier_forward(self, x: torch.Tensor) -> torch.Tensor: + return self.classifier(x) + + @torch.inference_mode() + def _encode(self, tokens: Mapping[str, torch.Tensor] | torch.Tensor) -> torch.Tensor: + if isinstance(tokens, Mapping): + input_ids = tokens["input_ids"] + attention_mask = tokens.get("attention_mask") + if attention_mask is None: + attention_mask = create_attention_mask(input_ids) + else: + input_ids = tokens + attention_mask = create_attention_mask(tokens) + + return self.base_model(input_ids=input_ids, attention_mask=attention_mask) + + +class _GaussianNoiseLayer(nn.Module): + def __init__(self, std): + super().__init__() + self.std = std + + def forward(self, x): + if self.training: + noise = torch.randn_like(x) * self.std + return noise + x + return x diff --git a/src/reranking/models/pairwise_mlp_with_large_context_emb.py b/src/reranking/models/pairwise_mlp_with_large_context_emb.py new file mode 100644 index 0000000..858c66e --- /dev/null +++ b/src/reranking/models/pairwise_mlp_with_large_context_emb.py @@ -0,0 +1,241 @@ +from __future__ import annotations + +from copy import deepcopy +from typing import Any, Dict, Mapping, Sequence + +import torch +from torch import nn +from transformers import AutoTokenizer + +from reranking.models.base import BaseRerankingModel +from utils.embeddings import create_attention_mask +from utils.model_factory import ModelFactory, ModelOutputType + + +def _maybe_convert_output_type(output_type: ModelOutputType | str | None) -> ModelOutputType | None: + if output_type is None or isinstance(output_type, ModelOutputType): + return output_type + return ModelOutputType(output_type) + + +def _infer_output_dim(model: nn.Module) -> int: + if hasattr(model, "output_dim"): + return int(getattr(model, "output_dim")) + if hasattr(model, "config") and hasattr(model.config, "hidden_size"): + return int(model.config.hidden_size) + if hasattr(model, "model"): + nested_model = getattr(model, "model") + if hasattr(nested_model, "config") and hasattr(nested_model.config, "hidden_size"): + return int(nested_model.config.hidden_size) + raise ValueError("Unable to infer output dimension from the provided base model.") + + +class PairwiseMLPRerankerWithLargeContextEmb(BaseRerankingModel): + """Reranking model that augments a LEALLA encoder with an MLP head that uses paraphrase embedding to get more context.""" + + def __init__( + self, + model_name_or_path: str, + qid_to_paraphrase_emb: Dict[int, torch.Tensor], + qid_to_base_emb: Dict[int, torch.Tensor], + *, + state_dict_path: str | None = None, + target_dim: int | None = None, + output_type: ModelOutputType | str | None = None, + tokenizer_name_or_path: str | None = None, + dropout: float = 0.1, + ema_decay: float = 0.9999, + ) -> None: + super().__init__() + self.ema_decay = ema_decay + + resolved_output_type = _maybe_convert_output_type(output_type) + self.base_model = ModelFactory.auto_load_from_file( + model_name_or_path, + state_dict_path=state_dict_path, + target_dim=target_dim, + output_type=resolved_output_type, + ) + self.base_model.eval() + self.base_model.requires_grad_(False) + + self.paraphrase_model_embedding_dim = next(iter(qid_to_paraphrase_emb.values())).shape[0] + self.base_model_embedding_dim = next(iter(qid_to_base_emb.values())).shape[0] + + hidden_dim = 2 * self.base_model_embedding_dim + self.paraphrase_model_embedding_dim + + self.classifier = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim * 4), + nn.GELU(), + nn.Linear(4 * hidden_dim, hidden_dim), + nn.GELU(), + nn.Dropout(p=dropout), + nn.Linear(hidden_dim, 1), + ) + + self.classifier_ema = deepcopy(self.classifier) + + self.model = _PairwiseMLPReranker( + self.base_model, self.classifier, qid_to_paraphrase_emb, qid_to_base_emb + ) + + tokenizer_id = tokenizer_name_or_path or model_name_or_path + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) + self.loss_fn = nn.BCEWithLogitsLoss() + + def forward( + self, + mention_tokens: torch.Tensor, + entity_tokens: torch.Tensor, + ) -> torch.Tensor: + return self.model.forward(mention_tokens, entity_tokens) + + def train_step_imp(self, data: Dict[str, Any]) -> torch.Tensor: + self.train() + + mention_tokens = data["mention_tokens"] + labels = data["labels"].float().view(-1) + qids = data["qids"].view(-1) + + logits = self.model.forward(mention_tokens, qids).view(-1) + loss = self.loss_fn(logits, labels) + return loss + + def update_ema(self) -> None: + with torch.no_grad(): + for param, ema_param in zip( + self.classifier.parameters(), self.classifier_ema.parameters() + ): + ema_param.data.mul_(self.ema_decay).add_(param.data, alpha=1 - self.ema_decay) + + def save(self, path: str) -> None: + ema_path = path.replace(".pth", "_ema.pth") + torch.save(self.classifier_ema.state_dict(), ema_path) + torch.save(self.classifier.state_dict(), path) + + def load(self, path: str) -> None: + state_dict = torch.load(path, map_location="cpu") + self.classifier_ema.load_state_dict(state_dict) + self.classifier.load_state_dict(state_dict) + + @torch.inference_mode() + def score( + self, mention: str | Sequence[str], entity_description: str | Sequence[str] + ) -> float | torch.Tensor: + single_pair = isinstance(mention, str) and isinstance(entity_description, str) + + if isinstance(mention, str): + mention_batch = [mention] + elif isinstance(mention, Sequence): + mention_batch = list(mention) + else: + raise TypeError("Mentions must be a string or a sequence of strings.") + + if isinstance(entity_description, str): + entity_batch = [entity_description] + elif isinstance(entity_description, Sequence): + entity_batch = list(entity_description) + else: + raise TypeError("Entity descriptions must be a string or a sequence of strings.") + + if len(mention_batch) != len(entity_batch): + if len(mention_batch) == 1: + mention_batch = mention_batch * len(entity_batch) + single_pair = False + elif len(entity_batch) == 1: + entity_batch = entity_batch * len(mention_batch) + single_pair = False + else: + raise ValueError( + "Mention and entity batches must be the same length or broadcastable." + ) + + mention_tokens = self.tokenizer( + mention_batch, + padding=True, + truncation=True, + return_tensors="pt", + )["input_ids"] + entity_tokens = self.tokenizer( + entity_batch, + padding=True, + truncation=True, + return_tensors="pt", + )["input_ids"] + + probabilities = self.score_from_tokens(mention_tokens, entity_tokens) + if not isinstance(probabilities, torch.Tensor): + probabilities = torch.as_tensor(probabilities) + + probabilities = probabilities.reshape(-1).detach().cpu() + + if single_pair: + return float(probabilities[0]) + return probabilities + + @torch.inference_mode() + def score_from_tokens(self, mentions: torch.Tensor, qids: torch.Tensor) -> torch.Tensor: + logits = self.model.forward(mentions, qids) + probability = torch.sigmoid(logits).reshape(-1) + return probability + + def classifier_forward(self, x): + return self.model.classifier_forward(x) + + +class _PairwiseMLPReranker(nn.Module): + def __init__( + self, + base_model: nn.Module, + classifier: nn.Module, + qid_to_paraphrase_emb: Dict[int, torch.Tensor], + qid_to_base_emb: Dict[int, torch.Tensor], + ) -> None: + super().__init__() + self.base_model = base_model + self.classifier = classifier + self.qid_to_paraphrase_emb = qid_to_paraphrase_emb + self.qid_to_base_emb = qid_to_base_emb + + def forward( + self, + mention_tokens: torch.Tensor, + qids: torch.Tensor, + return_embeddings: bool = False, + ) -> torch.Tensor: + + mention_embeddings = self._encode(mention_tokens) + paraphrase_embeddings = torch.stack( + [self.qid_to_paraphrase_emb[int(qid)] for qid in qids], dim=0 + ).to(mention_embeddings.device) + base_embeddings = torch.stack([self.qid_to_base_emb[int(qid)] for qid in qids], dim=0).to( + mention_embeddings.device + ) + + combined = torch.cat([mention_embeddings, paraphrase_embeddings, base_embeddings], dim=-1) + logits = self.classifier(combined).squeeze(-1) + if return_embeddings: + return logits, mention_embeddings, paraphrase_embeddings, base_embeddings + return logits + + def train(self, mode: bool = True) -> _PairwiseMLPReranker: + super().train(mode) + # Make sure that base model is never trained. + self.base_model.eval() + return self + + def classifier_forward(self, x): + return self.classifier(x) + + @torch.inference_mode() + def _encode(self, tokens: Mapping[str, torch.Tensor] | torch.Tensor) -> torch.Tensor: + if isinstance(tokens, Mapping): + input_ids = tokens["input_ids"] + attention_mask = tokens.get("attention_mask") + if attention_mask is None: + attention_mask = create_attention_mask(input_ids) + else: + input_ids = tokens + attention_mask = create_attention_mask(tokens) + + return self.base_model(input_ids=input_ids, attention_mask=attention_mask) diff --git a/src/reranking/models/pairwise_mlp_with_retrieval_score.py b/src/reranking/models/pairwise_mlp_with_retrieval_score.py new file mode 100644 index 0000000..b143866 --- /dev/null +++ b/src/reranking/models/pairwise_mlp_with_retrieval_score.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +from typing import Sequence + +import torch + +from reranking.models.pairwise_mlp import PairwiseMLPReranker +from utils.model_factory import ModelOutputType + + +class PairwiseMLPRerankerWithRetrievalScore(PairwiseMLPReranker): + """Reranking model that augments a LEALLA encoder with an MLP head.""" + + def __init__( + self, + model_name_or_path: str, + *, + state_dict_path: str | None = None, + target_dim: int | None = None, + output_type: ModelOutputType | str | None = None, + tokenizer_name_or_path: str | None = None, + mlp_hidden_dim: int | None = None, + dropout: float = 0.1, + ) -> None: + super().__init__( + model_name_or_path=model_name_or_path, + state_dict_path=state_dict_path, + target_dim=target_dim, + output_type=output_type, + tokenizer_name_or_path=tokenizer_name_or_path, + mlp_hidden_dim=mlp_hidden_dim, + dropout=dropout, + ) + + @torch.inference_mode() + def score( + self, mention: str | Sequence[str], entity_description: str | Sequence[str] + ) -> float | torch.Tensor: + return super().score(mention, entity_description) + + @torch.inference_mode() + def score_from_tokens(self, mentions: torch.Tensor, entities: torch.Tensor) -> float: + logits, mention_embeddings, entity_embeddings = self.model.forward( + mentions, entities, return_embeddings=True + ) + probability1 = torch.sigmoid(logits) + probability2 = torch.sigmoid(mention_embeddings @ entity_embeddings.T) + print(probability1, probability2) + probability = (probability1 + probability2.diagonal()) / 2 + return probability diff --git a/src/reranking/training/reranking_iterable_dataset.py b/src/reranking/training/reranking_iterable_dataset.py new file mode 100644 index 0000000..507b942 --- /dev/null +++ b/src/reranking/training/reranking_iterable_dataset.py @@ -0,0 +1,141 @@ +"""Minimal iterable dataset for streaming reranking training data.""" + +from pathlib import Path +from typing import Iterator, List, Tuple + +import numpy as np +import torch +from torch.utils.data import IterableDataset, get_worker_info + +from utils.loaders import map_qids_to_token_matrix + + +class RerankingIterableDataset(IterableDataset): + """Yield ``(link_tokens, description_tokens, labels, qids)`` from NPZ files. + + Each worker spawned by ``DataLoader`` receives a disjoint stride of the + underlying NPZ shards, ensuring the samples are not duplicated when + ``num_workers > 0``. + """ + + def __init__( + self, + data_dir: str | Path = "~/troja/outputs/reranking_test/reranker_dataset_with_qids", + ) -> None: + super().__init__() + self.data_dir = Path(data_dir).expanduser() + if not self.data_dir.is_dir(): + raise FileNotFoundError(f"Dataset directory not found: {self.data_dir}") + + self._files: List[Path] = sorted(self.data_dir.glob("*.npz")) + if not self._files: + raise FileNotFoundError( + f"No NPZ files found in {self.data_dir}; expected reranking shards" + ) + + def __iter__(self) -> Iterator[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: + worker_info = get_worker_info() + if worker_info is None: + worker_id = 0 + num_workers = 1 + else: + worker_id = worker_info.id + num_workers = worker_info.num_workers + + file_paths = self._files[worker_id::num_workers] + + for file_path in file_paths: + with np.load(file_path, allow_pickle=False) as data: + qids = torch.from_numpy(data["qids"]).long() + labels = torch.from_numpy(data["y"]).float() + description_tokens = torch.from_numpy(data["description_tokens"]).long() + link_tokens = torch.from_numpy(data["link_tokens"]).long() + + if not (len(qids) == len(labels) == len(description_tokens) == len(link_tokens)): + raise ValueError( + "Mismatched array lengths in NPZ file " + f"{file_path}: qids={len(qids)} labels={len(labels)} " + f"description_tokens={len(description_tokens)} link_tokens={len(link_tokens)}" + ) + + permutation = torch.randperm(len(qids)) + qids = qids[permutation] + labels = labels[permutation] + description_tokens = description_tokens[permutation] + link_tokens = link_tokens[permutation] + + for idx in range(len(qids)): + print(f"Yielding sample idx={idx} from file {file_path}") + yield ( + link_tokens[idx], + description_tokens[idx], + labels[idx], + qids[idx], + ) + + +class RerankingQIDsToDescriptionsIterableDataset(IterableDataset): + """Yield ``(link_tokens, description_tokens, labels, qids)`` from NPZ files. + + Each worker spawned by ``DataLoader`` receives a disjoint stride of the + underlying NPZ shards, ensuring the samples are not duplicated when + ``num_workers > 0``. + """ + + def __init__( + self, + data_dir: ( + str | Path + ) = "~/troja/outputs/reranking_test/reranker_dataset_with_qids_multiclass", + index_dir: ( + str | Path + ) = "/lnet/work/home-students-external/farhan/troja/outputs/descriptions_paraphrase_after_multiling_dataset/descs_pages", + ) -> None: + super().__init__() + self.data_dir = Path(data_dir).expanduser() + if not self.data_dir.is_dir(): + raise FileNotFoundError(f"Dataset directory not found: {self.data_dir}") + + self._files: List[Path] = sorted(self.data_dir.glob("*.npz")) + if not self._files: + raise FileNotFoundError( + f"No NPZ files found in {self.data_dir}; expected reranking shards" + ) + self.qid_to_tokens = map_qids_to_token_matrix(index_dir, verbose=True) + + def __iter__(self) -> Iterator[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: + worker_info = get_worker_info() + if worker_info is None: + worker_id = 0 + num_workers = 1 + else: + worker_id = worker_info.id + num_workers = worker_info.num_workers + + file_paths = self._files[worker_id::num_workers] + + for file_path in file_paths: + with np.load(file_path, allow_pickle=False) as data: + qids = torch.from_numpy(data["qids"]).long() + labels = torch.from_numpy(data["y"]).float() + link_tokens = torch.from_numpy(data["link_tokens"]) + + if not (len(qids) == len(labels) == len(link_tokens)): + raise ValueError( + "Mismatched array lengths in NPZ file " + f"{file_path}: qids={len(qids)} labels={len(labels)} " + f"link_tokens={len(link_tokens)}" + ) + + permutation = torch.randperm(len(qids)) + qids = qids[permutation] + labels = labels[permutation] + link_tokens = link_tokens[permutation] + + for idx in range(len(qids)): + yield ( + link_tokens[idx], + torch.from_numpy(self.qid_to_tokens[qids[idx].item()].toarray()), + labels[idx], + qids[idx], + ) diff --git a/src/reranking/training/trainer.py b/src/reranking/training/trainer.py new file mode 100644 index 0000000..4c8611a --- /dev/null +++ b/src/reranking/training/trainer.py @@ -0,0 +1,207 @@ +import logging +import os +from itertools import islice + +import torch + +import wandb +from reranking.models.pairwise_mlp import PairwiseMLPReranker +from tests.utils.test_embeddings import model + +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +torch.backends.cudnn.benchmark = True +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from reranking.training.training_configs import ( + TrainingConfig, + get_config_from_name, +) + +# Settings =========================================== + + +_logger = logging.getLogger("reranking.train.trainer") + + +if torch.cuda.is_available(): + _logger.debug("Running on CUDA.") + device = torch.device("cuda") +else: + _logger.debug("CUDA is not available.") + device = torch.device("cpu") + + +def setup(rank, world_size, master_port: str = "12355"): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = master_port + + # initialize the process group + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + +def cleanup(): + dist.destroy_process_group() + + +SEED = 0 +torch.manual_seed(SEED) + + +def _ddp_train( + rank: int, + world_size: int, + config_name: str, + gradient_clip=1.0, + master_port: str = "12355", +): + setup(rank, world_size, master_port) + + is_the_main_process = rank == 0 + + _logger.info("Loading training configuration") + training_config = get_config_from_name(config_name) + _logger.info(f"Training configuration loaded: {training_config.config_name}") + + if is_the_main_process: + wandb.init( + project="EL-reranking_train_ddp_process_0", + config={ + "config_name": training_config.config_name, + "batch_size": training_config.batch_size, + "save_each": training_config.save_each, + "validate_each": training_config.validate_each, + "validation_size": training_config.validation_size, + "output_dir": training_config.output_dir, + "epochs": training_config.epochs, + }, + ) + + model = training_config.model + dataset = training_config.dataset + optimizer = training_config.optimizer + + dataloader = DataLoader( + dataset, + batch_size=training_config.batch_size, + pin_memory=True, + num_workers=4, + ) + + num_validation_batches = training_config.validation_size // training_config.batch_size + validation_batches = list(islice(iter(dataloader), num_validation_batches)) + val_dataloader = validation_batches + + model.to(rank) + model = torch.compile(model) + model.model = DDP(model.model, device_ids=[rank]) + + scaler = torch.amp.GradScaler("cuda") + + @torch.compile + def step(current_loss): + scaler.scale(current_loss).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip) + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + + global_step = 0 + + for epoch in range(training_config.epochs): + if is_the_main_process: + _logger.info(f"Starting epoch {epoch + 1}/{training_config.epochs}") + + for links, entities, labels, qids in dataloader: + global_step += 1 + links = links.to(rank, non_blocking=True) + entities = entities.to(rank, non_blocking=True) + labels = labels.to(rank, non_blocking=True) + qids = qids.to(rank, non_blocking=True) + + batch_data = { + "mention_tokens": links, + "entity_tokens": entities, + "labels": labels, + "qids": qids, + } + + with torch.autocast(device_type="cuda"): + loss = model.train_step(batch_data) + step(loss) + if is_the_main_process and global_step % training_config.save_each == 0: + path = training_config.get_output_path(global_step) + _logger.info(f"Saving model at step {global_step} to {path}") + model.save(path) + if is_the_main_process and global_step % 500 == 0: + wandb.log({"train/loss": loss.item()}, step=global_step) + _logger.info(f"Step {global_step}, loss: {loss.item():.4f}") + if is_the_main_process and global_step % training_config.validate_each == 0: + model.eval() + + correct = 0 + total = 0 + + total_loss = 0.0 + val_steps = 0 + for links, entities, labels, qids in val_dataloader: + links = links.to(rank, non_blocking=True) + entities = entities.to(rank, non_blocking=True) + labels = labels.to(rank, non_blocking=True) + qids = qids.to(rank, non_blocking=True) + + val_steps += 1 + + with torch.inference_mode(): + if ( + training_config.config_name == "pairwise_mlp" + or training_config.config_name == "pairwise_mlp_debug" + or training_config.config_name == "full_lealla" + ): + probs = model.score_from_tokens(links, entities) + else: + probs = model.score_from_tokens(links, qids) + loss = torch.nn.functional.binary_cross_entropy(probs, labels.float()) + total_loss += loss.item() + predictions = (probs > 0.5).long() + correct += (predictions == labels).sum().item() + total += labels.size(0) + if is_the_main_process: + _logger.info(f"Validation loss: {total_loss / val_steps:.4f}") + _logger.info(f"Validation accuracy: {correct / total:.4f}") + + model.train() + + if is_the_main_process: + _logger.info(f"Epoch {epoch + 1} finished.") + model.save(training_config.get_output_path(global_step)) + + cleanup() + + +# Training =========================================== +def train_ddp(config_name: str, master_port: str = "12355"): + _logger.info("Starting DDP training") + gradient_clip = 1.0 + world_size = torch.cuda.device_count() + _logger.debug(f"Using {world_size} GPUs for training") + + mp.spawn( + _ddp_train, + args=( + world_size, + config_name, + gradient_clip, + master_port, + ), + nprocs=world_size, + ) + + +if __name__ == "__main__": + train_ddp("pairwise_mlp") diff --git a/src/reranking/training/trainer_simple.py b/src/reranking/training/trainer_simple.py new file mode 100644 index 0000000..a980ecd --- /dev/null +++ b/src/reranking/training/trainer_simple.py @@ -0,0 +1,166 @@ +import logging +from contextlib import nullcontext +from itertools import islice + +import torch + +import wandb + +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +torch.backends.cudnn.benchmark = True +from torch.utils.data import DataLoader + +from reranking.training.training_configs import get_config_from_name + +# Settings =========================================== + + +_logger = logging.getLogger("reranking.train.trainer") + + +if torch.cuda.is_available(): + _logger.debug("Running on CUDA.") + device = torch.device("cuda") +else: + _logger.debug("CUDA is not available.") + device = torch.device("cpu") + + +SEED = 0 +torch.manual_seed(SEED) + + +def train( + config_name: str, + gradient_clip: float = 1.0, +): + _logger.info("Loading training configuration") + training_config = get_config_from_name(config_name) + _logger.info(f"Training configuration loaded: {training_config.config_name}") + + model = training_config.model + dataset = training_config.dataset + optimizer = training_config.optimizer + + dataloader = DataLoader( + dataset, + batch_size=training_config.batch_size, + pin_memory=True, + num_workers=1, + ) + + num_validation_batches = training_config.validation_size // training_config.batch_size + validation_batches = list(islice(iter(dataloader), num_validation_batches)) + val_dataloader = validation_batches + + model.to(device) + # model = torch.compile(model) + + use_amp = device.type == "cuda" + scaler = torch.amp.GradScaler(device.type) if use_amp else None + + def step(current_loss): + if scaler is not None: + scaler.scale(current_loss).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip) + scaler.step(optimizer) + scaler.update() + else: + current_loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip) + optimizer.step() + optimizer.zero_grad() + + global_step = 0 + + for epoch in range(training_config.epochs): + _logger.info(f"Starting epoch {epoch + 1}/{training_config.epochs}") + + for links, entities, labels, qids in dataloader: + global_step += 1 + links = links.to(device, non_blocking=True) + entities = entities.to(device, non_blocking=True) + labels = labels.to(device, non_blocking=True) + # qids = qids.to(device, non_blocking=True) + + batch_data = { + "mention_tokens": links, + "entity_tokens": entities, + "labels": labels, + "qids": qids, + } + + autocast_context = torch.autocast(device_type="cuda") if use_amp else nullcontext() + with autocast_context: + loss = model.train_step(batch_data) + step(loss) + if global_step % training_config.save_each == 0: + path = training_config.get_output_path(global_step) + _logger.info(f"Saving model at step {global_step} to {path}") + model.save(path) + if global_step % 500 == 0: + wandb.log({"train/loss": loss.item()}, step=global_step) + _logger.info(f"Step {global_step}, loss: {loss.item():.4f}") + if global_step % training_config.validate_each == 0: + model.eval() + + correct = 0 + total = 0 + + total_loss = 0.0 + val_steps = 0 + for links, entities, labels, qids in val_dataloader: + links = links.to(device, non_blocking=True) + entities = entities.to(device, non_blocking=True) + labels = labels.to(device, non_blocking=True) + # qids = qids.to(device, non_blocking=True) + + val_steps += 1 + + with torch.inference_mode(): + if ( + training_config.config_name == "pairwise_mlp" + or training_config.config_name == "pairwise_mlp_debug" + or training_config.config_name == "full_lealla" + or training_config.config_name == "full_lealla_r" + ): + probs = model.score_from_tokens(links, entities) + elif training_config.config_name == "fusion": + probs = model.score_from_tokens_and_qids(links, entities, qids) + else: + probs = model.score_from_tokens(links, qids) + loss = torch.nn.functional.binary_cross_entropy(probs, labels.float()) + total_loss += loss.item() + predictions = (probs > 0.5).long() + correct += (predictions == labels).sum().item() + total += labels.size(0) + val_loss = total_loss / max(val_steps, 1) + accuracy = correct / max(total, 1) + _logger.info(f"Validation loss: {val_loss:.4f}") + _logger.info(f"Validation accuracy: {accuracy:.4f}") + wandb.log( + { + "validation/loss": val_loss, + "validation/accuracy": accuracy, + }, + step=global_step, + ) + + model.train() + + _logger.info(f"Epoch {epoch + 1} finished.") + model.save(training_config.get_output_path(global_step)) + + +# Training =========================================== +def train_ddp(config_name: str, master_port: str = "12355"): + _logger.warning( + "train_ddp is deprecated and now runs single-device training. master_port is ignored." + ) + train(config_name) + + +if __name__ == "__main__": + train("pairwise_mlp") diff --git a/src/reranking/training/training_configs.py b/src/reranking/training/training_configs.py new file mode 100644 index 0000000..b961060 --- /dev/null +++ b/src/reranking/training/training_configs.py @@ -0,0 +1,518 @@ +import os +from dataclasses import dataclass +from pathlib import Path + +import numpy as np +import torch + +from reranking.models.base import BaseRerankingModel +from reranking.models.context_emb_with_attention import ContextEmbWithAttention +from reranking.models.full_lealla import FullLEALLAReranker +from reranking.models.full_lealla_multiclass import FullLEALLARerankerMulticlass +from reranking.models.fusion import FusionReranker +from reranking.models.pairwise_mlp import PairwiseMLPReranker +from reranking.models.pairwise_mlp_with_large_context_emb import ( + PairwiseMLPRerankerWithLargeContextEmb, +) +from reranking.training.reranking_iterable_dataset import ( + RerankingIterableDataset, + RerankingQIDsToDescriptionsIterableDataset, +) +from utils.loaders import load_embs_and_qids + + +@dataclass +class TrainingConfig: + config_name: str + model: BaseRerankingModel + dataset: torch.utils.data.Dataset + optimizer: torch.optim.Optimizer + batch_size: int + output_dir: str + save_each: int + validate_each: int + validation_size: int = 100000 + epochs: int = 1 + + def get_output_path(self, step: int) -> str: + dir_path = Path(self.output_dir) / self.config_name + dir_path.mkdir(parents=True, exist_ok=True) + return f"{dir_path}/{step}.pth" + + +def pairwise_mlp( + LR: float = 0.0001, + SAVE_EACH: int = 5000, + BATCH_SIZE: int = 1024, + VALIDATE_EACH: int = 10000, + VALIDATION_SIZE: int = 10000, + DROPOUT: float = 0.5, +) -> TrainingConfig: + name = "pairwise_mlp" + model = PairwiseMLPReranker( + model_name_or_path="/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base", + state_dict_path="/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/models_5/ema.pth", + mlp_hidden_dim=2048, + dropout=DROPOUT, + ) + + dataset = RerankingIterableDataset() + + optimizer = torch.optim.AdamW(model.parameters(), lr=LR, fused=True) + output_dir = "/lnet/work/home-students-external/farhan/troja/outputs/reranking_models" + return TrainingConfig( + config_name=name, + model=model, + dataset=dataset, + optimizer=optimizer, + output_dir=output_dir, + save_each=SAVE_EACH, + batch_size=BATCH_SIZE, + validate_each=VALIDATE_EACH, + validation_size=VALIDATION_SIZE, + ) + + +def full_lealla( + LR: float = 0.0001, + SAVE_EACH: int = 5000, + BATCH_SIZE: int = 1536, + VALIDATE_EACH: int = 10000, + VALIDATION_SIZE: int = 10000, + DROPOUT: float = 0.1, +) -> TrainingConfig: + name = "full_lealla" + model = FullLEALLAReranker( + model_name_or_path="/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base", + dropout=DROPOUT, + ) + + dataset = RerankingIterableDataset() + + optimizer = torch.optim.AdamW(model.parameters(), lr=LR, fused=True) + output_dir = "/lnet/work/home-students-external/farhan/troja/outputs/reranking_models" + return TrainingConfig( + config_name=name, + model=model, + dataset=dataset, + optimizer=optimizer, + output_dir=output_dir, + save_each=SAVE_EACH, + batch_size=BATCH_SIZE, + validate_each=VALIDATE_EACH, + validation_size=VALIDATION_SIZE, + ) + + +def full_lealla_r( + LR: float = 0.00005, + SAVE_EACH: int = 10000, + BATCH_SIZE: int = 1300, + VALIDATE_EACH: int = 10000, + VALIDATION_SIZE: int = 10000, + DROPOUT: float = 0.1, +) -> TrainingConfig: + name = "full_lealla_r" + model = FullLEALLAReranker( + model_name_or_path="/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base", + state_dict_path="/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/models_5/ema.pth", + dropout=DROPOUT, + ) + + dataset = RerankingIterableDataset() + + optimizer = torch.optim.AdamW(model.parameters(), lr=LR, fused=True) + output_dir = "/lnet/work/home-students-external/farhan/troja/outputs/reranking_models" + return TrainingConfig( + config_name=name, + model=model, + dataset=dataset, + optimizer=optimizer, + output_dir=output_dir, + save_each=SAVE_EACH, + batch_size=BATCH_SIZE, + validate_each=VALIDATE_EACH, + validation_size=VALIDATION_SIZE, + ) + + +def full_lealla_r_192( + LR: float = 0.00005, + SAVE_EACH: int = 10000, + BATCH_SIZE: int = 1000, + VALIDATE_EACH: int = 10000, + VALIDATION_SIZE: int = 10000, + DROPOUT: float = 0.1, +) -> TrainingConfig: + name = "full_lealla_r_192" + + d = np.load( + "/lnet/work/home-students-external/farhan/troja/outputs/reranking_test/reranker_dataset_with_qids/mentions_5_dataset.npz" + ) + description_tokens = d["description_tokens"] + assert description_tokens.shape[1] == 192, "Expected description tokens to have length 192" + + model = FullLEALLAReranker( + model_name_or_path="/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base", + state_dict_path="/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/models_5/ema.pth", + dropout=DROPOUT, + embedding_dim=192 + 64, + ) + + dataset = RerankingIterableDataset() + + optimizer = torch.optim.AdamW(model.parameters(), lr=LR, fused=True) + output_dir = "/lnet/work/home-students-external/farhan/troja/outputs/reranking_models" + return TrainingConfig( + config_name=name, + model=model, + dataset=dataset, + optimizer=optimizer, + output_dir=output_dir, + save_each=SAVE_EACH, + batch_size=BATCH_SIZE, + validate_each=VALIDATE_EACH, + validation_size=VALIDATION_SIZE, + ) + + +def full_lealla_r_multiclass( + LR: float = 0.00005, + SAVE_EACH: int = 10000, + BATCH_SIZE: int = 1, + VALIDATE_EACH: int = 10000, + VALIDATION_SIZE: int = 10000, + DROPOUT: float = 0.1, +) -> TrainingConfig: + name = "full_lealla_r_multiclass" + + d = np.load( + "/lnet/work/home-students-external/farhan/troja/outputs/reranking_test/reranker_dataset_with_qids/mentions_5_dataset.npz" + ) + description_tokens = d["description_tokens"] + assert description_tokens.shape[1] == 192, "Expected description tokens to have length 192" + + model = FullLEALLARerankerMulticlass( + model_name_or_path="/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base", + state_dict_path="/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/models_5/ema.pth", + dropout=DROPOUT, + embedding_dim=8 * 64, + ) + + dataset = RerankingQIDsToDescriptionsIterableDataset() + + optimizer = torch.optim.AdamW(model.parameters(), lr=LR, fused=True) + output_dir = "/lnet/work/home-students-external/farhan/troja/outputs/reranking_models" + return TrainingConfig( + config_name=name, + model=model, + dataset=dataset, + optimizer=optimizer, + output_dir=output_dir, + save_each=SAVE_EACH, + batch_size=BATCH_SIZE, + validate_each=VALIDATE_EACH, + validation_size=VALIDATION_SIZE, + ) + + +def full_lealla_r_128( + LR: float = 0.00005, + SAVE_EACH: int = 10000, + BATCH_SIZE: int = 1300, + VALIDATE_EACH: int = 10000, + VALIDATION_SIZE: int = 10000, + DROPOUT: float = 0.1, +) -> TrainingConfig: + name = "full_lealla_r_1" + + d = np.load("~/troja/outputs/reranking_test/reranker_dataset_with_qids/mentions_5_dataset.npz") + description_tokens = d["description_tokens"] + assert description_tokens.shape[1] == 128, "Expected description tokens to have length 128" + + model = FullLEALLAReranker( + model_name_or_path="/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base", + state_dict_path="/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/models_5/ema.pth", + dropout=DROPOUT, + embedding_dim=128 + 64, + ) + + dataset = RerankingIterableDataset() + + optimizer = torch.optim.AdamW(model.parameters(), lr=LR, fused=True) + output_dir = "/lnet/work/home-students-external/farhan/troja/outputs/reranking_models" + return TrainingConfig( + config_name=name, + model=model, + dataset=dataset, + optimizer=optimizer, + output_dir=output_dir, + save_each=SAVE_EACH, + batch_size=BATCH_SIZE, + validate_each=VALIDATE_EACH, + validation_size=VALIDATION_SIZE, + ) + + +def fusion( + LR: float = 0.0001, + SAVE_EACH: int = 10000, + BATCH_SIZE: int = 128, + VALIDATE_EACH: int = 10000, + VALIDATION_SIZE: int = 10000, + DROPOUT: float = 0.1, +) -> TrainingConfig: + name = "fusion" + + qid_to_para_toks = {} + from tqdm import tqdm + + for file in tqdm( + os.listdir( + "/lnet/work/home-students-external/farhan/troja/outputs/descriptions_paraphrase_after_multiling_dataset/descs_pages" + ), + total=len( + os.listdir( + "/lnet/work/home-students-external/farhan/troja/outputs/descriptions_paraphrase_after_multiling_dataset/descs_pages" + ) + ), + ): + if file.endswith(".npz"): + d = np.load( + os.path.join( + "/lnet/work/home-students-external/farhan/troja/outputs/descriptions_paraphrase_after_multiling_dataset/descs_pages", + file, + ) + ) + qids = d["qids"] + tokens = torch.from_numpy(d["tokens"]).to(torch.int32) + for qid, token in zip(qids, tokens): + qid_to_para_toks[qid] = token + + model = FusionReranker( + base_model_name_or_path="/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base", + paraphrase_model_name_or_path="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", + base_state_dict_path="/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/models_5/ema.pth", + qid_to_para_toks=qid_to_para_toks, + dropout=DROPOUT, + ) + + dataset = RerankingIterableDataset() + + optimizer = torch.optim.AdamW(model.parameters(), lr=LR, fused=True) + output_dir = "/lnet/work/home-students-external/farhan/troja/outputs/reranking_models" + return TrainingConfig( + config_name=name, + model=model, + dataset=dataset, + optimizer=optimizer, + output_dir=output_dir, + save_each=SAVE_EACH, + batch_size=BATCH_SIZE, + validate_each=VALIDATE_EACH, + validation_size=VALIDATION_SIZE, + ) + + +def pairwise_mlp_noise( + LR: float = 0.0001, + SAVE_EACH: int = 20000, + BATCH_SIZE: int = 1024, + VALIDATE_EACH: int = 10000, + VALIDATION_SIZE: int = 10000, +) -> TrainingConfig: + name = "pairwise_mlp_noise" + model = PairwiseMLPReranker( + model_name_or_path="/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base", + state_dict_path="/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/models_5/ema.pth", + mlp_hidden_dim=2048, + emb_noise=0.1, + ) + data = np.load( + "/lnet/work/home-students-external/farhan/troja/outputs/reranking_test/reranker_dataset_with_qids.npz" + ) + # qids = torch.from_numpy(data["qids"]).long() + labels = torch.from_numpy(data["y"]).float() + description_tokens = torch.from_numpy(data["description_tokens"]).long() + link_tokens = torch.from_numpy(data["link_tokens"]).long() + + dataset = torch.utils.data.TensorDataset(link_tokens, description_tokens, labels) + optimizer = torch.optim.AdamW(model.parameters(), lr=LR, fused=True) + output_dir = "/lnet/work/home-students-external/farhan/troja/outputs/reranking_models" + return TrainingConfig( + config_name=name, + model=model, + dataset=dataset, + optimizer=optimizer, + output_dir=output_dir, + save_each=SAVE_EACH, + batch_size=BATCH_SIZE, + validate_each=VALIDATE_EACH, + validation_size=VALIDATION_SIZE, + ) + + +def pairwise_mlp_noise_dropout( + LR: float = 0.0001, + SAVE_EACH: int = 20000, + BATCH_SIZE: int = 1024, + VALIDATE_EACH: int = 10000, + VALIDATION_SIZE: int = 10000, +) -> TrainingConfig: + name = "pairwise_mlp_noise_dropout" + model = PairwiseMLPReranker( + model_name_or_path="/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base", + state_dict_path="/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/models_5/ema.pth", + mlp_hidden_dim=2048, + emb_noise=0.1, + dropout=0.5, + ) + data = np.load( + "/lnet/work/home-students-external/farhan/troja/outputs/reranking_test/reranker_dataset_with_qids.npz" + ) + # qids = torch.from_numpy(data["qids"]).long() + labels = torch.from_numpy(data["y"]).float() + description_tokens = torch.from_numpy(data["description_tokens"]).long() + link_tokens = torch.from_numpy(data["link_tokens"]).long() + + dataset = torch.utils.data.TensorDataset(link_tokens, description_tokens, labels) + optimizer = torch.optim.AdamW(model.parameters(), lr=LR, fused=True) + output_dir = "/lnet/work/home-students-external/farhan/troja/outputs/reranking_models" + return TrainingConfig( + config_name=name, + model=model, + dataset=dataset, + optimizer=optimizer, + output_dir=output_dir, + save_each=SAVE_EACH, + batch_size=BATCH_SIZE, + validate_each=VALIDATE_EACH, + validation_size=VALIDATION_SIZE, + ) + + +def pairwise_mlp_paraphrase( + LR: float = 0.0001, + SAVE_EACH: int = 5000, + BATCH_SIZE: int = 1024, + VALIDATE_EACH: int = 10000, + VALIDATION_SIZE: int = 10000, + DROPOUT: float = 0.5, +) -> TrainingConfig: + name = "pairwise_mlp_paraphrase" + + paraphrase_embs, paraphrase_qids = load_embs_and_qids( + "/lnet/work/home-students-external/farhan/troja/outputs/paraphrase_multilig_index" + ) + paraphrase_embs = torch.from_numpy(paraphrase_embs) + base_embs, base_qids = load_embs_and_qids( + "/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/damuel_for_index_6" + ) + base_embs = torch.from_numpy(base_embs) + + dataset = RerankingIterableDataset() + output_dir = "/lnet/work/home-students-external/farhan/troja/outputs/reranking_models" + + qid_to_paraphrase_emb = {qid: emb for qid, emb in zip(paraphrase_qids, paraphrase_embs)} + qid_to_base_emb = {qid: emb for qid, emb in zip(base_qids, base_embs)} + + model = PairwiseMLPRerankerWithLargeContextEmb( + model_name_or_path="/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base", + state_dict_path="/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/models_5/ema.pth", + dropout=DROPOUT, + qid_to_paraphrase_emb=qid_to_paraphrase_emb, + qid_to_base_emb=qid_to_base_emb, + ) + + optimizer = torch.optim.AdamW(model.parameters(), lr=LR, fused=True) + + return TrainingConfig( + config_name=name, + model=model, + dataset=dataset, + optimizer=optimizer, + output_dir=output_dir, + save_each=SAVE_EACH, + batch_size=BATCH_SIZE, + validate_each=VALIDATE_EACH, + validation_size=VALIDATION_SIZE, + ) + + +def context_emb_with_attention( + LR: float = 0.0001, + SAVE_EACH: int = 5000, + BATCH_SIZE: int = 1024, + VALIDATE_EACH: int = 1000, + VALIDATION_SIZE: int = 10000, + DROPOUT: float = 0.1, +) -> TrainingConfig: + name = "context_emb_with_attention" + + paraphrase_embs, paraphrase_qids = load_embs_and_qids( + "/lnet/work/home-students-external/farhan/troja/outputs/paraphrase_multilig_index" + ) + paraphrase_embs = torch.from_numpy(paraphrase_embs) + base_embs, base_qids = load_embs_and_qids( + "/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/damuel_for_index_6" + ) + base_embs = torch.from_numpy(base_embs) + + dataset = RerankingIterableDataset() + output_dir = "/lnet/work/home-students-external/farhan/troja/outputs/reranking_models" + + qid_to_paraphrase_emb = {qid: emb for qid, emb in zip(paraphrase_qids, paraphrase_embs)} + qid_to_base_emb = {qid: emb for qid, emb in zip(base_qids, base_embs)} + + model = ContextEmbWithAttention( + model_name_or_path="/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base", + state_dict_path="/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/models_5/ema.pth", + dropout=DROPOUT, + qid_to_paraphrase_emb=qid_to_paraphrase_emb, + qid_to_base_emb=qid_to_base_emb, + ) + + optimizer = torch.optim.AdamW(model.parameters(), lr=LR, fused=True) + + return TrainingConfig( + config_name=name, + model=model, + dataset=dataset, + optimizer=optimizer, + output_dir=output_dir, + save_each=SAVE_EACH, + batch_size=BATCH_SIZE, + validate_each=VALIDATE_EACH, + validation_size=VALIDATION_SIZE, + ) + + +def get_config_from_name(config_name: str) -> TrainingConfig: + if config_name == "pairwise_mlp": + return pairwise_mlp() + if config_name == "pairwise_mlp_debug": + return pairwise_mlp(VALIDATE_EACH=1000, SAVE_EACH=1000000000000) + if config_name == "pairwise_mlp_noise": + return pairwise_mlp_noise() + if config_name == "pairwise_mlp_noise_dropout": + return pairwise_mlp_noise_dropout() + if config_name == "pairwise_mlp_paraphrase": + return pairwise_mlp_paraphrase() + if config_name == "context_emb_with_attention": + return context_emb_with_attention() + if config_name == "full_lealla": + return full_lealla() + if config_name == "full_lealla_r": + return full_lealla_r() + if config_name == "full_lealla_debug": + return full_lealla(VALIDATE_EACH=1000, SAVE_EACH=1000000000000, BATCH_SIZE=128) + if config_name == "fusion": + return fusion() + if config_name == "full_lealla_r_128": + return full_lealla_r_128() + if config_name == "full_lealla_r_192": + return full_lealla_r_192() + if config_name == "full_lealla_r_multiclass": + return full_lealla_r_multiclass() + else: + raise ValueError(f"Unknown training configuration: {config_name}") diff --git a/src/run_action_gin.py b/src/run_action_gin.py index 48298a9..9727994 100644 --- a/src/run_action_gin.py +++ b/src/run_action_gin.py @@ -21,7 +21,14 @@ from finetunings.generate_epochs.generate import generate from multilingual_dataset.combine_embs import combine_embs_by_qid from multilingual_dataset.creator import create_multilingual_dataset, run_kb_creator +from reranking.dataset.create_dataset import ( + create_default_binary_dataset, + create_default_multiclass_dataset, +) +from reranking.training.trainer import train_ddp as reranking_train_ddp +from reranking.training.trainer_simple import train as reranking_train from tokenization.runner import ( + run_damuel_description, run_damuel_description_context, run_damuel_description_mention, run_damuel_link_context, @@ -91,10 +98,20 @@ def choose_action(action): return run_damuel_link_mention case "run_damuel_mention": return run_damuel_mention + case "run_damuel_description": + return run_damuel_description case "olpeat": return olpeat case "find_candidates": return find_candidates + case "reranking_train_ddp": + return reranking_train_ddp + case "create_default_binary_dataset": + return create_default_binary_dataset + case "create_default_multiclass_dataset": + return create_default_multiclass_dataset + case "reranking_train": + return reranking_train case _: raise ValueError(f"Unknown action: {action}") diff --git a/src/scripts/benchmarking/dp_vs_manual.py b/src/scripts/benchmarking/dp_vs_manual.py new file mode 100644 index 0000000..84f6759 --- /dev/null +++ b/src/scripts/benchmarking/dp_vs_manual.py @@ -0,0 +1,118 @@ +import time + +import torch +import torch.nn as nn +import torch.nn.functional as F + +torch._dynamo.config.recompile_limit = 1000 +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True + + +class _WrappedSearcher(nn.Module): + def __init__(self, kb_embs, num_neighbors): + super().__init__() + self.register_buffer("kb_embs", kb_embs) + self.num_neighbors: int = num_neighbors + + def forward(self, x): + dot_product = F.linear(x, self.kb_embs) + _, top_indices = dot_product.topk(self.num_neighbors) + return top_indices + + +def benchmark(f, input_data, warmup_iters=20, profile_iters=100): + for _ in range(warmup_iters): + with torch.inference_mode(): + f(input_data) + torch.cuda.synchronize() + + time_sum = 0.0 + for _ in range(profile_iters): + start = time.time() + with torch.inference_mode(): + f(input_data) + torch.cuda.synchronize() + end = time.time() + time_sum += end - start + return time_sum / profile_iters + + +num_cuda_devices = torch.cuda.device_count() +print(f"Number of available CUDA devices: {num_cuda_devices}") + +cuda_devices = [torch.device(f"cuda:{i}") for i in range(num_cuda_devices)] +print(f"CUDA devices: {cuda_devices}") + +kb = torch.randn(1000, 768) +data = torch.randn(32, 768).to(cuda_devices[0]) + +searcher_dp = _WrappedSearcher(kb, num_neighbors=10) +searcher_dp = nn.DataParallel(searcher_dp, device_ids=cuda_devices) +searcher_dp = torch.compile(searcher_dp) +searcher_dp.to("cuda") + + +class ManualSearcher: + def __init__(self, kb, device_ids, num_neighbors): + self.device_ids = device_ids + self.searchers = [] + for i, device in enumerate(device_ids): + searcher = _WrappedSearcher(kb, num_neighbors=num_neighbors).to(device) + self.searchers.append(searcher) + + def find(self, x): + # Split input data across available devices + inputs = nn.parallel.scatter(x, self.device_ids) + # Compute on each device + outputs = [ + searcher(input_chunk.to(device)) + for searcher, input_chunk, device in zip(self.searchers, inputs, self.device_ids) + ] + gathered = nn.parallel.gather(outputs, self.device_ids[0]) + return gathered + + +searcher_manual = ManualSearcher(kb, device_ids=cuda_devices, num_neighbors=10) +searcher_manual.find = torch.compile(searcher_manual.find) + +print("Comparing outputs...") +out_dp = searcher_dp(data) +out_manual = searcher_manual.find(data) +print("Outputs are equal:", torch.equal(out_dp, out_manual)) +assert torch.equal(out_dp, out_manual) + +kb = torch.randn(10000000, 128).to(torch.float16) +data = torch.randn(256, 128).to(cuda_devices[0], dtype=torch.float16) + +print("Benchmarking DataParallel compiled searcher...") +searcher_dp = _WrappedSearcher(kb, num_neighbors=10) +searcher_dp = nn.DataParallel(searcher_dp, device_ids=cuda_devices) +searcher_dp = torch.compile(searcher_dp) +searcher_dp.to("cuda") +print(benchmark(searcher_dp, data)) +del searcher_dp + +print("Benchmarking DataParallel searcher...") +searcher_dp = _WrappedSearcher(kb, num_neighbors=10) +searcher_dp = nn.DataParallel(searcher_dp, device_ids=cuda_devices) +searcher_dp.to("cuda") +print(benchmark(searcher_dp, data)) +del searcher_dp + +print("Benchmarking Manual searcher...") +searcher_manual = ManualSearcher(kb, device_ids=cuda_devices, num_neighbors=10) +searcher_manual.find = torch.compile(searcher_manual.find) +print(benchmark(searcher_manual.find, data)) +del searcher_manual + +print("Benchmarking Normal searcher...") +normal_searcher = _WrappedSearcher(kb, num_neighbors=10).to("cuda") +print(benchmark(normal_searcher, data)) +del normal_searcher + +print("Benchmarking Compiled Normal searcher...") +normal_searcher_c = _WrappedSearcher(kb, num_neighbors=10).to("cuda") +normal_searcher_c = torch.compile(normal_searcher_c) +print(benchmark(normal_searcher_c, data)) +del normal_searcher_c diff --git a/src/scripts/qwen/el_example.py b/src/scripts/qwen/el_example.py new file mode 100644 index 0000000..d055f04 --- /dev/null +++ b/src/scripts/qwen/el_example.py @@ -0,0 +1,88 @@ +import torch +from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer + + +def format_instruction(instruction, query, doc): + if instruction is None: + instruction = "Given a web search query, retrieve relevant passages that answer the query" + output = ": {instruction}\n: {query}\n: {doc}".format( + instruction=instruction, query=query, doc=doc + ) + return output + + +def process_inputs(pairs): + inputs = tokenizer( + pairs, + padding=False, + truncation="longest_first", + return_attention_mask=False, + max_length=max_length - len(prefix_tokens) - len(suffix_tokens), + ) + for i, ele in enumerate(inputs["input_ids"]): + inputs["input_ids"][i] = prefix_tokens + ele + suffix_tokens + inputs = tokenizer.pad(inputs, padding=True, return_tensors="pt", max_length=max_length) + for key in inputs: + inputs[key] = inputs[key].to(model.device) + return inputs + + +@torch.no_grad() +def compute_logits(inputs, **kwargs): + batch_scores = model(**inputs).logits[:, -1, :] + true_vector = batch_scores[:, token_true_id] + false_vector = batch_scores[:, token_false_id] + batch_scores = torch.stack([false_vector, true_vector], dim=1) + batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1) + scores = batch_scores[:, 1].exp().tolist() + return scores + + +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Reranker-0.6B", padding_side="left") +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-Reranker-0.6B").eval() +# We recommend enabling flash_attention_2 for better acceleration and memory saving. +# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-Reranker-0.6B", torch_dtype=torch.float16, attn_implementation="flash_attention_2").cuda().eval() +token_false_id = tokenizer.convert_tokens_to_ids("no") +token_true_id = tokenizer.convert_tokens_to_ids("yes") +max_length = 8192 + +prefix = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n' +suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" +prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False) +suffix_tokens = tokenizer.encode(suffix, add_special_tokens=False) + +task = ( + "Your task is to determine if the provided Wikipedia description correctly corresponds " + "to the entity mention found in the query. The entity mention is marked by and . " + "Check if the description matches the entity. Answer strictly with 'yes' or 'no'.\n" + "Example:\n" + " Query: 'What is the capital of France?'\n" + " Description: 'Paris is the capital and largest city of France...'\n" + " Answer: no" + " Query: 'What is the capital of France?'\n" + " Description: 'Paris is the capital and largest city of France...'\n" + " Answer: yes" +) + +queries = [ + "What is the capital of China?", + "In order to save Troy, Paris had to be sacrificed.", + "2. prezident republiky byl zdrcen dohodou z Mnichova.", +] * 2 + +documents = [ + "Peking (zvuk výslovnost, čínsky v českém přepisu Pej-ťing, pchin-jinem Běijīng, znaky 北京) je hlavní město Čínské lidové republiky. S více než 21 miliony obyvatel je jedním z nejlidnatějších hlavních měst na světě,[2][3] a po Šanghaji druhým nejlidnatějším městem v Číně.", + "Paris (Ancient Greek: Πάρις, romanized: Páris), also known as Alexander (Ancient Greek: Ἀλέξανδρος, romanized: Aléxandros), is a mythological figure in the story of the Trojan War.", + "Edvard Beneš (původním jménem Eduard;[pozn. 2] 28. května 1884 Kožlany[4] – 3. září 1948 Sezimovo Ústí) byl československý politik a státník, druhý československý prezident v letech 1935–1948, resp. v letech 1935–1938 a 1945–1948. V období tzv. Druhé republiky (po Mnichovské dohodě ze dne 29. září 1938 do 15. března 1939) a následné německé okupace do května 1945 žil a politicky působil v exilu. Od roku 1940 až do osvobození Československa byl mezinárodně (nejen protihitlerovskou koalicí) uznaným vrcholným představitelem československého odboje a exilovým prezidentem republiky. Úřadujícím československým prezidentem byl opět v letech 1945–1948.", + "Shanghai[a] is a direct-administered municipality and the most populous urban area in China. The city is located on the Chinese shoreline on the southern estuary of the Yangtze River, with the Huangpu River flowing through it. The population of the city proper is the second largest in the world with around 24.87 million inhabitants in 2023, while the urban area is the most populous in China, with 29.87 million residents.", + "Paris[a] is the capital and largest city of France, with an estimated city center population of 2,048,472, and a metropolitan population of 13,171,056 as of January 2025[3] in an area of more than 105 km2 (41 sq mi). It is located in the centre of the Île-de-France region. Paris is the fourth-most populous city in the European Union. Nicknamed the City of Light, Paris has been one of the world's major centres of finance, diplomacy, commerce, culture, fashion, and gastronomy since the 17th century. ", + "Tomáš Garrigue Masaryk, označovaný T. G. M., TGM nebo Prezident Osvoboditel (7. března 1850 Hodonín[1] – 14. září 1937 Lány[2]), byl československý státník, filozof, sociolog a pedagog, první prezident Československé republiky. K jeho osmdesátým narozeninám byl roku 1930 přijat zákon o zásluhách T. G. Masaryka, obsahující větu „Tomáš Garrigue Masaryk zasloužil se o stát“, a po odchodu z funkce roku 1935 ho parlament znovu ocenil a odměnil za jeho osvoboditelské a budovatelské dílo.", +] + +pairs = [format_instruction(task, query, doc) for query, doc in zip(queries, documents)] + +# Tokenize the input texts +inputs = process_inputs(pairs) +scores = compute_logits(inputs) + +print("scores: ", scores) diff --git a/src/scripts/qwen/hf_example.py b/src/scripts/qwen/hf_example.py new file mode 100644 index 0000000..932f6a1 --- /dev/null +++ b/src/scripts/qwen/hf_example.py @@ -0,0 +1,72 @@ +import torch +from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer + + +def format_instruction(instruction, query, doc): + if instruction is None: + instruction = "Given a web search query, retrieve relevant passages that answer the query" + output = ": {instruction}\n: {query}\n: {doc}".format( + instruction=instruction, query=query, doc=doc + ) + return output + + +def process_inputs(pairs): + inputs = tokenizer( + pairs, + padding=False, + truncation="longest_first", + return_attention_mask=False, + max_length=max_length - len(prefix_tokens) - len(suffix_tokens), + ) + for i, ele in enumerate(inputs["input_ids"]): + inputs["input_ids"][i] = prefix_tokens + ele + suffix_tokens + inputs = tokenizer.pad(inputs, padding=True, return_tensors="pt", max_length=max_length) + for key in inputs: + inputs[key] = inputs[key].to(model.device) + return inputs + + +@torch.no_grad() +def compute_logits(inputs, **kwargs): + batch_scores = model(**inputs).logits[:, -1, :] + true_vector = batch_scores[:, token_true_id] + false_vector = batch_scores[:, token_false_id] + batch_scores = torch.stack([false_vector, true_vector], dim=1) + batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1) + scores = batch_scores[:, 1].exp().tolist() + return scores + + +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Reranker-0.6B", padding_side="left") +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-Reranker-0.6B").eval() +# We recommend enabling flash_attention_2 for better acceleration and memory saving. +# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-Reranker-0.6B", torch_dtype=torch.float16, attn_implementation="flash_attention_2").cuda().eval() +token_false_id = tokenizer.convert_tokens_to_ids("no") +token_true_id = tokenizer.convert_tokens_to_ids("yes") +max_length = 8192 + +prefix = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n' +suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" +prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False) +suffix_tokens = tokenizer.encode(suffix, add_special_tokens=False) + +task = "Given a web search query, retrieve relevant passages that answer the query" + +queries = [ + "What is the capital of China?", + "Explain gravity", +] + +documents = [ + "In the Troyan war, the Greeks besieged the city of Troy for ten years. The war ended with the clever use of a wooden horse, known as the Trojan Horse, which allowed Greek soldiers to enter the city and defeat the Trojans.", + "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.", +] + +pairs = [format_instruction(task, query, doc) for query, doc in zip(queries, documents)] + +# Tokenize the input texts +inputs = process_inputs(pairs) +scores = compute_logits(inputs) + +print("scores: ", scores) diff --git a/src/scripts/qwen/reranker.py b/src/scripts/qwen/reranker.py new file mode 100644 index 0000000..a9de663 --- /dev/null +++ b/src/scripts/qwen/reranker.py @@ -0,0 +1,94 @@ +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + + +class Reranker: + """Lightweight wrapper around Qwen reranker for mention-description scoring.""" + + _DEFAULT_INSTRUCTION = ( + "Your task is to determine if the provided Wikipedia description correctly corresponds " + "to the entity mention found in the query. The entity mention is marked by [M] and [M]. " + "Check if the description matches the entity. Answer strictly with 'yes' or 'no'.\n" + "Note that the language of the query and description may differ.\n" + "Do NOT consider the language; the goal is to tell whether description matches the entity in the query.\n" + "Example:\n" + " Query: 'What is the capital of [M] France [M]?'\n" + " Description: '[M] Paris [M] is the capital and largest city of France...'\n" + " Answer: no\n" + " Query: 'What is the [M] capital [M] of France?'\n" + " Description: '[M] Paris [M] is the capital and largest city of France...'\n" + " Answer: yes" + ) + # _DEFAULT_INSTRUCTION = ( + # "Given a web search query, retrieve relevant passages that answer the query" + # ) + _SYSTEM_PROMPT = ( + "<|im_start|>system\n" + "Judge whether the Document meets the requirements based on the Query and the Instruct " + 'provided. Note that the answer can only be "yes" or "no".<|im_end|>\n' + "<|im_start|>user\n" + ) + _ASSISTANT_SUFFIX = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" + + def __init__( + self, + model_name: str = "Qwen/Qwen3-Reranker-8B", + max_length: int = 8192, + instruction: str | None = None, + ) -> None: + self.model_name = model_name + self.max_length = max_length + self.instruction = instruction or self._DEFAULT_INSTRUCTION + + self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") + self.model = AutoModelForCausalLM.from_pretrained(model_name).eval() + + self.token_false_id = self.tokenizer.convert_tokens_to_ids("no") + self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes") + self.prefix_tokens = self.tokenizer.encode(self._SYSTEM_PROMPT, add_special_tokens=False) + self.suffix_tokens = self.tokenizer.encode(self._ASSISTANT_SUFFIX, add_special_tokens=False) + + def score(self, mention: str, description: str, instruction: str | None = None) -> float: + """Return probability that description matches mention.""" + formatted_instruction = instruction or self.instruction + prompt = self._format_instruction(formatted_instruction, mention, description) + inputs = self._process_inputs([prompt]) + probabilities = self._compute_probabilities(inputs) + return probabilities[0] + + def _format_instruction(self, instruction: str, query: str, document: str) -> str: + return ": {instruction}\n: {query}\n: {doc}".format( + instruction=instruction, + query=query, + doc=document, + ) + + def _process_inputs(self, prompts: list[str]): + tokenized = self.tokenizer( + prompts, + padding=False, + truncation="longest_first", + return_attention_mask=False, + max_length=self.max_length - len(self.prefix_tokens) - len(self.suffix_tokens), + ) + + for i, ids in enumerate(tokenized["input_ids"]): + tokenized["input_ids"][i] = self.prefix_tokens + ids + self.suffix_tokens + + tokenized = self.tokenizer.pad( + tokenized, padding=True, return_tensors="pt", max_length=self.max_length + ) + + for key in tokenized: + tokenized[key] = tokenized[key].to(self.model.device) + + return tokenized + + @torch.no_grad() + def _compute_probabilities(self, inputs): + logits = self.model(**inputs).logits[:, -1, :] + true_vector = logits[:, self.token_true_id] + false_vector = logits[:, self.token_false_id] + stacked = torch.stack([false_vector, true_vector], dim=1) + log_probs = torch.nn.functional.log_softmax(stacked, dim=1) + return log_probs[:, 1].exp().tolist() diff --git a/src/scripts/qwen/reranking2.py b/src/scripts/qwen/reranking2.py new file mode 100644 index 0000000..a801242 --- /dev/null +++ b/src/scripts/qwen/reranking2.py @@ -0,0 +1,166 @@ +import argparse +from pathlib import Path + +import torch +from torch.utils.data import DataLoader, Dataset +from transformers import AutoTokenizer + +from models.searchers.brute_force_searcher import BruteForceSearcher +from reranking.models.pairwise_mlp import PairwiseMLPReranker +from scripts.qwen.reranker import Reranker +from utils.embeddings import create_attention_mask +from utils.loaders import load_embs_and_qids, load_tokens_qids, load_tokens_qids_from_dir +from utils.model_factory import ModelFactory + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def main(): + parser = argparse.ArgumentParser(description="Reranking for entity linking") + parser.add_argument( + "--damuel_token", + type=str, + default="/lnet/work/home-students-external/farhan/troja/outputs/v2_normal_filtered/descs_pages", + help="Path to damuel token file or directory", + ) + parser.add_argument( + "--damuel_embs", + type=str, + default="/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/damuel_for_index_6", + help="Path to damuel embeddings directory or .npz file", + ) + parser.add_argument( + "--mewsli_tokens", + type=str, + default="/lnet/work/home-students-external/farhan/troja/outputs/tokens_mewsli_finetuning/en/mentions_1641252057782057661.npz", + help="Path to mewsli token file or directory", + ) + parser.add_argument( + "--qwen_model_name", + type=str, + default="Qwen/Qwen3-Reranker-0.6B", + help="Name of the QWEN model", + ) + parser.add_argument( + "--reranking_model_path", + type=str, + default="/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/models_5/ema.pth", + ) + parser.add_argument( + "--num_neighbors", + type=int, + default=10, + help="Number of neighbors to retrieve from the searcher", + ) + args = parser.parse_args() + + reranking_model = ModelFactory.auto_load_from_file( + "/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base", + args.reranking_model_path, + ) + reranking_model.eval() + reranking_model.to(device) + + reranking_tokenizer = AutoTokenizer.from_pretrained( + "/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base" + ) + + reranker = PairwiseMLPReranker( + "/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base", + state_dict_path=args.reranking_model_path, + mlp_hidden_dim=2048, + ) + state_dict = torch.load( + "/lnet/work/home-students-external/farhan/troja/outputs/reranking_models/pairwise_mlp_noise/20000.pth", + map_location=device, + ) + reranker.load_state_dict(state_dict) + reranker.eval() + + mewsli_tokens, mewsli_qids = load_tokens_qids(args.mewsli_tokens) + mewsli_tokens = torch.from_numpy(mewsli_tokens) + mewsli_qids = torch.from_numpy(mewsli_qids) + + decode_mewsli_example = reranking_tokenizer.decode(mewsli_tokens[0], skip_special_tokens=True) + print("Decoded mewsli example:", decode_mewsli_example) + + # Resolve tokens and embeddings (directories or files) + damuel_tokens, damuel_token_qids = load_tokens_qids_from_dir(args.damuel_token, verbose=True) + damuel_embs, damuel_qids = load_embs_and_qids(args.damuel_embs) + + qid_to_damuel_token = {qid: token for qid, token in zip(damuel_token_qids, damuel_tokens)} + + del damuel_token_qids + + # Take first four names (tokens) from each as a quick smoke-test + damuel_tokens_preview = damuel_tokens[:4] + mewsli_tokens_preview = mewsli_tokens[:4] + + print("First 4 damuel tokens:", damuel_tokens_preview) + print("First 4 mewsli tokens:", mewsli_tokens_preview) + + # Create searcher using damuel embeddings and damuel qids + searcher = BruteForceSearcher(damuel_embs, damuel_qids) + print("Searcher created.") + + # Initialize reranker model (actual reranking logic to be implemented later) + # reranker = Reranker(model_name=args.qwen_model_name) + # print(f"Reranker initialized with model: {reranker.model_name}") + + # Stub for QWEN model loading + print("QWEN model to be used:", args.qwen_model_name) + + # TODO: implement reranking logic + print("Reranking logic not implemented. Exiting.") + + print("jupi") + + mewsli_dataset = torch.utils.data.TensorDataset(mewsli_tokens, mewsli_qids) + mewsli_loader = DataLoader(mewsli_dataset, batch_size=1, shuffle=False) + + good = 0 + total = 0 + + for batch in mewsli_loader: + mewsli_token = batch[0] + qid = batch[1].item() # ensure scalar for a fair comparison with predicted_qid + # print(f"Processing Mewsli token: {mewsli_token}, QID: {qid}") + + tokens = mewsli_token.to(device, dtype=torch.int64) + + with torch.inference_mode(): + attention_mask = create_attention_mask(tokens).to(device) + mewsli_emb = ( + reranking_model(tokens, attention_mask) + .to("cpu", dtype=torch.float16) + .detach() + .numpy() + ) + + neighbor_qids = searcher.find(mewsli_emb, num_neighbors=args.num_neighbors) + + damuel_candidates = [qid_to_damuel_token[nq] for nq in neighbor_qids[0]] + damuel_candidates_str = [ + reranking_tokenizer.decode(dc, skip_special_tokens=True) for dc in damuel_candidates + ] + + mewsli_str = reranking_tokenizer.decode(mewsli_token[0], skip_special_tokens=True) + + scores = [] + # print("Mewsli mention:", mewsli_str) + for dc in damuel_candidates_str: + with torch.inference_mode(): + # print(dc) + score = reranker.score(mewsli_str, dc) + scores.append(score) + # print(scores) + predicted_qid = int(neighbor_qids[0][scores.index(max(scores))]) + if predicted_qid == qid: + good += 1 + total += 1 + + print("Current accuracy:", round(good / total * 100, 4)) + + +if __name__ == "__main__": + main() diff --git a/src/scripts/qwen/reranking3.py b/src/scripts/qwen/reranking3.py new file mode 100644 index 0000000..1ea6015 --- /dev/null +++ b/src/scripts/qwen/reranking3.py @@ -0,0 +1,273 @@ +import argparse +from pathlib import Path + +import numpy as np +import torch +from einops import rearrange, repeat +from torch.utils.data import DataLoader, Dataset +from transformers import AutoTokenizer + +from models.searchers.brute_force_searcher import BruteForceSearcher +from reranking.models.full_lealla import FullLEALLAReranker +from reranking.models.pairwise_mlp import PairwiseMLPReranker +from reranking.models.pairwise_mlp_with_large_context_emb import ( + PairwiseMLPRerankerWithLargeContextEmb, +) +from reranking.models.pairwise_mlp_with_retrieval_score import PairwiseMLPRerankerWithRetrievalScore +from scripts.qwen.reranker import Reranker +from utils.embeddings import create_attention_mask +from utils.loaders import load_embs_and_qids, load_tokens_qids, load_tokens_qids_from_dir +from utils.model_factory import ModelFactory + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +LOGIT_MULTIPLIER = 20.0 + + +def main(): + parser = argparse.ArgumentParser(description="Reranking for entity linking") + parser.add_argument( + "--damuel_token", + type=str, + default="/lnet/work/home-students-external/farhan/troja/outputs/v2_normal_filtered/descs_pages", + help="Path to damuel token file or directory", + ) + parser.add_argument( + "--damuel_embs", + type=str, + default="/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/damuel_for_index_6", + help="Path to damuel embeddings directory or .npz file", + ) + parser.add_argument( + "--mewsli_tokens", + type=str, + default="/lnet/work/home-students-external/farhan/troja/outputs/tokens_mewsli_finetuning", + help="Path to mewsli token file or directory", + ) + parser.add_argument( + "--qwen_model_name", + type=str, + default="Qwen/Qwen3-Reranker-0.6B", + help="Name of the QWEN model", + ) + parser.add_argument( + "--reranking_model_path", + type=str, + default="/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/models_5/ema.pth", + ) + parser.add_argument( + "--num_neighbors", + type=int, + default=10, + help="Number of neighbors to retrieve from the searcher", + ) + parser.add_argument( + "--batch_size", + type=int, + default=32, + help="Batch size for DataLoader", + ) + args = parser.parse_args() + + reranking_model = ModelFactory.auto_load_from_file( + "/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base", + args.reranking_model_path, + ) + reranking_model.eval() + reranking_model.to(device) + + reranking_tokenizer = AutoTokenizer.from_pretrained( + "/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base" + ) + + # paraphrase_embs, paraphrase_qids = load_embs_and_qids( + # "/lnet/work/home-students-external/farhan/troja/outputs/paraphrase_multilig_index" + # ) + # paraphrase_embs = torch.from_numpy(paraphrase_embs) + # base_embs, base_qids = load_embs_and_qids( + # "/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/damuel_for_index_6" + # ) + # base_embs = torch.from_numpy(base_embs) + + reranker = PairwiseMLPReranker( + "/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base", + state_dict_path=args.reranking_model_path, + mlp_hidden_dim=2048, + ) + # qid_to_paraphrase_emb = {qid: emb for qid, emb in zip(paraphrase_qids, paraphrase_embs)} + # qid_to_base_emb = {qid: emb for qid, emb in zip(base_qids, base_embs)} + # reranker = PairwiseMLPRerankerWithLargeContextEmb( + # model_name_or_path="/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base", + # state_dict_path="/lnet/work/home-students-external/farhan/troja/outputs/workdirs/asi_se_to_rozbilo_init_all/models_5/ema.pth", + # qid_to_paraphrase_emb=qid_to_paraphrase_emb, + # qid_to_base_emb=qid_to_base_emb, + # ) + # reranker = FullLEALLAReranker( + # model_name_or_path="/lnet/work/home-students-external/farhan/troja/outputs/models/LEALLA-base", + # ) + reranker.load( + "/lnet/work/home-students-external/farhan/troja/outputs/reranking_models/pairwise_mlp/70000.pth", + ) + # reranker.load( + # "/lnet/work/home-students-external/farhan/troja/outputs/reranking_models/full_lealla_r/380000.pth", + # ) + reranker.eval() + reranker.to(device) + + # Resolve tokens and embeddings (directories or files) + damuel_tokens, damuel_token_qids = load_tokens_qids_from_dir(args.damuel_token, verbose=True) + damuel_embs, damuel_qids = load_embs_and_qids(args.damuel_embs) + + qid_to_damuel_emb = {qid: emb for qid, emb in zip(damuel_qids, damuel_embs)} + qid_to_damuel_token = {qid: token for qid, token in zip(damuel_token_qids, damuel_tokens)} + + del damuel_token_qids + + # Take first four names (tokens) from damuel as a quick smoke-test + damuel_tokens_preview = damuel_tokens[:4] + print("First 4 damuel tokens:", damuel_tokens_preview) + + # Create searcher using damuel embeddings and damuel qids + searcher = BruteForceSearcher(damuel_embs, damuel_qids) + print("Searcher created.") + + # Initialize reranker model (actual reranking logic to be implemented later) + # reranker = Reranker(model_name=args.qwen_model_name) + # print(f"Reranker initialized with model: {reranker.model_name}") + + # Stub for QWEN model loading + print("QWEN model to be used:", args.qwen_model_name) + + print("jupi") + + mewsli_root = Path(args.mewsli_tokens) + if not mewsli_root.exists(): + raise FileNotFoundError(f"MEWSLI tokens path not found: {mewsli_root}") + + language_paths: list[tuple[str, Path]] = [] + if mewsli_root.is_file(): + language_name = mewsli_root.parent.name or mewsli_root.stem + language_paths.append((language_name, mewsli_root)) + else: + for subdir in sorted(p for p in mewsli_root.iterdir() if p.is_dir()): + language_paths.append((subdir.name, subdir)) + + if not language_paths: + raise ValueError(f"No MEWSLI language directories or files found under {mewsli_root}") + + for language, mewsli_path in language_paths: + print(f"\nEvaluating language: {language}") + + if mewsli_path.is_dir(): + tokens_np, qids_np = load_tokens_qids_from_dir(mewsli_path, verbose=False) + else: + tokens_np, qids_np = load_tokens_qids(mewsli_path) + + if tokens_np.size == 0: + print(f"No samples found for language {language}, skipping.") + continue + + mewsli_tokens = torch.from_numpy(tokens_np) + mewsli_qids = torch.from_numpy(qids_np) + + mewsli_dataset = torch.utils.data.TensorDataset(mewsli_tokens, mewsli_qids) + mewsli_loader = DataLoader(mewsli_dataset, batch_size=args.batch_size, shuffle=False) + + good = 0 + upper_bound_hits = 0 + total = 0 + + for mewsli_tokens_batch, qids_batch in mewsli_loader: + tokens = mewsli_tokens_batch.to(device, dtype=torch.int64) + qids_batch = qids_batch.view(-1).to(torch.long) + + with torch.inference_mode(): + attention_mask = create_attention_mask(tokens).to(device) + mewsli_embs = reranking_model(tokens, attention_mask) + + neighbor_qids = searcher.find( + mewsli_embs.to("cpu").numpy().astype(np.float16), num_neighbors=args.num_neighbors + ) + neighbor_qids = torch.as_tensor(neighbor_qids, dtype=torch.long) + + retrieval_hits = (neighbor_qids == qids_batch.view(-1, 1)).any(dim=1) + upper_bound_hits += retrieval_hits.sum().item() + + candidate_embs_lists = [] + candidate_tokens_lists = [] + for row in neighbor_qids.tolist(): + for nq in row: + emb = qid_to_damuel_emb[int(nq)] + token = qid_to_damuel_token[int(nq)] + candidate_embs_lists.append(emb) + candidate_tokens_lists.append(token) + + # assert len(candidate_embs_lists) == neighbor_qids.size(0) * neighbor_qids.size(1) + + candidate_embs = torch.as_tensor( + candidate_embs_lists, dtype=torch.float16, device=device + ) + together = torch.cat( + (mewsli_embs.repeat_interleave(neighbor_qids.size(1), dim=0), candidate_embs), + dim=-1, + ) + together = together.to(device) + candidate_tokens = torch.as_tensor( + candidate_tokens_lists, dtype=torch.int64, device=device + ) + with torch.inference_mode(): + if isinstance(reranker, PairwiseMLPRerankerWithLargeContextEmb): + scores = reranker.score_from_tokens( + repeat(tokens, "b d -> (b n) d", n=neighbor_qids.size(1)), + rearrange(neighbor_qids, "b n -> (b n)"), + ) + scores = rearrange(scores, "(b n) -> b n", n=neighbor_qids.size(1)) + elif isinstance(reranker, FullLEALLAReranker): + score = reranker.score_from_tokens( + repeat(tokens, "b d -> (b n) d", n=neighbor_qids.size(1)), candidate_tokens + ) + scores = rearrange(score, "(b n) -> b n", n=neighbor_qids.size(1)) + else: + logits = reranker.classifier_forward(together) + scores = torch.sigmoid(logits).reshape( + neighbor_qids.size(0), neighbor_qids.size(1) + ) + if ( + isinstance(reranker, PairwiseMLPRerankerWithRetrievalScore) + # or isinstance(reranker, FullLEALLAReranker) + # or isinstance(reranker, PairwiseMLPReranker) + # or isinstance(reranker, PairwiseMLPRerankerWithLargeContextEmb) + ): + candidate_embs = candidate_embs.reshape( + neighbor_qids.size(0), neighbor_qids.size(1), -1 + ) + out = ( + torch.einsum( + "abc,ac->ab", + candidate_embs.to(torch.bfloat16), + mewsli_embs.to(torch.bfloat16), + ) + * LOGIT_MULTIPLIER + ) + out = out.reshape(neighbor_qids.size(0), neighbor_qids.size(1)) + scores = (scores + torch.sigmoid(out)) / 2 + + max_indices = scores.argmax(dim=1).cpu() + predicted_qids = neighbor_qids[torch.arange(neighbor_qids.size(0)), max_indices] + + good += (predicted_qids.cpu() == qids_batch.cpu()).sum().item() + total += qids_batch.numel() + + if total == 0: + print(f"No valid predictions for language {language}.") + continue + + final_accuracy = round(good / total * 100, 4) + retrieval_upper_bound = round(upper_bound_hits / total * 100, 4) + print( + f"Final accuracy for {language}: {final_accuracy} (retrieval upper bound: {retrieval_upper_bound})" + ) + + +if __name__ == "__main__": + main() diff --git a/src/scripts/reranking/change_dataset_tokens.py b/src/scripts/reranking/change_dataset_tokens.py new file mode 100644 index 0000000..f6b8778 --- /dev/null +++ b/src/scripts/reranking/change_dataset_tokens.py @@ -0,0 +1,57 @@ +from pathlib import Path + +import fire +import numpy as np +from scipy.sparse import csr_matrix +from tqdm import tqdm + +from utils.loaders import map_qids_to_token_matrix + + +def update_tokens_in_file(file_path: Path, qid_to_new_tokens: csr_matrix) -> None: + """Updates tokens in a single .npz file in place. + + Replaces each entry in "description_tokens" with the corresponding row from + qid_to_new_tokens indexed by the file's qids. Raises ValueError if a qid has + no corresponding tokens in qid_to_new_tokens. + """ + print(file_path) + with np.load(file_path) as data: + tokens = data["description_tokens"] + qids = data["qids"] + save_data = dict(data) + + # tokens = np.empty((tokens.shape[0], qid_to_new_tokens.shape[1]), dtype=tokens.dtype) + + tokens = qid_to_new_tokens[qids].toarray() + + # for i, qid in enumerate(qids): + # if qid_to_new_tokens[qid].nnz > 0: + # tokens[i] = qid_to_new_tokens[qid].toarray()[0] + # else: + # # We could also pad/truncate here if needed but this code should not really happen. + # raise ValueError(f"No new tokens found for qid {qid}") + + save_data["description_tokens"] = tokens + np.savez(file_path, **save_data) + + +def process_directory(dataset_dir: Path, new_tokens_dir: Path) -> None: + """Orchestrates the token update process for an entire directory.""" + qid_to_new_tokens = map_qids_to_token_matrix(new_tokens_dir, verbose=True) + files_to_process = list(dataset_dir.glob("*.npz")) + + for fp in tqdm(files_to_process, desc="Updating files"): + update_tokens_in_file(fp, qid_to_new_tokens) + + +def main(dataset_dir, new_tokens_dir): + dataset_dir = Path(dataset_dir) + new_tokens_dir = Path(new_tokens_dir) + print("Starting token update process...") + process_directory(dataset_dir, new_tokens_dir) + print("Update process completed.") + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/src/scripts/train/asi_se_to_rozbilo_init_all.sh b/src/scripts/train/asi_se_to_rozbilo_init_all.sh new file mode 100755 index 0000000..091c21f --- /dev/null +++ b/src/scripts/train/asi_se_to_rozbilo_init_all.sh @@ -0,0 +1,195 @@ +#!/bin/bash + +# Runs the complete finetuning process. +# Expects tokens to be in the dirs specified below. +# Additionaly, one can specify additional parameters. +# For running, please also set up/fix the path to venv in run_finetuning_action.sh + +set -ueo pipefail + +cd ../../ + +echo "Running all_langs.sh" +echo "Current directory: $(pwd)" + +MODEL_CONFIG_PATH="../configs/lealla_m.gin" +TRAIN_CONFIG_PATH="../configs/train.gin" + +DAMUEL_DESCS_TOKENS_RAW="$OUTPUTS/v2_normal_filtered/descs_pages" +DAMUEL_LINKS_TOKENS_RAW="$OUTPUTS/v2_normal_filtered/links" +MEWSLI_TOKENS_RAW="$OUTPUTS/tokens_mewsli_finetuning" +WORKDIR="$OUTPUTS/workdirs/asi_se_to_rozbilo_init_all" +N_OF_ROUNDS=6 + +run_ml_finetuning_round() { + local DAMUEL_DESCS_TOKENS_RAW=$1 + local DAMUEL_LINKS_TOKENS_RAW=$2 + local MEWSLI_TOKENS_RAW=$3 + local WORKDIR=$4 + local STATE_DICT=${5:-"None"} + local ROUND_ID=${6:-"0"} + local N_OF_ROUNDS=${7} + + local STEPS_PER_EPOCH=1000 + + local LINKS_PER_ROUND=$(($STEPS_PER_EPOCH * 250000 * 8)) + echo "LPR $LINKS_PER_ROUND" + + local ACTION_SCRIPT="run_action_gin.py $MODEL_CONFIG_PATH $TRAIN_CONFIG_PATH" + + ENV="../venv/bin/activate" + source $ENV + + # ====================TOKENS COPY==================== + + local DAMUEL_LINKS_TOKENS="$WORKDIR/damuel_links_together_tokens_$ROUND_ID" + if [ ! "$(ls -A $DAMUEL_LINKS_TOKENS)" ]; then + ../venv/bin/python $ACTION_SCRIPT "copy" \ + --source="$DAMUEL_LINKS_TOKENS_RAW" \ + --dest="$DAMUEL_LINKS_TOKENS" \ + --m="$N_OF_ROUNDS" \ + --r="$ROUND_ID" \ + --max_to_copy="$LINKS_PER_ROUND" + fi + + # ====================DAMUEL DESC EMBS==================== + + local DAMUEL_FOR_INDEX_DIR="$WORKDIR/damuel_for_index_$ROUND_ID" + + mkdir -p "$DAMUEL_FOR_INDEX_DIR" + + if [ ! "$(ls -A $DAMUEL_FOR_INDEX_DIR)" ]; then + echo "Running embs generating for damuel" + ../venv/bin/python $ACTION_SCRIPT "embs_from_tokens_model_name_and_state_dict" \ + --source_path="$DAMUEL_DESCS_TOKENS_RAW" \ + --dest_path="$DAMUEL_FOR_INDEX_DIR" \ + --state_dict_path="$STATE_DICT" + fi + + # ====================DAMUEL LINKS EMBEDDING==================== + + local DAMUEL_LINKS_DIR="$WORKDIR/links_embs_$ROUND_ID" + + mkdir -p "$DAMUEL_LINKS_DIR" + + if [ ! "$(ls -A $DAMUEL_LINKS_DIR)" ]; then + echo "Running embs generating for damuel links" + ../venv/bin/python $ACTION_SCRIPT "embed_links_for_generation" \ + --links_tokens_dir_path="$DAMUEL_LINKS_TOKENS" \ + --dest_dir_path="$DAMUEL_LINKS_DIR" \ + --state_dict_path="$STATE_DICT" + fi + + # ====================GENERATING BATCHES==================== + + local BATCH_DIR="$WORKDIR/batches_$ROUND_ID" + + mkdir -p "$BATCH_DIR" + if [ ! "$(ls -A $BATCH_DIR)" ]; then + echo "Running batches generating for damuel" + # ../venv/bin/python -m cProfile -o "generate.prof" $ACTION_SCRIPT "generate" \ + ../venv/bin/python $ACTION_SCRIPT "generate" \ + --LINKS_EMBS_DIR="$DAMUEL_LINKS_DIR" \ + --INDEX_TOKENS_DIR="$DAMUEL_DESCS_TOKENS_RAW" \ + --INDEX_EMBS_QIDS_DIR="$DAMUEL_FOR_INDEX_DIR" \ + --OUTPUT_DIR="$BATCH_DIR" + fi + + # ====================TRAINING MODEL==================== + + local MODELS_DIR="$WORKDIR/models_$ROUND_ID" + + mkdir -p $MODELS_DIR + + if [ ! "$(ls -A $MODELS_DIR)" ]; then + echo "Running training for damuel" + #../venv/bin/python -m cProfile -o "train_ddp.prof" $ACTION_SCRIPT "train_ddp" \ + echo $ACTION_SCRIPT "train_ddp" \ + --DATASET_DIR="$BATCH_DIR" \ + --MODEL_SAVE_DIR="$MODELS_DIR" \ + --STATE_DICT_PATH="$STATE_DICT" + ../venv/bin/python $ACTION_SCRIPT "train_ddp" \ + --DATASET_DIR="$BATCH_DIR" \ + --MODEL_SAVE_DIR="$MODELS_DIR" \ + --STATE_DICT_PATH="$STATE_DICT" + fi + + # ====================EVALUATION==================== + + local NEXT_INDEX=$(($ROUND_ID + 1)) + local DAMUEL_FOR_INDEX_NEW_DIR="$WORKDIR/damuel_for_index_$NEXT_INDEX" + mkdir -p "$DAMUEL_FOR_INDEX_NEW_DIR" + + if [ ! "$(ls -A $DAMUEL_FOR_INDEX_NEW_DIR)" ]; then + echo "Running embs generating for damuel" + ../venv/bin/python $ACTION_SCRIPT "embs_from_tokens_model_name_and_state_dict" \ + --source_path="$DAMUEL_DESCS_TOKENS_RAW" \ + --dest_path="$DAMUEL_FOR_INDEX_NEW_DIR" \ + --state_dict_path="$MODELS_DIR/ema.pth" + fi + + local LANGUAGES=("ar" "de" "en" "es" "ja" "fa" "sr" "ta" "tr") + + for LANG in "${LANGUAGES[@]}"; do + echo "Processing language: $LANG" + + local LANG_TOKEN_DIR="$MEWSLI_TOKENS_RAW/$LANG" + local MEWSLI_EMBS_DIR="$WORKDIR/mewsli_embs_${LANG}_$ROUND_ID" + + mkdir -p "$MEWSLI_EMBS_DIR" + + if [ ! "$(ls -A $MEWSLI_EMBS_DIR)" ]; then + echo "Running embs generating for mewsli - Language: $LANG" + ../venv/bin/python $ACTION_SCRIPT "embs_from_tokens_model_name_and_state_dict" \ + --source_path="$LANG_TOKEN_DIR" \ + --dest_path="$MEWSLI_EMBS_DIR" \ + --state_dict_path="$MODELS_DIR/ema.pth" + fi + + # ../venv/bin/python $ACTION_SCRIPT "recalls" \ + # --damuel_dir="$DAMUEL_FOR_INDEX_NEW_DIR" \ + # --mewsli_dir="$MEWSLI_EMBS_DIR" + + # echo "Completed processing for language: $LANG" + echo "----------------------------------------" + done + + ../venv/bin/python $ACTION_SCRIPT "evaluate" \ + --root_dir="$WORKDIR" \ + --finetuning_round=$ROUND_ID + + rm -r $BATCH_DIR $DAMUEL_LINKS_DIR $DAMUEL_FOR_INDEX_DIR $DAMUEL_LINKS_TOKENS +} + +if [ ! -L "$WORKDIR" ]; then + mkdir -p "$WORKDIR" +fi + +DAMUEL_DESCS_TOKENS="$WORKDIR/damuel_descs_together_tokens" +if [ ! -L "$DAMUEL_DESCS_TOKENS" ]; then + mkdir -p "$DAMUEL_DESCS_TOKENS" +fi + +for ((ROUND_ID=0; ROUND_ID<$N_OF_ROUNDS; ROUND_ID++)) +do + DAMUEL_LINKS_TOKENS="$WORKDIR/damuel_links_together_tokens_$ROUND_ID" + if [ ! -L "$DAMUEL_LINKS_TOKENS" ]; then + mkdir -p "$DAMUEL_LINKS_TOKENS" + fi +done + +STATE_DICT="None" + +for ((ROUND_ID=0; ROUND_ID<$N_OF_ROUNDS; ROUND_ID++)) +do + if [ ! -e "$WORKDIR/models_$ROUND_ID/final.pth" ]; then + echo "Running round $ROUND_ID" + + run_ml_finetuning_round "$DAMUEL_DESCS_TOKENS_RAW" "$DAMUEL_LINKS_TOKENS_RAW" \ + "$MEWSLI_TOKENS_RAW" \ + "$WORKDIR" "$STATE_DICT" \ + "$ROUND_ID" "$N_OF_ROUNDS" + fi + + STATE_DICT="$WORKDIR/models_$ROUND_ID/ema.pth" +done diff --git a/src/tokenization/runner.py b/src/tokenization/runner.py index 55a5711..1f7b2ff 100644 --- a/src/tokenization/runner.py +++ b/src/tokenization/runner.py @@ -175,6 +175,9 @@ def run_damuel_description_context( ) -> None: tokenizer = AutoTokenizer.from_pretrained(model_path) + print(languages) + print(type(languages)) + for lang in languages: os.makedirs(os.path.join(output_base_dir, lang, "descs_pages"), exist_ok=True) @@ -359,5 +362,32 @@ def run_damuel_mention( os.rename(file, new_name) +@gin.configurable +def run_damuel_description( + model_path: str, + expected_size: int, + output_base_dir: str, + languages: List[str], + damuel_base_path: str, + compress: bool, + remainder_mod: int, + num_processes: int, +) -> None: + run_damuel_description_context( + model_path=model_path, + expected_size=expected_size, + output_base_dir=output_base_dir, + languages=languages, + damuel_base_path=damuel_base_path, + # The goal of this pipeline is to tokenize descriptions without wrapping anything in the label token (e.g., [M]). + # The label token is required by the tokenization class, so we provide an empty string + # so that it has no effect. + label_token="", + compress=compress, + remainder_mod=remainder_mod, + num_processes=num_processes, + ) + + if __name__ == "__main__": run_mewsli_mention() diff --git a/src/utils/embeddings.py b/src/utils/embeddings.py index faa1838..300d944 100644 --- a/src/utils/embeddings.py +++ b/src/utils/embeddings.py @@ -208,7 +208,7 @@ def embs_from_tokens_model_name_and_state_dict( target_dim: int | None = None, output_type: str | None = None, ): - model = ModelFactory.auto_load_from_file(model_name, state_dict_path, target_dim) + model = ModelFactory.auto_load_from_file(model_name, state_dict_path, target_dim, output_type) embs_from_tokens_and_model(source_path, model, batch_size, dest_path) diff --git a/src/utils/loaders.py b/src/utils/loaders.py index 9da0f55..38b8956 100644 --- a/src/utils/loaders.py +++ b/src/utils/loaders.py @@ -5,6 +5,8 @@ import gin import numpy as np import pandas as pd +from scipy.sparse import coo_matrix, csr_matrix +from tqdm import tqdm # from tokenization.pipeline import DamuelAliasTablePipeline from tokenization.runner import run_alias_table_damuel @@ -114,24 +116,37 @@ def load_qids_npy(file_path: str | Path) -> np.ndarray: @_sort_by_output(1) @qid_filter(qids_index=1) @remap_qids_decorator(qids_index=1, json_path=gin.REQUIRED) -def load_tokens_qids_from_dir(dir_path: str | Path) -> tuple[np.ndarray, np.ndarray]: +def load_tokens_qids_from_dir( + dir_path: str | Path, verbose=False, max_items_to_load: int | None = None +) -> tuple[np.ndarray, np.ndarray]: """ Loads mention tokens and query IDs from all .npz files in a given directory. Args: dir_path (str | Path): Path to the directory containing .npz files. + verbose (bool): If True, displays a progress bar while loading files. Returns: tuple[np.ndarray, np.ndarray]: A tuple containing two numpy arrays: - tokens: Array of mention tokens loaded from the files. - qids: Array of query IDs loaded from the files. """ + if type(dir_path) == str: + dir_path = Path(dir_path) tokens, qids = [], [] - for file in dir_path.iterdir(): + iterator = dir_path.iterdir() + if verbose: + total = sum(1 for itm in dir_path.iterdir() if itm.is_file() and itm.suffix == ".npz") + iterator = tqdm( + dir_path.iterdir(), desc=f"Loading tokens and qids from {dir_path}", total=total + ) + for file in iterator: if file.is_file() and file.suffix == ".npz": d = np.load(file) tokens.extend(d["tokens"]) qids.extend(d["qids"]) + if max_items_to_load is not None and len(tokens) >= max_items_to_load: + break return np.array(tokens), np.array(qids) @@ -142,6 +157,49 @@ def load_tokens_and_qids(file_path: str | Path) -> tuple[np.ndarray, np.ndarray] return d["tokens"], d["qids"] +def map_qids_to_token_matrix( + dir_path: str | Path, verbose: bool = False, max_items_to_load: int | None = None +) -> csr_matrix: + """Builds a memory-efficient sparse matrix mapping qids to their token vectors. + + Args: + dir_path (str | Path): Directory containing data files with 'tokens' and 'qids'. + verbose (bool): Forwarded to `load_tokens_qids_from_dir` to toggle progress output. + max_items_to_load (int | None): Optional cap on the number of token rows to read. + + Returns: + scipy.sparse.csr_matrix: A CSR matrix where a row index corresponds to a qid + and the row's data is the token vector. Use + `matrix[qid]` to retrieve a vector. + """ + tokens, qids = load_tokens_qids_from_dir( + dir_path=dir_path, verbose=verbose, max_items_to_load=max_items_to_load + ) + # remove duplicated items by qids + print("Original number of items:", tokens.shape[0]) + unique_qids, unique_indices = np.unique(qids, return_index=True) + tokens = tokens[unique_indices] + qids = unique_qids + print("New number of items after removing duplicates:", tokens.shape[0]) + + num_items, vector_len = tokens.shape + + assert num_items == qids.shape[0], "Mismatch between number of token rows and qids" + + row_indices = np.repeat(qids, vector_len) + col_indices = np.tile(np.arange(vector_len), num_items) + data = tokens.flatten() + + print("MAX TOKENS", np.max(data)) + + shape = (qids.max() + 1, vector_len) + + coo = coo_matrix((data, (row_indices, col_indices)), shape=shape, dtype=tokens.dtype) + csr = coo.tocsr() + print("MAX TOKENS", csr.data.max()) + return csr + + class AliasTableLoader: """ This class provides methods to load and process alias tables from two different sources: diff --git a/src/utils/model_factory.py b/src/utils/model_factory.py index 6d56afb..9d4e41b 100644 --- a/src/utils/model_factory.py +++ b/src/utils/model_factory.py @@ -45,6 +45,18 @@ def auto_load_from_file( target_dim: int | None = None, output_type: ModelOutputType | None = None, ) -> torch.nn.Module: + """ + Automatically loads a model from a specified file, optionally applying a state dictionary, target dimension, and output type. + + Args: + file_path (str): Path to the model definition file. + state_dict_path (str | None, optional): Path to the state dictionary file to load model weights. If None, weights are not loaded. + target_dim (int | None, optional): Target output dimension for the model. If None, uses the default dimension. + output_type (ModelOutputType | None, optional): Type of model output. If None, defaults to ModelOutputType.PoolerOutput. + + Returns: + torch.nn.Module: The constructed and optionally weight-loaded model instance. + """ builder = ModelBuilder(file_path) if output_type is None: output_type = ModelOutputType.PoolerOutput # the original/old default diff --git a/tests/models/test_brute_force_searcher.py b/tests/models/test_brute_force_searcher.py index f70a08e..5cca1c1 100644 --- a/tests/models/test_brute_force_searcher.py +++ b/tests/models/test_brute_force_searcher.py @@ -5,6 +5,8 @@ from models.searchers.brute_force_searcher import ( BruteForceSearcher, DPBruteForceSearcher, + DPBruteForceSearcherPT, + ManualSyncBruteForceSearcher, ) # torch.compiler.disable(BruteForceSearcher.find) @@ -129,3 +131,114 @@ def test_dataparallel_initialization(self, small_embs): searcher = DPBruteForceSearcher(small_embs, np.arange(len(small_embs))) searcher.find(np.random.random((1, 3)), 2) # This should initialize module_searcher assert isinstance(searcher.module_searcher, torch.nn.DataParallel) + + +class TestDPBruteForceSearcherPT: + @pytest.fixture + def small_embs(self): + return torch.tensor( + [ + [0.9, 0.9, 0.9], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + ] + ) + + def test_search_present(self, small_embs): + searcher = DPBruteForceSearcherPT(small_embs, np.arange(4)) + for i, e in enumerate(small_embs): + res = searcher.find(e.unsqueeze(0), 2) + assert res[0][0] == i + assert res[0][1] != i + assert len(res[0]) == 2 + + def test_search_missing(self): + embs = torch.tensor( + [ + [1.0, 1.0, 1.0], + [1.0, 0.0, 0.0], + [1.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + ] + ) + searcher = DPBruteForceSearcherPT(embs, np.arange(4)) + res = searcher.find(torch.tensor([[1.0, 0.0, 1.0]]), 2) + assert res[0][0] == 0 + + def test_device_selection(self, small_embs): + searcher = DPBruteForceSearcherPT(small_embs, np.arange(len(small_embs))) + expected_device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + assert searcher.device == expected_device + + def test_changing_num_neighbors(self, small_embs): + searcher = DPBruteForceSearcherPT(small_embs, np.arange(len(small_embs))) + searcher.find( + torch.from_numpy(np.random.random((1, 3))).to(torch.float32), 2 + ) # Initialize with 2 neighbors + # with pytest.raises(Exception): + # Does nothing: + searcher.find( + torch.from_numpy(np.random.random((1, 3))).to(torch.float32), 3 + ) # Try to change to 3 neighbors + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU not available") +class TestManualSyncBruteForceSearcher: + @pytest.fixture + def small_embs(self): + return np.array( + [ + [0.9, 0.9, 0.9], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + ] + ) + + @pytest.fixture + def large_embs(self): + embs = np.random.random((10000, 128)) + return embs / np.linalg.norm(embs, ord=2, axis=1, keepdims=True) + + def test_search_present(self, small_embs): + searcher = ManualSyncBruteForceSearcher(small_embs, np.arange(4)) + for i, e in enumerate(small_embs): + res = searcher.find(np.array([e]), 2) + assert res[0][0] == i + assert res[0][1] != i + assert len(res[0]) == 2 + + def test_search_missing(self): + embs = np.array( + [ + [1.0, 1.0, 1.0], + [1.0, 0.0, 0.0], + [1.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + ] + ) + searcher = ManualSyncBruteForceSearcher(embs, np.arange(4)) + res = searcher.find(np.array([[1.0, 0.0, 1.0]]), 2) + assert res[0][0] == 0 + + @pytest.mark.slow + def test_search_large(self, large_embs): + searcher = ManualSyncBruteForceSearcher(large_embs, np.arange(len(large_embs))) + neg = 7 + for _ in range(10): # Reduced iterations for faster testing + batch = np.random.random((32, 128)) + batch = batch / np.linalg.norm(batch, ord=2, axis=1, keepdims=True) + res = searcher.find(batch, neg) + for j, emb in enumerate(batch): + neighbor_embs = large_embs[res[j]] + dists = [np.dot(emb, ne) for ne in neighbor_embs] + dists_order = [dists[i] >= dists[i + 1] for i in range(len(dists) - 1)] + assert all(dists_order) + + def test_changing_num_neighbors(self, small_embs): + searcher = ManualSyncBruteForceSearcher(small_embs, np.arange(len(small_embs))) + searcher.find(np.random.random((1, 3)), 2) # Initialize with 2 neighbors + # with pytest.raises(Exception): + # Does nothing: + searcher.find(np.random.random((1, 3)), 3) # Try to change to 3 neighbors diff --git a/tests/models/test_faiss_searcher.py b/tests/models/test_faiss_searcher.py index ea55b8d..b89b428 100644 --- a/tests/models/test_faiss_searcher.py +++ b/tests/models/test_faiss_searcher.py @@ -23,6 +23,7 @@ def assert_equal_results(faiss_results, brute_results): np.testing.assert_array_equal(faiss_results, brute_results) +@pytest.mark.skip(reason="FAISS is not currently supported") def test_small(generate_data): from models.searchers.faiss_searcher import FaissSearcher @@ -36,6 +37,7 @@ def test_small(generate_data): assert_equal_results(faiss_out, brute_out) +@pytest.mark.skip(reason="FAISS is not currently supported") def test_large(generate_data): from models.searchers.faiss_searcher import FaissSearcher diff --git a/tests/reranking/models/test_pairwise_mlp.py b/tests/reranking/models/test_pairwise_mlp.py new file mode 100644 index 0000000..b13e92a --- /dev/null +++ b/tests/reranking/models/test_pairwise_mlp.py @@ -0,0 +1,202 @@ +from __future__ import annotations + +import math +from typing import Sequence + +import pytest +import torch +from torch import nn +from transformers.tokenization_utils_base import BatchEncoding + +from reranking.models import PairwiseMLPReranker +from utils.embeddings import create_attention_mask + +MENTION_TEXTS = ["dummy mention positive", "dummy mention negative"] +ENTITY_TEXTS = ["dummy entity positive", "dummy entity negative"] + + +class DummyEmbeddingModel(nn.Module): + output_dim = 2 + + def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: # type: ignore[override] + masked = input_ids.float() * attention_mask + summed = masked.sum(dim=1, keepdim=True) + return torch.cat([summed, 2 * summed], dim=1) + + +class DummyTokenizer: + _TOKEN_MAP = { + MENTION_TEXTS[0]: [1, 1, 0, 0], + MENTION_TEXTS[1]: [1, 0, 0, 0], + ENTITY_TEXTS[0]: [3, 0, 0, 0], + ENTITY_TEXTS[1]: [0, 0, 0, 0], + } + + def __call__( + self, + texts: str | Sequence[str], + *, + padding: bool = True, + truncation: bool = True, + return_tensors: str = "pt", + ) -> BatchEncoding: + del padding, truncation, return_tensors + items = [texts] if isinstance(texts, str) else list(texts) + encodings = [self._TOKEN_MAP.get(text, [1, 0, 0, 0]) for text in items] + tensor = torch.tensor(encodings, dtype=torch.long) + attention_mask = (tensor != 0).long() + return BatchEncoding({"input_ids": tensor, "attention_mask": attention_mask}) + + +def _configure_classifier(model: PairwiseMLPReranker) -> None: + hidden_layer = model.classifier[0] + output_layer = model.classifier[-1] + + with torch.no_grad(): + hidden_layer.weight.copy_(torch.tensor([[0.1, 0.2, 0.3, 0.4]], dtype=torch.float)) + hidden_layer.bias.copy_(torch.tensor([0.5], dtype=torch.float)) + output_layer.weight.copy_(torch.tensor([[1.0]], dtype=torch.float)) + output_layer.bias.zero_() + + +def _prepare_labels(size: int, *, positive_index: int = 0) -> torch.Tensor: + labels = torch.zeros(size, dtype=torch.float) + labels[positive_index] = 1.0 + return labels + + +def _run_common_checks( + model: PairwiseMLPReranker, + mentions: Sequence[str], + entities: Sequence[str], + labels: torch.Tensor, +): + tokenizer = model.tokenizer + mention_batch = tokenizer(mentions, padding=True, truncation=True, return_tensors="pt") + mention_batch["attention_mask"] = create_attention_mask(mention_batch["input_ids"]) + entity_batch = tokenizer(entities, padding=True, truncation=True, return_tensors="pt") + entity_batch["attention_mask"] = create_attention_mask(entity_batch["input_ids"]) + + encode_out = model.model._encode(mention_batch["input_ids"]) + + loss = model.train_step( + { + "mention_tokens": dict(mention_batch), + "entity_tokens": dict(entity_batch), + "labels": labels, + } + ) + + with torch.no_grad(): + manual_mention_embeddings = model.base_model( + input_ids=mention_batch["input_ids"], + attention_mask=mention_batch.get("attention_mask"), + ) + manual_entity_embeddings = model.base_model( + input_ids=entity_batch["input_ids"], + attention_mask=entity_batch.get("attention_mask"), + ) + manual_logits = model.classifier( + torch.cat([manual_mention_embeddings, manual_entity_embeddings], dim=-1) + ).squeeze(-1) + expected_loss = model.loss_fn(manual_logits, labels) + + score = model.score(mentions[0], entities[0]) + + return { + "encode": encode_out, + "loss": loss.detach(), + "score": score, + "manual_mention_embeddings": manual_mention_embeddings, + "manual_entity_embeddings": manual_entity_embeddings, + "manual_logits": manual_logits, + "expected_loss": expected_loss, + } + + +@pytest.fixture() +def dummy_model(monkeypatch) -> PairwiseMLPReranker: + dummy_embedding_model = DummyEmbeddingModel() + + def _mock_model_loader(*_args, **_kwargs): + return dummy_embedding_model + + monkeypatch.setattr( + "reranking.models.pairwise_mlp.ModelFactory.auto_load_from_file", _mock_model_loader + ) + monkeypatch.setattr( + "reranking.models.pairwise_mlp.AutoTokenizer.from_pretrained", + lambda *_args, **_kwargs: DummyTokenizer(), + ) + + model = PairwiseMLPReranker( + "dummy-model", + mlp_hidden_dim=1, + dropout=0.0, + ) + _configure_classifier(model) + return model + + +def test_pairwise_mlp_with_dummy_backbone(dummy_model: PairwiseMLPReranker) -> None: + labels = _prepare_labels(len(MENTION_TEXTS)) + results = _run_common_checks(dummy_model, MENTION_TEXTS, ENTITY_TEXTS, labels) + + assert torch.allclose( + results["encode"], + results["manual_mention_embeddings"], + ) + + assert torch.allclose(results["loss"], results["expected_loss"]) + + positive_logit = results["manual_logits"][0].item() + assert math.isclose( + results["score"], torch.sigmoid(torch.tensor(positive_logit)).item(), rel_tol=1e-5 + ) + + +@pytest.mark.slow +def test_pairwise_mlp_with_lealla_backbone() -> None: + model = PairwiseMLPReranker( + "setu4993/LEALLA-base", + mlp_hidden_dim=1, + dropout=0.0, + ) + + labels = _prepare_labels(len(MENTION_TEXTS)) + results = _run_common_checks(model, MENTION_TEXTS, ENTITY_TEXTS, labels) + + assert torch.allclose( + results["encode"], + results["manual_mention_embeddings"], + ) + + assert torch.allclose(results["loss"], results["expected_loss"], atol=1e-6) + + positive_probability = torch.sigmoid(results["manual_logits"][0]).item() + assert 0.0 <= results["score"] <= 1.0 + assert math.isclose(results["score"], positive_probability, rel_tol=1e-5) + + +@pytest.mark.slow +def test_pairwise_mlp_with_lealla_backbone_more_iters() -> None: + model = PairwiseMLPReranker( + "setu4993/LEALLA-base", + mlp_hidden_dim=1, + dropout=0.0, + ) + + for _ in range(3): + labels = _prepare_labels(len(MENTION_TEXTS)) + results = _run_common_checks(model, MENTION_TEXTS, ENTITY_TEXTS, labels) + + assert torch.allclose( + results["encode"], + results["manual_mention_embeddings"], + ) + + assert torch.allclose(results["loss"], results["expected_loss"], atol=1e-6) + + positive_probability = torch.sigmoid(results["manual_logits"][0]).item() + assert 0.0 <= results["score"] <= 1.0 + assert math.isclose(results["score"], positive_probability, rel_tol=1e-5) diff --git a/tests/reranking/models/test_pairwise_with_retrieval_score.py b/tests/reranking/models/test_pairwise_with_retrieval_score.py new file mode 100644 index 0000000..668306d --- /dev/null +++ b/tests/reranking/models/test_pairwise_with_retrieval_score.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +import math +from typing import Sequence + +import pytest +import torch +from torch import nn +from transformers.tokenization_utils_base import BatchEncoding + +from reranking.models import PairwiseMLPRerankerWithRetrievalScore +from utils.embeddings import create_attention_mask + +MENTION_TEXTS = ["dummy mention positive", "dummy mention negative"] +ENTITY_TEXTS = ["dummy entity positive", "dummy entity negative"] + + +class DummyEmbeddingModel(nn.Module): + output_dim = 2 + + def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: # type: ignore[override] + masked = input_ids.float() * attention_mask + summed = masked.sum(dim=1, keepdim=True) + return torch.cat([summed, 2 * summed], dim=1) + + +class DummyTokenizer: + _TOKEN_MAP = { + MENTION_TEXTS[0]: [1, 1, 0, 0], + MENTION_TEXTS[1]: [1, 0, 0, 0], + ENTITY_TEXTS[0]: [3, 0, 0, 0], + ENTITY_TEXTS[1]: [0, 0, 0, 0], + } + + def __call__( + self, + texts: str | Sequence[str], + *, + padding: bool = True, + truncation: bool = True, + return_tensors: str = "pt", + ) -> BatchEncoding: + del padding, truncation, return_tensors + items = [texts] if isinstance(texts, str) else list(texts) + encodings = [self._TOKEN_MAP.get(text, [1, 0, 0, 0]) for text in items] + tensor = torch.tensor(encodings, dtype=torch.long) + attention_mask = (tensor != 0).long() + return BatchEncoding({"input_ids": tensor, "attention_mask": attention_mask}) + + +def _configure_classifier(model: PairwiseMLPRerankerWithRetrievalScore) -> None: + hidden_layer = model.classifier[0] + output_layer = model.classifier[-1] + + with torch.no_grad(): + hidden_layer.weight.copy_(torch.tensor([[0.1, 0.2, 0.3, 0.4]], dtype=torch.float)) + hidden_layer.bias.copy_(torch.tensor([0.5], dtype=torch.float)) + output_layer.weight.copy_(torch.tensor([[1.0]], dtype=torch.float)) + output_layer.bias.zero_() + + +def _prepare_labels(size: int, *, positive_index: int = 0) -> torch.Tensor: + labels = torch.zeros(size, dtype=torch.float) + labels[positive_index] = 1.0 + return labels + + +@pytest.fixture() +def dummy_model(monkeypatch) -> PairwiseMLPRerankerWithRetrievalScore: + dummy_embedding_model = DummyEmbeddingModel() + + def _mock_model_loader(*_args, **_kwargs): + return dummy_embedding_model + + monkeypatch.setattr( + "reranking.models.pairwise_mlp.ModelFactory.auto_load_from_file", _mock_model_loader + ) + monkeypatch.setattr( + "reranking.models.pairwise_mlp.AutoTokenizer.from_pretrained", + lambda *_args, **_kwargs: DummyTokenizer(), + ) + + model = PairwiseMLPRerankerWithRetrievalScore( + "dummy-model", + mlp_hidden_dim=1, + dropout=0.0, + ) + _configure_classifier(model) + return model + + +def test_score_not_failing(dummy_model: PairwiseMLPRerankerWithRetrievalScore) -> None: + score = dummy_model.score(MENTION_TEXTS[0], ENTITY_TEXTS[0]) + assert isinstance(score, float) + assert score >= 0.0 + assert score <= 1.0 diff --git a/tests/reranking/training/test_reranking_iterable_dataset.py b/tests/reranking/training/test_reranking_iterable_dataset.py new file mode 100644 index 0000000..81dba02 --- /dev/null +++ b/tests/reranking/training/test_reranking_iterable_dataset.py @@ -0,0 +1,24 @@ +import numpy as np +import torch + +from src.reranking.training.reranking_iterable_dataset import RerankingIterableDataset + + +def test_reranking_iterable_dataset_iterates_samples(tmp_path): + file_path = tmp_path / "part-000.npz" + data = { + "qids": np.array([10, 20], dtype=np.int64), + "y": np.array([1.0, 0.0], dtype=np.float32), + "description_tokens": np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64), + "link_tokens": np.array([[7, 8], [9, 10]], dtype=np.int64), + } + np.savez(file_path, **data) + + dataset = RerankingIterableDataset(tmp_path) + + samples = list(dataset) + assert len(samples) == 2 + + first = samples[0] + assert isinstance(first, tuple) + assert torch.equal(first[2], torch.tensor(1.0, dtype=torch.float32)) diff --git a/tests/reranking/training/test_training_configs.py b/tests/reranking/training/test_training_configs.py new file mode 100644 index 0000000..8ed34a5 --- /dev/null +++ b/tests/reranking/training/test_training_configs.py @@ -0,0 +1,25 @@ +from unittest.mock import MagicMock + +import pytest + +from reranking.training.training_configs import TrainingConfig + + +def test_get_output_path_creates_directory(tmp_path): + output_root = tmp_path / "outputs" + config = TrainingConfig( + config_name="test_config", + model=MagicMock(), + dataset=MagicMock(), + optimizer=MagicMock(), + save_each=100, + batch_size=1, + output_dir=str(output_root), + validate_each=50, + ) + + path = config.get_output_path(step=5) + + expected_dir = output_root / "test_config" + assert expected_dir.exists() and expected_dir.is_dir() + assert path == f"{expected_dir}/5.pth" diff --git a/tests/scripts/reranking/test_change_dataset_tokens.py b/tests/scripts/reranking/test_change_dataset_tokens.py new file mode 100644 index 0000000..613d33f --- /dev/null +++ b/tests/scripts/reranking/test_change_dataset_tokens.py @@ -0,0 +1,70 @@ +from pathlib import Path +from unittest.mock import MagicMock, patch + +import numpy as np +from scipy.sparse import coo_matrix, csr_matrix + +from scripts.reranking.change_dataset_tokens import process_directory, update_tokens_in_file + + +def _create_npz_file(path: Path, filename: str, **kwargs): + """Helper to create an .npz file for testing.""" + np.savez(path / filename, **kwargs) + + +def _create_sparse_matrix(data_dict: dict, shape: tuple[int, int], dtype=np.float32) -> csr_matrix: + """Helper to build a CSR matrix from a dictionary of {qid: vector}.""" + rows, cols, data = [], [], [] + vector_len = shape[1] + for qid, vector in data_dict.items(): + rows.extend([qid] * vector_len) + cols.extend(range(vector_len)) + data.extend(vector) + return coo_matrix((data, (rows, cols)), shape=shape, dtype=dtype).tocsr() + + +def test_update_tokens_when_match_exists(tmp_path: Path): + """ + Tests that the file is modified. + """ + test_file = tmp_path / "data.npz" + original_tokens = np.array([[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]]) + original_qids = np.array([100, 200, 300]) + _create_npz_file( + tmp_path, "data.npz", description_tokens=original_tokens.copy(), qids=original_qids + ) + + new_token_vector = np.array([99.0, 99.0]) + update_matrix = _create_sparse_matrix( + {100: original_tokens[0], 200: new_token_vector, 300: original_tokens[2]}, shape=(400, 2) + ) + update_tokens_in_file(test_file, update_matrix) + + loaded_data = np.load(test_file) + updated_tokens = loaded_data["description_tokens"] + + assert np.array_equal(updated_tokens[1], new_token_vector) + assert np.array_equal(updated_tokens[0], original_tokens[0]) + assert np.array_equal(updated_tokens[2], original_tokens[2]) + + +@patch("scripts.reranking.change_dataset_tokens.update_tokens_in_file") +@patch("scripts.reranking.change_dataset_tokens.map_qids_to_token_matrix") +def test_process_directory_orchestration( + mock_map_qids: MagicMock, mock_update_file: MagicMock, tmp_path: Path +): + dataset_dir = tmp_path / "dataset" + tokens_dir = tmp_path / "tokens" + dataset_dir.mkdir() + tokens_dir.mkdir() + + file_paths = [dataset_dir / "a.npz", dataset_dir / "b.npz", dataset_dir / "c.npz"] + + mock_map_qids.return_value = csr_matrix((3, 2)) + mock_update_file.side_effect = [True, False, True] + + with patch.object(Path, "glob", return_value=file_paths): + process_directory(dataset_dir, tokens_dir) + + mock_map_qids.assert_called_once_with(tokens_dir, verbose=True) + assert mock_update_file.call_count == len(file_paths) diff --git a/tests/utils/test_loaders.py b/tests/utils/test_loaders.py index dff323c..9c34b6f 100644 --- a/tests/utils/test_loaders.py +++ b/tests/utils/test_loaders.py @@ -5,6 +5,7 @@ import numpy as np import pandas as pd import pytest +from scipy.sparse import csr_matrix from utils.loaders import ( AliasTableLoader, @@ -13,6 +14,8 @@ load_qids, load_qids_npy, load_tokens_qids, + load_tokens_qids_from_dir, + map_qids_to_token_matrix, ) @@ -20,8 +23,14 @@ def mock_remap_qids(qids, _): return qids +def _create_tokens_qids_npz(dir_path: Path, file_name: str, tokens: np.ndarray, qids: np.ndarray): + file_path = Path(dir_path) / file_name + np.savez_compressed(file_path, tokens=np.array(tokens), qids=np.array(qids)) + return file_path + + @patch("utils.qids_remap.qids_remap", side_effect=mock_remap_qids) -def test_load_mentions_with_path_object(mock_qids_remap): +def test_load_tokens_qids_with_path_object(mock_qids_remap): with tempfile.TemporaryDirectory() as temp_dir: file_path = Path(temp_dir) / "mentions_2.npz" @@ -37,7 +46,7 @@ def test_load_mentions_with_path_object(mock_qids_remap): @patch("utils.qids_remap.qids_remap", side_effect=mock_remap_qids) -def test_load_mentions_with_string_path(mock_qids_remap): +def test_load_tokens_qids_with_string_path(mock_qids_remap): with tempfile.TemporaryDirectory() as temp_dir: file_path = str(Path(temp_dir) / "mentions_1.npz") @@ -273,6 +282,101 @@ def test_embs_qids_tokens_from_file(mock_qids_remap, use_string_path, file_name) assert len(loaded_data) == len(test_data) +@patch("utils.qids_remap.qids_remap", side_effect=mock_remap_qids) +def test_load_tokens_qids_from_dir_single_file(mock_qids_remap): + with tempfile.TemporaryDirectory() as temp_dir: + dir_path = Path(temp_dir) + + tokens = np.array([[1, 2, 3], [4, 5, 6]]) + qids = np.array([10, 20]) + _create_tokens_qids_npz(dir_path, "data_0.npz", tokens, qids) + + loaded_tokens, loaded_qids = load_tokens_qids_from_dir(dir_path) + + sort_indices = np.argsort(qids, kind="stable") + assert np.array_equal(loaded_qids, qids[sort_indices]) + assert np.array_equal(loaded_tokens, tokens[sort_indices]) + + +@patch("utils.qids_remap.qids_remap", side_effect=mock_remap_qids) +def test_load_tokens_qids_from_dir_multiple_files(mock_qids_remap): + with tempfile.TemporaryDirectory() as temp_dir: + dir_path = Path(temp_dir) + + tokens_a = np.array([[1, 1, 1], [2, 2, 2]]) + qids_a = np.array([20, 10]) + tokens_b = np.array([[3, 3, 3]]) + qids_b = np.array([30]) + _create_tokens_qids_npz(dir_path, "data_a.npz", tokens_a, qids_a) + _create_tokens_qids_npz(dir_path, "data_b.npz", tokens_b, qids_b) + + loaded_tokens, loaded_qids = load_tokens_qids_from_dir(dir_path) + + expected_tokens = np.vstack([tokens_a, tokens_b]) + expected_qids = np.concatenate([qids_a, qids_b]) + sort_indices = np.argsort(expected_qids, kind="stable") + + assert np.array_equal(loaded_qids, expected_qids[sort_indices]) + assert np.array_equal(loaded_tokens, expected_tokens[sort_indices]) + + +@patch("utils.qids_remap.qids_remap", side_effect=mock_remap_qids) +def test_load_tokens_qids_from_dir_max_limit_overflow(mock_qids_remap): + with tempfile.TemporaryDirectory() as temp_dir: + dir_path = Path(temp_dir) + + tokens_a = np.array([[1, 0], [2, 0], [3, 0]]) + qids_a = np.array([300, 100, 200]) + tokens_b = np.array([[4, 0], [5, 0], [6, 0]]) + qids_b = np.array([600, 500, 400]) + _create_tokens_qids_npz(dir_path, "data_a.npz", tokens_a, qids_a) + _create_tokens_qids_npz(dir_path, "data_b.npz", tokens_b, qids_b) + + max_items = 4 + loaded_tokens, loaded_qids = load_tokens_qids_from_dir( + dir_path, max_items_to_load=max_items + ) + + expected_tokens = np.vstack([tokens_a, tokens_b]) + expected_qids = np.concatenate([qids_a, qids_b]) + sort_indices = np.argsort(expected_qids, kind="stable") + + assert len(loaded_tokens) == len(expected_tokens) + assert len(loaded_tokens) > max_items + assert np.array_equal(loaded_qids, expected_qids[sort_indices]) + assert np.array_equal(loaded_tokens, expected_tokens[sort_indices]) + + +@patch("utils.qids_remap.qids_remap", side_effect=mock_remap_qids) +def test_load_tokens_qids_from_dir_large_max_limit_loads_all(mock_qids_remap): + with tempfile.TemporaryDirectory() as temp_dir: + dir_path = Path(temp_dir) + + tokens_a = np.array([[10, 10], [20, 20]]) + qids_a = np.array([200, 100]) + tokens_b = np.array([[30, 30]]) + qids_b = np.array([300]) + tokens_c = np.array([[40, 40]]) + qids_c = np.array([400]) + _create_tokens_qids_npz(dir_path, "data_a.npz", tokens_a, qids_a) + _create_tokens_qids_npz(dir_path, "data_b.npz", tokens_b, qids_b) + _create_tokens_qids_npz(dir_path, "data_c.npz", tokens_c, qids_c) + + max_items = 100 + loaded_tokens, loaded_qids = load_tokens_qids_from_dir( + dir_path, max_items_to_load=max_items + ) + + expected_tokens = np.vstack([tokens_a, tokens_b, tokens_c]) + expected_qids = np.concatenate([qids_a, qids_b, qids_c]) + sort_indices = np.argsort(expected_qids, kind="stable") + + assert len(loaded_tokens) == len(expected_tokens) + assert len(loaded_tokens) < max_items + assert np.array_equal(loaded_qids, expected_qids[sort_indices]) + assert np.array_equal(loaded_tokens, expected_tokens[sort_indices]) + + @pytest.mark.parametrize("use_string_path", [True, False]) @patch("utils.qids_remap.qids_remap", side_effect=mock_remap_qids) def test_load_qids(mock_qids_remap, use_string_path: bool) -> None: @@ -311,6 +415,33 @@ def test_load_qids_npy(mock_qids_remap, use_string_path: bool) -> None: assert isinstance(loaded_qids, np.ndarray) +def test_map_qids_to_token_matrix() -> None: + """ + Tests that qids are correctly mapped to their token vectors in a sparse matrix. + """ + with tempfile.TemporaryDirectory() as temp_dir: + dir_path = Path(temp_dir) + + tokens = np.array([[1.0, 1.5], [2.0, 2.5], [3.0, 3.5]]) + qids = np.array([300, 100, 200]) + _create_tokens_qids_npz(dir_path, "tokens_qids.npz", tokens, qids) + + token_matrix = map_qids_to_token_matrix(dir_path) + + assert isinstance(token_matrix, csr_matrix) + assert token_matrix.shape == (301, 2) + + for i, qid in enumerate(qids): + original_vector = tokens[i] + retrieved_vector = token_matrix[qid].toarray()[0] + + assert np.array_equal(original_vector, retrieved_vector) + + non_existent_qid = 150 + zero_vector = token_matrix[non_existent_qid].toarray()[0] + assert np.array_equal(zero_vector, np.zeros(tokens.shape[1])) + + @pytest.mark.parametrize("lowercase", [True, False]) class TestAliasTableLoader: def setup_method(self, lowercase): diff --git a/uv.lock b/uv.lock index af4a589..a4b6bf2 100644 --- a/uv.lock +++ b/uv.lock @@ -17,6 +17,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643, upload-time = "2024-05-20T21:33:24.1Z" }, ] +[[package]] +name = "asttokens" +version = "3.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/4a/e7/82da0a03e7ba5141f05cce0d302e6eed121ae055e0456ca228bf693984bc/asttokens-3.0.0.tar.gz", hash = "sha256:0dcd8baa8d62b0c1d118b399b2ddba3c4aff271d0d7a9e0d4c1681c79035bbc7", size = 61978, upload-time = "2024-11-30T04:30:14.439Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/25/8a/c46dcc25341b5bce5472c718902eb3d38600a903b14fa6aeecef3f21a46f/asttokens-3.0.0-py3-none-any.whl", hash = "sha256:e3078351a059199dd5138cb1c706e6430c05eff2ff136af5eb4790f9d28932e2", size = 26918, upload-time = "2024-11-30T04:30:10.946Z" }, +] + [[package]] name = "certifi" version = "2025.8.3" @@ -100,6 +109,125 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, ] +[[package]] +name = "coverage" +version = "7.10.7" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/51/26/d22c300112504f5f9a9fd2297ce33c35f3d353e4aeb987c8419453b2a7c2/coverage-7.10.7.tar.gz", hash = "sha256:f4ab143ab113be368a3e9b795f9cd7906c5ef407d6173fe9675a902e1fffc239", size = 827704, upload-time = "2025-09-21T20:03:56.815Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/5d/c1a17867b0456f2e9ce2d8d4708a4c3a089947d0bec9c66cdf60c9e7739f/coverage-7.10.7-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a609f9c93113be646f44c2a0256d6ea375ad047005d7f57a5c15f614dc1b2f59", size = 218102, upload-time = "2025-09-21T20:01:16.089Z" }, + { url = "https://files.pythonhosted.org/packages/54/f0/514dcf4b4e3698b9a9077f084429681bf3aad2b4a72578f89d7f643eb506/coverage-7.10.7-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:65646bb0359386e07639c367a22cf9b5bf6304e8630b565d0626e2bdf329227a", size = 218505, upload-time = "2025-09-21T20:01:17.788Z" }, + { url = "https://files.pythonhosted.org/packages/20/f6/9626b81d17e2a4b25c63ac1b425ff307ecdeef03d67c9a147673ae40dc36/coverage-7.10.7-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:5f33166f0dfcce728191f520bd2692914ec70fac2713f6bf3ce59c3deacb4699", size = 248898, upload-time = "2025-09-21T20:01:19.488Z" }, + { url = "https://files.pythonhosted.org/packages/b0/ef/bd8e719c2f7417ba03239052e099b76ea1130ac0cbb183ee1fcaa58aaff3/coverage-7.10.7-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:35f5e3f9e455bb17831876048355dca0f758b6df22f49258cb5a91da23ef437d", size = 250831, upload-time = "2025-09-21T20:01:20.817Z" }, + { url = "https://files.pythonhosted.org/packages/a5/b6/bf054de41ec948b151ae2b79a55c107f5760979538f5fb80c195f2517718/coverage-7.10.7-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4da86b6d62a496e908ac2898243920c7992499c1712ff7c2b6d837cc69d9467e", size = 252937, upload-time = "2025-09-21T20:01:22.171Z" }, + { url = "https://files.pythonhosted.org/packages/0f/e5/3860756aa6f9318227443c6ce4ed7bf9e70bb7f1447a0353f45ac5c7974b/coverage-7.10.7-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:6b8b09c1fad947c84bbbc95eca841350fad9cbfa5a2d7ca88ac9f8d836c92e23", size = 249021, upload-time = "2025-09-21T20:01:23.907Z" }, + { url = "https://files.pythonhosted.org/packages/26/0f/bd08bd042854f7fd07b45808927ebcce99a7ed0f2f412d11629883517ac2/coverage-7.10.7-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:4376538f36b533b46f8971d3a3e63464f2c7905c9800db97361c43a2b14792ab", size = 250626, upload-time = "2025-09-21T20:01:25.721Z" }, + { url = "https://files.pythonhosted.org/packages/8e/a7/4777b14de4abcc2e80c6b1d430f5d51eb18ed1d75fca56cbce5f2db9b36e/coverage-7.10.7-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:121da30abb574f6ce6ae09840dae322bef734480ceafe410117627aa54f76d82", size = 248682, upload-time = "2025-09-21T20:01:27.105Z" }, + { url = "https://files.pythonhosted.org/packages/34/72/17d082b00b53cd45679bad682fac058b87f011fd8b9fe31d77f5f8d3a4e4/coverage-7.10.7-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:88127d40df529336a9836870436fc2751c339fbaed3a836d42c93f3e4bd1d0a2", size = 248402, upload-time = "2025-09-21T20:01:28.629Z" }, + { url = "https://files.pythonhosted.org/packages/81/7a/92367572eb5bdd6a84bfa278cc7e97db192f9f45b28c94a9ca1a921c3577/coverage-7.10.7-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ba58bbcd1b72f136080c0bccc2400d66cc6115f3f906c499013d065ac33a4b61", size = 249320, upload-time = "2025-09-21T20:01:30.004Z" }, + { url = "https://files.pythonhosted.org/packages/2f/88/a23cc185f6a805dfc4fdf14a94016835eeb85e22ac3a0e66d5e89acd6462/coverage-7.10.7-cp311-cp311-win32.whl", hash = "sha256:972b9e3a4094b053a4e46832b4bc829fc8a8d347160eb39d03f1690316a99c14", size = 220536, upload-time = "2025-09-21T20:01:32.184Z" }, + { url = "https://files.pythonhosted.org/packages/fe/ef/0b510a399dfca17cec7bc2f05ad8bd78cf55f15c8bc9a73ab20c5c913c2e/coverage-7.10.7-cp311-cp311-win_amd64.whl", hash = "sha256:a7b55a944a7f43892e28ad4bc0561dfd5f0d73e605d1aa5c3c976b52aea121d2", size = 221425, upload-time = "2025-09-21T20:01:33.557Z" }, + { url = "https://files.pythonhosted.org/packages/51/7f/023657f301a276e4ba1850f82749bc136f5a7e8768060c2e5d9744a22951/coverage-7.10.7-cp311-cp311-win_arm64.whl", hash = "sha256:736f227fb490f03c6488f9b6d45855f8e0fd749c007f9303ad30efab0e73c05a", size = 220103, upload-time = "2025-09-21T20:01:34.929Z" }, + { url = "https://files.pythonhosted.org/packages/13/e4/eb12450f71b542a53972d19117ea5a5cea1cab3ac9e31b0b5d498df1bd5a/coverage-7.10.7-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:7bb3b9ddb87ef7725056572368040c32775036472d5a033679d1fa6c8dc08417", size = 218290, upload-time = "2025-09-21T20:01:36.455Z" }, + { url = "https://files.pythonhosted.org/packages/37/66/593f9be12fc19fb36711f19a5371af79a718537204d16ea1d36f16bd78d2/coverage-7.10.7-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:18afb24843cbc175687225cab1138c95d262337f5473512010e46831aa0c2973", size = 218515, upload-time = "2025-09-21T20:01:37.982Z" }, + { url = "https://files.pythonhosted.org/packages/66/80/4c49f7ae09cafdacc73fbc30949ffe77359635c168f4e9ff33c9ebb07838/coverage-7.10.7-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:399a0b6347bcd3822be369392932884b8216d0944049ae22925631a9b3d4ba4c", size = 250020, upload-time = "2025-09-21T20:01:39.617Z" }, + { url = "https://files.pythonhosted.org/packages/a6/90/a64aaacab3b37a17aaedd83e8000142561a29eb262cede42d94a67f7556b/coverage-7.10.7-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:314f2c326ded3f4b09be11bc282eb2fc861184bc95748ae67b360ac962770be7", size = 252769, upload-time = "2025-09-21T20:01:41.341Z" }, + { url = "https://files.pythonhosted.org/packages/98/2e/2dda59afd6103b342e096f246ebc5f87a3363b5412609946c120f4e7750d/coverage-7.10.7-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c41e71c9cfb854789dee6fc51e46743a6d138b1803fab6cb860af43265b42ea6", size = 253901, upload-time = "2025-09-21T20:01:43.042Z" }, + { url = "https://files.pythonhosted.org/packages/53/dc/8d8119c9051d50f3119bb4a75f29f1e4a6ab9415cd1fa8bf22fcc3fb3b5f/coverage-7.10.7-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:bc01f57ca26269c2c706e838f6422e2a8788e41b3e3c65e2f41148212e57cd59", size = 250413, upload-time = "2025-09-21T20:01:44.469Z" }, + { url = "https://files.pythonhosted.org/packages/98/b3/edaff9c5d79ee4d4b6d3fe046f2b1d799850425695b789d491a64225d493/coverage-7.10.7-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a6442c59a8ac8b85812ce33bc4d05bde3fb22321fa8294e2a5b487c3505f611b", size = 251820, upload-time = "2025-09-21T20:01:45.915Z" }, + { url = "https://files.pythonhosted.org/packages/11/25/9a0728564bb05863f7e513e5a594fe5ffef091b325437f5430e8cfb0d530/coverage-7.10.7-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:78a384e49f46b80fb4c901d52d92abe098e78768ed829c673fbb53c498bef73a", size = 249941, upload-time = "2025-09-21T20:01:47.296Z" }, + { url = "https://files.pythonhosted.org/packages/e0/fd/ca2650443bfbef5b0e74373aac4df67b08180d2f184b482c41499668e258/coverage-7.10.7-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:5e1e9802121405ede4b0133aa4340ad8186a1d2526de5b7c3eca519db7bb89fb", size = 249519, upload-time = "2025-09-21T20:01:48.73Z" }, + { url = "https://files.pythonhosted.org/packages/24/79/f692f125fb4299b6f963b0745124998ebb8e73ecdfce4ceceb06a8c6bec5/coverage-7.10.7-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:d41213ea25a86f69efd1575073d34ea11aabe075604ddf3d148ecfec9e1e96a1", size = 251375, upload-time = "2025-09-21T20:01:50.529Z" }, + { url = "https://files.pythonhosted.org/packages/5e/75/61b9bbd6c7d24d896bfeec57acba78e0f8deac68e6baf2d4804f7aae1f88/coverage-7.10.7-cp312-cp312-win32.whl", hash = "sha256:77eb4c747061a6af8d0f7bdb31f1e108d172762ef579166ec84542f711d90256", size = 220699, upload-time = "2025-09-21T20:01:51.941Z" }, + { url = "https://files.pythonhosted.org/packages/ca/f3/3bf7905288b45b075918d372498f1cf845b5b579b723c8fd17168018d5f5/coverage-7.10.7-cp312-cp312-win_amd64.whl", hash = "sha256:f51328ffe987aecf6d09f3cd9d979face89a617eacdaea43e7b3080777f647ba", size = 221512, upload-time = "2025-09-21T20:01:53.481Z" }, + { url = "https://files.pythonhosted.org/packages/5c/44/3e32dbe933979d05cf2dac5e697c8599cfe038aaf51223ab901e208d5a62/coverage-7.10.7-cp312-cp312-win_arm64.whl", hash = "sha256:bda5e34f8a75721c96085903c6f2197dc398c20ffd98df33f866a9c8fd95f4bf", size = 220147, upload-time = "2025-09-21T20:01:55.2Z" }, + { url = "https://files.pythonhosted.org/packages/9a/94/b765c1abcb613d103b64fcf10395f54d69b0ef8be6a0dd9c524384892cc7/coverage-7.10.7-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:981a651f543f2854abd3b5fcb3263aac581b18209be49863ba575de6edf4c14d", size = 218320, upload-time = "2025-09-21T20:01:56.629Z" }, + { url = "https://files.pythonhosted.org/packages/72/4f/732fff31c119bb73b35236dd333030f32c4bfe909f445b423e6c7594f9a2/coverage-7.10.7-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:73ab1601f84dc804f7812dc297e93cd99381162da39c47040a827d4e8dafe63b", size = 218575, upload-time = "2025-09-21T20:01:58.203Z" }, + { url = "https://files.pythonhosted.org/packages/87/02/ae7e0af4b674be47566707777db1aa375474f02a1d64b9323e5813a6cdd5/coverage-7.10.7-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:a8b6f03672aa6734e700bbcd65ff050fd19cddfec4b031cc8cf1c6967de5a68e", size = 249568, upload-time = "2025-09-21T20:01:59.748Z" }, + { url = "https://files.pythonhosted.org/packages/a2/77/8c6d22bf61921a59bce5471c2f1f7ac30cd4ac50aadde72b8c48d5727902/coverage-7.10.7-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:10b6ba00ab1132a0ce4428ff68cf50a25efd6840a42cdf4239c9b99aad83be8b", size = 252174, upload-time = "2025-09-21T20:02:01.192Z" }, + { url = "https://files.pythonhosted.org/packages/b1/20/b6ea4f69bbb52dac0aebd62157ba6a9dddbfe664f5af8122dac296c3ee15/coverage-7.10.7-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c79124f70465a150e89340de5963f936ee97097d2ef76c869708c4248c63ca49", size = 253447, upload-time = "2025-09-21T20:02:02.701Z" }, + { url = "https://files.pythonhosted.org/packages/f9/28/4831523ba483a7f90f7b259d2018fef02cb4d5b90bc7c1505d6e5a84883c/coverage-7.10.7-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:69212fbccdbd5b0e39eac4067e20a4a5256609e209547d86f740d68ad4f04911", size = 249779, upload-time = "2025-09-21T20:02:04.185Z" }, + { url = "https://files.pythonhosted.org/packages/a7/9f/4331142bc98c10ca6436d2d620c3e165f31e6c58d43479985afce6f3191c/coverage-7.10.7-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:7ea7c6c9d0d286d04ed3541747e6597cbe4971f22648b68248f7ddcd329207f0", size = 251604, upload-time = "2025-09-21T20:02:06.034Z" }, + { url = "https://files.pythonhosted.org/packages/ce/60/bda83b96602036b77ecf34e6393a3836365481b69f7ed7079ab85048202b/coverage-7.10.7-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:b9be91986841a75042b3e3243d0b3cb0b2434252b977baaf0cd56e960fe1e46f", size = 249497, upload-time = "2025-09-21T20:02:07.619Z" }, + { url = "https://files.pythonhosted.org/packages/5f/af/152633ff35b2af63977edd835d8e6430f0caef27d171edf2fc76c270ef31/coverage-7.10.7-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:b281d5eca50189325cfe1f365fafade89b14b4a78d9b40b05ddd1fc7d2a10a9c", size = 249350, upload-time = "2025-09-21T20:02:10.34Z" }, + { url = "https://files.pythonhosted.org/packages/9d/71/d92105d122bd21cebba877228990e1646d862e34a98bb3374d3fece5a794/coverage-7.10.7-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:99e4aa63097ab1118e75a848a28e40d68b08a5e19ce587891ab7fd04475e780f", size = 251111, upload-time = "2025-09-21T20:02:12.122Z" }, + { url = "https://files.pythonhosted.org/packages/a2/9e/9fdb08f4bf476c912f0c3ca292e019aab6712c93c9344a1653986c3fd305/coverage-7.10.7-cp313-cp313-win32.whl", hash = "sha256:dc7c389dce432500273eaf48f410b37886be9208b2dd5710aaf7c57fd442c698", size = 220746, upload-time = "2025-09-21T20:02:13.919Z" }, + { url = "https://files.pythonhosted.org/packages/b1/b1/a75fd25df44eab52d1931e89980d1ada46824c7a3210be0d3c88a44aaa99/coverage-7.10.7-cp313-cp313-win_amd64.whl", hash = "sha256:cac0fdca17b036af3881a9d2729a850b76553f3f716ccb0360ad4dbc06b3b843", size = 221541, upload-time = "2025-09-21T20:02:15.57Z" }, + { url = "https://files.pythonhosted.org/packages/14/3a/d720d7c989562a6e9a14b2c9f5f2876bdb38e9367126d118495b89c99c37/coverage-7.10.7-cp313-cp313-win_arm64.whl", hash = "sha256:4b6f236edf6e2f9ae8fcd1332da4e791c1b6ba0dc16a2dc94590ceccb482e546", size = 220170, upload-time = "2025-09-21T20:02:17.395Z" }, + { url = "https://files.pythonhosted.org/packages/bb/22/e04514bf2a735d8b0add31d2b4ab636fc02370730787c576bb995390d2d5/coverage-7.10.7-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:a0ec07fd264d0745ee396b666d47cef20875f4ff2375d7c4f58235886cc1ef0c", size = 219029, upload-time = "2025-09-21T20:02:18.936Z" }, + { url = "https://files.pythonhosted.org/packages/11/0b/91128e099035ece15da3445d9015e4b4153a6059403452d324cbb0a575fa/coverage-7.10.7-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:dd5e856ebb7bfb7672b0086846db5afb4567a7b9714b8a0ebafd211ec7ce6a15", size = 219259, upload-time = "2025-09-21T20:02:20.44Z" }, + { url = "https://files.pythonhosted.org/packages/8b/51/66420081e72801536a091a0c8f8c1f88a5c4bf7b9b1bdc6222c7afe6dc9b/coverage-7.10.7-cp313-cp313t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:f57b2a3c8353d3e04acf75b3fed57ba41f5c0646bbf1d10c7c282291c97936b4", size = 260592, upload-time = "2025-09-21T20:02:22.313Z" }, + { url = "https://files.pythonhosted.org/packages/5d/22/9b8d458c2881b22df3db5bb3e7369e63d527d986decb6c11a591ba2364f7/coverage-7.10.7-cp313-cp313t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:1ef2319dd15a0b009667301a3f84452a4dc6fddfd06b0c5c53ea472d3989fbf0", size = 262768, upload-time = "2025-09-21T20:02:24.287Z" }, + { url = "https://files.pythonhosted.org/packages/f7/08/16bee2c433e60913c610ea200b276e8eeef084b0d200bdcff69920bd5828/coverage-7.10.7-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:83082a57783239717ceb0ad584de3c69cf581b2a95ed6bf81ea66034f00401c0", size = 264995, upload-time = "2025-09-21T20:02:26.133Z" }, + { url = "https://files.pythonhosted.org/packages/20/9d/e53eb9771d154859b084b90201e5221bca7674ba449a17c101a5031d4054/coverage-7.10.7-cp313-cp313t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:50aa94fb1fb9a397eaa19c0d5ec15a5edd03a47bf1a3a6111a16b36e190cff65", size = 259546, upload-time = "2025-09-21T20:02:27.716Z" }, + { url = "https://files.pythonhosted.org/packages/ad/b0/69bc7050f8d4e56a89fb550a1577d5d0d1db2278106f6f626464067b3817/coverage-7.10.7-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:2120043f147bebb41c85b97ac45dd173595ff14f2a584f2963891cbcc3091541", size = 262544, upload-time = "2025-09-21T20:02:29.216Z" }, + { url = "https://files.pythonhosted.org/packages/ef/4b/2514b060dbd1bc0aaf23b852c14bb5818f244c664cb16517feff6bb3a5ab/coverage-7.10.7-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:2fafd773231dd0378fdba66d339f84904a8e57a262f583530f4f156ab83863e6", size = 260308, upload-time = "2025-09-21T20:02:31.226Z" }, + { url = "https://files.pythonhosted.org/packages/54/78/7ba2175007c246d75e496f64c06e94122bdb914790a1285d627a918bd271/coverage-7.10.7-cp313-cp313t-musllinux_1_2_riscv64.whl", hash = "sha256:0b944ee8459f515f28b851728ad224fa2d068f1513ef6b7ff1efafeb2185f999", size = 258920, upload-time = "2025-09-21T20:02:32.823Z" }, + { url = "https://files.pythonhosted.org/packages/c0/b3/fac9f7abbc841409b9a410309d73bfa6cfb2e51c3fada738cb607ce174f8/coverage-7.10.7-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:4b583b97ab2e3efe1b3e75248a9b333bd3f8b0b1b8e5b45578e05e5850dfb2c2", size = 261434, upload-time = "2025-09-21T20:02:34.86Z" }, + { url = "https://files.pythonhosted.org/packages/ee/51/a03bec00d37faaa891b3ff7387192cef20f01604e5283a5fabc95346befa/coverage-7.10.7-cp313-cp313t-win32.whl", hash = "sha256:2a78cd46550081a7909b3329e2266204d584866e8d97b898cd7fb5ac8d888b1a", size = 221403, upload-time = "2025-09-21T20:02:37.034Z" }, + { url = "https://files.pythonhosted.org/packages/53/22/3cf25d614e64bf6d8e59c7c669b20d6d940bb337bdee5900b9ca41c820bb/coverage-7.10.7-cp313-cp313t-win_amd64.whl", hash = "sha256:33a5e6396ab684cb43dc7befa386258acb2d7fae7f67330ebb85ba4ea27938eb", size = 222469, upload-time = "2025-09-21T20:02:39.011Z" }, + { url = "https://files.pythonhosted.org/packages/49/a1/00164f6d30d8a01c3c9c48418a7a5be394de5349b421b9ee019f380df2a0/coverage-7.10.7-cp313-cp313t-win_arm64.whl", hash = "sha256:86b0e7308289ddde73d863b7683f596d8d21c7d8664ce1dee061d0bcf3fbb4bb", size = 220731, upload-time = "2025-09-21T20:02:40.939Z" }, + { url = "https://files.pythonhosted.org/packages/23/9c/5844ab4ca6a4dd97a1850e030a15ec7d292b5c5cb93082979225126e35dd/coverage-7.10.7-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:b06f260b16ead11643a5a9f955bd4b5fd76c1a4c6796aeade8520095b75de520", size = 218302, upload-time = "2025-09-21T20:02:42.527Z" }, + { url = "https://files.pythonhosted.org/packages/f0/89/673f6514b0961d1f0e20ddc242e9342f6da21eaba3489901b565c0689f34/coverage-7.10.7-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:212f8f2e0612778f09c55dd4872cb1f64a1f2b074393d139278ce902064d5b32", size = 218578, upload-time = "2025-09-21T20:02:44.468Z" }, + { url = "https://files.pythonhosted.org/packages/05/e8/261cae479e85232828fb17ad536765c88dd818c8470aca690b0ac6feeaa3/coverage-7.10.7-cp314-cp314-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:3445258bcded7d4aa630ab8296dea4d3f15a255588dd535f980c193ab6b95f3f", size = 249629, upload-time = "2025-09-21T20:02:46.503Z" }, + { url = "https://files.pythonhosted.org/packages/82/62/14ed6546d0207e6eda876434e3e8475a3e9adbe32110ce896c9e0c06bb9a/coverage-7.10.7-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:bb45474711ba385c46a0bfe696c695a929ae69ac636cda8f532be9e8c93d720a", size = 252162, upload-time = "2025-09-21T20:02:48.689Z" }, + { url = "https://files.pythonhosted.org/packages/ff/49/07f00db9ac6478e4358165a08fb41b469a1b053212e8a00cb02f0d27a05f/coverage-7.10.7-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:813922f35bd800dca9994c5971883cbc0d291128a5de6b167c7aa697fcf59360", size = 253517, upload-time = "2025-09-21T20:02:50.31Z" }, + { url = "https://files.pythonhosted.org/packages/a2/59/c5201c62dbf165dfbc91460f6dbbaa85a8b82cfa6131ac45d6c1bfb52deb/coverage-7.10.7-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:93c1b03552081b2a4423091d6fb3787265b8f86af404cff98d1b5342713bdd69", size = 249632, upload-time = "2025-09-21T20:02:51.971Z" }, + { url = "https://files.pythonhosted.org/packages/07/ae/5920097195291a51fb00b3a70b9bbd2edbfe3c84876a1762bd1ef1565ebc/coverage-7.10.7-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:cc87dd1b6eaf0b848eebb1c86469b9f72a1891cb42ac7adcfbce75eadb13dd14", size = 251520, upload-time = "2025-09-21T20:02:53.858Z" }, + { url = "https://files.pythonhosted.org/packages/b9/3c/a815dde77a2981f5743a60b63df31cb322c944843e57dbd579326625a413/coverage-7.10.7-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:39508ffda4f343c35f3236fe8d1a6634a51f4581226a1262769d7f970e73bffe", size = 249455, upload-time = "2025-09-21T20:02:55.807Z" }, + { url = "https://files.pythonhosted.org/packages/aa/99/f5cdd8421ea656abefb6c0ce92556709db2265c41e8f9fc6c8ae0f7824c9/coverage-7.10.7-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:925a1edf3d810537c5a3abe78ec5530160c5f9a26b1f4270b40e62cc79304a1e", size = 249287, upload-time = "2025-09-21T20:02:57.784Z" }, + { url = "https://files.pythonhosted.org/packages/c3/7a/e9a2da6a1fc5d007dd51fca083a663ab930a8c4d149c087732a5dbaa0029/coverage-7.10.7-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:2c8b9a0636f94c43cd3576811e05b89aa9bc2d0a85137affc544ae5cb0e4bfbd", size = 250946, upload-time = "2025-09-21T20:02:59.431Z" }, + { url = "https://files.pythonhosted.org/packages/ef/5b/0b5799aa30380a949005a353715095d6d1da81927d6dbed5def2200a4e25/coverage-7.10.7-cp314-cp314-win32.whl", hash = "sha256:b7b8288eb7cdd268b0304632da8cb0bb93fadcfec2fe5712f7b9cc8f4d487be2", size = 221009, upload-time = "2025-09-21T20:03:01.324Z" }, + { url = "https://files.pythonhosted.org/packages/da/b0/e802fbb6eb746de006490abc9bb554b708918b6774b722bb3a0e6aa1b7de/coverage-7.10.7-cp314-cp314-win_amd64.whl", hash = "sha256:1ca6db7c8807fb9e755d0379ccc39017ce0a84dcd26d14b5a03b78563776f681", size = 221804, upload-time = "2025-09-21T20:03:03.4Z" }, + { url = "https://files.pythonhosted.org/packages/9e/e8/71d0c8e374e31f39e3389bb0bd19e527d46f00ea8571ec7ec8fd261d8b44/coverage-7.10.7-cp314-cp314-win_arm64.whl", hash = "sha256:097c1591f5af4496226d5783d036bf6fd6cd0cbc132e071b33861de756efb880", size = 220384, upload-time = "2025-09-21T20:03:05.111Z" }, + { url = "https://files.pythonhosted.org/packages/62/09/9a5608d319fa3eba7a2019addeacb8c746fb50872b57a724c9f79f146969/coverage-7.10.7-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:a62c6ef0d50e6de320c270ff91d9dd0a05e7250cac2a800b7784bae474506e63", size = 219047, upload-time = "2025-09-21T20:03:06.795Z" }, + { url = "https://files.pythonhosted.org/packages/f5/6f/f58d46f33db9f2e3647b2d0764704548c184e6f5e014bef528b7f979ef84/coverage-7.10.7-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:9fa6e4dd51fe15d8738708a973470f67a855ca50002294852e9571cdbd9433f2", size = 219266, upload-time = "2025-09-21T20:03:08.495Z" }, + { url = "https://files.pythonhosted.org/packages/74/5c/183ffc817ba68e0b443b8c934c8795553eb0c14573813415bd59941ee165/coverage-7.10.7-cp314-cp314t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:8fb190658865565c549b6b4706856d6a7b09302c797eb2cf8e7fe9dabb043f0d", size = 260767, upload-time = "2025-09-21T20:03:10.172Z" }, + { url = "https://files.pythonhosted.org/packages/0f/48/71a8abe9c1ad7e97548835e3cc1adbf361e743e9d60310c5f75c9e7bf847/coverage-7.10.7-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:affef7c76a9ef259187ef31599a9260330e0335a3011732c4b9effa01e1cd6e0", size = 262931, upload-time = "2025-09-21T20:03:11.861Z" }, + { url = "https://files.pythonhosted.org/packages/84/fd/193a8fb132acfc0a901f72020e54be5e48021e1575bb327d8ee1097a28fd/coverage-7.10.7-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6e16e07d85ca0cf8bafe5f5d23a0b850064e8e945d5677492b06bbe6f09cc699", size = 265186, upload-time = "2025-09-21T20:03:13.539Z" }, + { url = "https://files.pythonhosted.org/packages/b1/8f/74ecc30607dd95ad50e3034221113ccb1c6d4e8085cc761134782995daae/coverage-7.10.7-cp314-cp314t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:03ffc58aacdf65d2a82bbeb1ffe4d01ead4017a21bfd0454983b88ca73af94b9", size = 259470, upload-time = "2025-09-21T20:03:15.584Z" }, + { url = "https://files.pythonhosted.org/packages/0f/55/79ff53a769f20d71b07023ea115c9167c0bb56f281320520cf64c5298a96/coverage-7.10.7-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:1b4fd784344d4e52647fd7857b2af5b3fbe6c239b0b5fa63e94eb67320770e0f", size = 262626, upload-time = "2025-09-21T20:03:17.673Z" }, + { url = "https://files.pythonhosted.org/packages/88/e2/dac66c140009b61ac3fc13af673a574b00c16efdf04f9b5c740703e953c0/coverage-7.10.7-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:0ebbaddb2c19b71912c6f2518e791aa8b9f054985a0769bdb3a53ebbc765c6a1", size = 260386, upload-time = "2025-09-21T20:03:19.36Z" }, + { url = "https://files.pythonhosted.org/packages/a2/f1/f48f645e3f33bb9ca8a496bc4a9671b52f2f353146233ebd7c1df6160440/coverage-7.10.7-cp314-cp314t-musllinux_1_2_riscv64.whl", hash = "sha256:a2d9a3b260cc1d1dbdb1c582e63ddcf5363426a1a68faa0f5da28d8ee3c722a0", size = 258852, upload-time = "2025-09-21T20:03:21.007Z" }, + { url = "https://files.pythonhosted.org/packages/bb/3b/8442618972c51a7affeead957995cfa8323c0c9bcf8fa5a027421f720ff4/coverage-7.10.7-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:a3cc8638b2480865eaa3926d192e64ce6c51e3d29c849e09d5b4ad95efae5399", size = 261534, upload-time = "2025-09-21T20:03:23.12Z" }, + { url = "https://files.pythonhosted.org/packages/b2/dc/101f3fa3a45146db0cb03f5b4376e24c0aac818309da23e2de0c75295a91/coverage-7.10.7-cp314-cp314t-win32.whl", hash = "sha256:67f8c5cbcd3deb7a60b3345dffc89a961a484ed0af1f6f73de91705cc6e31235", size = 221784, upload-time = "2025-09-21T20:03:24.769Z" }, + { url = "https://files.pythonhosted.org/packages/4c/a1/74c51803fc70a8a40d7346660379e144be772bab4ac7bb6e6b905152345c/coverage-7.10.7-cp314-cp314t-win_amd64.whl", hash = "sha256:e1ed71194ef6dea7ed2d5cb5f7243d4bcd334bfb63e59878519be558078f848d", size = 222905, upload-time = "2025-09-21T20:03:26.93Z" }, + { url = "https://files.pythonhosted.org/packages/12/65/f116a6d2127df30bcafbceef0302d8a64ba87488bf6f73a6d8eebf060873/coverage-7.10.7-cp314-cp314t-win_arm64.whl", hash = "sha256:7fe650342addd8524ca63d77b2362b02345e5f1a093266787d210c70a50b471a", size = 220922, upload-time = "2025-09-21T20:03:28.672Z" }, + { url = "https://files.pythonhosted.org/packages/ec/16/114df1c291c22cac3b0c127a73e0af5c12ed7bbb6558d310429a0ae24023/coverage-7.10.7-py3-none-any.whl", hash = "sha256:f7941f6f2fe6dd6807a1208737b8a0cbcf1cc6d7b07d24998ad2d63590868260", size = 209952, upload-time = "2025-09-21T20:03:53.918Z" }, +] + +[package.optional-dependencies] +toml = [ + { name = "tomli", marker = "python_full_version <= '3.11'" }, +] + +[[package]] +name = "decorator" +version = "5.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/43/fa/6d96a0978d19e17b68d634497769987b16c8f4cd0a7a05048bec693caa6b/decorator-5.2.1.tar.gz", hash = "sha256:65f266143752f734b0a7cc83c46f4618af75b8c5911b00ccb61d0ac9b6da0360", size = 56711, upload-time = "2025-02-24T04:41:34.073Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4e/8c/f3147f5c4b73e7550fe5f9352eaa956ae838d5c51eb58e7a25b9f3e2643b/decorator-5.2.1-py3-none-any.whl", hash = "sha256:d316bb415a2d9e2d2b3abcc4084c6502fc09240e292cd76a76afc106a1c8e04a", size = 9190, upload-time = "2025-02-24T04:41:32.565Z" }, +] + +[[package]] +name = "einops" +version = "0.8.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e5/81/df4fbe24dff8ba3934af99044188e20a98ed441ad17a274539b74e82e126/einops-0.8.1.tar.gz", hash = "sha256:de5d960a7a761225532e0f1959e5315ebeafc0cd43394732f103ca44b9837e84", size = 54805, upload-time = "2025-02-09T03:17:00.434Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/87/62/9773de14fe6c45c23649e98b83231fffd7b9892b6cf863251dc2afa73643/einops-0.8.1-py3-none-any.whl", hash = "sha256:919387eb55330f5757c6bea9165c5ff5cfe63a642682ea788a6d472576d81737", size = 64359, upload-time = "2025-02-09T03:17:01.998Z" }, +] + +[[package]] +name = "executing" +version = "2.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/cc/28/c14e053b6762b1044f34a13aab6859bbf40456d37d23aa286ac24cfd9a5d/executing-2.2.1.tar.gz", hash = "sha256:3632cc370565f6648cc328b32435bd120a1e4ebb20c77e3fdde9a13cd1e533c4", size = 1129488, upload-time = "2025-09-01T09:48:10.866Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c1/ea/53f2148663b321f21b5a606bd5f191517cf40b7072c0497d3c92c4a13b1e/executing-2.2.1-py2.py3-none-any.whl", hash = "sha256:760643d3452b4d777d295bb167ccc74c64a81df23fb5e08eff250c425a4b2017", size = 28317, upload-time = "2025-09-01T09:48:08.5Z" }, +] + [[package]] name = "filelock" version = "3.19.1" @@ -215,6 +343,52 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050, upload-time = "2025-03-19T20:10:01.071Z" }, ] +[[package]] +name = "ipython" +version = "9.6.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "decorator" }, + { name = "ipython-pygments-lexers" }, + { name = "jedi" }, + { name = "matplotlib-inline" }, + { name = "pexpect", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" }, + { name = "prompt-toolkit" }, + { name = "pygments" }, + { name = "stack-data" }, + { name = "traitlets" }, + { name = "typing-extensions", marker = "python_full_version < '3.12'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/2a/34/29b18c62e39ee2f7a6a3bba7efd952729d8aadd45ca17efc34453b717665/ipython-9.6.0.tar.gz", hash = "sha256:5603d6d5d356378be5043e69441a072b50a5b33b4503428c77b04cb8ce7bc731", size = 4396932, upload-time = "2025-09-29T10:55:53.948Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/48/c5/d5e07995077e48220269c28a221e168c91123ad5ceee44d548f54a057fc0/ipython-9.6.0-py3-none-any.whl", hash = "sha256:5f77efafc886d2f023442479b8149e7d86547ad0a979e9da9f045d252f648196", size = 616170, upload-time = "2025-09-29T10:55:47.676Z" }, +] + +[[package]] +name = "ipython-pygments-lexers" +version = "1.1.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ef/4c/5dd1d8af08107f88c7f741ead7a40854b8ac24ddf9ae850afbcf698aa552/ipython_pygments_lexers-1.1.1.tar.gz", hash = "sha256:09c0138009e56b6854f9535736f4171d855c8c08a563a0dcd8022f78355c7e81", size = 8393, upload-time = "2025-01-17T11:24:34.505Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d9/33/1f075bf72b0b747cb3288d011319aaf64083cf2efef8354174e3ed4540e2/ipython_pygments_lexers-1.1.1-py3-none-any.whl", hash = "sha256:a9462224a505ade19a605f71f8fa63c2048833ce50abc86768a0d81d876dc81c", size = 8074, upload-time = "2025-01-17T11:24:33.271Z" }, +] + +[[package]] +name = "jedi" +version = "0.19.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "parso" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/72/3a/79a912fbd4d8dd6fbb02bf69afd3bb72cf0c729bb3063c6f4498603db17a/jedi-0.19.2.tar.gz", hash = "sha256:4770dc3de41bde3966b02eb84fbcf557fb33cce26ad23da12c742fb50ecb11f0", size = 1231287, upload-time = "2024-11-11T01:41:42.873Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c0/5a/9cac0c82afec3d09ccd97c8b6502d48f165f9124db81b4bcb90b4af974ee/jedi-0.19.2-py2.py3-none-any.whl", hash = "sha256:a8ef22bde8490f57fe5c7681a3c83cb58874daf72b4784de3cce5b6ef6edb5b9", size = 1572278, upload-time = "2024-11-11T01:41:40.175Z" }, +] + [[package]] name = "jinja2" version = "3.1.6" @@ -298,20 +472,36 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4f/65/6079a46068dfceaeabb5dcad6d674f5f5c61a6fa5673746f42a9f4c233b3/MarkupSafe-3.0.2-cp313-cp313t-win_amd64.whl", hash = "sha256:e444a31f8db13eb18ada366ab3cf45fd4b31e4db1236a4448f68778c1d1a5a2f", size = 15739, upload-time = "2024-10-18T15:21:42.784Z" }, ] +[[package]] +name = "matplotlib-inline" +version = "0.1.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/99/5b/a36a337438a14116b16480db471ad061c36c3694df7c2084a0da7ba538b7/matplotlib_inline-0.1.7.tar.gz", hash = "sha256:8423b23ec666be3d16e16b60bdd8ac4e86e840ebd1dd11a30b9f117f2fa0ab90", size = 8159, upload-time = "2024-04-15T13:44:44.803Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8f/8e/9ad090d3553c280a8060fbf6e24dc1c0c29704ee7d1c372f0c174aa59285/matplotlib_inline-0.1.7-py3-none-any.whl", hash = "sha256:df192d39a4ff8f21b1895d72e6a13f5fcc5099f00fa84384e0ea28c2cc0653ca", size = 9899, upload-time = "2024-04-15T13:44:43.265Z" }, +] + [[package]] name = "mel" version = "0.1.0" source = { virtual = "." } dependencies = [ + { name = "einops" }, { name = "fire" }, { name = "gin-config" }, + { name = "ipython" }, { name = "numba" }, { name = "orjson" }, { name = "pandas" }, { name = "pytest" }, + { name = "pytest-cov" }, { name = "pytest-mock" }, { name = "python-fire" }, { name = "scann" }, + { name = "scipy" }, { name = "torch" }, { name = "tqdm" }, { name = "transformers" }, @@ -320,15 +510,19 @@ dependencies = [ [package.metadata] requires-dist = [ + { name = "einops", specifier = ">=0.8.1" }, { name = "fire", specifier = ">=0.7.1" }, { name = "gin-config", specifier = ">=0.5.0" }, + { name = "ipython", specifier = ">=9.6.0" }, { name = "numba", specifier = ">=0.61.2" }, { name = "orjson", specifier = ">=3.11.3" }, { name = "pandas", specifier = ">=2.3.2" }, { name = "pytest", specifier = ">=8.4.2" }, + { name = "pytest-cov", specifier = ">=7.0.0" }, { name = "pytest-mock", specifier = ">=3.15.0" }, { name = "python-fire", specifier = ">=0.1.0" }, { name = "scann", specifier = ">=1.4.2" }, + { name = "scipy", specifier = ">=1.16.2" }, { name = "torch", specifier = ">=2.8.0" }, { name = "tqdm", specifier = ">=4.67.1" }, { name = "transformers", specifier = ">=4.56.1" }, @@ -668,6 +862,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cd/d7/612123674d7b17cf345aad0a10289b2a384bff404e0463a83c4a3a59d205/pandas-2.3.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:d2c3554bd31b731cd6490d94a28f3abb8dd770634a9e06eb6d2911b9827db370", size = 13186141, upload-time = "2025-08-21T10:28:05.377Z" }, ] +[[package]] +name = "parso" +version = "0.8.5" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d4/de/53e0bcf53d13e005bd8c92e7855142494f41171b34c2536b86187474184d/parso-0.8.5.tar.gz", hash = "sha256:034d7354a9a018bdce352f48b2a8a450f05e9d6ee85db84764e9b6bd96dafe5a", size = 401205, upload-time = "2025-08-23T15:15:28.028Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/16/32/f8e3c85d1d5250232a5d3477a2a28cc291968ff175caeadaf3cc19ce0e4a/parso-0.8.5-py2.py3-none-any.whl", hash = "sha256:646204b5ee239c396d040b90f9e272e9a8017c630092bf59980beb62fd033887", size = 106668, upload-time = "2025-08-23T15:15:25.663Z" }, +] + +[[package]] +name = "pexpect" +version = "4.9.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "ptyprocess" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/42/92/cc564bf6381ff43ce1f4d06852fc19a2f11d180f23dc32d9588bee2f149d/pexpect-4.9.0.tar.gz", hash = "sha256:ee7d41123f3c9911050ea2c2dac107568dc43b2d3b0c7557a33212c398ead30f", size = 166450, upload-time = "2023-11-25T09:07:26.339Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9e/c3/059298687310d527a58bb01f3b1965787ee3b40dce76752eda8b44e9a2c5/pexpect-4.9.0-py2.py3-none-any.whl", hash = "sha256:7236d1e080e4936be2dc3e326cec0af72acf9212a7e1d060210e70a47e253523", size = 63772, upload-time = "2023-11-25T06:56:14.81Z" }, +] + [[package]] name = "platformdirs" version = "4.4.0" @@ -686,6 +901,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, ] +[[package]] +name = "prompt-toolkit" +version = "3.0.52" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "wcwidth" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a1/96/06e01a7b38dce6fe1db213e061a4602dd6032a8a97ef6c1a862537732421/prompt_toolkit-3.0.52.tar.gz", hash = "sha256:28cde192929c8e7321de85de1ddbe736f1375148b02f2e17edd840042b1be855", size = 434198, upload-time = "2025-08-27T15:24:02.057Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/84/03/0d3ce49e2505ae70cf43bc5bb3033955d2fc9f932163e84dc0779cc47f48/prompt_toolkit-3.0.52-py3-none-any.whl", hash = "sha256:9aac639a3bbd33284347de5ad8d68ecc044b91a762dc39b7c21095fcd6a19955", size = 391431, upload-time = "2025-08-27T15:23:59.498Z" }, +] + [[package]] name = "protobuf" version = "6.32.0" @@ -700,6 +927,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9c/f2/80ffc4677aac1bc3519b26bc7f7f5de7fce0ee2f7e36e59e27d8beb32dd1/protobuf-6.32.0-py3-none-any.whl", hash = "sha256:ba377e5b67b908c8f3072a57b63e2c6a4cbd18aea4ed98d2584350dbf46f2783", size = 169287, upload-time = "2025-08-14T21:21:23.515Z" }, ] +[[package]] +name = "ptyprocess" +version = "0.7.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/20/e5/16ff212c1e452235a90aeb09066144d0c5a6a8c0834397e03f5224495c4e/ptyprocess-0.7.0.tar.gz", hash = "sha256:5c5d0a3b48ceee0b48485e0c26037c0acd7d29765ca3fbb5cb3831d347423220", size = 70762, upload-time = "2020-12-28T15:15:30.155Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/22/a6/858897256d0deac81a172289110f31629fc4cee19b6f01283303e18c8db3/ptyprocess-0.7.0-py2.py3-none-any.whl", hash = "sha256:4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35", size = 13993, upload-time = "2020-12-28T15:15:28.35Z" }, +] + +[[package]] +name = "pure-eval" +version = "0.2.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/cd/05/0a34433a064256a578f1783a10da6df098ceaa4a57bbeaa96a6c0352786b/pure_eval-0.2.3.tar.gz", hash = "sha256:5f4e983f40564c576c7c8635ae88db5956bb2229d7e9237d03b3c0b0190eaf42", size = 19752, upload-time = "2024-07-21T12:58:21.801Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8e/37/efad0257dc6e593a18957422533ff0f87ede7c9c6ea010a2177d738fb82f/pure_eval-0.2.3-py3-none-any.whl", hash = "sha256:1db8e35b67b3d218d818ae653e27f06c3aa420901fa7b081ca98cbedc874e0d0", size = 11842, upload-time = "2024-07-21T12:58:20.04Z" }, +] + [[package]] name = "pydantic" version = "2.11.7" @@ -805,6 +1050,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a8/a4/20da314d277121d6534b3a980b29035dcd51e6744bd79075a6ce8fa4eb8d/pytest-8.4.2-py3-none-any.whl", hash = "sha256:872f880de3fc3a5bdc88a11b39c9710c3497a547cfa9320bc3c5e62fbf272e79", size = 365750, upload-time = "2025-09-04T14:34:20.226Z" }, ] +[[package]] +name = "pytest-cov" +version = "7.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "coverage", extra = ["toml"] }, + { name = "pluggy" }, + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5e/f7/c933acc76f5208b3b00089573cf6a2bc26dc80a8aece8f52bb7d6b1855ca/pytest_cov-7.0.0.tar.gz", hash = "sha256:33c97eda2e049a0c5298e91f519302a1334c26ac65c1a483d6206fd458361af1", size = 54328, upload-time = "2025-09-09T10:57:02.113Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ee/49/1377b49de7d0c1ce41292161ea0f721913fa8722c19fb9c1e3aa0367eecb/pytest_cov-7.0.0-py3-none-any.whl", hash = "sha256:3b8e9558b16cc1479da72058bdecf8073661c7f57f7d3c5f22a1c23507f2d861", size = 22424, upload-time = "2025-09-09T10:57:00.695Z" }, +] + [[package]] name = "pytest-mock" version = "3.15.0" @@ -1003,6 +1262,77 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3b/b4/3decfd7039399b6bd9c9fbf0ccda39301bac01c39a09a5a791c8237f5d26/scann-1.4.2-cp313-cp313-manylinux_2_27_x86_64.whl", hash = "sha256:c87e97f91c98d7d1f0bf985b39634e07e1149ba79c20f7dbf9b7b465c94100f2", size = 11579811, upload-time = "2025-08-29T14:32:25.101Z" }, ] +[[package]] +name = "scipy" +version = "1.16.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4c/3b/546a6f0bfe791bbb7f8d591613454d15097e53f906308ec6f7c1ce588e8e/scipy-1.16.2.tar.gz", hash = "sha256:af029b153d243a80afb6eabe40b0a07f8e35c9adc269c019f364ad747f826a6b", size = 30580599, upload-time = "2025-09-11T17:48:08.271Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0b/ef/37ed4b213d64b48422df92560af7300e10fe30b5d665dd79932baebee0c6/scipy-1.16.2-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:6ab88ea43a57da1af33292ebd04b417e8e2eaf9d5aa05700be8d6e1b6501cd92", size = 36619956, upload-time = "2025-09-11T17:39:20.5Z" }, + { url = "https://files.pythonhosted.org/packages/85/ab/5c2eba89b9416961a982346a4d6a647d78c91ec96ab94ed522b3b6baf444/scipy-1.16.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:c95e96c7305c96ede73a7389f46ccd6c659c4da5ef1b2789466baeaed3622b6e", size = 28931117, upload-time = "2025-09-11T17:39:29.06Z" }, + { url = "https://files.pythonhosted.org/packages/80/d1/eed51ab64d227fe60229a2d57fb60ca5898cfa50ba27d4f573e9e5f0b430/scipy-1.16.2-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:87eb178db04ece7c698220d523c170125dbffebb7af0345e66c3554f6f60c173", size = 20921997, upload-time = "2025-09-11T17:39:34.892Z" }, + { url = "https://files.pythonhosted.org/packages/be/7c/33ea3e23bbadde96726edba6bf9111fb1969d14d9d477ffa202c67bec9da/scipy-1.16.2-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:4e409eac067dcee96a57fbcf424c13f428037827ec7ee3cb671ff525ca4fc34d", size = 23523374, upload-time = "2025-09-11T17:39:40.846Z" }, + { url = "https://files.pythonhosted.org/packages/96/0b/7399dc96e1e3f9a05e258c98d716196a34f528eef2ec55aad651ed136d03/scipy-1.16.2-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:e574be127bb760f0dad24ff6e217c80213d153058372362ccb9555a10fc5e8d2", size = 33583702, upload-time = "2025-09-11T17:39:49.011Z" }, + { url = "https://files.pythonhosted.org/packages/1a/bc/a5c75095089b96ea72c1bd37a4497c24b581ec73db4ef58ebee142ad2d14/scipy-1.16.2-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f5db5ba6188d698ba7abab982ad6973265b74bb40a1efe1821b58c87f73892b9", size = 35883427, upload-time = "2025-09-11T17:39:57.406Z" }, + { url = "https://files.pythonhosted.org/packages/ab/66/e25705ca3d2b87b97fe0a278a24b7f477b4023a926847935a1a71488a6a6/scipy-1.16.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ec6e74c4e884104ae006d34110677bfe0098203a3fec2f3faf349f4cb05165e3", size = 36212940, upload-time = "2025-09-11T17:40:06.013Z" }, + { url = "https://files.pythonhosted.org/packages/d6/fd/0bb911585e12f3abdd603d721d83fc1c7492835e1401a0e6d498d7822b4b/scipy-1.16.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:912f46667d2d3834bc3d57361f854226475f695eb08c08a904aadb1c936b6a88", size = 38865092, upload-time = "2025-09-11T17:40:15.143Z" }, + { url = "https://files.pythonhosted.org/packages/d6/73/c449a7d56ba6e6f874183759f8483cde21f900a8be117d67ffbb670c2958/scipy-1.16.2-cp311-cp311-win_amd64.whl", hash = "sha256:91e9e8a37befa5a69e9cacbe0bcb79ae5afb4a0b130fd6db6ee6cc0d491695fa", size = 38687626, upload-time = "2025-09-11T17:40:24.041Z" }, + { url = "https://files.pythonhosted.org/packages/68/72/02f37316adf95307f5d9e579023c6899f89ff3a051fa079dbd6faafc48e5/scipy-1.16.2-cp311-cp311-win_arm64.whl", hash = "sha256:f3bf75a6dcecab62afde4d1f973f1692be013110cad5338007927db8da73249c", size = 25503506, upload-time = "2025-09-11T17:40:30.703Z" }, + { url = "https://files.pythonhosted.org/packages/b7/8d/6396e00db1282279a4ddd507c5f5e11f606812b608ee58517ce8abbf883f/scipy-1.16.2-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:89d6c100fa5c48472047632e06f0876b3c4931aac1f4291afc81a3644316bb0d", size = 36646259, upload-time = "2025-09-11T17:40:39.329Z" }, + { url = "https://files.pythonhosted.org/packages/3b/93/ea9edd7e193fceb8eef149804491890bde73fb169c896b61aa3e2d1e4e77/scipy-1.16.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:ca748936cd579d3f01928b30a17dc474550b01272d8046e3e1ee593f23620371", size = 28888976, upload-time = "2025-09-11T17:40:46.82Z" }, + { url = "https://files.pythonhosted.org/packages/91/4d/281fddc3d80fd738ba86fd3aed9202331180b01e2c78eaae0642f22f7e83/scipy-1.16.2-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:fac4f8ce2ddb40e2e3d0f7ec36d2a1e7f92559a2471e59aec37bd8d9de01fec0", size = 20879905, upload-time = "2025-09-11T17:40:52.545Z" }, + { url = "https://files.pythonhosted.org/packages/69/40/b33b74c84606fd301b2915f0062e45733c6ff5708d121dd0deaa8871e2d0/scipy-1.16.2-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:033570f1dcefd79547a88e18bccacff025c8c647a330381064f561d43b821232", size = 23553066, upload-time = "2025-09-11T17:40:59.014Z" }, + { url = "https://files.pythonhosted.org/packages/55/a7/22c739e2f21a42cc8f16bc76b47cff4ed54fbe0962832c589591c2abec34/scipy-1.16.2-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ea3421209bf00c8a5ef2227de496601087d8f638a2363ee09af059bd70976dc1", size = 33336407, upload-time = "2025-09-11T17:41:06.796Z" }, + { url = "https://files.pythonhosted.org/packages/53/11/a0160990b82999b45874dc60c0c183d3a3a969a563fffc476d5a9995c407/scipy-1.16.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f66bd07ba6f84cd4a380b41d1bf3c59ea488b590a2ff96744845163309ee8e2f", size = 35673281, upload-time = "2025-09-11T17:41:15.055Z" }, + { url = "https://files.pythonhosted.org/packages/96/53/7ef48a4cfcf243c3d0f1643f5887c81f29fdf76911c4e49331828e19fc0a/scipy-1.16.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:5e9feab931bd2aea4a23388c962df6468af3d808ddf2d40f94a81c5dc38f32ef", size = 36004222, upload-time = "2025-09-11T17:41:23.868Z" }, + { url = "https://files.pythonhosted.org/packages/49/7f/71a69e0afd460049d41c65c630c919c537815277dfea214031005f474d78/scipy-1.16.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:03dfc75e52f72cf23ec2ced468645321407faad8f0fe7b1f5b49264adbc29cb1", size = 38664586, upload-time = "2025-09-11T17:41:31.021Z" }, + { url = "https://files.pythonhosted.org/packages/34/95/20e02ca66fb495a95fba0642fd48e0c390d0ece9b9b14c6e931a60a12dea/scipy-1.16.2-cp312-cp312-win_amd64.whl", hash = "sha256:0ce54e07bbb394b417457409a64fd015be623f36e330ac49306433ffe04bc97e", size = 38550641, upload-time = "2025-09-11T17:41:36.61Z" }, + { url = "https://files.pythonhosted.org/packages/92/ad/13646b9beb0a95528ca46d52b7babafbe115017814a611f2065ee4e61d20/scipy-1.16.2-cp312-cp312-win_arm64.whl", hash = "sha256:2a8ffaa4ac0df81a0b94577b18ee079f13fecdb924df3328fc44a7dc5ac46851", size = 25456070, upload-time = "2025-09-11T17:41:41.3Z" }, + { url = "https://files.pythonhosted.org/packages/c1/27/c5b52f1ee81727a9fc457f5ac1e9bf3d6eab311805ea615c83c27ba06400/scipy-1.16.2-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:84f7bf944b43e20b8a894f5fe593976926744f6c185bacfcbdfbb62736b5cc70", size = 36604856, upload-time = "2025-09-11T17:41:47.695Z" }, + { url = "https://files.pythonhosted.org/packages/32/a9/15c20d08e950b540184caa8ced675ba1128accb0e09c653780ba023a4110/scipy-1.16.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:5c39026d12edc826a1ef2ad35ad1e6d7f087f934bb868fc43fa3049c8b8508f9", size = 28864626, upload-time = "2025-09-11T17:41:52.642Z" }, + { url = "https://files.pythonhosted.org/packages/4c/fc/ea36098df653cca26062a627c1a94b0de659e97127c8491e18713ca0e3b9/scipy-1.16.2-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:e52729ffd45b68777c5319560014d6fd251294200625d9d70fd8626516fc49f5", size = 20855689, upload-time = "2025-09-11T17:41:57.886Z" }, + { url = "https://files.pythonhosted.org/packages/dc/6f/d0b53be55727f3e6d7c72687ec18ea6d0047cf95f1f77488b99a2bafaee1/scipy-1.16.2-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:024dd4a118cccec09ca3209b7e8e614931a6ffb804b2a601839499cb88bdf925", size = 23512151, upload-time = "2025-09-11T17:42:02.303Z" }, + { url = "https://files.pythonhosted.org/packages/11/85/bf7dab56e5c4b1d3d8eef92ca8ede788418ad38a7dc3ff50262f00808760/scipy-1.16.2-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:7a5dc7ee9c33019973a470556081b0fd3c9f4c44019191039f9769183141a4d9", size = 33329824, upload-time = "2025-09-11T17:42:07.549Z" }, + { url = "https://files.pythonhosted.org/packages/da/6a/1a927b14ddc7714111ea51f4e568203b2bb6ed59bdd036d62127c1a360c8/scipy-1.16.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c2275ff105e508942f99d4e3bc56b6ef5e4b3c0af970386ca56b777608ce95b7", size = 35681881, upload-time = "2025-09-11T17:42:13.255Z" }, + { url = "https://files.pythonhosted.org/packages/c1/5f/331148ea5780b4fcc7007a4a6a6ee0a0c1507a796365cc642d4d226e1c3a/scipy-1.16.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:af80196eaa84f033e48444d2e0786ec47d328ba00c71e4299b602235ffef9acb", size = 36006219, upload-time = "2025-09-11T17:42:18.765Z" }, + { url = "https://files.pythonhosted.org/packages/46/3a/e991aa9d2aec723b4a8dcfbfc8365edec5d5e5f9f133888067f1cbb7dfc1/scipy-1.16.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9fb1eb735fe3d6ed1f89918224e3385fbf6f9e23757cacc35f9c78d3b712dd6e", size = 38682147, upload-time = "2025-09-11T17:42:25.177Z" }, + { url = "https://files.pythonhosted.org/packages/a1/57/0f38e396ad19e41b4c5db66130167eef8ee620a49bc7d0512e3bb67e0cab/scipy-1.16.2-cp313-cp313-win_amd64.whl", hash = "sha256:fda714cf45ba43c9d3bae8f2585c777f64e3f89a2e073b668b32ede412d8f52c", size = 38520766, upload-time = "2025-09-11T17:43:25.342Z" }, + { url = "https://files.pythonhosted.org/packages/1b/a5/85d3e867b6822d331e26c862a91375bb7746a0b458db5effa093d34cdb89/scipy-1.16.2-cp313-cp313-win_arm64.whl", hash = "sha256:2f5350da923ccfd0b00e07c3e5cfb316c1c0d6c1d864c07a72d092e9f20db104", size = 25451169, upload-time = "2025-09-11T17:43:30.198Z" }, + { url = "https://files.pythonhosted.org/packages/09/d9/60679189bcebda55992d1a45498de6d080dcaf21ce0c8f24f888117e0c2d/scipy-1.16.2-cp313-cp313t-macosx_10_14_x86_64.whl", hash = "sha256:53d8d2ee29b925344c13bda64ab51785f016b1b9617849dac10897f0701b20c1", size = 37012682, upload-time = "2025-09-11T17:42:30.677Z" }, + { url = "https://files.pythonhosted.org/packages/83/be/a99d13ee4d3b7887a96f8c71361b9659ba4ef34da0338f14891e102a127f/scipy-1.16.2-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:9e05e33657efb4c6a9d23bd8300101536abd99c85cca82da0bffff8d8764d08a", size = 29389926, upload-time = "2025-09-11T17:42:35.845Z" }, + { url = "https://files.pythonhosted.org/packages/bf/0a/130164a4881cec6ca8c00faf3b57926f28ed429cd6001a673f83c7c2a579/scipy-1.16.2-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:7fe65b36036357003b3ef9d37547abeefaa353b237e989c21027b8ed62b12d4f", size = 21381152, upload-time = "2025-09-11T17:42:40.07Z" }, + { url = "https://files.pythonhosted.org/packages/47/a6/503ffb0310ae77fba874e10cddfc4a1280bdcca1d13c3751b8c3c2996cf8/scipy-1.16.2-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:6406d2ac6d40b861cccf57f49592f9779071655e9f75cd4f977fa0bdd09cb2e4", size = 23914410, upload-time = "2025-09-11T17:42:44.313Z" }, + { url = "https://files.pythonhosted.org/packages/fa/c7/1147774bcea50d00c02600aadaa919facbd8537997a62496270133536ed6/scipy-1.16.2-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ff4dc42bd321991fbf611c23fc35912d690f731c9914bf3af8f417e64aca0f21", size = 33481880, upload-time = "2025-09-11T17:42:49.325Z" }, + { url = "https://files.pythonhosted.org/packages/6a/74/99d5415e4c3e46b2586f30cdbecb95e101c7192628a484a40dd0d163811a/scipy-1.16.2-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:654324826654d4d9133e10675325708fb954bc84dae6e9ad0a52e75c6b1a01d7", size = 35791425, upload-time = "2025-09-11T17:42:54.711Z" }, + { url = "https://files.pythonhosted.org/packages/1b/ee/a6559de7c1cc710e938c0355d9d4fbcd732dac4d0d131959d1f3b63eb29c/scipy-1.16.2-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:63870a84cd15c44e65220eaed2dac0e8f8b26bbb991456a033c1d9abfe8a94f8", size = 36178622, upload-time = "2025-09-11T17:43:00.375Z" }, + { url = "https://files.pythonhosted.org/packages/4e/7b/f127a5795d5ba8ece4e0dce7d4a9fb7cb9e4f4757137757d7a69ab7d4f1a/scipy-1.16.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:fa01f0f6a3050fa6a9771a95d5faccc8e2f5a92b4a2e5440a0fa7264a2398472", size = 38783985, upload-time = "2025-09-11T17:43:06.661Z" }, + { url = "https://files.pythonhosted.org/packages/3e/9f/bc81c1d1e033951eb5912cd3750cc005943afa3e65a725d2443a3b3c4347/scipy-1.16.2-cp313-cp313t-win_amd64.whl", hash = "sha256:116296e89fba96f76353a8579820c2512f6e55835d3fad7780fece04367de351", size = 38631367, upload-time = "2025-09-11T17:43:14.44Z" }, + { url = "https://files.pythonhosted.org/packages/d6/5e/2cc7555fd81d01814271412a1d59a289d25f8b63208a0a16c21069d55d3e/scipy-1.16.2-cp313-cp313t-win_arm64.whl", hash = "sha256:98e22834650be81d42982360382b43b17f7ba95e0e6993e2a4f5b9ad9283a94d", size = 25787992, upload-time = "2025-09-11T17:43:19.745Z" }, + { url = "https://files.pythonhosted.org/packages/8b/ac/ad8951250516db71619f0bd3b2eb2448db04b720a003dd98619b78b692c0/scipy-1.16.2-cp314-cp314-macosx_10_14_x86_64.whl", hash = "sha256:567e77755019bb7461513c87f02bb73fb65b11f049aaaa8ca17cfaa5a5c45d77", size = 36595109, upload-time = "2025-09-11T17:43:35.713Z" }, + { url = "https://files.pythonhosted.org/packages/ff/f6/5779049ed119c5b503b0f3dc6d6f3f68eefc3a9190d4ad4c276f854f051b/scipy-1.16.2-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:17d9bb346194e8967296621208fcdfd39b55498ef7d2f376884d5ac47cec1a70", size = 28859110, upload-time = "2025-09-11T17:43:40.814Z" }, + { url = "https://files.pythonhosted.org/packages/82/09/9986e410ae38bf0a0c737ff8189ac81a93b8e42349aac009891c054403d7/scipy-1.16.2-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:0a17541827a9b78b777d33b623a6dcfe2ef4a25806204d08ead0768f4e529a88", size = 20850110, upload-time = "2025-09-11T17:43:44.981Z" }, + { url = "https://files.pythonhosted.org/packages/0d/ad/485cdef2d9215e2a7df6d61b81d2ac073dfacf6ae24b9ae87274c4e936ae/scipy-1.16.2-cp314-cp314-macosx_14_0_x86_64.whl", hash = "sha256:d7d4c6ba016ffc0f9568d012f5f1eb77ddd99412aea121e6fa8b4c3b7cbad91f", size = 23497014, upload-time = "2025-09-11T17:43:49.074Z" }, + { url = "https://files.pythonhosted.org/packages/a7/74/f6a852e5d581122b8f0f831f1d1e32fb8987776ed3658e95c377d308ed86/scipy-1.16.2-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9702c4c023227785c779cba2e1d6f7635dbb5b2e0936cdd3a4ecb98d78fd41eb", size = 33401155, upload-time = "2025-09-11T17:43:54.661Z" }, + { url = "https://files.pythonhosted.org/packages/d9/f5/61d243bbc7c6e5e4e13dde9887e84a5cbe9e0f75fd09843044af1590844e/scipy-1.16.2-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d1cdf0ac28948d225decdefcc45ad7dd91716c29ab56ef32f8e0d50657dffcc7", size = 35691174, upload-time = "2025-09-11T17:44:00.101Z" }, + { url = "https://files.pythonhosted.org/packages/03/99/59933956331f8cc57e406cdb7a483906c74706b156998f322913e789c7e1/scipy-1.16.2-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:70327d6aa572a17c2941cdfb20673f82e536e91850a2e4cb0c5b858b690e1548", size = 36070752, upload-time = "2025-09-11T17:44:05.619Z" }, + { url = "https://files.pythonhosted.org/packages/c6/7d/00f825cfb47ee19ef74ecf01244b43e95eae74e7e0ff796026ea7cd98456/scipy-1.16.2-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:5221c0b2a4b58aa7c4ed0387d360fd90ee9086d383bb34d9f2789fafddc8a936", size = 38701010, upload-time = "2025-09-11T17:44:11.322Z" }, + { url = "https://files.pythonhosted.org/packages/e4/9f/b62587029980378304ba5a8563d376c96f40b1e133daacee76efdcae32de/scipy-1.16.2-cp314-cp314-win_amd64.whl", hash = "sha256:f5a85d7b2b708025af08f060a496dd261055b617d776fc05a1a1cc69e09fe9ff", size = 39360061, upload-time = "2025-09-11T17:45:09.814Z" }, + { url = "https://files.pythonhosted.org/packages/82/04/7a2f1609921352c7fbee0815811b5050582f67f19983096c4769867ca45f/scipy-1.16.2-cp314-cp314-win_arm64.whl", hash = "sha256:2cc73a33305b4b24556957d5857d6253ce1e2dcd67fa0ff46d87d1670b3e1e1d", size = 26126914, upload-time = "2025-09-11T17:45:14.73Z" }, + { url = "https://files.pythonhosted.org/packages/51/b9/60929ce350c16b221928725d2d1d7f86cf96b8bc07415547057d1196dc92/scipy-1.16.2-cp314-cp314t-macosx_10_14_x86_64.whl", hash = "sha256:9ea2a3fed83065d77367775d689401a703d0f697420719ee10c0780bcab594d8", size = 37013193, upload-time = "2025-09-11T17:44:16.757Z" }, + { url = "https://files.pythonhosted.org/packages/2a/41/ed80e67782d4bc5fc85a966bc356c601afddd175856ba7c7bb6d9490607e/scipy-1.16.2-cp314-cp314t-macosx_12_0_arm64.whl", hash = "sha256:7280d926f11ca945c3ef92ba960fa924e1465f8d07ce3a9923080363390624c4", size = 29390172, upload-time = "2025-09-11T17:44:21.783Z" }, + { url = "https://files.pythonhosted.org/packages/c4/a3/2f673ace4090452696ccded5f5f8efffb353b8f3628f823a110e0170b605/scipy-1.16.2-cp314-cp314t-macosx_14_0_arm64.whl", hash = "sha256:8afae1756f6a1fe04636407ef7dbece33d826a5d462b74f3d0eb82deabefd831", size = 21381326, upload-time = "2025-09-11T17:44:25.982Z" }, + { url = "https://files.pythonhosted.org/packages/42/bf/59df61c5d51395066c35836b78136accf506197617c8662e60ea209881e1/scipy-1.16.2-cp314-cp314t-macosx_14_0_x86_64.whl", hash = "sha256:5c66511f29aa8d233388e7416a3f20d5cae7a2744d5cee2ecd38c081f4e861b3", size = 23915036, upload-time = "2025-09-11T17:44:30.527Z" }, + { url = "https://files.pythonhosted.org/packages/91/c3/edc7b300dc16847ad3672f1a6f3f7c5d13522b21b84b81c265f4f2760d4a/scipy-1.16.2-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:efe6305aeaa0e96b0ccca5ff647a43737d9a092064a3894e46c414db84bc54ac", size = 33484341, upload-time = "2025-09-11T17:44:35.981Z" }, + { url = "https://files.pythonhosted.org/packages/26/c7/24d1524e72f06ff141e8d04b833c20db3021020563272ccb1b83860082a9/scipy-1.16.2-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7f3a337d9ae06a1e8d655ee9d8ecb835ea5ddcdcbd8d23012afa055ab014f374", size = 35790840, upload-time = "2025-09-11T17:44:41.76Z" }, + { url = "https://files.pythonhosted.org/packages/aa/b7/5aaad984eeedd56858dc33d75efa59e8ce798d918e1033ef62d2708f2c3d/scipy-1.16.2-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:bab3605795d269067d8ce78a910220262711b753de8913d3deeaedb5dded3bb6", size = 36174716, upload-time = "2025-09-11T17:44:47.316Z" }, + { url = "https://files.pythonhosted.org/packages/fd/c2/e276a237acb09824822b0ada11b028ed4067fdc367a946730979feacb870/scipy-1.16.2-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:b0348d8ddb55be2a844c518cd8cc8deeeb8aeba707cf834db5758fc89b476a2c", size = 38790088, upload-time = "2025-09-11T17:44:53.011Z" }, + { url = "https://files.pythonhosted.org/packages/c6/b4/5c18a766e8353015439f3780f5fc473f36f9762edc1a2e45da3ff5a31b21/scipy-1.16.2-cp314-cp314t-win_amd64.whl", hash = "sha256:26284797e38b8a75e14ea6631d29bda11e76ceaa6ddb6fdebbfe4c4d90faf2f9", size = 39457455, upload-time = "2025-09-11T17:44:58.899Z" }, + { url = "https://files.pythonhosted.org/packages/97/30/2f9a5243008f76dfc5dee9a53dfb939d9b31e16ce4bd4f2e628bfc5d89d2/scipy-1.16.2-cp314-cp314t-win_arm64.whl", hash = "sha256:d2a4472c231328d4de38d5f1f68fdd6d28a615138f842580a8a321b5845cf779", size = 26448374, upload-time = "2025-09-11T17:45:03.45Z" }, +] + [[package]] name = "sentry-sdk" version = "2.37.1" @@ -1043,6 +1373,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/04/be/d09147ad1ec7934636ad912901c5fd7667e1c858e19d355237db0d0cd5e4/smmap-5.0.2-py3-none-any.whl", hash = "sha256:b30115f0def7d7531d22a0fb6502488d879e75b260a9db4d0819cfb25403af5e", size = 24303, upload-time = "2025-01-02T07:14:38.724Z" }, ] +[[package]] +name = "stack-data" +version = "0.6.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "asttokens" }, + { name = "executing" }, + { name = "pure-eval" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/28/e3/55dcc2cfbc3ca9c29519eb6884dd1415ecb53b0e934862d3559ddcb7e20b/stack_data-0.6.3.tar.gz", hash = "sha256:836a778de4fec4dcd1dcd89ed8abff8a221f58308462e1c4aa2a3cf30148f0b9", size = 44707, upload-time = "2023-09-30T13:58:05.479Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl", hash = "sha256:d5558e0c25a4cb0853cddad3d77da9891a08cb85dd9f9f91b9f8cd66e511e695", size = 24521, upload-time = "2023-09-30T13:58:03.53Z" }, +] + [[package]] name = "sympy" version = "1.14.0" @@ -1089,6 +1433,45 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d1/9b/0e0bf82214ee20231845b127aa4a8015936ad5a46779f30865d10e404167/tokenizers-0.22.0-cp39-abi3-win_amd64.whl", hash = "sha256:c78174859eeaee96021f248a56c801e36bfb6bd5b067f2e95aa82445ca324f00", size = 2680494, upload-time = "2025-08-29T10:25:35.14Z" }, ] +[[package]] +name = "tomli" +version = "2.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/18/87/302344fed471e44a87289cf4967697d07e532f2421fdaf868a303cbae4ff/tomli-2.2.1.tar.gz", hash = "sha256:cd45e1dc79c835ce60f7404ec8119f2eb06d38b1deba146f07ced3bbc44505ff", size = 17175, upload-time = "2024-11-27T22:38:36.873Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/43/ca/75707e6efa2b37c77dadb324ae7d9571cb424e61ea73fad7c56c2d14527f/tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249", size = 131077, upload-time = "2024-11-27T22:37:54.956Z" }, + { url = "https://files.pythonhosted.org/packages/c7/16/51ae563a8615d472fdbffc43a3f3d46588c264ac4f024f63f01283becfbb/tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6", size = 123429, upload-time = "2024-11-27T22:37:56.698Z" }, + { url = "https://files.pythonhosted.org/packages/f1/dd/4f6cd1e7b160041db83c694abc78e100473c15d54620083dbd5aae7b990e/tomli-2.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ece47d672db52ac607a3d9599a9d48dcb2f2f735c6c2d1f34130085bb12b112a", size = 226067, upload-time = "2024-11-27T22:37:57.63Z" }, + { url = "https://files.pythonhosted.org/packages/a9/6b/c54ede5dc70d648cc6361eaf429304b02f2871a345bbdd51e993d6cdf550/tomli-2.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6972ca9c9cc9f0acaa56a8ca1ff51e7af152a9f87fb64623e31d5c83700080ee", size = 236030, upload-time = "2024-11-27T22:37:59.344Z" }, + { url = "https://files.pythonhosted.org/packages/1f/47/999514fa49cfaf7a92c805a86c3c43f4215621855d151b61c602abb38091/tomli-2.2.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c954d2250168d28797dd4e3ac5cf812a406cd5a92674ee4c8f123c889786aa8e", size = 240898, upload-time = "2024-11-27T22:38:00.429Z" }, + { url = "https://files.pythonhosted.org/packages/73/41/0a01279a7ae09ee1573b423318e7934674ce06eb33f50936655071d81a24/tomli-2.2.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8dd28b3e155b80f4d54beb40a441d366adcfe740969820caf156c019fb5c7ec4", size = 229894, upload-time = "2024-11-27T22:38:02.094Z" }, + { url = "https://files.pythonhosted.org/packages/55/18/5d8bc5b0a0362311ce4d18830a5d28943667599a60d20118074ea1b01bb7/tomli-2.2.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e59e304978767a54663af13c07b3d1af22ddee3bb2fb0618ca1593e4f593a106", size = 245319, upload-time = "2024-11-27T22:38:03.206Z" }, + { url = "https://files.pythonhosted.org/packages/92/a3/7ade0576d17f3cdf5ff44d61390d4b3febb8a9fc2b480c75c47ea048c646/tomli-2.2.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:33580bccab0338d00994d7f16f4c4ec25b776af3ffaac1ed74e0b3fc95e885a8", size = 238273, upload-time = "2024-11-27T22:38:04.217Z" }, + { url = "https://files.pythonhosted.org/packages/72/6f/fa64ef058ac1446a1e51110c375339b3ec6be245af9d14c87c4a6412dd32/tomli-2.2.1-cp311-cp311-win32.whl", hash = "sha256:465af0e0875402f1d226519c9904f37254b3045fc5084697cefb9bdde1ff99ff", size = 98310, upload-time = "2024-11-27T22:38:05.908Z" }, + { url = "https://files.pythonhosted.org/packages/6a/1c/4a2dcde4a51b81be3530565e92eda625d94dafb46dbeb15069df4caffc34/tomli-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:2d0f2fdd22b02c6d81637a3c95f8cd77f995846af7414c5c4b8d0545afa1bc4b", size = 108309, upload-time = "2024-11-27T22:38:06.812Z" }, + { url = "https://files.pythonhosted.org/packages/52/e1/f8af4c2fcde17500422858155aeb0d7e93477a0d59a98e56cbfe75070fd0/tomli-2.2.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4a8f6e44de52d5e6c657c9fe83b562f5f4256d8ebbfe4ff922c495620a7f6cea", size = 132762, upload-time = "2024-11-27T22:38:07.731Z" }, + { url = "https://files.pythonhosted.org/packages/03/b8/152c68bb84fc00396b83e7bbddd5ec0bd3dd409db4195e2a9b3e398ad2e3/tomli-2.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8d57ca8095a641b8237d5b079147646153d22552f1c637fd3ba7f4b0b29167a8", size = 123453, upload-time = "2024-11-27T22:38:09.384Z" }, + { url = "https://files.pythonhosted.org/packages/c8/d6/fc9267af9166f79ac528ff7e8c55c8181ded34eb4b0e93daa767b8841573/tomli-2.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e340144ad7ae1533cb897d406382b4b6fede8890a03738ff1683af800d54192", size = 233486, upload-time = "2024-11-27T22:38:10.329Z" }, + { url = "https://files.pythonhosted.org/packages/5c/51/51c3f2884d7bab89af25f678447ea7d297b53b5a3b5730a7cb2ef6069f07/tomli-2.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db2b95f9de79181805df90bedc5a5ab4c165e6ec3fe99f970d0e302f384ad222", size = 242349, upload-time = "2024-11-27T22:38:11.443Z" }, + { url = "https://files.pythonhosted.org/packages/ab/df/bfa89627d13a5cc22402e441e8a931ef2108403db390ff3345c05253935e/tomli-2.2.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40741994320b232529c802f8bc86da4e1aa9f413db394617b9a256ae0f9a7f77", size = 252159, upload-time = "2024-11-27T22:38:13.099Z" }, + { url = "https://files.pythonhosted.org/packages/9e/6e/fa2b916dced65763a5168c6ccb91066f7639bdc88b48adda990db10c8c0b/tomli-2.2.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:400e720fe168c0f8521520190686ef8ef033fb19fc493da09779e592861b78c6", size = 237243, upload-time = "2024-11-27T22:38:14.766Z" }, + { url = "https://files.pythonhosted.org/packages/b4/04/885d3b1f650e1153cbb93a6a9782c58a972b94ea4483ae4ac5cedd5e4a09/tomli-2.2.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:02abe224de6ae62c19f090f68da4e27b10af2b93213d36cf44e6e1c5abd19fdd", size = 259645, upload-time = "2024-11-27T22:38:15.843Z" }, + { url = "https://files.pythonhosted.org/packages/9c/de/6b432d66e986e501586da298e28ebeefd3edc2c780f3ad73d22566034239/tomli-2.2.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b82ebccc8c8a36f2094e969560a1b836758481f3dc360ce9a3277c65f374285e", size = 244584, upload-time = "2024-11-27T22:38:17.645Z" }, + { url = "https://files.pythonhosted.org/packages/1c/9a/47c0449b98e6e7d1be6cbac02f93dd79003234ddc4aaab6ba07a9a7482e2/tomli-2.2.1-cp312-cp312-win32.whl", hash = "sha256:889f80ef92701b9dbb224e49ec87c645ce5df3fa2cc548664eb8a25e03127a98", size = 98875, upload-time = "2024-11-27T22:38:19.159Z" }, + { url = "https://files.pythonhosted.org/packages/ef/60/9b9638f081c6f1261e2688bd487625cd1e660d0a85bd469e91d8db969734/tomli-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:7fc04e92e1d624a4a63c76474610238576942d6b8950a2d7f908a340494e67e4", size = 109418, upload-time = "2024-11-27T22:38:20.064Z" }, + { url = "https://files.pythonhosted.org/packages/04/90/2ee5f2e0362cb8a0b6499dc44f4d7d48f8fff06d28ba46e6f1eaa61a1388/tomli-2.2.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f4039b9cbc3048b2416cc57ab3bda989a6fcf9b36cf8937f01a6e731b64f80d7", size = 132708, upload-time = "2024-11-27T22:38:21.659Z" }, + { url = "https://files.pythonhosted.org/packages/c0/ec/46b4108816de6b385141f082ba99e315501ccd0a2ea23db4a100dd3990ea/tomli-2.2.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:286f0ca2ffeeb5b9bd4fcc8d6c330534323ec51b2f52da063b11c502da16f30c", size = 123582, upload-time = "2024-11-27T22:38:22.693Z" }, + { url = "https://files.pythonhosted.org/packages/a0/bd/b470466d0137b37b68d24556c38a0cc819e8febe392d5b199dcd7f578365/tomli-2.2.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a92ef1a44547e894e2a17d24e7557a5e85a9e1d0048b0b5e7541f76c5032cb13", size = 232543, upload-time = "2024-11-27T22:38:24.367Z" }, + { url = "https://files.pythonhosted.org/packages/d9/e5/82e80ff3b751373f7cead2815bcbe2d51c895b3c990686741a8e56ec42ab/tomli-2.2.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9316dc65bed1684c9a98ee68759ceaed29d229e985297003e494aa825ebb0281", size = 241691, upload-time = "2024-11-27T22:38:26.081Z" }, + { url = "https://files.pythonhosted.org/packages/05/7e/2a110bc2713557d6a1bfb06af23dd01e7dde52b6ee7dadc589868f9abfac/tomli-2.2.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e85e99945e688e32d5a35c1ff38ed0b3f41f43fad8df0bdf79f72b2ba7bc5272", size = 251170, upload-time = "2024-11-27T22:38:27.921Z" }, + { url = "https://files.pythonhosted.org/packages/64/7b/22d713946efe00e0adbcdfd6d1aa119ae03fd0b60ebed51ebb3fa9f5a2e5/tomli-2.2.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ac065718db92ca818f8d6141b5f66369833d4a80a9d74435a268c52bdfa73140", size = 236530, upload-time = "2024-11-27T22:38:29.591Z" }, + { url = "https://files.pythonhosted.org/packages/38/31/3a76f67da4b0cf37b742ca76beaf819dca0ebef26d78fc794a576e08accf/tomli-2.2.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:d920f33822747519673ee656a4b6ac33e382eca9d331c87770faa3eef562aeb2", size = 258666, upload-time = "2024-11-27T22:38:30.639Z" }, + { url = "https://files.pythonhosted.org/packages/07/10/5af1293da642aded87e8a988753945d0cf7e00a9452d3911dd3bb354c9e2/tomli-2.2.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a198f10c4d1b1375d7687bc25294306e551bf1abfa4eace6650070a5c1ae2744", size = 243954, upload-time = "2024-11-27T22:38:31.702Z" }, + { url = "https://files.pythonhosted.org/packages/5b/b9/1ed31d167be802da0fc95020d04cd27b7d7065cc6fbefdd2f9186f60d7bd/tomli-2.2.1-cp313-cp313-win32.whl", hash = "sha256:d3f5614314d758649ab2ab3a62d4f2004c825922f9e370b29416484086b264ec", size = 98724, upload-time = "2024-11-27T22:38:32.837Z" }, + { url = "https://files.pythonhosted.org/packages/c7/32/b0963458706accd9afcfeb867c0f9175a741bf7b19cd424230714d722198/tomli-2.2.1-cp313-cp313-win_amd64.whl", hash = "sha256:a38aa0308e754b0e3c67e344754dff64999ff9b513e691d0e786265c93583c69", size = 109383, upload-time = "2024-11-27T22:38:34.455Z" }, + { url = "https://files.pythonhosted.org/packages/6e/c2/61d3e0f47e2b74ef40a68b9e6ad5984f6241a942f7cd3bbfbdbd03861ea9/tomli-2.2.1-py3-none-any.whl", hash = "sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc", size = 14257, upload-time = "2024-11-27T22:38:35.385Z" }, +] + [[package]] name = "torch" version = "2.8.0" @@ -1148,6 +1531,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2", size = 78540, upload-time = "2024-11-24T20:12:19.698Z" }, ] +[[package]] +name = "traitlets" +version = "5.14.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/eb/79/72064e6a701c2183016abbbfedaba506d81e30e232a68c9f0d6f6fcd1574/traitlets-5.14.3.tar.gz", hash = "sha256:9ed0579d3502c94b4b3732ac120375cda96f923114522847de4b3bb98b96b6b7", size = 161621, upload-time = "2024-04-19T11:11:49.746Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/00/c0/8f5d070730d7836adc9c9b6408dec68c6ced86b304a9b26a14df072a6e8c/traitlets-5.14.3-py3-none-any.whl", hash = "sha256:b74e89e397b1ed28cc831db7aea759ba6640cb3de13090ca145426688ff1ac4f", size = 85359, upload-time = "2024-04-19T11:11:46.763Z" }, +] + [[package]] name = "transformers" version = "4.56.1" @@ -1250,3 +1642,12 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/75/d3/faa6ddb792a158c154fb704b25c96d0478e71eabf96e3f17529fb23b6894/wandb-0.21.3-py3-none-win32.whl", hash = "sha256:45aa3d8ad53c6ee06f37490d7a329ed7d0f5ca4dbd5d05bb0c01d5da22f14691", size = 18709408, upload-time = "2025-08-30T18:21:50.859Z" }, { url = "https://files.pythonhosted.org/packages/d8/2d/7ef56e25f78786e59fefd9b19867c325f9686317d9f7b93b5cb340360a3e/wandb-0.21.3-py3-none-win_amd64.whl", hash = "sha256:56d5a5697766f552a9933d8c6a564202194768eb0389bd5f9fe9a99cd4cee41e", size = 18709411, upload-time = "2025-08-30T18:21:52.874Z" }, ] + +[[package]] +name = "wcwidth" +version = "0.2.14" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/24/30/6b0809f4510673dc723187aeaf24c7f5459922d01e2f794277a3dfb90345/wcwidth-0.2.14.tar.gz", hash = "sha256:4d478375d31bc5395a3c55c40ccdf3354688364cd61c4f6adacaa9215d0b3605", size = 102293, upload-time = "2025-09-22T16:29:53.023Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/af/b5/123f13c975e9f27ab9c0770f514345bd406d0e8d3b7a0723af9d43f710af/wcwidth-0.2.14-py2.py3-none-any.whl", hash = "sha256:a7bb560c8aee30f9957e5f9895805edd20602f2d7f720186dfd906e82b4982e1", size = 37286, upload-time = "2025-09-22T16:29:51.641Z" }, +]