Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
fe17fc0
feat: :sparkles: add verbose option to load_tokens_qids_from_dir for …
Yokto13 Sep 30, 2025
8fbe834
feat: :sparkles: add new scripts for Qwen reranking and processing in…
Yokto13 Sep 30, 2025
a8038d8
feat: :sparkles: add max_items_to_load parameter to load_tokens_qids_…
Yokto13 Oct 1, 2025
266fe7a
feat: :wip: reranking simple model
Yokto13 Oct 3, 2025
0585685
fix(rerank): :bug: calling super in .train override
Yokto13 Oct 5, 2025
8ed60da
feat(rerank): :construction: new trainer code that will work more eas…
Yokto13 Oct 5, 2025
b4e8ed7
feat(reranking): :sparkles: add training configs
Yokto13 Oct 5, 2025
45364df
feat(models): :sparkles: add searcher that expects inputs to be tenso…
Yokto13 Oct 6, 2025
78e4462
feat(utils): :sparkles: benchmark whether DataParallel really helps
Yokto13 Oct 6, 2025
d22eebd
skip unused tests
Yokto13 Oct 6, 2025
4aa3de1
Refactor and enhance the reranking and searcher modules
Yokto13 Oct 6, 2025
41cc070
add ipython and pytest-cov
Yokto13 Oct 6, 2025
aeb0dff
feat(rerank): add binary dataset creation functions
Yokto13 Oct 9, 2025
725576e
feat(dependencies): add einops package to project dependencies
Yokto13 Oct 12, 2025
b3ec79c
feat(config): add run_kb_creator.langs to multilingual dataset config…
Yokto13 Oct 12, 2025
0707fec
feat(rerank): add create_default_binary_dataset and reranking_train a…
Yokto13 Oct 12, 2025
7e801b6
feat(creator): enhance dataset creation functions with detailed docst…
Yokto13 Oct 12, 2025
3584f42
feat(rerank): enhance create_binary_dataset function with improved li…
Yokto13 Oct 12, 2025
ea5ee20
feat(config): add model path for run_damuel_description in paraphrase…
Yokto13 Oct 12, 2025
6aa3544
Refactor training scripts and add new models
Yokto13 Oct 12, 2025
d1d5ea8
feat(embeddings): update model loading to include output type in embs…
Yokto13 Oct 12, 2025
1b2bba8
feat(scripts): add another finetuning script
Yokto13 Oct 12, 2025
3d32e0d
Add mapping of qids to token matrix and corresponding tests
Yokto13 Oct 16, 2025
d7a91e7
feat(train): add ManualSyncBruteForceSearcher for improved CUDA suppo…
Yokto13 Oct 16, 2025
f681cfc
feat(train): enable fused optimization in AdamW for improved performance
Yokto13 Oct 16, 2025
2e655c6
fix(reranking): :bug: make updating dataset tokens work with differen…
Yokto13 Oct 20, 2025
70f3cb5
feat(dataset): add multiclass dataset creation and corresponding iter…
Yokto13 Oct 21, 2025
05c8532
fix tests
Yokto13 Oct 21, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions configs/multilingual_dataset.gin
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,6 @@ dest_dir="/lnet/work/home-students-external/farhan/troja/outputs/v2_normal/"
create_multilingual_dataset.source_dir=%source_dir
create_multilingual_dataset.langs=%langs
create_multilingual_dataset.dest_dir=%dest_dir

run_kb_creator.langs=%langs

1 change: 1 addition & 0 deletions configs/paraphrase.gin
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,6 @@ train_ddp.FOUNDATION_MODEL_PATH=%model_path
run_mewsli_mention.model_path=%model_path
run_damuel_mention.model_path=%model_path
run_damuel_description_context.model_path=%model_path
run_damuel_description.model_path=%model_path
run_damuel_link_context.model_path=%model_path
run_mewsli_context.model_path=%model_path
2 changes: 1 addition & 1 deletion configs/train.gin
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
training_batch_size=2688
training_batch_size=3712
#epochs=300
epochs=100
logit_mutliplier=20
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,19 @@ description = "Add your description here"
readme = "README.md"
requires-python = ">=3.11"
dependencies = [
"einops>=0.8.1",
"fire>=0.7.1",
"gin-config>=0.5.0",
"ipython>=9.6.0",
"numba>=0.61.2",
"orjson>=3.11.3",
"pandas>=2.3.2",
"pytest>=8.4.2",
"pytest-cov>=7.0.0",
"pytest-mock>=3.15.0",
"python-fire>=0.1.0",
"scann>=1.4.2",
"scipy>=1.16.2",
"torch>=2.8.0",
"tqdm>=4.67.1",
"transformers>=4.56.1",
Expand Down
2 changes: 1 addition & 1 deletion src/finetunings/finetune_model/train_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def _ddp_train(
},
)

optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY, fused=True)
criterion = nn.CrossEntropyLoss()

scaler = torch.amp.GradScaler("cuda")
Expand Down
5 changes: 3 additions & 2 deletions src/finetunings/generate_epochs/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from finetunings.generate_epochs.datasets import BatcherDataset, DamuelNeighborsIterator
from models.batch_sampler import BatchSampler
from models.negative_sampler import NegativeSamplingType
from models.searchers.brute_force_searcher import DPBruteForceSearcher
from models.searchers.brute_force_searcher import DPBruteForceSearcher, ManualSyncBruteForceSearcher
from utils.calculate_qids_distribution import calculate_qids_distribution_from_links
from utils.loaders import load_embs_and_qids
from utils.multifile_dataset import MultiFileDataset
Expand Down Expand Up @@ -95,8 +95,9 @@ def generate(
batch_sampler = BatchSampler(
index_embs,
index_qids,
DPBruteForceSearcher,
# DPBruteForceSearcher,
# BruteForceSearcher,
ManualSyncBruteForceSearcher,
NegativeSamplingType(NEGATIVE_SAMPLING_TYPE),
**negative_sampler_kwargs,
)
Expand Down
97 changes: 96 additions & 1 deletion src/models/searchers/brute_force_searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def build(self) -> None:
class _WrappedSearcher(nn.Module):
def __init__(self, kb_embs, num_neighbors):
super().__init__()
self.kb_embs: torch.Tensor = nn.Parameter(kb_embs)
# Replace Parameter with register_buffer
self.register_buffer("kb_embs", kb_embs)
self.num_neighbors: int = num_neighbors

# @torch.compile
Expand Down Expand Up @@ -106,6 +107,10 @@ def find(self, batch: np.ndarray, num_neighbors: int, mask=None) -> np.ndarray:
_WrappedSearcher(torch.from_numpy(self.embs), num_neighbors)
)
self.module_searcher.to(self.device)
# Set module to eval() and disable gradients
self.module_searcher.eval()
for param in self.module_searcher.parameters():
param.requires_grad = False
self.required_num_neighbors = num_neighbors
top_indices: torch.Tensor = self.module_searcher(
torch.from_numpy(batch).to(self.device)
Expand All @@ -116,3 +121,93 @@ def find(self, batch: np.ndarray, num_neighbors: int, mask=None) -> np.ndarray:

def build(self):
pass


class DPBruteForceSearcherPT(Searcher):
def __init__(self, embs: np.ndarray, results: np.ndarray, run_build_from_init: bool = True):
if torch.cuda.is_available():
_logger.info("Running on CUDA.")
self.device: torch.device = torch.device("cuda")
else:
_logger.info("CUDA is not available.")
self.device: torch.device = torch.device("cpu")
self.module_searcher: Optional[nn.DataParallel] = None
self.required_num_neighbors: Optional[int] = None
super().__init__(embs, results, run_build_from_init)

@torch.inference_mode()
def find(self, batch: torch.Tensor, num_neighbors: int) -> np.ndarray:
"""
Finds the nearest neighbors for a given batch of input data.
CAREFUL: This is an optimized version that comes with potential pitfalls to get better performance.
Read Notes for details!

Args:
batch (torch.Tensor): A batch of input data for which neighbors are to be found.
num_neighbors (int): The number of nearest neighbors to retrieve.
Returns:
np.ndarray: An array containing the results corresponding to the nearest neighbors.
Raises:
TypeError: If `module_searcher` if an unexpected attribute access occurs when using module_searcher.
Notes:
- It is not possible to change num_neighbors after the first call to find.
If you need to do that, you need to reinitialize this object. If you call the find with different
num_neighbors, it will not raise an error and will fail silently.
- The first call to find will be slow, because the module_searcher will be initialized and torch.compile is called.
"""
# with torch.inference_mode(), torch.autocast(
# device_type=self.device.type, dtype=torch.float16
# ):
# A try except trick to avoid the overhead of checking if the module_searcher is None
# on every call to find.
# This is a bit of a hack, but it should make things faster as we are suggesting that the module_searcher is initialized.
try:
with torch.amp.autocast(device_type="cuda"):
top_indices: torch.Tensor = self.module_searcher(batch)
except TypeError as e:
if self.module_searcher is not None:
raise e
self.module_searcher = torch.compile(
nn.DataParallel(_WrappedSearcher(self.embs, num_neighbors))
)
self.module_searcher.to(self.device)
# Set module to eval() and disable gradients
self.module_searcher.eval()
for param in self.module_searcher.parameters():
param.requires_grad = False
self.required_num_neighbors = num_neighbors
top_indices: torch.Tensor = self.module_searcher(batch)

return self.results[top_indices.cpu().numpy()]

def build(self):
pass


class ManualSyncBruteForceSearcher(Searcher):
def __init__(self, embs: np.ndarray, results: np.ndarray, run_build_from_init: bool = False):
assert torch.cuda.is_available(), "This class requires CUDA."
assert run_build_from_init is False, "This class does not support building from init."
self.searchers = []
self.num_devices = torch.cuda.device_count() if torch.cuda.is_available() else 1
self.cuda_devices = [torch.device(f"cuda:{i}") for i in range(self.num_devices)]
super().__init__(torch.from_numpy(embs), results, run_build_from_init)

@torch.inference_mode()
def find(self, batch: np.ndarray, num_neighbors: int) -> np.ndarray:
if len(self.searchers) == 0:
for device in self.cuda_devices:
self.searchers.append(
_WrappedSearcher(self.embs, num_neighbors=num_neighbors).to(device)
)
batch = torch.from_numpy(batch)
inputs = nn.parallel.scatter(batch, self.cuda_devices)
outputs = [
searcher(input_chunk.to(device))
for searcher, input_chunk, device in zip(self.searchers, inputs, self.cuda_devices)
]
gathered = nn.parallel.gather(outputs, self.cuda_devices[0])
return gathered.cpu().numpy()

def build(self):
pass
35 changes: 35 additions & 0 deletions src/multilingual_dataset/creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,13 +284,48 @@ def create_multilingual_dataset(
dest_dir: Union[str, Path],
max_links_per_qid: int,
) -> None:
"""Create a multilingual dataset by mixing links and building a language-filtered KB.

- Links: copies and intermixes link shards from the given languages into `dest_dir/links`,
limiting per-QID occurrences to `max_links_per_qid`. Outputs NPZ files with arrays
`tokens` and `qids`.
- KB pages: writes a subset of description/page shards to `dest_dir/descs_pages`,
assigning up to one language per QID by default.

Args:
source_dir: Root directory of the DAMUEL dataset to read from.
langs: Language codes to include.
dest_dir: Output directory; creates `links/` and `descs_pages/` subfolders.
max_links_per_qid: Maximum number of link samples retained per QID.

Notes:
Uses parallel mixing and threaded I/O for performance.
"""
MultilingualDatasetCreator(Path(source_dir), langs, Path(dest_dir), max_links_per_qid).run()


@gin.configurable
def run_kb_creator(
source_dir: Union[str, Path],
langs: list[str],
dest_dir: Union[str, Path],
langs_per_qid: int,
) -> None:
"""
Build a language-filtered KB (descriptions/pages) subset from a DAMUEL dataset.

For each QID, selects up to `langs_per_qid` languages ranked by link frequency
(ties broken by overall language size), then copies only the chosen language pages
into `dest_dir/descs_pages` as compressed NPZ shards named `mentions_{lang}_{i}.npz`.

Args:
source_dir: Root path of the DAMUEL dataset to read (links and pages).
langs: List of language codes to consider.
dest_dir: Output directory; the 'descs_pages' subfolder is created inside.
langs_per_qid: Maximum number of languages to assign to each QID.

Notes:
- Affects KB creation only; link files are not modified.
- Uses parallel I/O with ThreadPoolExecutor for speed.
"""
_KBCreator(DamuelPaths(source_dir), langs, Path(dest_dir), langs_per_qid).run()
Loading
Loading