diff --git a/notebooks/AmazonBeautyDatasetStatistics.ipynb b/notebooks/AmazonBeautyDatasetStatistics.ipynb index 6d34ff2..379239d 100644 --- a/notebooks/AmazonBeautyDatasetStatistics.ipynb +++ b/notebooks/AmazonBeautyDatasetStatistics.ipynb @@ -405,7 +405,7 @@ ], "metadata": { "kernelspec": { - "display_name": ".venv", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -419,7 +419,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.6" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/notebooks/LsvdDownload.ipynb b/notebooks/LsvdDownload.ipynb new file mode 100644 index 0000000..c57e1af --- /dev/null +++ b/notebooks/LsvdDownload.ipynb @@ -0,0 +1,574 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "SbkKok0dfjjS" + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import polars as pl" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'1.8.2'" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pl.__version__" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "HF_ENDPOINT=\"http://huggingface.proxy\" hf download deepvk/VK-LSVD --repo-type dataset --include \"metadata/*\" --local-dir /home/jovyan/IRec/sigir/lsvd_data/raw\n", + "\n", + "HF_ENDPOINT=\"http://huggingface.proxy\" hf download deepvk/VK-LSVD --repo-type dataset --include \"subsamples/ur0.01_ir0.01/*\" --local-dir /home/jovyan/IRec/sigir/lsvd_data/raw\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Разбиение сабсэмплов на базовую, гэп, вал и тест части\n", + "\n", + "Добавляется колонка original_order чтобы сохранять порядок внутри каждой из частей" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(8, 1, 1, 1, 11, 9)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "subsample_name = 'ur0.01_ir0.01'\n", + "content_embedding_size = 256\n", + "DATASET_PATH = \"/home/jovyan/IRec/sigir/lsvd_data/raw\"\n", + "\n", + "metadata_files = ['metadata/users_metadata.parquet',\n", + " 'metadata/items_metadata.parquet',\n", + " 'metadata/item_embeddings.npz']\n", + "\n", + "BASE_WEEKS = (15, 23)\n", + "GAP_WEEKS = (23, 24) #увеличить гэп\n", + "VAL_WEEKS = (24, 25)\n", + "\n", + "base_interactions_files = [f'subsamples/{subsample_name}/train/week_{i:02}.parquet'\n", + " for i in range(BASE_WEEKS[0], BASE_WEEKS[1])]\n", + "\n", + "gap_interactions_files = [f'subsamples/{subsample_name}/train/week_{i:02}.parquet'\n", + " for i in range(GAP_WEEKS[0], GAP_WEEKS[1])]\n", + "\n", + "val_interactions_files = [f'subsamples/{subsample_name}/train/week_{i:02}.parquet'\n", + " for i in range(VAL_WEEKS[0], VAL_WEEKS[1])]\n", + "\n", + "test_interactions_files = [f'subsamples/{subsample_name}/validation/week_25.parquet']\n", + "\n", + "all_interactions_files = base_interactions_files + gap_interactions_files + val_interactions_files + test_interactions_files\n", + "\n", + "base_with_gap_interactions_files = base_interactions_files + gap_interactions_files\n", + "\n", + "len(base_interactions_files), len(gap_interactions_files), len(val_interactions_files), len(test_interactions_files), len(all_interactions_files), len(base_with_gap_interactions_files)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "def get_parquet_interactions(data_files):\n", + " data_interactions = pl.concat([pl.scan_parquet(f'{DATASET_PATH}/{file}')\n", + " for file in data_files])\n", + " data_interactions = data_interactions.collect(streaming=True)\n", + " data_interactions = data_interactions.with_row_index(\"original_order\")\n", + " return data_interactions\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "base_interactions = get_parquet_interactions(base_interactions_files)\n", + "gap_interactions = get_parquet_interactions(gap_interactions_files)\n", + "val_interactions = get_parquet_interactions(val_interactions_files)\n", + "test_interactions = get_parquet_interactions(test_interactions_files)\n", + "all_data_interactions = get_parquet_interactions(all_interactions_files)\n", + "base_with_gap_interactions = get_parquet_interactions(base_with_gap_interactions_files)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Загрузка и фильтрация эмбеддингов" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "all_data_users = all_data_interactions.select('user_id').unique()\n", + "all_data_items = all_data_interactions.select('item_id').unique()\n", + "\n", + "item_ids = np.load(f\"{DATASET_PATH}/metadata/item_embeddings.npz\")['item_id']\n", + "item_embeddings = np.load(f\"{DATASET_PATH}/metadata/item_embeddings.npz\")['embedding']\n", + "\n", + "mask = np.isin(item_ids, all_data_items.to_numpy())\n", + "item_ids = item_ids[mask]\n", + "item_embeddings = item_embeddings[mask]\n", + "item_embeddings = item_embeddings[:, :content_embedding_size]\n", + "\n", + "users_metadata = pl.read_parquet(f\"{DATASET_PATH}/metadata/users_metadata.parquet\")\n", + "items_metadata = pl.read_parquet(f\"{DATASET_PATH}/metadata/items_metadata.parquet\")\n", + "\n", + "users_metadata = users_metadata.join(all_data_users, on='user_id')\n", + "items_metadata = items_metadata.join(all_data_items, on='item_id')\n", + "items_metadata = items_metadata.join(pl.DataFrame({'item_id': item_ids, \n", + " 'embedding': item_embeddings}), on='item_id')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Сжатие айтем айди и ремапинг" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total users: 79074, Total items: 62758\n" + ] + } + ], + "source": [ + "all_data_items = all_data_interactions.select('item_id').unique()\n", + "all_data_users = all_data_interactions.select('user_id').unique()\n", + "\n", + "unique_items_sorted = all_data_items.sort('item_id').with_row_index('new_item_id')\n", + "global_item_mapping = dict(zip(unique_items_sorted['item_id'], unique_items_sorted['new_item_id']))\n", + "\n", + "print(f\"Total users: {all_data_users.shape[0]}, Total items: {len(global_item_mapping)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "def remap_interactions(df, mapping):\n", + " return df.with_columns(\n", + " pl.col('item_id')\n", + " .map_elements(lambda x: mapping.get(x, None), return_dtype=pl.UInt32)\n", + " )\n", + "\n", + "base_interactions_remapped = remap_interactions(base_interactions, global_item_mapping)\n", + "gap_interactions_remapped = remap_interactions(gap_interactions, global_item_mapping)\n", + "test_interactions_remapped = remap_interactions(test_interactions, global_item_mapping)\n", + "val_interactions_remapped = remap_interactions(val_interactions, global_item_mapping)\n", + "all_data_interactions_remapped = remap_interactions(all_data_interactions, global_item_mapping)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "del base_interactions, gap_interactions, test_interactions, val_interactions, all_data_interactions" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "base_with_gap_interactions_remapped = remap_interactions(base_with_gap_interactions, global_item_mapping)\n", + "del base_with_gap_interactions" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "items_metadata_remapped = remap_interactions(items_metadata, global_item_mapping)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Группировка по юзер айди" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "interactions count: (1244323, 13)\n", + "users count: (74862, 3)\n", + "interactions count: (176791, 13)\n", + "users count: (44444, 3)\n", + "interactions count: (170111, 13)\n", + "users count: (43370, 3)\n", + "interactions count: (164151, 13)\n", + "users count: (43233, 3)\n", + "interactions count: (1755376, 13)\n", + "users count: (79074, 3)\n" + ] + } + ], + "source": [ + "def get_grouped_interactions(data_interactions):\n", + " print(f\"interactions count: {data_interactions.shape}\")\n", + " data_res = (\n", + " data_interactions\n", + " .select(['original_order', 'user_id', 'item_id'])\n", + " .group_by('user_id')\n", + " .agg(\n", + " pl.col('item_id')\n", + " .sort_by(pl.col('original_order'))\n", + " .alias('item_ids'),\n", + " pl.col('original_order').alias('timestamps')\n", + " )\n", + " .rename({'user_id': 'uid'})\n", + " )\n", + " print(f\"users count: {data_res.shape}\")\n", + " return data_res\n", + "base_interactions_grouped = get_grouped_interactions(base_interactions_remapped)\n", + "gap_interactions_grouped = get_grouped_interactions(gap_interactions_remapped)\n", + "test_interactions_grouped = get_grouped_interactions(test_interactions_remapped)\n", + "val_interactions_grouped = get_grouped_interactions(val_interactions_remapped)\n", + "all_data_interactions_grouped = get_grouped_interactions(all_data_interactions_remapped)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (1, 3)
uiditem_idstimestamps
u32list[u32]list[u32]
2655558[16621, 42990, … 51285][46109, 59132, … 1209536]
" + ], + "text/plain": [ + "shape: (1, 3)\n", + "┌─────────┬─────────────────────────┬───────────────────────────┐\n", + "│ uid ┆ item_ids ┆ timestamps │\n", + "│ --- ┆ --- ┆ --- │\n", + "│ u32 ┆ list[u32] ┆ list[u32] │\n", + "╞═════════╪═════════════════════════╪═══════════════════════════╡\n", + "│ 2655558 ┆ [16621, 42990, … 51285] ┆ [46109, 59132, … 1209536] │\n", + "└─────────┴─────────────────────────┴───────────────────────────┘" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "base_interactions_grouped.head(1)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "del base_interactions_remapped, gap_interactions_remapped, test_interactions_remapped, val_interactions_remapped" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "interactions count: (1421114, 13)\n", + "users count: (76483, 3)\n" + ] + } + ], + "source": [ + "base_with_gap_interactions_grouped = get_grouped_interactions(base_with_gap_interactions_remapped)\n", + "del base_with_gap_interactions_remapped" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Сохранение" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Сохранён маппинг: /home/jovyan/IRec/sigir/lsvd_data/8-days-base-ows/global_item_mapping.json\n" + ] + } + ], + "source": [ + "import json\n", + "OUTPUT_DIR = \"/home/jovyan/IRec/sigir/lsvd_data/8-days-base-ows\"\n", + "\n", + "mapping_output_path = f\"{OUTPUT_DIR}/global_item_mapping.json\"\n", + "\n", + "with open(mapping_output_path, 'w') as f:\n", + " json.dump({str(k): v for k, v in global_item_mapping.items()}, f, indent=2)\n", + "\n", + "print(f\"Сохранён маппинг: {mapping_output_path}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Сохранен файл: items_metadata_remapped\n", + "Сохранен файл: items_metadata_remapped_old\n", + "Сохранен файл: base_interactions_grouped\n", + "Сохранен файл: gap_interactions_grouped\n", + "Сохранен файл: test_interactions_grouped\n", + "Сохранен файл: val_interactions_grouped\n", + "Сохранен файл: base_with_gap_interactions_grouped\n", + "Сохранен файл: all_data_interactions_grouped\n", + "Сохранен файл: all_data_interactions_remapped\n" + ] + } + ], + "source": [ + "def write_parquet(output_dir, data, file_name):\n", + " output_parquet_path = f\"{output_dir}/{file_name}.parquet\"\n", + " data.write_parquet(output_parquet_path)\n", + " print(f\"Сохранен файл: {file_name}\")\n", + "\n", + "write_parquet(OUTPUT_DIR, items_metadata_remapped, \"items_metadata_remapped\")\n", + "write_parquet(OUTPUT_DIR, items_metadata, \"items_metadata_remapped_old\")\n", + "\n", + "write_parquet(OUTPUT_DIR, base_interactions_grouped, \"base_interactions_grouped\")\n", + "write_parquet(OUTPUT_DIR, gap_interactions_grouped, \"gap_interactions_grouped\")\n", + "write_parquet(OUTPUT_DIR, test_interactions_grouped, \"test_interactions_grouped\")\n", + "write_parquet(OUTPUT_DIR, val_interactions_grouped, \"val_interactions_grouped\")\n", + "write_parquet(OUTPUT_DIR, base_with_gap_interactions_grouped, \"base_with_gap_interactions_grouped\")\n", + "\n", + "write_parquet(OUTPUT_DIR, all_data_interactions_grouped, \"all_data_interactions_grouped\")\n", + "\n", + "write_parquet(OUTPUT_DIR, all_data_interactions_remapped, \"all_data_interactions_remapped\")" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(64, 64)" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(list(items_metadata_remapped.head(1)['embedding'].item())), len(list(items_metadata.head(1)['embedding'].item()))" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(62758, 5)" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "items_metadata_remapped.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (1, 5)
item_idauthor_iddurationtrain_interactions_rankembedding
u32u32u8u32array[f32, 64]
012494249771612[-0.503418, 0.201538, … 0.007988]
" + ], + "text/plain": [ + "shape: (1, 5)\n", + "┌─────────┬───────────┬──────────┬─────────────────────────┬─────────────────────────────────┐\n", + "│ item_id ┆ author_id ┆ duration ┆ train_interactions_rank ┆ embedding │\n", + "│ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n", + "│ u32 ┆ u32 ┆ u8 ┆ u32 ┆ array[f32, 64] │\n", + "╞═════════╪═══════════╪══════════╪═════════════════════════╪═════════════════════════════════╡\n", + "│ 0 ┆ 1249424 ┆ 9 ┆ 771612 ┆ [-0.503418, 0.201538, … 0.0079… │\n", + "└─────────┴───────────┴──────────┴─────────────────────────┴─────────────────────────────────┘" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "items_metadata_remapped.head(1)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (5, 3)
uiditem_idstimestamps
u32list[u32]list[u32]
4465123[28298, 3829, … 28995][257260, 272293, … 1390041]
3043171[8638, 23487, … 15086][6628, 11364, … 1370935]
2757146[56345, 56828, … 37056][194522, 217739, … 1390752]
1148408[40326, 42152][427153, 1367211]
2537065[27766, 39966, … 19887][9428, 35459, … 1214991]
" + ], + "text/plain": [ + "shape: (5, 3)\n", + "┌─────────┬─────────────────────────┬─────────────────────────────┐\n", + "│ uid ┆ item_ids ┆ timestamps │\n", + "│ --- ┆ --- ┆ --- │\n", + "│ u32 ┆ list[u32] ┆ list[u32] │\n", + "╞═════════╪═════════════════════════╪═════════════════════════════╡\n", + "│ 4465123 ┆ [28298, 3829, … 28995] ┆ [257260, 272293, … 1390041] │\n", + "│ 3043171 ┆ [8638, 23487, … 15086] ┆ [6628, 11364, … 1370935] │\n", + "│ 2757146 ┆ [56345, 56828, … 37056] ┆ [194522, 217739, … 1390752] │\n", + "│ 1148408 ┆ [40326, 42152] ┆ [427153, 1367211] │\n", + "│ 2537065 ┆ [27766, 39966, … 19887] ┆ [9428, 35459, … 1214991] │\n", + "└─────────┴─────────────────────────┴─────────────────────────────┘" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "base_with_gap_interactions_grouped.head()" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/scripts/plum-lsvd/callbacks.py b/scripts/plum-lsvd/callbacks.py new file mode 100644 index 0000000..43ec460 --- /dev/null +++ b/scripts/plum-lsvd/callbacks.py @@ -0,0 +1,64 @@ +import torch + +import irec.callbacks as cb +from irec.runners import TrainingRunner, TrainingRunnerContext + +class InitCodebooks(cb.TrainingCallback): + def __init__(self, dataloader): + super().__init__() + self._dataloader = dataloader + + @torch.no_grad() + def before_run(self, runner: TrainingRunner): + for i in range(len(runner.model.codebooks)): + X = next(iter(self._dataloader))['embedding'] + idx = torch.randperm(X.shape[0], device=X.device)[:len(runner.model.codebooks[i])] + remainder = runner.model.encoder(X[idx]) + + for j in range(i): + codebook_indices = runner.model.get_codebook_indices(remainder, runner.model.codebooks[j]) + codebook_vectors = runner.model.codebooks[j][codebook_indices] + remainder = remainder - codebook_vectors + + runner.model.codebooks[i].data = remainder.detach() + + +class FixDeadCentroids(cb.TrainingCallback): + def __init__(self, dataloader): + super().__init__() + self._dataloader = dataloader + + def after_step(self, runner: TrainingRunner, context: TrainingRunnerContext): + for i, num_fixed in enumerate(self.fix_dead_codebooks(runner)): + context.metrics[f'num_dead/{i}'] = num_fixed + + @torch.no_grad() + def fix_dead_codebooks(self, runner: TrainingRunner): + num_fixed = [] + for codebook_idx, codebook in enumerate(runner.model.codebooks): + centroid_counts = torch.zeros(codebook.shape[0], dtype=torch.long, device=codebook.device) + random_batch = next(iter(self._dataloader))['embedding'] + + for batch in self._dataloader: + remainder = runner.model.encoder(batch['embedding']) + for l in range(codebook_idx): + ind = runner.model.get_codebook_indices(remainder, runner.model.codebooks[l]) + remainder = remainder - runner.model.codebooks[l][ind] + + indices = runner.model.get_codebook_indices(remainder, codebook) + centroid_counts.scatter_add_(0, indices, torch.ones_like(indices)) + + dead_mask = (centroid_counts == 0) + num_dead = int(dead_mask.sum().item()) + num_fixed.append(num_dead) + if num_dead == 0: + continue + + remainder = runner.model.encoder(random_batch) + for l in range(codebook_idx): + ind = runner.model.get_codebook_indices(remainder, runner.model.codebooks[l]) + remainder = remainder - runner.model.codebooks[l][ind] + remainder = remainder[torch.randperm(remainder.shape[0], device=codebook.device)][:num_dead] + codebook[dead_mask] = remainder.detach() + + return num_fixed diff --git a/scripts/plum-lsvd/cooc_data.py b/scripts/plum-lsvd/cooc_data.py new file mode 100644 index 0000000..7cea906 --- /dev/null +++ b/scripts/plum-lsvd/cooc_data.py @@ -0,0 +1,117 @@ +import json +from collections import defaultdict, Counter +from data import InteractionsDatasetParquet +from collections import defaultdict, Counter + + +class CoocMappingDataset: + def __init__( + self, + train_sampler, + num_items, + cooccur_counter_mapping=None + ): + self._train_sampler = train_sampler + self._num_items = num_items + self._cooccur_counter_mapping = cooccur_counter_mapping + + @classmethod + def create(cls, inter_json_path, window_size): + max_item_id = 0 + train_dataset = [] + + with open(inter_json_path, 'r') as f: + user_interactions = json.load(f) + + for user_id_str, item_ids in user_interactions.items(): + user_id = int(user_id_str) + if item_ids: + max_item_id = max(max_item_id, max(item_ids)) + if len(item_ids) >= 5: + print(f'Core-5 dataset is used, user {user_id} has only {len(item_ids)} items') + train_dataset.append({ + 'user_ids': [user_id], + 'item_ids': item_ids[:-2], + }) + + + cooccur_counter_mapping = cls.build_cooccur_counter_mapping(train_dataset, window_size=window_size) + print(f'Computed window-based co-occurrence mapping for {len(cooccur_counter_mapping)} items but max_item_id is {max_item_id}') + + + train_sampler = train_dataset + + + return cls( + train_sampler=train_sampler, + num_items=max_item_id + 1, + cooccur_counter_mapping=cooccur_counter_mapping + ) + + + @classmethod + def create_from_split_part( + cls, + train_inter_parquet_path, + window_size, + ): + + max_item_id = 0 + train_dataset = [] + + + train_interactions = InteractionsDatasetParquet(train_inter_parquet_path) + + actions_num = 0 + for session in train_interactions: + user_id, item_ids = int(session['user_id']), session['item_ids'] + if item_ids.any(): + max_item_id = max(max_item_id, max(item_ids)) + actions_num += len(item_ids) + train_dataset.append({ + 'user_ids': [user_id], + 'item_ids': item_ids, + }) + + + print(f'Train: {len(train_dataset)} users') + print(f'Max item ID: {max_item_id}') + print(f"Actions num: {actions_num}") + + + cooccur_counter_mapping = cls.build_cooccur_counter_mapping( + train_dataset, + window_size=window_size + ) + + + print(f'Computed window-based co-occurrence mapping for {len(cooccur_counter_mapping)} items') + + + return cls( + train_sampler=train_dataset, + num_items=max_item_id + 1, + cooccur_counter_mapping=cooccur_counter_mapping + ) + + + + @staticmethod + def build_cooccur_counter_mapping(train_dataset, window_size): + cooccur_counts = defaultdict(Counter) + for session in train_dataset: + items = session['item_ids'] + for i in range(len(items)): + item_i = items[i] + for j in range(max(0, i - window_size), min(len(items), i + window_size + 1)): + if i != j: + cooccur_counts[item_i][items[j]] += 1 + max_hist_len = max(len(counter) for counter in cooccur_counts.values()) if cooccur_counts else 0 + print(f"Max cooccurrence history length is {max_hist_len}") + return cooccur_counts + + + + @property + def cooccur_counter_mapping(self): + return self._cooccur_counter_mapping \ No newline at end of file diff --git a/scripts/plum-lsvd/data.py b/scripts/plum-lsvd/data.py new file mode 100644 index 0000000..5a780fb --- /dev/null +++ b/scripts/plum-lsvd/data.py @@ -0,0 +1,87 @@ +import numpy as np +import pickle + +from irec.data.base import BaseDataset +from irec.data.transforms import Transform + + +import polars as pl + +class InteractionsDatasetParquet(BaseDataset): + def __init__(self, data_path, max_items=None): + self.df = pl.read_parquet(data_path) + assert 'uid' in self.df.columns, "Missing 'uid' column" + assert 'item_ids' in self.df.columns, "Missing 'item_ids' column" + print(f"Dataset loaded: {len(self.df)} users") + + if max_items is not None: + self.df = self.df.with_columns( + pl.col("item_ids").list.slice(-max_items).alias("item_ids") + ) + + def __getitem__(self, idx): + row = self.df.row(idx, named=True) + return { + 'user_id': row['uid'], + 'item_ids': np.array(row['item_ids'], dtype=np.uint32), + } + + def __len__(self): + return len(self.df) + + def __iter__(self): + for idx in range(len(self)): + yield self[idx] + + +class EmbeddingDatasetParquet(BaseDataset): + def __init__(self, data_path): + self.df = pl.read_parquet(data_path) + self.item_ids = np.array(self.df['item_id'], dtype=np.int64) + self.embeddings = np.array(self.df['embedding'].to_list(), dtype=np.float32) + print(f"embedding dim: {self.embeddings[0].shape}") + + def __getitem__(self, idx): + index = self.item_ids[idx] + tensor_emb = self.embeddings[idx] + return { + 'item_id': index, + 'embedding': tensor_emb, + 'embedding_dim': len(tensor_emb) + } + + def __len__(self): + return len(self.embeddings) + + +class EmbeddingDataset(BaseDataset): + def __init__(self, data_path): + self.data_path = data_path + with open(data_path, 'rb') as f: + self.data = pickle.load(f) + + self.item_ids = np.array(self.data['item_id'], dtype=np.int64) + self.embeddings = np.array(self.data['embedding'], dtype=np.float32) + + def __getitem__(self, idx): + index = self.item_ids[idx] + tensor_emb = self.embeddings[idx] + return { + 'item_id': index, + 'embedding': tensor_emb, + 'embedding_dim': len(tensor_emb) + } + + def __len__(self): + return len(self.embeddings) + + +class ProcessEmbeddings(Transform): + def __init__(self, embedding_dim, keys): + self.embedding_dim = embedding_dim + self.keys = keys + + def __call__(self, batch): + for key in self.keys: + batch[key] = batch[key].reshape(-1, self.embedding_dim) + return batch \ No newline at end of file diff --git a/scripts/plum-lsvd/infer_default.py b/scripts/plum-lsvd/infer_default.py new file mode 100644 index 0000000..b15fb6d --- /dev/null +++ b/scripts/plum-lsvd/infer_default.py @@ -0,0 +1,146 @@ +from loguru import logger +import os + +import torch + +import irec.callbacks as cb +from irec.data.dataloader import DataLoader +from irec.data.transforms import Collate, ToTorch, ToDevice +from irec.runners import InferenceRunner + +from irec.utils import fix_random_seed + +from data import EmbeddingDataset, ProcessEmbeddings +from models import PlumRQVAE + +# ПУТИ +IREC_PATH = '/home/jovyan/IRec/' +EMBEDDINGS_PATH = '/home/jovyan/tiger/data/Beauty/default_content_embeddings.pkl' +MODEL_PATH = '/home/jovyan/IRec/checkpoints/4-1_plum_rqvae_beauty_ws_2_best_0.0051.pth' +RESULTS_PATH = os.path.join(IREC_PATH, 'results') + +WINDOW_SIZE = 2 + +EXPERIMENT_NAME = f'test_plum_rqvae_beauty_ws_{WINDOW_SIZE}' + +# ОСТАЛЬНОЕ + +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + +BATCH_SIZE = 1024 + +INPUT_DIM = 4096 +HIDDEN_DIM = 32 +CODEBOOK_SIZE = 256 +NUM_CODEBOOKS = 3 + +BETA = 0.25 + + + +def main(): + fix_random_seed(SEED_VALUE) + + dataset = EmbeddingDataset( + data_path=EMBEDDINGS_PATH + ) + + item_id_to_embedding = {} + all_item_ids = [] + for idx in range(len(dataset)): + sample = dataset[idx] + item_id = int(sample['item_id']) + item_id_to_embedding[item_id] = torch.tensor(sample['embedding']) + all_item_ids.append(item_id) + + dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=False, + drop_last=False, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])) + + model = PlumRQVAE( + input_dim=INPUT_DIM, + num_codebooks=NUM_CODEBOOKS, + codebook_size=CODEBOOK_SIZE, + embedding_dim=HIDDEN_DIM, + beta=BETA, + quant_loss_weight=1.0, + contrastive_loss_weight=1.0, + temperature=1.0 + ).to(DEVICE) + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + callbacks = [ + cb.LoadModel(MODEL_PATH), + + cb.BatchMetrics(metrics=lambda model_outputs, _: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + 'con_loss': model_outputs['con_loss'] + }, name='valid'), + + cb.MetricAccumulator( + accumulators={ + 'valid/loss': cb.MeanAccumulator(), + 'valid/recon_loss': cb.MeanAccumulator(), + 'valid/rqvae_loss': cb.MeanAccumulator(), + 'valid/con_loss': cb.MeanAccumulator(), + }, + ), + + cb.Logger().every_num_steps(len(dataloader)), + + cb.InferenceSaver( + metrics=lambda batch, model_outputs, _: {'item_id': batch['item_id'], 'clusters': model_outputs['clusters']}, + save_path=os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'), + format='json' + ) + ] + + logger.debug('Everything is ready for training process!') + + runner = InferenceRunner( + model=model, + dataset=dataloader, + callbacks=callbacks, + ) + runner.run() + + import json + from collections import defaultdict + import numpy as np + + with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'), 'r') as f: + mappings = json.load(f) + + inter = {} + sem_2_ids = defaultdict(list) + for mapping in mappings: + item_id = mapping['item_id'] + clusters = mapping['clusters'] + inter[int(item_id)] = clusters + sem_2_ids[tuple(clusters)].append(int(item_id)) + + for semantics, items in sem_2_ids.items(): + assert len(items) <= CODEBOOK_SIZE, str(len(items)) + collision_solvers = np.random.permutation(CODEBOOK_SIZE)[:len(items)].tolist() + for item_id, collision_solver in zip(items, collision_solvers): + inter[item_id].append(collision_solver) + for i in range(len(inter[item_id])): + inter[item_id][i] += CODEBOOK_SIZE * i + + with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters_colisionless.json'), 'w') as f: + json.dump(inter, f, indent=2) + + +if __name__ == '__main__': + main() diff --git a/scripts/plum-lsvd/infer_plum_4.1.py b/scripts/plum-lsvd/infer_plum_4.1.py new file mode 100644 index 0000000..bb70a9d --- /dev/null +++ b/scripts/plum-lsvd/infer_plum_4.1.py @@ -0,0 +1,146 @@ +from loguru import logger +import os + +import torch + +import irec.callbacks as cb +from irec.data.dataloader import DataLoader +from irec.data.transforms import Collate, ToTorch, ToDevice +from irec.runners import InferenceRunner + +from irec.utils import fix_random_seed + +from data import EmbeddingDatasetParquet, ProcessEmbeddings +from models import PlumRQVAE + +# ЭКСПЕРИМЕНТ С ПОЛНОЙ ИСТОРИЕЙ +IREC_PATH = '../../' +MODEL_PATH = '/home/jovyan/IRec/checkpoints/4-1_vk_lsvd_ods_base_with_gap_cb_512_ws_2_k_2000_8w_e35_best_0.0096.pth' +EMBEDDINGS_PATH = "/home/jovyan/IRec/sigir/lsvd_data/8-days-base-ows/items_metadata_remapped.parquet" + +RESULTS_PATH = os.path.join(IREC_PATH, 'results') + +WINDOW_SIZE = 2 +CODEBOOK_SIZE = 512 +K = 2000 +EXPERIMENT_NAME = f'4-1_vk_lsvd_ods_base_with_gap_cb_{CODEBOOK_SIZE}_ws_{WINDOW_SIZE}_k_{K}_8w_e_35' +# ОСТАЛЬНОЕ + +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + + +BATCH_SIZE = 1024 + +INPUT_DIM = 64 +HIDDEN_DIM = 32 +NUM_CODEBOOKS = 3 +BETA = 0.25 + + + +def main(): + fix_random_seed(SEED_VALUE) + + dataset = EmbeddingDatasetParquet( + data_path=EMBEDDINGS_PATH + ) + + item_id_to_embedding = {} + all_item_ids = [] + for idx in range(len(dataset)): + sample = dataset[idx] + item_id = int(sample['item_id']) + item_id_to_embedding[item_id] = torch.tensor(sample['embedding']) + all_item_ids.append(item_id) + + dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=False, + drop_last=False, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])) + + model = PlumRQVAE( + input_dim=INPUT_DIM, + num_codebooks=NUM_CODEBOOKS, + codebook_size=CODEBOOK_SIZE, + embedding_dim=HIDDEN_DIM, + beta=BETA, + quant_loss_weight=1.0, + contrastive_loss_weight=1.0, + temperature=1.0 + ).to(DEVICE) + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + callbacks = [ + cb.LoadModel(MODEL_PATH), + + cb.BatchMetrics(metrics=lambda model_outputs, _: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + 'con_loss': model_outputs['con_loss'] + }, name='valid'), + + cb.MetricAccumulator( + accumulators={ + 'valid/loss': cb.MeanAccumulator(), + 'valid/recon_loss': cb.MeanAccumulator(), + 'valid/rqvae_loss': cb.MeanAccumulator(), + 'valid/con_loss': cb.MeanAccumulator(), + }, + ), + + cb.Logger().every_num_steps(len(dataloader)), + + cb.InferenceSaver( + metrics=lambda batch, model_outputs, _: {'item_id': batch['item_id'], 'clusters': model_outputs['clusters']}, + save_path=os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'), + format='json' + ) + ] + + logger.debug('Everything is ready for training process!') + + runner = InferenceRunner( + model=model, + dataset=dataloader, + callbacks=callbacks, + ) + runner.run() + + import json + from collections import defaultdict + import numpy as np + + with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'), 'r') as f: + mappings = json.load(f) + + inter = {} + sem_2_ids = defaultdict(list) + for mapping in mappings: + item_id = mapping['item_id'] + clusters = mapping['clusters'] + inter[int(item_id)] = clusters + sem_2_ids[tuple(clusters)].append(int(item_id)) + + for semantics, items in sem_2_ids.items(): + assert len(items) <= CODEBOOK_SIZE, str(len(items)) + collision_solvers = np.random.permutation(CODEBOOK_SIZE)[:len(items)].tolist() + for item_id, collision_solver in zip(items, collision_solvers): + inter[item_id].append(collision_solver) + for i in range(len(inter[item_id])): + inter[item_id][i] += CODEBOOK_SIZE * i + + with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters_colisionless.json'), 'w') as f: + json.dump(inter, f, indent=2) + + +if __name__ == '__main__': + main() diff --git a/scripts/plum-lsvd/infer_plum_4.2.py b/scripts/plum-lsvd/infer_plum_4.2.py new file mode 100644 index 0000000..977c0b5 --- /dev/null +++ b/scripts/plum-lsvd/infer_plum_4.2.py @@ -0,0 +1,146 @@ +from loguru import logger +import os + +import torch + +import irec.callbacks as cb +from irec.data.dataloader import DataLoader +from irec.data.transforms import Collate, ToTorch, ToDevice +from irec.runners import InferenceRunner + +from irec.utils import fix_random_seed + +from data import EmbeddingDatasetParquet, ProcessEmbeddings +from models import PlumRQVAE + +# ЭКСПЕРИМЕНТ С ОБРЕЗАННОЙ ИСТОРИЕЙ +IREC_PATH = '../../' +MODEL_PATH = '/home/jovyan/IRec/checkpoints/4-2_vk_lsvd_ods_base_with_gap_cb_512_ws_2_k_2000_8w_e35_best_0.0096.pth' +EMBEDDINGS_PATH = "/home/jovyan/IRec/sigir/lsvd_data/8-days-base-ows/items_metadata_remapped.parquet" + +RESULTS_PATH = os.path.join(IREC_PATH, 'results') + +WINDOW_SIZE = 2 +CODEBOOK_SIZE = 512 +K = 2000 +EXPERIMENT_NAME = f'4-2_vk_lsvd_ods_base_with_gap_cb_{CODEBOOK_SIZE}_ws_{WINDOW_SIZE}_k_{K}_8w_e_35' +# ОСТАЛЬНОЕ + +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + + +BATCH_SIZE = 1024 + +INPUT_DIM = 64 +HIDDEN_DIM = 32 +NUM_CODEBOOKS = 3 +BETA = 0.25 + + + +def main(): + fix_random_seed(SEED_VALUE) + + dataset = EmbeddingDatasetParquet( + data_path=EMBEDDINGS_PATH + ) + + item_id_to_embedding = {} + all_item_ids = [] + for idx in range(len(dataset)): + sample = dataset[idx] + item_id = int(sample['item_id']) + item_id_to_embedding[item_id] = torch.tensor(sample['embedding']) + all_item_ids.append(item_id) + + dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=False, + drop_last=False, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])) + + model = PlumRQVAE( + input_dim=INPUT_DIM, + num_codebooks=NUM_CODEBOOKS, + codebook_size=CODEBOOK_SIZE, + embedding_dim=HIDDEN_DIM, + beta=BETA, + quant_loss_weight=1.0, + contrastive_loss_weight=1.0, + temperature=1.0 + ).to(DEVICE) + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + callbacks = [ + cb.LoadModel(MODEL_PATH), + + cb.BatchMetrics(metrics=lambda model_outputs, _: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + 'con_loss': model_outputs['con_loss'] + }, name='valid'), + + cb.MetricAccumulator( + accumulators={ + 'valid/loss': cb.MeanAccumulator(), + 'valid/recon_loss': cb.MeanAccumulator(), + 'valid/rqvae_loss': cb.MeanAccumulator(), + 'valid/con_loss': cb.MeanAccumulator(), + }, + ), + + cb.Logger().every_num_steps(len(dataloader)), + + cb.InferenceSaver( + metrics=lambda batch, model_outputs, _: {'item_id': batch['item_id'], 'clusters': model_outputs['clusters']}, + save_path=os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'), + format='json' + ) + ] + + logger.debug('Everything is ready for training process!') + + runner = InferenceRunner( + model=model, + dataset=dataloader, + callbacks=callbacks, + ) + runner.run() + + import json + from collections import defaultdict + import numpy as np + + with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'), 'r') as f: + mappings = json.load(f) + + inter = {} + sem_2_ids = defaultdict(list) + for mapping in mappings: + item_id = mapping['item_id'] + clusters = mapping['clusters'] + inter[int(item_id)] = clusters + sem_2_ids[tuple(clusters)].append(int(item_id)) + + for semantics, items in sem_2_ids.items(): + assert len(items) <= CODEBOOK_SIZE, str(len(items)) + collision_solvers = np.random.permutation(CODEBOOK_SIZE)[:len(items)].tolist() + for item_id, collision_solver in zip(items, collision_solvers): + inter[item_id].append(collision_solver) + for i in range(len(inter[item_id])): + inter[item_id][i] += CODEBOOK_SIZE * i + + with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters_colisionless.json'), 'w') as f: + json.dump(inter, f, indent=2) + + +if __name__ == '__main__': + main() diff --git a/scripts/plum-lsvd/infer_rqvae.py b/scripts/plum-lsvd/infer_rqvae.py new file mode 100644 index 0000000..53a587c --- /dev/null +++ b/scripts/plum-lsvd/infer_rqvae.py @@ -0,0 +1,161 @@ +from loguru import logger +import os + +import torch + +import irec.callbacks as cb +from irec.data.dataloader import DataLoader +from irec.data.transforms import Collate, ToTorch, ToDevice +from irec.runners import InferenceRunner + +from irec.utils import fix_random_seed + +from data import EmbeddingDatasetParquet, ProcessEmbeddings +from collections import Counter +from models import PlumRQVAE + +# ПУТИ +IREC_PATH = '/home/jovyan/IRec/' +EMBEDDINGS_PATH = '/home/jovyan/IRec/sigir/lsvd_data/8-weeks-base-ows/items_metadata_remapped.parquet' +MODEL_PATH = '/home/jovyan/IRec/checkpoints/rqvae_vk_lsvd_cz_512_8-weeks_best_0.009.pth' +RESULTS_PATH = os.path.join(IREC_PATH, 'results') + +WINDOW_SIZE = 2 + +EXPERIMENT_NAME = f'rqvae_vk_lsvd_cz_512_8-weeks' + +# ОСТАЛЬНОЕ + +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + +BATCH_SIZE = 1024 + +INPUT_DIM = 64 +HIDDEN_DIM = 32 +CODEBOOK_SIZE = 512 +NUM_CODEBOOKS = 3 + +BETA = 0.25 + + + +def main(): + fix_random_seed(SEED_VALUE) + + dataset = EmbeddingDatasetParquet( + data_path=EMBEDDINGS_PATH + ) + + item_id_to_embedding = {} + all_item_ids = [] + for idx in range(len(dataset)): + sample = dataset[idx] + item_id = int(sample['item_id']) + item_id_to_embedding[item_id] = torch.tensor(sample['embedding']) + all_item_ids.append(item_id) + + dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=False, + drop_last=False, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])) + + model = PlumRQVAE( + input_dim=INPUT_DIM, + num_codebooks=NUM_CODEBOOKS, + codebook_size=CODEBOOK_SIZE, + embedding_dim=HIDDEN_DIM, + beta=BETA, + quant_loss_weight=1.0, + contrastive_loss_weight=1.0, + temperature=1.0 + ).to(DEVICE) + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + callbacks = [ + cb.LoadModel(MODEL_PATH), + + cb.BatchMetrics(metrics=lambda model_outputs, _: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + 'con_loss': model_outputs['con_loss'] + }, name='valid'), + + cb.MetricAccumulator( + accumulators={ + 'valid/loss': cb.MeanAccumulator(), + 'valid/recon_loss': cb.MeanAccumulator(), + 'valid/rqvae_loss': cb.MeanAccumulator(), + 'valid/con_loss': cb.MeanAccumulator(), + }, + ), + + cb.Logger().every_num_steps(len(dataloader)), + + cb.InferenceSaver( + metrics=lambda batch, model_outputs, _: {'item_id': batch['item_id'], 'clusters': model_outputs['clusters']}, + save_path=os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'), + format='json' + ) + ] + + logger.debug('Everything is ready for training process!') + + runner = InferenceRunner( + model=model, + dataset=dataloader, + callbacks=callbacks, + ) + runner.run() + + import json + from collections import defaultdict + import numpy as np + + with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'), 'r') as f: + mappings = json.load(f) + + inter = {} + sem_2_ids = defaultdict(list) + collision_stats = [] + for mapping in mappings: + item_id = mapping['item_id'] + clusters = mapping['clusters'] + inter[int(item_id)] = clusters + sem_2_ids[tuple(clusters)].append(int(item_id)) + + for semantics, items in sem_2_ids.items(): + assert len(items) <= CODEBOOK_SIZE, str(len(items)) + collision_solvers = np.random.permutation(CODEBOOK_SIZE)[:len(items)].tolist() + for item_id, collision_solver in zip(items, collision_solvers): + inter[item_id].append(collision_solver) + collision_stats.append(collision_solver) + for i in range(len(inter[item_id])): + inter[item_id][i] += CODEBOOK_SIZE * i + + if collision_stats: + max_col_tok = max(collision_stats) + avg_col_tok = np.mean(collision_stats) + collision_distribution = Counter(collision_stats) + + print(f"Max collision token: {max_col_tok}") + print(f"Avg collision token: {avg_col_tok:.2f}") + print(f"Total items with collisions: {len(collision_stats)}") + print(f"Collision solver distribution: {dict(collision_distribution)}") + else: + print("No collisions detected") + + with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters_colisionless.json'), 'w') as f: + json.dump(inter, f, indent=2) + + +if __name__ == '__main__': + main() diff --git a/scripts/plum-lsvd/models.py b/scripts/plum-lsvd/models.py new file mode 100644 index 0000000..d475712 --- /dev/null +++ b/scripts/plum-lsvd/models.py @@ -0,0 +1,131 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class PlumRQVAE(nn.Module): + def __init__( + self, + input_dim, + num_codebooks, + codebook_size, + embedding_dim, + beta=0.25, + quant_loss_weight=1.0, + contrastive_loss_weight=1.0, + temperature=0.0, + ): + super().__init__() + self.register_buffer('beta', torch.tensor(beta)) + self.temperature = temperature + + self.input_dim = input_dim + self.num_codebooks = num_codebooks + self.codebook_size = codebook_size + self.embedding_dim = embedding_dim + self.quant_loss_weight = quant_loss_weight + + self.contrastive_loss_weight = contrastive_loss_weight + + self.encoder = self.make_encoding_tower(input_dim, embedding_dim) + self.decoder = self.make_encoding_tower(embedding_dim, input_dim) + + self.codebooks = torch.nn.ParameterList() + for _ in range(num_codebooks): + cb = torch.FloatTensor(codebook_size, embedding_dim) + #nn.init.normal_(cb) + self.codebooks.append(cb) + + @staticmethod + def make_encoding_tower(d1, d2, bias=False): + return torch.nn.Sequential( + nn.Linear(d1, d1), + nn.ReLU(), + nn.Linear(d1, d2), + nn.ReLU(), + nn.Linear(d2, d2, bias=bias) + ) + + @staticmethod + def get_codebook_indices(remainder, codebook): + dist = torch.cdist(remainder, codebook) + return dist.argmin(dim=-1) + + def _quantize_representation(self, latent_vector): + latent_restored = 0 + remainder = latent_vector + + for codebook in self.codebooks: + codebook_indices = self.get_codebook_indices(remainder, codebook) + quantized = codebook[codebook_indices] + codebook_vectors = remainder + (quantized - remainder).detach() + latent_restored += codebook_vectors + remainder = remainder - codebook_vectors + + return latent_restored + + def contrastive_loss(self, p_i, p_i_star): + N_b = p_i.size(0) + + p_i = F.normalize(p_i, p=2, dim=-1) #TODO посмотреть без нормалайза + p_i_star = F.normalize(p_i_star, p=2, dim=-1) + + similarities = torch.matmul(p_i, p_i_star.T) / self.temperature + + labels = torch.arange(N_b, dtype=torch.long, device=p_i.device) + + loss = F.cross_entropy(similarities, labels) + + return loss + + def forward(self, inputs): + latent_vector = self.encoder(inputs['embedding']) + item_ids = inputs['item_id'] + + latent_restored = 0 + rqvae_loss = 0 + clusters = [] + remainder = latent_vector + + for codebook in self.codebooks: + codebook_indices = self.get_codebook_indices(remainder, codebook) + clusters.append(codebook_indices) + + quantized = codebook[codebook_indices] + codebook_vectors = remainder + (quantized - remainder).detach() + + rqvae_loss += self.beta * torch.nn.functional.mse_loss(remainder, quantized.detach()) + rqvae_loss += torch.nn.functional.mse_loss(quantized, remainder.detach()) + + latent_restored += codebook_vectors + remainder = remainder - codebook_vectors + + embeddings_restored = self.decoder(latent_restored) + recon_loss = torch.nn.functional.mse_loss(embeddings_restored, inputs['embedding']) + + if 'cooccurrence_embedding' in inputs: + cooccurrence_latent = self.encoder(inputs['cooccurrence_embedding'].to(latent_restored.device)) + cooccurrence_restored = self._quantize_representation(cooccurrence_latent) + con_loss = self.contrastive_loss(latent_restored, cooccurrence_restored) + else: + con_loss = torch.as_tensor(0.0, device=latent_vector.device) + + loss = ( + recon_loss + + self.quant_loss_weight * rqvae_loss + + self.contrastive_loss_weight * con_loss + ).mean() + + clusters_counts = [] + for cluster in clusters: + clusters_counts.append(torch.bincount(cluster, minlength=self.codebook_size)) + + return loss, { + 'loss': loss.item(), + 'recon_loss': recon_loss.mean().item(), + 'rqvae_loss': rqvae_loss.mean().item(), + 'con_loss': con_loss.item(), + + 'clusters_counts': clusters_counts, + 'clusters': torch.stack(clusters).T, + 'embedding_hat': embeddings_restored, + } \ No newline at end of file diff --git a/scripts/plum-lsvd/train_plum_4.1.py b/scripts/plum-lsvd/train_plum_4.1.py new file mode 100644 index 0000000..85027b3 --- /dev/null +++ b/scripts/plum-lsvd/train_plum_4.1.py @@ -0,0 +1,180 @@ +from loguru import logger +import os + +import torch + +import pickle + +import irec.callbacks as cb +from irec.data.dataloader import DataLoader +from irec.data.transforms import Collate, ToTorch, ToDevice +from irec.runners import TrainingRunner + +from irec.utils import fix_random_seed + +from callbacks import InitCodebooks, FixDeadCentroids +from data import EmbeddingDatasetParquet, ProcessEmbeddings +from models import PlumRQVAE +from transforms import AddWeightedCooccurrenceEmbeddingsVectorized +from cooc_data import CoocMappingDataset + +# ЭКСПЕРИМЕНТ С ПОЛНОЙ ИСТОРИЕЙ +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + +NUM_EPOCHS = 35 +BATCH_SIZE = 1024 + +INPUT_DIM = 64 +HIDDEN_DIM = 32 +CODEBOOK_SIZE = 512 +NUM_CODEBOOKS = 3 +BETA = 0.25 +LR = 1e-4 +WINDOW_SIZE = 2 +K=2000 + +EXPERIMENT_NAME = f'4-1_vk_lsvd_ods_base_with_gap_cb_{CODEBOOK_SIZE}_ws_{WINDOW_SIZE}_k_{K}_8w_e{NUM_EPOCHS}' +INTER_TRAIN_PATH = "/home/jovyan/IRec/sigir/lsvd_data/8-weeks-base-ows/base_with_gap_interactions_grouped.parquet" +EMBEDDINGS_PATH = "/home/jovyan/IRec/sigir/lsvd_data/8-weeks-base-ows/items_metadata_remapped.parquet" +IREC_PATH = '../../' + +print(INTER_TRAIN_PATH) +def main(): + fix_random_seed(SEED_VALUE) + + dataset = EmbeddingDatasetParquet( + data_path=EMBEDDINGS_PATH, + ) + + data = CoocMappingDataset.create_from_split_part( + train_inter_parquet_path=INTER_TRAIN_PATH, + window_size=WINDOW_SIZE + ) + + item_id_to_embedding = {} + all_item_ids = [] + for idx in range(len(dataset)): + sample = dataset[idx] + item_id = int(sample['item_id']) + item_id_to_embedding[item_id] = torch.tensor(sample['embedding'], device=DEVICE) + all_item_ids.append(item_id) + + # add_cooc_transform = AddWeightedCooccurrenceEmbeddings(data.cooccur_counter_mapping, item_id_to_embedding, all_item_ids, K) + add_cooc_transform = AddWeightedCooccurrenceEmbeddingsVectorized( + cooccur_counts=data.cooccur_counter_mapping, + item_id_to_embedding=item_id_to_embedding, + all_item_ids=all_item_ids, + device=DEVICE, + max_neighbors=K, + seed=42 + ) + + train_dataloader = DataLoader( #call в основном потоке делается нужно исправить + dataset, + batch_size=BATCH_SIZE, + shuffle=True, + drop_last=True, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map( + ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding']) + ).map(add_cooc_transform + ).repeat(NUM_EPOCHS) + + valid_dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=False, + drop_last=False, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding']) + ).map(add_cooc_transform) + + LOG_EVERY_NUM_STEPS = int(len(train_dataloader) // NUM_EPOCHS) + + model = PlumRQVAE( + input_dim=INPUT_DIM, + num_codebooks=NUM_CODEBOOKS, + codebook_size=CODEBOOK_SIZE, + embedding_dim=HIDDEN_DIM, + beta=BETA, + quant_loss_weight=1.0, + contrastive_loss_weight=1.0, + temperature=1.0 + ).to(DEVICE) + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + optimizer = torch.optim.Adam(model.parameters(), lr=LR, fused=True) + + callbacks = [ + InitCodebooks(valid_dataloader), + + cb.BatchMetrics(metrics=lambda model_outputs, batch: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + 'con_loss': model_outputs['con_loss'] + }, name='train'), + + FixDeadCentroids(valid_dataloader), + + cb.MetricAccumulator( + accumulators={ + 'train/loss': cb.MeanAccumulator(), + 'train/recon_loss': cb.MeanAccumulator(), + 'train/rqvae_loss': cb.MeanAccumulator(), + 'train/con_loss': cb.MeanAccumulator(), + 'num_dead/0': cb.MeanAccumulator(), + 'num_dead/1': cb.MeanAccumulator(), + 'num_dead/2': cb.MeanAccumulator(), + }, + reset_every_num_steps=LOG_EVERY_NUM_STEPS + ), + + cb.Validation( + dataset=valid_dataloader, + callbacks=[ + cb.BatchMetrics(metrics=lambda model_outputs, batch: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + 'con_loss': model_outputs['con_loss'] + }, name='valid'), + cb.MetricAccumulator( + accumulators={ + 'valid/loss': cb.MeanAccumulator(), + 'valid/recon_loss': cb.MeanAccumulator(), + 'valid/rqvae_loss': cb.MeanAccumulator(), + 'valid/con_loss': cb.MeanAccumulator() + } + ), + ], + ).every_num_steps(LOG_EVERY_NUM_STEPS), + + cb.Logger().every_num_steps(LOG_EVERY_NUM_STEPS), + cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=os.path.join(IREC_PATH, 'tensorboard_logs')), + + cb.EarlyStopping( + metric='valid/recon_loss', + patience=40, + minimize=True, + model_path=os.path.join(IREC_PATH, 'checkpoints', EXPERIMENT_NAME) + ).every_num_steps(LOG_EVERY_NUM_STEPS), + ] + + logger.debug('Everything is ready for training process!') + + runner = TrainingRunner( + model=model, + optimizer=optimizer, + dataset=train_dataloader, + callbacks=callbacks, + ) + runner.run() + + +if __name__ == '__main__': + main() diff --git a/scripts/plum-lsvd/train_plum_4.2.py b/scripts/plum-lsvd/train_plum_4.2.py new file mode 100644 index 0000000..de1864d --- /dev/null +++ b/scripts/plum-lsvd/train_plum_4.2.py @@ -0,0 +1,180 @@ +from loguru import logger +import os + +import torch + +import pickle + +import irec.callbacks as cb +from irec.data.dataloader import DataLoader +from irec.data.transforms import Collate, ToTorch, ToDevice +from irec.runners import TrainingRunner + +from irec.utils import fix_random_seed + +from callbacks import InitCodebooks, FixDeadCentroids +from data import EmbeddingDatasetParquet, ProcessEmbeddings +from models import PlumRQVAE +from transforms import AddWeightedCooccurrenceEmbeddingsVectorized +from cooc_data import CoocMappingDataset + +# ЭКСПЕРИМЕНТ С ОБРЕЗАННОЙ ИСТОРИЕЙ +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + +NUM_EPOCHS = 35 +BATCH_SIZE = 1024 + +INPUT_DIM = 64 +HIDDEN_DIM = 32 +CODEBOOK_SIZE = 512 +NUM_CODEBOOKS = 3 +BETA = 0.25 +LR = 1e-4 +WINDOW_SIZE = 2 +K=2000 + +EXPERIMENT_NAME = f'4-2_vk_lsvd_ods_base_with_gap_cb_{CODEBOOK_SIZE}_ws_{WINDOW_SIZE}_k_{K}_8w_e{NUM_EPOCHS}' +INTER_TRAIN_PATH = "/home/jovyan/IRec/sigir/lsvd_data/8-weeks-base-ows/base_interactions_grouped.parquet" +EMBEDDINGS_PATH = "/home/jovyan/IRec/sigir/lsvd_data/8-weeks-base-ows/items_metadata_remapped.parquet" +IREC_PATH = '../../' + +print(INTER_TRAIN_PATH) +def main(): + fix_random_seed(SEED_VALUE) + + dataset = EmbeddingDatasetParquet( + data_path=EMBEDDINGS_PATH, + ) + + data = CoocMappingDataset.create_from_split_part( + train_inter_parquet_path=INTER_TRAIN_PATH, + window_size=WINDOW_SIZE + ) + + item_id_to_embedding = {} + all_item_ids = [] + for idx in range(len(dataset)): + sample = dataset[idx] + item_id = int(sample['item_id']) + item_id_to_embedding[item_id] = torch.tensor(sample['embedding'], device=DEVICE) + all_item_ids.append(item_id) + + # add_cooc_transform = AddWeightedCooccurrenceEmbeddings(data.cooccur_counter_mapping, item_id_to_embedding, all_item_ids, K) + add_cooc_transform = AddWeightedCooccurrenceEmbeddingsVectorized( + cooccur_counts=data.cooccur_counter_mapping, + item_id_to_embedding=item_id_to_embedding, + all_item_ids=all_item_ids, + device=DEVICE, + max_neighbors=K, + seed=42 + ) + + train_dataloader = DataLoader( #call в основном потоке делается нужно исправить + dataset, + batch_size=BATCH_SIZE, + shuffle=True, + drop_last=True, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map( + ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding']) + ).map(add_cooc_transform + ).repeat(NUM_EPOCHS) + + valid_dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=False, + drop_last=False, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding']) + ).map(add_cooc_transform) + + LOG_EVERY_NUM_STEPS = int(len(train_dataloader) // NUM_EPOCHS) + + model = PlumRQVAE( + input_dim=INPUT_DIM, + num_codebooks=NUM_CODEBOOKS, + codebook_size=CODEBOOK_SIZE, + embedding_dim=HIDDEN_DIM, + beta=BETA, + quant_loss_weight=1.0, + contrastive_loss_weight=1.0, + temperature=1.0 + ).to(DEVICE) + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + optimizer = torch.optim.Adam(model.parameters(), lr=LR, fused=True) + + callbacks = [ + InitCodebooks(valid_dataloader), + + cb.BatchMetrics(metrics=lambda model_outputs, batch: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + 'con_loss': model_outputs['con_loss'] + }, name='train'), + + FixDeadCentroids(valid_dataloader), + + cb.MetricAccumulator( + accumulators={ + 'train/loss': cb.MeanAccumulator(), + 'train/recon_loss': cb.MeanAccumulator(), + 'train/rqvae_loss': cb.MeanAccumulator(), + 'train/con_loss': cb.MeanAccumulator(), + 'num_dead/0': cb.MeanAccumulator(), + 'num_dead/1': cb.MeanAccumulator(), + 'num_dead/2': cb.MeanAccumulator(), + }, + reset_every_num_steps=LOG_EVERY_NUM_STEPS + ), + + cb.Validation( + dataset=valid_dataloader, + callbacks=[ + cb.BatchMetrics(metrics=lambda model_outputs, batch: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + 'con_loss': model_outputs['con_loss'] + }, name='valid'), + cb.MetricAccumulator( + accumulators={ + 'valid/loss': cb.MeanAccumulator(), + 'valid/recon_loss': cb.MeanAccumulator(), + 'valid/rqvae_loss': cb.MeanAccumulator(), + 'valid/con_loss': cb.MeanAccumulator() + } + ), + ], + ).every_num_steps(LOG_EVERY_NUM_STEPS), + + cb.Logger().every_num_steps(LOG_EVERY_NUM_STEPS), + cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=os.path.join(IREC_PATH, 'tensorboard_logs')), + + cb.EarlyStopping( + metric='valid/recon_loss', + patience=40, + minimize=True, + model_path=os.path.join(IREC_PATH, 'checkpoints', EXPERIMENT_NAME) + ).every_num_steps(LOG_EVERY_NUM_STEPS), + ] + + logger.debug('Everything is ready for training process!') + + runner = TrainingRunner( + model=model, + optimizer=optimizer, + dataset=train_dataloader, + callbacks=callbacks, + ) + runner.run() + + +if __name__ == '__main__': + main() diff --git a/scripts/plum-lsvd/train_rqvae.py b/scripts/plum-lsvd/train_rqvae.py new file mode 100644 index 0000000..ea41b74 --- /dev/null +++ b/scripts/plum-lsvd/train_rqvae.py @@ -0,0 +1,174 @@ +from loguru import logger +import os + +import torch + +import pickle + +import irec.callbacks as cb +from irec.data.dataloader import DataLoader +from irec.data.transforms import Collate, ToTorch, ToDevice +from irec.runners import TrainingRunner + +from irec.utils import fix_random_seed + +from callbacks import InitCodebooks, FixDeadCentroids +from data import EmbeddingDatasetParquet, ProcessEmbeddings +from models import PlumRQVAE +from transforms import AddWeightedCooccurrenceEmbeddings +from cooc_data import CoocMappingDataset + +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + +NUM_EPOCHS = 15 +BATCH_SIZE = 1024 + +INPUT_DIM = 64 +HIDDEN_DIM = 32 +CODEBOOK_SIZE = 512 +NUM_CODEBOOKS = 3 +BETA = 0.25 +LR = 1e-4 +WINDOW_SIZE = 3 +MAX_LEN = 500 +K=100 + +# EXPERIMENT_NAME = f'4-1_vk_lsvd_ods_base_with_gap_cb_{CODEBOOK_SIZE}_ws_{WINDOW_SIZE}_k_{K}_ml_{MAX_LEN}' +EXPERIMENT_NAME = f'rqvae_vk_lsvd_cz_512_8-weeks' +EMBEDDINGS_PATH = "/home/jovyan/IRec/sigir/lsvd_data/8-weeks-base-ows/items_metadata_remapped.parquet" +IREC_PATH = '../../' + +# print(INTER_TRAIN_PATH) +def main(): + fix_random_seed(SEED_VALUE) + + dataset = EmbeddingDatasetParquet( + data_path=EMBEDDINGS_PATH + ) + + # data = CoocMappingDataset.create_from_split_part( + # train_inter_parquet_path=INTER_TRAIN_PATH, + # window_size=WINDOW_SIZE, + # max_items=MAX_LEN + # ) + + item_id_to_embedding = {} + all_item_ids = [] + for idx in range(len(dataset)): + sample = dataset[idx] + item_id = int(sample['item_id']) + item_id_to_embedding[item_id] = torch.tensor(sample['embedding'], device=DEVICE) + all_item_ids.append(item_id) + + # add_cooc_transform = AddWeightedCooccurrenceEmbeddings(data.cooccur_counter_mapping, item_id_to_embedding, all_item_ids) + + train_dataloader = DataLoader( #call в основном потоке делается нужно исправить + dataset, + batch_size=BATCH_SIZE, + shuffle=True, + drop_last=True, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map( + ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding']) + # ).map(add_cooc_transform + ).repeat(NUM_EPOCHS) + + valid_dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=False, + drop_last=False, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding']) + # ).map(add_cooc_transform) + ) + + LOG_EVERY_NUM_STEPS = int(len(train_dataloader) // NUM_EPOCHS) + + model = PlumRQVAE( + input_dim=INPUT_DIM, + num_codebooks=NUM_CODEBOOKS, + codebook_size=CODEBOOK_SIZE, + embedding_dim=HIDDEN_DIM, + beta=BETA, + quant_loss_weight=1.0, + contrastive_loss_weight=1.0, + temperature=1.0 + ).to(DEVICE) + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + optimizer = torch.optim.Adam(model.parameters(), lr=LR, fused=True) + + callbacks = [ + InitCodebooks(valid_dataloader), + + cb.BatchMetrics(metrics=lambda model_outputs, batch: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + 'con_loss': model_outputs['con_loss'] + }, name='train'), + + FixDeadCentroids(valid_dataloader), + + cb.MetricAccumulator( + accumulators={ + 'train/loss': cb.MeanAccumulator(), + 'train/recon_loss': cb.MeanAccumulator(), + 'train/rqvae_loss': cb.MeanAccumulator(), + 'train/con_loss': cb.MeanAccumulator(), + 'num_dead/0': cb.MeanAccumulator(), + 'num_dead/1': cb.MeanAccumulator(), + 'num_dead/2': cb.MeanAccumulator(), + }, + reset_every_num_steps=LOG_EVERY_NUM_STEPS + ), + + cb.Validation( + dataset=valid_dataloader, + callbacks=[ + cb.BatchMetrics(metrics=lambda model_outputs, batch: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + 'con_loss': model_outputs['con_loss'] + }, name='valid'), + cb.MetricAccumulator( + accumulators={ + 'valid/loss': cb.MeanAccumulator(), + 'valid/recon_loss': cb.MeanAccumulator(), + 'valid/rqvae_loss': cb.MeanAccumulator(), + 'valid/con_loss': cb.MeanAccumulator() + } + ), + ], + ).every_num_steps(LOG_EVERY_NUM_STEPS), + + cb.Logger().every_num_steps(LOG_EVERY_NUM_STEPS), + cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=os.path.join(IREC_PATH, 'tensorboard_logs')), + + cb.EarlyStopping( + metric='valid/recon_loss', + patience=40, + minimize=True, + model_path=os.path.join(IREC_PATH, 'checkpoints', EXPERIMENT_NAME) + ).every_num_steps(LOG_EVERY_NUM_STEPS), + ] + + logger.debug('Everything is ready for training process!') + + runner = TrainingRunner( + model=model, + optimizer=optimizer, + dataset=train_dataloader, + callbacks=callbacks, + ) + runner.run() + + +if __name__ == '__main__': + main() diff --git a/scripts/plum-lsvd/transforms.py b/scripts/plum-lsvd/transforms.py new file mode 100644 index 0000000..143002b --- /dev/null +++ b/scripts/plum-lsvd/transforms.py @@ -0,0 +1,287 @@ +import numpy as np +import pickle +import torch +from typing import Dict, List +import time +from collections import defaultdict, Counter + +class AddWeightedCooccurrenceEmbeddings: + def __init__(self, cooccur_counts, item_id_to_embedding, all_item_ids, top_k): + self.cooccur_counts = cooccur_counts + self.item_id_to_embedding = item_id_to_embedding + self.all_item_ids = all_item_ids + self.call_count = 0 + self.top_k = top_k + + # Предвычисляем top_k для каждого item_id + self._top_k_cache = {} + self._build_top_k_cache() + + def _build_top_k_cache(self): + """Предвычисляет top-k соседей для каждого item_id""" + for item_id, counter in self.cooccur_counts.items(): + if counter and len(counter) > 0: + # Сортируем по частоте и берем top_k + top_items = counter.most_common(self.top_k) + cooc_ids, freqs = zip(*top_items) + freqs_array = np.array(freqs, dtype=np.float32) + probs = freqs_array / freqs_array.sum() + + self._top_k_cache[item_id] = { + 'cooc_ids': cooc_ids, + 'probs': probs + } + + def __call__(self, batch): + self.call_count += 1 + item_ids = batch['item_id'] + cooccurrence_embeddings = [] + + for idx, item_id in enumerate(item_ids): + item_id_val = int(item_id.item()) if torch.is_tensor(item_id) else int(item_id) + + # Используем предвычисленный top-k кэш + if item_id_val in self._top_k_cache: + cache_entry = self._top_k_cache[item_id_val] + cooc_id = np.random.choice( + cache_entry['cooc_ids'], + p=cache_entry['probs'] + ) + else: + cooc_id = np.random.choice(self.all_item_ids) + if self.call_count % 500 == 0 and idx < 5: + print(f" idx={idx}: item_id={item_id_val} fallback random") + if self.call_count % 500 == 0 and idx < 5: + print(f" idx={idx}: item_id={item_id_val} cooc_id={cooc_id}") + cooc_emb = self.item_id_to_embedding.get(cooc_id, batch['embedding'][0]) + cooccurrence_embeddings.append(cooc_emb) + + batch['cooccurrence_embedding'] = torch.stack(cooccurrence_embeddings) + return batch + +#запустить сасрек, леттер, sasrec << tiger < letter < plum + + +class AddWeightedCooccurrenceEmbeddingsVectorized: + + def __init__( + self, + cooccur_counts: Dict[int, Dict[int, int]], + item_id_to_embedding: Dict[int, torch.Tensor], + all_item_ids: List[int], + device: torch.device, + limit_neighbors: bool = True, + max_neighbors: int = 256, + seed: int = 42, + verbose: bool = True + ): + self.device = device + self.call_count = 0 + self.limit_neighbors = limit_neighbors + self.max_neighbors = max_neighbors + self.seed = seed + self.verbose = verbose + + torch.manual_seed(seed) + np.random.seed(seed) + + if self.verbose: + print(f"\n{'='*80}") + print(f"Initializing AddWeightedCooccurrenceEmbeddingsVectorized") + print(f"{'='*80}") + init_start = time.time() + + all_item_ids_sorted = sorted(all_item_ids) + self.item_id_to_idx = {item_id: idx for idx, item_id in enumerate(all_item_ids_sorted)} + self.idx_to_item_id = torch.tensor(all_item_ids_sorted, device=device, dtype=torch.long) + + if self.verbose: + print(f"[INIT] Sorted {len(all_item_ids)} item IDs and created mappings") + + num_items = len(all_item_ids_sorted) + embedding_dim = next(iter(item_id_to_embedding.values())).shape[0] + + if self.verbose: + print(f"[INIT] Num items: {num_items}, Embedding dim: {embedding_dim}") + + self.embedding_matrix = torch.zeros( + size=(num_items, embedding_dim), + device=device, + dtype=torch.float32, + requires_grad=False + ) + + emb_load_start = time.time() + for item_id, emb in item_id_to_embedding.items(): + idx = self.item_id_to_idx[item_id] + if isinstance(emb, torch.Tensor): + self.embedding_matrix[idx] = emb.to(device).detach() + else: + self.embedding_matrix[idx] = torch.tensor(emb, device=device, dtype=torch.float32) + + if self.verbose: + emb_load_time = time.time() - emb_load_start + print(f"[INIT] Loaded {len(item_id_to_embedding)} embeddings in {emb_load_time*1000:.2f}ms") + + self._build_cooccurrence_tables(cooccur_counts, num_items) + + if self.verbose: + init_time = time.time() - init_start + print(f"[INIT] Total initialization time: {init_time*1000:.2f}ms") + print(f"{'='*80}\n") + + def _build_cooccurrence_tables(self, cooccur_counts: Dict, num_items: int): + if self.verbose: + build_start = time.time() + print(f"\n[BUILD] Building cooccurrence tables...") + + indexed_cooccur_counts = {} + for item_id, neighbors in cooccur_counts.items(): + if item_id in self.item_id_to_idx: + idx = self.item_id_to_idx[item_id] + indexed_neighbors = {} + for neighbor_id, count in neighbors.items(): + if neighbor_id in self.item_id_to_idx: + neighbor_idx = self.item_id_to_idx[neighbor_id] + indexed_neighbors[neighbor_idx] = count + if indexed_neighbors: + indexed_cooccur_counts[idx] = indexed_neighbors + + if self.verbose: + items_with_cooc = len(indexed_cooccur_counts) + print(f"[BUILD] Items with cooccurrences: {items_with_cooc}/{num_items}") + total_pairs = sum(len(neighbors) for neighbors in indexed_cooccur_counts.values()) + print(f"[BUILD] Total cooccurrence pairs: {total_pairs}") + + max_actual_neighbors = 0 + for idx in range(num_items): + counter = indexed_cooccur_counts.get(idx) + if counter and len(counter) > 0: + num_neighbors = len(counter) + if self.limit_neighbors: + num_neighbors = min(num_neighbors, self.max_neighbors) + else: + num_neighbors = num_items + max_actual_neighbors = max(max_actual_neighbors, num_neighbors) + + if self.limit_neighbors: + max_actual_neighbors = min(max_actual_neighbors, self.max_neighbors) + + if self.verbose: + print(f"[BUILD] Max neighbors per item: {max_actual_neighbors}") + + neighbors_matrix = torch.zeros( + (num_items, max_actual_neighbors), + dtype=torch.long, + device=self.device, + requires_grad=False + ) + + probs_matrix = torch.zeros( + (num_items, max_actual_neighbors), + dtype=torch.float32, + device=self.device, + requires_grad=False + ) + + valid_mask = torch.zeros( + (num_items, max_actual_neighbors), + dtype=torch.bool, + device=self.device, + requires_grad=False + ) + + matrix_fill_start = time.time() + + for idx in range(num_items): + counter = indexed_cooccur_counts.get(idx) + + if counter and len(counter) > 0: + cooc_items = sorted(counter.items(), key=lambda x: x, reverse=True) + cooc_ids, freqs = zip(*cooc_items) + cooc_ids = list(cooc_ids) + freqs = np.array(freqs, dtype=np.float32) + + num_neighbors = min(len(cooc_ids), max_actual_neighbors) + cooc_ids = cooc_ids[:num_neighbors] + freqs = freqs[:num_neighbors] + + probs = freqs / freqs.sum() + + neighbors_matrix[idx, :num_neighbors] = torch.tensor( + cooc_ids, dtype=torch.long, device=self.device + ) + probs_matrix[idx, :num_neighbors] = torch.tensor( + probs, dtype=torch.float32, device=self.device + ) + valid_mask[idx, :num_neighbors] = True + + else: + if max_actual_neighbors >= num_items: + neighbors_matrix[idx, :num_items] = torch.arange(num_items, device=self.device) + probs_matrix[idx, :num_items] = 1.0 / num_items + valid_mask[idx, :num_items] = True + else: + perm = torch.randperm(num_items, device=self.device)[:max_actual_neighbors] + neighbors_matrix[idx] = perm + probs_matrix[idx] = 1.0 / max_actual_neighbors + valid_mask[idx] = True + + if self.verbose: + matrix_fill_time = time.time() - matrix_fill_start + print(f"[BUILD] Filled matrices in {matrix_fill_time*1000:.2f}ms") + + self.neighbors_matrix = neighbors_matrix + self.probs_matrix = probs_matrix + self.valid_mask = valid_mask + + if self.verbose: + print(f"[BUILD] neighbors_matrix shape: {neighbors_matrix.shape}") + print(f"[BUILD] probs_matrix shape: {probs_matrix.shape}") + print(f"[BUILD] valid_mask shape: {valid_mask.shape}") + build_time = time.time() - build_start + print(f"[BUILD] Total build time: {build_time*1000:.2f}ms") + + def __call__(self, batch): + self.call_count += 1 + + call_start = time.time() + + item_ids = batch['item_id'] + + if not isinstance(item_ids, torch.Tensor): + item_ids = torch.tensor(item_ids, device=self.device, dtype=torch.long) + else: + item_ids = item_ids.to(device=self.device, dtype=torch.long) + + batch_size = item_ids.shape + + indexed_item_ids = torch.tensor( + [self.item_id_to_idx.get(int(iid.item()), 0) for iid in item_ids], + device=self.device, + dtype=torch.long + ) + + probs = self.probs_matrix[indexed_item_ids] + mask = self.valid_mask[indexed_item_ids] + + masked_probs = probs.clone() + masked_probs[~mask] = 0.0 + + row_sums = masked_probs.sum(dim=1, keepdim=True) + row_sums[row_sums == 0] = 1.0 + masked_probs = masked_probs / row_sums + + neighbor_indices = torch.multinomial(masked_probs, num_samples=1, replacement=True) + neighbor_indices = neighbor_indices.squeeze(1) + + cooc_indexed_ids = self.neighbors_matrix[indexed_item_ids, neighbor_indices] + cooccurrence_embeddings = self.embedding_matrix[cooc_indexed_ids] + + batch['cooccurrence_embedding'] = cooccurrence_embeddings + + call_time = time.time() - call_start + if self.verbose and self.call_count % 1000 == 0: + print(f"Call #{self.call_count}: batch_size={batch_size}, {call_time*1000:.2f}ms") + + return batch \ No newline at end of file diff --git a/scripts/plum-yambda/callbacks.py b/scripts/plum-yambda/callbacks.py new file mode 100644 index 0000000..43ec460 --- /dev/null +++ b/scripts/plum-yambda/callbacks.py @@ -0,0 +1,64 @@ +import torch + +import irec.callbacks as cb +from irec.runners import TrainingRunner, TrainingRunnerContext + +class InitCodebooks(cb.TrainingCallback): + def __init__(self, dataloader): + super().__init__() + self._dataloader = dataloader + + @torch.no_grad() + def before_run(self, runner: TrainingRunner): + for i in range(len(runner.model.codebooks)): + X = next(iter(self._dataloader))['embedding'] + idx = torch.randperm(X.shape[0], device=X.device)[:len(runner.model.codebooks[i])] + remainder = runner.model.encoder(X[idx]) + + for j in range(i): + codebook_indices = runner.model.get_codebook_indices(remainder, runner.model.codebooks[j]) + codebook_vectors = runner.model.codebooks[j][codebook_indices] + remainder = remainder - codebook_vectors + + runner.model.codebooks[i].data = remainder.detach() + + +class FixDeadCentroids(cb.TrainingCallback): + def __init__(self, dataloader): + super().__init__() + self._dataloader = dataloader + + def after_step(self, runner: TrainingRunner, context: TrainingRunnerContext): + for i, num_fixed in enumerate(self.fix_dead_codebooks(runner)): + context.metrics[f'num_dead/{i}'] = num_fixed + + @torch.no_grad() + def fix_dead_codebooks(self, runner: TrainingRunner): + num_fixed = [] + for codebook_idx, codebook in enumerate(runner.model.codebooks): + centroid_counts = torch.zeros(codebook.shape[0], dtype=torch.long, device=codebook.device) + random_batch = next(iter(self._dataloader))['embedding'] + + for batch in self._dataloader: + remainder = runner.model.encoder(batch['embedding']) + for l in range(codebook_idx): + ind = runner.model.get_codebook_indices(remainder, runner.model.codebooks[l]) + remainder = remainder - runner.model.codebooks[l][ind] + + indices = runner.model.get_codebook_indices(remainder, codebook) + centroid_counts.scatter_add_(0, indices, torch.ones_like(indices)) + + dead_mask = (centroid_counts == 0) + num_dead = int(dead_mask.sum().item()) + num_fixed.append(num_dead) + if num_dead == 0: + continue + + remainder = runner.model.encoder(random_batch) + for l in range(codebook_idx): + ind = runner.model.get_codebook_indices(remainder, runner.model.codebooks[l]) + remainder = remainder - runner.model.codebooks[l][ind] + remainder = remainder[torch.randperm(remainder.shape[0], device=codebook.device)][:num_dead] + codebook[dead_mask] = remainder.detach() + + return num_fixed diff --git a/scripts/plum-yambda/cooc_data.py b/scripts/plum-yambda/cooc_data.py new file mode 100644 index 0000000..50f2bdd --- /dev/null +++ b/scripts/plum-yambda/cooc_data.py @@ -0,0 +1,108 @@ +import json +import pickle +from collections import defaultdict, Counter + +import numpy as np +from loguru import logger + + +import pickle +from collections import defaultdict, Counter + +class CoocMappingDataset: + def __init__( + self, + train_sampler, + num_items, + cooccur_counter_mapping=None + ): + self._train_sampler = train_sampler + self._num_items = num_items + self._cooccur_counter_mapping = cooccur_counter_mapping + + @classmethod + def create(cls, inter_json_path, window_size): + max_item_id = 0 + train_dataset, validation_dataset, test_dataset = [], [], [] + + with open(inter_json_path, 'r') as f: + user_interactions = json.load(f) + + for user_id_str, item_ids in user_interactions.items(): + user_id = int(user_id_str) + if item_ids: + max_item_id = max(max_item_id, max(item_ids)) + assert len(item_ids) >= 5, f'Core-5 dataset is used, user {user_id} has only {len(item_ids)} items' + train_dataset.append({ + 'user.ids': [user_id], + 'item.ids': item_ids[:-2], + }) + + cooccur_counter_mapping = cls.build_cooccur_counter_mapping(train_dataset, window_size=window_size) + logger.debug(f'Computed window-based co-occurrence mapping for {len(cooccur_counter_mapping)} items but max_item_id is {max_item_id}') + + train_sampler = train_dataset + + return cls( + train_sampler=train_sampler, + num_items=max_item_id + 1, + cooccur_counter_mapping=cooccur_counter_mapping + ) + + @classmethod + def create_from_split_part( + cls, + train_inter_json_path, + window_size + ): + + max_item_id = 0 + train_dataset = [] + + with open(train_inter_json_path, 'r') as f: + train_interactions = json.load(f) + + # Обрабатываем TRAIN + for user_id_str, item_ids in train_interactions.items(): + user_id = int(user_id_str) + if item_ids: + max_item_id = max(max_item_id, max(item_ids)) + + train_dataset.append({ + 'user.ids': [user_id], + 'item.ids': item_ids, + }) + + logger.debug(f'Train: {len(train_dataset)} users') + logger.debug(f'Max item ID: {max_item_id}') + + cooccur_counter_mapping = cls.build_cooccur_counter_mapping( + train_dataset, + window_size=window_size + ) + + logger.debug(f'Computed window-based co-occurrence mapping for {len(cooccur_counter_mapping)} items') + + return cls( + train_sampler=train_dataset, + num_items=max_item_id + 1, + cooccur_counter_mapping=cooccur_counter_mapping + ) + + + @staticmethod + def build_cooccur_counter_mapping(train_dataset, window_size): #TODO передавать время и по нему строить окно + cooccur_counts = defaultdict(Counter) + for session in train_dataset: + items = session['item.ids'] + for i in range(len(items)): + item_i = items[i] + for j in range(max(0, i - window_size), min(len(items), i + window_size + 1)): + if i != j: + cooccur_counts[item_i][items[j]] += 1 + return cooccur_counts + + + @property + def cooccur_counter_mapping(self): + return self._cooccur_counter_mapping diff --git a/scripts/plum-yambda/data.py b/scripts/plum-yambda/data.py new file mode 100644 index 0000000..842adb5 --- /dev/null +++ b/scripts/plum-yambda/data.py @@ -0,0 +1,62 @@ +import numpy as np +import pickle + +from irec.data.base import BaseDataset +from irec.data.transforms import Transform + + +import polars as pl +import numpy as np +import torch + +class EmbeddingDatasetParquet(BaseDataset): + def __init__(self, data_path): + self.df = pl.read_parquet(data_path) + self.item_ids = np.array(self.df['item_id'], dtype=np.int64) + self.embeddings = np.array(self.df['embedding'].to_list(), dtype=np.float32) + print(f"embedding dim: {self.embeddings[0].shape}") + + def __getitem__(self, idx): + index = self.item_ids[idx] + tensor_emb = self.embeddings[idx] + return { + 'item_id': index, + 'embedding': tensor_emb, + 'embedding_dim': len(tensor_emb) + } + + def __len__(self): + return len(self.embeddings) + + +class EmbeddingDataset(BaseDataset): + def __init__(self, data_path): + self.data_path = data_path + with open(data_path, 'rb') as f: + self.data = pickle.load(f) + + self.item_ids = np.array(self.data['item_id'], dtype=np.int64) + self.embeddings = np.array(self.data['embedding'], dtype=np.float32) + + def __getitem__(self, idx): + index = self.item_ids[idx] + tensor_emb = self.embeddings[idx] + return { + 'item_id': index, + 'embedding': tensor_emb, + 'embedding_dim': len(tensor_emb) + } + + def __len__(self): + return len(self.embeddings) + + +class ProcessEmbeddings(Transform): + def __init__(self, embedding_dim, keys): + self.embedding_dim = embedding_dim + self.keys = keys + + def __call__(self, batch): + for key in self.keys: + batch[key] = batch[key].reshape(-1, self.embedding_dim) + return batch \ No newline at end of file diff --git a/scripts/plum-yambda/models.py b/scripts/plum-yambda/models.py new file mode 100644 index 0000000..a411519 --- /dev/null +++ b/scripts/plum-yambda/models.py @@ -0,0 +1,135 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class PlumRQVAE(nn.Module): + def __init__( + self, + input_dim, + num_codebooks, + codebook_size, + embedding_dim, + beta=0.25, + quant_loss_weight=1.0, + contrastive_loss_weight=1.0, + temperature=0.0, + ): + super().__init__() + self.register_buffer('beta', torch.tensor(beta)) + self.temperature = temperature + + self.input_dim = input_dim + self.num_codebooks = num_codebooks + self.codebook_size = codebook_size + self.embedding_dim = embedding_dim + self.quant_loss_weight = quant_loss_weight + + self.contrastive_loss_weight = contrastive_loss_weight + + self.encoder = self.make_encoding_tower(input_dim, embedding_dim) + self.decoder = self.make_encoding_tower(embedding_dim, input_dim) + + self.codebooks = torch.nn.ParameterList() + for _ in range(num_codebooks): + cb = torch.FloatTensor(codebook_size, embedding_dim) + #nn.init.normal_(cb) + self.codebooks.append(cb) + + @staticmethod + def make_encoding_tower(d1, d2, bias=False): + return torch.nn.Sequential( + nn.Linear(d1, d1), + nn.ReLU(), + nn.Linear(d1, d2), + nn.ReLU(), + nn.Linear(d2, d2, bias=bias) + ) + + @staticmethod + def get_codebook_indices(remainder, codebook): + dist = torch.cdist(remainder, codebook) + return dist.argmin(dim=-1) + + def _quantize_representation(self, latent_vector): + latent_restored = 0 + remainder = latent_vector + + for codebook in self.codebooks: + codebook_indices = self.get_codebook_indices(remainder, codebook) + quantized = codebook[codebook_indices] + codebook_vectors = remainder + (quantized - remainder).detach() + latent_restored += codebook_vectors + remainder = remainder - codebook_vectors + + return latent_restored + + def contrastive_loss(self, p_i, p_i_star): + N_b = p_i.size(0) + + p_i = F.normalize(p_i, p=2, dim=-1) #TODO посмотреть без нормалайза + p_i_star = F.normalize(p_i_star, p=2, dim=-1) + + similarities = torch.matmul(p_i, p_i_star.T) / self.temperature + + labels = torch.arange(N_b, dtype=torch.long, device=p_i.device) + + loss = F.cross_entropy(similarities, labels) + + return loss #только по последней размерности + + def forward(self, inputs): + latent_vector = self.encoder(inputs['embedding']) + # print(f"latent vector shape: {latent_vector.shape}") + # print(f"inputs embedding shape: {inputs['embedding']}") + item_ids = inputs['item_id'] + + latent_restored = 0 + rqvae_loss = 0 + clusters = [] + remainder = latent_vector + + for codebook in self.codebooks: + codebook_indices = self.get_codebook_indices(remainder, codebook) + clusters.append(codebook_indices) + + quantized = codebook[codebook_indices] + codebook_vectors = remainder + (quantized - remainder).detach() + + rqvae_loss += self.beta * torch.nn.functional.mse_loss(remainder, quantized.detach()) + rqvae_loss += torch.nn.functional.mse_loss(quantized, remainder.detach()) + + latent_restored += codebook_vectors + remainder = remainder - codebook_vectors + + embeddings_restored = self.decoder(latent_restored) + recon_loss = torch.nn.functional.mse_loss(embeddings_restored, inputs['embedding']) + + if 'cooccurrence_embedding' in inputs: + # print(f"cooccurrence_embedding shape: {inputs['cooccurrence_embedding'].shape} device {inputs['cooccurrence_embedding'].device}" ) + # print(f"latent_restored shape {latent_restored.shape} device {latent_restored.device}") + cooccurrence_latent = self.encoder(inputs['cooccurrence_embedding'].to(latent_restored.device)) + cooccurrence_restored = self._quantize_representation(cooccurrence_latent) + con_loss = self.contrastive_loss(latent_restored, cooccurrence_restored) + else: + con_loss = torch.as_tensor(0.0, device=latent_vector.device) + + loss = ( + recon_loss + + self.quant_loss_weight * rqvae_loss + + self.contrastive_loss_weight * con_loss + ).mean() + + clusters_counts = [] + for cluster in clusters: + clusters_counts.append(torch.bincount(cluster, minlength=self.codebook_size)) + + return loss, { + 'loss': loss.item(), + 'recon_loss': recon_loss.mean().item(), + 'rqvae_loss': rqvae_loss.mean().item(), + 'con_loss': con_loss.item(), + + 'clusters_counts': clusters_counts, + 'clusters': torch.stack(clusters).T, + 'embedding_hat': embeddings_restored, + } \ No newline at end of file diff --git a/scripts/plum-yambda/transforms.py b/scripts/plum-yambda/transforms.py new file mode 100644 index 0000000..bdbfffa --- /dev/null +++ b/scripts/plum-yambda/transforms.py @@ -0,0 +1,247 @@ +import numpy as np +import pickle +import torch +import torch.nn.functional as F +from typing import Dict, List +from irec.data.base import BaseDataset +from irec.data.transforms import Transform + +from cooc_data import CoocMappingDataset + + +class AddWeightedCooccurrenceEmbeddings: + def __init__(self, cooccur_counts, item_id_to_embedding, all_item_ids): + self.cooccur_counts = cooccur_counts + self.item_id_to_embedding = item_id_to_embedding + self.all_item_ids = all_item_ids + self.call_count = 0 + + def __call__(self, batch): + self.call_count += 1 + item_ids = batch['item_id'] + cooccurrence_embeddings = [] + + for idx, item_id in enumerate(item_ids): + item_id_val = int(item_id.item()) if torch.is_tensor(item_id) else int(item_id) + + counter = self.cooccur_counts.get(item_id_val) + if counter and len(counter) > 0: + cooc_ids, freqs = zip(*counter.items()) + freqs_array = np.array(freqs, dtype=np.float32) + probs = freqs_array / freqs_array.sum() + cooc_id = np.random.choice(cooc_ids, p=probs) + + else: + cooc_id = np.random.choice(self.all_item_ids) + if self.call_count % 500 == 0 and idx < 5: + print(f" idx={idx}: item_id={item_id_val} fallback random") + + cooc_emb = self.item_id_to_embedding.get(cooc_id, batch['embedding'][0]) + cooccurrence_embeddings.append(cooc_emb) + + batch['cooccurrence_embedding'] = torch.stack(cooccurrence_embeddings) + return batch + + + +class AddWeightedCooccurrenceEmbeddingsCached: + def __init__(self, cooccur_counts, item_id_to_embedding, all_item_ids): + self.cooccur_counts = cooccur_counts + self.item_id_to_embedding = item_id_to_embedding + self.all_item_ids = all_item_ids + self.call_count = 0 + + self.cooc_probs_cache = {} + self._precompute_probabilities() + + def _precompute_probabilities(self): + for item_id, counter in self.cooccur_counts.items(): + if counter and len(counter) > 0: + cooc_ids, freqs = zip(*counter.items()) + freqs_array = np.array(freqs, dtype=np.float32) + probs = freqs_array / freqs_array.sum() + self.cooc_probs_cache[item_id] = (cooc_ids, probs) + + def __call__(self, batch): + self.call_count += 1 + item_ids = batch['item_id'] + cooccurrence_embeddings = [] + + for idx, item_id in enumerate(item_ids): + item_id_val = int(item_id.item()) if torch.is_tensor(item_id) else int(item_id) + + if item_id_val in self.cooc_probs_cache: + cooc_ids, probs = self.cooc_probs_cache[item_id_val] + cooc_id = np.random.choice(cooc_ids, p=probs) + else: + cooc_id = np.random.choice(self.all_item_ids) + if self.call_count % 10 == 0 and idx < 5: + print(f" idx={idx}: item_id={item_id_val} fallback random") + + cooc_emb = self.item_id_to_embedding.get(cooc_id, batch['embedding'][0]) + cooccurrence_embeddings.append(cooc_emb) + + batch['cooccurrence_embedding'] = torch.stack(cooccurrence_embeddings) + return batch + +class AddWeightedCooccurrenceEmbeddingsVectorized: + + def __init__( + self, + cooccur_counts: Dict[int, Dict[int, int]], + item_id_to_embedding: Dict[int, torch.Tensor], + all_item_ids: List[int], + device: torch.device, + limit_neighbors: bool = True, + max_neighbors: int = 256 + ): + """ + limit_neighbors: если True, ограничиваем до max_neighbors (для экономии памяти) + max_neighbors: максимум соседей (используется только если limit_neighbors=True) + """ + self.device = device + self.call_count = 0 + self.limit_neighbors = limit_neighbors + self.max_neighbors = max_neighbors + + max_item_id = max(item_id_to_embedding.keys()) + embedding_dim = next(iter(item_id_to_embedding.values())).shape[0] + + self.embedding_matrix = torch.zeros( + (max_item_id + 1, embedding_dim), + device=device, + dtype=torch.float32, + requires_grad=False + ) + + print("Building embedding matrix") + for item_id, emb in item_id_to_embedding.items(): + if isinstance(emb, torch.Tensor): + self.embedding_matrix[item_id] = emb.detach() + else: + self.embedding_matrix[item_id] = torch.tensor(emb, device=device, dtype=torch.float32) + + self.all_item_ids_tensor = torch.tensor( + all_item_ids, + device=device, + dtype=torch.long, + requires_grad=False + ) + + print("Building cooccurrence tables") + self._build_cooccurrence_tables(cooccur_counts, max_item_id, len(all_item_ids)) + + def _build_cooccurrence_tables(self, cooccur_counts: Dict, max_item_id: int, num_all_items: int): + """ + - neighbors_matrix: [max_item_id+1, num_neighbors] + - probs_matrix: [max_item_id+1, num_neighbors] + Если у item_id нет соседей, neighbors и probs заполняются равномерно из all_items + """ + neighbor_counts = {} + for item_id in range(max_item_id + 1): + counter = cooccur_counts.get(item_id) + if counter and len(counter) > 0: + num_neighbors = len(counter) + if self.limit_neighbors: + num_neighbors = min(num_neighbors, self.max_neighbors) + else: + num_neighbors = num_all_items + + neighbor_counts[item_id] = num_neighbors + + max_num_neighbors = max(neighbor_counts.values()) + actual_max_neighbors = min(max_num_neighbors, self.max_neighbors) if self.limit_neighbors else max_num_neighbors + + print(f"Max neighbors per item: {actual_max_neighbors}") + + neighbors_matrix = torch.zeros( + (max_item_id + 1, actual_max_neighbors), + dtype=torch.long, + device=self.device, + requires_grad=False + ) + + probs_matrix = torch.zeros( + (max_item_id + 1, actual_max_neighbors), + dtype=torch.float32, + device=self.device, + requires_grad=False + ) + + num_items_with_cooc = 0 + + # Заполняем матрицы + for item_id in range(max_item_id + 1): + counter = cooccur_counts.get(item_id) + + if counter and len(counter) > 0: + # === Есть соседи: используем реальные вероятности === + num_items_with_cooc += 1 + + # Извлекаем соседей и их counts, сортируем по частоте + cooc_ids, freqs = zip(*sorted(counter.items(), key=lambda x: x[1], reverse=True)) + cooc_ids = list(cooc_ids) + freqs = np.array(freqs, dtype=np.float32) + + # Берем только топ + num_neighbors = min(len(cooc_ids), actual_max_neighbors) + cooc_ids = cooc_ids[:num_neighbors] + freqs = freqs[:num_neighbors] + + # Нормализуем + probs = freqs / freqs.sum() + + neighbors_matrix[item_id, :num_neighbors] = torch.tensor( + cooc_ids, dtype=torch.long, device=self.device + ) + probs_matrix[item_id, :num_neighbors] = torch.tensor( + probs, dtype=torch.float32, device=self.device + ) + + else: + # Нет соседей: равномерное распределение на all_items + if actual_max_neighbors >= num_all_items: + # Можем поместить всех айтемов + neighbors_matrix[item_id, :num_all_items] = self.all_item_ids_tensor + probs_matrix[item_id, :num_all_items] = 1.0 / num_all_items + else: + # Выбираем случайное подмножество + indices = torch.randperm(num_all_items, device=self.device)[:actual_max_neighbors] + neighbors_matrix[item_id] = self.all_item_ids_tensor[indices] + probs_matrix[item_id] = 1.0 / actual_max_neighbors + + self.neighbors_matrix = neighbors_matrix + self.probs_matrix = probs_matrix + + print(f"Cooccurrence tables built: {num_items_with_cooc}/{max_item_id + 1} items have real neighbors") + + def __call__(self, batch): + self.call_count += 1 + + item_ids = batch['item_id'] # [batch_size] + batch_size = item_ids.shape[0] + + # Берем вероятности для items в батче + probs = self.probs_matrix[item_ids] # [batch_size, max_neighbors] + + # Выбираем индекс соседа для каждого item + # torch.multinomial: выбирает из max_neighbors категорий по вероятностям + # Результат: [batch_size, 1] - индексы в диапазоне [0, max_neighbors) + neighbor_indices = torch.multinomial(probs, num_samples=1, replacement=True) + neighbor_indices = neighbor_indices.squeeze(1) # [batch_size] + + # neighbors_matrix[item_ids, neighbor_indices] -> [batch_size] + cooc_ids = self.neighbors_matrix[item_ids, neighbor_indices] + + # Lookup эмбеддингов + cooccurrence_embeddings = self.embedding_matrix[cooc_ids] # [batch_size, embedding_dim] + + batch['cooccurrence_embedding'] = cooccurrence_embeddings + + # if self.call_count % 500 == 0: + # print( + # f"Call #{self.call_count}: {batch_size} samples, " + # f"cooc_embeddings shape: {cooccurrence_embeddings.shape}" + # ) + + return batch \ No newline at end of file diff --git a/scripts/plum-yambda/yambda_4_1_train_plum.py b/scripts/plum-yambda/yambda_4_1_train_plum.py new file mode 100644 index 0000000..86e0c2b --- /dev/null +++ b/scripts/plum-yambda/yambda_4_1_train_plum.py @@ -0,0 +1,186 @@ +from loguru import logger +import os + +import torch + +import pickle + +import irec.callbacks as cb +from irec.data.dataloader import DataLoader +from irec.data.transforms import Collate, ToTorch, ToDevice +from irec.runners import TrainingRunner + +from irec.utils import fix_random_seed + +from callbacks import InitCodebooks, FixDeadCentroids +from data import EmbeddingDatasetParquet, ProcessEmbeddings +from models import PlumRQVAE +from transforms import AddWeightedCooccurrenceEmbeddingsVectorized +from cooc_data import CoocMappingDataset + +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + +NUM_EPOCHS = 35 +BATCH_SIZE = 1024 + +MAX_NEIGHBOURS_COUNT = 1000 + +INPUT_DIM = 128 +HIDDEN_DIM = 32 +CODEBOOK_SIZE = 256 +NUM_CODEBOOKS = 3 +BETA = 0.25 +LR = 1e-4 +WINDOW_SIZE = 2 + +EXPERIMENT_NAME = f'4-1_filtered_yambda_gpu_week_ws_{WINDOW_SIZE}' +INTER_TRAIN_PATH = "/home/jovyan/IRec/data/Yambda/week-splits/merged_for_exps_filtered/exp_4-1_0.9_inter_semantics_train.json" #отсекать старое (может и нет) +EMBEDDINGS_PATH = "/home/jovyan/IRec/sigir/yambda_data/yambda_embeddings_reindexed.parquet" +IREC_PATH = '../../' + +print(INTER_TRAIN_PATH) +def main(): + fix_random_seed(SEED_VALUE) + + data = CoocMappingDataset.create_from_split_part( + train_inter_json_path=INTER_TRAIN_PATH, + window_size=WINDOW_SIZE + ) + + dataset = EmbeddingDatasetParquet( + data_path=EMBEDDINGS_PATH + ) + + item_id_to_embedding = {} + all_item_ids = [] + for idx in range(len(dataset)): + sample = dataset[idx] + item_id = int(sample['item_id']) + item_id_to_embedding[item_id] = torch.tensor(sample['embedding'], device=DEVICE) + all_item_ids.append(item_id) + + add_cooc_transform = AddWeightedCooccurrenceEmbeddingsVectorized( + cooccur_counts=data.cooccur_counter_mapping, + item_id_to_embedding=item_id_to_embedding, + all_item_ids=all_item_ids, + device=DEVICE, + limit_neighbors=True, + max_neighbors = MAX_NEIGHBOURS_COUNT + ) + + train_dataloader = DataLoader( #call в основном потоке делается нужно исправить + dataset, + batch_size=BATCH_SIZE, + shuffle=True, + drop_last=True, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map( + ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding']) + ).map(add_cooc_transform + ).repeat(NUM_EPOCHS) + + valid_dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=False, + drop_last=False, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding']) + ).map(add_cooc_transform) + + LOG_EVERY_NUM_STEPS = int(len(train_dataloader) // NUM_EPOCHS) + + model = PlumRQVAE( + input_dim=INPUT_DIM, + num_codebooks=NUM_CODEBOOKS, + codebook_size=CODEBOOK_SIZE, + embedding_dim=HIDDEN_DIM, + beta=BETA, + quant_loss_weight=1.0, + contrastive_loss_weight=1.0, + temperature=1.0 + ).to(DEVICE) + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + optimizer = torch.optim.Adam(model.parameters(), lr=LR, fused=True) + + callbacks = [ + InitCodebooks(valid_dataloader), + + cb.BatchMetrics(metrics=lambda model_outputs, batch: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + 'con_loss': model_outputs['con_loss'] + }, name='train'), + + FixDeadCentroids(valid_dataloader), + + cb.MetricAccumulator( + accumulators={ + 'train/loss': cb.MeanAccumulator(), + 'train/recon_loss': cb.MeanAccumulator(), + 'train/rqvae_loss': cb.MeanAccumulator(), + 'train/con_loss': cb.MeanAccumulator(), + 'num_dead/0': cb.MeanAccumulator(), + 'num_dead/1': cb.MeanAccumulator(), + 'num_dead/2': cb.MeanAccumulator(), + }, + reset_every_num_steps=LOG_EVERY_NUM_STEPS + ), + + cb.Validation( + dataset=valid_dataloader, + callbacks=[ + cb.BatchMetrics(metrics=lambda model_outputs, batch: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + 'con_loss': model_outputs['con_loss'] + }, name='valid'), + cb.MetricAccumulator( + accumulators={ + 'valid/loss': cb.MeanAccumulator(), + 'valid/recon_loss': cb.MeanAccumulator(), + 'valid/rqvae_loss': cb.MeanAccumulator(), + 'valid/con_loss': cb.MeanAccumulator() + } + ), + ], + ).every_num_steps(LOG_EVERY_NUM_STEPS), + + cb.Logger().every_num_steps(LOG_EVERY_NUM_STEPS), + cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=os.path.join(IREC_PATH, 'tensorboard_logs')), + + cb.Profiler( + wait=10, + warmup=10, + active=10, + logdir=os.path.join(IREC_PATH, 'tensorboard_logs') + ), + + cb.EarlyStopping( + metric='valid/recon_loss', + patience=40, + minimize=True, + model_path=os.path.join(IREC_PATH, 'checkpoints', EXPERIMENT_NAME) + ).every_num_steps(LOG_EVERY_NUM_STEPS), + ] + + logger.debug('Everything is ready for training process!') + + runner = TrainingRunner( + model=model, + optimizer=optimizer, + dataset=train_dataloader, + callbacks=callbacks, + ) + runner.run() + + +if __name__ == '__main__': + main() diff --git a/scripts/plum-yambda/yambda_infer_4.1_default.py b/scripts/plum-yambda/yambda_infer_4.1_default.py new file mode 100644 index 0000000..1485fde --- /dev/null +++ b/scripts/plum-yambda/yambda_infer_4.1_default.py @@ -0,0 +1,145 @@ +from loguru import logger +import os + +import torch + +import irec.callbacks as cb +from irec.data.dataloader import DataLoader +from irec.data.transforms import Collate, ToTorch, ToDevice +from irec.runners import InferenceRunner + +from irec.utils import fix_random_seed + +from data import EmbeddingDatasetParquet, ProcessEmbeddings +from models import PlumRQVAE + +# ПУТИ +IREC_PATH = '/home/jovyan/IRec/' +EMBEDDINGS_PATH = "/home/jovyan/IRec/sigir/yambda_data/yambda_embeddings_reindexed.parquet" +MODEL_PATH = '/home/jovyan/IRec/checkpoints/4-1_filtered_yambda_gpu_quantile_ws_2_best_0.0026.pth' +RESULTS_PATH = os.path.join(IREC_PATH, 'results_sigir_yambda') + +WINDOW_SIZE = 2 +EXPERIMENT_NAME = f'4-1_filtered_yambda_gpu_quantile_ws_{WINDOW_SIZE}' + +# ОСТАЛЬНОЕ + +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + +BATCH_SIZE = 1024 + +INPUT_DIM = 128 +HIDDEN_DIM = 32 +CODEBOOK_SIZE = 256 +NUM_CODEBOOKS = 3 + +BETA = 0.25 + + + +def main(): + fix_random_seed(SEED_VALUE) + + dataset = EmbeddingDatasetParquet( + data_path=EMBEDDINGS_PATH + ) + + item_id_to_embedding = {} + all_item_ids = [] + for idx in range(len(dataset)): + sample = dataset[idx] + item_id = int(sample['item_id']) + item_id_to_embedding[item_id] = torch.tensor(sample['embedding']) + all_item_ids.append(item_id) + + dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=False, + drop_last=False, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])) + + model = PlumRQVAE( + input_dim=INPUT_DIM, + num_codebooks=NUM_CODEBOOKS, + codebook_size=CODEBOOK_SIZE, + embedding_dim=HIDDEN_DIM, + beta=BETA, + quant_loss_weight=1.0, + contrastive_loss_weight=1.0, + temperature=1.0 + ).to(DEVICE) + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + callbacks = [ + cb.LoadModel(MODEL_PATH), + + cb.BatchMetrics(metrics=lambda model_outputs, _: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + 'con_loss': model_outputs['con_loss'] + }, name='valid'), + + cb.MetricAccumulator( + accumulators={ + 'valid/loss': cb.MeanAccumulator(), + 'valid/recon_loss': cb.MeanAccumulator(), + 'valid/rqvae_loss': cb.MeanAccumulator(), + 'valid/con_loss': cb.MeanAccumulator(), + }, + ), + + cb.Logger().every_num_steps(len(dataloader)), + + cb.InferenceSaver( + metrics=lambda batch, model_outputs, _: {'item_id': batch['item_id'], 'clusters': model_outputs['clusters']}, + save_path=os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'), + format='json' + ) + ] + + logger.debug('Everything is ready for training process!') + + runner = InferenceRunner( + model=model, + dataset=dataloader, + callbacks=callbacks, + ) + runner.run() + + import json + from collections import defaultdict + import numpy as np + + with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'), 'r') as f: + mappings = json.load(f) + + inter = {} + sem_2_ids = defaultdict(list) + for mapping in mappings: + item_id = mapping['item_id'] + clusters = mapping['clusters'] + inter[int(item_id)] = clusters + sem_2_ids[tuple(clusters)].append(int(item_id)) + + for semantics, items in sem_2_ids.items(): + assert len(items) <= CODEBOOK_SIZE, str(len(items)) + collision_solvers = np.random.permutation(CODEBOOK_SIZE)[:len(items)].tolist() + for item_id, collision_solver in zip(items, collision_solvers): + inter[item_id].append(collision_solver) + for i in range(len(inter[item_id])): + inter[item_id][i] += CODEBOOK_SIZE * i + + with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters_colisionless.json'), 'w') as f: + json.dump(inter, f, indent=2) + + +if __name__ == '__main__': + main() diff --git a/scripts/plum/beauty-exps/4_1_train_plum.py b/scripts/plum/beauty-exps/4_1_train_plum.py new file mode 100644 index 0000000..357fc19 --- /dev/null +++ b/scripts/plum/beauty-exps/4_1_train_plum.py @@ -0,0 +1,169 @@ +from loguru import logger +import os + +import torch + +import pickle + +import irec.callbacks as cb +from irec.data.dataloader import DataLoader +from irec.data.transforms import Collate, ToTorch, ToDevice +from irec.runners import TrainingRunner + +from irec.utils import fix_random_seed + +from callbacks import InitCodebooks, FixDeadCentroids +from data import EmbeddingDataset, ProcessEmbeddings +from models import PlumRQVAE +from transforms import AddWeightedCooccurrenceEmbeddings +from cooc_data import CoocMappingDataset + +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + +NUM_EPOCHS = 500 +BATCH_SIZE = 1024 + +INPUT_DIM = 4096 +HIDDEN_DIM = 32 +CODEBOOK_SIZE = 256 +NUM_CODEBOOKS = 3 +BETA = 0.25 +LR = 1e-4 +WINDOW_SIZE = 2 + +EXPERIMENT_NAME = f'4-1_yambda_quantile_ws_{WINDOW_SIZE}' +INTER_TRAIN_PATH = "/home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps/exp_4-1_0.9_inter_semantics_train.json" +EMBEDDINGS_PATH = "/home/jovyan/tiger/data/Beauty/default_content_embeddings.pkl" +IREC_PATH = '../../' + +print(INTER_TRAIN_PATH) +def main(): + fix_random_seed(SEED_VALUE) + + data = CoocMappingDataset.create_from_split_part( + train_inter_json_path=INTER_TRAIN_PATH, + window_size=WINDOW_SIZE + ) + + dataset = EmbeddingDataset( + data_path=EMBEDDINGS_PATH + ) + + item_id_to_embedding = {} + all_item_ids = [] + for idx in range(len(dataset)): + sample = dataset[idx] + item_id = int(sample['item_id']) + item_id_to_embedding[item_id] = torch.tensor(sample['embedding']) + all_item_ids.append(item_id) + + add_cooc_transform = AddWeightedCooccurrenceEmbeddings( + data.cooccur_counter_mapping, item_id_to_embedding, all_item_ids) + + train_dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=True, + drop_last=True, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map( + ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding']) + ).map(add_cooc_transform).repeat(NUM_EPOCHS) + + valid_dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=False, + drop_last=False, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])).map(add_cooc_transform) + + LOG_EVERY_NUM_STEPS = int(len(train_dataloader) // NUM_EPOCHS) + + model = PlumRQVAE( + input_dim=INPUT_DIM, + num_codebooks=NUM_CODEBOOKS, + codebook_size=CODEBOOK_SIZE, + embedding_dim=HIDDEN_DIM, + beta=BETA, + quant_loss_weight=1.0, + contrastive_loss_weight=1.0, + temperature=1.0 + ).to(DEVICE) + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + optimizer = torch.optim.Adam(model.parameters(), lr=LR, fused=True) + + callbacks = [ + InitCodebooks(valid_dataloader), + + cb.BatchMetrics(metrics=lambda model_outputs, batch: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + 'con_loss': model_outputs['con_loss'] + }, name='train'), + + FixDeadCentroids(valid_dataloader), + + cb.MetricAccumulator( + accumulators={ + 'train/loss': cb.MeanAccumulator(), + 'train/recon_loss': cb.MeanAccumulator(), + 'train/rqvae_loss': cb.MeanAccumulator(), + 'train/con_loss': cb.MeanAccumulator(), + 'num_dead/0': cb.MeanAccumulator(), + 'num_dead/1': cb.MeanAccumulator(), + 'num_dead/2': cb.MeanAccumulator(), + }, + reset_every_num_steps=LOG_EVERY_NUM_STEPS + ), + + cb.Validation( + dataset=valid_dataloader, + callbacks=[ + cb.BatchMetrics(metrics=lambda model_outputs, batch: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + 'con_loss': model_outputs['con_loss'] + }, name='valid'), + cb.MetricAccumulator( + accumulators={ + 'valid/loss': cb.MeanAccumulator(), + 'valid/recon_loss': cb.MeanAccumulator(), + 'valid/rqvae_loss': cb.MeanAccumulator(), + 'valid/con_loss': cb.MeanAccumulator() + } + ), + ], + ).every_num_steps(LOG_EVERY_NUM_STEPS), + + cb.Logger().every_num_steps(LOG_EVERY_NUM_STEPS), + cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=os.path.join(IREC_PATH, 'tensorboard_logs')), + + cb.EarlyStopping( + metric='valid/recon_loss', + patience=40, + minimize=True, + model_path=os.path.join(IREC_PATH, 'checkpoints', EXPERIMENT_NAME) + ).every_num_steps(LOG_EVERY_NUM_STEPS), + ] + + logger.debug('Everything is ready for training process!') + + runner = TrainingRunner( + model=model, + optimizer=optimizer, + dataset=train_dataloader, + callbacks=callbacks, + ) + runner.run() + + +if __name__ == '__main__': + main() diff --git a/scripts/plum/beauty-exps/4_2_train_plum.py b/scripts/plum/beauty-exps/4_2_train_plum.py new file mode 100644 index 0000000..96cfda9 --- /dev/null +++ b/scripts/plum/beauty-exps/4_2_train_plum.py @@ -0,0 +1,169 @@ +from loguru import logger +import os + +import torch + +import pickle + +import irec.callbacks as cb +from irec.data.dataloader import DataLoader +from irec.data.transforms import Collate, ToTorch, ToDevice +from irec.runners import TrainingRunner + +from irec.utils import fix_random_seed + +from callbacks import InitCodebooks, FixDeadCentroids +from data import EmbeddingDataset, ProcessEmbeddings +from models import PlumRQVAE +from transforms import AddWeightedCooccurrenceEmbeddings +from cooc_data import CoocMappingDataset + +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + +NUM_EPOCHS = 500 +BATCH_SIZE = 1024 + +INPUT_DIM = 4096 +HIDDEN_DIM = 32 +CODEBOOK_SIZE = 256 +NUM_CODEBOOKS = 3 +BETA = 0.25 +LR = 1e-4 +WINDOW_SIZE = 2 + +EXPERIMENT_NAME = f'4-2_updated_quantile_plum_rqvae_beauty_ws_{WINDOW_SIZE}' +INTER_TRAIN_PATH = "/home/jovyan/IRec/sigir/Beauty_new/updated_quantile_splits/merged_for_exps/exp_4-2_0.8_inter_semantics_train.json" +EMBEDDINGS_PATH = "/home/jovyan/tiger/data/Beauty/default_content_embeddings.pkl" +IREC_PATH = '../../' + +print(INTER_TRAIN_PATH) +def main(): + fix_random_seed(SEED_VALUE) + + data = CoocMappingDataset.create_from_split_part( + train_inter_json_path=INTER_TRAIN_PATH, + window_size=WINDOW_SIZE + ) + + dataset = EmbeddingDataset( + data_path=EMBEDDINGS_PATH + ) + + item_id_to_embedding = {} + all_item_ids = [] + for idx in range(len(dataset)): + sample = dataset[idx] + item_id = int(sample['item_id']) + item_id_to_embedding[item_id] = torch.tensor(sample['embedding']) + all_item_ids.append(item_id) + + add_cooc_transform = AddWeightedCooccurrenceEmbeddings( + data.cooccur_counter_mapping, item_id_to_embedding, all_item_ids) + + train_dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=True, + drop_last=True, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map( + ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding']) + ).map(add_cooc_transform).repeat(NUM_EPOCHS) + + valid_dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=False, + drop_last=False, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])).map(add_cooc_transform) + + LOG_EVERY_NUM_STEPS = int(len(train_dataloader) // NUM_EPOCHS) + + model = PlumRQVAE( + input_dim=INPUT_DIM, + num_codebooks=NUM_CODEBOOKS, + codebook_size=CODEBOOK_SIZE, + embedding_dim=HIDDEN_DIM, + beta=BETA, + quant_loss_weight=1.0, + contrastive_loss_weight=1.0, + temperature=1.0 + ).to(DEVICE) + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + optimizer = torch.optim.Adam(model.parameters(), lr=LR, fused=True) + + callbacks = [ + InitCodebooks(valid_dataloader), + + cb.BatchMetrics(metrics=lambda model_outputs, batch: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + 'con_loss': model_outputs['con_loss'] + }, name='train'), + + FixDeadCentroids(valid_dataloader), + + cb.MetricAccumulator( + accumulators={ + 'train/loss': cb.MeanAccumulator(), + 'train/recon_loss': cb.MeanAccumulator(), + 'train/rqvae_loss': cb.MeanAccumulator(), + 'train/con_loss': cb.MeanAccumulator(), + 'num_dead/0': cb.MeanAccumulator(), + 'num_dead/1': cb.MeanAccumulator(), + 'num_dead/2': cb.MeanAccumulator(), + }, + reset_every_num_steps=LOG_EVERY_NUM_STEPS + ), + + cb.Validation( + dataset=valid_dataloader, + callbacks=[ + cb.BatchMetrics(metrics=lambda model_outputs, batch: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + 'con_loss': model_outputs['con_loss'] + }, name='valid'), + cb.MetricAccumulator( + accumulators={ + 'valid/loss': cb.MeanAccumulator(), + 'valid/recon_loss': cb.MeanAccumulator(), + 'valid/rqvae_loss': cb.MeanAccumulator(), + 'valid/con_loss': cb.MeanAccumulator() + } + ), + ], + ).every_num_steps(LOG_EVERY_NUM_STEPS), + + cb.Logger().every_num_steps(LOG_EVERY_NUM_STEPS), + cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=os.path.join(IREC_PATH, 'tensorboard_logs')), + + cb.EarlyStopping( + metric='valid/recon_loss', + patience=40, + minimize=True, + model_path=os.path.join(IREC_PATH, 'checkpoints', EXPERIMENT_NAME) + ).every_num_steps(LOG_EVERY_NUM_STEPS), + ] + + logger.debug('Everything is ready for training process!') + + runner = TrainingRunner( + model=model, + optimizer=optimizer, + dataset=train_dataloader, + callbacks=callbacks, + ) + runner.run() + + +if __name__ == '__main__': + main() diff --git a/scripts/plum/beauty-exps/4_3_train_plum.py b/scripts/plum/beauty-exps/4_3_train_plum.py new file mode 100644 index 0000000..ac6cfb6 --- /dev/null +++ b/scripts/plum/beauty-exps/4_3_train_plum.py @@ -0,0 +1,169 @@ +from loguru import logger +import os + +import torch + +import pickle + +import irec.callbacks as cb +from irec.data.dataloader import DataLoader +from irec.data.transforms import Collate, ToTorch, ToDevice +from irec.runners import TrainingRunner + +from irec.utils import fix_random_seed + +from callbacks import InitCodebooks, FixDeadCentroids +from data import EmbeddingDataset, ProcessEmbeddings +from models import PlumRQVAE +from transforms import AddWeightedCooccurrenceEmbeddings +from cooc_data import CoocMappingDataset + +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + +NUM_EPOCHS = 500 +BATCH_SIZE = 1024 + +INPUT_DIM = 4096 +HIDDEN_DIM = 32 +CODEBOOK_SIZE = 256 +NUM_CODEBOOKS = 3 +BETA = 0.25 +LR = 1e-4 +WINDOW_SIZE = 2 + +EXPERIMENT_NAME = f'4-3_updated_quantile_plum_rqvae_beauty_ws_{WINDOW_SIZE}' +INTER_TRAIN_PATH = "/home/jovyan/IRec/sigir/Beauty_new/updated_quantile_splits/merged_for_exps/exp_4-3_0.95_inter_semantics_train.json" +EMBEDDINGS_PATH = "/home/jovyan/tiger/data/Beauty/default_content_embeddings.pkl" +IREC_PATH = '../../' + +print(INTER_TRAIN_PATH) +def main(): + fix_random_seed(SEED_VALUE) + + data = CoocMappingDataset.create_from_split_part( + train_inter_json_path=INTER_TRAIN_PATH, + window_size=WINDOW_SIZE + ) + + dataset = EmbeddingDataset( + data_path=EMBEDDINGS_PATH + ) + + item_id_to_embedding = {} + all_item_ids = [] + for idx in range(len(dataset)): + sample = dataset[idx] + item_id = int(sample['item_id']) + item_id_to_embedding[item_id] = torch.tensor(sample['embedding']) + all_item_ids.append(item_id) + + add_cooc_transform = AddWeightedCooccurrenceEmbeddings( + data.cooccur_counter_mapping, item_id_to_embedding, all_item_ids) + + train_dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=True, + drop_last=True, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map( + ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding']) + ).map(add_cooc_transform).repeat(NUM_EPOCHS) + + valid_dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=False, + drop_last=False, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])).map(add_cooc_transform) + + LOG_EVERY_NUM_STEPS = int(len(train_dataloader) // NUM_EPOCHS) + + model = PlumRQVAE( + input_dim=INPUT_DIM, + num_codebooks=NUM_CODEBOOKS, + codebook_size=CODEBOOK_SIZE, + embedding_dim=HIDDEN_DIM, + beta=BETA, + quant_loss_weight=1.0, + contrastive_loss_weight=1.0, + temperature=1.0 + ).to(DEVICE) + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + optimizer = torch.optim.Adam(model.parameters(), lr=LR, fused=True) + + callbacks = [ + InitCodebooks(valid_dataloader), + + cb.BatchMetrics(metrics=lambda model_outputs, batch: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + 'con_loss': model_outputs['con_loss'] + }, name='train'), + + FixDeadCentroids(valid_dataloader), + + cb.MetricAccumulator( + accumulators={ + 'train/loss': cb.MeanAccumulator(), + 'train/recon_loss': cb.MeanAccumulator(), + 'train/rqvae_loss': cb.MeanAccumulator(), + 'train/con_loss': cb.MeanAccumulator(), + 'num_dead/0': cb.MeanAccumulator(), + 'num_dead/1': cb.MeanAccumulator(), + 'num_dead/2': cb.MeanAccumulator(), + }, + reset_every_num_steps=LOG_EVERY_NUM_STEPS + ), + + cb.Validation( + dataset=valid_dataloader, + callbacks=[ + cb.BatchMetrics(metrics=lambda model_outputs, batch: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + 'con_loss': model_outputs['con_loss'] + }, name='valid'), + cb.MetricAccumulator( + accumulators={ + 'valid/loss': cb.MeanAccumulator(), + 'valid/recon_loss': cb.MeanAccumulator(), + 'valid/rqvae_loss': cb.MeanAccumulator(), + 'valid/con_loss': cb.MeanAccumulator() + } + ), + ], + ).every_num_steps(LOG_EVERY_NUM_STEPS), + + cb.Logger().every_num_steps(LOG_EVERY_NUM_STEPS), + cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=os.path.join(IREC_PATH, 'tensorboard_logs')), + + cb.EarlyStopping( + metric='valid/recon_loss', + patience=40, + minimize=True, + model_path=os.path.join(IREC_PATH, 'checkpoints', EXPERIMENT_NAME) + ).every_num_steps(LOG_EVERY_NUM_STEPS), + ] + + logger.debug('Everything is ready for training process!') + + runner = TrainingRunner( + model=model, + optimizer=optimizer, + dataset=train_dataloader, + callbacks=callbacks, + ) + runner.run() + + +if __name__ == '__main__': + main() diff --git a/scripts/plum/beauty-exps/infer_4.1_default.py b/scripts/plum/beauty-exps/infer_4.1_default.py new file mode 100644 index 0000000..fff61d3 --- /dev/null +++ b/scripts/plum/beauty-exps/infer_4.1_default.py @@ -0,0 +1,145 @@ +from loguru import logger +import os + +import torch + +import irec.callbacks as cb +from irec.data.dataloader import DataLoader +from irec.data.transforms import Collate, ToTorch, ToDevice +from irec.runners import InferenceRunner + +from irec.utils import fix_random_seed + +from data import EmbeddingDataset, ProcessEmbeddings +from models import PlumRQVAE + +# ПУТИ +IREC_PATH = '/home/jovyan/IRec/' +EMBEDDINGS_PATH = '/home/jovyan/tiger/data/Beauty/default_content_embeddings.pkl' +MODEL_PATH = '/home/jovyan/IRec/checkpoints/4-1_updated_quantile_plum_rqvae_beauty_ws_2_best_0.0052.pth' +RESULTS_PATH = os.path.join(IREC_PATH, 'results_sigir') + +WINDOW_SIZE = 2 +EXPERIMENT_NAME = f'4-1_updated_quantile_plum_rqvae_beauty_ws_{WINDOW_SIZE}' + +# ОСТАЛЬНОЕ + +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + +BATCH_SIZE = 1024 + +INPUT_DIM = 4096 +HIDDEN_DIM = 32 +CODEBOOK_SIZE = 256 +NUM_CODEBOOKS = 3 + +BETA = 0.25 + + + +def main(): + fix_random_seed(SEED_VALUE) + + dataset = EmbeddingDataset( + data_path=EMBEDDINGS_PATH + ) + + item_id_to_embedding = {} + all_item_ids = [] + for idx in range(len(dataset)): + sample = dataset[idx] + item_id = int(sample['item_id']) + item_id_to_embedding[item_id] = torch.tensor(sample['embedding']) + all_item_ids.append(item_id) + + dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=False, + drop_last=False, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])) + + model = PlumRQVAE( + input_dim=INPUT_DIM, + num_codebooks=NUM_CODEBOOKS, + codebook_size=CODEBOOK_SIZE, + embedding_dim=HIDDEN_DIM, + beta=BETA, + quant_loss_weight=1.0, + contrastive_loss_weight=1.0, + temperature=1.0 + ).to(DEVICE) + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + callbacks = [ + cb.LoadModel(MODEL_PATH), + + cb.BatchMetrics(metrics=lambda model_outputs, _: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + 'con_loss': model_outputs['con_loss'] + }, name='valid'), + + cb.MetricAccumulator( + accumulators={ + 'valid/loss': cb.MeanAccumulator(), + 'valid/recon_loss': cb.MeanAccumulator(), + 'valid/rqvae_loss': cb.MeanAccumulator(), + 'valid/con_loss': cb.MeanAccumulator(), + }, + ), + + cb.Logger().every_num_steps(len(dataloader)), + + cb.InferenceSaver( + metrics=lambda batch, model_outputs, _: {'item_id': batch['item_id'], 'clusters': model_outputs['clusters']}, + save_path=os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'), + format='json' + ) + ] + + logger.debug('Everything is ready for training process!') + + runner = InferenceRunner( + model=model, + dataset=dataloader, + callbacks=callbacks, + ) + runner.run() + + import json + from collections import defaultdict + import numpy as np + + with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'), 'r') as f: + mappings = json.load(f) + + inter = {} + sem_2_ids = defaultdict(list) + for mapping in mappings: + item_id = mapping['item_id'] + clusters = mapping['clusters'] + inter[int(item_id)] = clusters + sem_2_ids[tuple(clusters)].append(int(item_id)) + + for semantics, items in sem_2_ids.items(): + assert len(items) <= CODEBOOK_SIZE, str(len(items)) + collision_solvers = np.random.permutation(CODEBOOK_SIZE)[:len(items)].tolist() + for item_id, collision_solver in zip(items, collision_solvers): + inter[item_id].append(collision_solver) + for i in range(len(inter[item_id])): + inter[item_id][i] += CODEBOOK_SIZE * i + + with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters_colisionless.json'), 'w') as f: + json.dump(inter, f, indent=2) + + +if __name__ == '__main__': + main() diff --git a/scripts/plum/beauty-exps/infer_4.2_default.py b/scripts/plum/beauty-exps/infer_4.2_default.py new file mode 100644 index 0000000..c5c7c02 --- /dev/null +++ b/scripts/plum/beauty-exps/infer_4.2_default.py @@ -0,0 +1,145 @@ +from loguru import logger +import os + +import torch + +import irec.callbacks as cb +from irec.data.dataloader import DataLoader +from irec.data.transforms import Collate, ToTorch, ToDevice +from irec.runners import InferenceRunner + +from irec.utils import fix_random_seed + +from data import EmbeddingDataset, ProcessEmbeddings +from models import PlumRQVAE + +# ПУТИ +IREC_PATH = '/home/jovyan/IRec/' +EMBEDDINGS_PATH = '/home/jovyan/tiger/data/Beauty/default_content_embeddings.pkl' +MODEL_PATH = '/home/jovyan/IRec/checkpoints/4-2_updated_quantile_plum_rqvae_beauty_ws_2_best_0.0051.pth' +RESULTS_PATH = os.path.join(IREC_PATH, 'results_sigir') + +WINDOW_SIZE = 2 +EXPERIMENT_NAME = f'4-2_updated_quantile_plum_rqvae_beauty_ws_{WINDOW_SIZE}' + +# ОСТАЛЬНОЕ + +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + +BATCH_SIZE = 1024 + +INPUT_DIM = 4096 +HIDDEN_DIM = 32 +CODEBOOK_SIZE = 256 +NUM_CODEBOOKS = 3 + +BETA = 0.25 + + + +def main(): + fix_random_seed(SEED_VALUE) + + dataset = EmbeddingDataset( + data_path=EMBEDDINGS_PATH + ) + + item_id_to_embedding = {} + all_item_ids = [] + for idx in range(len(dataset)): + sample = dataset[idx] + item_id = int(sample['item_id']) + item_id_to_embedding[item_id] = torch.tensor(sample['embedding']) + all_item_ids.append(item_id) + + dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=False, + drop_last=False, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])) + + model = PlumRQVAE( + input_dim=INPUT_DIM, + num_codebooks=NUM_CODEBOOKS, + codebook_size=CODEBOOK_SIZE, + embedding_dim=HIDDEN_DIM, + beta=BETA, + quant_loss_weight=1.0, + contrastive_loss_weight=1.0, + temperature=1.0 + ).to(DEVICE) + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + callbacks = [ + cb.LoadModel(MODEL_PATH), + + cb.BatchMetrics(metrics=lambda model_outputs, _: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + 'con_loss': model_outputs['con_loss'] + }, name='valid'), + + cb.MetricAccumulator( + accumulators={ + 'valid/loss': cb.MeanAccumulator(), + 'valid/recon_loss': cb.MeanAccumulator(), + 'valid/rqvae_loss': cb.MeanAccumulator(), + 'valid/con_loss': cb.MeanAccumulator(), + }, + ), + + cb.Logger().every_num_steps(len(dataloader)), + + cb.InferenceSaver( + metrics=lambda batch, model_outputs, _: {'item_id': batch['item_id'], 'clusters': model_outputs['clusters']}, + save_path=os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'), + format='json' + ) + ] + + logger.debug('Everything is ready for training process!') + + runner = InferenceRunner( + model=model, + dataset=dataloader, + callbacks=callbacks, + ) + runner.run() + + import json + from collections import defaultdict + import numpy as np + + with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'), 'r') as f: + mappings = json.load(f) + + inter = {} + sem_2_ids = defaultdict(list) + for mapping in mappings: + item_id = mapping['item_id'] + clusters = mapping['clusters'] + inter[int(item_id)] = clusters + sem_2_ids[tuple(clusters)].append(int(item_id)) + + for semantics, items in sem_2_ids.items(): + assert len(items) <= CODEBOOK_SIZE, str(len(items)) + collision_solvers = np.random.permutation(CODEBOOK_SIZE)[:len(items)].tolist() + for item_id, collision_solver in zip(items, collision_solvers): + inter[item_id].append(collision_solver) + for i in range(len(inter[item_id])): + inter[item_id][i] += CODEBOOK_SIZE * i + + with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters_colisionless.json'), 'w') as f: + json.dump(inter, f, indent=2) + + +if __name__ == '__main__': + main() diff --git a/scripts/plum/beauty-exps/infer_4.3_default.py b/scripts/plum/beauty-exps/infer_4.3_default.py new file mode 100644 index 0000000..c7fca80 --- /dev/null +++ b/scripts/plum/beauty-exps/infer_4.3_default.py @@ -0,0 +1,145 @@ +from loguru import logger +import os + +import torch + +import irec.callbacks as cb +from irec.data.dataloader import DataLoader +from irec.data.transforms import Collate, ToTorch, ToDevice +from irec.runners import InferenceRunner + +from irec.utils import fix_random_seed + +from data import EmbeddingDataset, ProcessEmbeddings +from models import PlumRQVAE + +# ПУТИ +IREC_PATH = '/home/jovyan/IRec/' +EMBEDDINGS_PATH = '/home/jovyan/tiger/data/Beauty/default_content_embeddings.pkl' +MODEL_PATH = '/home/jovyan/IRec/checkpoints/4-3_updated_quantile_plum_rqvae_beauty_ws_2_best_0.005.pth' +RESULTS_PATH = os.path.join(IREC_PATH, 'results_sigir') + +WINDOW_SIZE = 2 +EXPERIMENT_NAME = f'4-3_updated_quantile_plum_rqvae_beauty_ws_{WINDOW_SIZE}' + +# ОСТАЛЬНОЕ + +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + +BATCH_SIZE = 1024 + +INPUT_DIM = 4096 +HIDDEN_DIM = 32 +CODEBOOK_SIZE = 256 +NUM_CODEBOOKS = 3 + +BETA = 0.25 + + + +def main(): + fix_random_seed(SEED_VALUE) + + dataset = EmbeddingDataset( + data_path=EMBEDDINGS_PATH + ) + + item_id_to_embedding = {} + all_item_ids = [] + for idx in range(len(dataset)): + sample = dataset[idx] + item_id = int(sample['item_id']) + item_id_to_embedding[item_id] = torch.tensor(sample['embedding']) + all_item_ids.append(item_id) + + dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=False, + drop_last=False, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])) + + model = PlumRQVAE( + input_dim=INPUT_DIM, + num_codebooks=NUM_CODEBOOKS, + codebook_size=CODEBOOK_SIZE, + embedding_dim=HIDDEN_DIM, + beta=BETA, + quant_loss_weight=1.0, + contrastive_loss_weight=1.0, + temperature=1.0 + ).to(DEVICE) + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + callbacks = [ + cb.LoadModel(MODEL_PATH), + + cb.BatchMetrics(metrics=lambda model_outputs, _: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + 'con_loss': model_outputs['con_loss'] + }, name='valid'), + + cb.MetricAccumulator( + accumulators={ + 'valid/loss': cb.MeanAccumulator(), + 'valid/recon_loss': cb.MeanAccumulator(), + 'valid/rqvae_loss': cb.MeanAccumulator(), + 'valid/con_loss': cb.MeanAccumulator(), + }, + ), + + cb.Logger().every_num_steps(len(dataloader)), + + cb.InferenceSaver( + metrics=lambda batch, model_outputs, _: {'item_id': batch['item_id'], 'clusters': model_outputs['clusters']}, + save_path=os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'), + format='json' + ) + ] + + logger.debug('Everything is ready for training process!') + + runner = InferenceRunner( + model=model, + dataset=dataloader, + callbacks=callbacks, + ) + runner.run() + + import json + from collections import defaultdict + import numpy as np + + with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'), 'r') as f: + mappings = json.load(f) + + inter = {} + sem_2_ids = defaultdict(list) + for mapping in mappings: + item_id = mapping['item_id'] + clusters = mapping['clusters'] + inter[int(item_id)] = clusters + sem_2_ids[tuple(clusters)].append(int(item_id)) + + for semantics, items in sem_2_ids.items(): + assert len(items) <= CODEBOOK_SIZE, str(len(items)) + collision_solvers = np.random.permutation(CODEBOOK_SIZE)[:len(items)].tolist() + for item_id, collision_solver in zip(items, collision_solvers): + inter[item_id].append(collision_solver) + for i in range(len(inter[item_id])): + inter[item_id][i] += CODEBOOK_SIZE * i + + with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters_colisionless.json'), 'w') as f: + json.dump(inter, f, indent=2) + + +if __name__ == '__main__': + main() diff --git a/scripts/plum/beauty-exps/infer_default.py b/scripts/plum/beauty-exps/infer_default.py new file mode 100644 index 0000000..af8df34 --- /dev/null +++ b/scripts/plum/beauty-exps/infer_default.py @@ -0,0 +1,152 @@ +from loguru import logger +import os + +import torch + +import irec.callbacks as cb +from irec.data.dataloader import DataLoader +from irec.data.transforms import Collate, ToTorch, ToDevice +from irec.runners import InferenceRunner + +from irec.utils import fix_random_seed + +from data import EmbeddingDataset, ProcessEmbeddings +from models import PlumRQVAE +from transforms import AddWeightedCooccurrenceEmbeddings +from cooc_data import CoocMappingDataset + +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + +BATCH_SIZE = 1024 + +INPUT_DIM = 4096 +HIDDEN_DIM = 32 +CODEBOOK_SIZE = 256 +NUM_CODEBOOKS = 3 + +BETA = 0.25 +MODEL_PATH = '/home/jovyan/IRec/checkpoints/test_plum_rqvae_beauty_ws_2_best_0.0054.pth' + +WINDOW_SIZE = 2 + +EXPERIMENT_NAME = f'test_plum_rqvae_beauty_ws_{WINDOW_SIZE}' + +IREC_PATH = '/home/jovyan/IRec/' + + +def main(): + fix_random_seed(SEED_VALUE) + + data = CoocMappingDataset.create( + inter_json_path=os.path.join(IREC_PATH, 'data/Beauty/inter_new.json'), + max_sequence_length=20, + sampler_type='sasrec', + window_size=WINDOW_SIZE + ) + + dataset = EmbeddingDataset( + data_path='/home/jovyan/tiger/data/Beauty/default_content_embeddings.pkl' + ) + + item_id_to_embedding = {} + all_item_ids = [] + for idx in range(len(dataset)): + sample = dataset[idx] + item_id = int(sample['item_id']) + item_id_to_embedding[item_id] = torch.tensor(sample['embedding']) + all_item_ids.append(item_id) + + add_cooc_transform = AddWeightedCooccurrenceEmbeddings( + data.cooccur_counter_mapping, item_id_to_embedding, all_item_ids) + + dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=False, + drop_last=False, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])).map(add_cooc_transform) + + model = PlumRQVAE( + input_dim=INPUT_DIM, + num_codebooks=NUM_CODEBOOKS, + codebook_size=CODEBOOK_SIZE, + embedding_dim=HIDDEN_DIM, + beta=BETA, + quant_loss_weight=1.0, + contrastive_loss_weight=1.0, + temperature=1.0 + ).to(DEVICE) + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + callbacks = [ + cb.LoadModel(MODEL_PATH), + + cb.BatchMetrics(metrics=lambda model_outputs, _: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + 'con_loss': model_outputs['con_loss'] + }, name='valid'), + + cb.MetricAccumulator( + accumulators={ + 'valid/loss': cb.MeanAccumulator(), + 'valid/recon_loss': cb.MeanAccumulator(), + 'valid/rqvae_loss': cb.MeanAccumulator(), + 'valid/con_loss': cb.MeanAccumulator(), + }, + ), + + cb.Logger().every_num_steps(len(dataloader)), + + cb.InferenceSaver( + metrics=lambda batch, model_outputs, _: {'item_id': batch['item_id'], 'clusters': model_outputs['clusters']}, + save_path=f'/home/jovyan/IRec/results/{EXPERIMENT_NAME}_clusters.json', + format='json' + ) + ] + + logger.debug('Everything is ready for training process!') + + runner = InferenceRunner( + model=model, + dataset=dataloader, + callbacks=callbacks, + ) + runner.run() + + import json + from collections import defaultdict + import numpy as np + + with open(f'/home/jovyan/IRec/results/{EXPERIMENT_NAME}_clusters.json', 'r') as f: + mappings = json.load(f) + + inter = {} + sem_2_ids = defaultdict(list) + for mapping in mappings: + item_id = mapping['item_id'] + clusters = mapping['clusters'] + inter[int(item_id)] = clusters + sem_2_ids[tuple(clusters)].append(int(item_id)) + + for semantics, items in sem_2_ids.items(): + assert len(items) <= CODEBOOK_SIZE, str(len(items)) + collision_solvers = np.random.permutation(CODEBOOK_SIZE)[:len(items)].tolist() + for item_id, collision_solver in zip(items, collision_solvers): + inter[item_id].append(collision_solver) + for i in range(len(inter[item_id])): + inter[item_id][i] += CODEBOOK_SIZE * i + + with open(os.path.join(IREC_PATH, 'results', f'{EXPERIMENT_NAME}_clusters_colisionless.json'), 'w') as f: + json.dump(inter, f, indent=2) + + +if __name__ == '__main__': + main() diff --git a/scripts/plum/cooc_data.py b/scripts/plum/cooc_data.py index b11e6f0..50f2bdd 100644 --- a/scripts/plum/cooc_data.py +++ b/scripts/plum/cooc_data.py @@ -13,21 +13,15 @@ class CoocMappingDataset: def __init__( self, train_sampler, - validation_sampler, - test_sampler, num_items, - max_sequence_length, cooccur_counter_mapping=None ): self._train_sampler = train_sampler - self._validation_sampler = validation_sampler - self._test_sampler = test_sampler self._num_items = num_items - self._max_sequence_length = max_sequence_length self._cooccur_counter_mapping = cooccur_counter_mapping @classmethod - def create(cls, inter_json_path, max_sequence_length, sampler_type, window_size): + def create(cls, inter_json_path, window_size): max_item_id = 0 train_dataset, validation_dataset, test_dataset = [], [], [] @@ -43,31 +37,59 @@ def create(cls, inter_json_path, max_sequence_length, sampler_type, window_size) 'user.ids': [user_id], 'item.ids': item_ids[:-2], }) - validation_dataset.append({ - 'user.ids': [user_id], - 'item.ids': item_ids[:-1], - }) - test_dataset.append({ - 'user.ids': [user_id], - 'item.ids': item_ids, - }) cooccur_counter_mapping = cls.build_cooccur_counter_mapping(train_dataset, window_size=window_size) logger.debug(f'Computed window-based co-occurrence mapping for {len(cooccur_counter_mapping)} items but max_item_id is {max_item_id}') train_sampler = train_dataset - validation_sampler = validation_dataset - test_sampler = test_dataset return cls( train_sampler=train_sampler, - validation_sampler=validation_sampler, - test_sampler=test_sampler, num_items=max_item_id + 1, - max_sequence_length=max_sequence_length, cooccur_counter_mapping=cooccur_counter_mapping ) + @classmethod + def create_from_split_part( + cls, + train_inter_json_path, + window_size + ): + + max_item_id = 0 + train_dataset = [] + + with open(train_inter_json_path, 'r') as f: + train_interactions = json.load(f) + + # Обрабатываем TRAIN + for user_id_str, item_ids in train_interactions.items(): + user_id = int(user_id_str) + if item_ids: + max_item_id = max(max_item_id, max(item_ids)) + + train_dataset.append({ + 'user.ids': [user_id], + 'item.ids': item_ids, + }) + + logger.debug(f'Train: {len(train_dataset)} users') + logger.debug(f'Max item ID: {max_item_id}') + + cooccur_counter_mapping = cls.build_cooccur_counter_mapping( + train_dataset, + window_size=window_size + ) + + logger.debug(f'Computed window-based co-occurrence mapping for {len(cooccur_counter_mapping)} items') + + return cls( + train_sampler=train_dataset, + num_items=max_item_id + 1, + cooccur_counter_mapping=cooccur_counter_mapping + ) + + @staticmethod def build_cooccur_counter_mapping(train_dataset, window_size): #TODO передавать время и по нему строить окно cooccur_counts = defaultdict(Counter) @@ -80,16 +102,6 @@ def build_cooccur_counter_mapping(train_dataset, window_size): #TODO перед cooccur_counts[item_i][items[j]] += 1 return cooccur_counts - def get_datasets(self): - return self._train_sampler, self._validation_sampler, self._test_sampler - - @property - def num_items(self): - return self._num_items - - @property - def max_sequence_length(self): - return self._max_sequence_length @property def cooccur_counter_mapping(self): diff --git a/scripts/plum/data.py b/scripts/plum/data.py index 0ffef82..9c15b70 100644 --- a/scripts/plum/data.py +++ b/scripts/plum/data.py @@ -5,6 +5,29 @@ from irec.data.transforms import Transform +import polars as pl +import torch + +class EmbeddingDatasetParquet(BaseDataset): + def __init__(self, data_path): + self.df = pl.read_parquet(data_path) + self.item_ids = np.array(self.df['item_id'], dtype=np.int64) + self.embeddings = np.array(self.df['embedding'].to_list(), dtype=np.float32) + print(f"embedding dim: {self.embeddings[0].shape}") + + def __getitem__(self, idx): + index = self.item_ids[idx] + tensor_emb = self.embeddings[idx] + return { + 'item_id': index, + 'embedding': tensor_emb, + 'embedding_dim': len(tensor_emb) + } + + def __len__(self): + return len(self.embeddings) + + class EmbeddingDataset(BaseDataset): def __init__(self, data_path): self.data_path = data_path diff --git a/scripts/plum/infer_default.py b/scripts/plum/infer_default.py index af8df34..b15fb6d 100644 --- a/scripts/plum/infer_default.py +++ b/scripts/plum/infer_default.py @@ -12,8 +12,18 @@ from data import EmbeddingDataset, ProcessEmbeddings from models import PlumRQVAE -from transforms import AddWeightedCooccurrenceEmbeddings -from cooc_data import CoocMappingDataset + +# ПУТИ +IREC_PATH = '/home/jovyan/IRec/' +EMBEDDINGS_PATH = '/home/jovyan/tiger/data/Beauty/default_content_embeddings.pkl' +MODEL_PATH = '/home/jovyan/IRec/checkpoints/4-1_plum_rqvae_beauty_ws_2_best_0.0051.pth' +RESULTS_PATH = os.path.join(IREC_PATH, 'results') + +WINDOW_SIZE = 2 + +EXPERIMENT_NAME = f'test_plum_rqvae_beauty_ws_{WINDOW_SIZE}' + +# ОСТАЛЬНОЕ SEED_VALUE = 42 DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') @@ -26,29 +36,16 @@ NUM_CODEBOOKS = 3 BETA = 0.25 -MODEL_PATH = '/home/jovyan/IRec/checkpoints/test_plum_rqvae_beauty_ws_2_best_0.0054.pth' -WINDOW_SIZE = 2 - -EXPERIMENT_NAME = f'test_plum_rqvae_beauty_ws_{WINDOW_SIZE}' - -IREC_PATH = '/home/jovyan/IRec/' def main(): fix_random_seed(SEED_VALUE) - data = CoocMappingDataset.create( - inter_json_path=os.path.join(IREC_PATH, 'data/Beauty/inter_new.json'), - max_sequence_length=20, - sampler_type='sasrec', - window_size=WINDOW_SIZE - ) - dataset = EmbeddingDataset( - data_path='/home/jovyan/tiger/data/Beauty/default_content_embeddings.pkl' + data_path=EMBEDDINGS_PATH ) - + item_id_to_embedding = {} all_item_ids = [] for idx in range(len(dataset)): @@ -57,15 +54,12 @@ def main(): item_id_to_embedding[item_id] = torch.tensor(sample['embedding']) all_item_ids.append(item_id) - add_cooc_transform = AddWeightedCooccurrenceEmbeddings( - data.cooccur_counter_mapping, item_id_to_embedding, all_item_ids) - dataloader = DataLoader( dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, - ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])).map(add_cooc_transform) + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])) model = PlumRQVAE( input_dim=INPUT_DIM, @@ -106,8 +100,8 @@ def main(): cb.Logger().every_num_steps(len(dataloader)), cb.InferenceSaver( - metrics=lambda batch, model_outputs, _: {'item_id': batch['item_id'], 'clusters': model_outputs['clusters']}, - save_path=f'/home/jovyan/IRec/results/{EXPERIMENT_NAME}_clusters.json', + metrics=lambda batch, model_outputs, _: {'item_id': batch['item_id'], 'clusters': model_outputs['clusters']}, + save_path=os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'), format='json' ) ] @@ -125,9 +119,9 @@ def main(): from collections import defaultdict import numpy as np - with open(f'/home/jovyan/IRec/results/{EXPERIMENT_NAME}_clusters.json', 'r') as f: + with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'), 'r') as f: mappings = json.load(f) - + inter = {} sem_2_ids = defaultdict(list) for mapping in mappings: @@ -143,8 +137,8 @@ def main(): inter[item_id].append(collision_solver) for i in range(len(inter[item_id])): inter[item_id][i] += CODEBOOK_SIZE * i - - with open(os.path.join(IREC_PATH, 'results', f'{EXPERIMENT_NAME}_clusters_colisionless.json'), 'w') as f: + + with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters_colisionless.json'), 'w') as f: json.dump(inter, f, indent=2) diff --git a/scripts/plum/train_plum.py b/scripts/plum/train_plum.py index 5a00bc3..ffa9e43 100644 --- a/scripts/plum/train_plum.py +++ b/scripts/plum/train_plum.py @@ -41,8 +41,6 @@ def main(): data = CoocMappingDataset.create( inter_json_path=os.path.join(IREC_PATH, 'data/Beauty/inter_new.json'), - max_sequence_length=20, - sampler_type='sasrec', window_size=WINDOW_SIZE ) diff --git a/scripts/plum/train_plum_timestamp_based.py b/scripts/plum/train_plum_timestamp_based.py new file mode 100644 index 0000000..e755d95 --- /dev/null +++ b/scripts/plum/train_plum_timestamp_based.py @@ -0,0 +1,168 @@ +from loguru import logger +import os + +import torch + +import pickle + +import irec.callbacks as cb +from irec.data.dataloader import DataLoader +from irec.data.transforms import Collate, ToTorch, ToDevice +from irec.runners import TrainingRunner + +from irec.utils import fix_random_seed + +from callbacks import InitCodebooks, FixDeadCentroids +from data import EmbeddingDataset, ProcessEmbeddings +from models import PlumRQVAE +from transforms import AddWeightedCooccurrenceEmbeddings +from cooc_data import CoocMappingDataset + +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + +NUM_EPOCHS = 500 +BATCH_SIZE = 1024 + +INPUT_DIM = 4096 +HIDDEN_DIM = 32 +CODEBOOK_SIZE = 256 +NUM_CODEBOOKS = 3 +BETA = 0.25 +LR = 1e-4 +WINDOW_SIZE = 2 + +EXPERIMENT_NAME = f'4-1_plum_rqvae_beauty_ws_{WINDOW_SIZE}' +INTER_TRAIN_PATH = "/home/jovyan/IRec/sigir/Beauty_new/splits/exp_data/exp_4.1_inter_semantics_train.json" +EMBEDDINGS_PATH = "/home/jovyan/tiger/data/Beauty/default_content_embeddings.pkl" +IREC_PATH = '../../' + +def main(): + fix_random_seed(SEED_VALUE) + + data = CoocMappingDataset.create_from_split_part( + train_inter_json_path=INTER_TRAIN_PATH, + window_size=WINDOW_SIZE + ) + + dataset = EmbeddingDataset( + data_path=EMBEDDINGS_PATH + ) + + item_id_to_embedding = {} + all_item_ids = [] + for idx in range(len(dataset)): + sample = dataset[idx] + item_id = int(sample['item_id']) + item_id_to_embedding[item_id] = torch.tensor(sample['embedding']) + all_item_ids.append(item_id) + + add_cooc_transform = AddWeightedCooccurrenceEmbeddings( + data.cooccur_counter_mapping, item_id_to_embedding, all_item_ids) + + train_dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=True, + drop_last=True, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map( + ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding']) + ).map(add_cooc_transform).repeat(NUM_EPOCHS) + + valid_dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=False, + drop_last=False, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])).map(add_cooc_transform) + + LOG_EVERY_NUM_STEPS = int(len(train_dataloader) // NUM_EPOCHS) + + model = PlumRQVAE( + input_dim=INPUT_DIM, + num_codebooks=NUM_CODEBOOKS, + codebook_size=CODEBOOK_SIZE, + embedding_dim=HIDDEN_DIM, + beta=BETA, + quant_loss_weight=1.0, + contrastive_loss_weight=1.0, + temperature=1.0 + ).to(DEVICE) + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + optimizer = torch.optim.Adam(model.parameters(), lr=LR, fused=True) + + callbacks = [ + InitCodebooks(valid_dataloader), + + cb.BatchMetrics(metrics=lambda model_outputs, batch: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + 'con_loss': model_outputs['con_loss'] + }, name='train'), + + FixDeadCentroids(valid_dataloader), + + cb.MetricAccumulator( + accumulators={ + 'train/loss': cb.MeanAccumulator(), + 'train/recon_loss': cb.MeanAccumulator(), + 'train/rqvae_loss': cb.MeanAccumulator(), + 'train/con_loss': cb.MeanAccumulator(), + 'num_dead/0': cb.MeanAccumulator(), + 'num_dead/1': cb.MeanAccumulator(), + 'num_dead/2': cb.MeanAccumulator(), + }, + reset_every_num_steps=LOG_EVERY_NUM_STEPS + ), + + cb.Validation( + dataset=valid_dataloader, + callbacks=[ + cb.BatchMetrics(metrics=lambda model_outputs, batch: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + 'con_loss': model_outputs['con_loss'] + }, name='valid'), + cb.MetricAccumulator( + accumulators={ + 'valid/loss': cb.MeanAccumulator(), + 'valid/recon_loss': cb.MeanAccumulator(), + 'valid/rqvae_loss': cb.MeanAccumulator(), + 'valid/con_loss': cb.MeanAccumulator() + } + ), + ], + ).every_num_steps(LOG_EVERY_NUM_STEPS), + + cb.Logger().every_num_steps(LOG_EVERY_NUM_STEPS), + cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=os.path.join(IREC_PATH, 'tensorboard_logs')), + + cb.EarlyStopping( + metric='valid/recon_loss', + patience=40, + minimize=True, + model_path=os.path.join(IREC_PATH, 'checkpoints', EXPERIMENT_NAME) + ).every_num_steps(LOG_EVERY_NUM_STEPS), + ] + + logger.debug('Everything is ready for training process!') + + runner = TrainingRunner( + model=model, + optimizer=optimizer, + dataset=train_dataloader, + callbacks=callbacks, + ) + runner.run() + + +if __name__ == '__main__': + main() diff --git a/scripts/plum/transforms.py b/scripts/plum/transforms.py index 0af1dda..8887115 100644 --- a/scripts/plum/transforms.py +++ b/scripts/plum/transforms.py @@ -2,12 +2,6 @@ import pickle import torch -from irec.data.base import BaseDataset -from irec.data.transforms import Transform - -from cooc_data import CoocMappingDataset - - class AddWeightedCooccurrenceEmbeddings: def __init__(self, cooccur_counts, item_id_to_embedding, all_item_ids): self.cooccur_counts = cooccur_counts @@ -32,7 +26,7 @@ def __call__(self, batch): else: cooc_id = np.random.choice(self.all_item_ids) - if self.call_count % 10 == 0 and idx < 5: + if self.call_count % 500 == 0 and idx < 5: print(f" idx={idx}: item_id={item_id_val} fallback random") cooc_emb = self.item_id_to_embedding.get(cooc_id, batch['embedding'][0]) diff --git a/scripts/rqvae-yambda/callbacks.py b/scripts/rqvae-yambda/callbacks.py new file mode 100644 index 0000000..43ec460 --- /dev/null +++ b/scripts/rqvae-yambda/callbacks.py @@ -0,0 +1,64 @@ +import torch + +import irec.callbacks as cb +from irec.runners import TrainingRunner, TrainingRunnerContext + +class InitCodebooks(cb.TrainingCallback): + def __init__(self, dataloader): + super().__init__() + self._dataloader = dataloader + + @torch.no_grad() + def before_run(self, runner: TrainingRunner): + for i in range(len(runner.model.codebooks)): + X = next(iter(self._dataloader))['embedding'] + idx = torch.randperm(X.shape[0], device=X.device)[:len(runner.model.codebooks[i])] + remainder = runner.model.encoder(X[idx]) + + for j in range(i): + codebook_indices = runner.model.get_codebook_indices(remainder, runner.model.codebooks[j]) + codebook_vectors = runner.model.codebooks[j][codebook_indices] + remainder = remainder - codebook_vectors + + runner.model.codebooks[i].data = remainder.detach() + + +class FixDeadCentroids(cb.TrainingCallback): + def __init__(self, dataloader): + super().__init__() + self._dataloader = dataloader + + def after_step(self, runner: TrainingRunner, context: TrainingRunnerContext): + for i, num_fixed in enumerate(self.fix_dead_codebooks(runner)): + context.metrics[f'num_dead/{i}'] = num_fixed + + @torch.no_grad() + def fix_dead_codebooks(self, runner: TrainingRunner): + num_fixed = [] + for codebook_idx, codebook in enumerate(runner.model.codebooks): + centroid_counts = torch.zeros(codebook.shape[0], dtype=torch.long, device=codebook.device) + random_batch = next(iter(self._dataloader))['embedding'] + + for batch in self._dataloader: + remainder = runner.model.encoder(batch['embedding']) + for l in range(codebook_idx): + ind = runner.model.get_codebook_indices(remainder, runner.model.codebooks[l]) + remainder = remainder - runner.model.codebooks[l][ind] + + indices = runner.model.get_codebook_indices(remainder, codebook) + centroid_counts.scatter_add_(0, indices, torch.ones_like(indices)) + + dead_mask = (centroid_counts == 0) + num_dead = int(dead_mask.sum().item()) + num_fixed.append(num_dead) + if num_dead == 0: + continue + + remainder = runner.model.encoder(random_batch) + for l in range(codebook_idx): + ind = runner.model.get_codebook_indices(remainder, runner.model.codebooks[l]) + remainder = remainder - runner.model.codebooks[l][ind] + remainder = remainder[torch.randperm(remainder.shape[0], device=codebook.device)][:num_dead] + codebook[dead_mask] = remainder.detach() + + return num_fixed diff --git a/scripts/rqvae-yambda/data.py b/scripts/rqvae-yambda/data.py new file mode 100644 index 0000000..6c213ee --- /dev/null +++ b/scripts/rqvae-yambda/data.py @@ -0,0 +1,35 @@ +import numpy as np +import polars as pl + +from irec.data.base import BaseDataset +from irec.data.transforms import Transform + + +class EmbeddingDatasetParquet(BaseDataset): + def __init__(self, data_path): + self.df = pl.read_parquet(data_path) + self.item_ids = np.array(self.df['item_id'], dtype=np.int64) + self.embeddings = np.array(self.df['embedding'].to_list(), dtype=np.float32) + print(f"embedding dim: {self.embeddings[0].shape}") + + def __getitem__(self, idx): + index = self.item_ids[idx] + tensor_emb = self.embeddings[idx] + return { + 'item_id': index, + 'embedding': tensor_emb, + 'embedding_dim': len(tensor_emb) + } + + def __len__(self): + return len(self.embeddings) + +class ProcessEmbeddings(Transform): + def __init__(self, embedding_dim, keys): + self.embedding_dim = embedding_dim + self.keys = keys + + def __call__(self, batch): + for key in self.keys: + batch[key] = batch[key].reshape(-1, self.embedding_dim) + return batch \ No newline at end of file diff --git a/scripts/rqvae-yambda/infer_yambda.py b/scripts/rqvae-yambda/infer_yambda.py new file mode 100644 index 0000000..7daf42f --- /dev/null +++ b/scripts/rqvae-yambda/infer_yambda.py @@ -0,0 +1,128 @@ +from loguru import logger +import os + +import torch + +import irec.callbacks as cb +from irec.data.dataloader import DataLoader +from irec.data.transforms import Collate, ToTorch, ToDevice +from irec.runners import InferenceRunner + +from irec.utils import fix_random_seed + +from data import EmbeddingDatasetParquet, ProcessEmbeddings +from models import RQVAE + + +IREC_PATH = '/home/jovyan/IRec/' +EMBEDDINGS_PATH = "/home/jovyan/IRec/sigir/yambda_data/yambda_embeddings_reindexed.parquet" +MODEL_PATH = '/home/jovyan/IRec/checkpoints/rqvae_yambda_hd_128_cz_512_best_0.0014.pth' +RESULTS_PATH = '/home/jovyan/IRec/rqvae-yambda-sem-ids' +EXPERIMENT_NAME = 'rqvae_yambda_hd_128_cz_512' + +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + +BATCH_SIZE = 1024 + +INPUT_DIM = 128 +HIDDEN_DIM = 128 +CODEBOOK_SIZE = 512 +NUM_CODEBOOKS = 3 + +BETA = 0.25 + + +def main(): + fix_random_seed(SEED_VALUE) + + dataset = EmbeddingDatasetParquet( + data_path=EMBEDDINGS_PATH + ) + + dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=True, + drop_last=False, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])) + + model = RQVAE( + input_dim=INPUT_DIM, + num_codebooks=NUM_CODEBOOKS, + codebook_size=CODEBOOK_SIZE, + embedding_dim=HIDDEN_DIM, + beta=BETA, + quant_loss_weight=1.0 + ).to(DEVICE) + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + callbacks = [ + cb.LoadModel(MODEL_PATH), + + cb.BatchMetrics(metrics=lambda model_outputs, _: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + }, name='valid'), + + cb.MetricAccumulator( + accumulators={ + 'valid/loss': cb.MeanAccumulator(), + 'valid/recon_loss': cb.MeanAccumulator(), + 'valid/rqvae_loss': cb.MeanAccumulator(), + }, + ), + + cb.Logger().every_num_steps(len(dataloader)), + + cb.InferenceSaver( + metrics=lambda batch, model_outputs, _: {'item_id': batch['item_id'], 'clusters': model_outputs['clusters']}, + save_path=os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'), + format='json' + ) + ] + + logger.debug('Everything is ready for training process!') + + runner = InferenceRunner( + model=model, + dataset=dataloader, + callbacks=callbacks, + ) + runner.run() + + import json + from collections import defaultdict + import numpy as np + + with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'), 'r') as f: + mappings = json.load(f) + + inter = {} + sem_2_ids = defaultdict(list) + for mapping in mappings: + item_id = mapping['item_id'] + clusters = mapping['clusters'] + inter[int(item_id)] = clusters + sem_2_ids[tuple(clusters)].append(int(item_id)) + + for semantics, items in sem_2_ids.items(): + assert len(items) <= CODEBOOK_SIZE, str(len(items)) + collision_solvers = np.random.permutation(CODEBOOK_SIZE)[:len(items)].tolist() + for item_id, collision_solver in zip(items, collision_solvers): + inter[item_id].append(collision_solver) + for i in range(len(inter[item_id])): + inter[item_id][i] += CODEBOOK_SIZE * i + + with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters_colisionless.json'), 'w') as f: + json.dump(inter, f, indent=2) + + +if __name__ == '__main__': + main() diff --git a/scripts/rqvae-yambda/models.py b/scripts/rqvae-yambda/models.py new file mode 100644 index 0000000..87c2241 --- /dev/null +++ b/scripts/rqvae-yambda/models.py @@ -0,0 +1,91 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class RQVAE(nn.Module): + def __init__( + self, + input_dim, + num_codebooks, + codebook_size, + embedding_dim, + beta=0.25, + quant_loss_weight=1.0, + ): + super().__init__() + self.register_buffer('beta', torch.tensor(beta)) + + self.input_dim = input_dim + self.num_codebooks = num_codebooks + self.codebook_size = codebook_size + self.embedding_dim = embedding_dim + self.quant_loss_weight = quant_loss_weight + + + self.encoder = self.make_encoding_tower(input_dim, embedding_dim) + self.decoder = self.make_encoding_tower(embedding_dim, input_dim) + + self.codebooks = torch.nn.ParameterList() + for _ in range(num_codebooks): + cb = torch.FloatTensor(codebook_size, embedding_dim) + #nn.init.normal_(cb) + self.codebooks.append(cb) + + @staticmethod + def make_encoding_tower(d1, d2, bias=False): + return torch.nn.Sequential( + nn.Linear(d1, d1), + nn.ReLU(), + nn.Linear(d1, d2), + nn.ReLU(), + nn.Linear(d2, d2, bias=bias) + ) + + @staticmethod + def get_codebook_indices(remainder, codebook): + dist = torch.cdist(remainder, codebook) + return dist.argmin(dim=-1) + + def forward(self, inputs): + latent_vector = self.encoder(inputs['embedding']) + item_ids = inputs['item_id'] + + latent_restored = 0 + rqvae_loss = 0 + clusters = [] + remainder = latent_vector + + for codebook in self.codebooks: + codebook_indices = self.get_codebook_indices(remainder, codebook) + clusters.append(codebook_indices) + + quantized = codebook[codebook_indices] + codebook_vectors = remainder + (quantized - remainder).detach() + + rqvae_loss += self.beta * torch.nn.functional.mse_loss(remainder, quantized.detach()) + rqvae_loss += torch.nn.functional.mse_loss(quantized, remainder.detach()) + + latent_restored += codebook_vectors + remainder = remainder - codebook_vectors + + embeddings_restored = self.decoder(latent_restored) + recon_loss = torch.nn.functional.mse_loss(embeddings_restored, inputs['embedding']) + + loss = ( + recon_loss + + self.quant_loss_weight * rqvae_loss + ).mean() + + clusters_counts = [] + for cluster in clusters: + clusters_counts.append(torch.bincount(cluster, minlength=self.codebook_size)) + + return loss, { + 'loss': loss.item(), + 'recon_loss': recon_loss.mean().item(), + 'rqvae_loss': rqvae_loss.mean().item(), + + 'clusters_counts': clusters_counts, + 'clusters': torch.stack(clusters).T, + 'embedding_hat': embeddings_restored, + } \ No newline at end of file diff --git a/scripts/rqvae-yambda/train_yambda.py b/scripts/rqvae-yambda/train_yambda.py new file mode 100644 index 0000000..71582ae --- /dev/null +++ b/scripts/rqvae-yambda/train_yambda.py @@ -0,0 +1,151 @@ +from loguru import logger +import os + +import torch + +import irec.callbacks as cb +from irec.data.dataloader import DataLoader +from irec.data.transforms import Collate, ToTorch, ToDevice +from irec.runners import TrainingRunner + +from irec.utils import fix_random_seed + +from callbacks import InitCodebooks, FixDeadCentroids +from data import EmbeddingDatasetParquet, ProcessEmbeddings +from models import RQVAE + +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + +NUM_EPOCHS = 100 +BATCH_SIZE = 1024 + +INPUT_DIM = 128 +HIDDEN_DIM = 128 +CODEBOOK_SIZE = 512 +NUM_CODEBOOKS = 3 +BETA = 0.25 +LR = 1e-4 + +EXPERIMENT_NAME = 'rqvae_yambda_hd_128_cz_512' +EMBEDDINGS_PATH = "/home/jovyan/IRec/sigir/yambda_data/yambda_embeddings_reindexed.parquet" +IREC_PATH = '../../' + +print(EXPERIMENT_NAME, EMBEDDINGS_PATH) +def main(): + fix_random_seed(SEED_VALUE) + + dataset = EmbeddingDatasetParquet( + data_path=EMBEDDINGS_PATH + ) + + train_dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + num_workers=8, + shuffle=True, + drop_last=True, + persistent_workers=True, + pin_memory=True + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map( + ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding']) + ).repeat(NUM_EPOCHS) + + valid_dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=False, + drop_last=False, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])) + + LOG_EVERY_NUM_STEPS = int(len(train_dataloader) // NUM_EPOCHS) + + model = RQVAE( + input_dim=INPUT_DIM, + num_codebooks=NUM_CODEBOOKS, + codebook_size=CODEBOOK_SIZE, + embedding_dim=HIDDEN_DIM, + beta=BETA, + quant_loss_weight=1.0 + ).to(DEVICE) + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + optimizer = torch.optim.Adam(model.parameters(), lr=LR, fused=True) + + callbacks = [ + InitCodebooks(valid_dataloader), + + cb.BatchMetrics(metrics=lambda model_outputs, batch: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'] + }, name='train'), + + FixDeadCentroids(valid_dataloader), + + cb.MetricAccumulator( + accumulators={ + 'train/loss': cb.MeanAccumulator(), + 'train/recon_loss': cb.MeanAccumulator(), + 'train/rqvae_loss': cb.MeanAccumulator(), + 'num_dead/0': cb.MeanAccumulator(), + 'num_dead/1': cb.MeanAccumulator(), + 'num_dead/2': cb.MeanAccumulator(), + }, + reset_every_num_steps=LOG_EVERY_NUM_STEPS + ), + + cb.Validation( + dataset=valid_dataloader, + callbacks=[ + cb.BatchMetrics(metrics=lambda model_outputs, batch: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + }, name='valid'), + cb.MetricAccumulator( + accumulators={ + 'valid/loss': cb.MeanAccumulator(), + 'valid/recon_loss': cb.MeanAccumulator(), + 'valid/rqvae_loss': cb.MeanAccumulator(), + } + ), + ], + ).every_num_steps(LOG_EVERY_NUM_STEPS), + + cb.Logger().every_num_steps(LOG_EVERY_NUM_STEPS), + cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=os.path.join(IREC_PATH, 'tensorboard_logs')), + + cb.Profiler( + wait=10, + warmup=10, + active=10, + logdir=os.path.join(IREC_PATH, 'tensorboard_logs') + ), + + cb.EarlyStopping( + metric='valid/recon_loss', + patience=40, + minimize=True, + model_path=os.path.join(IREC_PATH, 'checkpoints', EXPERIMENT_NAME) + ).every_num_steps(LOG_EVERY_NUM_STEPS), + ] + + logger.debug('Everything is ready for training process!') + + runner = TrainingRunner( + model=model, + optimizer=optimizer, + dataset=train_dataloader, + callbacks=callbacks, + ) + runner.run() + + +if __name__ == '__main__': + main() diff --git a/scripts/tiger-lsvd/data.py b/scripts/tiger-lsvd/data.py new file mode 100644 index 0000000..26123db --- /dev/null +++ b/scripts/tiger-lsvd/data.py @@ -0,0 +1,428 @@ +from collections import defaultdict +import json +from loguru import logger +import numpy as np +from pathlib import Path + + +import pyarrow as pa +import pyarrow.feather as feather + +import torch +import polars as pl +from irec.data.base import BaseDataset + + +class InteractionsDatasetParquet(BaseDataset): + def __init__(self, data_path, max_items=None): + self.df = pl.read_parquet(data_path) + assert 'uid' in self.df.columns, "Missing 'uid' column" + assert 'item_ids' in self.df.columns, "Missing 'item_ids' column" + print(f"Dataset loaded: {len(self.df)} users") + + if max_items is not None: + self.df = self.df.with_columns( + pl.col("item_ids").list.slice(-max_items).alias("item_ids") + ) + + def __getitem__(self, idx): + row = self.df.row(idx, named=True) + return { + 'user_id': row['uid'], + 'item_ids': np.array(row['item_ids'], dtype=np.uint32), + } + + def __len__(self): + return len(self.df) + + def __iter__(self): + for idx in range(len(self)): + yield self[idx] + + + +class Dataset: + def __init__( + self, + train_sampler, + validation_sampler, + test_sampler, + num_items, + max_sequence_length + ): + self._train_sampler = train_sampler + self._validation_sampler = validation_sampler + self._test_sampler = test_sampler + self._num_items = num_items + self._max_sequence_length = max_sequence_length + + @classmethod + def create_timestamp_based_parquet( + cls, + train_parquet_path, + validation_parquet_path, + test_parquet_path, + max_sequence_length, + sampler_type, + min_sample_len=2, + is_extended=False, + max_train_events=50 + ): + """ + Загружает данные из parquet файлов с timestamp-based сплитом. + + Ожидает структуру parquet: + - uid: int (user id) + - item_ids: list[int] (список item ids) + + Аналогично create_timestamp_based, но для parquet формата. + """ + max_item_id = 0 + train_dataset, validation_dataset, test_dataset = [], [], [] + + print(f"started to load datasets from parquet with max train length {max_train_events}") + + # Загружаем parquet файлы + train_df = pl.read_parquet(train_parquet_path) + validation_df = pl.read_parquet(validation_parquet_path) + test_df = pl.read_parquet(test_parquet_path) + + # Проверяем наличие необходимых колонок + for df, name in [(train_df, "train"), (validation_df, "validation"), (test_df, "test")]: + assert 'uid' in df.columns, f"Missing 'uid' column in {name}" + assert 'item_ids' in df.columns, f"Missing 'item_ids' column in {name}" + + # Создаем словари для быстрого доступа + train_data = {str(row['uid']): row['item_ids'] for row in train_df.iter_rows(named=True)} + validation_data = {str(row['uid']): row['item_ids'] for row in validation_df.iter_rows(named=True)} + test_data = {str(row['uid']): row['item_ids'] for row in test_df.iter_rows(named=True)} + + all_users = set(train_data.keys()) | set(validation_data.keys()) | set(test_data.keys()) + print(f"all users count: {len(all_users)}") + + us_count = 0 + for user_id_str in all_users: + if us_count % 100 == 0: + print(f"user id {us_count}/{len(all_users)}: {user_id_str}") + + user_id = int(user_id_str) + + # Получаем последовательности для каждого сплита + train_items = list(train_data.get(user_id_str, [])) + validation_items = list(validation_data.get(user_id_str, [])) + test_items = list(test_data.get(user_id_str, [])) + + # Обрезаем train на последние max_train_events событий + train_items = train_items[-max_train_events:] if len(train_items) > max_train_events else train_items + + full_sequence = train_items + validation_items + test_items + if full_sequence: + max_item_id = max(max_item_id, max(full_sequence)) + + if us_count % 100 == 0: + print(f"full sequence len: {len(full_sequence)}") + + us_count += 1 + if len(full_sequence) < 4: + print(f'Core-4 dataset is used, user {user_id} has only {len(full_sequence)} items') + continue + + if is_extended: + # sample = [1, 2] + # sample = [1, 2, 3] + # sample = [1, 2, 3, 4] + # sample = [1, 2, 3, 4, 5] + # sample = [1, 2, 3, 4, 5, 6] + # sample = [1, 2, 3, 4, 5, 6, 7] + # sample = [1, 2, 3, 4, 5, 6, 7, 8] + for prefix_length in range(min_sample_len, len(train_items) + 1): + train_dataset.append({ + 'user.ids': [user_id], + 'item.ids': train_items[:prefix_length], + }) + else: + # sample = [1, 2, 3, 4, 5, 6, 7, 8] + train_dataset.append({ + 'user.ids': [user_id], + 'item.ids': train_items, + }) + + # валидация + + # разворачиваем каждый айтем из валидации в отдельный сэмпл + # Пример: Train=[1,2], Valid=[3,4] + # sample = [1, 2, 3] + # sample = [1, 2, 3, 4] + + current_history = train_items.copy() + valid_small_history = 0 + for item in validation_items: + # эвал датасет сам отрезает таргет потом + sample_sequence = current_history + [item] + + if len(sample_sequence) >= min_sample_len: + validation_dataset.append({ + 'user.ids': [user_id], + 'item.ids': sample_sequence, + }) + else: + valid_small_history += 1 + current_history.append(item) + + # разворачиваем каждый айтем из теста в отдельный сэмпл + # Пример: Train=[1,2], Valid=[3,4], Test=[5, 6] + # sample = [1, 2, 3, 4, 5] + # sample = [1, 2, 3, 4, 5, 6] + current_history = train_items + validation_items + test_small_history = 0 + for item in test_items: + sample_sequence = current_history + [item] + if len(sample_sequence) >= min_sample_len: + test_dataset.append({ + 'user.ids': [user_id], + 'item.ids': sample_sequence, + }) + else: + test_small_history += 1 + current_history.append(item) + + print(f"Train dataset size: {len(train_dataset)}") + print(f"Validation dataset size: {len(validation_dataset)} with skipped {valid_small_history}") + print(f"Test dataset size: {len(test_dataset)} with skipped {test_small_history}") + + logger.debug(f'Train dataset size: {len(train_dataset)}') + logger.debug(f'Validation dataset size: {len(validation_dataset)}') + logger.debug(f'Test dataset size: {len(test_dataset)}') + + train_sampler = TrainDataset(train_dataset, sampler_type, max_sequence_length=max_sequence_length) + validation_sampler = EvalDataset(validation_dataset, max_sequence_length=max_sequence_length) + test_sampler = EvalDataset(test_dataset, max_sequence_length=max_sequence_length) + + return cls( + train_sampler=train_sampler, + validation_sampler=validation_sampler, + test_sampler=test_sampler, + num_items=max_item_id + 1, # +1 added because our ids are 0-indexed + max_sequence_length=max_sequence_length + ) + + @classmethod + def create(cls, inter_json_path, max_sequence_length, sampler_type, is_extended=False): + max_item_id = 0 + train_dataset, validation_dataset, test_dataset = [], [], [] + + with open(inter_json_path, 'r') as f: + user_interactions = json.load(f) + + for user_id_str, item_ids in user_interactions.items(): + user_id = int(user_id_str) + + if item_ids: + max_item_id = max(max_item_id, max(item_ids)) + + assert len(item_ids) >= 5, f'Core-5 dataset is used, user {user_id} has only {len(item_ids)} items' + + # sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] (leave one out scheme, 8 - train, 9 - valid, 10 - test) + if is_extended: + # sample = [1, 2] + # sample = [1, 2, 3] + # sample = [1, 2, 3, 4] + # sample = [1, 2, 3, 4, 5] + # sample = [1, 2, 3, 4, 5, 6] + # sample = [1, 2, 3, 4, 5, 6, 7] + # sample = [1, 2, 3, 4, 5, 6, 7, 8] + for prefix_length in range(2, len(item_ids) - 2 + 1): + train_dataset.append({ + 'user.ids': [user_id], + 'item.ids': item_ids[:prefix_length], + }) + else: + # sample = [1, 2, 3, 4, 5, 6, 7, 8] + train_dataset.append({ + 'user.ids': [user_id], + 'item.ids': item_ids[:-2], + }) + + # sample = [1, 2, 3, 4, 5, 6, 7, 8, 9] + validation_dataset.append({ + 'user.ids': [user_id], + 'item.ids': item_ids[:-1], + }) + + # sample = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + test_dataset.append({ + 'user.ids': [user_id], + 'item.ids': item_ids, + }) + + logger.debug(f'Train dataset size: {len(train_dataset)}') + logger.debug(f'Validation dataset size: {len(validation_dataset)}') + logger.debug(f'Test dataset size: {len(test_dataset)}') + logger.debug(f'Max item id: {max_item_id}') + + train_sampler = TrainDataset(train_dataset, sampler_type, max_sequence_length=max_sequence_length) + validation_sampler = EvalDataset(validation_dataset, max_sequence_length=max_sequence_length) + test_sampler = EvalDataset(test_dataset, max_sequence_length=max_sequence_length) + + return cls( + train_sampler=train_sampler, + validation_sampler=validation_sampler, + test_sampler=test_sampler, + num_items=max_item_id + 1, # +1 added because our ids are 0-indexed + max_sequence_length=max_sequence_length + ) + + def get_datasets(self): + return self._train_sampler, self._validation_sampler, self._test_sampler + + @property + def num_items(self): + return self._num_items + + @property + def max_sequence_length(self): + return self._max_sequence_length + + +class TrainDataset(BaseDataset): + def __init__(self, dataset, prediction_type, max_sequence_length): + self._dataset = dataset + self._prediction_type = prediction_type + self._max_sequence_length = max_sequence_length + + self._transforms = { + 'sasrec': self._all_items_transform, + 'tiger': self._last_item_transform + } + + def _all_items_transform(self, sample): + item_sequence = sample['item.ids'][-self._max_sequence_length:][:-1] + next_item_sequence = sample['item.ids'][-self._max_sequence_length:][1:] + return { + 'user.ids': np.array(sample['user.ids'], dtype=np.int64), + 'user.length': np.array([len(sample['user.ids'])], dtype=np.int64), + 'item.ids': np.array(item_sequence, dtype=np.int64), + 'item.length': np.array([len(item_sequence)], dtype=np.int64), + 'labels.ids': np.array(next_item_sequence, dtype=np.int64), + 'labels.length': np.array([len(next_item_sequence)], dtype=np.int64) + } + + def _last_item_transform(self, sample): + item_sequence = sample['item.ids'][-self._max_sequence_length:][:-1] + last_item = sample['item.ids'][-self._max_sequence_length:][-1] + return { + 'user.ids': np.array(sample['user.ids'], dtype=np.int64), + 'user.length': np.array([len(sample['user.ids'])], dtype=np.int64), + 'item.ids': np.array(item_sequence, dtype=np.int64), + 'item.length': np.array([len(item_sequence)], dtype=np.int64), + 'labels.ids': np.array([last_item], dtype=np.int64), + 'labels.length': np.array([1], dtype=np.int64), + } + + def __getitem__(self, index): + return self._transforms[self._prediction_type](self._dataset[index]) + + def __len__(self): + return len(self._dataset) + + +class EvalDataset(BaseDataset): + def __init__(self, dataset, max_sequence_length): + self._dataset = dataset + self._max_sequence_length = max_sequence_length + + @property + def dataset(self): + return self._dataset + + def __len__(self): + return len(self._dataset) + + def __getitem__(self, index): + sample = self._dataset[index] + + item_sequence = sample['item.ids'][-self._max_sequence_length:][:-1] + next_item = sample['item.ids'][-self._max_sequence_length:][-1] + + return { + 'user.ids': np.array(sample['user.ids'], dtype=np.int64), + 'user.length': np.array([len(sample['user.ids'])], dtype=np.int64), + 'item.ids': np.array(item_sequence, dtype=np.int64), + 'item.length': np.array([len(item_sequence)], dtype=np.int64), + 'labels.ids': np.array([next_item], dtype=np.int64), + 'labels.length': np.array([1], dtype=np.int64), + 'visited.ids': np.array(sample['item.ids'][:-1], dtype=np.int64), + 'visited.length': np.array([len(sample['item.ids'][:-1])], dtype=np.int64), + } + + +class ArrowBatchDataset(BaseDataset): + def __init__(self, batch_dir, device='cuda', preload=False): + self.batch_dir = Path(batch_dir) + self.device = device + + all_files = list(self.batch_dir.glob('batch_*_len_*.arrow')) + + batch_files_map = defaultdict(list) + for f in all_files: + batch_id = int(f.stem.split('_')[1]) + batch_files_map[batch_id].append(f) + + for batch_id in batch_files_map: + batch_files_map[batch_id].sort() + + self.batch_indices = sorted(batch_files_map.keys()) + + if preload: + print(f"Preloading {len(self.batch_indices)} batches...") + self.cached_batches = [] + + for idx in range(len(self.batch_indices)): + batch = self._load_batch(batch_files_map[self.batch_indices[idx]]) + self.cached_batches.append(batch) + else: + self.cached_batches = None + self.batch_files_map = batch_files_map + + def _load_batch(self, arrow_files): + batch = {} + + for arrow_file in arrow_files: + table = feather.read_table(arrow_file) + metadata = table.schema.metadata or {} + + for col_name in table.column_names: + col = table.column(col_name) + + shape_key = f'{col_name}_shape' + dtype_key = f'{col_name}_dtype' + + if shape_key.encode() in metadata: + shape = eval(metadata[shape_key.encode()].decode()) + dtype = np.dtype(metadata[dtype_key.encode()].decode()) + + # Проверяем тип колонки + if pa.types.is_list(col.type) or pa.types.is_large_list(col.type): + arr = np.array(col.to_pylist(), dtype=dtype) + else: + arr = col.to_numpy().reshape(shape).astype(dtype) + else: + if pa.types.is_list(col.type) or pa.types.is_large_list(col.type): + arr = np.array(col.to_pylist()) + else: + arr = col.to_numpy() + + batch[col_name] = torch.from_numpy(arr.copy()).to(self.device) + + return batch + + def __len__(self): + return len(self.batch_indices) + + def __getitem__(self, idx): + if self.cached_batches is not None: + return self.cached_batches[idx] + else: + batch_id = self.batch_indices[idx] + arrow_files = self.batch_files_map[batch_id] + return self._load_batch(arrow_files) diff --git a/scripts/tiger-lsvd/lsvd_train_4.1_plum.py b/scripts/tiger-lsvd/lsvd_train_4.1_plum.py new file mode 100644 index 0000000..96a08d7 --- /dev/null +++ b/scripts/tiger-lsvd/lsvd_train_4.1_plum.py @@ -0,0 +1,230 @@ +from functools import partial +import json +from loguru import logger +import os + +import torch + +import irec.callbacks as cb +from irec.data.transforms import Collate, ToDevice +from irec.data.dataloader import DataLoader +from irec.runners import TrainingRunner +from irec.utils import fix_random_seed + +from data import ArrowBatchDataset +from models import TigerModel, CorrectItemsLogitsProcessor + + +# ПУТИ +IREC_PATH = '../../' +SEMANTIC_MAPPING_PATH = os.path.join(IREC_PATH, 'results/4-1_vk_lsvd_ods_base_with_gap_cb_512_ws_2_k_2000_8w_e_35_clusters_colisionless.json') +TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd/8-weeks-base-one-week-split-4.1/train_batches/') +VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd/8-weeks-base-one-week-split-4.1/valid_batches/') +EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd/8-weeks-base-one-week-split-4.1/eval_batches/') + + +TENSORBOARD_LOGDIR = os.path.join(IREC_PATH, 'tensorboard_logs') +CHECKPOINTS_DIR = os.path.join(IREC_PATH, 'checkpoints') + +EXPERIMENT_NAME = 'tiger_4-1_vk_lsvd_ods_base_with_gap_cb_512_ws_2_k_2000_8w_e_35' + +# ОСТАЛЬНОЕ +SEED_VALUE = 42 +DEVICE = 'cuda' + +NUM_EPOCHS = 300 +MAX_SEQ_LEN = 20 +TRAIN_BATCH_SIZE = 256 +VALID_BATCH_SIZE = 1024 +EMBEDDING_DIM = 128 +CODEBOOK_SIZE = 512 +NUM_POSITIONS = 80 +NUM_USER_HASH = 2000 +NUM_HEADS = 6 +NUM_LAYERS = 4 +FEEDFORWARD_DIM = 1024 +KV_DIM = 64 +DROPOUT = 0.2 +NUM_BEAMS = 30 +TOP_K = 20 +NUM_CODEBOOKS = 4 +LR = 0.0001 + +USE_MICROBATCHING = True +MICROBATCH_SIZE = 256 + +torch.set_float32_matmul_precision('high') +torch._dynamo.config.capture_scalar_outputs = True + +import torch._inductor.config as config +config.triton.cudagraph_skip_dynamic_graphs = True + + +def main(): + fix_random_seed(SEED_VALUE) + + with open(SEMANTIC_MAPPING_PATH, 'r') as f: + mappings = json.load(f) + + train_dataloader = DataLoader( + ArrowBatchDataset( + TRAIN_BATCHES_DIR, + device='cpu', + preload=True + ), + batch_size=1, + shuffle=True, + num_workers=0, + pin_memory=True, + collate_fn=Collate() + ).map(ToDevice(DEVICE)).repeat(NUM_EPOCHS) + + valid_dataloder = ArrowBatchDataset( + VALID_BATCHES_DIR, + device=DEVICE, + preload=True + ) + + eval_dataloder = ArrowBatchDataset( + EVAL_BATCHES_DIR, + device=DEVICE, + preload=True + ) + + model = TigerModel( + embedding_dim=EMBEDDING_DIM, + codebook_size=CODEBOOK_SIZE, + sem_id_len=NUM_CODEBOOKS, + user_ids_count=NUM_USER_HASH, + num_positions=NUM_POSITIONS, + num_heads=NUM_HEADS, + num_encoder_layers=NUM_LAYERS, + num_decoder_layers=NUM_LAYERS, + dim_feedforward=FEEDFORWARD_DIM, + num_beams=NUM_BEAMS, + num_return_sequences=TOP_K, + activation='relu', + d_kv=KV_DIM, + dropout=DROPOUT, + layer_norm_eps=1e-6, + initializer_range=0.02, + logits_processor=partial( + CorrectItemsLogitsProcessor, + NUM_CODEBOOKS, + CODEBOOK_SIZE, + mappings, + NUM_BEAMS + ), + use_microbatching=USE_MICROBATCHING, + microbatch_size=MICROBATCH_SIZE + ).to(DEVICE) + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + optimizer = torch.optim.AdamW( + model.parameters(), + lr=LR, + ) + + EPOCH_NUM_STEPS = 1024 # int(len(train_dataloader) // NUM_EPOCHS) + + callbacks = [ + cb.BatchMetrics(metrics=lambda model_outputs, _: { + 'loss': model_outputs['loss'].item(), + }, name='train'), + cb.MetricAccumulator( + accumulators={ + 'train/loss': cb.MeanAccumulator(), + }, + reset_every_num_steps=EPOCH_NUM_STEPS + ), + + cb.Validation( + dataset=valid_dataloder, + callbacks=[ + cb.BatchMetrics(metrics=lambda model_outputs, _:{ + 'loss': model_outputs['loss'].item(), + 'recall@5': model_outputs['recall@5'].tolist(), + 'recall@10': model_outputs['recall@10'].tolist(), + 'recall@20': model_outputs['recall@20'].tolist(), + 'ndcg@5': model_outputs['ndcg@5'].tolist(), + 'ndcg@10': model_outputs['ndcg@10'].tolist(), + 'ndcg@20': model_outputs['ndcg@20'].tolist(), + }, name='validation'), + cb.MetricAccumulator( + accumulators={ + 'validation/loss': cb.MeanAccumulator(), + 'validation/recall@5': cb.MeanAccumulator(), + 'validation/recall@10': cb.MeanAccumulator(), + 'validation/recall@20': cb.MeanAccumulator(), + 'validation/ndcg@5': cb.MeanAccumulator(), + 'validation/ndcg@10': cb.MeanAccumulator(), + 'validation/ndcg@20': cb.MeanAccumulator(), + }, + ), + ], + ).every_num_steps(EPOCH_NUM_STEPS), + + cb.Validation( + dataset=eval_dataloder, + callbacks=[ + cb.BatchMetrics(metrics=lambda model_outputs, _: { + 'loss': model_outputs['loss'].item(), + 'recall@5': model_outputs['recall@5'].tolist(), + 'recall@10': model_outputs['recall@10'].tolist(), + 'recall@20': model_outputs['recall@20'].tolist(), + 'ndcg@5': model_outputs['ndcg@5'].tolist(), + 'ndcg@10': model_outputs['ndcg@10'].tolist(), + 'ndcg@20': model_outputs['ndcg@20'].tolist(), + }, name='eval'), + cb.MetricAccumulator( + accumulators={ + 'eval/loss': cb.MeanAccumulator(), + 'eval/recall@5': cb.MeanAccumulator(), + 'eval/recall@10': cb.MeanAccumulator(), + 'eval/recall@20': cb.MeanAccumulator(), + 'eval/ndcg@5': cb.MeanAccumulator(), + 'eval/ndcg@10': cb.MeanAccumulator(), + 'eval/ndcg@20': cb.MeanAccumulator(), + }, + ), + ], + ).every_num_steps(EPOCH_NUM_STEPS * 4), + + cb.Logger().every_num_steps(EPOCH_NUM_STEPS), + cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=TENSORBOARD_LOGDIR), + + cb.EarlyStopping( + metric='validation/ndcg@20', + patience=40 * 4, + minimize=False, + model_path=os.path.join(CHECKPOINTS_DIR, EXPERIMENT_NAME) + ).every_num_steps(EPOCH_NUM_STEPS) + + # cb.Profiler( + # wait=10, + # warmup=10, + # active=10, + # logdir=TENSORBOARD_LOGDIR + # ), + # cb.StopAfterNumSteps(40) + + ] + + logger.debug('Everything is ready for training process!') + + runner = TrainingRunner( + model=model, + optimizer=optimizer, + dataset=train_dataloader, + callbacks=callbacks, + ) + runner.run() + + +if __name__ == '__main__': + main() diff --git a/scripts/tiger-lsvd/lsvd_train_4.2_plum.py b/scripts/tiger-lsvd/lsvd_train_4.2_plum.py new file mode 100644 index 0000000..991f662 --- /dev/null +++ b/scripts/tiger-lsvd/lsvd_train_4.2_plum.py @@ -0,0 +1,230 @@ +from functools import partial +import json +from loguru import logger +import os + +import torch + +import irec.callbacks as cb +from irec.data.transforms import Collate, ToDevice +from irec.data.dataloader import DataLoader +from irec.runners import TrainingRunner +from irec.utils import fix_random_seed + +from data import ArrowBatchDataset +from models import TigerModel, CorrectItemsLogitsProcessor + + +# ПУТИ +IREC_PATH = '../../' +SEMANTIC_MAPPING_PATH = os.path.join(IREC_PATH, 'results/4-2_vk_lsvd_ods_base_with_gap_cb_512_ws_2_k_2000_8w_e_35_clusters_colisionless.json') +TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd/8-weeks-base-one-week-split-4.2/train_batches/') +VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd/8-weeks-base-one-week-split-4.2/valid_batches/') +EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd/8-weeks-base-one-week-split-4.2/eval_batches/') + + +TENSORBOARD_LOGDIR = os.path.join(IREC_PATH, 'tensorboard_logs') +CHECKPOINTS_DIR = os.path.join(IREC_PATH, 'checkpoints') + +EXPERIMENT_NAME = 'tiger_4-2_vk_lsvd_ods_base_cb_512_ws_2_k_2000_8w_e_35' + +# ОСТАЛЬНОЕ +SEED_VALUE = 42 +DEVICE = 'cuda' + +NUM_EPOCHS = 300 +MAX_SEQ_LEN = 20 +TRAIN_BATCH_SIZE = 256 +VALID_BATCH_SIZE = 1024 +EMBEDDING_DIM = 128 +CODEBOOK_SIZE = 512 +NUM_POSITIONS = 80 +NUM_USER_HASH = 2000 +NUM_HEADS = 6 +NUM_LAYERS = 4 +FEEDFORWARD_DIM = 1024 +KV_DIM = 64 +DROPOUT = 0.2 +NUM_BEAMS = 30 +TOP_K = 20 +NUM_CODEBOOKS = 4 +LR = 0.0001 + +USE_MICROBATCHING = True +MICROBATCH_SIZE = 256 + +torch.set_float32_matmul_precision('high') +torch._dynamo.config.capture_scalar_outputs = True + +import torch._inductor.config as config +config.triton.cudagraph_skip_dynamic_graphs = True + + +def main(): + fix_random_seed(SEED_VALUE) + + with open(SEMANTIC_MAPPING_PATH, 'r') as f: + mappings = json.load(f) + + train_dataloader = DataLoader( + ArrowBatchDataset( + TRAIN_BATCHES_DIR, + device='cpu', + preload=True + ), + batch_size=1, + shuffle=True, + num_workers=0, + pin_memory=True, + collate_fn=Collate() + ).map(ToDevice(DEVICE)).repeat(NUM_EPOCHS) + + valid_dataloder = ArrowBatchDataset( + VALID_BATCHES_DIR, + device=DEVICE, + preload=True + ) + + eval_dataloder = ArrowBatchDataset( + EVAL_BATCHES_DIR, + device=DEVICE, + preload=True + ) + + model = TigerModel( + embedding_dim=EMBEDDING_DIM, + codebook_size=CODEBOOK_SIZE, + sem_id_len=NUM_CODEBOOKS, + user_ids_count=NUM_USER_HASH, + num_positions=NUM_POSITIONS, + num_heads=NUM_HEADS, + num_encoder_layers=NUM_LAYERS, + num_decoder_layers=NUM_LAYERS, + dim_feedforward=FEEDFORWARD_DIM, + num_beams=NUM_BEAMS, + num_return_sequences=TOP_K, + activation='relu', + d_kv=KV_DIM, + dropout=DROPOUT, + layer_norm_eps=1e-6, + initializer_range=0.02, + logits_processor=partial( + CorrectItemsLogitsProcessor, + NUM_CODEBOOKS, + CODEBOOK_SIZE, + mappings, + NUM_BEAMS + ), + use_microbatching=USE_MICROBATCHING, + microbatch_size=MICROBATCH_SIZE + ).to(DEVICE) + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + optimizer = torch.optim.AdamW( + model.parameters(), + lr=LR, + ) + + EPOCH_NUM_STEPS = 1024 # int(len(train_dataloader) // NUM_EPOCHS) + + callbacks = [ + cb.BatchMetrics(metrics=lambda model_outputs, _: { + 'loss': model_outputs['loss'].item(), + }, name='train'), + cb.MetricAccumulator( + accumulators={ + 'train/loss': cb.MeanAccumulator(), + }, + reset_every_num_steps=EPOCH_NUM_STEPS + ), + + cb.Validation( + dataset=valid_dataloder, + callbacks=[ + cb.BatchMetrics(metrics=lambda model_outputs, _:{ + 'loss': model_outputs['loss'].item(), + 'recall@5': model_outputs['recall@5'].tolist(), + 'recall@10': model_outputs['recall@10'].tolist(), + 'recall@20': model_outputs['recall@20'].tolist(), + 'ndcg@5': model_outputs['ndcg@5'].tolist(), + 'ndcg@10': model_outputs['ndcg@10'].tolist(), + 'ndcg@20': model_outputs['ndcg@20'].tolist(), + }, name='validation'), + cb.MetricAccumulator( + accumulators={ + 'validation/loss': cb.MeanAccumulator(), + 'validation/recall@5': cb.MeanAccumulator(), + 'validation/recall@10': cb.MeanAccumulator(), + 'validation/recall@20': cb.MeanAccumulator(), + 'validation/ndcg@5': cb.MeanAccumulator(), + 'validation/ndcg@10': cb.MeanAccumulator(), + 'validation/ndcg@20': cb.MeanAccumulator(), + }, + ), + ], + ).every_num_steps(EPOCH_NUM_STEPS), + + cb.Validation( + dataset=eval_dataloder, + callbacks=[ + cb.BatchMetrics(metrics=lambda model_outputs, _: { + 'loss': model_outputs['loss'].item(), + 'recall@5': model_outputs['recall@5'].tolist(), + 'recall@10': model_outputs['recall@10'].tolist(), + 'recall@20': model_outputs['recall@20'].tolist(), + 'ndcg@5': model_outputs['ndcg@5'].tolist(), + 'ndcg@10': model_outputs['ndcg@10'].tolist(), + 'ndcg@20': model_outputs['ndcg@20'].tolist(), + }, name='eval'), + cb.MetricAccumulator( + accumulators={ + 'eval/loss': cb.MeanAccumulator(), + 'eval/recall@5': cb.MeanAccumulator(), + 'eval/recall@10': cb.MeanAccumulator(), + 'eval/recall@20': cb.MeanAccumulator(), + 'eval/ndcg@5': cb.MeanAccumulator(), + 'eval/ndcg@10': cb.MeanAccumulator(), + 'eval/ndcg@20': cb.MeanAccumulator(), + }, + ), + ], + ).every_num_steps(EPOCH_NUM_STEPS * 4), + + cb.Logger().every_num_steps(EPOCH_NUM_STEPS), + cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=TENSORBOARD_LOGDIR), + + cb.EarlyStopping( + metric='validation/ndcg@20', + patience=40 * 4, + minimize=False, + model_path=os.path.join(CHECKPOINTS_DIR, EXPERIMENT_NAME) + ).every_num_steps(EPOCH_NUM_STEPS) + + # cb.Profiler( + # wait=10, + # warmup=10, + # active=10, + # logdir=TENSORBOARD_LOGDIR + # ), + # cb.StopAfterNumSteps(40) + + ] + + logger.debug('Everything is ready for training process!') + + runner = TrainingRunner( + model=model, + optimizer=optimizer, + dataset=train_dataloader, + callbacks=callbacks, + ) + runner.run() + + +if __name__ == '__main__': + main() diff --git a/scripts/tiger-lsvd/lsvd_train_rqvae.py b/scripts/tiger-lsvd/lsvd_train_rqvae.py new file mode 100644 index 0000000..aadf225 --- /dev/null +++ b/scripts/tiger-lsvd/lsvd_train_rqvae.py @@ -0,0 +1,230 @@ +from functools import partial +import json +from loguru import logger +import os + +import torch + +import irec.callbacks as cb +from irec.data.transforms import Collate, ToDevice +from irec.data.dataloader import DataLoader +from irec.runners import TrainingRunner +from irec.utils import fix_random_seed + +from data import ArrowBatchDataset +from models import TigerModel, CorrectItemsLogitsProcessor + + +# ПУТИ +IREC_PATH = '../../' +SEMANTIC_MAPPING_PATH = os.path.join(IREC_PATH, 'results/rqvae_vk_lsvd_cz_512_8-weeks_clusters_colisionless.json') +TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd/8-weeks-base-one-week-split-rqvae/train_batches/') +VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd/8-weeks-base-one-week-split-rqvae/valid_batches/') +EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd/8-weeks-base-one-week-split-rqvae/eval_batches/') + + +TENSORBOARD_LOGDIR = os.path.join(IREC_PATH, 'tensorboard_logs') +CHECKPOINTS_DIR = os.path.join(IREC_PATH, 'checkpoints') + +EXPERIMENT_NAME = 'tiger_rqvae_vk_lsvd_cz_512_8-weeks' + +# ОСТАЛЬНОЕ +SEED_VALUE = 42 +DEVICE = 'cuda' + +NUM_EPOCHS = 300 +MAX_SEQ_LEN = 20 +TRAIN_BATCH_SIZE = 256 +VALID_BATCH_SIZE = 1024 +EMBEDDING_DIM = 128 +CODEBOOK_SIZE = 512 +NUM_POSITIONS = 80 +NUM_USER_HASH = 2000 +NUM_HEADS = 6 +NUM_LAYERS = 4 +FEEDFORWARD_DIM = 1024 +KV_DIM = 64 +DROPOUT = 0.2 +NUM_BEAMS = 30 +TOP_K = 20 +NUM_CODEBOOKS = 4 +LR = 0.0001 + +USE_MICROBATCHING = True +MICROBATCH_SIZE = 256 + +torch.set_float32_matmul_precision('high') +torch._dynamo.config.capture_scalar_outputs = True + +import torch._inductor.config as config +config.triton.cudagraph_skip_dynamic_graphs = True + + +def main(): + fix_random_seed(SEED_VALUE) + + with open(SEMANTIC_MAPPING_PATH, 'r') as f: + mappings = json.load(f) + + train_dataloader = DataLoader( + ArrowBatchDataset( + TRAIN_BATCHES_DIR, + device='cpu', + preload=True + ), + batch_size=1, + shuffle=True, + num_workers=0, + pin_memory=True, + collate_fn=Collate() + ).map(ToDevice(DEVICE)).repeat(NUM_EPOCHS) + + valid_dataloder = ArrowBatchDataset( + VALID_BATCHES_DIR, + device=DEVICE, + preload=True + ) + + eval_dataloder = ArrowBatchDataset( + EVAL_BATCHES_DIR, + device=DEVICE, + preload=True + ) + + model = TigerModel( + embedding_dim=EMBEDDING_DIM, + codebook_size=CODEBOOK_SIZE, + sem_id_len=NUM_CODEBOOKS, + user_ids_count=NUM_USER_HASH, + num_positions=NUM_POSITIONS, + num_heads=NUM_HEADS, + num_encoder_layers=NUM_LAYERS, + num_decoder_layers=NUM_LAYERS, + dim_feedforward=FEEDFORWARD_DIM, + num_beams=NUM_BEAMS, + num_return_sequences=TOP_K, + activation='relu', + d_kv=KV_DIM, + dropout=DROPOUT, + layer_norm_eps=1e-6, + initializer_range=0.02, + logits_processor=partial( + CorrectItemsLogitsProcessor, + NUM_CODEBOOKS, + CODEBOOK_SIZE, + mappings, + NUM_BEAMS + ), + use_microbatching=USE_MICROBATCHING, + microbatch_size=MICROBATCH_SIZE + ).to(DEVICE) + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + optimizer = torch.optim.AdamW( + model.parameters(), + lr=LR, + ) + + EPOCH_NUM_STEPS = 1024 # int(len(train_dataloader) // NUM_EPOCHS) + + callbacks = [ + cb.BatchMetrics(metrics=lambda model_outputs, _: { + 'loss': model_outputs['loss'].item(), + }, name='train'), + cb.MetricAccumulator( + accumulators={ + 'train/loss': cb.MeanAccumulator(), + }, + reset_every_num_steps=EPOCH_NUM_STEPS + ), + + cb.Validation( + dataset=valid_dataloder, + callbacks=[ + cb.BatchMetrics(metrics=lambda model_outputs, _:{ + 'loss': model_outputs['loss'].item(), + 'recall@5': model_outputs['recall@5'].tolist(), + 'recall@10': model_outputs['recall@10'].tolist(), + 'recall@20': model_outputs['recall@20'].tolist(), + 'ndcg@5': model_outputs['ndcg@5'].tolist(), + 'ndcg@10': model_outputs['ndcg@10'].tolist(), + 'ndcg@20': model_outputs['ndcg@20'].tolist(), + }, name='validation'), + cb.MetricAccumulator( + accumulators={ + 'validation/loss': cb.MeanAccumulator(), + 'validation/recall@5': cb.MeanAccumulator(), + 'validation/recall@10': cb.MeanAccumulator(), + 'validation/recall@20': cb.MeanAccumulator(), + 'validation/ndcg@5': cb.MeanAccumulator(), + 'validation/ndcg@10': cb.MeanAccumulator(), + 'validation/ndcg@20': cb.MeanAccumulator(), + }, + ), + ], + ).every_num_steps(EPOCH_NUM_STEPS), + + cb.Validation( + dataset=eval_dataloder, + callbacks=[ + cb.BatchMetrics(metrics=lambda model_outputs, _: { + 'loss': model_outputs['loss'].item(), + 'recall@5': model_outputs['recall@5'].tolist(), + 'recall@10': model_outputs['recall@10'].tolist(), + 'recall@20': model_outputs['recall@20'].tolist(), + 'ndcg@5': model_outputs['ndcg@5'].tolist(), + 'ndcg@10': model_outputs['ndcg@10'].tolist(), + 'ndcg@20': model_outputs['ndcg@20'].tolist(), + }, name='eval'), + cb.MetricAccumulator( + accumulators={ + 'eval/loss': cb.MeanAccumulator(), + 'eval/recall@5': cb.MeanAccumulator(), + 'eval/recall@10': cb.MeanAccumulator(), + 'eval/recall@20': cb.MeanAccumulator(), + 'eval/ndcg@5': cb.MeanAccumulator(), + 'eval/ndcg@10': cb.MeanAccumulator(), + 'eval/ndcg@20': cb.MeanAccumulator(), + }, + ), + ], + ).every_num_steps(EPOCH_NUM_STEPS * 4), + + cb.Logger().every_num_steps(EPOCH_NUM_STEPS), + cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=TENSORBOARD_LOGDIR), + + cb.EarlyStopping( + metric='validation/ndcg@20', + patience=40 * 4, + minimize=False, + model_path=os.path.join(CHECKPOINTS_DIR, EXPERIMENT_NAME) + ).every_num_steps(EPOCH_NUM_STEPS) + + # cb.Profiler( + # wait=10, + # warmup=10, + # active=10, + # logdir=TENSORBOARD_LOGDIR + # ), + # cb.StopAfterNumSteps(40) + + ] + + logger.debug('Everything is ready for training process!') + + runner = TrainingRunner( + model=model, + optimizer=optimizer, + dataset=train_dataloader, + callbacks=callbacks, + ) + runner.run() + + +if __name__ == '__main__': + main() diff --git a/scripts/tiger-lsvd/lsvd_varka_4.1_plum.py b/scripts/tiger-lsvd/lsvd_varka_4.1_plum.py new file mode 100644 index 0000000..cc35507 --- /dev/null +++ b/scripts/tiger-lsvd/lsvd_varka_4.1_plum.py @@ -0,0 +1,304 @@ +from collections import defaultdict +import json +import murmurhash +import numpy as np +import os +from pathlib import Path + +import pyarrow as pa +import pyarrow.feather as feather + +import torch + +from irec.data.transforms import Collate, Transform +from irec.data.dataloader import DataLoader + +from data import Dataset + +print("tiger no arrow varka 4.1") + +# ПУТИ + +IREC_PATH = '../../' +INTERACTIONS_TRAIN_PATH = "/home/jovyan/IRec/sigir/lsvd_data/8-days-base-ows/base_with_gap_interactions_grouped.parquet" +INTERACTIONS_VALID_PATH = "/home/jovyan/IRec/sigir/lsvd_data/8-days-base-ows/val_interactions_grouped.parquet" +INTERACTIONS_TEST_PATH = "/home/jovyan/IRec/sigir/lsvd_data/8-days-base-ows/test_interactions_grouped.parquet" + +SEMANTIC_MAPPING_PATH = os.path.join(IREC_PATH, 'results/4-1_vk_lsvd_ods_base_with_gap_cb_512_ws_2_k_2000_8w_e_35_clusters_colisionless.json') +TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd/8-weeks-base-one-week-split-4.1/train_batches/') +VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd/8-weeks-base-one-week-split-4.1/valid_batches/') +EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd/8-weeks-base-one-week-split-4.1/eval_batches/') + + +# ОСТАЛЬНОЕ + +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + + +MAX_TRAIN_EVENTS = 500 +MAX_SEQ_LEN = 20 +TRAIN_BATCH_SIZE = 256 +VALID_BATCH_SIZE = 1024 +NUM_USER_HASH = 2000 +CODEBOOK_SIZE = 512 +NUM_CODEBOOKS = 4 + +UNIFIED_VOCAB_SIZE = CODEBOOK_SIZE * NUM_CODEBOOKS + NUM_USER_HASH + 10 # 10 for utilities +PAD_TOKEN_ID = UNIFIED_VOCAB_SIZE - 1, +EOS_TOKEN_ID = UNIFIED_VOCAB_SIZE - 2, +DECODER_START_TOKEN_ID = UNIFIED_VOCAB_SIZE - 3, + + +class TigerProcessing(Transform): + def __call__(self, batch): + input_semantic_ids, attention_mask = batch['item.semantic.padded'], batch['item.semantic.mask'] + batch_size = attention_mask.shape[0] + + input_semantic_ids[~attention_mask] = PAD_TOKEN_ID # TODO ??? + + input_semantic_ids = np.concatenate([ + input_semantic_ids, + NUM_CODEBOOKS * CODEBOOK_SIZE + batch['user.hashed.ids'][:, None] + ], axis=-1) + + attention_mask = np.concatenate([ + attention_mask, + np.ones((batch_size, 1), dtype=attention_mask.dtype) + ], axis=-1) + + batch['input.data'] = input_semantic_ids + batch['input.mask'] = attention_mask + + target_semantic_ids = batch['labels.semantic.padded'] + target_semantic_ids = np.concatenate([ + np.ones( + (batch_size, 1), + dtype=np.int64, + ) * DECODER_START_TOKEN_ID, + target_semantic_ids + ], axis=-1) + + batch['output.data'] = target_semantic_ids + + return batch + + +class ToMasked(Transform): + def __init__(self, prefix, is_right_aligned=False): + self._prefix = prefix + self._is_right_aligned = is_right_aligned + + def __call__(self, batch): + data = batch[f'{self._prefix}.ids'] + lengths = batch[f'{self._prefix}.length'] + + batch_size = lengths.shape[0] + max_sequence_length = int(lengths.max()) + + if len(data.shape) == 1: # only indices + padded_tensor = np.zeros( + (batch_size, max_sequence_length), + dtype=data.dtype + ) # (batch_size, max_seq_len) + else: + assert len(data.shape) == 2 # embeddings + padded_tensor = np.zeros( + (batch_size, max_sequence_length, data.shape[-1]), + dtype=data.dtype + ) # (batch_size, max_seq_len, emb_dim) + + mask = np.arange(max_sequence_length)[None] < lengths[:, None] + + if self._is_right_aligned: + mask = np.flip(mask, axis=-1) + + padded_tensor[mask] = data + + batch[f'{self._prefix}.padded'] = padded_tensor + batch[f'{self._prefix}.mask'] = mask + + return batch + + +class SemanticIdsMapper(Transform): + def __init__(self, mapping, names=[]): + super().__init__() + self._mapping = mapping + self._names = names + + max_item_id = max(int(k) for k in mapping.keys()) + print(len(list(mapping.keys())), min(int(k) for k in mapping.keys()) , max(int(k) for k in mapping.keys())) + # print(mapping["280052"]) #304781 + # assert False + data = [] + for i in range(max_item_id + 1): + if str(i) in mapping: + data.append(mapping[str(i)]) + else: + data.append([-1] * NUM_CODEBOOKS) + + self._mapping_tensor = torch.tensor(data, dtype=torch.long) + self._semantic_length = self._mapping_tensor.shape[-1] + + missing_count = (max_item_id + 1) - len(mapping) + print(f"Mapping: {len(mapping)} items, {missing_count} missing (-1 filled)") + + def __call__(self, batch): + for name in self._names: + if f'{name}.ids' in batch: + ids = batch[f'{name}.ids'] + lengths = batch[f'{name}.length'] + assert ids.min() >= 0 + assert ids.max() < self._mapping_tensor.shape[0] + semantic_ids = self._mapping_tensor[ids].flatten() + + assert (semantic_ids != -1).all(), \ + f"Missing mappings detected in {name}! Invalid positions: {(semantic_ids == -1).sum()} out of {len(semantic_ids)}" + + batch[f'{name}.semantic.ids'] = semantic_ids.numpy() + batch[f'{name}.semantic.length'] = lengths * self._semantic_length + + return batch + + +class UserHashing(Transform): + def __init__(self, hash_size): + super().__init__() + self._hash_size = hash_size + + def __call__(self, batch): + batch['user.hashed.ids'] = np.array([murmurhash.hash(str(x)) % self._hash_size for x in batch['user.ids']], dtype=np.int64) + return batch + + +def save_batches_to_arrow(batches, output_dir): + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=False) + + for batch_idx, batch in enumerate(batches): + length_groups = defaultdict(dict) + metadata_groups = defaultdict(dict) + + for key, value in batch.items(): + length = len(value) + + metadata_groups[length][f'{key}_shape'] = str(value.shape) + metadata_groups[length][f'{key}_dtype'] = str(value.dtype) + + if value.ndim == 1: + # 1D массив - сохраняем как есть + length_groups[length][key] = value + elif value.ndim == 2: + # 2D массив - используем list of lists + length_groups[length][key] = value.tolist() + else: + # >2D массив - flatten и сохраняем shape + length_groups[length][key] = value.flatten() + + for length, fields in length_groups.items(): + arrow_dict = {} + for k, v in fields.items(): + if isinstance(v, list) and len(v) > 0 and isinstance(v[0], list): + # List of lists (2D) + arrow_dict[k] = pa.array(v) + else: + arrow_dict[k] = pa.array(v) + + table = pa.table(arrow_dict) + if length in metadata_groups: + table = table.replace_schema_metadata(metadata_groups[length]) + + feather.write_feather( + table, + output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow", + compression='lz4' + ) + + # arrow_dict = {k: pa.array(v) for k, v in fields.items()} + # table = pa.table(arrow_dict) + + # feather.write_feather( + # table, + # output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow", + # compression='lz4' + # ) + + +def main(): + with open(SEMANTIC_MAPPING_PATH, 'r') as f: + mappings = json.load(f) + print("варка может начать умирать") + data = Dataset.create_timestamp_based_parquet( + train_parquet_path=INTERACTIONS_TRAIN_PATH, + validation_parquet_path=INTERACTIONS_VALID_PATH, + test_parquet_path=INTERACTIONS_TEST_PATH, + max_sequence_length=MAX_SEQ_LEN, + sampler_type='tiger', + min_sample_len=2, + is_extended=True, + max_train_events=MAX_TRAIN_EVENTS + ) + + train_dataset, valid_dataset, eval_dataset = data.get_datasets() + print("варка не умерла") + train_dataloader = DataLoader( + dataset=train_dataset, + batch_size=TRAIN_BATCH_SIZE, + shuffle=True, + drop_last=True + ) \ + .map(Collate()) \ + .map(UserHashing(NUM_USER_HASH)) \ + .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \ + .map(ToMasked('item.semantic', is_right_aligned=True)) \ + .map(ToMasked('labels.semantic', is_right_aligned=True)) \ + .map(TigerProcessing()) + + valid_dataloader = DataLoader( + dataset=valid_dataset, + batch_size=VALID_BATCH_SIZE, + shuffle=False, + drop_last=False + ) \ + .map(Collate()) \ + .map(UserHashing(NUM_USER_HASH)) \ + .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \ + .map(ToMasked('item.semantic', is_right_aligned=True)) \ + .map(ToMasked('labels.semantic', is_right_aligned=True)) \ + .map(ToMasked('visited', is_right_aligned=True)) \ + .map(TigerProcessing()) + + eval_dataloader = DataLoader( + dataset=eval_dataset, + batch_size=VALID_BATCH_SIZE, + shuffle=False, + drop_last=False + ) \ + .map(Collate()) \ + .map(UserHashing(NUM_USER_HASH)) \ + .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \ + .map(ToMasked('item.semantic', is_right_aligned=True)) \ + .map(ToMasked('labels.semantic', is_right_aligned=True)) \ + .map(ToMasked('visited', is_right_aligned=True)) \ + .map(TigerProcessing()) + + train_batches = [] + for train_batch in train_dataloader: + train_batches.append(train_batch) + save_batches_to_arrow(train_batches, TRAIN_BATCHES_DIR) + + valid_batches = [] + for valid_batch in valid_dataloader: + valid_batches.append(valid_batch) + save_batches_to_arrow(valid_batches, VALID_BATCHES_DIR) + + eval_batches = [] + for eval_batch in eval_dataloader: + eval_batches.append(eval_batch) + save_batches_to_arrow(eval_batches, EVAL_BATCHES_DIR) + + + +if __name__ == '__main__': + main() diff --git a/scripts/tiger-lsvd/lsvd_varka_4.2_plum.py b/scripts/tiger-lsvd/lsvd_varka_4.2_plum.py new file mode 100644 index 0000000..7de54e4 --- /dev/null +++ b/scripts/tiger-lsvd/lsvd_varka_4.2_plum.py @@ -0,0 +1,304 @@ +from collections import defaultdict +import json +import murmurhash +import numpy as np +import os +from pathlib import Path + +import pyarrow as pa +import pyarrow.feather as feather + +import torch + +from irec.data.transforms import Collate, Transform +from irec.data.dataloader import DataLoader + +from data import Dataset + +print("tiger no arrow varka 4.1") + +# ПУТИ + +IREC_PATH = '../../' +INTERACTIONS_TRAIN_PATH = "/home/jovyan/IRec/sigir/lsvd_data/8-days-base-ows/base_with_gap_interactions_grouped.parquet" +INTERACTIONS_VALID_PATH = "/home/jovyan/IRec/sigir/lsvd_data/8-days-base-ows/val_interactions_grouped.parquet" +INTERACTIONS_TEST_PATH = "/home/jovyan/IRec/sigir/lsvd_data/8-days-base-ows/test_interactions_grouped.parquet" + +SEMANTIC_MAPPING_PATH = os.path.join(IREC_PATH, 'results/4-2_vk_lsvd_ods_base_with_gap_cb_512_ws_2_k_2000_8w_e_35_clusters_colisionless.json') +TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd/8-weeks-base-one-week-split-4.2/train_batches/') +VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd/8-weeks-base-one-week-split-4.2/valid_batches/') +EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd/8-weeks-base-one-week-split-4.2/eval_batches/') + + +# ОСТАЛЬНОЕ + +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + + +MAX_TRAIN_EVENTS = 500 +MAX_SEQ_LEN = 20 +TRAIN_BATCH_SIZE = 256 +VALID_BATCH_SIZE = 1024 +NUM_USER_HASH = 2000 +CODEBOOK_SIZE = 512 +NUM_CODEBOOKS = 4 + +UNIFIED_VOCAB_SIZE = CODEBOOK_SIZE * NUM_CODEBOOKS + NUM_USER_HASH + 10 # 10 for utilities +PAD_TOKEN_ID = UNIFIED_VOCAB_SIZE - 1, +EOS_TOKEN_ID = UNIFIED_VOCAB_SIZE - 2, +DECODER_START_TOKEN_ID = UNIFIED_VOCAB_SIZE - 3, + + +class TigerProcessing(Transform): + def __call__(self, batch): + input_semantic_ids, attention_mask = batch['item.semantic.padded'], batch['item.semantic.mask'] + batch_size = attention_mask.shape[0] + + input_semantic_ids[~attention_mask] = PAD_TOKEN_ID # TODO ??? + + input_semantic_ids = np.concatenate([ + input_semantic_ids, + NUM_CODEBOOKS * CODEBOOK_SIZE + batch['user.hashed.ids'][:, None] + ], axis=-1) + + attention_mask = np.concatenate([ + attention_mask, + np.ones((batch_size, 1), dtype=attention_mask.dtype) + ], axis=-1) + + batch['input.data'] = input_semantic_ids + batch['input.mask'] = attention_mask + + target_semantic_ids = batch['labels.semantic.padded'] + target_semantic_ids = np.concatenate([ + np.ones( + (batch_size, 1), + dtype=np.int64, + ) * DECODER_START_TOKEN_ID, + target_semantic_ids + ], axis=-1) + + batch['output.data'] = target_semantic_ids + + return batch + + +class ToMasked(Transform): + def __init__(self, prefix, is_right_aligned=False): + self._prefix = prefix + self._is_right_aligned = is_right_aligned + + def __call__(self, batch): + data = batch[f'{self._prefix}.ids'] + lengths = batch[f'{self._prefix}.length'] + + batch_size = lengths.shape[0] + max_sequence_length = int(lengths.max()) + + if len(data.shape) == 1: # only indices + padded_tensor = np.zeros( + (batch_size, max_sequence_length), + dtype=data.dtype + ) # (batch_size, max_seq_len) + else: + assert len(data.shape) == 2 # embeddings + padded_tensor = np.zeros( + (batch_size, max_sequence_length, data.shape[-1]), + dtype=data.dtype + ) # (batch_size, max_seq_len, emb_dim) + + mask = np.arange(max_sequence_length)[None] < lengths[:, None] + + if self._is_right_aligned: + mask = np.flip(mask, axis=-1) + + padded_tensor[mask] = data + + batch[f'{self._prefix}.padded'] = padded_tensor + batch[f'{self._prefix}.mask'] = mask + + return batch + + +class SemanticIdsMapper(Transform): + def __init__(self, mapping, names=[]): + super().__init__() + self._mapping = mapping + self._names = names + + max_item_id = max(int(k) for k in mapping.keys()) + print(len(list(mapping.keys())), min(int(k) for k in mapping.keys()) , max(int(k) for k in mapping.keys())) + # print(mapping["280052"]) #304781 + # assert False + data = [] + for i in range(max_item_id + 1): + if str(i) in mapping: + data.append(mapping[str(i)]) + else: + data.append([-1] * NUM_CODEBOOKS) + + self._mapping_tensor = torch.tensor(data, dtype=torch.long) + self._semantic_length = self._mapping_tensor.shape[-1] + + missing_count = (max_item_id + 1) - len(mapping) + print(f"Mapping: {len(mapping)} items, {missing_count} missing (-1 filled)") + + def __call__(self, batch): + for name in self._names: + if f'{name}.ids' in batch: + ids = batch[f'{name}.ids'] + lengths = batch[f'{name}.length'] + assert ids.min() >= 0 + assert ids.max() < self._mapping_tensor.shape[0] + semantic_ids = self._mapping_tensor[ids].flatten() + + assert (semantic_ids != -1).all(), \ + f"Missing mappings detected in {name}! Invalid positions: {(semantic_ids == -1).sum()} out of {len(semantic_ids)}" + + batch[f'{name}.semantic.ids'] = semantic_ids.numpy() + batch[f'{name}.semantic.length'] = lengths * self._semantic_length + + return batch + + +class UserHashing(Transform): + def __init__(self, hash_size): + super().__init__() + self._hash_size = hash_size + + def __call__(self, batch): + batch['user.hashed.ids'] = np.array([murmurhash.hash(str(x)) % self._hash_size for x in batch['user.ids']], dtype=np.int64) + return batch + + +def save_batches_to_arrow(batches, output_dir): + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=False) + + for batch_idx, batch in enumerate(batches): + length_groups = defaultdict(dict) + metadata_groups = defaultdict(dict) + + for key, value in batch.items(): + length = len(value) + + metadata_groups[length][f'{key}_shape'] = str(value.shape) + metadata_groups[length][f'{key}_dtype'] = str(value.dtype) + + if value.ndim == 1: + # 1D массив - сохраняем как есть + length_groups[length][key] = value + elif value.ndim == 2: + # 2D массив - используем list of lists + length_groups[length][key] = value.tolist() + else: + # >2D массив - flatten и сохраняем shape + length_groups[length][key] = value.flatten() + + for length, fields in length_groups.items(): + arrow_dict = {} + for k, v in fields.items(): + if isinstance(v, list) and len(v) > 0 and isinstance(v[0], list): + # List of lists (2D) + arrow_dict[k] = pa.array(v) + else: + arrow_dict[k] = pa.array(v) + + table = pa.table(arrow_dict) + if length in metadata_groups: + table = table.replace_schema_metadata(metadata_groups[length]) + + feather.write_feather( + table, + output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow", + compression='lz4' + ) + + # arrow_dict = {k: pa.array(v) for k, v in fields.items()} + # table = pa.table(arrow_dict) + + # feather.write_feather( + # table, + # output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow", + # compression='lz4' + # ) + + +def main(): + with open(SEMANTIC_MAPPING_PATH, 'r') as f: + mappings = json.load(f) + print("варка может начать умирать") + data = Dataset.create_timestamp_based_parquet( + train_parquet_path=INTERACTIONS_TRAIN_PATH, + validation_parquet_path=INTERACTIONS_VALID_PATH, + test_parquet_path=INTERACTIONS_TEST_PATH, + max_sequence_length=MAX_SEQ_LEN, + sampler_type='tiger', + min_sample_len=2, + is_extended=True, + max_train_events=MAX_TRAIN_EVENTS + ) + + train_dataset, valid_dataset, eval_dataset = data.get_datasets() + print("варка не умерла") + train_dataloader = DataLoader( + dataset=train_dataset, + batch_size=TRAIN_BATCH_SIZE, + shuffle=True, + drop_last=True + ) \ + .map(Collate()) \ + .map(UserHashing(NUM_USER_HASH)) \ + .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \ + .map(ToMasked('item.semantic', is_right_aligned=True)) \ + .map(ToMasked('labels.semantic', is_right_aligned=True)) \ + .map(TigerProcessing()) + + valid_dataloader = DataLoader( + dataset=valid_dataset, + batch_size=VALID_BATCH_SIZE, + shuffle=False, + drop_last=False + ) \ + .map(Collate()) \ + .map(UserHashing(NUM_USER_HASH)) \ + .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \ + .map(ToMasked('item.semantic', is_right_aligned=True)) \ + .map(ToMasked('labels.semantic', is_right_aligned=True)) \ + .map(ToMasked('visited', is_right_aligned=True)) \ + .map(TigerProcessing()) + + eval_dataloader = DataLoader( + dataset=eval_dataset, + batch_size=VALID_BATCH_SIZE, + shuffle=False, + drop_last=False + ) \ + .map(Collate()) \ + .map(UserHashing(NUM_USER_HASH)) \ + .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \ + .map(ToMasked('item.semantic', is_right_aligned=True)) \ + .map(ToMasked('labels.semantic', is_right_aligned=True)) \ + .map(ToMasked('visited', is_right_aligned=True)) \ + .map(TigerProcessing()) + + train_batches = [] + for train_batch in train_dataloader: + train_batches.append(train_batch) + save_batches_to_arrow(train_batches, TRAIN_BATCHES_DIR) + + valid_batches = [] + for valid_batch in valid_dataloader: + valid_batches.append(valid_batch) + save_batches_to_arrow(valid_batches, VALID_BATCHES_DIR) + + eval_batches = [] + for eval_batch in eval_dataloader: + eval_batches.append(eval_batch) + save_batches_to_arrow(eval_batches, EVAL_BATCHES_DIR) + + + +if __name__ == '__main__': + main() diff --git a/scripts/tiger-lsvd/lsvd_varka_rqvae.py b/scripts/tiger-lsvd/lsvd_varka_rqvae.py new file mode 100644 index 0000000..bb5ecc0 --- /dev/null +++ b/scripts/tiger-lsvd/lsvd_varka_rqvae.py @@ -0,0 +1,304 @@ +from collections import defaultdict +import json +import murmurhash +import numpy as np +import os +from pathlib import Path + +import pyarrow as pa +import pyarrow.feather as feather + +import torch + +from irec.data.transforms import Collate, Transform +from irec.data.dataloader import DataLoader + +from data import Dataset + +print("tiger no arrow varka 4.1") + +# ПУТИ + +IREC_PATH = '../../' +INTERACTIONS_TRAIN_PATH = "/home/jovyan/IRec/sigir/lsvd_data/8-weeks-base-ows/base_with_gap_interactions_grouped.parquet" +INTERACTIONS_VALID_PATH = "/home/jovyan/IRec/sigir/lsvd_data/8-weeks-base-ows/val_interactions_grouped.parquet" +INTERACTIONS_TEST_PATH = "/home/jovyan/IRec/sigir/lsvd_data/8-weeks-base-ows/test_interactions_grouped.parquet" + +SEMANTIC_MAPPING_PATH = os.path.join(IREC_PATH, 'results/rqvae_vk_lsvd_cz_512_8-weeks_clusters_colisionless.json') +TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd/8-weeks-base-one-week-split-rqvae/train_batches/') +VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd/8-weeks-base-one-week-split-rqvae/valid_batches/') +EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd/8-weeks-base-one-week-split-rqvae/eval_batches/') + + +# ОСТАЛЬНОЕ + +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + + +MAX_TRAIN_EVENTS = 500 +MAX_SEQ_LEN = 20 +TRAIN_BATCH_SIZE = 256 +VALID_BATCH_SIZE = 1024 +NUM_USER_HASH = 2000 +CODEBOOK_SIZE = 512 +NUM_CODEBOOKS = 4 + +UNIFIED_VOCAB_SIZE = CODEBOOK_SIZE * NUM_CODEBOOKS + NUM_USER_HASH + 10 # 10 for utilities +PAD_TOKEN_ID = UNIFIED_VOCAB_SIZE - 1, +EOS_TOKEN_ID = UNIFIED_VOCAB_SIZE - 2, +DECODER_START_TOKEN_ID = UNIFIED_VOCAB_SIZE - 3, + + +class TigerProcessing(Transform): + def __call__(self, batch): + input_semantic_ids, attention_mask = batch['item.semantic.padded'], batch['item.semantic.mask'] + batch_size = attention_mask.shape[0] + + input_semantic_ids[~attention_mask] = PAD_TOKEN_ID # TODO ??? + + input_semantic_ids = np.concatenate([ + input_semantic_ids, + NUM_CODEBOOKS * CODEBOOK_SIZE + batch['user.hashed.ids'][:, None] + ], axis=-1) + + attention_mask = np.concatenate([ + attention_mask, + np.ones((batch_size, 1), dtype=attention_mask.dtype) + ], axis=-1) + + batch['input.data'] = input_semantic_ids + batch['input.mask'] = attention_mask + + target_semantic_ids = batch['labels.semantic.padded'] + target_semantic_ids = np.concatenate([ + np.ones( + (batch_size, 1), + dtype=np.int64, + ) * DECODER_START_TOKEN_ID, + target_semantic_ids + ], axis=-1) + + batch['output.data'] = target_semantic_ids + + return batch + + +class ToMasked(Transform): + def __init__(self, prefix, is_right_aligned=False): + self._prefix = prefix + self._is_right_aligned = is_right_aligned + + def __call__(self, batch): + data = batch[f'{self._prefix}.ids'] + lengths = batch[f'{self._prefix}.length'] + + batch_size = lengths.shape[0] + max_sequence_length = int(lengths.max()) + + if len(data.shape) == 1: # only indices + padded_tensor = np.zeros( + (batch_size, max_sequence_length), + dtype=data.dtype + ) # (batch_size, max_seq_len) + else: + assert len(data.shape) == 2 # embeddings + padded_tensor = np.zeros( + (batch_size, max_sequence_length, data.shape[-1]), + dtype=data.dtype + ) # (batch_size, max_seq_len, emb_dim) + + mask = np.arange(max_sequence_length)[None] < lengths[:, None] + + if self._is_right_aligned: + mask = np.flip(mask, axis=-1) + + padded_tensor[mask] = data + + batch[f'{self._prefix}.padded'] = padded_tensor + batch[f'{self._prefix}.mask'] = mask + + return batch + + +class SemanticIdsMapper(Transform): + def __init__(self, mapping, names=[]): + super().__init__() + self._mapping = mapping + self._names = names + + max_item_id = max(int(k) for k in mapping.keys()) + print(len(list(mapping.keys())), min(int(k) for k in mapping.keys()) , max(int(k) for k in mapping.keys())) + # print(mapping["280052"]) #304781 + # assert False + data = [] + for i in range(max_item_id + 1): + if str(i) in mapping: + data.append(mapping[str(i)]) + else: + data.append([-1] * NUM_CODEBOOKS) + + self._mapping_tensor = torch.tensor(data, dtype=torch.long) + self._semantic_length = self._mapping_tensor.shape[-1] + + missing_count = (max_item_id + 1) - len(mapping) + print(f"Mapping: {len(mapping)} items, {missing_count} missing (-1 filled)") + + def __call__(self, batch): + for name in self._names: + if f'{name}.ids' in batch: + ids = batch[f'{name}.ids'] + lengths = batch[f'{name}.length'] + assert ids.min() >= 0 + assert ids.max() < self._mapping_tensor.shape[0] + semantic_ids = self._mapping_tensor[ids].flatten() + + assert (semantic_ids != -1).all(), \ + f"Missing mappings detected in {name}! Invalid positions: {(semantic_ids == -1).sum()} out of {len(semantic_ids)}" + + batch[f'{name}.semantic.ids'] = semantic_ids.numpy() + batch[f'{name}.semantic.length'] = lengths * self._semantic_length + + return batch + + +class UserHashing(Transform): + def __init__(self, hash_size): + super().__init__() + self._hash_size = hash_size + + def __call__(self, batch): + batch['user.hashed.ids'] = np.array([murmurhash.hash(str(x)) % self._hash_size for x in batch['user.ids']], dtype=np.int64) + return batch + + +def save_batches_to_arrow(batches, output_dir): + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=False) + + for batch_idx, batch in enumerate(batches): + length_groups = defaultdict(dict) + metadata_groups = defaultdict(dict) + + for key, value in batch.items(): + length = len(value) + + metadata_groups[length][f'{key}_shape'] = str(value.shape) + metadata_groups[length][f'{key}_dtype'] = str(value.dtype) + + if value.ndim == 1: + # 1D массив - сохраняем как есть + length_groups[length][key] = value + elif value.ndim == 2: + # 2D массив - используем list of lists + length_groups[length][key] = value.tolist() + else: + # >2D массив - flatten и сохраняем shape + length_groups[length][key] = value.flatten() + + for length, fields in length_groups.items(): + arrow_dict = {} + for k, v in fields.items(): + if isinstance(v, list) and len(v) > 0 and isinstance(v[0], list): + # List of lists (2D) + arrow_dict[k] = pa.array(v) + else: + arrow_dict[k] = pa.array(v) + + table = pa.table(arrow_dict) + if length in metadata_groups: + table = table.replace_schema_metadata(metadata_groups[length]) + + feather.write_feather( + table, + output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow", + compression='lz4' + ) + + # arrow_dict = {k: pa.array(v) for k, v in fields.items()} + # table = pa.table(arrow_dict) + + # feather.write_feather( + # table, + # output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow", + # compression='lz4' + # ) + + +def main(): + with open(SEMANTIC_MAPPING_PATH, 'r') as f: + mappings = json.load(f) + print("варка может начать умирать") + data = Dataset.create_timestamp_based_parquet( + train_parquet_path=INTERACTIONS_TRAIN_PATH, + validation_parquet_path=INTERACTIONS_VALID_PATH, + test_parquet_path=INTERACTIONS_TEST_PATH, + max_sequence_length=MAX_SEQ_LEN, + sampler_type='tiger', + min_sample_len=2, + is_extended=True, + max_train_events=MAX_TRAIN_EVENTS + ) + + train_dataset, valid_dataset, eval_dataset = data.get_datasets() + print("варка не умерла") + train_dataloader = DataLoader( + dataset=train_dataset, + batch_size=TRAIN_BATCH_SIZE, + shuffle=True, + drop_last=True + ) \ + .map(Collate()) \ + .map(UserHashing(NUM_USER_HASH)) \ + .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \ + .map(ToMasked('item.semantic', is_right_aligned=True)) \ + .map(ToMasked('labels.semantic', is_right_aligned=True)) \ + .map(TigerProcessing()) + + valid_dataloader = DataLoader( + dataset=valid_dataset, + batch_size=VALID_BATCH_SIZE, + shuffle=False, + drop_last=False + ) \ + .map(Collate()) \ + .map(UserHashing(NUM_USER_HASH)) \ + .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \ + .map(ToMasked('item.semantic', is_right_aligned=True)) \ + .map(ToMasked('labels.semantic', is_right_aligned=True)) \ + .map(ToMasked('visited', is_right_aligned=True)) \ + .map(TigerProcessing()) + + eval_dataloader = DataLoader( + dataset=eval_dataset, + batch_size=VALID_BATCH_SIZE, + shuffle=False, + drop_last=False + ) \ + .map(Collate()) \ + .map(UserHashing(NUM_USER_HASH)) \ + .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \ + .map(ToMasked('item.semantic', is_right_aligned=True)) \ + .map(ToMasked('labels.semantic', is_right_aligned=True)) \ + .map(ToMasked('visited', is_right_aligned=True)) \ + .map(TigerProcessing()) + + train_batches = [] + for train_batch in train_dataloader: + train_batches.append(train_batch) + save_batches_to_arrow(train_batches, TRAIN_BATCHES_DIR) + + valid_batches = [] + for valid_batch in valid_dataloader: + valid_batches.append(valid_batch) + save_batches_to_arrow(valid_batches, VALID_BATCHES_DIR) + + eval_batches = [] + for eval_batch in eval_dataloader: + eval_batches.append(eval_batch) + save_batches_to_arrow(eval_batches, EVAL_BATCHES_DIR) + + + +if __name__ == '__main__': + main() diff --git a/scripts/tiger-lsvd/models.py b/scripts/tiger-lsvd/models.py new file mode 100644 index 0000000..f89419e --- /dev/null +++ b/scripts/tiger-lsvd/models.py @@ -0,0 +1,223 @@ +import torch +from transformers import T5ForConditionalGeneration, T5Config, LogitsProcessor + +from irec.models import TorchModel + + +class CorrectItemsLogitsProcessor(LogitsProcessor): + def __init__(self, num_codebooks, codebook_size, mapping, num_beams, visited_items): + self.num_codebooks = num_codebooks + self.codebook_size = codebook_size + self.num_beams = num_beams + + semantic_ids = [] + for i in range(len(mapping)): + assert len(mapping[str(i)]) == num_codebooks, 'All semantic ids must have the same length' + semantic_ids.append(mapping[str(i)]) + + self.index_semantic_ids = torch.tensor(semantic_ids, dtype=torch.long, device=visited_items.device) # (num_items, semantic_ids) + + batch_size, _ = visited_items.shape + + self.index_semantic_ids = torch.tile(self.index_semantic_ids[None], dims=[batch_size, 1, 1]) # (batch_size, num_items, semantic_ids) + + index = visited_items[..., None].tile(dims=[1, 1, num_codebooks]) # (batch_size, num_rated, semantic_ids) + self.index_semantic_ids = torch.scatter( + input=self.index_semantic_ids, + dim=1, + index=index, + src=torch.zeros_like(index) + ) # (batch_size, num_items, semantic_ids) + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + next_sid_codebook_num = (torch.minimum((input_ids[:, -1].max() // self.codebook_size), torch.as_tensor(self.num_codebooks - 1)).item() + 1) % self.num_codebooks + a = torch.tile(self.index_semantic_ids[:, None, :, next_sid_codebook_num], dims=[1, self.num_beams, 1]) # (batch_size, num_beams, num_items) + a = a.reshape(a.shape[0] * a.shape[1], a.shape[2]) # (batch_size * num_beams, num_items) + + if next_sid_codebook_num != 0: + b = torch.tile(self.index_semantic_ids[:, None :, :next_sid_codebook_num], dims=[1, self.num_beams, 1, 1]) # (batch_size, num_beams, num_items, sid_len) + b = b.reshape(b.shape[0] * b.shape[1], b.shape[2], b.shape[3]) # (batch_size * num_beams, num_items, sid_len) + + current_prefixes = input_ids[:, -next_sid_codebook_num:] # (batch_size * num_beams, sid_len) + possible_next_items_mask = ( + torch.eq(current_prefixes[:, None, :], b).long().sum(dim=-1) == next_sid_codebook_num + ) # (batch_size * num_beams, num_items) + a[~possible_next_items_mask] = (next_sid_codebook_num + 1) * self.codebook_size + + scores_mask = torch.zeros_like(scores).bool() # (batch_size * num_beams, num_items) + scores_mask = torch.scatter_add( + input=scores_mask, + dim=-1, + index=a, + src=torch.ones_like(a).bool() + ) + + scores[:, :next_sid_codebook_num * self.codebook_size] = -torch.inf + scores[:, (next_sid_codebook_num + 1) * self.codebook_size:] = -torch.inf + scores[~(scores_mask.bool())] = -torch.inf + + return scores + + +class TigerModel(TorchModel): + def __init__( + self, + embedding_dim, + codebook_size, + sem_id_len, + num_positions, + user_ids_count, + num_heads, + num_encoder_layers, + num_decoder_layers, + dim_feedforward, + num_beams=100, + num_return_sequences=20, + d_kv=64, + layer_norm_eps=1e-6, + activation='relu', + dropout=0.1, + initializer_range=0.02, + logits_processor=None, + use_microbatching=False, + microbatch_size=128 + ): + super().__init__() + self._embedding_dim = embedding_dim + self._codebook_size = codebook_size + self._num_positions = num_positions + self._num_heads = num_heads + self._num_encoder_layers = num_encoder_layers + self._num_decoder_layers = num_decoder_layers + self._dim_feedforward = dim_feedforward + self._num_beams = num_beams + self._num_return_sequences = num_return_sequences + self._d_kv = d_kv + self._layer_norm_eps = layer_norm_eps + self._activation = activation + self._dropout = dropout + self._sem_id_len = sem_id_len + self.user_ids_count = user_ids_count + self.logits_processor = logits_processor + self._use_microbatching = use_microbatching + self._microbatch_size = microbatch_size + + unified_vocab_size = codebook_size * self._sem_id_len + self.user_ids_count + 10 # 10 for utilities + self.config = T5Config( + vocab_size=unified_vocab_size, + d_model=self._embedding_dim, + d_kv=self._d_kv, + d_ff=self._dim_feedforward, + num_layers=self._num_encoder_layers, + num_decoder_layers=self._num_decoder_layers, + num_heads=self._num_heads, + dropout_rate=self._dropout, + is_encoder_decoder=True, + use_cache=False, + pad_token_id=unified_vocab_size - 1, + eos_token_id=unified_vocab_size - 2, + decoder_start_token_id=unified_vocab_size - 3, + layer_norm_epsilon=self._layer_norm_eps, + feed_forward_proj=self._activation, + tie_word_embeddings=False + ) + self.model = T5ForConditionalGeneration(config=self.config) + self._init_weights(initializer_range) + + self.model = torch.compile( + self.model, + mode='reduce-overhead', + fullgraph=False, + dynamic=True + ) + + def forward(self, inputs): + input_semantic_ids = inputs['input.data'] + attention_mask = inputs['input.mask'] + target_semantic_ids = inputs['output.data'] + + decoder_input_ids = target_semantic_ids[:, :-1].contiguous() + labels = target_semantic_ids[:, 1:].contiguous() + + model_output = self.model( + input_ids=input_semantic_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + labels=labels + ) + loss = model_output['loss'] + + metrics = {'loss': loss.detach()} + + if not self.training and not self._use_microbatching: + visited_batch = inputs['visited.padded'] + + output = self.model.generate( + input_ids=input_semantic_ids, + attention_mask=attention_mask, + num_beams=self._num_beams, + num_return_sequences=self._num_return_sequences, + max_length=self._sem_id_len + 1, + decoder_start_token_id=self.config.decoder_start_token_id, + eos_token_id=self.config.eos_token_id, + pad_token_id=self.config.pad_token_id, + do_sample=False, + early_stopping=False, + logits_processor=[self.logits_processor(visited_items=visited_batch)] if self.logits_processor is not None else [], + ) + + predictions = output[:, 1:].reshape(-1, self._num_return_sequences, self._sem_id_len) + + all_hits = (torch.eq(predictions, labels[:, None]).sum(dim=-1)) # (batch_size, top_k) + elif not self.training and self._use_microbatching: + visited_batch = inputs['visited.padded'] + batch_size = input_semantic_ids.shape[0] + + inference_batch_size = self._microbatch_size # вместо полного batch_size + + all_predictions = [] + all_labels = [] + # print(f"start to infer batch of shape {input_semantic_ids.shape} with new batch {inference_batch_size}") + for batch_idx in range(0, batch_size, inference_batch_size): + batch_end = min(batch_idx + inference_batch_size, batch_size) + batch_slice = slice(batch_idx, batch_end) + + input_ids_batch = input_semantic_ids[batch_slice] + attention_mask_batch = attention_mask[batch_slice] + visited_batch_subset = visited_batch[batch_slice] + labels_batch = labels[batch_slice] + + with torch.inference_mode(): + output = self.model.generate( + input_ids=input_ids_batch, + attention_mask=attention_mask_batch, + num_beams=self._num_beams, + num_return_sequences=self._num_return_sequences, + max_length=self._sem_id_len + 1, + decoder_start_token_id=self.config.decoder_start_token_id, + eos_token_id=self.config.eos_token_id, + pad_token_id=self.config.pad_token_id, + do_sample=False, + early_stopping=False, + logits_processor=[self.logits_processor(visited_items=visited_batch_subset)] if self.logits_processor is not None else [], + ) + + predictions_batch = output[:, 1:].reshape(-1, self._num_return_sequences, self._sem_id_len) + all_predictions.append(predictions_batch) + all_labels.append(labels_batch) + # print("end infer of batch") + + predictions = torch.cat(all_predictions, dim=0) # (batch_size, num_return_sequences, sem_id_len) + labels_full = torch.cat(all_labels, dim=0) # (batch_size, sem_id_len) + all_hits = (torch.eq(predictions, labels_full[:, None]).sum(dim=-1)) # (batch_size, top_k) + + if not self.training: + for k in [5, 10, 20]: + hits = (all_hits[:, :k] == self._sem_id_len).float() # (batch_size, k) + recall = hits.sum(dim=-1) # (batch_size) + discount_factor = 1 / torch.log2(torch.arange(1, k + 1, 1).float() + 1.).to(hits.device) # (k) + + metrics[f'recall@{k}'] = recall.cpu().float() + metrics[f'ndcg@{k}'] = torch.einsum('bk,k->b', hits, discount_factor).cpu().float() + + return loss, metrics \ No newline at end of file diff --git a/scripts/tiger-yambda/data.py b/scripts/tiger-yambda/data.py new file mode 100644 index 0000000..87ff07d --- /dev/null +++ b/scripts/tiger-yambda/data.py @@ -0,0 +1,498 @@ +from collections import defaultdict +import json +from loguru import logger +import numpy as np +from pathlib import Path + + +import pyarrow as pa +import pyarrow.feather as feather + +import torch + +from irec.data.base import BaseDataset + + +class Dataset: + def __init__( + self, + train_sampler, + validation_sampler, + test_sampler, + num_items, + max_sequence_length + ): + self._train_sampler = train_sampler + self._validation_sampler = validation_sampler + self._test_sampler = test_sampler + self._num_items = num_items + self._max_sequence_length = max_sequence_length + + @classmethod + def create_timestamp_based( + cls, + train_json_path, + validation_json_path, + test_json_path, + max_sequence_length, + sampler_type, + min_sample_len=2, + is_extended=False, + max_train_events=50 + ): + max_item_id = 0 + train_dataset, validation_dataset, test_dataset = [], [], [] + print("started to load datasets") + with open(train_json_path, 'r') as f: + train_data = json.load(f) + with open(validation_json_path, 'r') as f: + validation_data = json.load(f) + with open(test_json_path, 'r') as f: + test_data = json.load(f) + + all_users = set(train_data.keys()) | set(validation_data.keys()) | set(test_data.keys()) + print(f"all users count: {len(all_users)}") + us_count = 0 + for user_id_str in all_users: + if us_count % 100 == 0: + print(f"user id {us_count}/{len(all_users)}: {user_id_str}") + user_id = int(user_id_str) + + train_items = train_data.get(user_id_str, []) + validation_items = validation_data.get(user_id_str, []) + test_items = test_data.get(user_id_str, []) + + full_sequence = train_items + validation_items + test_items + if full_sequence: + max_item_id = max(max_item_id, max(full_sequence)) + + if us_count % 100 == 0: + print(f"full sequence len: {len(full_sequence)}") + us_count += 1 + assert len(full_sequence) >= 2, f'Core-5 dataset is used, user {user_id} has only {len(full_sequence)} items' + + # Обрезаем train на последние max_train_events событий + train_items = train_items[-max_train_events:] if len(train_items) > max_train_events else train_items + + if is_extended: + # sample = [1, 2] + # sample = [1, 2, 3] + # sample = [1, 2, 3, 4] + # sample = [1, 2, 3, 4, 5] + # sample = [1, 2, 3, 4, 5, 6] + # sample = [1, 2, 3, 4, 5, 6, 7] + # sample = [1, 2, 3, 4, 5, 6, 7, 8] + for prefix_length in range(min_sample_len, len(train_items) + 1): + train_dataset.append({ + 'user.ids': [user_id], + 'item.ids': train_items[:prefix_length], + }) + else: + # sample = [1, 2, 3, 4, 5, 6, 7, 8] + train_dataset.append({ + 'user.ids': [user_id], + 'item.ids': train_items, + }) + + # валидация + + # разворачиваем каждый айтем из валидации в отдельный сэмпл + # Пример: Train=[1,2], Valid=[3,4] + # sample = [1, 2, 3] + # sample = [1, 2, 3, 4] + + current_history = train_items.copy() + valid_small_history = 0 + for item in validation_items: + # эвал датасет сам отрезает таргет потом + sample_sequence = current_history + [item] + + if len(sample_sequence) >= min_sample_len: + validation_dataset.append({ + 'user.ids': [user_id], + 'item.ids': sample_sequence, + }) + else: + valid_small_history += 1 + current_history.append(item) + + # разворачиваем каждый айтем из теста в отдельный сэмпл + # Пример: Train=[1,2], Valid=[3,4], Test=[5, 6] + # sample = [1, 2, 3, 4, 5] + # sample = [1, 2, 3, 4, 5, 6] + current_history = train_items + validation_items + test_small_history = 0 + for item in test_items: + sample_sequence = current_history + [item] + if len(sample_sequence) >= min_sample_len: + test_dataset.append({ + 'user.ids': [user_id], + 'item.ids': sample_sequence, + }) + else: + test_small_history += 1 + current_history.append(item) + + print(f"Train dataset size: {len(train_dataset)}") + print(f"Validation dataset size: {len(validation_dataset)} with skipped {valid_small_history}") + print(f"Test dataset size: {len(test_dataset)} with skipped {test_small_history}") + + logger.debug(f'Train dataset size: {len(train_dataset)}') + logger.debug(f'Validation dataset size: {len(validation_dataset)}') + logger.debug(f'Test dataset size: {len(test_dataset)}') + + train_sampler = TrainDataset(train_dataset, sampler_type, max_sequence_length=max_sequence_length) + validation_sampler = EvalDataset(validation_dataset, max_sequence_length=max_sequence_length) + test_sampler = EvalDataset(test_dataset, max_sequence_length=max_sequence_length) + + return cls( + train_sampler=train_sampler, + validation_sampler=validation_sampler, + test_sampler=test_sampler, + num_items=max_item_id + 1, # +1 added because our ids are 0-indexed + max_sequence_length=max_sequence_length + ) + + @classmethod + def create_timestamp_based_with_one_valid( + cls, + train_json_path, + validation_json_path, + test_json_path, + max_sequence_length, + sampler_type, + min_sample_len=2, + is_extended=False, + max_train_events=50, + max_valid_events=50 + ): + max_item_id = 0 + train_dataset, validation_dataset, test_dataset = [], [], [] + print("started to load datasets") + with open(train_json_path, 'r') as f: + train_data = json.load(f) + with open(validation_json_path, 'r') as f: + validation_data = json.load(f) + with open(test_json_path, 'r') as f: + test_data = json.load(f) + + all_users = set(train_data.keys()) | set(validation_data.keys()) | set(test_data.keys()) + print(f"all users count: {len(all_users)}") + us_count = 0 + for user_id_str in all_users: + if us_count % 100 == 0: + print(f"user id {us_count}/{len(all_users)}: {user_id_str}") + user_id = int(user_id_str) + + train_items = train_data.get(user_id_str, []) + validation_items = validation_data.get(user_id_str, []) + test_items = test_data.get(user_id_str, []) + + full_sequence = train_items + validation_items + test_items + if full_sequence: + max_item_id = max(max_item_id, max(full_sequence)) + + if us_count % 100 == 0: + print(f"full sequence len: {len(full_sequence)}") + + assert len(full_sequence) >= 2, f'Core-5 dataset is used, user {user_id} has only {len(full_sequence)} items' + + # Обрезаем train на последние max_train_events событий + train_items = train_items[-max_train_events:] if len(train_items) > max_train_events else train_items + + if is_extended: + # sample = [1, 2] + # sample = [1, 2, 3] + # sample = [1, 2, 3, 4] + # sample = [1, 2, 3, 4, 5] + # sample = [1, 2, 3, 4, 5, 6] + # sample = [1, 2, 3, 4, 5, 6, 7] + # sample = [1, 2, 3, 4, 5, 6, 7, 8] + for prefix_length in range(min_sample_len, len(train_items) + 1): + train_dataset.append({ + 'user.ids': [user_id], + 'item.ids': train_items[:prefix_length], + }) + else: + # sample = [1, 2, 3, 4, 5, 6, 7, 8] + train_dataset.append({ + 'user.ids': [user_id], + 'item.ids': train_items, + }) + + # валидация + + # разворачиваем каждый айтем из валидации в отдельный сэмпл + # Пример: Train=[1,2], Valid=[3,4] + # sample = [1, 2, 3] + # sample = [1, 2, 3, 4] + + current_history = train_items.copy() + if us_count % 100 == 0: + print(f"validation data length {len(validation_items[:max_valid_events])}") + us_count += 1 + for item in validation_items[:max_valid_events]: + # эвал датасет сам отрезает таргет потом + sample_sequence = current_history + [item] + + if len(sample_sequence) >= min_sample_len: + validation_dataset.append({ + 'user.ids': [user_id], + 'item.ids': sample_sequence, + }) + current_history.append(item) + + # разворачиваем каждый айтем из теста в отдельный сэмпл + # Пример: Train=[1,2], Valid=[3,4], Test=[5, 6] + # sample = [1, 2, 3, 4, 5] + # sample = [1, 2, 3, 4, 5, 6] + current_history = train_items + validation_items + for item in test_items: + sample_sequence = current_history + [item] + if len(sample_sequence) >= min_sample_len: + test_dataset.append({ + 'user.ids': [user_id], + 'item.ids': sample_sequence, + }) + current_history.append(item) + + print(f"Train dataset size: {len(train_dataset)}") + print(f"Validation dataset size: {len(validation_dataset)}") + print(f"Test dataset size: {len(test_dataset)}") + + logger.debug(f'Train dataset size: {len(train_dataset)}') + logger.debug(f'Validation dataset size: {len(validation_dataset)}') + logger.debug(f'Test dataset size: {len(test_dataset)}') + + train_sampler = TrainDataset(train_dataset, sampler_type, max_sequence_length=max_sequence_length) + validation_sampler = EvalDataset(validation_dataset, max_sequence_length=max_sequence_length) + test_sampler = EvalDataset(test_dataset, max_sequence_length=max_sequence_length) + + return cls( + train_sampler=train_sampler, + validation_sampler=validation_sampler, + test_sampler=test_sampler, + num_items=max_item_id + 1, # +1 added because our ids are 0-indexed + max_sequence_length=max_sequence_length + ) + + @classmethod + def create(cls, inter_json_path, max_sequence_length, sampler_type, is_extended=False): + max_item_id = 0 + train_dataset, validation_dataset, test_dataset = [], [], [] + + with open(inter_json_path, 'r') as f: + user_interactions = json.load(f) + + for user_id_str, item_ids in user_interactions.items(): + user_id = int(user_id_str) + + if item_ids: + max_item_id = max(max_item_id, max(item_ids)) + + assert len(item_ids) >= 5, f'Core-5 dataset is used, user {user_id} has only {len(item_ids)} items' + + # sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] (leave one out scheme, 8 - train, 9 - valid, 10 - test) + if is_extended: + # sample = [1, 2] + # sample = [1, 2, 3] + # sample = [1, 2, 3, 4] + # sample = [1, 2, 3, 4, 5] + # sample = [1, 2, 3, 4, 5, 6] + # sample = [1, 2, 3, 4, 5, 6, 7] + # sample = [1, 2, 3, 4, 5, 6, 7, 8] + for prefix_length in range(2, len(item_ids) - 2 + 1): + train_dataset.append({ + 'user.ids': [user_id], + 'item.ids': item_ids[:prefix_length], + }) + else: + # sample = [1, 2, 3, 4, 5, 6, 7, 8] + train_dataset.append({ + 'user.ids': [user_id], + 'item.ids': item_ids[:-2], + }) + + # sample = [1, 2, 3, 4, 5, 6, 7, 8, 9] + validation_dataset.append({ + 'user.ids': [user_id], + 'item.ids': item_ids[:-1], + }) + + # sample = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + test_dataset.append({ + 'user.ids': [user_id], + 'item.ids': item_ids, + }) + + logger.debug(f'Train dataset size: {len(train_dataset)}') + logger.debug(f'Validation dataset size: {len(validation_dataset)}') + logger.debug(f'Test dataset size: {len(test_dataset)}') + logger.debug(f'Max item id: {max_item_id}') + + train_sampler = TrainDataset(train_dataset, sampler_type, max_sequence_length=max_sequence_length) + validation_sampler = EvalDataset(validation_dataset, max_sequence_length=max_sequence_length) + test_sampler = EvalDataset(test_dataset, max_sequence_length=max_sequence_length) + + return cls( + train_sampler=train_sampler, + validation_sampler=validation_sampler, + test_sampler=test_sampler, + num_items=max_item_id + 1, # +1 added because our ids are 0-indexed + max_sequence_length=max_sequence_length + ) + + def get_datasets(self): + return self._train_sampler, self._validation_sampler, self._test_sampler + + @property + def num_items(self): + return self._num_items + + @property + def max_sequence_length(self): + return self._max_sequence_length + + +class TrainDataset(BaseDataset): + def __init__(self, dataset, prediction_type, max_sequence_length): + self._dataset = dataset + self._prediction_type = prediction_type + self._max_sequence_length = max_sequence_length + + self._transforms = { + 'sasrec': self._all_items_transform, + 'tiger': self._last_item_transform + } + + def _all_items_transform(self, sample): + item_sequence = sample['item.ids'][-self._max_sequence_length:][:-1] + next_item_sequence = sample['item.ids'][-self._max_sequence_length:][1:] + return { + 'user.ids': np.array(sample['user.ids'], dtype=np.int64), + 'user.length': np.array([len(sample['user.ids'])], dtype=np.int64), + 'item.ids': np.array(item_sequence, dtype=np.int64), + 'item.length': np.array([len(item_sequence)], dtype=np.int64), + 'labels.ids': np.array(next_item_sequence, dtype=np.int64), + 'labels.length': np.array([len(next_item_sequence)], dtype=np.int64) + } + + def _last_item_transform(self, sample): + item_sequence = sample['item.ids'][-self._max_sequence_length:][:-1] + last_item = sample['item.ids'][-self._max_sequence_length:][-1] + return { + 'user.ids': np.array(sample['user.ids'], dtype=np.int64), + 'user.length': np.array([len(sample['user.ids'])], dtype=np.int64), + 'item.ids': np.array(item_sequence, dtype=np.int64), + 'item.length': np.array([len(item_sequence)], dtype=np.int64), + 'labels.ids': np.array([last_item], dtype=np.int64), + 'labels.length': np.array([1], dtype=np.int64), + } + + def __getitem__(self, index): + return self._transforms[self._prediction_type](self._dataset[index]) + + def __len__(self): + return len(self._dataset) + + +class EvalDataset(BaseDataset): + def __init__(self, dataset, max_sequence_length): + self._dataset = dataset + self._max_sequence_length = max_sequence_length + + @property + def dataset(self): + return self._dataset + + def __len__(self): + return len(self._dataset) + + def __getitem__(self, index): + sample = self._dataset[index] + + item_sequence = sample['item.ids'][-self._max_sequence_length:][:-1] + next_item = sample['item.ids'][-self._max_sequence_length:][-1] + + return { + 'user.ids': np.array(sample['user.ids'], dtype=np.int64), + 'user.length': np.array([len(sample['user.ids'])], dtype=np.int64), + 'item.ids': np.array(item_sequence, dtype=np.int64), + 'item.length': np.array([len(item_sequence)], dtype=np.int64), + 'labels.ids': np.array([next_item], dtype=np.int64), + 'labels.length': np.array([1], dtype=np.int64), + 'visited.ids': np.array(sample['item.ids'][:-1], dtype=np.int64), + 'visited.length': np.array([len(sample['item.ids'][:-1])], dtype=np.int64), + } + + +class ArrowBatchDataset(BaseDataset): + def __init__(self, batch_dir, device='cuda', preload=False): + self.batch_dir = Path(batch_dir) + self.device = device + + all_files = list(self.batch_dir.glob('batch_*_len_*.arrow')) + + batch_files_map = defaultdict(list) + for f in all_files: + batch_id = int(f.stem.split('_')[1]) + batch_files_map[batch_id].append(f) + + for batch_id in batch_files_map: + batch_files_map[batch_id].sort() + + self.batch_indices = sorted(batch_files_map.keys()) + + if preload: + print(f"Preloading {len(self.batch_indices)} batches...") + self.cached_batches = [] + + for idx in range(len(self.batch_indices)): + batch = self._load_batch(batch_files_map[self.batch_indices[idx]]) + self.cached_batches.append(batch) + else: + self.cached_batches = None + self.batch_files_map = batch_files_map + + def _load_batch(self, arrow_files): + batch = {} + + for arrow_file in arrow_files: + table = feather.read_table(arrow_file) + metadata = table.schema.metadata or {} + + for col_name in table.column_names: + col = table.column(col_name) + + shape_key = f'{col_name}_shape' + dtype_key = f'{col_name}_dtype' + + if shape_key.encode() in metadata: + shape = eval(metadata[shape_key.encode()].decode()) + dtype = np.dtype(metadata[dtype_key.encode()].decode()) + + # Проверяем тип колонки + if pa.types.is_list(col.type) or pa.types.is_large_list(col.type): + arr = np.array(col.to_pylist(), dtype=dtype) + else: + arr = col.to_numpy().reshape(shape).astype(dtype) + else: + if pa.types.is_list(col.type) or pa.types.is_large_list(col.type): + arr = np.array(col.to_pylist()) + else: + arr = col.to_numpy() + + batch[col_name] = torch.from_numpy(arr.copy()).to(self.device) + + return batch + + def __len__(self): + return len(self.batch_indices) + + def __getitem__(self, idx): + if self.cached_batches is not None: + return self.cached_batches[idx] + else: + batch_id = self.batch_indices[idx] + arrow_files = self.batch_files_map[batch_id] + return self._load_batch(arrow_files) diff --git a/scripts/tiger-yambda/models.py b/scripts/tiger-yambda/models.py new file mode 100644 index 0000000..8fd0f76 --- /dev/null +++ b/scripts/tiger-yambda/models.py @@ -0,0 +1,223 @@ +import torch +from transformers import T5ForConditionalGeneration, T5Config, LogitsProcessor + +from irec.models import TorchModel + + +class CorrectItemsLogitsProcessor(LogitsProcessor): + def __init__(self, num_codebooks, codebook_size, mapping, num_beams, visited_items): + self.num_codebooks = num_codebooks + self.codebook_size = codebook_size + self.num_beams = num_beams + + semantic_ids = [] + for i in range(len(mapping)): + assert len(mapping[str(i)]) == num_codebooks, 'All semantic ids must have the same length' + semantic_ids.append(mapping[str(i)]) + + self.index_semantic_ids = torch.tensor(semantic_ids, dtype=torch.long, device=visited_items.device) # (num_items, semantic_ids) + + batch_size, _ = visited_items.shape + + self.index_semantic_ids = torch.tile(self.index_semantic_ids[None], dims=[batch_size, 1, 1]) # (batch_size, num_items, semantic_ids) + + index = visited_items[..., None].tile(dims=[1, 1, num_codebooks]) # (batch_size, num_rated, semantic_ids) + self.index_semantic_ids = torch.scatter( + input=self.index_semantic_ids, + dim=1, + index=index, + src=torch.zeros_like(index) + ) # (batch_size, num_items, semantic_ids) + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + next_sid_codebook_num = (torch.minimum((input_ids[:, -1].max() // self.codebook_size), torch.as_tensor(self.num_codebooks - 1)).item() + 1) % self.num_codebooks + a = torch.tile(self.index_semantic_ids[:, None, :, next_sid_codebook_num], dims=[1, self.num_beams, 1]) # (batch_size, num_beams, num_items) + a = a.reshape(a.shape[0] * a.shape[1], a.shape[2]) # (batch_size * num_beams, num_items) + + if next_sid_codebook_num != 0: + b = torch.tile(self.index_semantic_ids[:, None :, :next_sid_codebook_num], dims=[1, self.num_beams, 1, 1]) # (batch_size, num_beams, num_items, sid_len) + b = b.reshape(b.shape[0] * b.shape[1], b.shape[2], b.shape[3]) # (batch_size * num_beams, num_items, sid_len) + + current_prefixes = input_ids[:, -next_sid_codebook_num:] # (batch_size * num_beams, sid_len) + possible_next_items_mask = ( + torch.eq(current_prefixes[:, None, :], b).long().sum(dim=-1) == next_sid_codebook_num + ) # (batch_size * num_beams, num_items) + a[~possible_next_items_mask] = (next_sid_codebook_num + 1) * self.codebook_size + + scores_mask = torch.zeros_like(scores).bool() # (batch_size * num_beams, num_items) + scores_mask = torch.scatter_add( + input=scores_mask, + dim=-1, + index=a, + src=torch.ones_like(a).bool() + ) + + scores[:, :next_sid_codebook_num * self.codebook_size] = -torch.inf + scores[:, (next_sid_codebook_num + 1) * self.codebook_size:] = -torch.inf + scores[~(scores_mask.bool())] = -torch.inf + + return scores + + +class TigerModel(TorchModel): + def __init__( + self, + embedding_dim, + codebook_size, + sem_id_len, + num_positions, + user_ids_count, + num_heads, + num_encoder_layers, + num_decoder_layers, + dim_feedforward, + num_beams=100, + num_return_sequences=20, + d_kv=64, + layer_norm_eps=1e-6, + activation='relu', + dropout=0.1, + initializer_range=0.02, + logits_processor=None, + use_microbatching=False, + microbatch_size=128 + ): + super().__init__() + self._embedding_dim = embedding_dim + self._codebook_size = codebook_size + self._num_positions = num_positions + self._num_heads = num_heads + self._num_encoder_layers = num_encoder_layers + self._num_decoder_layers = num_decoder_layers + self._dim_feedforward = dim_feedforward + self._num_beams = num_beams + self._num_return_sequences = num_return_sequences + self._d_kv = d_kv + self._layer_norm_eps = layer_norm_eps + self._activation = activation + self._dropout = dropout + self._sem_id_len = sem_id_len + self.user_ids_count = user_ids_count + self.logits_processor = logits_processor + self._use_microbatching = use_microbatching + self._microbatch_size = microbatch_size + + unified_vocab_size = codebook_size * self._sem_id_len + self.user_ids_count + 10 # 10 for utilities + self.config = T5Config( + vocab_size=unified_vocab_size, + d_model=self._embedding_dim, + d_kv=self._d_kv, + d_ff=self._dim_feedforward, + num_layers=self._num_encoder_layers, + num_decoder_layers=self._num_decoder_layers, + num_heads=self._num_heads, + dropout_rate=self._dropout, + is_encoder_decoder=True, + use_cache=False, + pad_token_id=unified_vocab_size - 1, + eos_token_id=unified_vocab_size - 2, + decoder_start_token_id=unified_vocab_size - 3, + layer_norm_epsilon=self._layer_norm_eps, + feed_forward_proj=self._activation, + tie_word_embeddings=False + ) + self.model = T5ForConditionalGeneration(config=self.config) + self._init_weights(initializer_range) + + self.model = torch.compile( + self.model, + mode='reduce-overhead', + fullgraph=False, + dynamic=True + ) + + def forward(self, inputs): + input_semantic_ids = inputs['input.data'] + attention_mask = inputs['input.mask'] + target_semantic_ids = inputs['output.data'] + + decoder_input_ids = target_semantic_ids[:, :-1].contiguous() + labels = target_semantic_ids[:, 1:].contiguous() + + model_output = self.model( + input_ids=input_semantic_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + labels=labels + ) + loss = model_output['loss'] + + metrics = {'loss': loss.detach()} + + if not self.training and not self._use_microbatching: + visited_batch = inputs['visited.padded'] + + output = self.model.generate( + input_ids=input_semantic_ids, + attention_mask=attention_mask, + num_beams=self._num_beams, + num_return_sequences=self._num_return_sequences, + max_length=self._sem_id_len + 1, + decoder_start_token_id=self.config.decoder_start_token_id, + eos_token_id=self.config.eos_token_id, + pad_token_id=self.config.pad_token_id, + do_sample=False, + early_stopping=False, + logits_processor=[self.logits_processor(visited_items=visited_batch)] if self.logits_processor is not None else [], + ) + + predictions = output[:, 1:].reshape(-1, self._num_return_sequences, self._sem_id_len) + + all_hits = (torch.eq(predictions, labels[:, None]).sum(dim=-1)) # (batch_size, top_k) + elif not self.training and self._use_microbatching: + visited_batch = inputs['visited.padded'] + batch_size = input_semantic_ids.shape[0] + + inference_batch_size = self._microbatch_size # вместо полного batch_size + + all_predictions = [] + all_labels = [] + # print(f"start to infer batch of shape {input_semantic_ids.shape} with new batch {inference_batch_size}") + for batch_idx in range(0, batch_size, inference_batch_size): + batch_end = min(batch_idx + inference_batch_size, batch_size) + batch_slice = slice(batch_idx, batch_end) + + input_ids_batch = input_semantic_ids[batch_slice] + attention_mask_batch = attention_mask[batch_slice] + visited_batch_subset = visited_batch[batch_slice] + labels_batch = labels[batch_slice] + + with torch.inference_mode(): + output = self.model.generate( + input_ids=input_ids_batch, + attention_mask=attention_mask_batch, + num_beams=self._num_beams, + num_return_sequences=self._num_return_sequences, + max_length=self._sem_id_len + 1, + decoder_start_token_id=self.config.decoder_start_token_id, + eos_token_id=self.config.eos_token_id, + pad_token_id=self.config.pad_token_id, + do_sample=False, + early_stopping=False, + logits_processor=[self.logits_processor(visited_items=visited_batch_subset)] if self.logits_processor is not None else [], + ) + + predictions_batch = output[:, 1:].reshape(-1, self._num_return_sequences, self._sem_id_len) + all_predictions.append(predictions_batch) + all_labels.append(labels_batch) + # print("end infer of batch") + + predictions = torch.cat(all_predictions, dim=0) # (batch_size, num_return_sequences, sem_id_len) + labels_full = torch.cat(all_labels, dim=0) # (batch_size, sem_id_len) + all_hits = (torch.eq(predictions, labels_full[:, None]).sum(dim=-1)) # (batch_size, top_k) + + if not self.training: + for k in [5, 10, 20]: + hits = (all_hits[:, :k] == self._sem_id_len).float() # (batch_size, k) + recall = hits.sum(dim=-1) # (batch_size) + discount_factor = 1 / torch.log2(torch.arange(1, k + 1, 1).float() + 1.).to(hits.device) # (k) + + metrics[f'recall@{k}'] = recall.cpu().float() + metrics[f'ndcg@{k}'] = torch.einsum('bk,k->b', hits, discount_factor).cpu().float() + + return loss, metrics \ No newline at end of file diff --git a/scripts/tiger-yambda/yambda_train_4.1_plum.py b/scripts/tiger-yambda/yambda_train_4.1_plum.py new file mode 100644 index 0000000..607c0e3 --- /dev/null +++ b/scripts/tiger-yambda/yambda_train_4.1_plum.py @@ -0,0 +1,230 @@ +from functools import partial +import json +from loguru import logger +import os + +import torch + +import irec.callbacks as cb +from irec.data.transforms import Collate, ToDevice +from irec.data.dataloader import DataLoader +from irec.runners import TrainingRunner +from irec.utils import fix_random_seed + +from data import ArrowBatchDataset +from models import TigerModel, CorrectItemsLogitsProcessor + + +# ПУТИ +IREC_PATH = '../../' +SEMANTIC_MAPPING_PATH = os.path.join(IREC_PATH, 'results_sigir_yambda/4-1_filtered_yambda_gpu_quantile_ws_2_clusters_colisionless.json') +TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Yambda/day-splits/test/yambda_quantile_tiger_T_train_batches/') +VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Yambda/day-splits/test/yambda_quantile_tiger_T_valid_batches/') +EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Yambda/day-splits/test/yambda_quantile_tiger_T_eval_batches/') + + +TENSORBOARD_LOGDIR = os.path.join(IREC_PATH, 'tensorboard_logs') +CHECKPOINTS_DIR = os.path.join(IREC_PATH, 'checkpoints') + +EXPERIMENT_NAME = 'TEST_tiger_yambda_filtered_day-split_plum_ws_2_dp_0.2_max_300_256_1024' + +# ОСТАЛЬНОЕ +SEED_VALUE = 42 +DEVICE = 'cuda' + +NUM_EPOCHS = 300 +MAX_SEQ_LEN = 20 +TRAIN_BATCH_SIZE = 256 +VALID_BATCH_SIZE = 1024 +EMBEDDING_DIM = 128 +CODEBOOK_SIZE = 256 +NUM_POSITIONS = 20 +NUM_USER_HASH = 2000 +NUM_HEADS = 6 +NUM_LAYERS = 4 +FEEDFORWARD_DIM = 1024 +KV_DIM = 64 +DROPOUT = 0.2 +NUM_BEAMS = 30 +TOP_K = 20 +NUM_CODEBOOKS = 4 +LR = 0.0001 + +USE_MICROBATCHING = True +MICROBATCH_SIZE = 128 + +torch.set_float32_matmul_precision('high') +torch._dynamo.config.capture_scalar_outputs = True + +import torch._inductor.config as config +config.triton.cudagraph_skip_dynamic_graphs = True + + +def main(): + fix_random_seed(SEED_VALUE) + + with open(SEMANTIC_MAPPING_PATH, 'r') as f: + mappings = json.load(f) + + train_dataloader = DataLoader( + ArrowBatchDataset( + TRAIN_BATCHES_DIR, + device='cpu', + preload=True + ), + batch_size=1, + shuffle=True, + num_workers=0, + pin_memory=True, + collate_fn=Collate() + ).map(ToDevice(DEVICE)).repeat(NUM_EPOCHS) + + valid_dataloder = ArrowBatchDataset( + VALID_BATCHES_DIR, + device=DEVICE, + preload=True + ) + + eval_dataloder = ArrowBatchDataset( + EVAL_BATCHES_DIR, + device=DEVICE, + preload=True + ) + + model = TigerModel( + embedding_dim=EMBEDDING_DIM, + codebook_size=CODEBOOK_SIZE, + sem_id_len=NUM_CODEBOOKS, + user_ids_count=NUM_USER_HASH, + num_positions=NUM_POSITIONS, + num_heads=NUM_HEADS, + num_encoder_layers=NUM_LAYERS, + num_decoder_layers=NUM_LAYERS, + dim_feedforward=FEEDFORWARD_DIM, + num_beams=NUM_BEAMS, + num_return_sequences=TOP_K, + activation='relu', + d_kv=KV_DIM, + dropout=DROPOUT, + layer_norm_eps=1e-6, + initializer_range=0.02, + logits_processor=partial( + CorrectItemsLogitsProcessor, + NUM_CODEBOOKS, + CODEBOOK_SIZE, + mappings, + NUM_BEAMS + ), + use_microbatching=USE_MICROBATCHING, + microbatch_size=MICROBATCH_SIZE + ).to(DEVICE) + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + optimizer = torch.optim.AdamW( + model.parameters(), + lr=LR, + ) + + EPOCH_NUM_STEPS = 1024 # int(len(train_dataloader) // NUM_EPOCHS) + + callbacks = [ + cb.BatchMetrics(metrics=lambda model_outputs, _: { + 'loss': model_outputs['loss'].item(), + }, name='train'), + cb.MetricAccumulator( + accumulators={ + 'train/loss': cb.MeanAccumulator(), + }, + reset_every_num_steps=EPOCH_NUM_STEPS + ), + + cb.Validation( + dataset=valid_dataloder, + callbacks=[ + cb.BatchMetrics(metrics=lambda model_outputs, _:{ + 'loss': model_outputs['loss'].item(), + 'recall@5': model_outputs['recall@5'].tolist(), + 'recall@10': model_outputs['recall@10'].tolist(), + 'recall@20': model_outputs['recall@20'].tolist(), + 'ndcg@5': model_outputs['ndcg@5'].tolist(), + 'ndcg@10': model_outputs['ndcg@10'].tolist(), + 'ndcg@20': model_outputs['ndcg@20'].tolist(), + }, name='validation'), + cb.MetricAccumulator( + accumulators={ + 'validation/loss': cb.MeanAccumulator(), + 'validation/recall@5': cb.MeanAccumulator(), + 'validation/recall@10': cb.MeanAccumulator(), + 'validation/recall@20': cb.MeanAccumulator(), + 'validation/ndcg@5': cb.MeanAccumulator(), + 'validation/ndcg@10': cb.MeanAccumulator(), + 'validation/ndcg@20': cb.MeanAccumulator(), + }, + ), + ], + ).every_num_steps(EPOCH_NUM_STEPS), + + cb.Validation( + dataset=eval_dataloder, + callbacks=[ + cb.BatchMetrics(metrics=lambda model_outputs, _: { + 'loss': model_outputs['loss'].item(), + 'recall@5': model_outputs['recall@5'].tolist(), + 'recall@10': model_outputs['recall@10'].tolist(), + 'recall@20': model_outputs['recall@20'].tolist(), + 'ndcg@5': model_outputs['ndcg@5'].tolist(), + 'ndcg@10': model_outputs['ndcg@10'].tolist(), + 'ndcg@20': model_outputs['ndcg@20'].tolist(), + }, name='eval'), + cb.MetricAccumulator( + accumulators={ + 'eval/loss': cb.MeanAccumulator(), + 'eval/recall@5': cb.MeanAccumulator(), + 'eval/recall@10': cb.MeanAccumulator(), + 'eval/recall@20': cb.MeanAccumulator(), + 'eval/ndcg@5': cb.MeanAccumulator(), + 'eval/ndcg@10': cb.MeanAccumulator(), + 'eval/ndcg@20': cb.MeanAccumulator(), + }, + ), + ], + ).every_num_steps(EPOCH_NUM_STEPS * 4), + + cb.Logger().every_num_steps(EPOCH_NUM_STEPS), + cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=TENSORBOARD_LOGDIR), + + cb.EarlyStopping( + metric='validation/ndcg@20', + patience=40 * 4, + minimize=False, + model_path=os.path.join(CHECKPOINTS_DIR, EXPERIMENT_NAME) + ).every_num_steps(EPOCH_NUM_STEPS) + + # cb.Profiler( + # wait=10, + # warmup=10, + # active=10, + # logdir=TENSORBOARD_LOGDIR + # ), + # cb.StopAfterNumSteps(40) + + ] + + logger.debug('Everything is ready for training process!') + + runner = TrainingRunner( + model=model, + optimizer=optimizer, + dataset=train_dataloader, + callbacks=callbacks, + ) + runner.run() + + +if __name__ == '__main__': + main() diff --git a/scripts/tiger-yambda/yambda_varka_4.1_plum.py b/scripts/tiger-yambda/yambda_varka_4.1_plum.py new file mode 100644 index 0000000..9c00704 --- /dev/null +++ b/scripts/tiger-yambda/yambda_varka_4.1_plum.py @@ -0,0 +1,304 @@ +from collections import defaultdict +import json +import murmurhash +import numpy as np +import os +from pathlib import Path + +import pyarrow as pa +import pyarrow.feather as feather + +import torch + +from irec.data.transforms import Collate, Transform +from irec.data.dataloader import DataLoader + +from data import Dataset + +print("tiger no arrow varka 4.1") + +# ПУТИ + +IREC_PATH = '../../' +INTERACTIONS_TRAIN_PATH = os.path.join(IREC_PATH, 'data/Yambda/day-splits/merged_for_exps_filtered/exp_4_0.9_inter_tiger_train.json') +INTERACTIONS_VALID_PATH = os.path.join(IREC_PATH, 'data/Yambda/day-splits/merged_for_exps_filtered/valid_set.json') +INTERACTIONS_TEST_PATH = os.path.join(IREC_PATH, 'data/Yambda/day-splits/merged_for_exps_filtered/test_set.json') + +SEMANTIC_MAPPING_PATH = os.path.join(IREC_PATH, 'results_sigir_yambda/4-1_filtered_yambda_gpu_quantile_ws_2_clusters_colisionless.json') +TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Yambda/day-splits/test/yambda_quantile_tiger_T_train_batches/') +VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Yambda/day-splits/test/yambda_quantile_tiger_T_valid_batches/') +EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Yambda/day-splits/test/yambda_quantile_tiger_T_eval_batches/') + + +# ОСТАЛЬНОЕ + +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + + +MAX_TRAIN_EVENTS = 300 +MAX_SEQ_LEN = 20 +TRAIN_BATCH_SIZE = 256 +VALID_BATCH_SIZE = 1024 +NUM_USER_HASH = 2000 +CODEBOOK_SIZE = 256 +NUM_CODEBOOKS = 4 + +UNIFIED_VOCAB_SIZE = CODEBOOK_SIZE * NUM_CODEBOOKS + NUM_USER_HASH + 10 # 10 for utilities +PAD_TOKEN_ID = UNIFIED_VOCAB_SIZE - 1, +EOS_TOKEN_ID = UNIFIED_VOCAB_SIZE - 2, +DECODER_START_TOKEN_ID = UNIFIED_VOCAB_SIZE - 3, + + +class TigerProcessing(Transform): + def __call__(self, batch): + input_semantic_ids, attention_mask = batch['item.semantic.padded'], batch['item.semantic.mask'] + batch_size = attention_mask.shape[0] + + input_semantic_ids[~attention_mask] = PAD_TOKEN_ID # TODO ??? + + input_semantic_ids = np.concatenate([ + input_semantic_ids, + NUM_CODEBOOKS * CODEBOOK_SIZE + batch['user.hashed.ids'][:, None] + ], axis=-1) + + attention_mask = np.concatenate([ + attention_mask, + np.ones((batch_size, 1), dtype=attention_mask.dtype) + ], axis=-1) + + batch['input.data'] = input_semantic_ids + batch['input.mask'] = attention_mask + + target_semantic_ids = batch['labels.semantic.padded'] + target_semantic_ids = np.concatenate([ + np.ones( + (batch_size, 1), + dtype=np.int64, + ) * DECODER_START_TOKEN_ID, + target_semantic_ids + ], axis=-1) + + batch['output.data'] = target_semantic_ids + + return batch + + +class ToMasked(Transform): + def __init__(self, prefix, is_right_aligned=False): + self._prefix = prefix + self._is_right_aligned = is_right_aligned + + def __call__(self, batch): + data = batch[f'{self._prefix}.ids'] + lengths = batch[f'{self._prefix}.length'] + + batch_size = lengths.shape[0] + max_sequence_length = int(lengths.max()) + + if len(data.shape) == 1: # only indices + padded_tensor = np.zeros( + (batch_size, max_sequence_length), + dtype=data.dtype + ) # (batch_size, max_seq_len) + else: + assert len(data.shape) == 2 # embeddings + padded_tensor = np.zeros( + (batch_size, max_sequence_length, data.shape[-1]), + dtype=data.dtype + ) # (batch_size, max_seq_len, emb_dim) + + mask = np.arange(max_sequence_length)[None] < lengths[:, None] + + if self._is_right_aligned: + mask = np.flip(mask, axis=-1) + + padded_tensor[mask] = data + + batch[f'{self._prefix}.padded'] = padded_tensor + batch[f'{self._prefix}.mask'] = mask + + return batch + + +class SemanticIdsMapper(Transform): + def __init__(self, mapping, names=[]): + super().__init__() + self._mapping = mapping + self._names = names + + max_item_id = max(int(k) for k in mapping.keys()) + print(len(list(mapping.keys())), min(int(k) for k in mapping.keys()) , max(int(k) for k in mapping.keys())) + print(mapping["280052"]) #304781 + # assert False + data = [] + for i in range(max_item_id + 1): + if str(i) in mapping: + data.append(mapping[str(i)]) + else: + data.append([-1] * NUM_CODEBOOKS) + + self._mapping_tensor = torch.tensor(data, dtype=torch.long) + self._semantic_length = self._mapping_tensor.shape[-1] + + missing_count = (max_item_id + 1) - len(mapping) + print(f"Mapping: {len(mapping)} items, {missing_count} missing (-1 filled)") + + def __call__(self, batch): + for name in self._names: + if f'{name}.ids' in batch: + ids = batch[f'{name}.ids'] + lengths = batch[f'{name}.length'] + assert ids.min() >= 0 + assert ids.max() < self._mapping_tensor.shape[0] + semantic_ids = self._mapping_tensor[ids].flatten() + + assert (semantic_ids != -1).all(), \ + f"Missing mappings detected in {name}! Invalid positions: {(semantic_ids == -1).sum()} out of {len(semantic_ids)}" + + batch[f'{name}.semantic.ids'] = semantic_ids.numpy() + batch[f'{name}.semantic.length'] = lengths * self._semantic_length + + return batch + + +class UserHashing(Transform): + def __init__(self, hash_size): + super().__init__() + self._hash_size = hash_size + + def __call__(self, batch): + batch['user.hashed.ids'] = np.array([murmurhash.hash(str(x)) % self._hash_size for x in batch['user.ids']], dtype=np.int64) + return batch + + +def save_batches_to_arrow(batches, output_dir): + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=False) + + for batch_idx, batch in enumerate(batches): + length_groups = defaultdict(dict) + metadata_groups = defaultdict(dict) + + for key, value in batch.items(): + length = len(value) + + metadata_groups[length][f'{key}_shape'] = str(value.shape) + metadata_groups[length][f'{key}_dtype'] = str(value.dtype) + + if value.ndim == 1: + # 1D массив - сохраняем как есть + length_groups[length][key] = value + elif value.ndim == 2: + # 2D массив - используем list of lists + length_groups[length][key] = value.tolist() + else: + # >2D массив - flatten и сохраняем shape + length_groups[length][key] = value.flatten() + + for length, fields in length_groups.items(): + arrow_dict = {} + for k, v in fields.items(): + if isinstance(v, list) and len(v) > 0 and isinstance(v[0], list): + # List of lists (2D) + arrow_dict[k] = pa.array(v) + else: + arrow_dict[k] = pa.array(v) + + table = pa.table(arrow_dict) + if length in metadata_groups: + table = table.replace_schema_metadata(metadata_groups[length]) + + feather.write_feather( + table, + output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow", + compression='lz4' + ) + + # arrow_dict = {k: pa.array(v) for k, v in fields.items()} + # table = pa.table(arrow_dict) + + # feather.write_feather( + # table, + # output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow", + # compression='lz4' + # ) + + +def main(): + with open(SEMANTIC_MAPPING_PATH, 'r') as f: + mappings = json.load(f) + + data = Dataset.create_timestamp_based( + train_json_path=INTERACTIONS_TRAIN_PATH, + validation_json_path=INTERACTIONS_VALID_PATH, + test_json_path=INTERACTIONS_TEST_PATH, + max_sequence_length=MAX_SEQ_LEN, + sampler_type='tiger', + min_sample_len=2, + is_extended=True, + max_train_events=MAX_TRAIN_EVENTS + ) + + train_dataset, valid_dataset, eval_dataset = data.get_datasets() + + train_dataloader = DataLoader( + dataset=train_dataset, + batch_size=TRAIN_BATCH_SIZE, + shuffle=True, + drop_last=True + ) \ + .map(Collate()) \ + .map(UserHashing(NUM_USER_HASH)) \ + .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \ + .map(ToMasked('item.semantic', is_right_aligned=True)) \ + .map(ToMasked('labels.semantic', is_right_aligned=True)) \ + .map(TigerProcessing()) + + valid_dataloader = DataLoader( + dataset=valid_dataset, + batch_size=VALID_BATCH_SIZE, + shuffle=False, + drop_last=False + ) \ + .map(Collate()) \ + .map(UserHashing(NUM_USER_HASH)) \ + .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \ + .map(ToMasked('item.semantic', is_right_aligned=True)) \ + .map(ToMasked('labels.semantic', is_right_aligned=True)) \ + .map(ToMasked('visited', is_right_aligned=True)) \ + .map(TigerProcessing()) + + eval_dataloader = DataLoader( + dataset=eval_dataset, + batch_size=VALID_BATCH_SIZE, + shuffle=False, + drop_last=False + ) \ + .map(Collate()) \ + .map(UserHashing(NUM_USER_HASH)) \ + .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \ + .map(ToMasked('item.semantic', is_right_aligned=True)) \ + .map(ToMasked('labels.semantic', is_right_aligned=True)) \ + .map(ToMasked('visited', is_right_aligned=True)) \ + .map(TigerProcessing()) + + train_batches = [] + for train_batch in train_dataloader: + train_batches.append(train_batch) + save_batches_to_arrow(train_batches, TRAIN_BATCHES_DIR) + + valid_batches = [] + for valid_batch in valid_dataloader: + valid_batches.append(valid_batch) + save_batches_to_arrow(valid_batches, VALID_BATCHES_DIR) + + eval_batches = [] + for eval_batch in eval_dataloader: + eval_batches.append(eval_batch) + save_batches_to_arrow(eval_batches, EVAL_BATCHES_DIR) + + + +if __name__ == '__main__': + main() diff --git a/scripts/tiger/beauty_exps/train_4.1_plum.py b/scripts/tiger/beauty_exps/train_4.1_plum.py new file mode 100644 index 0000000..8daf273 --- /dev/null +++ b/scripts/tiger/beauty_exps/train_4.1_plum.py @@ -0,0 +1,225 @@ +from functools import partial +import json +from loguru import logger +import os + +import torch + +import irec.callbacks as cb +from irec.data.transforms import Collate, ToDevice +from irec.data.dataloader import DataLoader +from irec.runners import TrainingRunner +from irec.utils import fix_random_seed + +from data import ArrowBatchDataset +from models import TigerModel, CorrectItemsLogitsProcessor + + +# ПУТИ +IREC_PATH = '../../' +SEMANTIC_MAPPING_PATH = os.path.join(IREC_PATH, 'results_sigir/4-1_updated_quantile_plum_rqvae_beauty_ws_2_clusters_colisionless.json') +TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/updated_quantile_tiger_4-1_train_batches/') +VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/updated_quantile_tiger_4-1_valid_batches/') +EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/updated_quantile_tiger_4-1_eval_batches/') + +TENSORBOARD_LOGDIR = os.path.join(IREC_PATH, 'tensorboard_logs') +CHECKPOINTS_DIR = os.path.join(IREC_PATH, 'checkpoints') + +EXPERIMENT_NAME = 'tiger_beauty_updated_quantile_4-1_plum_ws_2_dp_0.2' + +# ОСТАЛЬНОЕ +SEED_VALUE = 42 +DEVICE = 'cuda' + +NUM_EPOCHS = 300 +MAX_SEQ_LEN = 20 +TRAIN_BATCH_SIZE = 256 +VALID_BATCH_SIZE = 1024 +EMBEDDING_DIM = 128 +CODEBOOK_SIZE = 256 +NUM_POSITIONS = 20 +NUM_USER_HASH = 2000 +NUM_HEADS = 6 +NUM_LAYERS = 4 +FEEDFORWARD_DIM = 1024 +KV_DIM = 64 +DROPOUT = 0.2 +NUM_BEAMS = 30 +TOP_K = 20 +NUM_CODEBOOKS = 4 +LR = 0.0001 + + +torch.set_float32_matmul_precision('high') +torch._dynamo.config.capture_scalar_outputs = True + +import torch._inductor.config as config +config.triton.cudagraph_skip_dynamic_graphs = True + + +def main(): + fix_random_seed(SEED_VALUE) + + with open(SEMANTIC_MAPPING_PATH, 'r') as f: + mappings = json.load(f) + + train_dataloader = DataLoader( + ArrowBatchDataset( + TRAIN_BATCHES_DIR, + device='cpu', + preload=True + ), + batch_size=1, + shuffle=True, + num_workers=0, + pin_memory=True, + collate_fn=Collate() + ).map(ToDevice(DEVICE)).repeat(NUM_EPOCHS) + + valid_dataloder = ArrowBatchDataset( + VALID_BATCHES_DIR, + device=DEVICE, + preload=True + ) + + eval_dataloder = ArrowBatchDataset( + EVAL_BATCHES_DIR, + device=DEVICE, + preload=True + ) + + model = TigerModel( + embedding_dim=EMBEDDING_DIM, + codebook_size=CODEBOOK_SIZE, + sem_id_len=NUM_CODEBOOKS, + user_ids_count=NUM_USER_HASH, + num_positions=NUM_POSITIONS, + num_heads=NUM_HEADS, + num_encoder_layers=NUM_LAYERS, + num_decoder_layers=NUM_LAYERS, + dim_feedforward=FEEDFORWARD_DIM, + num_beams=NUM_BEAMS, + num_return_sequences=TOP_K, + activation='relu', + d_kv=KV_DIM, + dropout=DROPOUT, + layer_norm_eps=1e-6, + initializer_range=0.02, + logits_processor=partial( + CorrectItemsLogitsProcessor, + NUM_CODEBOOKS, + CODEBOOK_SIZE, + mappings, + NUM_BEAMS + ) + ).to(DEVICE) + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + optimizer = torch.optim.AdamW( + model.parameters(), + lr=LR, + ) + + EPOCH_NUM_STEPS = 1024 # int(len(train_dataloader) // NUM_EPOCHS) + + callbacks = [ + cb.BatchMetrics(metrics=lambda model_outputs, _: { + 'loss': model_outputs['loss'].item(), + }, name='train'), + cb.MetricAccumulator( + accumulators={ + 'train/loss': cb.MeanAccumulator(), + }, + reset_every_num_steps=EPOCH_NUM_STEPS + ), + + cb.Validation( + dataset=valid_dataloder, + callbacks=[ + cb.BatchMetrics(metrics=lambda model_outputs, _:{ + 'loss': model_outputs['loss'].item(), + 'recall@5': model_outputs['recall@5'].tolist(), + 'recall@10': model_outputs['recall@10'].tolist(), + 'recall@20': model_outputs['recall@20'].tolist(), + 'ndcg@5': model_outputs['ndcg@5'].tolist(), + 'ndcg@10': model_outputs['ndcg@10'].tolist(), + 'ndcg@20': model_outputs['ndcg@20'].tolist(), + }, name='validation'), + cb.MetricAccumulator( + accumulators={ + 'validation/loss': cb.MeanAccumulator(), + 'validation/recall@5': cb.MeanAccumulator(), + 'validation/recall@10': cb.MeanAccumulator(), + 'validation/recall@20': cb.MeanAccumulator(), + 'validation/ndcg@5': cb.MeanAccumulator(), + 'validation/ndcg@10': cb.MeanAccumulator(), + 'validation/ndcg@20': cb.MeanAccumulator(), + }, + ), + ], + ).every_num_steps(EPOCH_NUM_STEPS), + + cb.Validation( + dataset=eval_dataloder, + callbacks=[ + cb.BatchMetrics(metrics=lambda model_outputs, _: { + 'loss': model_outputs['loss'].item(), + 'recall@5': model_outputs['recall@5'].tolist(), + 'recall@10': model_outputs['recall@10'].tolist(), + 'recall@20': model_outputs['recall@20'].tolist(), + 'ndcg@5': model_outputs['ndcg@5'].tolist(), + 'ndcg@10': model_outputs['ndcg@10'].tolist(), + 'ndcg@20': model_outputs['ndcg@20'].tolist(), + }, name='eval'), + cb.MetricAccumulator( + accumulators={ + 'eval/loss': cb.MeanAccumulator(), + 'eval/recall@5': cb.MeanAccumulator(), + 'eval/recall@10': cb.MeanAccumulator(), + 'eval/recall@20': cb.MeanAccumulator(), + 'eval/ndcg@5': cb.MeanAccumulator(), + 'eval/ndcg@10': cb.MeanAccumulator(), + 'eval/ndcg@20': cb.MeanAccumulator(), + }, + ), + ], + ).every_num_steps(EPOCH_NUM_STEPS), + + cb.Logger().every_num_steps(EPOCH_NUM_STEPS), + cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=TENSORBOARD_LOGDIR), + + cb.EarlyStopping( + metric='eval/ndcg@20', + patience=40, + minimize=False, + model_path=os.path.join(CHECKPOINTS_DIR, EXPERIMENT_NAME) + ).every_num_steps(EPOCH_NUM_STEPS) + + # cb.Profiler( + # wait=10, + # warmup=10, + # active=10, + # logdir=TENSORBOARD_LOGDIR + # ), + # cb.StopAfterNumSteps(40) + + ] + + logger.debug('Everything is ready for training process!') + + runner = TrainingRunner( + model=model, + optimizer=optimizer, + dataset=train_dataloader, + callbacks=callbacks, + ) + runner.run() + + +if __name__ == '__main__': + main() diff --git a/scripts/tiger/beauty_exps/train_4.2_plum.py b/scripts/tiger/beauty_exps/train_4.2_plum.py new file mode 100644 index 0000000..580bcb5 --- /dev/null +++ b/scripts/tiger/beauty_exps/train_4.2_plum.py @@ -0,0 +1,225 @@ +from functools import partial +import json +from loguru import logger +import os + +import torch + +import irec.callbacks as cb +from irec.data.transforms import Collate, ToDevice +from irec.data.dataloader import DataLoader +from irec.runners import TrainingRunner +from irec.utils import fix_random_seed + +from data import ArrowBatchDataset +from models import TigerModel, CorrectItemsLogitsProcessor + + +# ПУТИ +IREC_PATH = '../../' +SEMANTIC_MAPPING_PATH = os.path.join(IREC_PATH, 'results_sigir/4-2_updated_quantile_plum_rqvae_beauty_ws_2_clusters_colisionless.json') +TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/updated_quantile_tiger_4-2_train_batches/') +VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/updated_quantile_tiger_4-2_valid_batches/') +EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/updated_quantile_tiger_4-2_eval_batches/') + +TENSORBOARD_LOGDIR = os.path.join(IREC_PATH, 'tensorboard_logs') +CHECKPOINTS_DIR = os.path.join(IREC_PATH, 'checkpoints') + +EXPERIMENT_NAME = 'tiger_beauty_updated_quantile_4-2_plum_ws_2_dp_0.2' + +# ОСТАЛЬНОЕ +SEED_VALUE = 42 +DEVICE = 'cuda' + +NUM_EPOCHS = 300 +MAX_SEQ_LEN = 20 +TRAIN_BATCH_SIZE = 256 +VALID_BATCH_SIZE = 1024 +EMBEDDING_DIM = 128 +CODEBOOK_SIZE = 256 +NUM_POSITIONS = 20 +NUM_USER_HASH = 2000 +NUM_HEADS = 6 +NUM_LAYERS = 4 +FEEDFORWARD_DIM = 1024 +KV_DIM = 64 +DROPOUT = 0.2 +NUM_BEAMS = 30 +TOP_K = 20 +NUM_CODEBOOKS = 4 +LR = 0.0001 + + +torch.set_float32_matmul_precision('high') +torch._dynamo.config.capture_scalar_outputs = True + +import torch._inductor.config as config +config.triton.cudagraph_skip_dynamic_graphs = True + + +def main(): + fix_random_seed(SEED_VALUE) + + with open(SEMANTIC_MAPPING_PATH, 'r') as f: + mappings = json.load(f) + + train_dataloader = DataLoader( + ArrowBatchDataset( + TRAIN_BATCHES_DIR, + device='cpu', + preload=True + ), + batch_size=1, + shuffle=True, + num_workers=0, + pin_memory=True, + collate_fn=Collate() + ).map(ToDevice(DEVICE)).repeat(NUM_EPOCHS) + + valid_dataloder = ArrowBatchDataset( + VALID_BATCHES_DIR, + device=DEVICE, + preload=True + ) + + eval_dataloder = ArrowBatchDataset( + EVAL_BATCHES_DIR, + device=DEVICE, + preload=True + ) + + model = TigerModel( + embedding_dim=EMBEDDING_DIM, + codebook_size=CODEBOOK_SIZE, + sem_id_len=NUM_CODEBOOKS, + user_ids_count=NUM_USER_HASH, + num_positions=NUM_POSITIONS, + num_heads=NUM_HEADS, + num_encoder_layers=NUM_LAYERS, + num_decoder_layers=NUM_LAYERS, + dim_feedforward=FEEDFORWARD_DIM, + num_beams=NUM_BEAMS, + num_return_sequences=TOP_K, + activation='relu', + d_kv=KV_DIM, + dropout=DROPOUT, + layer_norm_eps=1e-6, + initializer_range=0.02, + logits_processor=partial( + CorrectItemsLogitsProcessor, + NUM_CODEBOOKS, + CODEBOOK_SIZE, + mappings, + NUM_BEAMS + ) + ).to(DEVICE) + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + optimizer = torch.optim.AdamW( + model.parameters(), + lr=LR, + ) + + EPOCH_NUM_STEPS = 1024 # int(len(train_dataloader) // NUM_EPOCHS) + + callbacks = [ + cb.BatchMetrics(metrics=lambda model_outputs, _: { + 'loss': model_outputs['loss'].item(), + }, name='train'), + cb.MetricAccumulator( + accumulators={ + 'train/loss': cb.MeanAccumulator(), + }, + reset_every_num_steps=EPOCH_NUM_STEPS + ), + + cb.Validation( + dataset=valid_dataloder, + callbacks=[ + cb.BatchMetrics(metrics=lambda model_outputs, _:{ + 'loss': model_outputs['loss'].item(), + 'recall@5': model_outputs['recall@5'].tolist(), + 'recall@10': model_outputs['recall@10'].tolist(), + 'recall@20': model_outputs['recall@20'].tolist(), + 'ndcg@5': model_outputs['ndcg@5'].tolist(), + 'ndcg@10': model_outputs['ndcg@10'].tolist(), + 'ndcg@20': model_outputs['ndcg@20'].tolist(), + }, name='validation'), + cb.MetricAccumulator( + accumulators={ + 'validation/loss': cb.MeanAccumulator(), + 'validation/recall@5': cb.MeanAccumulator(), + 'validation/recall@10': cb.MeanAccumulator(), + 'validation/recall@20': cb.MeanAccumulator(), + 'validation/ndcg@5': cb.MeanAccumulator(), + 'validation/ndcg@10': cb.MeanAccumulator(), + 'validation/ndcg@20': cb.MeanAccumulator(), + }, + ), + ], + ).every_num_steps(EPOCH_NUM_STEPS), + + cb.Validation( + dataset=eval_dataloder, + callbacks=[ + cb.BatchMetrics(metrics=lambda model_outputs, _: { + 'loss': model_outputs['loss'].item(), + 'recall@5': model_outputs['recall@5'].tolist(), + 'recall@10': model_outputs['recall@10'].tolist(), + 'recall@20': model_outputs['recall@20'].tolist(), + 'ndcg@5': model_outputs['ndcg@5'].tolist(), + 'ndcg@10': model_outputs['ndcg@10'].tolist(), + 'ndcg@20': model_outputs['ndcg@20'].tolist(), + }, name='eval'), + cb.MetricAccumulator( + accumulators={ + 'eval/loss': cb.MeanAccumulator(), + 'eval/recall@5': cb.MeanAccumulator(), + 'eval/recall@10': cb.MeanAccumulator(), + 'eval/recall@20': cb.MeanAccumulator(), + 'eval/ndcg@5': cb.MeanAccumulator(), + 'eval/ndcg@10': cb.MeanAccumulator(), + 'eval/ndcg@20': cb.MeanAccumulator(), + }, + ), + ], + ).every_num_steps(EPOCH_NUM_STEPS), + + cb.Logger().every_num_steps(EPOCH_NUM_STEPS), + cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=TENSORBOARD_LOGDIR), + + cb.EarlyStopping( + metric='eval/ndcg@20', + patience=40, + minimize=False, + model_path=os.path.join(CHECKPOINTS_DIR, EXPERIMENT_NAME) + ).every_num_steps(EPOCH_NUM_STEPS) + + # cb.Profiler( + # wait=10, + # warmup=10, + # active=10, + # logdir=TENSORBOARD_LOGDIR + # ), + # cb.StopAfterNumSteps(40) + + ] + + logger.debug('Everything is ready for training process!') + + runner = TrainingRunner( + model=model, + optimizer=optimizer, + dataset=train_dataloader, + callbacks=callbacks, + ) + runner.run() + + +if __name__ == '__main__': + main() diff --git a/scripts/tiger/beauty_exps/train_4.3_plum.py b/scripts/tiger/beauty_exps/train_4.3_plum.py new file mode 100644 index 0000000..f98e9fd --- /dev/null +++ b/scripts/tiger/beauty_exps/train_4.3_plum.py @@ -0,0 +1,225 @@ +from functools import partial +import json +from loguru import logger +import os + +import torch + +import irec.callbacks as cb +from irec.data.transforms import Collate, ToDevice +from irec.data.dataloader import DataLoader +from irec.runners import TrainingRunner +from irec.utils import fix_random_seed + +from data import ArrowBatchDataset +from models import TigerModel, CorrectItemsLogitsProcessor + + +# ПУТИ +IREC_PATH = '../../' +SEMANTIC_MAPPING_PATH = os.path.join(IREC_PATH, 'results_sigir/4-3_updated_quantile_plum_rqvae_beauty_ws_2_clusters_colisionless.json') +TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/updated_quantile_tiger_4-3_train_batches/') +VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/updated_quantile_tiger_4-3_valid_batches/') +EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/updated_quantile_tiger_4-3_eval_batches/') + +TENSORBOARD_LOGDIR = os.path.join(IREC_PATH, 'tensorboard_logs') +CHECKPOINTS_DIR = os.path.join(IREC_PATH, 'checkpoints') + +EXPERIMENT_NAME = 'tiger_beauty_updated_quantile_4-3_plum_ws_2_dp_0.2' + +# ОСТАЛЬНОЕ +SEED_VALUE = 42 +DEVICE = 'cuda' + +NUM_EPOCHS = 300 +MAX_SEQ_LEN = 20 +TRAIN_BATCH_SIZE = 256 +VALID_BATCH_SIZE = 1024 +EMBEDDING_DIM = 128 +CODEBOOK_SIZE = 256 +NUM_POSITIONS = 20 +NUM_USER_HASH = 2000 +NUM_HEADS = 6 +NUM_LAYERS = 4 +FEEDFORWARD_DIM = 1024 +KV_DIM = 64 +DROPOUT = 0.2 +NUM_BEAMS = 30 +TOP_K = 20 +NUM_CODEBOOKS = 4 +LR = 0.0001 + + +torch.set_float32_matmul_precision('high') +torch._dynamo.config.capture_scalar_outputs = True + +import torch._inductor.config as config +config.triton.cudagraph_skip_dynamic_graphs = True + + +def main(): + fix_random_seed(SEED_VALUE) + + with open(SEMANTIC_MAPPING_PATH, 'r') as f: + mappings = json.load(f) + + train_dataloader = DataLoader( + ArrowBatchDataset( + TRAIN_BATCHES_DIR, + device='cpu', + preload=True + ), + batch_size=1, + shuffle=True, + num_workers=0, + pin_memory=True, + collate_fn=Collate() + ).map(ToDevice(DEVICE)).repeat(NUM_EPOCHS) + + valid_dataloder = ArrowBatchDataset( + VALID_BATCHES_DIR, + device=DEVICE, + preload=True + ) + + eval_dataloder = ArrowBatchDataset( + EVAL_BATCHES_DIR, + device=DEVICE, + preload=True + ) + + model = TigerModel( + embedding_dim=EMBEDDING_DIM, + codebook_size=CODEBOOK_SIZE, + sem_id_len=NUM_CODEBOOKS, + user_ids_count=NUM_USER_HASH, + num_positions=NUM_POSITIONS, + num_heads=NUM_HEADS, + num_encoder_layers=NUM_LAYERS, + num_decoder_layers=NUM_LAYERS, + dim_feedforward=FEEDFORWARD_DIM, + num_beams=NUM_BEAMS, + num_return_sequences=TOP_K, + activation='relu', + d_kv=KV_DIM, + dropout=DROPOUT, + layer_norm_eps=1e-6, + initializer_range=0.02, + logits_processor=partial( + CorrectItemsLogitsProcessor, + NUM_CODEBOOKS, + CODEBOOK_SIZE, + mappings, + NUM_BEAMS + ) + ).to(DEVICE) + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + optimizer = torch.optim.AdamW( + model.parameters(), + lr=LR, + ) + + EPOCH_NUM_STEPS = 1024 # int(len(train_dataloader) // NUM_EPOCHS) + + callbacks = [ + cb.BatchMetrics(metrics=lambda model_outputs, _: { + 'loss': model_outputs['loss'].item(), + }, name='train'), + cb.MetricAccumulator( + accumulators={ + 'train/loss': cb.MeanAccumulator(), + }, + reset_every_num_steps=EPOCH_NUM_STEPS + ), + + cb.Validation( + dataset=valid_dataloder, + callbacks=[ + cb.BatchMetrics(metrics=lambda model_outputs, _:{ + 'loss': model_outputs['loss'].item(), + 'recall@5': model_outputs['recall@5'].tolist(), + 'recall@10': model_outputs['recall@10'].tolist(), + 'recall@20': model_outputs['recall@20'].tolist(), + 'ndcg@5': model_outputs['ndcg@5'].tolist(), + 'ndcg@10': model_outputs['ndcg@10'].tolist(), + 'ndcg@20': model_outputs['ndcg@20'].tolist(), + }, name='validation'), + cb.MetricAccumulator( + accumulators={ + 'validation/loss': cb.MeanAccumulator(), + 'validation/recall@5': cb.MeanAccumulator(), + 'validation/recall@10': cb.MeanAccumulator(), + 'validation/recall@20': cb.MeanAccumulator(), + 'validation/ndcg@5': cb.MeanAccumulator(), + 'validation/ndcg@10': cb.MeanAccumulator(), + 'validation/ndcg@20': cb.MeanAccumulator(), + }, + ), + ], + ).every_num_steps(EPOCH_NUM_STEPS), + + cb.Validation( + dataset=eval_dataloder, + callbacks=[ + cb.BatchMetrics(metrics=lambda model_outputs, _: { + 'loss': model_outputs['loss'].item(), + 'recall@5': model_outputs['recall@5'].tolist(), + 'recall@10': model_outputs['recall@10'].tolist(), + 'recall@20': model_outputs['recall@20'].tolist(), + 'ndcg@5': model_outputs['ndcg@5'].tolist(), + 'ndcg@10': model_outputs['ndcg@10'].tolist(), + 'ndcg@20': model_outputs['ndcg@20'].tolist(), + }, name='eval'), + cb.MetricAccumulator( + accumulators={ + 'eval/loss': cb.MeanAccumulator(), + 'eval/recall@5': cb.MeanAccumulator(), + 'eval/recall@10': cb.MeanAccumulator(), + 'eval/recall@20': cb.MeanAccumulator(), + 'eval/ndcg@5': cb.MeanAccumulator(), + 'eval/ndcg@10': cb.MeanAccumulator(), + 'eval/ndcg@20': cb.MeanAccumulator(), + }, + ), + ], + ).every_num_steps(EPOCH_NUM_STEPS), + + cb.Logger().every_num_steps(EPOCH_NUM_STEPS), + cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=TENSORBOARD_LOGDIR), + + cb.EarlyStopping( + metric='eval/ndcg@20', + patience=40, + minimize=False, + model_path=os.path.join(CHECKPOINTS_DIR, EXPERIMENT_NAME) + ).every_num_steps(EPOCH_NUM_STEPS) + + # cb.Profiler( + # wait=10, + # warmup=10, + # active=10, + # logdir=TENSORBOARD_LOGDIR + # ), + # cb.StopAfterNumSteps(40) + + ] + + logger.debug('Everything is ready for training process!') + + runner = TrainingRunner( + model=model, + optimizer=optimizer, + dataset=train_dataloader, + callbacks=callbacks, + ) + runner.run() + + +if __name__ == '__main__': + main() diff --git a/scripts/tiger/beauty_exps/varka_4.1_plum.py b/scripts/tiger/beauty_exps/varka_4.1_plum.py new file mode 100644 index 0000000..302e04e --- /dev/null +++ b/scripts/tiger/beauty_exps/varka_4.1_plum.py @@ -0,0 +1,287 @@ +from collections import defaultdict +import json +import murmurhash +import numpy as np +import os +from pathlib import Path + +import pyarrow as pa +import pyarrow.feather as feather + +import torch + +from irec.data.transforms import Collate, Transform +from irec.data.dataloader import DataLoader + +from data import Dataset + + + +# ПУТИ + +IREC_PATH = '../../' +INTERACTIONS_TRAIN_PATH = os.path.join(IREC_PATH, 'sigir/Beauty_new/updated_quantile_splits/merged_for_exps/exp_4_0.9_inter_tiger_train.json') +INTERACTIONS_VALID_PATH = os.path.join(IREC_PATH, 'sigir/Beauty_new/updated_quantile_splits/merged_for_exps/valid_set.json') +INTERACTIONS_TEST_PATH = os.path.join(IREC_PATH, 'sigir/Beauty_new/updated_quantile_splits/merged_for_exps/test_set.json') + +SEMANTIC_MAPPING_PATH = os.path.join(IREC_PATH, 'results_sigir/4-1_updated_quantile_plum_rqvae_beauty_ws_2_clusters_colisionless.json') +TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/updated_quantile_tiger_4-1_train_batches/') +VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/updated_quantile_tiger_4-1_valid_batches/') +EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/updated_quantile_tiger_4-1_eval_batches/') + + +# ОСТАЛЬНОЕ + +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + + +MAX_SEQ_LEN = 20 +TRAIN_BATCH_SIZE = 256 +VALID_BATCH_SIZE = 1024 +NUM_USER_HASH = 2000 +CODEBOOK_SIZE = 256 +NUM_CODEBOOKS = 4 + +UNIFIED_VOCAB_SIZE = CODEBOOK_SIZE * NUM_CODEBOOKS + NUM_USER_HASH + 10 # 10 for utilities +PAD_TOKEN_ID = UNIFIED_VOCAB_SIZE - 1, +EOS_TOKEN_ID = UNIFIED_VOCAB_SIZE - 2, +DECODER_START_TOKEN_ID = UNIFIED_VOCAB_SIZE - 3, + + + +class TigerProcessing(Transform): + def __call__(self, batch): + input_semantic_ids, attention_mask = batch['item.semantic.padded'], batch['item.semantic.mask'] + batch_size = attention_mask.shape[0] + + input_semantic_ids[~attention_mask] = PAD_TOKEN_ID # TODO ??? + + input_semantic_ids = np.concatenate([ + input_semantic_ids, + NUM_CODEBOOKS * CODEBOOK_SIZE + batch['user.hashed.ids'][:, None] + ], axis=-1) + + attention_mask = np.concatenate([ + attention_mask, + np.ones((batch_size, 1), dtype=attention_mask.dtype) + ], axis=-1) + + batch['input.data'] = input_semantic_ids + batch['input.mask'] = attention_mask + + target_semantic_ids = batch['labels.semantic.padded'] + target_semantic_ids = np.concatenate([ + np.ones( + (batch_size, 1), + dtype=np.int64, + ) * DECODER_START_TOKEN_ID, + target_semantic_ids + ], axis=-1) + + batch['output.data'] = target_semantic_ids + + return batch + + +class ToMasked(Transform): + def __init__(self, prefix, is_right_aligned=False): + self._prefix = prefix + self._is_right_aligned = is_right_aligned + + def __call__(self, batch): + data = batch[f'{self._prefix}.ids'] + lengths = batch[f'{self._prefix}.length'] + + batch_size = lengths.shape[0] + max_sequence_length = int(lengths.max()) + + if len(data.shape) == 1: # only indices + padded_tensor = np.zeros( + (batch_size, max_sequence_length), + dtype=data.dtype + ) # (batch_size, max_seq_len) + else: + assert len(data.shape) == 2 # embeddings + padded_tensor = np.zeros( + (batch_size, max_sequence_length, data.shape[-1]), + dtype=data.dtype + ) # (batch_size, max_seq_len, emb_dim) + + mask = np.arange(max_sequence_length)[None] < lengths[:, None] + + if self._is_right_aligned: + mask = np.flip(mask, axis=-1) + + padded_tensor[mask] = data + + batch[f'{self._prefix}.padded'] = padded_tensor + batch[f'{self._prefix}.mask'] = mask + + return batch + + +class SemanticIdsMapper(Transform): + def __init__(self, mapping, names=[]): + super().__init__() + self._mapping = mapping + self._names = names + + data = [] + for i in range(len(mapping)): + data.append(mapping[str(i)]) + self._mapping_tensor = torch.tensor(data, dtype=torch.long) + self._semantic_length = self._mapping_tensor.shape[-1] + + def __call__(self, batch): + for name in self._names: + if f'{name}.ids' in batch: + ids = batch[f'{name}.ids'] + lengths = batch[f'{name}.length'] + assert ids.min() >= 0 + assert ids.max() < self._mapping_tensor.shape[0] + batch[f'{name}.semantic.ids'] = self._mapping_tensor[ids].flatten().numpy() + batch[f'{name}.semantic.length'] = lengths * self._semantic_length + + return batch + + +class UserHashing(Transform): + def __init__(self, hash_size): + super().__init__() + self._hash_size = hash_size + + def __call__(self, batch): + batch['user.hashed.ids'] = np.array([murmurhash.hash(str(x)) % self._hash_size for x in batch['user.ids']], dtype=np.int64) + return batch + + +def save_batches_to_arrow(batches, output_dir): + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=False) + + for batch_idx, batch in enumerate(batches): + length_groups = defaultdict(dict) + metadata_groups = defaultdict(dict) + + for key, value in batch.items(): + length = len(value) + + metadata_groups[length][f'{key}_shape'] = str(value.shape) + metadata_groups[length][f'{key}_dtype'] = str(value.dtype) + + if value.ndim == 1: + # 1D массив - сохраняем как есть + length_groups[length][key] = value + elif value.ndim == 2: + # 2D массив - используем list of lists + length_groups[length][key] = value.tolist() + else: + # >2D массив - flatten и сохраняем shape + length_groups[length][key] = value.flatten() + + for length, fields in length_groups.items(): + arrow_dict = {} + for k, v in fields.items(): + if isinstance(v, list) and len(v) > 0 and isinstance(v[0], list): + # List of lists (2D) + arrow_dict[k] = pa.array(v) + else: + arrow_dict[k] = pa.array(v) + + table = pa.table(arrow_dict) + if length in metadata_groups: + table = table.replace_schema_metadata(metadata_groups[length]) + + feather.write_feather( + table, + output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow", + compression='lz4' + ) + + # arrow_dict = {k: pa.array(v) for k, v in fields.items()} + # table = pa.table(arrow_dict) + + # feather.write_feather( + # table, + # output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow", + # compression='lz4' + # ) + + +def main(): + data = Dataset.create_timestamp_based( + train_json_path=INTERACTIONS_TRAIN_PATH, + validation_json_path=INTERACTIONS_VALID_PATH, + test_json_path=INTERACTIONS_TEST_PATH, + max_sequence_length=MAX_SEQ_LEN, + sampler_type='tiger', + min_sample_len=2, + is_extended=True + ) + + with open(SEMANTIC_MAPPING_PATH, 'r') as f: + mappings = json.load(f) + + train_dataset, valid_dataset, eval_dataset = data.get_datasets() + + train_dataloader = DataLoader( + dataset=train_dataset, + batch_size=TRAIN_BATCH_SIZE, + shuffle=True, + drop_last=True + ) \ + .map(Collate()) \ + .map(UserHashing(NUM_USER_HASH)) \ + .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \ + .map(ToMasked('item.semantic', is_right_aligned=True)) \ + .map(ToMasked('labels.semantic', is_right_aligned=True)) \ + .map(TigerProcessing()) + + valid_dataloader = DataLoader( + dataset=valid_dataset, + batch_size=VALID_BATCH_SIZE, + shuffle=False, + drop_last=False + ) \ + .map(Collate()) \ + .map(UserHashing(NUM_USER_HASH)) \ + .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \ + .map(ToMasked('item.semantic', is_right_aligned=True)) \ + .map(ToMasked('labels.semantic', is_right_aligned=True)) \ + .map(ToMasked('visited', is_right_aligned=True)) \ + .map(TigerProcessing()) + + eval_dataloader = DataLoader( + dataset=eval_dataset, + batch_size=VALID_BATCH_SIZE, + shuffle=False, + drop_last=False + ) \ + .map(Collate()) \ + .map(UserHashing(NUM_USER_HASH)) \ + .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \ + .map(ToMasked('item.semantic', is_right_aligned=True)) \ + .map(ToMasked('labels.semantic', is_right_aligned=True)) \ + .map(ToMasked('visited', is_right_aligned=True)) \ + .map(TigerProcessing()) + + train_batches = [] + for train_batch in train_dataloader: + train_batches.append(train_batch) + save_batches_to_arrow(train_batches, TRAIN_BATCHES_DIR) + + valid_batches = [] + for valid_batch in valid_dataloader: + valid_batches.append(valid_batch) + save_batches_to_arrow(valid_batches, VALID_BATCHES_DIR) + + eval_batches = [] + for eval_batch in eval_dataloader: + eval_batches.append(eval_batch) + save_batches_to_arrow(eval_batches, EVAL_BATCHES_DIR) + + + +if __name__ == '__main__': + main() diff --git a/scripts/tiger/beauty_exps/varka_4.2_plum.py b/scripts/tiger/beauty_exps/varka_4.2_plum.py new file mode 100644 index 0000000..b00fef2 --- /dev/null +++ b/scripts/tiger/beauty_exps/varka_4.2_plum.py @@ -0,0 +1,288 @@ +from collections import defaultdict +import json +import murmurhash +import numpy as np +import os +from pathlib import Path + +import pyarrow as pa +import pyarrow.feather as feather + +import torch + +from irec.data.transforms import Collate, Transform +from irec.data.dataloader import DataLoader + +from data import Dataset + + + +# ПУТИ + +IREC_PATH = '../../' +INTERACTIONS_TRAIN_PATH = os.path.join(IREC_PATH, 'sigir/Beauty_new/updated_quantile_splits/merged_for_exps/exp_4_0.9_inter_tiger_train.json') +INTERACTIONS_VALID_PATH = os.path.join(IREC_PATH, 'sigir/Beauty_new/updated_quantile_splits/merged_for_exps/valid_set.json') +INTERACTIONS_TEST_PATH = os.path.join(IREC_PATH, 'sigir/Beauty_new/updated_quantile_splits/merged_for_exps/test_set.json') + +SEMANTIC_MAPPING_PATH = os.path.join(IREC_PATH, 'results_sigir/4-2_updated_quantile_plum_rqvae_beauty_ws_2_clusters_colisionless.json') +TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/updated_quantile_tiger_4-2_train_batches/') +VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/updated_quantile_tiger_4-2_valid_batches/') +EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/updated_quantile_tiger_4-2_eval_batches/') + + + +# ОСТАЛЬНОЕ + +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + + +MAX_SEQ_LEN = 20 +TRAIN_BATCH_SIZE = 256 +VALID_BATCH_SIZE = 1024 +NUM_USER_HASH = 2000 +CODEBOOK_SIZE = 256 +NUM_CODEBOOKS = 4 + +UNIFIED_VOCAB_SIZE = CODEBOOK_SIZE * NUM_CODEBOOKS + NUM_USER_HASH + 10 # 10 for utilities +PAD_TOKEN_ID = UNIFIED_VOCAB_SIZE - 1, +EOS_TOKEN_ID = UNIFIED_VOCAB_SIZE - 2, +DECODER_START_TOKEN_ID = UNIFIED_VOCAB_SIZE - 3, + + + +class TigerProcessing(Transform): + def __call__(self, batch): + input_semantic_ids, attention_mask = batch['item.semantic.padded'], batch['item.semantic.mask'] + batch_size = attention_mask.shape[0] + + input_semantic_ids[~attention_mask] = PAD_TOKEN_ID # TODO ??? + + input_semantic_ids = np.concatenate([ + input_semantic_ids, + NUM_CODEBOOKS * CODEBOOK_SIZE + batch['user.hashed.ids'][:, None] + ], axis=-1) + + attention_mask = np.concatenate([ + attention_mask, + np.ones((batch_size, 1), dtype=attention_mask.dtype) + ], axis=-1) + + batch['input.data'] = input_semantic_ids + batch['input.mask'] = attention_mask + + target_semantic_ids = batch['labels.semantic.padded'] + target_semantic_ids = np.concatenate([ + np.ones( + (batch_size, 1), + dtype=np.int64, + ) * DECODER_START_TOKEN_ID, + target_semantic_ids + ], axis=-1) + + batch['output.data'] = target_semantic_ids + + return batch + + +class ToMasked(Transform): + def __init__(self, prefix, is_right_aligned=False): + self._prefix = prefix + self._is_right_aligned = is_right_aligned + + def __call__(self, batch): + data = batch[f'{self._prefix}.ids'] + lengths = batch[f'{self._prefix}.length'] + + batch_size = lengths.shape[0] + max_sequence_length = int(lengths.max()) + + if len(data.shape) == 1: # only indices + padded_tensor = np.zeros( + (batch_size, max_sequence_length), + dtype=data.dtype + ) # (batch_size, max_seq_len) + else: + assert len(data.shape) == 2 # embeddings + padded_tensor = np.zeros( + (batch_size, max_sequence_length, data.shape[-1]), + dtype=data.dtype + ) # (batch_size, max_seq_len, emb_dim) + + mask = np.arange(max_sequence_length)[None] < lengths[:, None] + + if self._is_right_aligned: + mask = np.flip(mask, axis=-1) + + padded_tensor[mask] = data + + batch[f'{self._prefix}.padded'] = padded_tensor + batch[f'{self._prefix}.mask'] = mask + + return batch + + +class SemanticIdsMapper(Transform): + def __init__(self, mapping, names=[]): + super().__init__() + self._mapping = mapping + self._names = names + + data = [] + for i in range(len(mapping)): + data.append(mapping[str(i)]) + self._mapping_tensor = torch.tensor(data, dtype=torch.long) + self._semantic_length = self._mapping_tensor.shape[-1] + + def __call__(self, batch): + for name in self._names: + if f'{name}.ids' in batch: + ids = batch[f'{name}.ids'] + lengths = batch[f'{name}.length'] + assert ids.min() >= 0 + assert ids.max() < self._mapping_tensor.shape[0] + batch[f'{name}.semantic.ids'] = self._mapping_tensor[ids].flatten().numpy() + batch[f'{name}.semantic.length'] = lengths * self._semantic_length + + return batch + + +class UserHashing(Transform): + def __init__(self, hash_size): + super().__init__() + self._hash_size = hash_size + + def __call__(self, batch): + batch['user.hashed.ids'] = np.array([murmurhash.hash(str(x)) % self._hash_size for x in batch['user.ids']], dtype=np.int64) + return batch + + +def save_batches_to_arrow(batches, output_dir): + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=False) + + for batch_idx, batch in enumerate(batches): + length_groups = defaultdict(dict) + metadata_groups = defaultdict(dict) + + for key, value in batch.items(): + length = len(value) + + metadata_groups[length][f'{key}_shape'] = str(value.shape) + metadata_groups[length][f'{key}_dtype'] = str(value.dtype) + + if value.ndim == 1: + # 1D массив - сохраняем как есть + length_groups[length][key] = value + elif value.ndim == 2: + # 2D массив - используем list of lists + length_groups[length][key] = value.tolist() + else: + # >2D массив - flatten и сохраняем shape + length_groups[length][key] = value.flatten() + + for length, fields in length_groups.items(): + arrow_dict = {} + for k, v in fields.items(): + if isinstance(v, list) and len(v) > 0 and isinstance(v[0], list): + # List of lists (2D) + arrow_dict[k] = pa.array(v) + else: + arrow_dict[k] = pa.array(v) + + table = pa.table(arrow_dict) + if length in metadata_groups: + table = table.replace_schema_metadata(metadata_groups[length]) + + feather.write_feather( + table, + output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow", + compression='lz4' + ) + + # arrow_dict = {k: pa.array(v) for k, v in fields.items()} + # table = pa.table(arrow_dict) + + # feather.write_feather( + # table, + # output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow", + # compression='lz4' + # ) + + +def main(): + data = Dataset.create_timestamp_based( + train_json_path=INTERACTIONS_TRAIN_PATH, + validation_json_path=INTERACTIONS_VALID_PATH, + test_json_path=INTERACTIONS_TEST_PATH, + max_sequence_length=MAX_SEQ_LEN, + sampler_type='tiger', + min_sample_len=2, + is_extended=True + ) + + with open(SEMANTIC_MAPPING_PATH, 'r') as f: + mappings = json.load(f) + + train_dataset, valid_dataset, eval_dataset = data.get_datasets() + + train_dataloader = DataLoader( + dataset=train_dataset, + batch_size=TRAIN_BATCH_SIZE, + shuffle=True, + drop_last=True + ) \ + .map(Collate()) \ + .map(UserHashing(NUM_USER_HASH)) \ + .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \ + .map(ToMasked('item.semantic', is_right_aligned=True)) \ + .map(ToMasked('labels.semantic', is_right_aligned=True)) \ + .map(TigerProcessing()) + + valid_dataloader = DataLoader( + dataset=valid_dataset, + batch_size=VALID_BATCH_SIZE, + shuffle=False, + drop_last=False + ) \ + .map(Collate()) \ + .map(UserHashing(NUM_USER_HASH)) \ + .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \ + .map(ToMasked('item.semantic', is_right_aligned=True)) \ + .map(ToMasked('labels.semantic', is_right_aligned=True)) \ + .map(ToMasked('visited', is_right_aligned=True)) \ + .map(TigerProcessing()) + + eval_dataloader = DataLoader( + dataset=eval_dataset, + batch_size=VALID_BATCH_SIZE, + shuffle=False, + drop_last=False + ) \ + .map(Collate()) \ + .map(UserHashing(NUM_USER_HASH)) \ + .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \ + .map(ToMasked('item.semantic', is_right_aligned=True)) \ + .map(ToMasked('labels.semantic', is_right_aligned=True)) \ + .map(ToMasked('visited', is_right_aligned=True)) \ + .map(TigerProcessing()) + + train_batches = [] + for train_batch in train_dataloader: + train_batches.append(train_batch) + save_batches_to_arrow(train_batches, TRAIN_BATCHES_DIR) + + valid_batches = [] + for valid_batch in valid_dataloader: + valid_batches.append(valid_batch) + save_batches_to_arrow(valid_batches, VALID_BATCHES_DIR) + + eval_batches = [] + for eval_batch in eval_dataloader: + eval_batches.append(eval_batch) + save_batches_to_arrow(eval_batches, EVAL_BATCHES_DIR) + + + +if __name__ == '__main__': + main() diff --git a/scripts/tiger/beauty_exps/varka_4.3_plum.py b/scripts/tiger/beauty_exps/varka_4.3_plum.py new file mode 100644 index 0000000..2a96339 --- /dev/null +++ b/scripts/tiger/beauty_exps/varka_4.3_plum.py @@ -0,0 +1,289 @@ +from collections import defaultdict +import json +import murmurhash +import numpy as np +import os +from pathlib import Path + +import pyarrow as pa +import pyarrow.feather as feather + +import torch + +from irec.data.transforms import Collate, Transform +from irec.data.dataloader import DataLoader + +from data import Dataset + + + +# ПУТИ + +IREC_PATH = '../../' +INTERACTIONS_TRAIN_PATH = os.path.join(IREC_PATH, 'sigir/Beauty_new/updated_quantile_splits/merged_for_exps/exp_4_0.9_inter_tiger_train.json') +INTERACTIONS_VALID_PATH = os.path.join(IREC_PATH, 'sigir/Beauty_new/updated_quantile_splits/merged_for_exps/valid_set.json') +INTERACTIONS_TEST_PATH = os.path.join(IREC_PATH, 'sigir/Beauty_new/updated_quantile_splits/merged_for_exps/test_set.json') + +SEMANTIC_MAPPING_PATH = os.path.join(IREC_PATH, 'results_sigir/4-3_updated_quantile_plum_rqvae_beauty_ws_2_clusters_colisionless.json') +TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/updated_quantile_tiger_4-3_train_batches/') +VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/updated_quantile_tiger_4-3_valid_batches/') +EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/updated_quantile_tiger_4-3_eval_batches/') + + + + +# ОСТАЛЬНОЕ + +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + + +MAX_SEQ_LEN = 20 +TRAIN_BATCH_SIZE = 256 +VALID_BATCH_SIZE = 1024 +NUM_USER_HASH = 2000 +CODEBOOK_SIZE = 256 +NUM_CODEBOOKS = 4 + +UNIFIED_VOCAB_SIZE = CODEBOOK_SIZE * NUM_CODEBOOKS + NUM_USER_HASH + 10 # 10 for utilities +PAD_TOKEN_ID = UNIFIED_VOCAB_SIZE - 1, +EOS_TOKEN_ID = UNIFIED_VOCAB_SIZE - 2, +DECODER_START_TOKEN_ID = UNIFIED_VOCAB_SIZE - 3, + + + +class TigerProcessing(Transform): + def __call__(self, batch): + input_semantic_ids, attention_mask = batch['item.semantic.padded'], batch['item.semantic.mask'] + batch_size = attention_mask.shape[0] + + input_semantic_ids[~attention_mask] = PAD_TOKEN_ID # TODO ??? + + input_semantic_ids = np.concatenate([ + input_semantic_ids, + NUM_CODEBOOKS * CODEBOOK_SIZE + batch['user.hashed.ids'][:, None] + ], axis=-1) + + attention_mask = np.concatenate([ + attention_mask, + np.ones((batch_size, 1), dtype=attention_mask.dtype) + ], axis=-1) + + batch['input.data'] = input_semantic_ids + batch['input.mask'] = attention_mask + + target_semantic_ids = batch['labels.semantic.padded'] + target_semantic_ids = np.concatenate([ + np.ones( + (batch_size, 1), + dtype=np.int64, + ) * DECODER_START_TOKEN_ID, + target_semantic_ids + ], axis=-1) + + batch['output.data'] = target_semantic_ids + + return batch + + +class ToMasked(Transform): + def __init__(self, prefix, is_right_aligned=False): + self._prefix = prefix + self._is_right_aligned = is_right_aligned + + def __call__(self, batch): + data = batch[f'{self._prefix}.ids'] + lengths = batch[f'{self._prefix}.length'] + + batch_size = lengths.shape[0] + max_sequence_length = int(lengths.max()) + + if len(data.shape) == 1: # only indices + padded_tensor = np.zeros( + (batch_size, max_sequence_length), + dtype=data.dtype + ) # (batch_size, max_seq_len) + else: + assert len(data.shape) == 2 # embeddings + padded_tensor = np.zeros( + (batch_size, max_sequence_length, data.shape[-1]), + dtype=data.dtype + ) # (batch_size, max_seq_len, emb_dim) + + mask = np.arange(max_sequence_length)[None] < lengths[:, None] + + if self._is_right_aligned: + mask = np.flip(mask, axis=-1) + + padded_tensor[mask] = data + + batch[f'{self._prefix}.padded'] = padded_tensor + batch[f'{self._prefix}.mask'] = mask + + return batch + + +class SemanticIdsMapper(Transform): + def __init__(self, mapping, names=[]): + super().__init__() + self._mapping = mapping + self._names = names + + data = [] + for i in range(len(mapping)): + data.append(mapping[str(i)]) + self._mapping_tensor = torch.tensor(data, dtype=torch.long) + self._semantic_length = self._mapping_tensor.shape[-1] + + def __call__(self, batch): + for name in self._names: + if f'{name}.ids' in batch: + ids = batch[f'{name}.ids'] + lengths = batch[f'{name}.length'] + assert ids.min() >= 0 + assert ids.max() < self._mapping_tensor.shape[0] + batch[f'{name}.semantic.ids'] = self._mapping_tensor[ids].flatten().numpy() + batch[f'{name}.semantic.length'] = lengths * self._semantic_length + + return batch + + +class UserHashing(Transform): + def __init__(self, hash_size): + super().__init__() + self._hash_size = hash_size + + def __call__(self, batch): + batch['user.hashed.ids'] = np.array([murmurhash.hash(str(x)) % self._hash_size for x in batch['user.ids']], dtype=np.int64) + return batch + + +def save_batches_to_arrow(batches, output_dir): + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=False) + + for batch_idx, batch in enumerate(batches): + length_groups = defaultdict(dict) + metadata_groups = defaultdict(dict) + + for key, value in batch.items(): + length = len(value) + + metadata_groups[length][f'{key}_shape'] = str(value.shape) + metadata_groups[length][f'{key}_dtype'] = str(value.dtype) + + if value.ndim == 1: + # 1D массив - сохраняем как есть + length_groups[length][key] = value + elif value.ndim == 2: + # 2D массив - используем list of lists + length_groups[length][key] = value.tolist() + else: + # >2D массив - flatten и сохраняем shape + length_groups[length][key] = value.flatten() + + for length, fields in length_groups.items(): + arrow_dict = {} + for k, v in fields.items(): + if isinstance(v, list) and len(v) > 0 and isinstance(v[0], list): + # List of lists (2D) + arrow_dict[k] = pa.array(v) + else: + arrow_dict[k] = pa.array(v) + + table = pa.table(arrow_dict) + if length in metadata_groups: + table = table.replace_schema_metadata(metadata_groups[length]) + + feather.write_feather( + table, + output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow", + compression='lz4' + ) + + # arrow_dict = {k: pa.array(v) for k, v in fields.items()} + # table = pa.table(arrow_dict) + + # feather.write_feather( + # table, + # output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow", + # compression='lz4' + # ) + + +def main(): + data = Dataset.create_timestamp_based( + train_json_path=INTERACTIONS_TRAIN_PATH, + validation_json_path=INTERACTIONS_VALID_PATH, + test_json_path=INTERACTIONS_TEST_PATH, + max_sequence_length=MAX_SEQ_LEN, + sampler_type='tiger', + min_sample_len=2, + is_extended=True + ) + + with open(SEMANTIC_MAPPING_PATH, 'r') as f: + mappings = json.load(f) + + train_dataset, valid_dataset, eval_dataset = data.get_datasets() + + train_dataloader = DataLoader( + dataset=train_dataset, + batch_size=TRAIN_BATCH_SIZE, + shuffle=True, + drop_last=True + ) \ + .map(Collate()) \ + .map(UserHashing(NUM_USER_HASH)) \ + .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \ + .map(ToMasked('item.semantic', is_right_aligned=True)) \ + .map(ToMasked('labels.semantic', is_right_aligned=True)) \ + .map(TigerProcessing()) + + valid_dataloader = DataLoader( + dataset=valid_dataset, + batch_size=VALID_BATCH_SIZE, + shuffle=False, + drop_last=False + ) \ + .map(Collate()) \ + .map(UserHashing(NUM_USER_HASH)) \ + .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \ + .map(ToMasked('item.semantic', is_right_aligned=True)) \ + .map(ToMasked('labels.semantic', is_right_aligned=True)) \ + .map(ToMasked('visited', is_right_aligned=True)) \ + .map(TigerProcessing()) + + eval_dataloader = DataLoader( + dataset=eval_dataset, + batch_size=VALID_BATCH_SIZE, + shuffle=False, + drop_last=False + ) \ + .map(Collate()) \ + .map(UserHashing(NUM_USER_HASH)) \ + .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \ + .map(ToMasked('item.semantic', is_right_aligned=True)) \ + .map(ToMasked('labels.semantic', is_right_aligned=True)) \ + .map(ToMasked('visited', is_right_aligned=True)) \ + .map(TigerProcessing()) + + train_batches = [] + for train_batch in train_dataloader: + train_batches.append(train_batch) + save_batches_to_arrow(train_batches, TRAIN_BATCHES_DIR) + + valid_batches = [] + for valid_batch in valid_dataloader: + valid_batches.append(valid_batch) + save_batches_to_arrow(valid_batches, VALID_BATCHES_DIR) + + eval_batches = [] + for eval_batch in eval_dataloader: + eval_batches.append(eval_batch) + save_batches_to_arrow(eval_batches, EVAL_BATCHES_DIR) + + + +if __name__ == '__main__': + main() diff --git a/scripts/tiger/data.py b/scripts/tiger/data.py index 188993a..a34accd 100644 --- a/scripts/tiger/data.py +++ b/scripts/tiger/data.py @@ -28,6 +28,116 @@ def __init__( self._num_items = num_items self._max_sequence_length = max_sequence_length + @classmethod + def create_timestamp_based( + cls, + train_json_path, + validation_json_path, + test_json_path, + max_sequence_length, + sampler_type, + min_sample_len=2, + is_extended=False + ): + max_item_id = 0 + train_dataset, validation_dataset, test_dataset = [], [], [] + + with open(train_json_path, 'r') as f: + train_data = json.load(f) + with open(validation_json_path, 'r') as f: + validation_data = json.load(f) + with open(test_json_path, 'r') as f: + test_data = json.load(f) + + all_users = set(train_data.keys()) | set(validation_data.keys()) | set(test_data.keys()) + print(f"all users count: {len(all_users)}") + for user_id_str in all_users: + user_id = int(user_id_str) + + train_items = train_data.get(user_id_str, []) + validation_items = validation_data.get(user_id_str, []) + test_items = test_data.get(user_id_str, []) + + full_sequence = train_items + validation_items + test_items + if full_sequence: + max_item_id = max(max_item_id, max(full_sequence)) + + assert len(full_sequence) >= 5, f'Core-5 dataset is used, user {user_id} has only {len(full_sequence)} items' + + if is_extended: + # sample = [1, 2] + # sample = [1, 2, 3] + # sample = [1, 2, 3, 4] + # sample = [1, 2, 3, 4, 5] + # sample = [1, 2, 3, 4, 5, 6] + # sample = [1, 2, 3, 4, 5, 6, 7] + # sample = [1, 2, 3, 4, 5, 6, 7, 8] + for prefix_length in range(min_sample_len, len(train_items) + 1): + train_dataset.append({ + 'user.ids': [user_id], + 'item.ids': train_items[:prefix_length], + }) + else: + # sample = [1, 2, 3, 4, 5, 6, 7, 8] + train_dataset.append({ + 'user.ids': [user_id], + 'item.ids': train_items, + }) + + # разворачиваем каждый айтем из валидации в отдельный сэмпл + # Пример: Train=[1,2], Valid=[3,4] + # sample = [1, 2, 3] + # sample = [1, 2, 3, 4] + + current_history = train_items.copy() + for item in validation_items: + # эвал датасет сам отрезает таргет потом + sample_sequence = current_history + [item] + + if len(sample_sequence) >= min_sample_len: + validation_dataset.append({ + 'user.ids': [user_id], + 'item.ids': sample_sequence, + }) + current_history.append(item) + + # разворачиваем каждый айтем из теста в отдельный сэмпл + # Пример: Train=[1,2], Valid=[3,4], Test=[5, 6] + # sample = [1, 2, 3, 4, 5] + # sample = [1, 2, 3, 4, 5, 6] + current_history = train_items + validation_items + + for item in test_items: + # эвал датасет сам отрезает таргет потом + sample_sequence = current_history + [item] + + if len(sample_sequence) >= min_sample_len: + test_dataset.append({ + 'user.ids': [user_id], + 'item.ids': sample_sequence, + }) + + current_history.append(item) + + logger.debug(f'Train dataset size: {len(train_dataset)}') + logger.debug(f'Validation dataset size: {len(validation_dataset)}') + logger.debug(f'Test dataset size: {len(test_dataset)}') + print(f'Train dataset size: {len(train_dataset)}') + print(f'Validation dataset size: {len(validation_dataset)}') + print(f'Test dataset size: {len(test_dataset)}') + + train_sampler = TrainDataset(train_dataset, sampler_type, max_sequence_length=max_sequence_length) + validation_sampler = EvalDataset(validation_dataset, max_sequence_length=max_sequence_length) + test_sampler = EvalDataset(test_dataset, max_sequence_length=max_sequence_length) + + return cls( + train_sampler=train_sampler, + validation_sampler=validation_sampler, + test_sampler=test_sampler, + num_items=max_item_id + 1, # +1 added because our ids are 0-indexed + max_sequence_length=max_sequence_length + ) + @classmethod def create(cls, inter_json_path, max_sequence_length, sampler_type, is_extended=False): max_item_id = 0 diff --git a/scripts/tiger/train.py b/scripts/tiger/train.py index f436dd4..1a2d347 100644 --- a/scripts/tiger/train.py +++ b/scripts/tiger/train.py @@ -14,10 +14,23 @@ from data import ArrowBatchDataset from models import TigerModel, CorrectItemsLogitsProcessor + +# ПУТИ +IREC_PATH = '../../' +SEMANTIC_MAPPING_PATH = os.path.join(IREC_PATH, 'results_sigir/4-1_plum_rqvae_beauty_ws_2_clusters_colisionless.json') +TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/tiger_4-1_train_batches/') +VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/tiger_4-1_valid_batches/') +EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/tiger_4-1_eval_batches/') + +TENSORBOARD_LOGDIR = os.path.join(IREC_PATH, 'tensorboard_logs') +CHECKPOINTS_DIR = os.path.join(IREC_PATH, 'checkpoints') + +EXPERIMENT_NAME = 'tiger_beauty_4-1_plum_ws_2_dp_0.2' + +# ОСТАЛЬНОЕ SEED_VALUE = 42 DEVICE = 'cuda' -EXPERIMENT_NAME = 'tiger_beauty' NUM_EPOCHS = 300 MAX_SEQ_LEN = 20 TRAIN_BATCH_SIZE = 256 @@ -30,13 +43,12 @@ NUM_LAYERS = 4 FEEDFORWARD_DIM = 1024 KV_DIM = 64 -DROPOUT = 0.1 +DROPOUT = 0.2 NUM_BEAMS = 30 TOP_K = 20 NUM_CODEBOOKS = 4 -LR = 3e-4 +LR = 0.0001 -IREC_PATH = '../../' torch.set_float32_matmul_precision('high') torch._dynamo.config.capture_scalar_outputs = True @@ -48,30 +60,30 @@ def main(): fix_random_seed(SEED_VALUE) - with open(os.path.join(IREC_PATH, 'results/rqvae_beauty_best_clusters_colisionless.json'), 'r') as f: + with open(SEMANTIC_MAPPING_PATH, 'r') as f: mappings = json.load(f) - + train_dataloader = DataLoader( ArrowBatchDataset( - os.path.join(IREC_PATH, 'data/Beauty/tiger_train_batches/'), - device='cpu', + TRAIN_BATCHES_DIR, + device='cpu', preload=True ), - batch_size=1, - shuffle=True, + batch_size=1, + shuffle=True, num_workers=0, - pin_memory=True, + pin_memory=True, collate_fn=Collate() ).map(ToDevice(DEVICE)).repeat(NUM_EPOCHS) valid_dataloder = ArrowBatchDataset( - os.path.join(IREC_PATH, 'data/Beauty/tiger_valid_batches/'), + VALID_BATCHES_DIR, device=DEVICE, preload=True ) eval_dataloder = ArrowBatchDataset( - os.path.join(IREC_PATH, 'data/Beauty/tiger_eval_batches/'), + EVAL_BATCHES_DIR, device=DEVICE, preload=True ) @@ -177,22 +189,22 @@ def main(): ), ], ).every_num_steps(EPOCH_NUM_STEPS), - + cb.Logger().every_num_steps(EPOCH_NUM_STEPS), - cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=os.path.join(IREC_PATH, 'tensorboard_logs')), + cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=TENSORBOARD_LOGDIR), cb.EarlyStopping( - metric='eval/ndcg@20', + metric='eval/ndcg@20', patience=40, minimize=False, - model_path=os.path.join(IREC_PATH, 'checkpoints', EXPERIMENT_NAME) + model_path=os.path.join(CHECKPOINTS_DIR, EXPERIMENT_NAME) ).every_num_steps(EPOCH_NUM_STEPS) # cb.Profiler( # wait=10, # warmup=10, # active=10, - # logdir=os.path.join(IREC_PATH, 'tensorboard_logs') + # logdir=TENSORBOARD_LOGDIR # ), # cb.StopAfterNumSteps(40) diff --git a/scripts/tiger/varka.py b/scripts/tiger/varka.py index ed47595..4dc3e02 100644 --- a/scripts/tiger/varka.py +++ b/scripts/tiger/varka.py @@ -15,6 +15,20 @@ from data import Dataset + + +# ПУТИ + +IREC_PATH = '../../' +INTERACTIONS_PATH = os.path.join(IREC_PATH, 'data/Beauty/inter.json') +SEMANTIC_MAPPING_PATH = os.path.join(IREC_PATH, 'results/rqvae_beauty_best_clusters_colisionless.json') +TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/tiger_train_batches/') +VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/tiger_valid_batches/') +EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/tiger_eval_batches/') + + +# ОСТАЛЬНОЕ + SEED_VALUE = 42 DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') @@ -32,8 +46,6 @@ DECODER_START_TOKEN_ID = UNIFIED_VOCAB_SIZE - 3, -IREC_PATH = '../../' - class TigerProcessing(Transform): def __call__(self, batch): @@ -42,12 +54,12 @@ def __call__(self, batch): input_semantic_ids[~attention_mask] = PAD_TOKEN_ID # TODO ??? - input_semantic_ids = np.concat([ + input_semantic_ids = np.concatenate([ input_semantic_ids, NUM_CODEBOOKS * CODEBOOK_SIZE + batch['user.hashed.ids'][:, None] ], axis=-1) - attention_mask = np.concat([ + attention_mask = np.concatenate([ attention_mask, np.ones((batch_size, 1), dtype=attention_mask.dtype) ], axis=-1) @@ -56,7 +68,7 @@ def __call__(self, batch): batch['input.mask'] = attention_mask target_semantic_ids = batch['labels.semantic.padded'] - target_semantic_ids = np.concat([ + target_semantic_ids = np.concatenate([ np.ones( (batch_size, 1), dtype=np.int64, @@ -73,7 +85,7 @@ class ToMasked(Transform): def __init__(self, prefix, is_right_aligned=False): self._prefix = prefix self._is_right_aligned = is_right_aligned - + def __call__(self, batch): data = batch[f'{self._prefix}.ids'] lengths = batch[f'{self._prefix}.length'] @@ -92,7 +104,7 @@ def __call__(self, batch): (batch_size, max_sequence_length, data.shape[-1]), dtype=data.dtype ) # (batch_size, max_seq_len, emb_dim) - + mask = np.arange(max_sequence_length)[None] < lengths[:, None] if self._is_right_aligned: @@ -117,10 +129,10 @@ def __init__(self, mapping, names=[]): data.append(mapping[str(i)]) self._mapping_tensor = torch.tensor(data, dtype=torch.long) self._semantic_length = self._mapping_tensor.shape[-1] - + def __call__(self, batch): for name in self._names: - if f'{name}.ids' in batch: + if f'{name}.ids' in batch: ids = batch[f'{name}.ids'] lengths = batch[f'{name}.length'] assert ids.min() >= 0 @@ -135,7 +147,7 @@ class UserHashing(Transform): def __init__(self, hash_size): super().__init__() self._hash_size = hash_size - + def __call__(self, batch): batch['user.hashed.ids'] = np.array([murmurhash.hash(str(x)) % self._hash_size for x in batch['user.ids']], dtype=np.int64) return batch @@ -144,7 +156,7 @@ def __call__(self, batch): def save_batches_to_arrow(batches, output_dir): output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=False) - + for batch_idx, batch in enumerate(batches): length_groups = defaultdict(dict) metadata_groups = defaultdict(dict) @@ -164,7 +176,7 @@ def save_batches_to_arrow(batches, output_dir): else: # >2D массив - flatten и сохраняем shape length_groups[length][key] = value.flatten() - + for length, fields in length_groups.items(): arrow_dict = {} for k, v in fields.items(): @@ -173,11 +185,11 @@ def save_batches_to_arrow(batches, output_dir): arrow_dict[k] = pa.array(v) else: arrow_dict[k] = pa.array(v) - + table = pa.table(arrow_dict) if length in metadata_groups: table = table.replace_schema_metadata(metadata_groups[length]) - + feather.write_feather( table, output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow", @@ -186,7 +198,7 @@ def save_batches_to_arrow(batches, output_dir): # arrow_dict = {k: pa.array(v) for k, v in fields.items()} # table = pa.table(arrow_dict) - + # feather.write_feather( # table, # output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow", @@ -196,15 +208,15 @@ def save_batches_to_arrow(batches, output_dir): def main(): data = Dataset.create( - inter_json_path=os.path.join(IREC_PATH, 'data/Beauty/inter.json'), + inter_json_path=INTERACTIONS_PATH, max_sequence_length=MAX_SEQ_LEN, sampler_type='tiger', is_extended=True ) - with open(os.path.join(IREC_PATH, 'results/rqvae_beauty_best_clusters_colisionless.json'), 'r') as f: + with open(SEMANTIC_MAPPING_PATH, 'r') as f: mappings = json.load(f) - + train_dataset, valid_dataset, eval_dataset = data.get_datasets() train_dataloader = DataLoader( @@ -219,7 +231,7 @@ def main(): .map(ToMasked('item.semantic', is_right_aligned=True)) \ .map(ToMasked('labels.semantic', is_right_aligned=True)) \ .map(TigerProcessing()) - + valid_dataloader = DataLoader( dataset=valid_dataset, batch_size=VALID_BATCH_SIZE, @@ -251,17 +263,18 @@ def main(): train_batches = [] for train_batch in train_dataloader: train_batches.append(train_batch) - save_batches_to_arrow(train_batches, os.path.join(IREC_PATH, 'data/Beauty/tiger_train_batches/')) - + save_batches_to_arrow(train_batches, TRAIN_BATCHES_DIR) + valid_batches = [] for valid_batch in valid_dataloader: valid_batches.append(valid_batch) - save_batches_to_arrow(valid_batches, os.path.join(IREC_PATH, 'data/Beauty/tiger_valid_batches/')) - + save_batches_to_arrow(valid_batches, VALID_BATCHES_DIR) + eval_batches = [] for eval_batch in eval_dataloader: eval_batches.append(eval_batch) - save_batches_to_arrow(eval_batches, os.path.join(IREC_PATH, 'data/Beauty/tiger_eval_batches/')) + save_batches_to_arrow(eval_batches, EVAL_BATCHES_DIR) + if __name__ == '__main__': diff --git a/scripts/tiger/varka_timestamp_based.py b/scripts/tiger/varka_timestamp_based.py new file mode 100644 index 0000000..11343ea --- /dev/null +++ b/scripts/tiger/varka_timestamp_based.py @@ -0,0 +1,287 @@ +from collections import defaultdict +import json +import murmurhash +import numpy as np +import os +from pathlib import Path + +import pyarrow as pa +import pyarrow.feather as feather + +import torch + +from irec.data.transforms import Collate, Transform +from irec.data.dataloader import DataLoader + +from data import Dataset + + + +# ПУТИ + +IREC_PATH = '../../' +INTERACTIONS_TRAIN_PATH = os.path.join(IREC_PATH, 'sigir/Beauty_new/splits/exp_data/exp_4_inter_tiger_train.json') +INTERACTIONS_VALID_PATH = os.path.join(IREC_PATH, 'sigir/Beauty_new/splits/exp_data/valid_skip_set.json') +INTERACTIONS_TEST_PATH = os.path.join(IREC_PATH, 'sigir/Beauty_new/splits/exp_data/test_set.json') + +SEMANTIC_MAPPING_PATH = os.path.join(IREC_PATH, 'results_sigir/4-1_plum_rqvae_beauty_ws_2_clusters_colisionless.json') +TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/tiger_4-1_train_batches/') +VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/tiger_4-1_valid_batches/') +EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/tiger_4-1_eval_batches/') + + +# ОСТАЛЬНОЕ + +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + + +MAX_SEQ_LEN = 20 +TRAIN_BATCH_SIZE = 256 +VALID_BATCH_SIZE = 1024 +NUM_USER_HASH = 2000 +CODEBOOK_SIZE = 256 +NUM_CODEBOOKS = 4 + +UNIFIED_VOCAB_SIZE = CODEBOOK_SIZE * NUM_CODEBOOKS + NUM_USER_HASH + 10 # 10 for utilities +PAD_TOKEN_ID = UNIFIED_VOCAB_SIZE - 1, +EOS_TOKEN_ID = UNIFIED_VOCAB_SIZE - 2, +DECODER_START_TOKEN_ID = UNIFIED_VOCAB_SIZE - 3, + + + +class TigerProcessing(Transform): + def __call__(self, batch): + input_semantic_ids, attention_mask = batch['item.semantic.padded'], batch['item.semantic.mask'] + batch_size = attention_mask.shape[0] + + input_semantic_ids[~attention_mask] = PAD_TOKEN_ID # TODO ??? + + input_semantic_ids = np.concatenate([ + input_semantic_ids, + NUM_CODEBOOKS * CODEBOOK_SIZE + batch['user.hashed.ids'][:, None] + ], axis=-1) + + attention_mask = np.concatenate([ + attention_mask, + np.ones((batch_size, 1), dtype=attention_mask.dtype) + ], axis=-1) + + batch['input.data'] = input_semantic_ids + batch['input.mask'] = attention_mask + + target_semantic_ids = batch['labels.semantic.padded'] + target_semantic_ids = np.concatenate([ + np.ones( + (batch_size, 1), + dtype=np.int64, + ) * DECODER_START_TOKEN_ID, + target_semantic_ids + ], axis=-1) + + batch['output.data'] = target_semantic_ids + + return batch + + +class ToMasked(Transform): + def __init__(self, prefix, is_right_aligned=False): + self._prefix = prefix + self._is_right_aligned = is_right_aligned + + def __call__(self, batch): + data = batch[f'{self._prefix}.ids'] + lengths = batch[f'{self._prefix}.length'] + + batch_size = lengths.shape[0] + max_sequence_length = int(lengths.max()) + + if len(data.shape) == 1: # only indices + padded_tensor = np.zeros( + (batch_size, max_sequence_length), + dtype=data.dtype + ) # (batch_size, max_seq_len) + else: + assert len(data.shape) == 2 # embeddings + padded_tensor = np.zeros( + (batch_size, max_sequence_length, data.shape[-1]), + dtype=data.dtype + ) # (batch_size, max_seq_len, emb_dim) + + mask = np.arange(max_sequence_length)[None] < lengths[:, None] + + if self._is_right_aligned: + mask = np.flip(mask, axis=-1) + + padded_tensor[mask] = data + + batch[f'{self._prefix}.padded'] = padded_tensor + batch[f'{self._prefix}.mask'] = mask + + return batch + + +class SemanticIdsMapper(Transform): + def __init__(self, mapping, names=[]): + super().__init__() + self._mapping = mapping + self._names = names + + data = [] + for i in range(len(mapping)): + data.append(mapping[str(i)]) + self._mapping_tensor = torch.tensor(data, dtype=torch.long) + self._semantic_length = self._mapping_tensor.shape[-1] + + def __call__(self, batch): + for name in self._names: + if f'{name}.ids' in batch: + ids = batch[f'{name}.ids'] + lengths = batch[f'{name}.length'] + assert ids.min() >= 0 + assert ids.max() < self._mapping_tensor.shape[0] + batch[f'{name}.semantic.ids'] = self._mapping_tensor[ids].flatten().numpy() + batch[f'{name}.semantic.length'] = lengths * self._semantic_length + + return batch + + +class UserHashing(Transform): + def __init__(self, hash_size): + super().__init__() + self._hash_size = hash_size + + def __call__(self, batch): + batch['user.hashed.ids'] = np.array([murmurhash.hash(str(x)) % self._hash_size for x in batch['user.ids']], dtype=np.int64) + return batch + + +def save_batches_to_arrow(batches, output_dir): + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=False) + + for batch_idx, batch in enumerate(batches): + length_groups = defaultdict(dict) + metadata_groups = defaultdict(dict) + + for key, value in batch.items(): + length = len(value) + + metadata_groups[length][f'{key}_shape'] = str(value.shape) + metadata_groups[length][f'{key}_dtype'] = str(value.dtype) + + if value.ndim == 1: + # 1D массив - сохраняем как есть + length_groups[length][key] = value + elif value.ndim == 2: + # 2D массив - используем list of lists + length_groups[length][key] = value.tolist() + else: + # >2D массив - flatten и сохраняем shape + length_groups[length][key] = value.flatten() + + for length, fields in length_groups.items(): + arrow_dict = {} + for k, v in fields.items(): + if isinstance(v, list) and len(v) > 0 and isinstance(v[0], list): + # List of lists (2D) + arrow_dict[k] = pa.array(v) + else: + arrow_dict[k] = pa.array(v) + + table = pa.table(arrow_dict) + if length in metadata_groups: + table = table.replace_schema_metadata(metadata_groups[length]) + + feather.write_feather( + table, + output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow", + compression='lz4' + ) + + # arrow_dict = {k: pa.array(v) for k, v in fields.items()} + # table = pa.table(arrow_dict) + + # feather.write_feather( + # table, + # output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow", + # compression='lz4' + # ) + + +def main(): + data = Dataset.create_timestamp_based( + train_json_path=INTERACTIONS_TRAIN_PATH, + validation_json_path=INTERACTIONS_VALID_PATH, + test_json_path=INTERACTIONS_TEST_PATH, + max_sequence_length=MAX_SEQ_LEN, + sampler_type='tiger', + min_sample_len=2, + is_extended=True + ) + + with open(SEMANTIC_MAPPING_PATH, 'r') as f: + mappings = json.load(f) + + train_dataset, valid_dataset, eval_dataset = data.get_datasets() + + train_dataloader = DataLoader( + dataset=train_dataset, + batch_size=TRAIN_BATCH_SIZE, + shuffle=True, + drop_last=True + ) \ + .map(Collate()) \ + .map(UserHashing(NUM_USER_HASH)) \ + .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \ + .map(ToMasked('item.semantic', is_right_aligned=True)) \ + .map(ToMasked('labels.semantic', is_right_aligned=True)) \ + .map(TigerProcessing()) + + valid_dataloader = DataLoader( + dataset=valid_dataset, + batch_size=VALID_BATCH_SIZE, + shuffle=False, + drop_last=False + ) \ + .map(Collate()) \ + .map(UserHashing(NUM_USER_HASH)) \ + .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \ + .map(ToMasked('item.semantic', is_right_aligned=True)) \ + .map(ToMasked('labels.semantic', is_right_aligned=True)) \ + .map(ToMasked('visited', is_right_aligned=True)) \ + .map(TigerProcessing()) + + eval_dataloader = DataLoader( + dataset=eval_dataset, + batch_size=VALID_BATCH_SIZE, + shuffle=False, + drop_last=False + ) \ + .map(Collate()) \ + .map(UserHashing(NUM_USER_HASH)) \ + .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \ + .map(ToMasked('item.semantic', is_right_aligned=True)) \ + .map(ToMasked('labels.semantic', is_right_aligned=True)) \ + .map(ToMasked('visited', is_right_aligned=True)) \ + .map(TigerProcessing()) + + train_batches = [] + for train_batch in train_dataloader: + train_batches.append(train_batch) + save_batches_to_arrow(train_batches, TRAIN_BATCHES_DIR) + + valid_batches = [] + for valid_batch in valid_dataloader: + valid_batches.append(valid_batch) + save_batches_to_arrow(valid_batches, VALID_BATCHES_DIR) + + eval_batches = [] + for eval_batch in eval_dataloader: + eval_batches.append(eval_batch) + save_batches_to_arrow(eval_batches, EVAL_BATCHES_DIR) + + + +if __name__ == '__main__': + main() diff --git a/sigir/Beauty/DatasetProcessing.ipynb b/sigir/Beauty/DatasetProcessing.ipynb new file mode 100644 index 0000000..b49f4ab --- /dev/null +++ b/sigir/Beauty/DatasetProcessing.ipynb @@ -0,0 +1,856 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "3bdb292f", + "metadata": {}, + "outputs": [], + "source": [ + "from collections import defaultdict\n", + "\n", + "import json\n", + "import numpy as np\n", + "import pandas as pd\n", + "import pickle\n", + "import polars as pl\n", + "\n", + "from transformers import LlamaModel, LlamaTokenizer\n", + "\n", + "import torch\n", + "from torch.utils.data import DataLoader\n", + "\n", + "from tqdm import tqdm as tqdm" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "66d9b312", + "metadata": {}, + "outputs": [], + "source": [ + "interactions_dataset_path = '../data/Beauty/Beauty_5.json'\n", + "metadata_path = '../data/Beauty/metadata.json'\n", + "\n", + "interactions_output_json_path = '../data/Beauty_new/inter_new.json'\n", + "interactions_output_parquet_path = '../data/Beauty_new/inter_new.parquet'\n", + "embeddings_output_path = '../data/Beauty_new/content_embeddings.pkl'\n", + "item_ids_mapping_output_path = '../data/Beauty_new/item_ids_mapping.json'" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "6ed4dffb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of events: 198502\n" + ] + } + ], + "source": [ + "df = defaultdict(list)\n", + "\n", + "with open(interactions_dataset_path, 'r') as f:\n", + " for line in f.readlines():\n", + " review = json.loads(line)\n", + " df['user_id'].append(review['reviewerID'])\n", + " df['item_id'].append(review['asin'])\n", + " df['timestamp'].append(review['unixReviewTime'])\n", + "\n", + "print(f'Number of events: {len(df[\"user_id\"])}')\n", + "\n", + "df = pl.from_dict(df)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "c26746c4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (5, 3)
user_iditem_idtimestamp
strstri64
"A1YJEY40YUW4SE""7806397051"1391040000
"A60XNB876KYML""7806397051"1397779200
"A3G6XNM240RMWA""7806397051"1378425600
"A1PQFP6SAJ6D80""7806397051"1386460800
"A38FVHZTNQ271F""7806397051"1382140800
" + ], + "text/plain": [ + "shape: (5, 3)\n", + "┌────────────────┬────────────┬────────────┐\n", + "│ user_id ┆ item_id ┆ timestamp │\n", + "│ --- ┆ --- ┆ --- │\n", + "│ str ┆ str ┆ i64 │\n", + "╞════════════════╪════════════╪════════════╡\n", + "│ A1YJEY40YUW4SE ┆ 7806397051 ┆ 1391040000 │\n", + "│ A60XNB876KYML ┆ 7806397051 ┆ 1397779200 │\n", + "│ A3G6XNM240RMWA ┆ 7806397051 ┆ 1378425600 │\n", + "│ A1PQFP6SAJ6D80 ┆ 7806397051 ┆ 1386460800 │\n", + "│ A38FVHZTNQ271F ┆ 7806397051 ┆ 1382140800 │\n", + "└────────────────┴────────────┴────────────┘" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "adcf5713", + "metadata": {}, + "outputs": [], + "source": [ + "filtered_df = df.clone()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c0bbf9ba", + "metadata": {}, + "outputs": [], + "source": [ + "# Processing dataset to get core-5 state in case full dataset is provided\n", + "is_changed = True\n", + "threshold = 5\n", + "good_users = set()\n", + "good_items = set()\n", + "\n", + "while is_changed:\n", + " user_counts = filtered_df.group_by('user_id').agg(\n", + " pl.len().alias('user_count'),\n", + " )\n", + " item_counts = filtered_df.group_by('item_id').agg(\n", + " pl.len().alias('item_count'),\n", + " )\n", + "\n", + " good_users = user_counts.filter(pl.col('user_count') >= threshold).select(\n", + " 'user_id',\n", + " )\n", + " good_items = item_counts.filter(pl.col('item_count') >= threshold).select(\n", + " 'item_id',\n", + " )\n", + "\n", + " old_size = len(filtered_df)\n", + "\n", + " new_df = filtered_df.join(good_users, on='user_id', how='inner')\n", + " new_df = new_df.join(good_items, on='item_id', how='inner')\n", + "\n", + " new_size = len(new_df)\n", + "\n", + " filtered_df = new_df\n", + " is_changed = old_size != new_size\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "218a9348", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (5, 3)
user_iditem_idtimestamp
i64i64i64
001391040000
101397779200
201378425600
301386460800
401382140800
" + ], + "text/plain": [ + "shape: (5, 3)\n", + "┌─────────┬─────────┬────────────┐\n", + "│ user_id ┆ item_id ┆ timestamp │\n", + "│ --- ┆ --- ┆ --- │\n", + "│ i64 ┆ i64 ┆ i64 │\n", + "╞═════════╪═════════╪════════════╡\n", + "│ 0 ┆ 0 ┆ 1391040000 │\n", + "│ 1 ┆ 0 ┆ 1397779200 │\n", + "│ 2 ┆ 0 ┆ 1378425600 │\n", + "│ 3 ┆ 0 ┆ 1386460800 │\n", + "│ 4 ┆ 0 ┆ 1382140800 │\n", + "└─────────┴─────────┴────────────┘" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "unique_values = filtered_df[\"user_id\"].unique(maintain_order=True).to_list()\n", + "user_ids_mapping = {value: i for i, value in enumerate(unique_values)}\n", + "\n", + "filtered_df = filtered_df.with_columns(\n", + " pl.col(\"user_id\").replace_strict(user_ids_mapping)\n", + ")\n", + "\n", + "unique_values = filtered_df[\"item_id\"].unique(maintain_order=True).to_list()\n", + "item_ids_mapping = {value: i for i, value in enumerate(unique_values)}\n", + "\n", + "filtered_df = filtered_df.with_columns(\n", + " pl.col(\"item_id\").replace_strict(item_ids_mapping)\n", + ")\n", + "\n", + "filtered_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "34604fe6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (5, 2)
old_item_idnew_item_id
stri64
"7806397051"0
"9759091062"1
"9788072216"2
"9790790961"3
"9790794231"4
" + ], + "text/plain": [ + "shape: (5, 2)\n", + "┌─────────────┬─────────────┐\n", + "│ old_item_id ┆ new_item_id │\n", + "│ --- ┆ --- │\n", + "│ str ┆ i64 │\n", + "╞═════════════╪═════════════╡\n", + "│ 7806397051 ┆ 0 │\n", + "│ 9759091062 ┆ 1 │\n", + "│ 9788072216 ┆ 2 │\n", + "│ 9790790961 ┆ 3 │\n", + "│ 9790794231 ┆ 4 │\n", + "└─────────────┴─────────────┘" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "item_ids_mapping_df = pl.from_dict({\n", + " 'old_item_id': list(item_ids_mapping.keys()),\n", + " 'new_item_id': list(item_ids_mapping.values())\n", + "})\n", + "item_ids_mapping_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "99b54db807b9495c", + "metadata": {}, + "outputs": [], + "source": [ + "with open(item_ids_mapping_output_path, 'w') as f:\n", + " json.dump({str(k): v for k, v in item_ids_mapping.items()}, f)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "6017e65c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (5, 3)
user_iditem_idtimestamp
i64i64i64
001391040000
101397779200
201378425600
301386460800
401382140800
" + ], + "text/plain": [ + "shape: (5, 3)\n", + "┌─────────┬─────────┬────────────┐\n", + "│ user_id ┆ item_id ┆ timestamp │\n", + "│ --- ┆ --- ┆ --- │\n", + "│ i64 ┆ i64 ┆ i64 │\n", + "╞═════════╪═════════╪════════════╡\n", + "│ 0 ┆ 0 ┆ 1391040000 │\n", + "│ 1 ┆ 0 ┆ 1397779200 │\n", + "│ 2 ┆ 0 ┆ 1378425600 │\n", + "│ 3 ┆ 0 ┆ 1386460800 │\n", + "│ 4 ┆ 0 ┆ 1382140800 │\n", + "└─────────┴─────────┴────────────┘" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "filtered_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "9efd1983", + "metadata": {}, + "outputs": [], + "source": [ + "filtered_df = filtered_df.sort([\"user_id\", \"timestamp\"])\n", + "\n", + "grouped_filtered_df = filtered_df.group_by(\"user_id\", maintain_order=True).agg(pl.all())" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "fd51c525", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (5, 2)
old_item_idnew_item_id
stri64
"7806397051"0
"9759091062"1
"9788072216"2
"9790790961"3
"9790794231"4
" + ], + "text/plain": [ + "shape: (5, 2)\n", + "┌─────────────┬─────────────┐\n", + "│ old_item_id ┆ new_item_id │\n", + "│ --- ┆ --- │\n", + "│ str ┆ i64 │\n", + "╞═════════════╪═════════════╡\n", + "│ 7806397051 ┆ 0 │\n", + "│ 9759091062 ┆ 1 │\n", + "│ 9788072216 ┆ 2 │\n", + "│ 9790790961 ┆ 3 │\n", + "│ 9790794231 ┆ 4 │\n", + "└─────────────┴─────────────┘" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "item_ids_mapping_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "8b0821da", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (5, 3)
user_iditem_idtimestamp
i64list[i64]list[i64]
0[6845, 7872, … 0][1318896000, 1318896000, … 1391040000]
1[815, 10405, … 232][1392422400, 1396224000, … 1397779200]
2[6049, 0, … 6608][1378425600, 1378425600, … 1400284800]
3[5521, 5160, … 0][1379116800, 1380931200, … 1386460800]
4[0, 10469, … 11389][1382140800, 1383523200, … 1388966400]
" + ], + "text/plain": [ + "shape: (5, 3)\n", + "┌─────────┬─────────────────────┬─────────────────────────────────┐\n", + "│ user_id ┆ item_id ┆ timestamp │\n", + "│ --- ┆ --- ┆ --- │\n", + "│ i64 ┆ list[i64] ┆ list[i64] │\n", + "╞═════════╪═════════════════════╪═════════════════════════════════╡\n", + "│ 0 ┆ [6845, 7872, … 0] ┆ [1318896000, 1318896000, … 139… │\n", + "│ 1 ┆ [815, 10405, … 232] ┆ [1392422400, 1396224000, … 139… │\n", + "│ 2 ┆ [6049, 0, … 6608] ┆ [1378425600, 1378425600, … 140… │\n", + "│ 3 ┆ [5521, 5160, … 0] ┆ [1379116800, 1380931200, … 138… │\n", + "│ 4 ┆ [0, 10469, … 11389] ┆ [1382140800, 1383523200, … 138… │\n", + "└─────────┴─────────────────────┴─────────────────────────────────┘" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "grouped_filtered_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "dc222d59", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Users count: 22363\n", + "Items count: 12101\n", + "Actions count: 198502\n", + "Avg user history len: 8.876358270357287\n" + ] + } + ], + "source": [ + "print('Users count:', filtered_df.select('user_id').unique().shape[0])\n", + "print('Items count:', filtered_df.select('item_id').unique().shape[0])\n", + "print('Actions count:', filtered_df.shape[0])\n", + "print('Avg user history len:', np.mean(list(map(lambda x: x[0], grouped_filtered_df.select(pl.col('item_id').list.len()).rows()))))" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "a272855d-84b2-4414-ba9f-62647e1151cf", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "shape: (5, 3)\n", + "┌─────┬─────────────────────┬─────────────────────────────────┐\n", + "│ uid ┆ item_ids ┆ timestamps │\n", + "│ --- ┆ --- ┆ --- │\n", + "│ i64 ┆ list[i64] ┆ list[i64] │\n", + "╞═════╪═════════════════════╪═════════════════════════════════╡\n", + "│ 0 ┆ [6845, 7872, … 0] ┆ [1318896000, 1318896000, … 139… │\n", + "│ 1 ┆ [815, 10405, … 232] ┆ [1392422400, 1396224000, … 139… │\n", + "│ 2 ┆ [6049, 0, … 6608] ┆ [1378425600, 1378425600, … 140… │\n", + "│ 3 ┆ [5521, 5160, … 0] ┆ [1379116800, 1380931200, … 138… │\n", + "│ 4 ┆ [0, 10469, … 11389] ┆ [1382140800, 1383523200, … 138… │\n", + "└─────┴─────────────────────┴─────────────────────────────────┘\n" + ] + } + ], + "source": [ + "inter_new = grouped_filtered_df.select([\n", + " pl.col(\"user_id\").alias(\"uid\"),\n", + " pl.col(\"item_id\").alias(\"item_ids\"),\n", + " pl.col(\"timestamp\").alias(\"timestamps\")\n", + "])\n", + "\n", + "print(inter_new.head())" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "de5a853a-8ee2-42dd-a71a-6cc6f90d526c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Файл успешно сохранен: ../data/Beauty_new/inter_new.parquet\n" + ] + } + ], + "source": [ + "output_path_parquet = interactions_output_parquet_path\n", + "inter_new.write_parquet(output_path_parquet)\n", + "\n", + "print(f\"Файл успешно сохранен: {output_path_parquet}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "d07a2e91", + "metadata": {}, + "outputs": [], + "source": [ + "json_data = {}\n", + "for user_id, item_ids, _ in grouped_filtered_df.iter_rows():\n", + " json_data[user_id] = item_ids\n", + "\n", + "with open(interactions_output_json_path, 'w') as f:\n", + " json.dump(json_data, f, indent=2)" + ] + }, + { + "cell_type": "markdown", + "id": "237523fa", + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "## Content embedding creation" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "6361c7a5", + "metadata": {}, + "outputs": [ + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[19], line 5\u001b[0m, in \u001b[0;36mgetDF\u001b[0;34m(path)\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mopen\u001b[39m(path, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;28;01mas\u001b[39;00m f:\n\u001b[0;32m----> 5\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m line \u001b[38;5;129;01min\u001b[39;00m \u001b[43mf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreadlines\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[1;32m 6\u001b[0m df[i] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28meval\u001b[39m(line)\n", + "File \u001b[0;32m/usr/lib/python3.10/codecs.py:319\u001b[0m, in \u001b[0;36mBufferedIncrementalDecoder.decode\u001b[0;34m(self, input, final)\u001b[0m\n\u001b[1;32m 317\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mNotImplementedError\u001b[39;00m\n\u001b[0;32m--> 319\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mdecode\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m, final\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m):\n\u001b[1;32m 320\u001b[0m \u001b[38;5;66;03m# decode input (taking the buffer into account)\u001b[39;00m\n\u001b[1;32m 321\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbuffer \u001b[38;5;241m+\u001b[39m \u001b[38;5;28minput\u001b[39m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: ", + "\nDuring handling of the above exception, another exception occurred:\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[19], line 11\u001b[0m\n\u001b[1;32m 7\u001b[0m i \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m pd\u001b[38;5;241m.\u001b[39mDataFrame\u001b[38;5;241m.\u001b[39mfrom_dict(df, orient\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mindex\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m---> 11\u001b[0m df \u001b[38;5;241m=\u001b[39m \u001b[43mgetDF\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmetadata_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 12\u001b[0m df\u001b[38;5;241m.\u001b[39mhead()\n", + "Cell \u001b[0;32mIn[19], line 5\u001b[0m, in \u001b[0;36mgetDF\u001b[0;34m(path)\u001b[0m\n\u001b[1;32m 3\u001b[0m df \u001b[38;5;241m=\u001b[39m {}\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mopen\u001b[39m(path, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;28;01mas\u001b[39;00m f:\n\u001b[0;32m----> 5\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m line \u001b[38;5;129;01min\u001b[39;00m \u001b[43mf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreadlines\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[1;32m 6\u001b[0m df[i] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28meval\u001b[39m(line)\n\u001b[1;32m 7\u001b[0m i \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "def getDF(path):\n", + " i = 0\n", + " df = {}\n", + " with open(path, 'r') as f:\n", + " for line in f.readlines():\n", + " df[i] = eval(line)\n", + " i += 1\n", + "\n", + " return pd.DataFrame.from_dict(df, orient=\"index\")\n", + "\n", + "df = getDF(metadata_path)\n", + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "971fa89c", + "metadata": {}, + "outputs": [], + "source": [ + "def preprocess(row: pd.Series):\n", + " row = row.fillna(\"None\")\n", + " return f\"Title: {row['title']}. Categories: {', '.join(row['categories'][0])}. Description: {row['description']}.\"\n", + "\n", + "\n", + "def get_data(metadata_df, item_ids_mapping_df):\n", + " filtered_df = metadata_df.join(\n", + " item_ids_mapping_df, \n", + " left_on=\"asin\", \n", + " right_on='old_item_id', \n", + " how=\"inner\"\n", + " ).select(pl.col('new_item_id'), pl.col('title'), pl.col('description'), pl.col('categories'))\n", + "\n", + " filtered_df = filtered_df.to_pandas()\n", + " filtered_df[\"combined_text\"] = filtered_df.apply(preprocess, axis=1)\n", + "\n", + " return filtered_df\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3b0dd5d5", + "metadata": {}, + "outputs": [], + "source": [ + "data = get_data(pl.from_pandas(df), item_ids_mapping)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12e622ff", + "metadata": {}, + "outputs": [], + "source": [ + "device = torch.device('cuda:6')\n", + "\n", + "model_name = \"huggyllama/llama-7b\"\n", + "tokenizer = LlamaTokenizer.from_pretrained(model_name)\n", + "\n", + "if tokenizer.pad_token is None:\n", + " tokenizer.pad_token = tokenizer.eos_token\n", + "\n", + "model = LlamaModel.from_pretrained(model_name)\n", + "model = model.to(device)\n", + "model = model.eval()\n", + "\n", + "\n", + "class MyDataset:\n", + " def __init__(self, data):\n", + " self._data = list(zip(data.to_dict()['new_item_id'].values(), data.to_dict()['combined_text'].values()))\n", + "\n", + " def __len__(self):\n", + " return len(self._data)\n", + "\n", + " def __getitem__(self, idx):\n", + " text = self._data[idx][1]\n", + " inputs = tokenizer(text, return_tensors=\"pt\", max_length=1024, truncation=True, padding=\"max_length\")\n", + " return {\n", + " 'item_id': self._data[idx][0],\n", + " 'input_ids': inputs['input_ids'][0],\n", + " 'attention_mask': inputs['attention_mask'][0]\n", + " }\n", + " \n", + "\n", + "dataset = MyDataset(data)\n", + "loader = DataLoader(dataset, batch_size=8, drop_last=False, shuffle=False, num_workers=10)\n", + "\n", + "\n", + "new_df = {\n", + " 'item_id': [],\n", + " 'embedding': []\n", + "}\n", + "\n", + "for batch in tqdm(loader):\n", + " with torch.inference_mode():\n", + " outputs = model(\n", + " input_ids=batch[\"input_ids\"].to(device), \n", + " attention_mask=batch[\"attention_mask\"].to(device)\n", + " )\n", + " embeddings = outputs.last_hidden_state\n", + " \n", + " embeddings = outputs.last_hidden_state # (bs, sl, ed)\n", + " embeddings[(~batch[\"attention_mask\"].bool())] = 0. # (bs, sl, ed)\n", + "\n", + " new_df['item_id'] += batch['item_id'].tolist()\n", + " new_df['embedding'] += embeddings.mean(dim=1).tolist() # (bs, ed)\n", + "\n", + "\n", + "with open(embeddings_output_path, 'wb') as f:\n", + " pickle.dump(new_df, f)\n" + ] + }, + { + "cell_type": "markdown", + "id": "a6fffc4a-85f1-424e-b460-29e526df3317", + "metadata": {}, + "source": [ + "# Test" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "1f922431-e3c1-4587-86d1-04a58eb8ffee", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "✓ Сохранено: ../data/Beauty_new/splits/inter_new_[0_1291403520.0).json\n", + "✓ Сохранено: ../data/Beauty_new/splits/inter_new_[1291403520.0_1329626880.0).json\n", + "✓ Сохранено: ../data/Beauty_new/splits/inter_new_[1329626880.0_1367850240.0).json\n", + "✓ Сохранено: ../data/Beauty_new/splits/inter_new_[1367850240.0_inf).json\n", + "Интервал 0: 3485 пользователей, 10350 взаимодействий\n", + "Интервал 1: 5751 пользователей, 15837 взаимодействий\n", + "Интервал 2: 13543 пользователей, 61954 взаимодействий\n", + "Интервал 3: 18811 пользователей, 110361 взаимодействий\n" + ] + } + ], + "source": [ + "import polars as pl\n", + "import json\n", + "from typing import List, Dict\n", + "\n", + "def split_session_by_timestamps(\n", + " df: pl.DataFrame,\n", + " time_cutoffs: List[int],\n", + " output_dir: str = None,\n", + " return_dicts: bool = True\n", + ") -> List[Dict[int, List[int]]]:\n", + " \"\"\"\n", + " Разбивает датасет по временным интервалам и возвращает JSON-подобные словари.\n", + " \n", + " Args:\n", + " df: Polars DataFrame с колонками uid, item_ids (list), timestamps (list)\n", + " time_cutoffs: Лист временных точек для разбиения\n", + " output_dir: Директория для сохранения JSON файлов (опционально)\n", + " return_dicts: Возвращать ли словари (как json_data format)\n", + " \n", + " Returns:\n", + " Лист словарей в формате {user_id: [item_ids для интервала]}\n", + " \n", + " Example:\n", + " >>> df = pl.read_parquet(\"inter_new.parquet\")\n", + " >>> cutoffs = [1000000, 2000000, 3000000]\n", + " >>> dicts = split_session_by_timestamps(df, cutoffs, output_dir=\"./data\")\n", + " >>> # Получим 4 JSON файла за каждый интервал + последний\n", + " \"\"\"\n", + " \n", + " result_dicts = []\n", + " \n", + " def extract_interval(df_source, start, end=None):\n", + " \"\"\"Извлекает данные для одного временного интервала\"\"\"\n", + " q = df_source.lazy()\n", + " q = q.explode([\"item_ids\", \"timestamps\"])\n", + " \n", + " if end is not None:\n", + " q = q.filter(\n", + " (pl.col(\"timestamps\") >= start) & \n", + " (pl.col(\"timestamps\") < end)\n", + " )\n", + " else:\n", + " q = q.filter(\n", + " pl.col(\"timestamps\") >= start\n", + " )\n", + " \n", + " q = q.group_by(\"uid\").agg([\n", + " pl.col(\"item_ids\").alias(\"item_ids\")\n", + " ]).sort(\"uid\")\n", + " \n", + " return q.collect()\n", + " \n", + " # Генерируем интервалы\n", + " intervals = []\n", + " current_start = 0\n", + " for cutoff in time_cutoffs:\n", + " intervals.append((current_start, cutoff))\n", + " current_start = cutoff\n", + " # Последний интервал от последнего cutoff до бесконечности\n", + " intervals.append((current_start, None))\n", + " \n", + " # Обрабатываем каждый интервал\n", + " for start, end in intervals:\n", + " subset = extract_interval(df, start, end)\n", + " \n", + " # Конвертируем в JSON-подобный словарь\n", + " json_dict = {}\n", + " for user_id, item_ids in subset.iter_rows():\n", + " json_dict[user_id] = item_ids\n", + " \n", + " result_dicts.append(json_dict)\n", + " \n", + " # Опционально сохраняем файлы\n", + " if output_dir:\n", + " if end is not None:\n", + " filename = f\"inter_new_[{start}_{end}).json\"\n", + " else:\n", + " filename = f\"inter_new_[{start}_inf).json\"\n", + " \n", + " filepath = f\"{output_dir}/{filename}\"\n", + " with open(filepath, 'w') as f:\n", + " json.dump(json_dict, f, indent=2)\n", + " \n", + " print(f\"✓ Сохранено: {filepath}\")\n", + " \n", + " return result_dicts\n", + "\n", + "\n", + "# ==========================================\n", + "# Использование в ноутбуке\n", + "# ==========================================\n", + "\n", + "# Загружаем сохраненный Parquet файл\n", + "df = pl.read_parquet(interactions_output_parquet_path)\n", + "\n", + "# Определяем временные точки разбиения (можно использовать процентили или конкретные даты)\n", + "# Например: разбить на 70%, 80%, 90% от времени\n", + "df_timestamps = df.select(\n", + " pl.col(\"timestamps\").explode()\n", + ")\n", + "min_time = df_timestamps.select(pl.col(\"timestamps\").min()).item()\n", + "max_time = df_timestamps.select(pl.col(\"timestamps\").max()).item()\n", + "\n", + "# Разделяем на 4 части (train/val/test/test_final)\n", + "cutoffs = [\n", + " min_time + (max_time - min_time) * 0.7, # 70%\n", + " min_time + (max_time - min_time) * 0.8, # 80%\n", + " min_time + (max_time - min_time) * 0.9, # 90%\n", + "]\n", + "\n", + "# Применяем функцию\n", + "result_dicts = split_session_by_timestamps(\n", + " df, \n", + " cutoffs,\n", + " output_dir=\"../data/Beauty_new/splits\" # Опционально\n", + ")\n", + "\n", + "# Выводим статистику\n", + "for i, json_dict in enumerate(result_dicts):\n", + " total_interactions = sum(len(items) for items in json_dict.values())\n", + " print(f\"Интервал {i}: {len(json_dict)} пользователей, {total_interactions} взаимодействий\")" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "73b5ec51-4d94-4021-9a21-3f4345ecdd26", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "✓ Сохранено: ../data/Beauty_new/splits/inter_new_[0_inf).json\n" + ] + } + ], + "source": [ + "split_session_by_timestamps(\n", + " df, \n", + " [],\n", + " output_dir=\"../data/Beauty_new/splits\"\n", + ")\n", + "None" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/sigir/Beauty/exps_data.ipynb b/sigir/Beauty/exps_data.ipynb new file mode 100644 index 0000000..2625231 --- /dev/null +++ b/sigir/Beauty/exps_data.ipynb @@ -0,0 +1,921 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "id": "e2462a97-6705-44e1-a232-4dd78a5dfc85", + "metadata": {}, + "outputs": [], + "source": [ + "import polars as pl\n", + "import json\n", + "from typing import List, Dict" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "fd38624d-5796-4aa5-929f-7e82c5544f6c", + "metadata": {}, + "outputs": [], + "source": [ + "interactions_output_parquet_path = '/home/jovyan/IRec/sigir/Beauty_new/inter_new.parquet'\n", + "# 1. Загружаем\n", + "df = pl.read_parquet(interactions_output_parquet_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "ee127317-66b8-4f22-9109-94bcb8b1f1ae", + "metadata": {}, + "outputs": [], + "source": [ + "def split_session_by_timestamps(\n", + " df: pl.DataFrame,\n", + " time_cutoffs: List[int],\n", + " output_dir: str = None,\n", + " return_dicts: bool = True\n", + ") -> List[Dict[int, List[int]]]:\n", + " \"\"\"\n", + " Разбивает датасет по временным интервалам и возвращает JSON-подобные словари.\n", + " \n", + " Args:\n", + " df: Polars DataFrame с колонками uid, item_ids (list), timestamps (list)\n", + " time_cutoffs: Лист временных точек для разбиения\n", + " output_dir: Директория для сохранения JSON файлов (опционально)\n", + " return_dicts: Возвращать ли словари (как json_data format)\n", + " \n", + " Returns:\n", + " Лист словарей в формате {user_id: [item_ids для интервала]}\n", + " \n", + " Example:\n", + " >>> df = pl.read_parquet(\"inter_new.parquet\")\n", + " >>> cutoffs = [1000000, 2000000, 3000000]\n", + " >>> dicts = split_session_by_timestamps(df, cutoffs, output_dir=\"./data\")\n", + " >>> # Получим 4 JSON файла за каждый интервал + последний\n", + " \"\"\"\n", + " \n", + " result_dicts = []\n", + " \n", + " def extract_interval(df_source, start, end=None):\n", + " \"\"\"Извлекает данные для одного временного интервала\"\"\"\n", + " q = df_source.lazy()\n", + " q = q.explode([\"item_ids\", \"timestamps\"])\n", + " \n", + " if end is not None:\n", + " q = q.filter(\n", + " (pl.col(\"timestamps\") >= start) & \n", + " (pl.col(\"timestamps\") < end)\n", + " )\n", + " else:\n", + " q = q.filter(\n", + " pl.col(\"timestamps\") >= start\n", + " )\n", + " \n", + " q = q.group_by(\"uid\").agg([\n", + " pl.col(\"item_ids\").alias(\"item_ids\")\n", + " ]).sort(\"uid\")\n", + " \n", + " return q.collect()\n", + " \n", + " # Генерируем интервалы\n", + " intervals = []\n", + " current_start = 0\n", + " for cutoff in time_cutoffs:\n", + " intervals.append((current_start, cutoff))\n", + " current_start = cutoff\n", + " # Последний интервал от последнего cutoff до бесконечности\n", + " intervals.append((current_start, None))\n", + " \n", + " # Обрабатываем каждый интервал\n", + " for start, end in intervals:\n", + " subset = extract_interval(df, start, end)\n", + " \n", + " # Конвертируем в JSON-подобный словарь\n", + " json_dict = {}\n", + " for user_id, item_ids in subset.iter_rows():\n", + " json_dict[user_id] = item_ids\n", + " \n", + " result_dicts.append(json_dict)\n", + " \n", + " # Опционально сохраняем файлы\n", + " if output_dir:\n", + " if end is not None:\n", + " filename = f\"inter_new_[{start}_{end}).json\"\n", + " else:\n", + " filename = f\"inter_new_[{start}_inf).json\"\n", + " \n", + " filepath = f\"{output_dir}/{filename}\"\n", + " with open(filepath, 'w') as f:\n", + " json.dump(json_dict, f, indent=2)\n", + " \n", + " print(f\"✓ Сохранено: {filepath}\")\n", + " \n", + " return result_dicts" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "efc8b582-dd8a-4299-9c49-de906251de8a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Cutoffs: [1402444800, 1403654400, 1404864000]\n", + "✓ Сохранено: /home/jovyan/IRec/sigir/Beauty_new/splits/test_splits/inter_new_[0_1402444800).json\n", + "✓ Сохранено: /home/jovyan/IRec/sigir/Beauty_new/splits/test_splits/inter_new_[1402444800_1403654400).json\n", + "✓ Сохранено: /home/jovyan/IRec/sigir/Beauty_new/splits/test_splits/inter_new_[1403654400_1404864000).json\n", + "✓ Сохранено: /home/jovyan/IRec/sigir/Beauty_new/splits/test_splits/inter_new_[1404864000_inf).json\n", + "Part 0 [Base]: 22029 users\n", + "Part 1 [Week -6]: 1854 users\n", + "Part 2 [Week -4]: 1945 users\n", + "Part 3 [Week -2]: 1381 users\n" + ] + } + ], + "source": [ + "global_max_time = df.select(\n", + " pl.col(\"timestamps\").explode().max()\n", + ").item()\n", + "\n", + "# 3. Размер окна (неделя)\n", + "days_val = 14\n", + "window_sec = days_val * 24 * 3600 \n", + "\n", + "# 4. Три отсечки с конца\n", + "cutoff_test_start = global_max_time - window_sec # T - 2w\n", + "cutoff_val_start = global_max_time - 2 * window_sec # T - 4w\n", + "cutoff_gap_start = global_max_time - 3 * window_sec # T - 6w\n", + "\n", + "cutoffs = [\n", + " int(cutoff_gap_start), # Граница Part 0 | Part 1\n", + " int(cutoff_val_start), # Граница Part 1 | Part 2\n", + " int(cutoff_test_start) # Граница Part 2 | Part 3\n", + "]\n", + "\n", + "print(f\"Cutoffs: {cutoffs}\")\n", + "\n", + "# 5. Разбиваем на 4 файла\n", + "# Part 0: Deep History\n", + "# Part 1: Pre-Validation (нужна для s1, но выкидывается для 'совсем короткого' s0?)\n", + "# *В вашем случае 4.2 просто 'на неделю короче', так что Part 1 все равно войдет в трейн Semantics, \n", + "# а выкинется только Part 2. Но если захотите еще короче - можно выкинуть и Part 1.*\n", + "# Part 2: Validation (Есть в 4.1, НЕТ в 4.2 для Semantics)\n", + "# Part 3: Test\n", + "\n", + "split_files = split_session_by_timestamps(\n", + " df, \n", + " cutoffs, \n", + " output_dir=\"/home/jovyan/IRec/sigir/Beauty_new/splits/test_splits\"\n", + ")\n", + "\n", + "names = [\"Base\", \"Week -6\", \"Week -4\", \"Week -2\"]\n", + "for i, d in enumerate(split_files):\n", + " print(f\"Part {i} [{names[i]}]: {len(d)} users\")" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "d5ba172e-b430-40a3-a4fa-64366d02a015", + "metadata": {}, + "outputs": [], + "source": [ + "def merge_and_save(parts_to_merge, dirr, output_name):\n", + " merged = {}\n", + " print(f\"Merging {len(parts_to_merge)} files into {output_name}...\")\n", + " \n", + " for part in parts_to_merge:\n", + " # with open(fp, 'r') as f:\n", + " # part = json.load(f)\n", + " for uid, items in part.items():\n", + " if uid not in merged:\n", + " merged[uid] = []\n", + " merged[uid].extend(items)\n", + " \n", + " out_path = f\"{dirr}/{output_name}\"\n", + " with open(out_path, 'w') as f:\n", + " json.dump(merged, f)\n", + " print(f\"✓ Done: {out_path} (Users: {len(merged)})\")\n", + "\n", + "\n", + "# p0, p1, p2, p3 = split_files[0], split_files[1], split_files[2], split_files[3]" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "d116b7e0-9bf9-4104-86a0-69788a70cc14", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Merging 2 files into exp_4_inter_tiger_train.json...\n", + "✓ Done: ../sigir/Beauty_new/splits/exp_data/exp_4_inter_tiger_train.json (Users: 22129)\n", + "Merging 2 files into exp_4.1_inter_semantics_train.json...\n", + "✓ Done: ../sigir/Beauty_new/splits/exp_data/exp_4.1_inter_semantics_train.json (Users: 22129)\n", + "Merging 1 files into exp_4.2_inter_semantics_train_short.json...\n", + "✓ Done: ../sigir/Beauty_new/splits/exp_data/exp_4.2_inter_semantics_train_short.json (Users: 22029)\n", + "Merging 3 files into exp_4.3_inter_semantics_train_leak.json...\n", + "✓ Done: ../sigir/Beauty_new/splits/exp_data/exp_4.3_inter_semantics_train_leak.json (Users: 22265)\n", + "Merging 1 files into test_set.json...\n", + "✓ Done: ../sigir/Beauty_new/splits/exp_data/test_set.json (Users: 1381)\n", + "Merging 1 files into valid_skip_set.json...\n", + "✓ Done: ../sigir/Beauty_new/splits/exp_data/valid_skip_set.json (Users: 1945)\n", + "\n", + "All done!\n" + ] + } + ], + "source": [ + "EXP_DIR = \"../sigir/Beauty_new/splits/exp_data\"\n", + "\n", + "# Tiger: P0+P1\n", + "merge_and_save([p0, p1], EXP_DIR, \"exp_4_inter_tiger_train.json\")\n", + "\n", + "# 1. Exp 4.1 (Standard)\n", + "# Semantics: P0+P1 (Всё кроме пропуска и теста)\n", + "merge_and_save([p0, p1], EXP_DIR, \"exp_4.1_inter_semantics_train.json\")\n", + "\n", + "# 2. Exp 4.2 (Short Semantics)\n", + "# Semantics: P0 (Короче на неделю, без P2)\n", + "merge_and_save([p0], EXP_DIR, \"exp_4.2_inter_semantics_train_short.json\")\n", + "\n", + "# 3. Exp 4.3 (Leak)\n", + "# Semantics: P0+P1+P2 (Видит валидацию)\n", + "merge_and_save([p0, p1, p2], EXP_DIR, \"exp_4.3_inter_semantics_train_leak.json\")\n", + "\n", + "# 4. Test Set (тест всех моделей)\n", + "merge_and_save([p3], EXP_DIR, \"test_set.json\")\n", + "\n", + "# 4. Valid Set (пропуск, имитируется разница трейна и теста чтобы потом дообучать)\n", + "merge_and_save([p2], EXP_DIR, \"valid_skip_set.json\")\n", + "\n", + "print(\"\\nAll done!\")" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "9ae1d1e5-567d-471a-8f83-4039ecacc8d2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Merging 4 files into all_set.json...\n", + "✓ Done: ../sigir/Beauty_new/splits/exp_data/all_set.json (Users: 22363)\n" + ] + } + ], + "source": [ + "merge_and_save([p0, p1, p2, p3], EXP_DIR, \"all_set.json\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "328de16c-f61d-45be-8a72-5f0eaef612e8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Проверка Train сетов (должны быть префиксами):\n", + "✅ [ПРЕФИКСЫ] Все 22129 массивов ОК. Полных совпадений: 19410\n", + "✅ [ПРЕФИКСЫ] Все 22029 массивов ОК. Полных совпадений: 18191\n", + "✅ [ПРЕФИКСЫ] Все 22265 массивов ОК. Полных совпадений: 20982\n", + "✅ [ПРЕФИКСЫ] Все 22129 массивов ОК. Полных совпадений: 19410\n", + "\n", + "Проверка Test сета (должен быть суффиксом):\n", + "✅ [СУФФИКСЫ] Все 1381 массивов ОК. Полных совпадений: 98\n", + "\n", + "(Контроль) Проверка Test сета как префикса (должна упасть):\n", + "❌ [ПРЕФИКСЫ] Найдено 1283 ошибок.\n" + ] + } + ], + "source": [ + "with open(\"../data/Beauty/inter_new.json\", 'r') as f:\n", + " old_inter_new = json.load(f)\n", + "\n", + "with open(\"../sigir/Beauty_new/splits/exp_data/exp_4.1_inter_semantics_train.json\", 'r') as ff:\n", + " first_sem = json.load(ff)\n", + " \n", + "with open(\"../sigir/Beauty_new/splits/exp_data/exp_4.2_inter_semantics_train_short.json\", 'r') as ff:\n", + " second_sem = json.load(ff)\n", + " \n", + "with open(\"../sigir/Beauty_new/splits/exp_data/exp_4.3_inter_semantics_train_leak.json\", 'r') as ff:\n", + " third_sem = json.load(ff)\n", + " \n", + "with open(\"../sigir/Beauty_new/splits/exp_data/exp_4_inter_tiger_train.json\", 'r') as ff:\n", + " tiger_sem = json.load(ff)\n", + "\n", + "with open(\"../sigir/Beauty_new/splits/exp_data/test_set.json\", 'r') as ff:\n", + " test_sem = json.load(ff)\n", + "\n", + "def check_prefix_match(full_data, subset_data, check_suffix=False):\n", + " \"\"\"\n", + " check_suffix=True включит режим проверки суффиксов (для теста).\n", + " \"\"\"\n", + " mismatch_count = 0\n", + " full_match_count = 0\n", + " \n", + " # Итерируемся по ключам сабсета, так как в full_data может быть больше юзеров\n", + " for user, sub_items in subset_data.items():\n", + " \n", + " # Проверяем есть ли такой юзер в исходнике\n", + " if user not in full_data:\n", + " print(f\"⚠ Юзер {user} не найден в исходном файле!\")\n", + " mismatch_count += 1\n", + " continue\n", + " \n", + " full_items = full_data[user]\n", + " \n", + " # Логика для проверки ПРЕФИКСА (начало совпадает)\n", + " if not check_suffix:\n", + " if len(sub_items) > len(full_items):\n", + " mismatch_count += 1\n", + " continue\n", + " \n", + " # Сравниваем начало full с sub\n", + " if full_items[:len(sub_items)] == sub_items:\n", + " if len(full_items) == len(sub_items):\n", + " full_match_count += 1\n", + " else:\n", + " mismatch_count += 1\n", + " \n", + " # Логика для проверки СУФФИКСА (конец совпадает - для теста)\n", + " else:\n", + " if len(sub_items) > len(full_items):\n", + " mismatch_count += 1\n", + " continue\n", + " \n", + " # Сравниваем конец full с sub\n", + " # Срез [-len:] берет последние N элементов\n", + " if full_items[-len(sub_items):] == sub_items:\n", + " if len(full_items) == len(sub_items):\n", + " full_match_count += 1\n", + " else:\n", + " mismatch_count += 1\n", + "\n", + " mode = \"СУФФИКСЫ\" if check_suffix else \"ПРЕФИКСЫ\"\n", + " \n", + " if mismatch_count == 0:\n", + " print(f\"✅ [{mode}] Все {len(subset_data)} массивов ОК. Полных совпадений: {full_match_count}\")\n", + " else:\n", + " print(f\"❌ [{mode}] Найдено {mismatch_count} ошибок.\")\n", + "\n", + "# --- Запуск проверок ---\n", + "print(\"Проверка Train сетов (должны быть префиксами):\")\n", + "check_prefix_match(old_inter_new, first_sem)\n", + "check_prefix_match(old_inter_new, second_sem)\n", + "check_prefix_match(old_inter_new, third_sem)\n", + "check_prefix_match(old_inter_new, tiger_sem)\n", + "\n", + "print(\"\\nПроверка Test сета (должен быть суффиксом):\")\n", + "check_prefix_match(old_inter_new, test_sem, check_suffix=True)\n", + "\n", + "print(\"\\n(Контроль) Проверка Test сета как префикса (должна упасть):\")\n", + "check_prefix_match(old_inter_new, test_sem, check_suffix=False)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "0715adfd", + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'суа' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[1], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mсуа\u001b[49m\n", + "\u001b[0;31mNameError\u001b[0m: name 'суа' is not defined" + ] + } + ], + "source": [ + "суа" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "f2df507d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--- Статистика по временным интервалам (Fixed Time Window) ---\n", + "Part 0 [Base]: 186516 events (Start -> 2014-06-11)\n", + "Part 1 [Gap (Week -6)]: 4073 events (2014-06-11 -> 2014-06-25)\n", + "Part 2 [Pre-Valid (Week -4)]: 4730 events (2014-06-25 -> 2014-07-09)\n", + "Part 3 [Test (Week -2)]: 3183 events (2014-07-09 -> Inf)\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import polars as pl\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "from datetime import datetime\n", + "\n", + "\n", + "# 2. Статистика по текущему временному разбиению\n", + "global_max_time = df.select(pl.col(\"timestamps\").explode().max()).item()\n", + "days_val = 14\n", + "window_sec = days_val * 24 * 3600 \n", + "\n", + "cutoffs = [\n", + " int(global_max_time - 3 * window_sec),\n", + " int(global_max_time - 2 * window_sec),\n", + " int(global_max_time - 1 * window_sec)\n", + "]\n", + "\n", + "print(\"--- Статистика по временным интервалам (Fixed Time Window) ---\")\n", + "intervals = [0] + cutoffs + [None]\n", + "labels = [\"Base\", \"Gap (Week -6)\", \"Pre-Valid (Week -4)\", \"Test (Week -2)\"]\n", + "\n", + "# Считаем события в каждом интервале\n", + "counts = []\n", + "for i in range(len(intervals)-1):\n", + " start, end = intervals[i], intervals[i+1]\n", + " \n", + " q = df.lazy().explode([\"timestamps\"])\n", + " if end is not None:\n", + " q = q.filter((pl.col(\"timestamps\") >= start) & (pl.col(\"timestamps\") < end))\n", + " else:\n", + " q = q.filter(pl.col(\"timestamps\") >= start)\n", + " \n", + " count = q.select(pl.len()).collect().item()\n", + " counts.append(count)\n", + " \n", + " end_str = datetime.fromtimestamp(end).strftime('%Y-%m-%d') if end else \"Inf\"\n", + " start_str = datetime.fromtimestamp(start).strftime('%Y-%m-%d') if start > 0 else \"Start\"\n", + " \n", + " print(f\"Part {i} [{labels[i]}]: {count} events ({start_str} -> {end_str})\")\n", + "\n", + "# 3. Гистограмма распределения событий во времени\n", + "all_timestamps = df.select(pl.col(\"timestamps\").explode()).to_series().to_numpy()\n", + "\n", + "plt.figure(figsize=(12, 6))\n", + "plt.hist(all_timestamps, bins=100, color='skyblue', alpha=0.7, label='Events')\n", + "\n", + "# Рисуем линии отсечек\n", + "colors = ['red', 'orange', 'green']\n", + "for cutoff, color, label in zip(cutoffs, colors, labels[1:]):\n", + " plt.axvline(x=cutoff, color=color, linestyle='--', linewidth=2, label=f'Cutoff: {label}')\n", + "\n", + "plt.title(\"Распределение взаимодействий во времени\")\n", + "plt.xlabel(\"Timestamp\")\n", + "plt.ylabel(\"Количество событий\")\n", + "plt.legend()\n", + "plt.grid(True, alpha=0.3)\n", + "plt.show()\n" + ] + }, + { + "cell_type": "markdown", + "id": "901e7400", + "metadata": {}, + "source": [ + "# QUANTILE CUTOFF" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "8c691891", + "metadata": {}, + "outputs": [], + "source": [ + "def get_quantile_cutoffs(df, num_parts=4, base_ratio=None):\n", + " \"\"\"\n", + " Считает cutoffs так, чтобы разбить данные на части.\n", + " \n", + " Args:\n", + " num_parts: На сколько частей делить \"хвост\" истории.\n", + " base_ratio: Какую долю данных отдать в Base (самую первую часть). \n", + " Если None, делит всё поровну.\n", + " \"\"\"\n", + " # Достаем все таймстемпы в один плоский массив\n", + " # Это может занять память, если данных очень много (>100M), но для Beauty (2M) это ок\n", + " all_ts = df.select(pl.col(\"timestamps\").explode()).to_series().sort()\n", + " total_events = len(all_ts)\n", + " \n", + " print(f\"Всего событий: {total_events}\")\n", + " \n", + " cutoffs = []\n", + " \n", + " if base_ratio:\n", + " # Сценарий: Base занимает X% (например 80%), а остаток делим поровну на 3 части (Valid, Gap, Test)\n", + " # Остаток = 1 - base_ratio\n", + " # Каждая малая часть = (1 - base_ratio) / num_parts_tail\n", + " \n", + " base_idx = int(total_events * base_ratio)\n", + " cutoffs.append(all_ts[base_idx]) # Первый cutoff отделяет Base\n", + " \n", + " remaining_events = total_events - base_idx\n", + " part_size = remaining_events // num_parts # Делим остаток на 3 части (P1, P2, P3)\n", + " \n", + " current_idx = base_idx\n", + " for _ in range(num_parts-1): # Нам нужно еще 2 границы, чтобы получить 3 части\n", + " current_idx += part_size\n", + " cutoffs.append(all_ts[current_idx])\n", + " \n", + " else:\n", + " # Сценарий: Просто делим всё на N равных частей\n", + " step = total_events // num_parts\n", + " for i in range(1, num_parts):\n", + " idx = i * step\n", + " cutoffs.append(all_ts[idx])\n", + " \n", + " return cutoffs\n" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "13c1466f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Всего событий: 198502\n", + "\n", + "--- Новые Cutoffs (по количеству событий) ---\n", + "Cutoffs: [1394150400, 1397001600, 1399939200, 1403049600]\n", + "[0, 1394150400, 1397001600, 1399939200, 1403049600, None]\n", + "Проверка количества событий в новых частях:\n", + "Part 0: 158689 events\n", + "Part 1: 9965 events\n", + "Part 2: 9701 events\n", + "Part 3: 10137 events\n", + "Part 4: 10010 events\n" + ] + } + ], + "source": [ + "equal_event_cutoffs = get_quantile_cutoffs(df, num_parts=4, base_ratio=0.8)\n", + "\n", + "print(\"\\n--- Новые Cutoffs (по количеству событий) ---\")\n", + "print(f\"Cutoffs: {equal_event_cutoffs}\")\n", + "\n", + "# Проверка распределения\n", + "intervals_eq = [0] + equal_event_cutoffs + [None]\n", + "print(intervals_eq)\n", + "print(\"Проверка количества событий в новых частях:\")\n", + "for i in range(len(intervals_eq)-1):\n", + " start, end = intervals_eq[i], intervals_eq[i+1]\n", + " q = df.lazy().explode([\"timestamps\"])\n", + " if end:\n", + " q = q.filter((pl.col(\"timestamps\") >= start) & (pl.col(\"timestamps\") < end))\n", + " else:\n", + " q = q.filter(pl.col(\"timestamps\") >= start)\n", + " count = q.select(pl.len()).collect().item()\n", + " print(f\"Part {i}: {count} events\")" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "4e7f7b46", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "✓ Сохранено: /home/jovyan/IRec/sigir/Beauty_new/updated_quantile_splits/raw/inter_new_[0_1394150400).json\n", + "✓ Сохранено: /home/jovyan/IRec/sigir/Beauty_new/updated_quantile_splits/raw/inter_new_[1394150400_1399939200).json\n", + "✓ Сохранено: /home/jovyan/IRec/sigir/Beauty_new/updated_quantile_splits/raw/inter_new_[1399939200_1403049600).json\n", + "✓ Сохранено: /home/jovyan/IRec/sigir/Beauty_new/updated_quantile_splits/raw/inter_new_[1403049600_inf).json\n", + "0 Base 20825 158689 \n", + "1 Gap 6816 19666 \n", + "2 Valid 3817 10137 \n", + "3 Test 3626 10010 \n" + ] + } + ], + "source": [ + "new_split_files = split_session_by_timestamps(\n", + " df, \n", + " [1394150400, 1399939200, 1403049600], \n", + " output_dir=\"/home/jovyan/IRec/sigir/Beauty_new/updated_quantile_splits/raw\"\n", + ")\n", + "\n", + "names = [\"Base\", \"Gap\", \"Valid\", \"Test\"]\n", + "for i, d in enumerate(new_split_files):\n", + " num_users = len(d)\n", + " \n", + " num_events = sum(len(items) for items in d.values())\n", + " \n", + " print(f\"{i:<10} {names[i]:<10} {num_users:<10} {num_events:<10}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "82fd2bca", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Merging 2 files into exp_4_0.9_inter_tiger_train.json...\n", + "✓ Done: /home/jovyan/IRec/sigir/Beauty_new/updated_quantile_splits/merged_for_exps/exp_4_0.9_inter_tiger_train.json (Users: 21760)\n", + "Merging 2 files into exp_4-1_0.9_inter_semantics_train.json...\n", + "✓ Done: /home/jovyan/IRec/sigir/Beauty_new/updated_quantile_splits/merged_for_exps/exp_4-1_0.9_inter_semantics_train.json (Users: 21760)\n", + "Merging 1 files into exp_4-2_0.8_inter_semantics_train.json...\n", + "✓ Done: /home/jovyan/IRec/sigir/Beauty_new/updated_quantile_splits/merged_for_exps/exp_4-2_0.8_inter_semantics_train.json (Users: 20825)\n", + "Merging 3 files into exp_4-3_0.95_inter_semantics_train.json...\n", + "✓ Done: /home/jovyan/IRec/sigir/Beauty_new/updated_quantile_splits/merged_for_exps/exp_4-3_0.95_inter_semantics_train.json (Users: 22079)\n", + "Merging 1 files into test_set.json...\n", + "✓ Done: /home/jovyan/IRec/sigir/Beauty_new/updated_quantile_splits/merged_for_exps/test_set.json (Users: 3626)\n", + "Merging 1 files into valid_set.json...\n", + "✓ Done: /home/jovyan/IRec/sigir/Beauty_new/updated_quantile_splits/merged_for_exps/valid_set.json (Users: 3817)\n", + "Merging 4 files into all_set.json...\n", + "✓ Done: /home/jovyan/IRec/sigir/Beauty_new/updated_quantile_splits/merged_for_exps/all_set.json (Users: 22363)\n", + "All done!\n" + ] + } + ], + "source": [ + "EXP_DIR = \"/home/jovyan/IRec/sigir/Beauty_new/updated_quantile_splits/merged_for_exps\"\n", + "\n", + "base_p, gap_p, valid_p, test_p = new_split_files[0], new_split_files[1], new_split_files[2], new_split_files[3]\n", + "\n", + "# Tiger: base + gap\n", + "merge_and_save([base_p, gap_p], EXP_DIR, \"exp_4_0.9_inter_tiger_train.json\")\n", + "\n", + "# 1. Exp 4.1 (Standard)\n", + "# Semantics: base + gap (Всё кроме валидации и теста)\n", + "merge_and_save([base_p, gap_p], EXP_DIR, \"exp_4-1_0.9_inter_semantics_train.json\")\n", + "\n", + "# 2. Exp 4.2 (Short Semantics)\n", + "# Semantics: base (Короче на пропуск, без gap)\n", + "merge_and_save([base_p], EXP_DIR, \"exp_4-2_0.8_inter_semantics_train.json\")\n", + "\n", + "# 3. Exp 4.3 (Leak)\n", + "# Semantics: base + gap + valid (Видит валидацию)\n", + "merge_and_save([base_p, gap_p, valid_p], EXP_DIR, \"exp_4-3_0.95_inter_semantics_train.json\")\n", + "\n", + "# 4. Test Set (тест всех моделей)\n", + "merge_and_save([test_p], EXP_DIR, \"test_set.json\")\n", + "\n", + "# 4. Valid Set (валидационный набор)\n", + "merge_and_save([valid_p], EXP_DIR, \"valid_set.json\")\n", + "\n", + "# 4. All Set (все данные)\n", + "merge_and_save([base_p, gap_p, valid_p, test_p], EXP_DIR, \"all_set.json\")\n", + "\n", + "print(\"All done!\")" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "d34b1c55", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Проверка Train сетов (должны быть префиксами):\n", + "доля событий всего 0.90:\n", + "✅ [ПРЕФИКСЫ] Все 21760 массивов ОК. Полных совпадений: 16175\n", + "доля событий всего 0.80:\n", + "✅ [ПРЕФИКСЫ] Все 20825 массивов ОК. Полных совпадений: 12129\n", + "доля событий всего 0.95:\n", + "✅ [ПРЕФИКСЫ] Все 22079 массивов ОК. Полных совпадений: 18737\n", + "доля событий всего 0.90:\n", + "✅ [ПРЕФИКСЫ] Все 21760 массивов ОК. Полных совпадений: 16175\n", + "\n", + "Проверка Test сета (должен быть суффиксом):\n", + "доля событий всего 0.05:\n", + "✅ [СУФФИКСЫ] Все 3626 массивов ОК. Полных совпадений: 284\n", + "\n", + "(Контроль) Проверка Test сета как префикса (должна упасть):\n", + "доля событий всего 0.05:\n", + "❌ [ПРЕФИКСЫ] Найдено 3342 ошибок.\n", + "доля событий всего 1.00:\n", + "✅ [ПРЕФИКСЫ] Все 22363 массивов ОК. Полных совпадений: 22363\n" + ] + } + ], + "source": [ + "with open(\"/home/jovyan/IRec/data/Beauty/inter_new.json\", 'r') as f:\n", + " old_inter_new = json.load(f)\n", + "\n", + "with open(\"/home/jovyan/IRec/sigir/Beauty_new/updated_quantile_splits/merged_for_exps/exp_4-1_0.9_inter_semantics_train.json\", 'r') as ff:\n", + " first_sem = json.load(ff)\n", + " \n", + "with open(\"/home/jovyan/IRec/sigir/Beauty_new/updated_quantile_splits/merged_for_exps/exp_4-2_0.8_inter_semantics_train.json\", 'r') as ff:\n", + " second_sem = json.load(ff)\n", + " \n", + "with open(\"/home/jovyan/IRec/sigir/Beauty_new/updated_quantile_splits/merged_for_exps/exp_4-3_0.95_inter_semantics_train.json\", 'r') as ff:\n", + " third_sem = json.load(ff)\n", + " \n", + "with open(\"/home/jovyan/IRec/sigir/Beauty_new/updated_quantile_splits/merged_for_exps/exp_4_0.9_inter_tiger_train.json\", 'r') as ff:\n", + " tiger_sem = json.load(ff)\n", + "\n", + "with open(\"//home/jovyan/IRec/sigir/Beauty_new/updated_quantile_splits/merged_for_exps/test_set.json\", 'r') as ff:\n", + " test_sem = json.load(ff)\n", + "\n", + "with open(\"/home/jovyan/IRec/sigir/Beauty_new/updated_quantile_splits/merged_for_exps/all_set.json\", 'r') as ff:\n", + " all_test_data = json.load(ff)\n", + "\n", + "def check_prefix_match(full_data, subset_data, check_suffix=False):\n", + " \"\"\"\n", + " check_suffix=True включит режим проверки суффиксов (для теста).\n", + " \"\"\"\n", + " mismatch_count = 0\n", + " full_match_count = 0\n", + "\n", + " num_events_full_data = sum(len(items) for items in full_data.values())\n", + " num_events_subset_data = sum(len(items) for items in subset_data.values())\n", + " print(f\"доля событий всего {(num_events_subset_data/num_events_full_data):.2f}:\")\n", + " \n", + " # Итерируемся по ключам сабсета, так как в full_data может быть больше юзеров\n", + " for user, sub_items in subset_data.items():\n", + " \n", + " # Проверяем есть ли такой юзер в исходнике\n", + " if user not in full_data:\n", + " print(f\"⚠ Юзер {user} не найден в исходном файле!\")\n", + " mismatch_count += 1\n", + " continue\n", + " \n", + " full_items = full_data[user]\n", + " \n", + " # Логика для проверки ПРЕФИКСА (начало совпадает)\n", + " if not check_suffix:\n", + " if len(sub_items) > len(full_items):\n", + " mismatch_count += 1\n", + " continue\n", + " \n", + " # Сравниваем начало full с sub\n", + " if full_items[:len(sub_items)] == sub_items:\n", + " if len(full_items) == len(sub_items):\n", + " full_match_count += 1\n", + " else:\n", + " mismatch_count += 1\n", + " \n", + " # Логика для проверки СУФФИКСА (конец совпадает - для теста)\n", + " else:\n", + " if len(sub_items) > len(full_items):\n", + " mismatch_count += 1\n", + " continue\n", + " \n", + " # Сравниваем конец full с sub\n", + " # Срез [-len:] берет последние N элементов\n", + " if full_items[-len(sub_items):] == sub_items:\n", + " if len(full_items) == len(sub_items):\n", + " full_match_count += 1\n", + " else:\n", + " mismatch_count += 1\n", + "\n", + " mode = \"СУФФИКСЫ\" if check_suffix else \"ПРЕФИКСЫ\"\n", + " \n", + " if mismatch_count == 0:\n", + " print(f\"✅ [{mode}] Все {len(subset_data)} массивов ОК. Полных совпадений: {full_match_count}\")\n", + " else:\n", + " print(f\"❌ [{mode}] Найдено {mismatch_count} ошибок.\")\n", + "\n", + "# --- Запуск проверок ---\n", + "print(\"Проверка Train сетов (должны быть префиксами):\")\n", + "check_prefix_match(old_inter_new, first_sem)\n", + "check_prefix_match(old_inter_new, second_sem)\n", + "check_prefix_match(old_inter_new, third_sem)\n", + "check_prefix_match(old_inter_new, tiger_sem)\n", + "\n", + "print(\"\\nПроверка Test сета (должен быть суффиксом):\")\n", + "check_prefix_match(old_inter_new, test_sem, check_suffix=True)\n", + "\n", + "print(\"\\n(Контроль) Проверка Test сета как префикса (должна упасть):\")\n", + "check_prefix_match(old_inter_new, test_sem, check_suffix=False)\n", + "\n", + "check_prefix_match(old_inter_new, all_test_data)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "501fae46", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Part 0 [Base]: 19666 events (2014-03-07 -> 2014-05-13)\n", + "Part 1 [Gap]: 10137 events (2014-05-13 -> 2014-06-18)\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import polars as pl\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "from datetime import datetime\n", + "labels = [\"Base\", \"Gap\", \"Valid\", \"Test\"]\n", + "\n", + "# Считаем события в каждом интервале\n", + "counts = []\n", + "intervals = [1394150400, 1399939200, 1403049600]\n", + "for i in range(len(intervals)-1):\n", + " start, end = intervals[i], intervals[i+1]\n", + " \n", + " q = df.lazy().explode([\"timestamps\"])\n", + " if end is not None:\n", + " q = q.filter((pl.col(\"timestamps\") >= start) & (pl.col(\"timestamps\") < end))\n", + " else:\n", + " q = q.filter(pl.col(\"timestamps\") >= start)\n", + " \n", + " count = q.select(pl.len()).collect().item()\n", + " counts.append(count)\n", + " \n", + " end_str = datetime.fromtimestamp(end).strftime('%Y-%m-%d') if end else \"Inf\"\n", + " start_str = datetime.fromtimestamp(start).strftime('%Y-%m-%d') if start > 0 else \"Start\"\n", + " \n", + " print(f\"Part {i} [{labels[i]}]: {count} events ({start_str} -> {end_str})\")\n", + "\n", + "# 3. Гистограмма распределения событий во времени\n", + "all_timestamps = df.select(pl.col(\"timestamps\").explode()).to_series().to_numpy()\n", + "\n", + "plt.figure(figsize=(12, 6))\n", + "plt.hist(all_timestamps, bins=100, color='skyblue', alpha=0.7, label='Events')\n", + "\n", + "# Рисуем линии отсечек\n", + "colors = ['red', 'orange', 'green']\n", + "for cutoff, color, label in zip(intervals, colors, labels[1:]):\n", + " plt.axvline(x=cutoff, color=color, linestyle='--', linewidth=2, label=f'Cutoff: {label}')\n", + "\n", + "plt.title(\"Распределение взаимодействий во времени\")\n", + "plt.xlabel(\"Timestamp\")\n", + "plt.ylabel(\"Количество событий\")\n", + "plt.legend()\n", + "plt.grid(True, alpha=0.3)\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/sigir/yambda_processing/YambdaDatasetProcessing.ipynb b/sigir/yambda_processing/YambdaDatasetProcessing.ipynb new file mode 100644 index 0000000..c36af65 --- /dev/null +++ b/sigir/yambda_processing/YambdaDatasetProcessing.ipynb @@ -0,0 +1,640 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "SbkKok0dfjjS" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.12/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "from collections import defaultdict, Counter\n", + "from typing import Any, Dict, List, Optional, Tuple\n", + "\n", + "from datasets import load_dataset\n", + "\n", + "import numpy as np\n", + "\n", + "import polars as pl\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "from torch.utils.data import Dataset, DataLoader" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gwwdsnwBfjjT" + }, + "source": [ + "## 🛠️ Подготовка данных" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "viKiSaEKfjjT", + "outputId": "6229cbba-dc3b-4d15-a8e4-ac08e4e187d6" + }, + "outputs": [], + "source": [ + "format = 'sequential'\n", + "size = '50m'\n", + "events = 'listens'\n", + "# listens_data = load_dataset('yandex/yambda', data_dir=f'{format}/{size}', data_files=f'{events}.parquet')\n", + "# yambda_df = pl.from_arrow(listens_data['train'].data.table)\n", + "yambda_df = pl.read_parquet(\"/home/jovyan/yambda_sequential_50m/sequential/50m/listens.parquet\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "VNanksDRfjjT", + "outputId": "e118e2b4-0076-475d-9104-5e1565dab7d9" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "✅ test_yambda_data_loading: OK\n" + ] + } + ], + "source": [ + "def test_yambda_data_loading():\n", + " assert isinstance(yambda_df, pl.DataFrame), 'yambda_df должен быть Polars DataFrame'\n", + " assert yambda_df.shape == (9238, 6), f'Неправильный размер: {yambda_df.shape}'\n", + "\n", + " expected_cols = {'uid', 'timestamp', 'item_id', 'is_organic', 'played_ratio_pct', 'track_length_seconds'}\n", + " assert set(yambda_df.columns) == expected_cols, f'Неправильные колонки: {yambda_df.columns}'\n", + "\n", + " assert yambda_df['item_id'].dtype == pl.List(pl.UInt32), 'item_id должен быть List[UInt32]'\n", + " assert yambda_df['timestamp'].dtype == pl.List(pl.UInt32), 'timestamp должен быть List[UInt32]'\n", + "\n", + " assert yambda_df['item_id'].list.len().min() > 0, 'Есть пустые истории'\n", + "\n", + " print('✅ test_yambda_data_loading: OK')\n", + "\n", + "test_yambda_data_loading()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 527 + }, + "id": "q33EG4wlc8ev", + "outputId": "01d03740-713e-46e8-d8c5-81ebf6b71546" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(shape: (5, 6)\n", + " ┌─────┬─────────────────────┬────────────┬─────────────┬─────────────────────┬─────────────────────┐\n", + " │ uid ┆ timestamp ┆ item_id ┆ is_organic ┆ played_ratio_pct ┆ track_length_second │\n", + " │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ s │\n", + " │ u32 ┆ list[u32] ┆ list[u32] ┆ list[u8] ┆ list[u16] ┆ --- │\n", + " │ ┆ ┆ ┆ ┆ ┆ list[u32] │\n", + " ╞═════╪═════════════════════╪════════════╪═════════════╪═════════════════════╪═════════════════════╡\n", + " │ 100 ┆ [39420, 39420, … ┆ [8326270, ┆ [0, 0, … 0] ┆ [100, 100, … 100] ┆ [170, 105, … 165] │\n", + " │ ┆ 25966140] ┆ 1441281, … ┆ ┆ ┆ │\n", + " │ ┆ ┆ 4734787] ┆ ┆ ┆ │\n", + " │ 200 ┆ [14329075, ┆ [3285270, ┆ [1, 1, … 1] ┆ [9, 28, … 100] ┆ [170, 170, … 145] │\n", + " │ ┆ 14329075, … ┆ 5253582, … ┆ ┆ ┆ │\n", + " │ ┆ 2545672… ┆ 3778807] ┆ ┆ ┆ │\n", + " │ 300 ┆ [54090, 54100, … ┆ [618910, ┆ [1, 1, … 1] ┆ [2, 4, … 15] ┆ [270, 130, … 210] │\n", + " │ ┆ 25907225] ┆ 8793425, … ┆ ┆ ┆ │\n", + " │ ┆ ┆ 9286415] ┆ ┆ ┆ │\n", + " │ 500 ┆ [22695440, ┆ [6417502, ┆ [0, 0, … 1] ┆ [100, 37, … 13] ┆ [225, 210, … 230] │\n", + " │ ┆ 22695690, … ┆ 6896222, … ┆ ┆ ┆ │\n", + " │ ┆ 2486145… ┆ 4077285] ┆ ┆ ┆ │\n", + " │ 600 ┆ [1329190, 1329405, ┆ [8077497, ┆ [0, 0, … 0] ┆ [100, 100, … 100] ┆ [245, 215, … 205] │\n", + " │ ┆ … 25997540] ┆ 1865247, … ┆ ┆ ┆ │\n", + " │ ┆ ┆ 6481452] ┆ ┆ ┆ │\n", + " └─────┴─────────────────────┴────────────┴─────────────┴─────────────────────┴─────────────────────┘,\n", + " shape: (0, 6)\n", + " ┌─────┬───────────┬───────────┬────────────┬──────────────────┬──────────────────────┐\n", + " │ uid ┆ timestamp ┆ item_id ┆ is_organic ┆ played_ratio_pct ┆ track_length_seconds │\n", + " │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n", + " │ u32 ┆ list[u32] ┆ list[u32] ┆ list[u8] ┆ list[u16] ┆ list[u32] │\n", + " ╞═════╪═══════════╪═══════════╪════════════╪══════════════════╪══════════════════════╡\n", + " └─────┴───────────┴───────────┴────────────┴──────────────────┴──────────────────────┘)" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "(yambda_df.head(), yambda_df.filter(yambda_df['timestamp'].list.len() != yambda_df['item_id'].list.len()))" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "id": "-9ou8IARfjjT" + }, + "outputs": [ + { + "ename": "ColumnNotFoundError", + "evalue": "'explode' on column: 'is_organic' is invalid\n\nSchema at this point: Schema:\nname: _idx, field: UInt32\nname: uid, field: UInt32\nname: timestamp, field: List(UInt32)\nname: item_id, field: List(UInt32)\n", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mColumnNotFoundError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[7], line 5\u001b[0m\n\u001b[1;32m 1\u001b[0m yambda_df \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 2\u001b[0m \u001b[43myambda_df\u001b[49m\n\u001b[1;32m 3\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfilter\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpl\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcol\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43muid\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m%\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m200\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m==\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# надо убрать\u001b[39;49;00m\n\u001b[1;32m 4\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwith_row_index\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43m_idx\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m----> 5\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexplode\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\n\u001b[1;32m 6\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mtimestamp\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 7\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mitem_id\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 8\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mis_organic\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 9\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mplayed_ratio_pct\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 10\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mtrack_length_seconds\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 11\u001b[0m \u001b[43m \u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;241m.\u001b[39mfilter(\n\u001b[1;32m 13\u001b[0m (pl\u001b[38;5;241m.\u001b[39mcol(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mis_organic\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m) \u001b[38;5;241m&\u001b[39m\n\u001b[1;32m 14\u001b[0m (pl\u001b[38;5;241m.\u001b[39mcol(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mplayed_ratio_pct\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m50\u001b[39m)\n\u001b[1;32m 15\u001b[0m )\n\u001b[1;32m 16\u001b[0m \u001b[38;5;241m.\u001b[39mgroup_by([\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m_idx\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124muid\u001b[39m\u001b[38;5;124m'\u001b[39m], maintain_order\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 17\u001b[0m \u001b[38;5;241m.\u001b[39magg([\n\u001b[1;32m 18\u001b[0m pl\u001b[38;5;241m.\u001b[39mcol(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtimestamp\u001b[39m\u001b[38;5;124m'\u001b[39m),\n\u001b[1;32m 19\u001b[0m pl\u001b[38;5;241m.\u001b[39mcol(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mitem_id\u001b[39m\u001b[38;5;124m'\u001b[39m),\n\u001b[1;32m 20\u001b[0m ])\n\u001b[1;32m 21\u001b[0m \u001b[38;5;241m.\u001b[39mdrop(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m_idx\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 22\u001b[0m )\n", + "File \u001b[0;32m/usr/local/lib/python3.12/dist-packages/polars/dataframe/frame.py:8072\u001b[0m, in \u001b[0;36mDataFrame.explode\u001b[0;34m(self, columns, *more_columns)\u001b[0m\n\u001b[1;32m 8015\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mexplode\u001b[39m(\n\u001b[1;32m 8016\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 8017\u001b[0m columns: \u001b[38;5;28mstr\u001b[39m \u001b[38;5;241m|\u001b[39m Expr \u001b[38;5;241m|\u001b[39m Sequence[\u001b[38;5;28mstr\u001b[39m \u001b[38;5;241m|\u001b[39m Expr],\n\u001b[1;32m 8018\u001b[0m \u001b[38;5;241m*\u001b[39mmore_columns: \u001b[38;5;28mstr\u001b[39m \u001b[38;5;241m|\u001b[39m Expr,\n\u001b[1;32m 8019\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m DataFrame:\n\u001b[1;32m 8020\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 8021\u001b[0m \u001b[38;5;124;03m Explode the dataframe to long format by exploding the given columns.\u001b[39;00m\n\u001b[1;32m 8022\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 8070\u001b[0m \u001b[38;5;124;03m └─────────┴─────────┘\u001b[39;00m\n\u001b[1;32m 8071\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m-> 8072\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlazy\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexplode\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcolumns\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmore_columns\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcollect\u001b[49m\u001b[43m(\u001b[49m\u001b[43m_eager\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/usr/local/lib/python3.12/dist-packages/polars/lazyframe/frame.py:2053\u001b[0m, in \u001b[0;36mLazyFrame.collect\u001b[0;34m(self, type_coercion, predicate_pushdown, projection_pushdown, simplify_expression, slice_pushdown, comm_subplan_elim, comm_subexpr_elim, cluster_with_columns, collapse_joins, no_optimization, streaming, engine, background, _eager, **_kwargs)\u001b[0m\n\u001b[1;32m 2051\u001b[0m \u001b[38;5;66;03m# Only for testing purposes\u001b[39;00m\n\u001b[1;32m 2052\u001b[0m callback \u001b[38;5;241m=\u001b[39m _kwargs\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpost_opt_callback\u001b[39m\u001b[38;5;124m\"\u001b[39m, callback)\n\u001b[0;32m-> 2053\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m wrap_df(\u001b[43mldf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcollect\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcallback\u001b[49m\u001b[43m)\u001b[49m)\n", + "\u001b[0;31mColumnNotFoundError\u001b[0m: 'explode' on column: 'is_organic' is invalid\n\nSchema at this point: Schema:\nname: _idx, field: UInt32\nname: uid, field: UInt32\nname: timestamp, field: List(UInt32)\nname: item_id, field: List(UInt32)\n" + ] + } + ], + "source": [ + "yambda_df = (\n", + " yambda_df\n", + " .filter(pl.col('uid') % 200 == 0) # надо убрать\n", + " .with_row_index('_idx')\n", + " .explode([\n", + " 'timestamp',\n", + " 'item_id',\n", + " 'is_organic',\n", + " 'played_ratio_pct',\n", + " 'track_length_seconds',\n", + " ])\n", + " .filter(\n", + " (pl.col('is_organic') == 0) &\n", + " (pl.col('played_ratio_pct') >= 50)\n", + " )\n", + " .group_by(['_idx', 'uid'], maintain_order=True)\n", + " .agg([\n", + " pl.col('timestamp'),\n", + " pl.col('item_id'),\n", + " ])\n", + " .drop('_idx')\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "4HTturbHfjjU", + "outputId": "ea8b1d93-4997-441a-c6e3-2628acd11d7f" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "✅ test_yambda_filtering: OK\n" + ] + } + ], + "source": [ + "def test_yambda_filtering():\n", + " assert yambda_df.shape[0] == 4289, \\\n", + " f'Неправильное количество пользователей: {yambda_df.shape[0]}'\n", + "\n", + " expected_columns = {'uid', 'timestamp', 'item_id'}\n", + " actual_columns = set(yambda_df.columns)\n", + " assert actual_columns == expected_columns, \\\n", + " f'Неправильные колонки. Ожидалось: {expected_columns}, получено: {actual_columns}'\n", + "\n", + " assert yambda_df['timestamp'].dtype == pl.List(pl.UInt32), \\\n", + " f\"timestamp должен быть List[UInt32], получено: {yambda_df['timestamp'].dtype}\"\n", + " assert yambda_df['item_id'].dtype == pl.List(pl.UInt32), \\\n", + " f\"item_id должен быть List[UInt32], получено: {yambda_df['item_id'].dtype}\"\n", + "\n", + " seq_lengths = yambda_df['item_id'].list.len()\n", + " assert seq_lengths.min() >= 1, \\\n", + " f'Минимальная длина последовательности должна быть >= 1, получено: {seq_lengths.min()}'\n", + " assert seq_lengths.sum() == 7587469, \\\n", + " f'Общее количество событий неверно. Ожидалось: 7587469, получено: {seq_lengths.sum()}'\n", + "\n", + " unique_items = yambda_df.select('item_id').explode('item_id').unique().shape[0]\n", + " assert unique_items == 304787, \\\n", + " f'Количество уникальных айтемов неверно. Ожидалось: 304787, получено: {unique_items}'\n", + "\n", + " print('✅ test_yambda_filtering: OK')\n", + "\n", + "test_yambda_filtering()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Оригинальные эмбеддинги: (7721749, 3)\n", + "Колонки: ['item_id', 'embed', 'normalized_embed']\n", + "shape: (5, 3)\n", + "┌─────────┬─────────────────────────────────┬─────────────────────────────────┐\n", + "│ item_id ┆ embed ┆ normalized_embed │\n", + "│ --- ┆ --- ┆ --- │\n", + "│ u32 ┆ list[f64] ┆ list[f64] │\n", + "╞═════════╪═════════════════════════════════╪═════════════════════════════════╡\n", + "│ 2 ┆ [-1.534035, -0.366767, … 0.999… ┆ [-0.064638, -0.015454, … 0.042… │\n", + "│ 3 ┆ [-3.761467, -1.068254, … -2.66… ┆ [-0.163937, -0.046558, … -0.11… │\n", + "│ 4 ┆ [2.445533, -2.523603, … -0.536… ┆ [0.076272, -0.078707, … -0.016… │\n", + "│ 5 ┆ [0.832846, 0.116125, … -1.4857… ┆ [0.03149, 0.004391, … -0.05617… │\n", + "│ 6 ┆ [-2.431483, -0.56872, … 0.0946… ┆ [-0.10345, -0.024197, … 0.0040… │\n", + "└─────────┴─────────────────────────────────┴─────────────────────────────────┘\n" + ] + } + ], + "source": [ + "import polars as pl\n", + "import pandas as pd\n", + "import pickle\n", + "\n", + "# === 1. Загрузить оригинальные embeddings ===\n", + "embeddings_path = \"/home/jovyan/yambda_embeddings/embeddings.parquet\"\n", + "emb_df = pl.read_parquet(embeddings_path)\n", + "\n", + "print(f\"Оригинальные эмбеддинги: {emb_df.shape}\")\n", + "print(f\"Колонки: {emb_df.columns}\")\n", + "print(emb_df.head())" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Валидные item_id: 7721749\n", + "Было строк: 4289\n", + "Стало строк: 4138\n" + ] + } + ], + "source": [ + "valid_item_ids = set(emb_df['item_id'].to_list())\n", + "print(f\"\\nВалидные item_id: {len(valid_item_ids)}\")\n", + "valid_ids_pl = pl.Series(list(valid_item_ids))\n", + "\n", + "valid_item_ids = set(emb_df['item_id'].to_list())\n", + "valid_ids_pl = pl.Series(list(valid_item_ids))\n", + "\n", + "yambda_df_filtered = (\n", + " yambda_df\n", + " .with_columns(\n", + " pl.col(\"item_id\").list.eval(\n", + " pl.when(pl.element().is_in(valid_ids_pl))\n", + " .then(pl.int_range(pl.len()))\n", + " .otherwise(None)\n", + " ).list.drop_nulls().alias(\"valid_indices\")\n", + " )\n", + " .with_columns([\n", + " pl.col(\"item_id\").list.gather(pl.col(\"valid_indices\")),\n", + " pl.col(\"timestamp\").list.gather(pl.col(\"valid_indices\"))\n", + " ])\n", + " .drop(\"valid_indices\")\n", + " .filter(pl.col(\"item_id\").list.len() > 0)\n", + " .rename({\"item_id\": \"item_ids\", \"timestamp\": \"timestamps\"})\n", + ")\n", + "yambda_df_filtered = yambda_df_filtered.filter(yambda_df_filtered['item_ids'].list.len() >= 5)\n", + "\n", + "print(f\"Было строк: {yambda_df.shape[0]}\")\n", + "print(f\"Стало строк: {yambda_df_filtered.shape[0]}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "3️⃣ Получите все уникальные ID треков из датасета и создайте маппинг: старый_id - новый_id, где новый_id находится в диапазоне от 0 до N - 1.\n", + "\n", + "Модели глубокого обучения требуют, чтобы категориальные признаки (в нашем случае ID треков) были представлены целыми числами в диапазоне от 0 до N-1, где N — количество уникальных треков. Датасет Yambda содержит оригинальные ID треков, которые могут быть разреженными (например, [100, 5000, 7, 12000, ...]) — это неэффективно для embedding-таблиц." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "unique_items = (\n", + " yambda_df_filtered\n", + " .select('item_ids')\n", + " .explode('item_ids')\n", + " .unique()\n", + " .sort('item_ids')\n", + ").with_row_index('new_item_ids')\n", + "\n", + "\n", + "item_mapping = dict(zip(unique_items['item_ids'], unique_items['new_item_ids']))\n", + "\n", + "\n", + "yambda_df_filtered = yambda_df_filtered.with_columns([\n", + " pl.col('item_ids')\n", + " .map_elements(\n", + " lambda items: [item_mapping[item] for item in items],\n", + " return_dtype=pl.List(pl.UInt32)\n", + " )\n", + " .alias('item_ids')\n", + "])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "✅ test_item_mapping: OK\n" + ] + } + ], + "source": [ + "def test_item_mapping():\n", + " assert unique_items.shape == (292865, 2), f'Неправильный размер unique_items: {unique_items.shape}'\n", + " assert set(unique_items.columns) == {'new_item_ids', 'item_ids'}, 'Неправильные колонки unique_items'\n", + "\n", + " assert len(item_mapping) == 292865, f'Неправильный размер item_mapping: {len(item_mapping)}'\n", + " assert item_mapping[50] == 0 and item_mapping[175] == 1 and item_mapping[195] == 2, \\\n", + " 'Неверные первые маппинги'\n", + "\n", + " new_ids = unique_items['new_item_ids']\n", + " assert new_ids.min() == 0 and new_ids.max() == 292864, 'new_item_id должны быть в [0, 292865,]'\n", + "\n", + " all_ids = yambda_df_filtered.select('item_ids').explode('item_ids')['item_ids']\n", + " assert all_ids.min() == 0 and all_ids.max() == 292864, 'item_id в yambda_df не обновлены'\n", + " assert all_ids.n_unique() == 292865, 'Количество уникальных item_id изменилось'\n", + "\n", + " print('✅ test_item_mapping: OK')\n", + "\n", + "test_item_mapping()" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Маппинг использует: 292865 уникальных item_id\n", + "Переиндексированные эмбеддинги: (292865, 2)\n", + "shape: (5, 2)\n", + "┌─────────────────────────────────┬─────────┐\n", + "│ embedding ┆ item_id │\n", + "│ --- ┆ --- │\n", + "│ list[f64] ┆ u32 │\n", + "╞═════════════════════════════════╪═════════╡\n", + "│ [-0.0526, 0.048672, … -0.04217… ┆ 0 │\n", + "│ [0.090222, -0.00718, … -0.0862… ┆ 1 │\n", + "│ [-0.00822, -0.057882, … 0.2188… ┆ 2 │\n", + "│ [-0.107289, -0.034719, … -0.02… ┆ 3 │\n", + "│ [0.012762, -0.043315, … 0.1494… ┆ 4 │\n", + "└─────────────────────────────────┴─────────┘\n" + ] + } + ], + "source": [ + "print(f\"\\nМаппинг использует: {len(item_mapping)} уникальных item_id\")\n", + "\n", + "# === 3. Переиндексировать embeddings используя существующий маппинг ===\n", + "emb_df_reindexed = emb_df.with_columns(\n", + " pl.col('item_id')\n", + " .map_elements(\n", + " lambda x: item_mapping.get(x, None),\n", + " return_dtype=pl.UInt32\n", + " )\n", + " .alias('new_item_id')\n", + ").filter(pl.col('new_item_id').is_not_null()).drop('item_id', 'embed').rename({'new_item_id': 'item_id', 'normalized_embed': 'embedding'})\n", + "\n", + "print(f\"Переиндексированные эмбеддинги: {emb_df_reindexed.shape}\")\n", + "print(emb_df_reindexed.head())\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "✅ test_item_mapping: OK\n" + ] + } + ], + "source": [ + "def test_emb_item_mapping():\n", + " all_ids = emb_df_reindexed['item_id']\n", + " assert all_ids.min() == 0 and all_ids.max() == 292864, 'item_id в yambda_df не обновлены'\n", + " assert all_ids.n_unique() == 292865, 'Количество уникальных item_id изменилось'\n", + "\n", + " print('✅ test_item_mapping: OK')\n", + "\n", + "test_emb_item_mapping()" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "✓ Сохранены embeddings: /home/jovyan/IRec/sigir/yambda_data/yambda_embeddings_reindexed.parquet\n" + ] + } + ], + "source": [ + "embeddings_output_parquet_path = \"/home/jovyan/IRec/sigir/yambda_data/yambda_embeddings_reindexed.parquet\"\n", + "emb_df_reindexed.write_parquet(embeddings_output_parquet_path)\n", + "print(f\"\\n✓ Сохранены embeddings: {embeddings_output_parquet_path}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Тест пройден: все 4138 строк синхронизированы\n" + ] + } + ], + "source": [ + "def test_integrity(df):\n", + " bad_rows = df.filter(\n", + " (pl.col(\"item_ids\").list.len() != pl.col(\"timestamps\").list.len()) | (pl.col(\"timestamps\").list.len() < 5)\n", + " )\n", + " \n", + " if bad_rows.height > 0:\n", + " print(f\"ОШИБКА: {bad_rows.height} строк рассинхронизированы!\")\n", + " raise ValueError(\"Рассинхрон массивов!\")\n", + " \n", + " print(f\"Тест пройден: все {df.height} строк синхронизированы\")\n", + "\n", + "test_integrity(yambda_df_filtered)" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Сохранён filtered yambda_df: /home/jovyan/IRec/sigir/yambda_data/yambda_sequential_50m_filtered_reindexed.parquet\n" + ] + } + ], + "source": [ + "yambda_output_parquet_path = \"/home/jovyan/IRec/sigir/yambda_data/yambda_sequential_50m_filtered_reindexed.parquet\"\n", + "yambda_df_filtered.write_parquet(yambda_output_parquet_path)\n", + "print(f\"Сохранён filtered yambda_df: {yambda_output_parquet_path}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Сохранён маппинг: /home/jovyan/IRec/sigir/yambda_data/old_to_new_item_id_mapping.json\n" + ] + } + ], + "source": [ + "import json\n", + "mapping_output_path = \"/home/jovyan/IRec/sigir/yambda_data/old_to_new_item_id_mapping.json\"\n", + "\n", + "with open(mapping_output_path, 'w') as f:\n", + " json.dump({str(k): v for k, v in item_mapping.items()}, f, indent=2)\n", + "\n", + "print(f\"Сохранён маппинг: {mapping_output_path}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (10, 3)
uidtimestampsitem_ids
u32list[u32]list[u32]
600[1329190, 1329405, … 25997540][252026, 58171, … 201909]
800[121100, 121290, … 25977310][20844, 198210, … 60455]
1000[11335730, 11335925, … 25972225][46643, 57592, … 95670]
1400[280570, 280735, … 25993315][4634, 213798, … 104891]
1600[899275, 930305, … 25941890][223933, 154424, … 104876]
2000[18814620, 18828965, … 25225145][137828, 138498, … 19072]
2200[10053900, 10054120, … 25948025][4923, 231643, … 28122]
2400[14246260, 14246390, … 25999860][157350, 217652, … 75038]
2600[6089640, 6089915, … 25951140][9426, 202953, … 140393]
2800[19744285, 19744475, … 25894825][123607, 291065, … 272888]
" + ], + "text/plain": [ + "shape: (10, 3)\n", + "┌──────┬─────────────────────────────────┬────────────────────────────┐\n", + "│ uid ┆ timestamps ┆ item_ids │\n", + "│ --- ┆ --- ┆ --- │\n", + "│ u32 ┆ list[u32] ┆ list[u32] │\n", + "╞══════╪═════════════════════════════════╪════════════════════════════╡\n", + "│ 600 ┆ [1329190, 1329405, … 25997540] ┆ [252026, 58171, … 201909] │\n", + "│ 800 ┆ [121100, 121290, … 25977310] ┆ [20844, 198210, … 60455] │\n", + "│ 1000 ┆ [11335730, 11335925, … 2597222… ┆ [46643, 57592, … 95670] │\n", + "│ 1400 ┆ [280570, 280735, … 25993315] ┆ [4634, 213798, … 104891] │\n", + "│ 1600 ┆ [899275, 930305, … 25941890] ┆ [223933, 154424, … 104876] │\n", + "│ 2000 ┆ [18814620, 18828965, … 2522514… ┆ [137828, 138498, … 19072] │\n", + "│ 2200 ┆ [10053900, 10054120, … 2594802… ┆ [4923, 231643, … 28122] │\n", + "│ 2400 ┆ [14246260, 14246390, … 2599986… ┆ [157350, 217652, … 75038] │\n", + "│ 2600 ┆ [6089640, 6089915, … 25951140] ┆ [9426, 202953, … 140393] │\n", + "│ 2800 ┆ [19744285, 19744475, … 2589482… ┆ [123607, 291065, … 272888] │\n", + "└──────┴─────────────────────────────────┴────────────────────────────┘" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "yambda_df_filtered.head(10)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/sigir/yambda_processing/yambda_exps_data.ipynb b/sigir/yambda_processing/yambda_exps_data.ipynb new file mode 100644 index 0000000..e8f8d8b --- /dev/null +++ b/sigir/yambda_processing/yambda_exps_data.ipynb @@ -0,0 +1,1168 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "id": "e2462a97-6705-44e1-a232-4dd78a5dfc85", + "metadata": {}, + "outputs": [], + "source": [ + "import polars as pl\n", + "import json\n", + "from typing import List, Dict" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "fd38624d-5796-4aa5-929f-7e82c5544f6c", + "metadata": {}, + "outputs": [], + "source": [ + "interactions_output_parquet_path = '/home/jovyan/IRec/sigir/yambda_data/yambda_sequential_50m_filtered_reindexed.parquet'\n", + "df = pl.read_parquet(interactions_output_parquet_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "69066941", + "metadata": {}, + "outputs": [], + "source": [ + "def merge_and_save(parts_to_merge, dirr, output_name):\n", + " merged = {}\n", + " print(f\"Merging {len(parts_to_merge)} files into {output_name}...\")\n", + " \n", + " for part in parts_to_merge:\n", + " # with open(fp, 'r') as f:\n", + " # part = json.load(f)\n", + " for uid, items in part.items():\n", + " if uid not in merged:\n", + " merged[uid] = []\n", + " merged[uid].extend(items)\n", + " \n", + " out_path = f\"{dirr}/{output_name}\"\n", + " with open(out_path, 'w') as f:\n", + " json.dump(merged, f)\n", + " print(f\"✓ Done: {out_path} (Users: {len(merged)})\")\n", + "\n", + "\n", + "def merge_and_save_with_filter(parts_to_merge, dirr, output_name, min_history_len=5):\n", + " merged = {}\n", + " print(f\"Merging {len(parts_to_merge)} files into {output_name} (min len={min_history_len})...\")\n", + " \n", + " for part in parts_to_merge:\n", + " for uid, items in part.items():\n", + " if uid not in merged:\n", + " merged[uid] = []\n", + " merged[uid].extend(items)\n", + "\n", + " filtered_merged = {}\n", + " filtered_count = 0\n", + " \n", + " for uid, items in merged.items():\n", + " if len(items) >= min_history_len:\n", + " filtered_merged[uid] = items\n", + " else:\n", + " filtered_count += 1\n", + " \n", + " print(f\"Filtered {filtered_count} users with history < {min_history_len}\")\n", + " print(f\"Remaining: {len(filtered_merged)} users\")\n", + " \n", + " out_path = f\"{dirr}/{output_name}\"\n", + " with open(out_path, 'w') as f:\n", + " json.dump(filtered_merged, f)\n", + " print(f\"Done: {out_path} (Users: {len(filtered_merged)})\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "ee127317-66b8-4f22-9109-94bcb8b1f1ae", + "metadata": {}, + "outputs": [], + "source": [ + "def split_session_by_timestamps(\n", + " df: pl.DataFrame,\n", + " time_cutoffs: List[int],\n", + " output_dir: str = None,\n", + " return_dicts: bool = True\n", + ") -> List[Dict[int, List[int]]]:\n", + " \"\"\"\n", + " Args:\n", + " df: Polars DataFrame с колонками uid, item_ids (list), timestamps (list)\n", + " time_cutoffs: Лист временных точек для разбиения\n", + " output_dir: Директория для сохранения JSON файлов (опционально)\n", + " return_dicts: Возвращать ли словари (как json_data format)\n", + " \n", + " Возвращает лист словарей в формате {user_id: [item_ids для интервала]}\n", + " \"\"\"\n", + " \n", + " result_dicts = []\n", + " \n", + " def extract_interval(df_source, start, end=None):\n", + " q = df_source.lazy()\n", + " q = q.explode([\"item_ids\", \"timestamps\"])\n", + " \n", + " if end is not None:\n", + " q = q.filter(\n", + " (pl.col(\"timestamps\") >= start) & \n", + " (pl.col(\"timestamps\") < end)\n", + " )\n", + " else:\n", + " q = q.filter(\n", + " pl.col(\"timestamps\") >= start\n", + " )\n", + " \n", + " q = q.group_by(\"uid\").agg([\n", + " pl.col(\"item_ids\").alias(\"item_ids\")\n", + " ]).sort(\"uid\")\n", + " \n", + " return q.collect()\n", + " \n", + " intervals = []\n", + " current_start = 0\n", + " for cutoff in time_cutoffs:\n", + " intervals.append((current_start, cutoff))\n", + " current_start = cutoff\n", + "\n", + " intervals.append((current_start, None))\n", + "\n", + " for start, end in intervals:\n", + " subset = extract_interval(df, start, end)\n", + "\n", + " json_dict = {}\n", + " for user_id, item_ids in subset.iter_rows():\n", + " json_dict[user_id] = item_ids\n", + " \n", + " result_dicts.append(json_dict)\n", + "\n", + " if output_dir:\n", + " if end is not None:\n", + " filename = f\"inter_new_[{start}_{end}).json\"\n", + " else:\n", + " filename = f\"inter_new_[{start}_inf).json\"\n", + " \n", + " filepath = f\"{output_dir}/{filename}\"\n", + " with open(filepath, 'w') as f:\n", + " json.dump(json_dict, f, indent=2)\n", + " \n", + " print(f\"✓ Сохранено: {filepath}\")\n", + " \n", + " return result_dicts" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "6cff8e7b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (5, 3)
uidtimestampsitem_ids
u32list[u32]list[u32]
600[1329190, 1329405, … 25997540][252026, 58171, … 201909]
800[121100, 121290, … 25977310][20844, 198210, … 60455]
1000[11335730, 11335925, … 25972225][46643, 57592, … 95670]
1400[280570, 280735, … 25993315][4634, 213798, … 104891]
1600[899275, 930305, … 25941890][223933, 154424, … 104876]
" + ], + "text/plain": [ + "shape: (5, 3)\n", + "┌──────┬─────────────────────────────────┬────────────────────────────┐\n", + "│ uid ┆ timestamps ┆ item_ids │\n", + "│ --- ┆ --- ┆ --- │\n", + "│ u32 ┆ list[u32] ┆ list[u32] │\n", + "╞══════╪═════════════════════════════════╪════════════════════════════╡\n", + "│ 600 ┆ [1329190, 1329405, … 25997540] ┆ [252026, 58171, … 201909] │\n", + "│ 800 ┆ [121100, 121290, … 25977310] ┆ [20844, 198210, … 60455] │\n", + "│ 1000 ┆ [11335730, 11335925, … 2597222… ┆ [46643, 57592, … 95670] │\n", + "│ 1400 ┆ [280570, 280735, … 25993315] ┆ [4634, 213798, … 104891] │\n", + "│ 1600 ┆ [899275, 930305, … 25941890] ┆ [223933, 154424, … 104876] │\n", + "└──────┴─────────────────────────────────┴────────────────────────────┘" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.head()" + ] + }, + { + "cell_type": "markdown", + "id": "901e7400", + "metadata": {}, + "source": [ + "# QUANTILE CUTOFF" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "8c691891", + "metadata": {}, + "outputs": [], + "source": [ + "def get_quantile_cutoffs(df, num_parts=4, base_ratio=None):\n", + " \"\"\"\n", + " Считает cutoffs так, чтобы разбить данные на части.\n", + " \n", + " Args:\n", + " num_parts: На сколько частей делить \"хвост\" истории.\n", + " base_ratio: Какую долю данных отдать в Base (самую первую часть). \n", + " Если None, делит всё поровну.\n", + " \"\"\"\n", + " # Достаем все таймстемпы в один плоский массив\n", + " # Это может занять память, если данных очень много (>100M), но для Beauty (2M) это ок\n", + " all_ts = df.select(pl.col(\"timestamps\").explode()).to_series().sort()\n", + " total_events = len(all_ts)\n", + " \n", + " print(f\"Всего событий: {total_events}\")\n", + " \n", + " cutoffs = []\n", + " \n", + " if base_ratio:\n", + " # Base занимает X% (например 80%), а остаток делим поровну на 3 части (Valid, Gap, Test)\n", + " # Остаток = 1 - base_ratio\n", + " # Каждая малая часть = (1 - base_ratio) / num_parts_tail\n", + " \n", + " base_idx = int(total_events * base_ratio)\n", + " cutoffs.append(all_ts[base_idx]) # Первый cutoff отделяет Base\n", + " \n", + " remaining_events = total_events - base_idx\n", + " part_size = remaining_events // num_parts # Делим остаток на 3 части (P1, P2, P3)\n", + " \n", + " current_idx = base_idx\n", + " for _ in range(num_parts-1): # Нам нужно еще 2 границы, чтобы получить 3 части\n", + " current_idx += part_size\n", + " cutoffs.append(all_ts[current_idx])\n", + " \n", + " else:\n", + " # Сценарий: Просто делим всё на N равных частей\n", + " step = total_events // num_parts\n", + " for i in range(1, num_parts):\n", + " idx = i * step\n", + " cutoffs.append(all_ts[idx])\n", + " \n", + " return cutoffs\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "13c1466f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Всего событий: 7371990\n", + "\n", + "--- Новые Cutoffs (по количеству событий) ---\n", + "Cutoffs: [22138015, 23136375, 24137410, 25093085]\n", + "[0, 22138015, 23136375, 24137410, 25093085, None]\n" + ] + } + ], + "source": [ + "equal_event_cutoffs = get_quantile_cutoffs(df, num_parts=4, base_ratio=0.8)\n", + "\n", + "print(\"\\n--- Новые Cutoffs (по количеству событий) ---\")\n", + "print(f\"Cutoffs: {equal_event_cutoffs}\")\n", + "\n", + "# Проверка распределения\n", + "intervals_eq = [0] + equal_event_cutoffs + [None]\n", + "print(intervals_eq)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "4e7f7b46", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "✓ Сохранено: /home/jovyan/IRec/data/Yambda/updated_quantile_splits/raw/inter_new_[0_22138015).json\n", + "✓ Сохранено: /home/jovyan/IRec/data/Yambda/updated_quantile_splits/raw/inter_new_[22138015_24137410).json\n", + "✓ Сохранено: /home/jovyan/IRec/data/Yambda/updated_quantile_splits/raw/inter_new_[24137410_25093085).json\n", + "✓ Сохранено: /home/jovyan/IRec/data/Yambda/updated_quantile_splits/raw/inter_new_[25093085_inf).json\n", + "0 Base 3813 5897592 \n", + "1 Gap 3315 737198 \n", + "2 Valid 3120 368599 \n", + "3 Test 3154 368601 \n" + ] + } + ], + "source": [ + "new_split_files = split_session_by_timestamps(\n", + " df, \n", + " [22138015, 24137410, 25093085], \n", + " output_dir=\"/home/jovyan/IRec/data/Yambda/updated_quantile_splits/raw\"\n", + ")\n", + "\n", + "names = [\"Base\", \"Gap\", \"Valid\", \"Test\"]\n", + "for i, d in enumerate(new_split_files):\n", + " num_users = len(d)\n", + " \n", + " num_events = sum(len(items) for items in d.values())\n", + " \n", + " print(f\"{i:<10} {names[i]:<10} {num_users:<10} {num_events:<10}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "82fd2bca", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Merging 2 files into exp_4_0.9_inter_tiger_train.json...\n", + "✓ Done: /home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps/exp_4_0.9_inter_tiger_train.json (Users: 4016)\n", + "Merging 2 files into exp_4-1_0.9_inter_semantics_train.json...\n", + "✓ Done: /home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps/exp_4-1_0.9_inter_semantics_train.json (Users: 4016)\n", + "Merging 1 files into exp_4-2_0.8_inter_semantics_train.json...\n", + "✓ Done: /home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps/exp_4-2_0.8_inter_semantics_train.json (Users: 3813)\n", + "Merging 3 files into exp_4-3_0.95_inter_semantics_train.json...\n", + "✓ Done: /home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps/exp_4-3_0.95_inter_semantics_train.json (Users: 4118)\n", + "Merging 1 files into test_set.json...\n", + "✓ Done: /home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps/test_set.json (Users: 3154)\n", + "Merging 1 files into valid_set.json...\n", + "✓ Done: /home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps/valid_set.json (Users: 3120)\n", + "Merging 4 files into all_set.json...\n", + "✓ Done: /home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps/all_set.json (Users: 4138)\n", + "All done!\n" + ] + } + ], + "source": [ + "EXP_DIR = \"/home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps\"\n", + "\n", + "base_p, gap_p, valid_p, test_p = new_split_files[0], new_split_files[1], new_split_files[2], new_split_files[3]\n", + "\n", + "# Tiger: base + gap\n", + "merge_and_save([base_p, gap_p], EXP_DIR, \"exp_4_0.9_inter_tiger_train.json\")\n", + "\n", + "# 1. Exp 4.1 (Standard)\n", + "# Semantics: base + gap (Всё кроме валидации и теста)\n", + "merge_and_save([base_p, gap_p], EXP_DIR, \"exp_4-1_0.9_inter_semantics_train.json\")\n", + "\n", + "# 2. Exp 4.2 (Short Semantics)\n", + "# Semantics: base (Короче на пропуск, без gap)\n", + "merge_and_save([base_p], EXP_DIR, \"exp_4-2_0.8_inter_semantics_train.json\")\n", + "\n", + "# 3. Exp 4.3 (Leak)\n", + "# Semantics: base + gap + valid (Видит валидацию)\n", + "merge_and_save([base_p, gap_p, valid_p], EXP_DIR, \"exp_4-3_0.95_inter_semantics_train.json\")\n", + "\n", + "# 4. Test Set (тест всех моделей)\n", + "merge_and_save([test_p], EXP_DIR, \"test_set.json\")\n", + "\n", + "# 4. Valid Set (валидационный набор)\n", + "merge_and_save([valid_p], EXP_DIR, \"valid_set.json\")\n", + "\n", + "# 4. All Set (все данные)\n", + "merge_and_save([base_p, gap_p, valid_p, test_p], EXP_DIR, \"all_set.json\")\n", + "\n", + "print(\"All done!\")" + ] + }, + { + "cell_type": "code", + "id": "d34b1c55", + "metadata": { + "ExecuteTime": { + "end_time": "2025-12-11T08:56:58.546300Z", + "start_time": "2025-12-11T08:56:58.343394Z" + } + }, + "source": [ + "with open(\"/home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps/all_set.json\", 'r') as f:\n", + " old_inter_new = json.load(f)\n", + "\n", + "with open(\"/home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps/exp_4-1_0.9_inter_semantics_train.json\", 'r') as ff:\n", + " first_sem = json.load(ff)\n", + " \n", + "with open(\"/home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps/exp_4-2_0.8_inter_semantics_train.json\", 'r') as ff:\n", + " second_sem = json.load(ff)\n", + " \n", + "with open(\"/home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps/exp_4-3_0.95_inter_semantics_train.json\", 'r') as ff:\n", + " third_sem = json.load(ff)\n", + " \n", + "with open(\"/home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps/exp_4_0.9_inter_tiger_train.json\", 'r') as ff:\n", + " tiger_sem = json.load(ff)\n", + "\n", + "with open(\"/home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps/test_set.json\", 'r') as ff:\n", + " test_sem = json.load(ff)\n", + "\n", + "with open(\"/home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps/all_set.json\", 'r') as ff:\n", + " all_test_data = json.load(ff)\n", + "\n", + "def check_prefix_match(full_data, subset_data, check_suffix=False):\n", + " \"\"\"\n", + " check_suffix=True включит режим проверки суффиксов (для теста).\n", + " \"\"\"\n", + " mismatch_count = 0\n", + " full_match_count = 0\n", + "\n", + " num_events_full_data = sum(len(items) for items in full_data.values())\n", + " num_events_subset_data = sum(len(items) for items in subset_data.values())\n", + " print(f\"доля событий всего {(num_events_subset_data/num_events_full_data):.2f}:\")\n", + " \n", + " for user, sub_items in subset_data.items():\n", + " \n", + " if user not in full_data:\n", + " print(f\"⚠ Юзер {user} не найден в исходном файле!\")\n", + " mismatch_count += 1\n", + " continue\n", + " \n", + " full_items = full_data[user]\n", + " \n", + " if not check_suffix:\n", + " if len(sub_items) > len(full_items):\n", + " mismatch_count += 1\n", + " continue\n", + " \n", + " if full_items[:len(sub_items)] == sub_items:\n", + " if len(full_items) == len(sub_items):\n", + " full_match_count += 1\n", + " else:\n", + " mismatch_count += 1\n", + "\n", + " else:\n", + " if len(sub_items) > len(full_items):\n", + " mismatch_count += 1\n", + " continue\n", + "\n", + " if full_items[-len(sub_items):] == sub_items:\n", + " if len(full_items) == len(sub_items):\n", + " full_match_count += 1\n", + " else:\n", + " mismatch_count += 1\n", + "\n", + " mode = \"СУФФИКСЫ\" if check_suffix else \"ПРЕФИКСЫ\"\n", + " \n", + " if mismatch_count == 0:\n", + " print(f\"OK [{mode}] Все {len(subset_data)} массивов ОК. Полных совпадений: {full_match_count}\")\n", + " else:\n", + " print(f\"NOT OK [{mode}] Найдено {mismatch_count} ошибок.\")\n", + "\n", + "# --- Запуск проверок ---\n", + "print(\"Проверка Train сетов (должны быть префиксами):\")\n", + "check_prefix_match(old_inter_new, first_sem)\n", + "check_prefix_match(old_inter_new, second_sem)\n", + "check_prefix_match(old_inter_new, third_sem)\n", + "check_prefix_match(old_inter_new, tiger_sem)\n", + "\n", + "print(\"\\nПроверка Test сета (должен быть суффиксом):\")\n", + "check_prefix_match(old_inter_new, test_sem, check_suffix=True)\n", + "\n", + "print(\"\\n(Контроль) Проверка Test сета как префикса (должна упасть):\")\n", + "check_prefix_match(old_inter_new, test_sem, check_suffix=False)\n", + "\n", + "check_prefix_match(old_inter_new, all_test_data)\n" + ], + "outputs": [ + { + "ename": "FileNotFoundError", + "evalue": "[Errno 2] No such file or directory: '/home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps/all_set.json'", + "output_type": "error", + "traceback": [ + "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", + "\u001B[0;31mFileNotFoundError\u001B[0m Traceback (most recent call last)", + "Cell \u001B[0;32mIn[1], line 1\u001B[0m\n\u001B[0;32m----> 1\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m \u001B[38;5;28;43mopen\u001B[39;49m\u001B[43m(\u001B[49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[38;5;124;43m/home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps/all_set.json\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[38;5;124;43mr\u001B[39;49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[43m)\u001B[49m \u001B[38;5;28;01mas\u001B[39;00m f:\n\u001B[1;32m 2\u001B[0m old_inter_new \u001B[38;5;241m=\u001B[39m json\u001B[38;5;241m.\u001B[39mload(f)\n\u001B[1;32m 4\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m \u001B[38;5;28mopen\u001B[39m(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m/home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps/exp_4-1_0.9_inter_semantics_train.json\u001B[39m\u001B[38;5;124m\"\u001B[39m, \u001B[38;5;124m'\u001B[39m\u001B[38;5;124mr\u001B[39m\u001B[38;5;124m'\u001B[39m) \u001B[38;5;28;01mas\u001B[39;00m ff:\n", + "File \u001B[0;32m~/repositories/ucp-author-centric/ucp-env/lib/python3.9/site-packages/IPython/core/interactiveshell.py:310\u001B[0m, in \u001B[0;36m_modified_open\u001B[0;34m(file, *args, **kwargs)\u001B[0m\n\u001B[1;32m 303\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m file \u001B[38;5;129;01min\u001B[39;00m {\u001B[38;5;241m0\u001B[39m, \u001B[38;5;241m1\u001B[39m, \u001B[38;5;241m2\u001B[39m}:\n\u001B[1;32m 304\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mValueError\u001B[39;00m(\n\u001B[1;32m 305\u001B[0m \u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mIPython won\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mt let you open fd=\u001B[39m\u001B[38;5;132;01m{\u001B[39;00mfile\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m by default \u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[1;32m 306\u001B[0m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mas it is likely to crash IPython. If you know what you are doing, \u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[1;32m 307\u001B[0m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124myou can use builtins\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124m open.\u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[1;32m 308\u001B[0m )\n\u001B[0;32m--> 310\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mio_open\u001B[49m\u001B[43m(\u001B[49m\u001B[43mfile\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n", + "\u001B[0;31mFileNotFoundError\u001B[0m: [Errno 2] No such file or directory: '/home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps/all_set.json'" + ] + } + ], + "execution_count": 1 + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "c3a0adf2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "================================================================================\n", + "ПРОВЕРКА НА ПУСТЫЕ ЧАСТИ ИСТОРИЙ\n", + "================================================================================\n", + "\n", + "[exp_4-1_0.9] Анализ...\n", + " Юзеров в сплите: 4,016 / 4,138\n", + " ПУСТЫХ сессий: 15\n", + " ОБЩИХ ПРОБЛЕМ: 15\n", + "\n", + "[exp_4-2_0.8] Анализ...\n", + " Юзеров в сплите: 3,813 / 4,138\n", + " ПУСТЫХ сессий: 22\n", + " ОБЩИХ ПРОБЛЕМ: 22\n", + "\n", + "[exp_4-3_0.95] Анализ...\n", + " Юзеров в сплите: 4,118 / 4,138\n", + " ПУСТЫХ сессий: 7\n", + " ОБЩИХ ПРОБЛЕМ: 7\n", + "\n", + "[exp_4_0.9_tiger] Анализ...\n", + " Юзеров в сплите: 4,016 / 4,138\n", + " ПУСТЫХ сессий: 15\n", + " ОБЩИХ ПРОБЛЕМ: 15\n", + "\n", + "[test_set] Анализ...\n", + " Юзеров в сплите: 3,154 / 4,138\n", + " ПУСТЫХ сессий: 105\n", + " ОБЩИХ ПРОБЛЕМ: 105\n" + ] + } + ], + "source": [ + "def check_non_empty_splits(full_data, splits_data, split_names, min_history_len=2):\n", + " \"\"\"\n", + " Проверяет, что ни одна часть истории пользователя НЕ пустая во всех разбиениях.\n", + " \"\"\"\n", + " print(\"\\n\" + \"=\"*80)\n", + " print(\"ПРОВЕРКА НА ПУСТЫЕ ЧАСТИ ИСТОРИЙ\")\n", + " print(\"=\"*80)\n", + " \n", + " all_users = set(full_data.keys())\n", + " total_issues = 0\n", + " \n", + " for i in range(len(split_names)):\n", + " split_name = split_names[i]\n", + " split_data = splits_data[i]\n", + " print(f\"\\n[{split_name}] Анализ...\")\n", + " \n", + " split_users = set(split_data.keys())\n", + " empty_sessions = []\n", + " \n", + " for user, items in split_data.items():\n", + " if not items or len(items) < min_history_len:\n", + " empty_sessions.append(user)\n", + " \n", + " issues_count = len(empty_sessions)\n", + " total_issues += issues_count\n", + " \n", + " print(f\" Юзеров в сплите: {len(split_users):,} / {len(all_users):,}\")\n", + " print(f\" ПУСТЫХ сессий: {len(empty_sessions)}\")\n", + " print(f\" ОБЩИХ ПРОБЛЕМ: {issues_count}\")\n", + " \n", + " if total_issues == 0:\n", + " print(\"\\nВСЕ РАЗБИЕНИЯ БЕЗ ПУСТЫХ СЕССИЙ\")\n", + "\n", + "split_names = ['exp_4-1_0.9', 'exp_4-2_0.8', 'exp_4-3_0.95', 'exp_4_0.9_tiger', 'test_set']\n", + "splits_list = [first_sem, second_sem, third_sem, tiger_sem, test_sem]\n", + "\n", + "check_non_empty_splits(old_inter_new, splits_list, split_names)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "43aa0142", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Merging 2 files into exp_4_0.9_inter_tiger_train.json (min len=2)...\n", + "Filtered 15 users with history < 2\n", + "Remaining: 4001 users\n", + "✓ Done: /home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps_filtered/exp_4_0.9_inter_tiger_train.json (Users: 4001)\n", + "Merging 2 files into exp_4-1_0.9_inter_semantics_train.json (min len=2)...\n", + "Filtered 15 users with history < 2\n", + "Remaining: 4001 users\n", + "✓ Done: /home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps_filtered/exp_4-1_0.9_inter_semantics_train.json (Users: 4001)\n", + "Merging 1 files into exp_4-2_0.8_inter_semantics_train.json (min len=2)...\n", + "Filtered 22 users with history < 2\n", + "Remaining: 3791 users\n", + "✓ Done: /home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps_filtered/exp_4-2_0.8_inter_semantics_train.json (Users: 3791)\n", + "Merging 3 files into exp_4-3_0.95_inter_semantics_train.json (min len=2)...\n", + "Filtered 7 users with history < 2\n", + "Remaining: 4111 users\n", + "✓ Done: /home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps_filtered/exp_4-3_0.95_inter_semantics_train.json (Users: 4111)\n", + "Merging 1 files into test_set.json...\n", + "✓ Done: /home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps_filtered/test_set.json (Users: 3154)\n", + "Merging 1 files into valid_set.json...\n", + "✓ Done: /home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps_filtered/valid_set.json (Users: 3120)\n", + "Merging 4 files into all_set.json...\n", + "✓ Done: /home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps_filtered/all_set.json (Users: 4138)\n", + "All done!\n" + ] + } + ], + "source": [ + "EXP_DIR_FILTERED = \"/home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps_filtered\"\n", + "\n", + "base_p, gap_p, valid_p, test_p = new_split_files[0], new_split_files[1], new_split_files[2], new_split_files[3]\n", + "\n", + "# Tiger: base + gap\n", + "merge_and_save_with_filter([base_p, gap_p], EXP_DIR_FILTERED, \"exp_4_0.9_inter_tiger_train.json\", min_history_len=2)\n", + "\n", + "# 1. Exp 4.1 (Standard)\n", + "# Semantics: base + gap (Всё кроме валидации и теста)\n", + "merge_and_save_with_filter([base_p, gap_p], EXP_DIR_FILTERED, \"exp_4-1_0.9_inter_semantics_train.json\", min_history_len=2)\n", + "\n", + "# 2. Exp 4.2 (Short Semantics)\n", + "# Semantics: base (Короче на пропуск, без gap)\n", + "merge_and_save_with_filter([base_p], EXP_DIR_FILTERED, \"exp_4-2_0.8_inter_semantics_train.json\", min_history_len=2)\n", + "\n", + "# 3. Exp 4.3 (Leak)\n", + "# Semantics: base + gap + valid (Видит валидацию)\n", + "merge_and_save_with_filter([base_p, gap_p, valid_p], EXP_DIR_FILTERED, \"exp_4-3_0.95_inter_semantics_train.json\", min_history_len=2)\n", + "\n", + "# 4. Test Set (тест всех моделей)\n", + "merge_and_save([test_p], EXP_DIR_FILTERED, \"test_set.json\")\n", + "\n", + "# 4. Valid Set (валидационный набор)\n", + "merge_and_save([valid_p], EXP_DIR_FILTERED, \"valid_set.json\")\n", + "\n", + "# 4. All Set (все данные)\n", + "merge_and_save([base_p, gap_p, valid_p, test_p], EXP_DIR_FILTERED, \"all_set.json\")\n", + "\n", + "print(\"All done!\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9060beaa", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Проверка Train сетов (должны быть префиксами):\n", + "доля событий всего 0.90:\n", + "✅ [ПРЕФИКСЫ] Все 4001 массивов ОК. Полных совпадений: 564\n", + "доля событий всего 0.80:\n", + "✅ [ПРЕФИКСЫ] Все 3791 массивов ОК. Полных совпадений: 343\n", + "доля событий всего 0.95:\n", + "✅ [ПРЕФИКСЫ] Все 4111 массивов ОК. Полных совпадений: 984\n", + "доля событий всего 0.90:\n", + "✅ [ПРЕФИКСЫ] Все 4001 массивов ОК. Полных совпадений: 564\n", + "\n", + "Проверка Test сета (должен быть суффиксом):\n", + "доля событий всего 0.05:\n", + "✅ [СУФФИКСЫ] Все 3154 массивов ОК. Полных совпадений: 20\n", + "\n", + "(Контроль) Проверка Test сета как префикса (должна упасть):\n", + "доля событий всего 0.05:\n", + "❌ [ПРЕФИКСЫ] Найдено 3134 ошибок.\n", + "доля событий всего 1.00:\n", + "✅ [ПРЕФИКСЫ] Все 4138 массивов ОК. Полных совпадений: 4138\n", + "\n", + "================================================================================\n", + "ПРОВЕРКА НА ПУСТЫЕ ЧАСТИ ИСТОРИЙ\n", + "================================================================================\n", + "\n", + "[exp_4-1_0.9] Анализ...\n", + " Юзеров в сплите: 4,001 / 4,138\n", + " ПУСТЫХ сессий: 0\n", + " ОБЩИХ ПРОБЛЕМ: 0\n", + "\n", + "[exp_4-2_0.8] Анализ...\n", + " Юзеров в сплите: 3,791 / 4,138\n", + " ПУСТЫХ сессий: 0\n", + " ОБЩИХ ПРОБЛЕМ: 0\n", + "\n", + "[exp_4-3_0.95] Анализ...\n", + " Юзеров в сплите: 4,111 / 4,138\n", + " ПУСТЫХ сессий: 0\n", + " ОБЩИХ ПРОБЛЕМ: 0\n", + "\n", + "[exp_4_0.9_tiger] Анализ...\n", + " Юзеров в сплите: 4,001 / 4,138\n", + " ПУСТЫХ сессий: 0\n", + " ОБЩИХ ПРОБЛЕМ: 0\n", + "\n", + "ВСЕ РАЗБИЕНИЯ БЕЗ ПУСТЫХ СЕССИЙ\n" + ] + } + ], + "source": [ + "with open(\"/home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_expsx/exp_4-1_0.9_inter_semantics_train.json\", 'r') as ff:\n", + " filtered_first_sem = json.load(ff)\n", + " \n", + "with open(\"/home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps_filtered/exp_4-2_0.8_inter_semantics_train.json\", 'r') as ff:\n", + " filtered_second_sem = json.load(ff)\n", + " \n", + "with open(\"/home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps_filtered/exp_4-3_0.95_inter_semantics_train.json\", 'r') as ff:\n", + " filtered_third_sem = json.load(ff)\n", + " \n", + "with open(\"/home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps_filtered/exp_4_0.9_inter_tiger_train.json\", 'r') as ff:\n", + " filtered_tiger_sem = json.load(ff)\n", + "\n", + "with open(\"/home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps_filtered/valid_set.json\", 'r') as ff:\n", + " fiiltered_valid_sem = json.load(ff)\n", + "\n", + "with open(\"/home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps_filtered/test_set.json\", 'r') as ff:\n", + " fiiltered_test_sem = json.load(ff)\n", + "\n", + "with open(\"/home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps_filtered/all_set.json\", 'r') as ff:\n", + " filtered_all_test_data = json.load(ff)\n", + "\n", + "# --- Запуск проверок ---\n", + "print(\"Проверка Train сетов (должны быть префиксами):\")\n", + "check_prefix_match(filtered_all_test_data, filtered_first_sem)\n", + "check_prefix_match(filtered_all_test_data, filtered_second_sem)\n", + "check_prefix_match(filtered_all_test_data, filtered_third_sem)\n", + "check_prefix_match(filtered_all_test_data, filtered_tiger_sem)\n", + "\n", + "print(\"\\nПроверка Test сета (должен быть суффиксом):\")\n", + "check_prefix_match(filtered_all_test_data, test_sem, check_suffix=True)\n", + "\n", + "print(\"\\n(Контроль) Проверка Test сета как префикса (должна упасть):\")\n", + "check_prefix_match(filtered_all_test_data, test_sem, check_suffix=False)\n", + "\n", + "check_prefix_match(filtered_all_test_data, all_test_data)\n", + "\n", + "split_names = ['exp_4-1_0.9', 'exp_4-2_0.8', 'exp_4-3_0.95', 'exp_4_0.9_tiger']\n", + "splits_list_filtered = [filtered_first_sem, filtered_second_sem, filtered_third_sem, filtered_tiger_sem]\n", + "\n", + "check_non_empty_splits(filtered_all_test_data, splits_list_filtered, split_names, min_history_len = 2)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "c540c8d5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "для теста и валидации (может упасть и скорее всего упадет)\n", + "\n", + "================================================================================\n", + "ПРОВЕРКА НА ПУСТЫЕ ЧАСТИ ИСТОРИЙ\n", + "================================================================================\n", + "\n", + "[valid] Анализ...\n", + " Юзеров в сплите: 3,120 / 4,138\n", + " ПУСТЫХ сессий: 88\n", + " ОБЩИХ ПРОБЛЕМ: 88\n", + "\n", + "[test] Анализ...\n", + " Юзеров в сплите: 3,154 / 4,138\n", + " ПУСТЫХ сессий: 105\n", + " ОБЩИХ ПРОБЛЕМ: 105\n" + ] + } + ], + "source": [ + "print(\"для теста и валидации (может упасть и скорее всего упадет)\")\n", + "vt_split_names = ['valid', 'test']\n", + "vt_splits_list_filtered = [fiiltered_valid_sem, test_sem]\n", + "\n", + "check_non_empty_splits(filtered_all_test_data, vt_splits_list_filtered, vt_split_names, min_history_len = 2)" + ] + }, + { + "cell_type": "markdown", + "id": "89efa96e", + "metadata": {}, + "source": [ + "# Разбиение YAMBDA по неделям" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "28e4ddc8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Cutoffs: [25740785, 25827185, 25913585]\n", + "✓ Сохранено: /home/jovyan/IRec/data/Yambda/day-splits/raw/inter_new_[0_25740785).json\n", + "✓ Сохранено: /home/jovyan/IRec/data/Yambda/day-splits/raw/inter_new_[25740785_25827185).json\n", + "✓ Сохранено: /home/jovyan/IRec/data/Yambda/day-splits/raw/inter_new_[25827185_25913585).json\n", + "✓ Сохранено: /home/jovyan/IRec/data/Yambda/day-splits/raw/inter_new_[25913585_inf).json\n", + "Part 0 [Base]: 4133 users\n", + "Part 1 [day -3]: 1381 users\n", + "Part 2 [day -2]: 1350 users\n", + "Part 3 [day -1]: 1403 users\n" + ] + } + ], + "source": [ + "global_max_time = df.select(\n", + " pl.col(\"timestamps\").explode().max()\n", + ").item()\n", + "\n", + "# 3. Размер окна (неделя)\n", + "days_val = 1\n", + "window_sec = days_val * 24 * 3600 \n", + "\n", + "# 4. Три отсечки с конца\n", + "cutoff_test_start = global_max_time - window_sec # T - 1w\n", + "cutoff_val_start = global_max_time - 2 * window_sec # T - 2w\n", + "cutoff_gap_start = global_max_time - 3 * window_sec # T - 3w\n", + "\n", + "cutoffs = [\n", + " int(cutoff_gap_start), # Граница Part 0 | Part 1\n", + " int(cutoff_val_start), # Граница Part 1 | Part 2\n", + " int(cutoff_test_start) # Граница Part 2 | Part 3\n", + "]\n", + "\n", + "print(f\"Cutoffs: {cutoffs}\")\n", + "\n", + "split_files = split_session_by_timestamps(\n", + " df, \n", + " cutoffs, \n", + " output_dir=\"/home/jovyan/IRec/data/Yambda/day-splits/raw\"\n", + ")\n", + "\n", + "names = [\"Base\", \"day -3\", \"day -2\", \"day -1\"]\n", + "for i, d in enumerate(split_files):\n", + " print(f\"Part {i} [{names[i]}]: {len(d)} users\")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "8d5b0c22", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Merging 2 files into exp_4_0.9_inter_tiger_train.json (min len=2)...\n", + "Filtered 3 users with history < 2\n", + "Remaining: 4133 users\n", + "✓ Done: /home/jovyan/IRec/data/Yambda/day-splits/merged_for_exps_filtered/exp_4_0.9_inter_tiger_train.json (Users: 4133)\n", + "Merging 2 files into exp_4-1_0.9_inter_semantics_train.json (min len=2)...\n", + "Filtered 3 users with history < 2\n", + "Remaining: 4133 users\n", + "✓ Done: /home/jovyan/IRec/data/Yambda/day-splits/merged_for_exps_filtered/exp_4-1_0.9_inter_semantics_train.json (Users: 4133)\n", + "Merging 1 files into exp_4-2_0.8_inter_semantics_train.json (min len=2)...\n", + "Filtered 3 users with history < 2\n", + "Remaining: 4130 users\n", + "✓ Done: /home/jovyan/IRec/data/Yambda/day-splits/merged_for_exps_filtered/exp_4-2_0.8_inter_semantics_train.json (Users: 4130)\n", + "Merging 3 files into exp_4-3_0.95_inter_semantics_train.json (min len=2)...\n", + "Filtered 3 users with history < 2\n", + "Remaining: 4133 users\n", + "✓ Done: /home/jovyan/IRec/data/Yambda/day-splits/merged_for_exps_filtered/exp_4-3_0.95_inter_semantics_train.json (Users: 4133)\n", + "Merging 1 files into test_set.json (min len=1)...\n", + "Filtered 0 users with history < 1\n", + "Remaining: 1403 users\n", + "✓ Done: /home/jovyan/IRec/data/Yambda/day-splits/merged_for_exps_filtered/test_set.json (Users: 1403)\n", + "Merging 1 files into valid_set.json (min len=1)...\n", + "Filtered 0 users with history < 1\n", + "Remaining: 1350 users\n", + "✓ Done: /home/jovyan/IRec/data/Yambda/day-splits/merged_for_exps_filtered/valid_set.json (Users: 1350)\n", + "Merging 4 files into all_set.json...\n", + "✓ Done: /home/jovyan/IRec/data/Yambda/day-splits/merged_for_exps_filtered/all_set.json (Users: 4138)\n", + "All done!\n" + ] + } + ], + "source": [ + "EXP_DIR_FILTERED = \"/home/jovyan/IRec/data/Yambda/day-splits/merged_for_exps_filtered\"\n", + "\n", + "base_p, gap_p, valid_p, test_p = split_files[0], split_files[1], split_files[2], split_files[3]\n", + "\n", + "# Tiger: base + gap\n", + "merge_and_save_with_filter([base_p, gap_p], EXP_DIR_FILTERED, \"exp_4_0.9_inter_tiger_train.json\", min_history_len=2)\n", + "\n", + "# 1. Exp 4.1 (Standard)\n", + "# Semantics: base + gap (Всё кроме валидации и теста)\n", + "merge_and_save_with_filter([base_p, gap_p], EXP_DIR_FILTERED, \"exp_4-1_0.9_inter_semantics_train.json\", min_history_len=2)\n", + "\n", + "# 2. Exp 4.2 (Short Semantics)\n", + "# Semantics: base (Короче на пропуск, без gap)\n", + "merge_and_save_with_filter([base_p], EXP_DIR_FILTERED, \"exp_4-2_0.8_inter_semantics_train.json\", min_history_len=2)\n", + "\n", + "# 3. Exp 4.3 (Leak)\n", + "# Semantics: base + gap + valid (Видит валидацию)\n", + "merge_and_save_with_filter([base_p, gap_p, valid_p], EXP_DIR_FILTERED, \"exp_4-3_0.95_inter_semantics_train.json\", min_history_len=2)\n", + "\n", + "# 4. Test Set (тест всех моделей)\n", + "merge_and_save_with_filter([test_p], EXP_DIR_FILTERED, \"test_set.json\", min_history_len=1)\n", + "\n", + "# 4. Valid Set (валидационный набор)\n", + "merge_and_save_with_filter([valid_p], EXP_DIR_FILTERED, \"valid_set.json\", min_history_len=1)\n", + "\n", + "# 4. All Set (все данные)\n", + "merge_and_save([base_p, gap_p, valid_p, test_p], EXP_DIR_FILTERED, \"all_set.json\")\n", + "\n", + "print(\"All done!\")" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "c0b9b767", + "metadata": {}, + "outputs": [], + "source": [ + "def check_non_empty_splits(full_data, splits_data, split_names, min_history_len=2):\n", + " \"\"\"\n", + " Проверяет, что ни одна часть истории пользователя НЕ пустая во всех разбиениях.\n", + " \"\"\"\n", + " print(\"\\n\" + \"=\"*80)\n", + " print(\"ПРОВЕРКА НА ПУСТЫЕ ЧАСТИ ИСТОРИЙ\")\n", + " print(\"=\"*80)\n", + " \n", + " all_users = set(full_data.keys())\n", + " total_issues = 0\n", + " \n", + " for i in range(len(split_names)):\n", + " split_name = split_names[i]\n", + " split_data = splits_data[i]\n", + " print(f\"\\n[{split_name}] Анализ...\")\n", + " \n", + " split_users = set(split_data.keys())\n", + " empty_sessions = []\n", + " \n", + " for user, items in split_data.items():\n", + " if not items or len(items) < min_history_len:\n", + " empty_sessions.append(user)\n", + " \n", + " issues_count = len(empty_sessions)\n", + " total_issues += issues_count\n", + " \n", + " print(f\" Юзеров в сплите: {len(split_users):,} / {len(all_users):,}\")\n", + " print(f\" ПУСТЫХ сессий: {len(empty_sessions)}\")\n", + " print(f\" ОБЩИХ ПРОБЛЕМ: {issues_count}\")\n", + " \n", + " if total_issues == 0:\n", + " print(\"\\nВСЕ РАЗБИЕНИЯ БЕЗ ПУСТЫХ СЕССИЙ\")\n", + "\n", + "def check_prefix_match(full_data, subset_data, check_suffix=False):\n", + " \"\"\"\n", + " check_suffix=True включит режим проверки суффиксов (для теста).\n", + " \"\"\"\n", + " mismatch_count = 0\n", + " full_match_count = 0\n", + " \n", + " # Итерируемся по ключам сабсета, так как в full_data может быть больше юзеров\n", + " for user, sub_items in subset_data.items():\n", + " \n", + " # Проверяем есть ли такой юзер в исходнике\n", + " if user not in full_data:\n", + " print(f\"⚠ Юзер {user} не найден в исходном файле!\")\n", + " mismatch_count += 1\n", + " continue\n", + " \n", + " full_items = full_data[user]\n", + " \n", + " # Логика для проверки ПРЕФИКСА (начало совпадает)\n", + " if not check_suffix:\n", + " if len(sub_items) > len(full_items):\n", + " mismatch_count += 1\n", + " continue\n", + " \n", + " # Сравниваем начало full с sub\n", + " if full_items[:len(sub_items)] == sub_items:\n", + " if len(full_items) == len(sub_items):\n", + " full_match_count += 1\n", + " else:\n", + " mismatch_count += 1\n", + " \n", + " # Логика для проверки СУФФИКСА (конец совпадает - для теста)\n", + " else:\n", + " if len(sub_items) > len(full_items):\n", + " mismatch_count += 1\n", + " continue\n", + " \n", + " # Сравниваем конец full с sub\n", + " # Срез [-len:] берет последние N элементов\n", + " if full_items[-len(sub_items):] == sub_items:\n", + " if len(full_items) == len(sub_items):\n", + " full_match_count += 1\n", + " else:\n", + " mismatch_count += 1\n", + "\n", + " mode = \"СУФФИКСЫ\" if check_suffix else \"ПРЕФИКСЫ\"\n", + " \n", + " if mismatch_count == 0:\n", + " print(f\"✅ [{mode}] Все {len(subset_data)} массивов ОК. Полных совпадений: {full_match_count}\")\n", + " else:\n", + " print(f\"❌ [{mode}] Найдено {mismatch_count} ошибок.\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "36ac0115", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Проверка Train сетов (должны быть префиксами):\n", + "✅ [ПРЕФИКСЫ] Все 4133 массивов ОК. Полных совпадений: 2272\n", + "✅ [ПРЕФИКСЫ] Все 4130 массивов ОК. Полных совпадений: 1969\n", + "✅ [ПРЕФИКСЫ] Все 4133 массивов ОК. Полных совпадений: 2735\n", + "✅ [ПРЕФИКСЫ] Все 4133 массивов ОК. Полных совпадений: 2272\n", + "\n", + "Проверка Test сета (должен быть суффиксом):\n", + "✅ [СУФФИКСЫ] Все 1403 массивов ОК. Полных совпадений: 2\n", + "\n", + "(Контроль) Проверка Test сета как префикса (должна упасть):\n", + "❌ [ПРЕФИКСЫ] Найдено 1401 ошибок.\n", + "✅ [ПРЕФИКСЫ] Все 4138 массивов ОК. Полных совпадений: 4138\n" + ] + } + ], + "source": [ + "with open(f\"{EXP_DIR_FILTERED}/exp_4-1_0.9_inter_semantics_train.json\", 'r') as ff:\n", + " filtered_first_sem = json.load(ff)\n", + " \n", + "with open(f\"{EXP_DIR_FILTERED}/exp_4-2_0.8_inter_semantics_train.json\", 'r') as ff:\n", + " filtered_second_sem = json.load(ff)\n", + " \n", + "with open(f\"{EXP_DIR_FILTERED}/exp_4-3_0.95_inter_semantics_train.json\", 'r') as ff:\n", + " filtered_third_sem = json.load(ff)\n", + " \n", + "with open(f\"{EXP_DIR_FILTERED}/exp_4_0.9_inter_tiger_train.json\", 'r') as ff:\n", + " filtered_tiger_sem = json.load(ff)\n", + "\n", + "with open(f\"{EXP_DIR_FILTERED}/valid_set.json\", 'r') as ff:\n", + " fiiltered_valid_sem = json.load(ff)\n", + "\n", + "with open(f\"{EXP_DIR_FILTERED}/test_set.json\", 'r') as ff:\n", + " filtered_test_sem = json.load(ff)\n", + "\n", + "with open(f\"{EXP_DIR_FILTERED}/all_set.json\", 'r') as ff:\n", + " filtered_all_test_data = json.load(ff)\n", + "\n", + "# --- Запуск проверок ---\n", + "print(\"Проверка Train сетов (должны быть префиксами):\")\n", + "check_prefix_match(filtered_all_test_data, filtered_first_sem)\n", + "check_prefix_match(filtered_all_test_data, filtered_second_sem)\n", + "check_prefix_match(filtered_all_test_data, filtered_third_sem)\n", + "check_prefix_match(filtered_all_test_data, filtered_tiger_sem)\n", + "\n", + "print(\"\\nПроверка Test сета (должен быть суффиксом):\")\n", + "check_prefix_match(filtered_all_test_data, filtered_test_sem, check_suffix=True)\n", + "\n", + "print(\"\\n(Контроль) Проверка Test сета как префикса (должна упасть):\")\n", + "check_prefix_match(filtered_all_test_data, filtered_test_sem, check_suffix=False)\n", + "\n", + "check_prefix_match(filtered_all_test_data, filtered_all_test_data)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "2c65331b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "================================================================================\n", + "ПРОВЕРКА НА ПУСТЫЕ ЧАСТИ ИСТОРИЙ\n", + "================================================================================\n", + "\n", + "[exp_4-1_0.9] Анализ...\n", + " Юзеров в сплите: 4,133 / 4,138\n", + " ПУСТЫХ сессий: 0\n", + " ОБЩИХ ПРОБЛЕМ: 0\n", + "\n", + "[exp_4-2_0.8] Анализ...\n", + " Юзеров в сплите: 4,130 / 4,138\n", + " ПУСТЫХ сессий: 0\n", + " ОБЩИХ ПРОБЛЕМ: 0\n", + "\n", + "[exp_4-3_0.95] Анализ...\n", + " Юзеров в сплите: 4,133 / 4,138\n", + " ПУСТЫХ сессий: 0\n", + " ОБЩИХ ПРОБЛЕМ: 0\n", + "\n", + "[exp_4_0.9_tiger] Анализ...\n", + " Юзеров в сплите: 4,133 / 4,138\n", + " ПУСТЫХ сессий: 0\n", + " ОБЩИХ ПРОБЛЕМ: 0\n", + "\n", + "ВСЕ РАЗБИЕНИЯ БЕЗ ПУСТЫХ СЕССИЙ\n" + ] + } + ], + "source": [ + "split_names = ['exp_4-1_0.9', 'exp_4-2_0.8', 'exp_4-3_0.95', 'exp_4_0.9_tiger']\n", + "splits_list_filtered = [filtered_first_sem, filtered_second_sem, filtered_third_sem, filtered_tiger_sem]\n", + "\n", + "check_non_empty_splits(filtered_all_test_data, splits_list_filtered, split_names, min_history_len = 2)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "f596db64", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0 Base 4133 7264231 \n", + "1 day -3 1381 36676 \n", + "2 day -2 1350 35128 \n", + "3 day -1 1403 35955 \n" + ] + } + ], + "source": [ + "filtered_all_test_data.keys()\n", + "for i, d in enumerate(split_files):\n", + " num_users = len(d)\n", + " \n", + " num_events = sum(len(items) for items in d.values())\n", + " \n", + " print(f\"{i:<10} {names[i]:<10} {num_users:<10} {num_events:<10}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/src/irec/callbacks/stopping.py b/src/irec/callbacks/stopping.py index 3d1405f..bbe091f 100644 --- a/src/irec/callbacks/stopping.py +++ b/src/irec/callbacks/stopping.py @@ -44,14 +44,18 @@ def after_step(self, runner: Runner, context: RunnerContext): metric = context.metrics[self._metric] if self._best_metric is None: self._best_metric = metric - torch.save(runner.model.state_dict(), f'{self._model_path}_best_{round(self._best_metric, 4)}.pth') + save_path = f'{self._model_path}_best_{round(self._best_metric, 4)}.pth' + os.makedirs(os.path.dirname(save_path), exist_ok=True) + torch.save(runner.model.state_dict(), save_path) else: if (self._minimize and metric < self._best_metric) or (not self._minimize and metric > self._best_metric): self._wait = 0 old_metric = self._best_metric self._best_metric = metric # Saving new model - torch.save(runner.model.state_dict(), f'{self._model_path}_best_{round(self._best_metric, 4)}.pth') + save_path = f'{self._model_path}_best_{round(self._best_metric, 4)}.pth' + os.makedirs(os.path.dirname(save_path), exist_ok=True) + torch.save(runner.model.state_dict(), save_path) # Deleting old model if str(round(self._best_metric, 4)) != str(round(old_metric, 4)): os.remove(f'{self._model_path}_best_{round(old_metric, 4)}.pth')